add chi-router, auth middleware & user roles.

group config options & split database logic
This commit is contained in:
gilex-dev 2023-11-05 17:42:14 +01:00
parent e4c9563961
commit a21494f94b
24 changed files with 3988 additions and 346 deletions

View File

@ -1,4 +1,16 @@
sqlite3_file: 'YetAnotherToDoList.sqlite3' database:
log_file: 'YetAnotherToDoList.log' sqlite3File: 'YetAnotherToDoList.sqlite3'
log_UTC: false secret: 'aS3cureAppl1cationk3y'
port: 4242 initialAdmin:
userName: 'admin'
password: 'temporaryPassword'
logging:
logFile: 'YetAnotherToDoList.log'
logUTC: false
server:
portHTTP: 4242
portHTTPS: 4241
certFile: 'certFile.crt'
keyFile: 'keyFile.key'

View File

@ -62,3 +62,8 @@ Commands were run in the order listed below on a debian based system.
pnpm install villus graphql pnpm install villus graphql
pnpm install graphql-tag pnpm install graphql-tag
``` ```
- Add go-chi
```bash
go get -u github.com/go-chi/chi/v5
go mod tidy
```

View File

@ -74,8 +74,8 @@ func init() {
// will be global for your application. // will be global for your application.
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.YetAnotherToDoList.yaml)") rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.YetAnotherToDoList.yaml)")
rootCmd.PersistentFlags().String("log_file", "", "Path to log file") rootCmd.PersistentFlags().String("logFile", "", "Path to log file")
rootCmd.PersistentFlags().String("sqlite3_file", "", "Path to SQLite3 database") rootCmd.PersistentFlags().String("sqlite3File", "", "Path to SQLite3 database")
// Cobra also supports local flags, which will only run // Cobra also supports local flags, which will only run
// when this action is called directly. // when this action is called directly.
@ -120,7 +120,7 @@ func initLog() {
time_zone_local, _ := time.Now().Zone() time_zone_local, _ := time.Now().Zone()
time_zone_offset := strings.Split(time.Now().In(time.Local).String(), " ")[2] time_zone_offset := strings.Split(time.Now().In(time.Local).String(), " ")[2]
if viper.GetBool("log_UTC") { if viper.GetBool("logging.logUTC") {
utc = log.LUTC utc = log.LUTC
time_zone_use = "UTC" time_zone_use = "UTC"
time_zone_alt = time_zone_local time_zone_alt = time_zone_local
@ -140,25 +140,25 @@ func initLog() {
logger_flags := log.Ldate | log.Ltime | utc logger_flags := log.Ldate | log.Ltime | utc
globals.Logger = log.New(os.Stdout, "", logger_flags) globals.Logger = log.New(os.Stdout, "", logger_flags)
if err := viper.BindPFlag("log_file", rootCmd.Flags().Lookup("log_file")); err != nil { if err := viper.BindPFlag("logging.logFile", rootCmd.Flags().Lookup("logFile")); err != nil {
fmt.Println("Unable to bind flag:", err) fmt.Println("Unable to bind flag:", err)
} }
if viper.GetString("log_file") != "" { if viper.GetString("logging.logFile") != "" {
log_path, err := filepath.Abs(viper.GetString("log_file")) log_path, err := filepath.Abs(viper.GetString("logging.logFile"))
globals.Logger.SetOutput(os.Stdout) globals.Logger.SetOutput(os.Stdout)
if err != nil { if err != nil {
globals.Logger.Println("Invalid path for log file", log_path) globals.Logger.Println("Invalid path for log file", log_path)
} }
log_file, err := os.OpenFile(log_path, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) logFile, err := os.OpenFile(log_path, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
if err != nil { if err != nil {
globals.Logger.Println("Failed to write to log file:", err) globals.Logger.Println("Failed to write to log file:", err)
} else { } else {
globals.Logger.Println("Switching to log file", log_path) globals.Logger.Println("Switching to log file", log_path)
globals.Logger.SetOutput(log_file) globals.Logger.SetOutput(logFile)
} }
} }
@ -170,19 +170,21 @@ func initLog() {
} }
func initDB() { func initDB() {
if err := viper.BindPFlag("sqlite3_file", rootCmd.Flags().Lookup("sqlite3_file")); err != nil { if err := viper.BindPFlag("database.sqlite3File", rootCmd.Flags().Lookup("sqlite3File")); err != nil {
fmt.Println("Unable to bind flag:", err) fmt.Println("Unable to bind flag:", err)
} }
if viper.GetString("sqlite3_file") == "" { if viper.GetString("database.sqlite3File") == "" {
globals.Logger.Fatalln("No SQLite3 file specified") globals.Logger.Fatalln("No SQLite3 file specified")
} }
db_path, err := filepath.Abs(viper.GetString("sqlite3_file")) db_path, err := filepath.Abs(viper.GetString("database.sqlite3File"))
if err != nil { if err != nil {
globals.Logger.Fatalln("Invalid path for SQLite3 file", db_path) globals.Logger.Fatalln("Invalid path for SQLite3 file", db_path)
} }
globals.Logger.Println("Connecting to SQLite3", db_path) globals.Logger.Println("Connecting to SQLite3", db_path)
globals.DB = database.InitSQLite3(db_path, globals.DB_schema, globals.Logger) globals.DB = database.InitSQLite3(db_path, globals.DB_schema, globals.Logger, []byte(viper.GetString("database.secret")), viper.GetString("database.initialAdmin.userName"), viper.GetString("database.initialAdmin.password"))
globals.DB.CleanExpiredRefreshTokensTicker(time.Minute * 10)
globals.DB.CleanRevokedAccessTokensTicker(time.Minute * 10)
} }

View File

@ -38,7 +38,7 @@ var serverCmd = &cobra.Command{
} }
globals.Logger.Println("starting http server...") globals.Logger.Println("starting http server...")
server.StartServer(viper.GetInt("port")) server.StartServer(viper.GetInt("server.portHTTP"), viper.GetInt("server.portHTTPS"), viper.GetString("server.certFile"), viper.GetString("server.keyFile"))
}, },
} }

475
database/crypto_helpers.go Normal file
View File

@ -0,0 +1,475 @@
/*
YetAnotherToDoList
Copyright © 2023 gilex-dev gilex-dev@proton.me
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package database
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"slices"
"strconv"
"strings"
"time"
"unicode"
"golang.org/x/crypto/bcrypt"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
)
type RefreshToken struct {
Selector string `json:"selector"`
Token string `json:"token"`
ExpiryDate int `json:"expiryDate"`
}
type AccessToken struct {
UserId string `json:"userId"`
IsAdmin bool `json:"isAdmin"`
IsUserCreator bool `json:"isUserCreator"`
ExpiryDate int `json:"expiryDate"`
}
const refreshTokenLifetime = "+10 day" // TODO: add to viper
const accessTokenLifetime = time.Minute * 10
const minPasswordLength = 10
const minUserNameLength = 10
var RevokedAccessTokens []*AccessToken
// ValidatePassword validates the passed string against the password criteria.
func ValidatePassword(password string) error {
if len([]rune(password)) < minPasswordLength {
return fmt.Errorf("password must be at least %d characters long", minPasswordLength)
}
for i, c := range password {
if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
return fmt.Errorf("password contains none-alphanumeric symbol at index %d", i)
}
}
return nil
}
// ValidateUserName validates the passed string against the user name criteria.
func ValidateUserName(userName string) error {
if len([]rune(userName)) < minUserNameLength {
return fmt.Errorf("userName must be at least %d characters long", minUserNameLength)
}
for i, c := range userName {
if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
return fmt.Errorf("userName contains none-alphanumeric symbol at index %d", i)
}
}
return nil
}
// GenerateHashFromPassword generates a hash of the passed password. Returns salt and salted & peppered hash.
func (db CustomDB) GenerateHashFromPassword(password string) (passwordHash []byte, err error) {
hashBytes, err := bcrypt.GenerateFromPassword(bytes.Join([][]byte{db.secret, []byte(password)}, nil), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
return hashBytes, nil
}
// GetRefreshTokenOwner takes a tokenId and return the owner's userId. Call before Update/Get/DeleteRefreshToken when not IS_admin.
func (db CustomDB) GetRefreshTokenOwner(tokenId string) (string, error) {
numTokenId, err := strconv.Atoi(tokenId)
if err != nil {
return "", errors.New("invalid tokenId")
}
statement, err := db.connection.Prepare("SELECT FK_User_userId, FROM RefreshToken WHERE tokenId = ?")
if err != nil {
return "", err
}
result := statement.QueryRow(numTokenId)
var owner string
if err := result.Scan(&owner); err != nil {
return "", err
}
return owner, nil
}
func (db CustomDB) ValidateUserCredentials(userId *string, userName *string, password string) (validUserId string, err error) {
var result *sql.Row
var hash string
if userId != nil { // user userId
numUserId, err := strconv.Atoi(*userId)
if err != nil {
return "", errors.New("userId not numeric")
}
statement, err := db.connection.Prepare("SELECT passwordHash FROM User WHERE userId = ?")
if err != nil {
return "", err
}
result = statement.QueryRow(numUserId)
if err := result.Scan(&hash); err != nil {
return "", err
}
} else if userName != nil { // user userName
statement, err := db.connection.Prepare("SELECT userId, passwordHash FROM User WHERE userName = ?")
if err != nil {
return "", err
}
result = statement.QueryRow(&userName)
if err := result.Scan(&userId, &hash); err != nil {
return "", err
}
} else {
return "", errors.New("neither userId nor userName specified")
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), bytes.Join([][]byte{db.secret, []byte(password)}, nil)); err != nil {
return "", err
}
return *userId, nil
}
// IssueRefreshToken issues a refresh token if the passed user credentials are valid. Returned refresh token can be passed to IssueAccessToken.
func (db CustomDB) IssueRefreshToken(userId string, tokenName *string) (refreshToken *RefreshToken, refreshTokenId string, err error) {
numUserId, err := strconv.Atoi(userId)
if err != nil {
return nil, "", errors.New("userId not numeric")
}
selector := make([]byte, 9)
if _, err := rand.Read(selector); err != nil {
return nil, "", err
}
token := make([]byte, 33)
if _, err := rand.Read(token); err != nil {
return nil, "", err
}
statement, err := db.connection.Prepare("INSERT INTO RefreshToken (FK_User_userId, selector, tokenHash, expiryDate, tokenName) VALUES (?, ?, ?, unixepoch('now','" + refreshTokenLifetime + "'), NULLIF(?, '')) RETURNING tokenId, expiryDate")
if err != nil {
return nil, "", err
}
encSelector := base64.RawURLEncoding.EncodeToString(selector)
encToken := base64.RawURLEncoding.EncodeToString(token)
tokenHash := sha256.Sum256(token)
var expiryDate int // int(time.Now().AddDate(0, 1, 0).Unix())
var tokenId string
result := statement.QueryRow(numUserId, encSelector, base64.RawURLEncoding.EncodeToString(tokenHash[:]), &tokenName)
if err := result.Scan(&tokenId, &expiryDate); err != nil {
return nil, "", err
}
return &RefreshToken{Selector: encSelector, Token: encToken, ExpiryDate: expiryDate}, tokenId, nil
}
func (db CustomDB) GetRefreshToken(token *model.RefreshToken) (*model.RefreshToken, error) {
numTokenId, err := strconv.Atoi(token.ID)
if err != nil {
return nil, errors.New("invalid tokenId")
}
statement, err := db.connection.Prepare("SELECT expiryDate, tokenName FROM RefreshToken WHERE tokenId = ?")
if err != nil {
return nil, err
}
result := statement.QueryRow(numTokenId)
if err := result.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil {
return nil, err
}
return token, nil
}
func (db CustomDB) GetRefreshTokensFrom(userId string) ([]*model.RefreshToken, error) {
numUserId, err := strconv.Atoi(userId)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("SELECT tokenId, expiaryDate, tokenName FROM RefreshToken WHERE FK_User_userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Query(numUserId)
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.RefreshToken
for rows.Next() {
token := model.RefreshToken{}
if err := rows.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil {
return nil, err
}
all = append(all, &token)
}
return all, nil
}
func (db CustomDB) GetAllRefreshTokens() ([]*model.RefreshToken, error) {
statement, err := db.connection.Prepare("SELECT tokenID, FK_User_userId, expiryDate, tokenName FROM RefreshToken")
if err != nil {
return nil, err
}
rows, err := statement.Query()
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.RefreshToken
for rows.Next() {
var token model.RefreshToken
if err := rows.Scan(&token.ID, &token.UserID, &token.ExpiryDate, &token.TokenName); err != nil {
return nil, err
}
all = append(all, &token)
}
return all, nil
}
func (db CustomDB) UpdateRefreshToken(tokenId string, changes *model.UpdateRefreshToken) (*model.RefreshToken, error) {
numTokenId, err := strconv.Atoi(tokenId)
if err != nil {
return nil, errors.New("invalid tokenId")
}
statement, err := db.connection.Prepare("UPDATE AuthToken SET tokenName = ? WHERE tokenId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(changes.TokenName, numTokenId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
token, err := db.GetRefreshToken(&model.RefreshToken{ID: tokenId})
if err != nil {
return nil, errors.New("failed to get updated token")
}
return token, nil
}
// RevokeRefreshToken revokes the access token matching the tokenId. Also calls RevokeAccessToken.
func (db CustomDB) RevokeRefreshToken(tokenId string) (*string, error) {
// TODO: return string instead of *string
numTokenId, err := strconv.Atoi(tokenId)
if err != nil {
return nil, errors.New("invalid tokenId")
}
statement, err := db.connection.Prepare("DELETE FROM RefreshToken WHERE tokenId = ? RETURNING FK_User_userId")
if err != nil {
return nil, err
}
result := statement.QueryRow(numTokenId)
var userId string
if err := result.Scan(&userId); err != nil {
return nil, err
}
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
return &tokenId, nil
}
// IssueAccessToken issues an access token if the passed refresh token is valid. Returned access token must be passed to SignAccessToken to be accepted.
func (db CustomDB) IssueAccessToken(refreshToken *RefreshToken) (*AccessToken, error) {
statement, err := db.connection.Prepare("SELECT RefreshToken.tokenHash, RefreshToken.FK_User_userId, Role.IS_admin, ROLE.IS_userCreator FROM RefreshToken INNER JOIN R_User_Role ON RefreshToken.FK_User_userId = R_User_Role.FK_User_userId INNER JOIN Role ON R_User_Role.FK_Role_roleId = Role.roleId WHERE RefreshToken.selector = ? AND RefreshToken.expiryDate >= unixepoch('now')")
if err != nil {
return nil, err
}
result := statement.QueryRow(refreshToken.Selector)
var tokenHash string
var newAccessToken AccessToken
if err := result.Scan(&tokenHash, &newAccessToken.UserId, &newAccessToken.IsAdmin, &newAccessToken.IsUserCreator); err != nil {
return nil, err
}
decUserToken, err := base64.RawURLEncoding.DecodeString(refreshToken.Token)
if err != nil {
return nil, err
}
userTokenHash := sha256.Sum256(decUserToken)
encUserTokenHash := base64.RawURLEncoding.EncodeToString(userTokenHash[:])
if encUserTokenHash != tokenHash {
return nil, errors.New("failed to issue access token: refreshToken does not match")
}
newAccessToken.ExpiryDate = int(time.Now().Add(accessTokenLifetime).Unix())
return &newAccessToken, nil
}
// SignAccessToken signs an access token and attaches a header. Returns access token encoded as jwt.
func (db CustomDB) SignAccessToken(accessToken AccessToken) (encAccessToken string, err error) {
data, err := json.Marshal(accessToken)
if err != nil {
return "", err
}
body := base64.RawURLEncoding.EncodeToString(data)
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
combined := header + "." + body
mac := hmac.New(sha256.New, db.secret)
mac.Write([]byte(combined))
signature := mac.Sum(nil)
return combined + "." + base64.RawURLEncoding.EncodeToString(signature), nil
}
func (db CustomDB) CheckAccessToken(encAccessToken string) (accessToken *AccessToken, err error) {
accessTokenBlob := strings.Split(encAccessToken, ".")
header, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[0])
if err != nil {
return nil, err
}
if !bytes.Equal(header, []byte(`{"alg":"HS256","typ":"JWT"}`)) {
return nil, errors.New("access token header wrong")
}
body, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[1])
if err != nil {
return nil, err
}
var data AccessToken
err = json.Unmarshal(body, &data)
if err != nil {
return nil, err
}
// TODO: maybe change part of signKey instead?
for _, revokedAccessToken := range RevokedAccessTokens {
if data.UserId == revokedAccessToken.UserId && data.ExpiryDate <= revokedAccessToken.ExpiryDate {
return nil, errors.New("access token revoked")
}
}
if data.ExpiryDate < int(time.Now().Unix()) {
return nil, errors.New("access token expired")
}
if data.ExpiryDate < initTimeStamp+int(accessTokenLifetime.Seconds()) {
return nil, errors.New("access token expired prematurely")
}
signature, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[2])
if err != nil {
return nil, err
}
mac := hmac.New(sha256.New, db.secret)
mac.Write([]byte(accessTokenBlob[0]))
mac.Write([]byte("."))
mac.Write([]byte(accessTokenBlob[1]))
expectedSignature := mac.Sum(nil)
if !hmac.Equal(signature, expectedSignature) {
return nil, errors.New("access token signature does not match")
}
return &data, nil
}
// RevokeAccessToken revokes all access tokens with matching UserId and UserRole that don't have a later ExpiryDate.
// revokedAccessToken.ExpiryDate should be set to now + token-lifetime.
func RevokeAccessToken(accessToken *AccessToken) {
RevokedAccessTokens = append(RevokedAccessTokens, accessToken)
}
// CleanRevokedTokensTicker removes expired tokens from the list. This should be called in an interval of > accessTokenLifetime.
func (db CustomDB) CleanRevokedAccessTokensTicker(interval time.Duration) (stopCleaner chan bool) {
ticker := time.NewTicker(interval)
stop := make(chan bool)
go func() {
for {
select {
case <-stop:
return
case <-ticker.C:
db.logger.Println("cleaning revoked access tokens")
for i, revokedAccessToken := range RevokedAccessTokens {
if revokedAccessToken.ExpiryDate < int(time.Now().Unix()) {
revokedAccessToken = nil
slices.Delete(RevokedAccessTokens, i, i+1)
}
}
}
}
}()
return stop
}
func (db CustomDB) CleanExpiredRefreshTokensTicker(interval time.Duration) (stopCleaner chan bool) {
ticker := time.NewTicker(interval)
stop := make(chan bool)
go func() {
for {
select {
case <-stop:
return
case <-ticker.C:
db.logger.Println("cleaning expired refresh tokens")
_, err := db.connection.Exec("DELETE FROM RefreshToken WHERE expiryDate < unixepoch('now')")
if err != nil {
db.logger.Println("failed to clean expired refresh tokens")
}
}
}
}()
return stop
}

View File

@ -18,10 +18,9 @@ package database
import ( import (
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"log" "log"
"strconv" "time"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model" "somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
) )
@ -30,11 +29,14 @@ type CustomDB struct {
connection *sql.DB connection *sql.DB
logger *log.Logger logger *log.Logger
schema uint schema uint
secret []byte
} }
func InitSQLite3(path string, schema uint, logger *log.Logger) *CustomDB { var initTimeStamp int
db := CustomDB{logger: logger, schema: schema} func InitSQLite3(path string, schema uint, logger *log.Logger, secret []byte, initialAdminName string, initialAdminPassword string) *CustomDB {
initTimeStamp = int(time.Now().Unix())
db := CustomDB{logger: logger, schema: schema, secret: secret}
var err error var err error
db.connection, err = sql.Open("sqlite3", "file:"+path+"?_foreign_keys=1") db.connection, err = sql.Open("sqlite3", "file:"+path+"?_foreign_keys=1")
@ -58,6 +60,10 @@ func InitSQLite3(path string, schema uint, logger *log.Logger) *CustomDB {
if err = db.createSQLite3Tables(); err != nil { if err = db.createSQLite3Tables(); err != nil {
db.logger.Fatalln("Error in creating table: ", err) db.logger.Fatalln("Error in creating table: ", err)
} }
err = db.CreateInitialAdmin(initialAdminName, initialAdminPassword)
if err != nil {
db.logger.Fatal("failed to create initial admin. Try to fix and delete old database file: ", err)
}
case user_version > db.schema: case user_version > db.schema:
db.logger.Fatalln("Incompatible database schema version. Try updating this software.") db.logger.Fatalln("Incompatible database schema version. Try updating this software.")
case user_version < db.schema: case user_version < db.schema:
@ -71,8 +77,11 @@ func (db CustomDB) createSQLite3Tables() error {
name string name string
sql string sql string
}{ }{
{"User", "userId INTEGER PRIMARY KEY NOT NULL, userName VARCHAR NOT NULL UNIQUE, fullName VARCHAR"}, {"User", "userId INTEGER PRIMARY KEY NOT NULL, userName VARCHAR NOT NULL UNIQUE CHECK(length(userName)!=0), passwordHash VARCHAR NOT NULL CHECK(length(passwordHash)!=0), fullName VARCHAR CHECK(length(fullName)!=0)"},
{"Todo", "todoId INTEGER PRIMARY KEY NOT NULL, text VARCHAR NOT NULL, IS_done BOOL NOT NULL DEFAULT false, FK_User_userId INTEGER NOT NULL, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"}, {"Todo", "todoId INTEGER PRIMARY KEY NOT NULL, text VARCHAR NOT NULL CHECK(length(text)!=0), IS_done BOOL NOT NULL, FK_User_userId INTEGER NOT NULL, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
{"R_User_Role", "relationId INTEGER PRIMARY KEY NOT NULL, FK_Role_roleId INTEGER NOT NULL, FK_User_userId INTEGER NOT NULL, UNIQUE(FK_Role_roleId, FK_User_userId), FOREIGN KEY(FK_Role_roleId) REFERENCES Role(roleId) ON UPDATE CASCADE ON DELETE CASCADE, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
{"Role", "roleId INTEGER PRIMARY KEY NOT NULL, roleName VARCHAR NOT NULL UNIQUE CHECK(length(roleName)!=0), IS_admin BOOL NOT NULL, IS_userCreator BOOL NOT NULL"},
{"RefreshToken", "tokenId INTEGER PRIMARY KEY NOT NULL, FK_User_userId INTEGER NOT NULL, selector VARCHAR NOT NULL CHECK(length(selector)!=0) UNIQUE, tokenHash VARCHAR NOT NULL CHECK(length(tokenHash)!=0), expiryDate INTEGER NOT NULL, tokenName VARCHAR CHECK(length(tokenName)!=0), FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
} }
for _, table := range tables { for _, table := range tables {
_, err := db.connection.Exec("CREATE TABLE IF NOT EXISTS " + table.name + " (" + table.sql + ")") _, err := db.connection.Exec("CREATE TABLE IF NOT EXISTS " + table.name + " (" + table.sql + ")")
@ -91,262 +100,18 @@ func (db CustomDB) createSQLite3Tables() error {
return nil return nil
} }
func (db CustomDB) GetUser(user *model.User) (*model.User, error) { func (db CustomDB) CreateInitialAdmin(initialAdminName string, initialAdminPassword string) error {
id, err := strconv.Atoi(user.ID) role, err := db.CreateRole(&model.NewRole{RoleName: "admin", IsAdmin: true, IsUserCreator: true})
if err != nil { if err != nil {
return nil, errors.New("invalid userId") return err
} }
statement, err := db.connection.Prepare("SELECT userName, fullName FROM User WHERE userId = ?") user, err := db.CreateUser(model.NewUser{UserName: initialAdminName, Password: initialAdminPassword})
if err != nil { if err != nil {
return nil, err return err
} }
_, err = db.AddRole(user.ID, role.ID)
result := statement.QueryRow(id) if err != nil {
if err := result.Scan(&user.UserName, &user.FullName); err != nil { return err
return nil, err
} }
return nil
return user, nil
}
func (db CustomDB) GetTodo(todo *model.Todo) (*model.Todo, error) {
id, err := strconv.Atoi(todo.ID)
if err != nil {
return nil, errors.New("invalid todoId")
}
statement, err := db.connection.Prepare("SELECT text, IS_done FROM Todo WHERE todoId = ?")
if err != nil {
return nil, err
}
result := statement.QueryRow(id)
if err := result.Scan(&todo.Text, &todo.Done); err != nil {
return nil, err
}
return todo, nil
}
func (db CustomDB) GetTodosFrom(user *model.User) ([]*model.Todo, error) {
id, err := strconv.Atoi(user.ID)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("SELECT todoId, text, IS_done FROM Todo WHERE FK_User_userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Query(id)
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.Todo
for rows.Next() {
todo := model.Todo{User: user}
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done); err != nil {
return nil, err
}
all = append(all, &todo)
}
return all, nil
}
func (db CustomDB) GetAllUsers() ([]*model.User, error) {
rows, err := db.connection.Query("SELECT userId, userName, fullName FROM User")
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.User
for rows.Next() {
var user model.User
if err := rows.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
return nil, err
}
all = append(all, &user)
}
return all, nil
}
func (db CustomDB) GetAllTodos() ([]*model.Todo, error) {
rows, err := db.connection.Query("SELECT Todo.todoID, Todo.text, Todo.IS_done, User.userID FROM Todo INNER JOIN User ON Todo.FK_User_userID=User.userID")
if err != nil {
return nil, err
}
defer rows.Close()
var todos []*model.Todo
for rows.Next() {
var todo = model.Todo{User: &model.User{}}
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done, &todo.User.ID); err != nil {
return nil, err
}
todos = append(todos, &todo)
}
return todos, nil
}
func (db CustomDB) AddUser(newUser model.NewUser) (*model.User, error) {
statement, err := db.connection.Prepare("INSERT INTO User (userName, fullName) VALUES (?, ?)")
if err != nil {
return nil, err
}
rows, err := statement.Exec(newUser.UserName, newUser.FullName)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
insertId, err := rows.LastInsertId()
if err != nil {
return nil, err
}
return &model.User{ID: strconv.FormatInt(insertId, 10), UserName: newUser.UserName, FullName: newUser.FullName}, nil
}
func (db CustomDB) AddTodo(newTodo model.NewTodo) (*model.Todo, error) {
statement, err := db.connection.Prepare("INSERT INTO Todo (text, FK_User_userID) VALUES (?, ?)")
if err != nil {
return nil, err
}
rows, err := statement.Exec(newTodo.Text, newTodo.UserID)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
insertId, err := rows.LastInsertId()
if err != nil {
return nil, err
}
return &model.Todo{ID: strconv.FormatInt(insertId, 10), Text: newTodo.Text, Done: false}, nil
}
func (db CustomDB) UpdateUser(userId string, changes *model.UpdateUser) (*model.User, error) {
id, err := strconv.Atoi(userId)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("UPDATE User SET userName = IFNULL(?, userName), fullName = IFNULL(?, fullName) WHERE userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(changes.UserName, changes.FullName, id)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return db.GetUser(&model.User{ID: userId})
}
func (db CustomDB) UpdateTodo(todoId string, changes *model.UpdateTodo) (*model.Todo, error) {
id, err := strconv.Atoi(todoId)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("UPDATE Todo SET text = IFNULL(?, text), IS_done = IFNULL(?, IS_done) WHERE todoId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(changes.Text, changes.Done, id)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return db.GetTodo(&model.Todo{ID: todoId})
}
func (db CustomDB) DeleteUser(userId string) (*string, error) {
statement, err := db.connection.Prepare("DELETE FROM User WHERE userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(userId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return &userId, nil
}
func (db CustomDB) DeleteTodo(todoId string) (*string, error) {
statement, err := db.connection.Prepare("DELETE FROM Todo WHERE todoId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(todoId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return &todoId, nil
} }

168
database/role.go Normal file
View File

@ -0,0 +1,168 @@
/*
YetAnotherToDoList
Copyright © 2023 gilex-dev gilex-dev@proton.me
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package database
import (
"errors"
"strconv"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
)
func (db CustomDB) CreateRole(newRole *model.NewRole) (role *model.Role, err error) {
statement, err := db.connection.Prepare("INSERT INTO Role (roleName, IS_admin, IS_userCreator) VALUES (?, ?, ?)")
if err != nil {
return nil, err
}
rows, err := statement.Exec(newRole.RoleName, newRole.IsAdmin, newRole.IsUserCreator)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
insertId, err := rows.LastInsertId()
if err != nil {
return nil, err
}
return &model.Role{ID: strconv.FormatInt(insertId, 10), RoleName: newRole.RoleName, IsAdmin: newRole.IsAdmin, IsUserCreator: newRole.IsUserCreator}, nil
}
func (db CustomDB) GetRole(role *model.Role) (*model.Role, error) {
id, err := strconv.Atoi(role.ID)
if err != nil {
return nil, errors.New("invalid roleId")
}
statement, err := db.connection.Prepare("SELECT roleName, IS_admin, IS_userCreator FROM Role WHERE roleId = ?")
if err != nil {
return nil, err
}
result := statement.QueryRow(id)
if err := result.Scan(&role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
return nil, err
}
return role, nil
}
func (db CustomDB) GetRolesFrom(userId string) ([]*model.Role, error) {
numUserId, err := strconv.Atoi(userId)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("SELECT Role.roleId, Role.roleName, Role.IS_admin, Role.IS_userCreator FROM Role INNER JOIN R_User_Role ON R_User_Role.FK_Role_roleId = Role.roleId WHERE R_User_Role.FK_User_userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Query(numUserId)
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.Role
for rows.Next() {
role := model.Role{}
if err := rows.Scan(&role.ID, &role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
return nil, err
}
all = append(all, &role)
}
return all, nil
}
func (db CustomDB) GetAllRoles() ([]*model.Role, error) {
rows, err := db.connection.Query("SELECT roleId, roleName, IS_admin, IS_userCreator FROM Role")
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.Role
for rows.Next() {
var role model.Role
if err := rows.Scan(&role.ID, &role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
return nil, err
}
all = append(all, &role)
}
return all, nil
}
func (db CustomDB) UpdateRole(roleId string, changes *model.UpdateRole) (*model.Role, error) {
id, err := strconv.Atoi(roleId)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("UPDATE Role SET roleName = IFNULL(?, roleName), IS_admin = IFNULL(?, IS_admin), IS_userCreator = IFNULL(?, IS_userCreator) WHERE roleId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(changes.RoleName, changes.IsAdmin, changes.IsUserCreator, id)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return db.GetRole(&model.Role{ID: roleId})
}
func (db CustomDB) DeleteRole(roleId string) (*string, error) {
statement, err := db.connection.Prepare("DELETE FROM Role WHERE roleId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(roleId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return &roleId, nil
}

199
database/todo.go Normal file
View File

@ -0,0 +1,199 @@
/*
YetAnotherToDoList
Copyright © 2023 gilex-dev gilex-dev@proton.me
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package database
import (
"errors"
"strconv"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
)
// GetOwner takes a todoId and return the owner's userId. Call before Update/Get/DeleteTodo when not IS_admin.
func (db CustomDB) GetOwner(todoId string) (string, error) {
numTodoId, err := strconv.Atoi(todoId)
if err != nil {
return "", errors.New("invalid todoId")
}
statement, err := db.connection.Prepare("SELECT FK_User_userId, FROM Todo WHERE todoId = ?")
if err != nil {
return "", err
}
result := statement.QueryRow(numTodoId)
var owner string
if err := result.Scan(&owner); err != nil {
return "", err
}
return owner, nil
}
// GetTodo takes a *model.Todo with at least ID set and adds the missing fields.
func (db CustomDB) GetTodo(todo *model.Todo) (*model.Todo, error) {
numTodoId, err := strconv.Atoi(todo.ID)
if err != nil {
return nil, errors.New("invalid todoId")
}
statement, err := db.connection.Prepare("SELECT text, IS_done FROM Todo WHERE todoId = ?")
if err != nil {
return nil, err
}
result := statement.QueryRow(numTodoId)
if err := result.Scan(&todo.Text, &todo.Done); err != nil {
return nil, err
}
return todo, nil
}
// GetTodosFrom gets all todos from the passed *model.User. ID must be set.
func (db CustomDB) GetTodosFrom(user *model.User) ([]*model.Todo, error) {
id, err := strconv.Atoi(user.ID)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("SELECT todoId, text, IS_done FROM Todo WHERE FK_User_userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Query(id)
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.Todo
for rows.Next() {
todo := model.Todo{User: user}
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done); err != nil {
return nil, err
}
all = append(all, &todo)
}
return all, nil
}
// GetAllTodos gets all todos from the database. Check if the source has the rights to call this.
func (db CustomDB) GetAllTodos() ([]*model.Todo, error) {
rows, err := db.connection.Query("SELECT todoID, text, IS_done, FK_User_userID FROM Todo")
if err != nil {
return nil, err
}
defer rows.Close()
var todos []*model.Todo
for rows.Next() {
var todo = model.Todo{User: &model.User{}}
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done, &todo.User.ID); err != nil {
return nil, err
}
todos = append(todos, &todo)
}
return todos, nil
}
// AddTodo adds a Todo to passed UserID. Check if the source has the rights to call this.
func (db CustomDB) AddTodo(newTodo model.NewTodo) (*model.Todo, error) {
statement, err := db.connection.Prepare("INSERT INTO Todo (text, FK_User_userID) VALUES (?, ?)")
if err != nil {
return nil, err
}
rows, err := statement.Exec(newTodo.Text, newTodo.UserID)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
insertId, err := rows.LastInsertId()
if err != nil {
return nil, err
}
return &model.Todo{ID: strconv.FormatInt(insertId, 10), Text: newTodo.Text, Done: false}, nil
}
func (db CustomDB) UpdateTodo(todoId string, changes *model.UpdateTodo) (*model.Todo, error) {
numTodoId, err := strconv.Atoi(todoId)
if err != nil {
return nil, errors.New("invalid todoId")
}
statement, err := db.connection.Prepare("UPDATE Todo SET text = IFNULL(?, text), IS_done = IFNULL(?, IS_done) WHERE todoId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(changes.Text, changes.Done, numTodoId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return db.GetTodo(&model.Todo{ID: todoId})
}
func (db CustomDB) DeleteTodo(todoId string) (deletedTodoId *string, err error) {
numTodoId, err := strconv.Atoi(todoId)
if err != nil {
return nil, errors.New("invalid todoId")
}
statement, err := db.connection.Prepare("DELETE FROM Todo WHERE todoId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(numTodoId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return &todoId, nil
}

257
database/user.go Normal file
View File

@ -0,0 +1,257 @@
/*
YetAnotherToDoList
Copyright © 2023 gilex-dev gilex-dev@proton.me
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package database
import (
"errors"
"strconv"
"time"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
)
func (db CustomDB) CreateUser(newUser model.NewUser) (*model.User, error) {
statement, err := db.connection.Prepare("INSERT INTO User (userName, fullName, passwordHash) VALUES (?, NULLIF(?, ''), ?)")
if err != nil {
return nil, err
}
if ValidateUserName(newUser.UserName); err != nil {
return nil, err
}
if ValidatePassword(newUser.Password); err != nil {
return nil, err
}
passwordHash, err := db.GenerateHashFromPassword(newUser.Password)
if err != nil {
return nil, err
}
rows, err := statement.Exec(newUser.UserName, newUser.FullName, string(passwordHash))
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
insertId, err := rows.LastInsertId()
if err != nil {
return nil, err
}
return &model.User{ID: strconv.FormatInt(insertId, 10), UserName: newUser.UserName, FullName: newUser.FullName}, nil
}
// GetUser takes a *model.User with at least ID or UserName set and adds the missing fields.
func (db CustomDB) GetUser(user *model.User) (*model.User, error) {
numUserId, err := strconv.Atoi(user.ID)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("SELECT userID, userName, fullName FROM User WHERE userId = ? OR userName = ?")
if err != nil {
return nil, err
}
result := statement.QueryRow(numUserId, user.UserName)
if err := result.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
return nil, err
}
return user, nil
}
func (db CustomDB) GetAllUsers() ([]*model.User, error) {
rows, err := db.connection.Query("SELECT userId, userName, fullName FROM User")
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.User
for rows.Next() {
var user model.User
if err := rows.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
return nil, err
}
all = append(all, &user)
}
return all, nil
}
func (db CustomDB) UpdateUser(userId string, changes *model.UpdateUser) (*model.User, error) {
var passwordHash *string
needAccessTokenRefresh := false
id, err := strconv.Atoi(userId)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("UPDATE User SET userName = IFNULL(?, userName), fullName = IFNULL(NULLIF(?, ''), fullName), passwordHash = IFNULL(?, passwordHash) WHERE userId = ?")
if err != nil {
return nil, err
}
if *changes.UserName == "" { // interpret empty string as nil
changes.UserName = nil
} else {
if err := ValidateUserName(*changes.UserName); err != nil {
return nil, err
}
needAccessTokenRefresh = true
}
if *changes.Password == "" { // interpret empty string as nil
passwordHash = nil
} else {
if err := ValidatePassword(*changes.Password); err != nil {
return nil, err
}
passwordHashByte, err := db.GenerateHashFromPassword(*changes.Password)
if err != nil {
return nil, err
}
passwordHashString := string(passwordHashByte)
passwordHash = &passwordHashString
needAccessTokenRefresh = true
}
rows, err := statement.Exec(changes.UserName, changes.FullName, passwordHash, id)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
if needAccessTokenRefresh {
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
}
return db.GetUser(&model.User{ID: userId})
}
func (db CustomDB) DeleteUser(userId string) (*string, error) {
statement, err := db.connection.Prepare("DELETE FROM User WHERE userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(userId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
return &userId, nil
}
func (db CustomDB) AddRole(userId string, roleId string) (relationId string, err error) {
encUserId, err := strconv.Atoi(userId)
if err != nil {
return "", errors.New("invalid userId")
}
encRoleId, err := strconv.Atoi(roleId)
if err != nil {
return "", errors.New("invalid roleId")
}
statement, err := db.connection.Prepare("INSERT INTO R_User_Role (FK_User_userId, FK_Role_roleId) VALUES (?, ?)")
if err != nil {
return "", err
}
rows, err := statement.Exec(encUserId, encRoleId)
if err != nil {
return "", err
}
num, err := rows.RowsAffected()
if err != nil {
return "", err
}
if num < 1 {
return "", errors.New("no rows affected")
}
insertId, err := rows.LastInsertId()
if err != nil {
return "", err
}
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
return strconv.FormatInt(insertId, 10), nil
}
func (db CustomDB) RemoveRole(userId string, roleId string) (relationId string, err error) {
encUserId, err := strconv.Atoi(userId)
if err != nil {
return "", errors.New("invalid userId")
}
encRoleId, err := strconv.Atoi(roleId)
if err != nil {
return "", errors.New("invalid roleId")
}
statement, err := db.connection.Prepare("DELETE FROM R_User_Role WHERE FK_User_userId = ? AND FK_Role_roleId = ?")
if err != nil {
return "", err
}
rows, err := statement.Exec(encUserId, encRoleId)
if err != nil {
return "", err
}
num, err := rows.RowsAffected()
if err != nil {
return "", err
}
if num < 1 {
return "", errors.New("no rows affected")
}
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
return strconv.FormatInt(int64(encRoleId), 10), nil
}

View File

@ -7,7 +7,7 @@ export default defineConfig({
server: { server: {
port: 4243, port: 4243,
proxy: { proxy: {
'^/(api|playground|version)': 'http://localhost:4242/' '^/(api|playground|version|auth)': 'http://localhost:4242/'
} }
}, },
build: { build: {

2
go.mod
View File

@ -4,10 +4,12 @@ go 1.20
require ( require (
github.com/99designs/gqlgen v0.17.37 github.com/99designs/gqlgen v0.17.37
github.com/go-chi/chi v1.5.5
github.com/mattn/go-sqlite3 v1.14.17 github.com/mattn/go-sqlite3 v1.14.17
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/spf13/viper v1.16.0 github.com/spf13/viper v1.16.0
github.com/vektah/gqlparser/v2 v2.5.9 github.com/vektah/gqlparser/v2 v2.5.9
golang.org/x/crypto v0.9.0
) )
require ( require (

4
go.sum
View File

@ -69,6 +69,8 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7
github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY= github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/go-chi/chi v1.5.5 h1:vOB/HbEMt9QqBqErz07QehcOKHaWFtuj87tTDVz2qXE=
github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
@ -216,6 +218,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=

View File

@ -108,3 +108,5 @@ models:
fields: fields:
todos: todos:
resolver: true resolver: true
roles:
resolver: true

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,16 @@
package model package model
type NewRefreshToken struct {
TokenName *string `json:"tokenName,omitempty"`
}
type NewRole struct {
RoleName string `json:"roleName"`
IsAdmin bool `json:"isAdmin"`
IsUserCreator bool `json:"isUserCreator"`
}
type NewTodo struct { type NewTodo struct {
Text string `json:"text"` Text string `json:"text"`
UserID string `json:"userId"` UserID string `json:"userId"`
@ -9,14 +19,34 @@ type NewTodo struct {
type NewUser struct { type NewUser struct {
UserName string `json:"userName"` UserName string `json:"userName"`
FullName string `json:"fullName"` FullName *string `json:"fullName,omitempty"`
Password string `json:"password"`
} }
type User struct { type RefreshToken struct {
ID string `json:"id"` ID string `json:"id"`
UserName string `json:"userName"` ExpiryDate int `json:"expiryDate"`
FullName string `json:"fullName"` TokenName *string `json:"tokenName,omitempty"`
Todos []*Todo `json:"todos"` Selector *string `json:"selector,omitempty"`
Token *string `json:"token,omitempty"`
UserID string `json:"userId"`
}
type Role struct {
ID string `json:"id"`
RoleName string `json:"roleName"`
IsAdmin bool `json:"isAdmin"`
IsUserCreator bool `json:"isUserCreator"`
}
type UpdateRefreshToken struct {
TokenName *string `json:"tokenName,omitempty"`
}
type UpdateRole struct {
RoleName *string `json:"roleName,omitempty"`
IsAdmin *bool `json:"isAdmin,omitempty"`
IsUserCreator *bool `json:"isUserCreator,omitempty"`
} }
type UpdateTodo struct { type UpdateTodo struct {
@ -27,4 +57,13 @@ type UpdateTodo struct {
type UpdateUser struct { type UpdateUser struct {
UserName *string `json:"userName,omitempty"` UserName *string `json:"userName,omitempty"`
FullName *string `json:"fullName,omitempty"` FullName *string `json:"fullName,omitempty"`
Password *string `json:"password,omitempty"`
}
type User struct {
ID string `json:"id"`
UserName string `json:"userName"`
FullName *string `json:"fullName,omitempty"`
Todos []*Todo `json:"todos"`
Roles []*Role `json:"roles"`
} }

View File

@ -27,20 +27,42 @@ type Todo {
type User { type User {
id: ID! id: ID!
userName: String! userName: String!
fullName: String! fullName: String
todos: [Todo!]! todos: [Todo!]!
roles: [Role!]!
}
type Role {
id: ID!
roleName: String!
isAdmin: Boolean!
isUserCreator: Boolean!
}
type RefreshToken {
id: ID!
expiryDate: Int!
tokenName: String
selector: String
token: String
userId: String!
} }
type Query { type Query {
todos: [Todo!]! todos: [Todo!]!
users: [User!]! users: [User!]!
roles: [Role!]!
refreshTokens: [RefreshToken!]!
user(id: ID!): User! user(id: ID!): User!
todo(id: ID!): Todo! todo(id: ID!): Todo!
role(id: ID!): Role!
refreshToken(id: ID!): RefreshToken!
} }
input NewUser { input NewUser {
userName: String! userName: String!
fullName: String! fullName: String
password: String!
} }
input NewTodo { input NewTodo {
@ -48,21 +70,50 @@ input NewTodo {
userId: ID! userId: ID!
} }
input updateTodo { input NewRole {
roleName: String!
isAdmin: Boolean!
isUserCreator: Boolean!
}
input NewRefreshToken {
tokenName: String
}
input UpdateTodo {
text: String text: String
done: Boolean done: Boolean
} }
input updateUser { input UpdateUser {
userName: String userName: String
fullName: String fullName: String
password: String
}
input UpdateRole {
roleName: String
isAdmin: Boolean
isUserCreator: Boolean
}
input UpdateRefreshToken {
tokenName: String
} }
type Mutation { type Mutation {
createUser(input: NewUser!): User! createUser(input: NewUser!): User!
createTodo(input: NewTodo!): Todo! createTodo(input: NewTodo!): Todo!
updateTodo(id: ID!, changes: updateTodo!): Todo! createRole(input: NewRole!): Role!
updateUser(id: ID!, changes: updateUser!): User! createRefreshToken(input: NewRefreshToken!): RefreshToken!
updateTodo(id: ID!, changes: UpdateTodo!): Todo!
updateUser(id: ID!, changes: UpdateUser!): User!
updateRole(id: ID!, changes: UpdateRole!): Role!
updateRefreshToken(id: ID!, changes: UpdateRefreshToken!): RefreshToken!
deleteUser(id: ID!): ID deleteUser(id: ID!): ID
deleteTodo(id: ID!): ID deleteTodo(id: ID!): ID
deleteRole(id: ID!): ID
deleteRefreshToken(id: ID!): ID
addRole(userId: ID!, roleId: ID!): [Role!]!
removeRole(userId: ID!, roleId: ID!): [Role!]!
} }

View File

@ -10,11 +10,12 @@ import (
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals" "somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model" "somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/server/auth"
) )
// CreateUser is the resolver for the createUser field. // CreateUser is the resolver for the createUser field.
func (r *mutationResolver) CreateUser(ctx context.Context, input model.NewUser) (*model.User, error) { func (r *mutationResolver) CreateUser(ctx context.Context, input model.NewUser) (*model.User, error) {
todo, err := globals.DB.AddUser(input) todo, err := globals.DB.CreateUser(input)
if err != nil { if err != nil {
globals.Logger.Println("Failed to add new user:", err) globals.Logger.Println("Failed to add new user:", err)
return nil, errors.New("failed to add new user") return nil, errors.New("failed to add new user")
@ -32,6 +33,28 @@ func (r *mutationResolver) CreateTodo(ctx context.Context, input model.NewTodo)
return todo, nil return todo, nil
} }
// CreateRole is the resolver for the createRole field.
func (r *mutationResolver) CreateRole(ctx context.Context, input model.NewRole) (*model.Role, error) {
role, err := globals.DB.CreateRole(&input)
if err != nil {
globals.Logger.Println("Failed to add new role:", err)
return nil, errors.New("failed to add new role")
}
return role, nil
}
// CreateRefreshToken is the resolver for the createRefreshToken field.
func (r *mutationResolver) CreateRefreshToken(ctx context.Context, input model.NewRefreshToken) (*model.RefreshToken, error) {
// TODO: unify model.RefreshToken & auth.RefreshToken
userToken := auth.ForContext(ctx)
refreshToken, tokenId, err := globals.DB.IssueRefreshToken(userToken.UserId, input.TokenName)
if err != nil {
globals.Logger.Println("Failed to create refresh token:", err)
return nil, errors.New("failed to create refresh token")
}
return &model.RefreshToken{ID: tokenId, ExpiryDate: refreshToken.ExpiryDate, TokenName: input.TokenName, Selector: &refreshToken.Selector, Token: &refreshToken.Token, UserID: userToken.UserId}, nil
}
// UpdateTodo is the resolver for the updateTodo field. // UpdateTodo is the resolver for the updateTodo field.
func (r *mutationResolver) UpdateTodo(ctx context.Context, id string, changes model.UpdateTodo) (*model.Todo, error) { func (r *mutationResolver) UpdateTodo(ctx context.Context, id string, changes model.UpdateTodo) (*model.Todo, error) {
return globals.DB.UpdateTodo(id, &changes) return globals.DB.UpdateTodo(id, &changes)
@ -42,6 +65,16 @@ func (r *mutationResolver) UpdateUser(ctx context.Context, id string, changes mo
return globals.DB.UpdateUser(id, &changes) return globals.DB.UpdateUser(id, &changes)
} }
// UpdateRole is the resolver for the updateRole field.
func (r *mutationResolver) UpdateRole(ctx context.Context, id string, changes model.UpdateRole) (*model.Role, error) {
return globals.DB.UpdateRole(id, &changes)
}
// UpdateRefreshToken is the resolver for the updateRefreshToken field.
func (r *mutationResolver) UpdateRefreshToken(ctx context.Context, id string, changes model.UpdateRefreshToken) (*model.RefreshToken, error) {
return globals.DB.UpdateRefreshToken(id, &changes)
}
// DeleteUser is the resolver for the deleteUser field. // DeleteUser is the resolver for the deleteUser field.
func (r *mutationResolver) DeleteUser(ctx context.Context, id string) (*string, error) { func (r *mutationResolver) DeleteUser(ctx context.Context, id string) (*string, error) {
return globals.DB.DeleteUser(id) return globals.DB.DeleteUser(id)
@ -52,6 +85,32 @@ func (r *mutationResolver) DeleteTodo(ctx context.Context, id string) (*string,
return globals.DB.DeleteTodo(id) return globals.DB.DeleteTodo(id)
} }
// DeleteRole is the resolver for the deleteRole field.
func (r *mutationResolver) DeleteRole(ctx context.Context, id string) (*string, error) {
return globals.DB.DeleteRole(id)
}
// DeleteRefreshToken is the resolver for the deleteRefreshToken field.
func (r *mutationResolver) DeleteRefreshToken(ctx context.Context, id string) (*string, error) {
return globals.DB.RevokeRefreshToken(id)
}
// AddRole is the resolver for the addRole field.
func (r *mutationResolver) AddRole(ctx context.Context, userID string, roleID string) ([]*model.Role, error) {
if _, err := globals.DB.AddRole(userID, roleID); err != nil {
return nil, err
}
return globals.DB.GetRolesFrom(userID)
}
// RemoveRole is the resolver for the removeRole field.
func (r *mutationResolver) RemoveRole(ctx context.Context, userID string, roleID string) ([]*model.Role, error) {
if _, err := globals.DB.RemoveRole(userID, roleID); err != nil {
return nil, err
}
return globals.DB.GetRolesFrom(userID)
}
// Todos is the resolver for the todos field. // Todos is the resolver for the todos field.
func (r *queryResolver) Todos(ctx context.Context) ([]*model.Todo, error) { func (r *queryResolver) Todos(ctx context.Context) ([]*model.Todo, error) {
return globals.DB.GetAllTodos() return globals.DB.GetAllTodos()
@ -62,6 +121,16 @@ func (r *queryResolver) Users(ctx context.Context) ([]*model.User, error) {
return globals.DB.GetAllUsers() return globals.DB.GetAllUsers()
} }
// Roles is the resolver for the roles field.
func (r *queryResolver) Roles(ctx context.Context) ([]*model.Role, error) {
return globals.DB.GetAllRoles()
}
// RefreshTokens is the resolver for the refreshTokens field.
func (r *queryResolver) RefreshTokens(ctx context.Context) ([]*model.RefreshToken, error) {
return globals.DB.GetAllRefreshTokens()
}
// User is the resolver for the user field. // User is the resolver for the user field.
func (r *queryResolver) User(ctx context.Context, id string) (*model.User, error) { func (r *queryResolver) User(ctx context.Context, id string) (*model.User, error) {
return globals.DB.GetUser(&model.User{ID: id}) return globals.DB.GetUser(&model.User{ID: id})
@ -72,6 +141,16 @@ func (r *queryResolver) Todo(ctx context.Context, id string) (*model.Todo, error
return globals.DB.GetTodo(&model.Todo{ID: id}) return globals.DB.GetTodo(&model.Todo{ID: id})
} }
// Role is the resolver for the role field.
func (r *queryResolver) Role(ctx context.Context, id string) (*model.Role, error) {
return globals.DB.GetRole(&model.Role{ID: id})
}
// RefreshToken is the resolver for the refreshToken field.
func (r *queryResolver) RefreshToken(ctx context.Context, id string) (*model.RefreshToken, error) {
return globals.DB.GetRefreshToken(&model.RefreshToken{ID: id})
}
// User is the resolver for the user field. // User is the resolver for the user field.
func (r *todoResolver) User(ctx context.Context, obj *model.Todo) (*model.User, error) { func (r *todoResolver) User(ctx context.Context, obj *model.Todo) (*model.User, error) {
// TODO: implement dataloader // TODO: implement dataloader
@ -83,6 +162,11 @@ func (r *userResolver) Todos(ctx context.Context, obj *model.User) ([]*model.Tod
return globals.DB.GetTodosFrom(obj) return globals.DB.GetTodosFrom(obj)
} }
// Roles is the resolver for the roles field.
func (r *userResolver) Roles(ctx context.Context, obj *model.User) ([]*model.Role, error) {
return globals.DB.GetRolesFrom(obj.ID)
}
// Mutation returns MutationResolver implementation. // Mutation returns MutationResolver implementation.
func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} } func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} }

91
server/auth/README.md Normal file
View File

@ -0,0 +1,91 @@
# Authentication in Golang
## Use header
Use the http header or the body, but avoid using cookies to transport tokens etc
(because of CSRF)
## Implementation
```
# create password/user
update_password(password):
generate salt
db.store(hash(salt + password))
#login
get_token(userId, password):
if not hash(salt + password) == db.get(salted_hash_of_password): # use timing-attack resistant compare here
return FAILED
generate selector
db.store(selector)
generate salt
generate auth_token
db.store(salt, hash(salt+auth_token))
return selector:auth_token
#authenticate
validate_token(selector, auth_token):
if not hash(salt+auth_token) == db.get(salt, salted_hash_of_auth_token WHERE selector): # use timing-attack resistant compare here
return UNKNOWN_TOKEN
return AUTHENTICATED
```
idea: replace selector with userId?
## JWT
```json
{
"userId": "id",
"userRole": "role",
"expiryTime": "now+10min"
}
```
We use JWT with a lifespan of 10min and an in-memory db to blacklist revoked
tokens. So if for e.g. a user changes it's password, we would add the userId and
the time of the change to the blacklist, filtering out all tokens that have been
issued before.
After a server restart, all tokens will become invalid as well, since we can not
be sure which ones were 'on the blacklist'. This could be mitigated by making
the blacklist persistent during restarts.
If a token has expired (either by a server restart or after 10min), a new token
is requested with a 'long lived' refresh-token (lifetime of ~1 week) that is
stored in a database.
### Pros:
- less DB lookups in general
### Cons:
- timestamps of blacklisting could become a problem (maybe add 1 second or use
timestamp returned by n-th node when it has received the update).
- increased load on db + service since we need to issue new jwt for everybody.
## SSL/TLS
```bash
openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout keyFile.key -out certFile.crt
```
## CRIME/BREACH http compression attack
While we do not use a CSRF token, headers sent to the `/api` still contain
private data.
If you are using http/1.1 or lower and have compression enabled on your proxy,
you would be .
## CRSF
We check for a http header like this `X-YOURSITE-CSRF-PROTECTION=1`. This should
be enough, according to
[cheatsheetseries.owasp.org](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#custom-request-headers)
## XSS
We rely on Vue.js's ability to escape user-input in templates.

139
server/auth/main.go Normal file
View File

@ -0,0 +1,139 @@
/*
YetAnotherToDoList
Copyright © 2023 gilex-dev gilex-dev@proton.me
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package auth
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/database"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals"
)
var userCtxKey = &contextKey{"user"}
type contextKey struct {
name string
}
func Middleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
const headerPrefix = "Bearer "
authField := r.Header.Get("Authorization")
if !strings.HasPrefix(authField, headerPrefix) {
http.Error(w, fmt.Sprintf(`{"error":"wrong token type, expect '%s'"}`, headerPrefix), http.StatusBadRequest)
return
}
// get the user from the database
user, err := globals.DB.CheckAccessToken(strings.TrimPrefix(authField, headerPrefix))
if err != nil {
http.Error(w, strings.ReplaceAll(fmt.Sprintf(`{"error":"failed token check: '%s'"}`, err), "\\", "\\\\"), http.StatusInternalServerError)
return
}
// put it in context
ctx := context.WithValue(r.Context(), userCtxKey, user)
// and call the next with our new context
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
}
func ForContext(ctx context.Context) *database.AccessToken {
raw, _ := ctx.Value(userCtxKey).(*database.AccessToken)
return raw
}
type userAuth struct {
UserId *string `json:"userId"`
UserName *string `json:"userName"`
Password string `json:"password"`
}
func IssueRefreshTokenHandler(w http.ResponseWriter, r *http.Request) {
const headerPrefix = "Login "
authField := r.Header.Get("Authorization")
if !strings.HasPrefix(authField, headerPrefix) {
http.Error(w, fmt.Sprintf("wrong token type, expect %q", headerPrefix), http.StatusBadRequest)
return
}
userCredentials := userAuth{}
err := json.Unmarshal([]byte(strings.TrimPrefix(authField, headerPrefix)), &userCredentials)
if err != nil {
http.Error(w, "malformed or missing user credentials", http.StatusBadRequest)
return
}
userId, err := globals.DB.ValidateUserCredentials(userCredentials.UserId, userCredentials.UserName, userCredentials.Password)
if err != nil {
http.Error(w, "invalid credentials", http.StatusUnauthorized)
return
}
refreshToken, _, err := globals.DB.IssueRefreshToken(userId, nil)
if err != nil {
http.Error(w, "failed to issue refresh token", http.StatusInternalServerError)
return
}
jsonRefreshToken, err := json.Marshal(refreshToken)
if err != nil {
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
w.Write([]byte(jsonRefreshToken))
}
func IssueAccessTokenHandler(w http.ResponseWriter, r *http.Request) {
const headerPrefix = "Refresh "
authField := r.Header.Get("Authorization")
if !strings.HasPrefix(authField, headerPrefix) {
http.Error(w, fmt.Sprintf("wrong token type, expect %q", headerPrefix), http.StatusBadRequest)
return
}
refreshToken := database.RefreshToken{}
err := json.Unmarshal([]byte(strings.TrimPrefix(authField, headerPrefix)), &refreshToken)
if err != nil {
http.Error(w, "malformed or missing refresh token", http.StatusBadRequest)
return
}
accessToken, err := globals.DB.IssueAccessToken(&refreshToken)
if err != nil {
http.Error(w, "invalid access token", http.StatusUnauthorized)
return
}
encAccessToken, err := globals.DB.SignAccessToken(*accessToken)
if err != nil {
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
w.Write([]byte(encAccessToken))
}

View File

@ -20,11 +20,24 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
package server package server
import ( import (
"fmt"
"github.com/99designs/gqlgen/graphql/playground" "github.com/99designs/gqlgen/graphql/playground"
"github.com/go-chi/chi"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals" "somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals"
) )
func handleDevTools(port int) { func handleDevTools(router *chi.Mux, portHTTP int, portHTTPS int) {
mux.Handle("/playground", playground.Handler("GraphQL playground", "/api")) router.Handle("/playground", playground.Handler("GraphQL playground", "/api"))
globals.Logger.Printf("connect to http://localhost:%v/ for GraphQL playground", port) url := "connect to "
if portHTTP != -1 && portHTTPS != -1 {
url += fmt.Sprintf("http://localhost:%v/ or https://localhost:%v/", portHTTP, portHTTPS)
} else if portHTTP != -1 {
url += fmt.Sprintf("http://localhost:%v/", portHTTP)
} else if portHTTPS != -1 {
url += fmt.Sprintf("https://localhost:%v/", portHTTPS)
} else {
return
}
globals.Logger.Println(url + " for GraphQL playground")
} }

View File

@ -19,5 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package server package server
func handleDevTools(_ int) { import "github.com/go-chi/chi"
func handleDevTools(router *chi.Mux, _ int, _ int) {
} }

View File

@ -24,19 +24,21 @@ import (
"io/fs" "io/fs"
"log" "log"
"net/http" "net/http"
"github.com/go-chi/chi"
) )
//go:embed dist/* //go:embed dist/*
var frontend embed.FS var frontend embed.FS
func handleFrontend() { func handleFrontend(router *chi.Mux) {
stripped, err := fs.Sub(frontend, "dist") stripped, err := fs.Sub(frontend, "dist")
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err)
} }
frontendFS := http.FileServer(http.FS(stripped)) frontendFS := http.FileServer(http.FS(stripped))
mux.Handle("/assets/", frontendFS) router.Handle("/assets/*", frontendFS)
mux.HandleFunc("/", indexHandler) router.HandleFunc("/", indexHandler)
// TODO: redirect from vue to 404 page (on go/proxy server) // TODO: redirect from vue to 404 page (on go/proxy server)
} }

View File

@ -19,5 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package server package server
func handleFrontend() { import "github.com/go-chi/chi"
func handleFrontend(router *chi.Mux) {
} }

View File

@ -22,24 +22,68 @@ import (
"strconv" "strconv"
"github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/handler"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals" "somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph" "somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/server/auth"
) )
var mux *http.ServeMux func StartServer(portHTTP int, portHTTPS int, certFile string, keyFile string) {
router := chi.NewRouter()
router.Use(middleware.StripSlashes)
func StartServer(port int) { handleDevTools(router, portHTTP, portHTTPS) // controlled by 'dev' tag
mux = http.NewServeMux() handleFrontend(router) // controlled by 'headless' tag
router.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) {
handleDevTools(port) // controlled by 'dev' tag
handleFrontend() // controlled by 'headless' tag
mux.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%s %s", globals.Version, globals.CommitHash) fmt.Fprintf(w, "%s %s", globals.Version, globals.CommitHash)
}) })
srv := handler.NewDefaultServer(graph.NewExecutableSchema(graph.Config{Resolvers: &graph.Resolver{}})) srv := handler.NewDefaultServer(graph.NewExecutableSchema(graph.Config{Resolvers: &graph.Resolver{}}))
mux.Handle("/api", srv) router.HandleFunc("/auth/login", auth.IssueRefreshTokenHandler)
router.HandleFunc("/auth", auth.IssueAccessTokenHandler)
router.Group(func(r chi.Router) {
r.Use(auth.Middleware())
r.Handle("/api", srv)
r.HandleFunc("/protected", func(w http.ResponseWriter, r *http.Request) {
fmt.Printf("user is %+v\n", auth.ForContext(r.Context()))
})
})
// router.HandleFunc("/auth/", func(w http.ResponseWriter, r *http.Request) {
// http.Redirect(w, r, "/auth", http.StatusMovedPermanently)
// })
err := <-listen(portHTTP, portHTTPS, certFile, keyFile, router)
globals.Logger.Fatalf("Could not start serving service due to (error: %s)", err)
globals.Logger.Fatal(http.ListenAndServe(":"+strconv.FormatInt(int64(port), 10), mux)) }
func listen(portHTTP int, portHTTPS int, certFile string, keyFile string, router *chi.Mux) chan error {
errs := make(chan error)
// Starting HTTP server
if portHTTP != -1 {
go func() {
globals.Logger.Printf("Staring HTTP service on %d ...", portHTTP)
if err := http.ListenAndServe(":"+strconv.FormatInt(int64(portHTTP), 10), router); err != nil {
errs <- err
}
}()
}
// Starting HTTPS server
if portHTTPS != -1 && certFile != "" && keyFile != "" {
go func() {
globals.Logger.Printf("Staring HTTPS service on %d ...", portHTTPS)
if err := http.ListenAndServeTLS(":"+strconv.FormatInt(int64(portHTTPS), 10), certFile, keyFile, router); err != nil {
errs <- err
}
}()
}
return errs
} }