add chi-router, auth middleware & user roles.
group config options & split database logic
This commit is contained in:
		
							parent
							
								
									e4c9563961
								
							
						
					
					
						commit
						a21494f94b
					
				@ -1,4 +1,16 @@
 | 
			
		||||
sqlite3_file: 'YetAnotherToDoList.sqlite3'
 | 
			
		||||
log_file: 'YetAnotherToDoList.log'
 | 
			
		||||
log_UTC: false
 | 
			
		||||
port: 4242
 | 
			
		||||
database:
 | 
			
		||||
 sqlite3File: 'YetAnotherToDoList.sqlite3'
 | 
			
		||||
 secret: 'aS3cureAppl1cationk3y'
 | 
			
		||||
 initialAdmin:
 | 
			
		||||
  userName: 'admin'
 | 
			
		||||
  password: 'temporaryPassword'
 | 
			
		||||
 | 
			
		||||
logging:
 | 
			
		||||
 logFile: 'YetAnotherToDoList.log'
 | 
			
		||||
 logUTC: false
 | 
			
		||||
 | 
			
		||||
server:
 | 
			
		||||
 portHTTP: 4242
 | 
			
		||||
 portHTTPS: 4241
 | 
			
		||||
 certFile: 'certFile.crt'
 | 
			
		||||
 keyFile: 'keyFile.key'
 | 
			
		||||
 | 
			
		||||
@ -62,3 +62,8 @@ Commands were run in the order listed below on a debian based system.
 | 
			
		||||
  pnpm install villus graphql
 | 
			
		||||
  pnpm install graphql-tag
 | 
			
		||||
  ```
 | 
			
		||||
- Add go-chi
 | 
			
		||||
  ```bash
 | 
			
		||||
  go get -u github.com/go-chi/chi/v5
 | 
			
		||||
  go mod tidy
 | 
			
		||||
  ```
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										26
									
								
								cmd/root.go
									
									
									
									
									
								
							
							
						
						
									
										26
									
								
								cmd/root.go
									
									
									
									
									
								
							@ -74,8 +74,8 @@ func init() {
 | 
			
		||||
	// will be global for your application.
 | 
			
		||||
 | 
			
		||||
	rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.YetAnotherToDoList.yaml)")
 | 
			
		||||
	rootCmd.PersistentFlags().String("log_file", "", "Path to log file")
 | 
			
		||||
	rootCmd.PersistentFlags().String("sqlite3_file", "", "Path to SQLite3 database")
 | 
			
		||||
	rootCmd.PersistentFlags().String("logFile", "", "Path to log file")
 | 
			
		||||
	rootCmd.PersistentFlags().String("sqlite3File", "", "Path to SQLite3 database")
 | 
			
		||||
 | 
			
		||||
	// Cobra also supports local flags, which will only run
 | 
			
		||||
	// when this action is called directly.
 | 
			
		||||
@ -120,7 +120,7 @@ func initLog() {
 | 
			
		||||
	time_zone_local, _ := time.Now().Zone()
 | 
			
		||||
	time_zone_offset := strings.Split(time.Now().In(time.Local).String(), " ")[2]
 | 
			
		||||
 | 
			
		||||
	if viper.GetBool("log_UTC") {
 | 
			
		||||
	if viper.GetBool("logging.logUTC") {
 | 
			
		||||
		utc = log.LUTC
 | 
			
		||||
		time_zone_use = "UTC"
 | 
			
		||||
		time_zone_alt = time_zone_local
 | 
			
		||||
@ -140,25 +140,25 @@ func initLog() {
 | 
			
		||||
	logger_flags := log.Ldate | log.Ltime | utc
 | 
			
		||||
	globals.Logger = log.New(os.Stdout, "", logger_flags)
 | 
			
		||||
 | 
			
		||||
	if err := viper.BindPFlag("log_file", rootCmd.Flags().Lookup("log_file")); err != nil {
 | 
			
		||||
	if err := viper.BindPFlag("logging.logFile", rootCmd.Flags().Lookup("logFile")); err != nil {
 | 
			
		||||
		fmt.Println("Unable to bind flag:", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if viper.GetString("log_file") != "" {
 | 
			
		||||
		log_path, err := filepath.Abs(viper.GetString("log_file"))
 | 
			
		||||
	if viper.GetString("logging.logFile") != "" {
 | 
			
		||||
		log_path, err := filepath.Abs(viper.GetString("logging.logFile"))
 | 
			
		||||
 | 
			
		||||
		globals.Logger.SetOutput(os.Stdout)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			globals.Logger.Println("Invalid path for log file", log_path)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		log_file, err := os.OpenFile(log_path, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
 | 
			
		||||
		logFile, err := os.OpenFile(log_path, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			globals.Logger.Println("Failed to write to log file:", err)
 | 
			
		||||
		} else {
 | 
			
		||||
			globals.Logger.Println("Switching to log file", log_path)
 | 
			
		||||
			globals.Logger.SetOutput(log_file)
 | 
			
		||||
			globals.Logger.SetOutput(logFile)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -170,19 +170,21 @@ func initLog() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func initDB() {
 | 
			
		||||
	if err := viper.BindPFlag("sqlite3_file", rootCmd.Flags().Lookup("sqlite3_file")); err != nil {
 | 
			
		||||
	if err := viper.BindPFlag("database.sqlite3File", rootCmd.Flags().Lookup("sqlite3File")); err != nil {
 | 
			
		||||
		fmt.Println("Unable to bind flag:", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if viper.GetString("sqlite3_file") == "" {
 | 
			
		||||
	if viper.GetString("database.sqlite3File") == "" {
 | 
			
		||||
		globals.Logger.Fatalln("No SQLite3 file specified")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db_path, err := filepath.Abs(viper.GetString("sqlite3_file"))
 | 
			
		||||
	db_path, err := filepath.Abs(viper.GetString("database.sqlite3File"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		globals.Logger.Fatalln("Invalid path for SQLite3 file", db_path)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	globals.Logger.Println("Connecting to SQLite3", db_path)
 | 
			
		||||
	globals.DB = database.InitSQLite3(db_path, globals.DB_schema, globals.Logger)
 | 
			
		||||
	globals.DB = database.InitSQLite3(db_path, globals.DB_schema, globals.Logger, []byte(viper.GetString("database.secret")), viper.GetString("database.initialAdmin.userName"), viper.GetString("database.initialAdmin.password"))
 | 
			
		||||
	globals.DB.CleanExpiredRefreshTokensTicker(time.Minute * 10)
 | 
			
		||||
	globals.DB.CleanRevokedAccessTokensTicker(time.Minute * 10)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -38,7 +38,7 @@ var serverCmd = &cobra.Command{
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		globals.Logger.Println("starting http server...")
 | 
			
		||||
		server.StartServer(viper.GetInt("port"))
 | 
			
		||||
		server.StartServer(viper.GetInt("server.portHTTP"), viper.GetInt("server.portHTTPS"), viper.GetString("server.certFile"), viper.GetString("server.keyFile"))
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										475
									
								
								database/crypto_helpers.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										475
									
								
								database/crypto_helpers.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,475 @@
 | 
			
		||||
/*
 | 
			
		||||
YetAnotherToDoList
 | 
			
		||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
 | 
			
		||||
 | 
			
		||||
This program is free software: you can redistribute it and/or modify
 | 
			
		||||
it under the terms of the GNU General Public License as published by
 | 
			
		||||
the Free Software Foundation, version 3.
 | 
			
		||||
 | 
			
		||||
This program is distributed in the hope that it will be useful,
 | 
			
		||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
GNU General Public License for more details.
 | 
			
		||||
 | 
			
		||||
You should have received a copy of the GNU General Public License
 | 
			
		||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
*/
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"crypto/hmac"
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/crypto/bcrypt"
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type RefreshToken struct {
 | 
			
		||||
	Selector   string `json:"selector"`
 | 
			
		||||
	Token      string `json:"token"`
 | 
			
		||||
	ExpiryDate int    `json:"expiryDate"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AccessToken struct {
 | 
			
		||||
	UserId        string `json:"userId"`
 | 
			
		||||
	IsAdmin       bool   `json:"isAdmin"`
 | 
			
		||||
	IsUserCreator bool   `json:"isUserCreator"`
 | 
			
		||||
	ExpiryDate    int    `json:"expiryDate"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const refreshTokenLifetime = "+10 day" // TODO: add to viper
 | 
			
		||||
const accessTokenLifetime = time.Minute * 10
 | 
			
		||||
 | 
			
		||||
const minPasswordLength = 10
 | 
			
		||||
const minUserNameLength = 10
 | 
			
		||||
 | 
			
		||||
var RevokedAccessTokens []*AccessToken
 | 
			
		||||
 | 
			
		||||
// ValidatePassword validates the passed string against the password criteria.
 | 
			
		||||
func ValidatePassword(password string) error {
 | 
			
		||||
	if len([]rune(password)) < minPasswordLength {
 | 
			
		||||
		return fmt.Errorf("password must be at least %d characters long", minPasswordLength)
 | 
			
		||||
	}
 | 
			
		||||
	for i, c := range password {
 | 
			
		||||
		if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
 | 
			
		||||
			return fmt.Errorf("password contains none-alphanumeric symbol at index %d", i)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ValidateUserName validates the passed string against the user name criteria.
 | 
			
		||||
func ValidateUserName(userName string) error {
 | 
			
		||||
	if len([]rune(userName)) < minUserNameLength {
 | 
			
		||||
		return fmt.Errorf("userName must be at least %d characters long", minUserNameLength)
 | 
			
		||||
	}
 | 
			
		||||
	for i, c := range userName {
 | 
			
		||||
		if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
 | 
			
		||||
			return fmt.Errorf("userName contains none-alphanumeric symbol at index %d", i)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GenerateHashFromPassword generates a hash of the passed password. Returns salt and salted & peppered hash.
 | 
			
		||||
func (db CustomDB) GenerateHashFromPassword(password string) (passwordHash []byte, err error) {
 | 
			
		||||
	hashBytes, err := bcrypt.GenerateFromPassword(bytes.Join([][]byte{db.secret, []byte(password)}, nil), bcrypt.DefaultCost)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return hashBytes, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetRefreshTokenOwner takes a tokenId and return the owner's userId. Call before Update/Get/DeleteRefreshToken when not IS_admin.
 | 
			
		||||
func (db CustomDB) GetRefreshTokenOwner(tokenId string) (string, error) {
 | 
			
		||||
	numTokenId, err := strconv.Atoi(tokenId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.New("invalid tokenId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT FK_User_userId, FROM RefreshToken WHERE tokenId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := statement.QueryRow(numTokenId)
 | 
			
		||||
	var owner string
 | 
			
		||||
	if err := result.Scan(&owner); err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return owner, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) ValidateUserCredentials(userId *string, userName *string, password string) (validUserId string, err error) {
 | 
			
		||||
	var result *sql.Row
 | 
			
		||||
	var hash string
 | 
			
		||||
 | 
			
		||||
	if userId != nil { // user userId
 | 
			
		||||
		numUserId, err := strconv.Atoi(*userId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", errors.New("userId not numeric")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		statement, err := db.connection.Prepare("SELECT passwordHash FROM User WHERE userId = ?")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		result = statement.QueryRow(numUserId)
 | 
			
		||||
		if err := result.Scan(&hash); err != nil {
 | 
			
		||||
			return "", err
 | 
			
		||||
		}
 | 
			
		||||
	} else if userName != nil { // user userName
 | 
			
		||||
		statement, err := db.connection.Prepare("SELECT userId, passwordHash FROM User WHERE userName = ?")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		result = statement.QueryRow(&userName)
 | 
			
		||||
		if err := result.Scan(&userId, &hash); err != nil {
 | 
			
		||||
			return "", err
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		return "", errors.New("neither userId nor userName specified")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := bcrypt.CompareHashAndPassword([]byte(hash), bytes.Join([][]byte{db.secret, []byte(password)}, nil)); err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return *userId, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IssueRefreshToken issues a refresh token if the passed user credentials are valid. Returned refresh token can be passed to IssueAccessToken.
 | 
			
		||||
func (db CustomDB) IssueRefreshToken(userId string, tokenName *string) (refreshToken *RefreshToken, refreshTokenId string, err error) {
 | 
			
		||||
	numUserId, err := strconv.Atoi(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, "", errors.New("userId not numeric")
 | 
			
		||||
	}
 | 
			
		||||
	selector := make([]byte, 9)
 | 
			
		||||
	if _, err := rand.Read(selector); err != nil {
 | 
			
		||||
		return nil, "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token := make([]byte, 33)
 | 
			
		||||
	if _, err := rand.Read(token); err != nil {
 | 
			
		||||
		return nil, "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("INSERT INTO RefreshToken (FK_User_userId, selector, tokenHash, expiryDate, tokenName) VALUES (?, ?, ?, unixepoch('now','" + refreshTokenLifetime + "'), NULLIF(?, '')) RETURNING tokenId, expiryDate")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	encSelector := base64.RawURLEncoding.EncodeToString(selector)
 | 
			
		||||
	encToken := base64.RawURLEncoding.EncodeToString(token)
 | 
			
		||||
 | 
			
		||||
	tokenHash := sha256.Sum256(token)
 | 
			
		||||
 | 
			
		||||
	var expiryDate int // int(time.Now().AddDate(0, 1, 0).Unix())
 | 
			
		||||
	var tokenId string
 | 
			
		||||
	result := statement.QueryRow(numUserId, encSelector, base64.RawURLEncoding.EncodeToString(tokenHash[:]), &tokenName)
 | 
			
		||||
 | 
			
		||||
	if err := result.Scan(&tokenId, &expiryDate); err != nil {
 | 
			
		||||
		return nil, "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &RefreshToken{Selector: encSelector, Token: encToken, ExpiryDate: expiryDate}, tokenId, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) GetRefreshToken(token *model.RefreshToken) (*model.RefreshToken, error) {
 | 
			
		||||
	numTokenId, err := strconv.Atoi(token.ID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid tokenId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT expiryDate, tokenName FROM RefreshToken WHERE tokenId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := statement.QueryRow(numTokenId)
 | 
			
		||||
	if err := result.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return token, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) GetRefreshTokensFrom(userId string) ([]*model.RefreshToken, error) {
 | 
			
		||||
	numUserId, err := strconv.Atoi(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid userId")
 | 
			
		||||
	}
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT tokenId, expiaryDate, tokenName FROM RefreshToken WHERE FK_User_userId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Query(numUserId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	var all []*model.RefreshToken
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		token := model.RefreshToken{}
 | 
			
		||||
		if err := rows.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		all = append(all, &token)
 | 
			
		||||
	}
 | 
			
		||||
	return all, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) GetAllRefreshTokens() ([]*model.RefreshToken, error) {
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT tokenID, FK_User_userId, expiryDate, tokenName FROM RefreshToken")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Query()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	var all []*model.RefreshToken
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var token model.RefreshToken
 | 
			
		||||
		if err := rows.Scan(&token.ID, &token.UserID, &token.ExpiryDate, &token.TokenName); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		all = append(all, &token)
 | 
			
		||||
	}
 | 
			
		||||
	return all, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) UpdateRefreshToken(tokenId string, changes *model.UpdateRefreshToken) (*model.RefreshToken, error) {
 | 
			
		||||
	numTokenId, err := strconv.Atoi(tokenId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid tokenId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("UPDATE AuthToken SET tokenName = ? WHERE tokenId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(changes.TokenName, numTokenId)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token, err := db.GetRefreshToken(&model.RefreshToken{ID: tokenId})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("failed to get updated token")
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	return token, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RevokeRefreshToken revokes the access token matching the tokenId. Also calls RevokeAccessToken.
 | 
			
		||||
func (db CustomDB) RevokeRefreshToken(tokenId string) (*string, error) {
 | 
			
		||||
	// TODO: return string instead of *string
 | 
			
		||||
	numTokenId, err := strconv.Atoi(tokenId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid tokenId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("DELETE FROM RefreshToken WHERE tokenId = ? RETURNING FK_User_userId")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := statement.QueryRow(numTokenId)
 | 
			
		||||
 | 
			
		||||
	var userId string
 | 
			
		||||
	if err := result.Scan(&userId); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
 | 
			
		||||
	return &tokenId, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IssueAccessToken issues an access token if the passed refresh token is valid. Returned access token must be passed to SignAccessToken to be accepted.
 | 
			
		||||
func (db CustomDB) IssueAccessToken(refreshToken *RefreshToken) (*AccessToken, error) {
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT RefreshToken.tokenHash, RefreshToken.FK_User_userId, Role.IS_admin, ROLE.IS_userCreator FROM RefreshToken INNER JOIN R_User_Role ON RefreshToken.FK_User_userId = R_User_Role.FK_User_userId INNER JOIN Role ON R_User_Role.FK_Role_roleId = Role.roleId WHERE RefreshToken.selector = ? AND RefreshToken.expiryDate >= unixepoch('now')")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := statement.QueryRow(refreshToken.Selector)
 | 
			
		||||
 | 
			
		||||
	var tokenHash string
 | 
			
		||||
	var newAccessToken AccessToken
 | 
			
		||||
	if err := result.Scan(&tokenHash, &newAccessToken.UserId, &newAccessToken.IsAdmin, &newAccessToken.IsUserCreator); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	decUserToken, err := base64.RawURLEncoding.DecodeString(refreshToken.Token)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userTokenHash := sha256.Sum256(decUserToken)
 | 
			
		||||
	encUserTokenHash := base64.RawURLEncoding.EncodeToString(userTokenHash[:])
 | 
			
		||||
	if encUserTokenHash != tokenHash {
 | 
			
		||||
		return nil, errors.New("failed to issue access token: refreshToken does not match")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newAccessToken.ExpiryDate = int(time.Now().Add(accessTokenLifetime).Unix())
 | 
			
		||||
	return &newAccessToken, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SignAccessToken signs an access token and attaches a header. Returns access token encoded as jwt.
 | 
			
		||||
func (db CustomDB) SignAccessToken(accessToken AccessToken) (encAccessToken string, err error) {
 | 
			
		||||
 | 
			
		||||
	data, err := json.Marshal(accessToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body := base64.RawURLEncoding.EncodeToString(data)
 | 
			
		||||
	header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
 | 
			
		||||
	combined := header + "." + body
 | 
			
		||||
 | 
			
		||||
	mac := hmac.New(sha256.New, db.secret)
 | 
			
		||||
	mac.Write([]byte(combined))
 | 
			
		||||
	signature := mac.Sum(nil)
 | 
			
		||||
 | 
			
		||||
	return combined + "." + base64.RawURLEncoding.EncodeToString(signature), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) CheckAccessToken(encAccessToken string) (accessToken *AccessToken, err error) {
 | 
			
		||||
	accessTokenBlob := strings.Split(encAccessToken, ".")
 | 
			
		||||
 | 
			
		||||
	header, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[0])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !bytes.Equal(header, []byte(`{"alg":"HS256","typ":"JWT"}`)) {
 | 
			
		||||
		return nil, errors.New("access token header wrong")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[1])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var data AccessToken
 | 
			
		||||
	err = json.Unmarshal(body, &data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO: maybe change part of signKey instead?
 | 
			
		||||
	for _, revokedAccessToken := range RevokedAccessTokens {
 | 
			
		||||
		if data.UserId == revokedAccessToken.UserId && data.ExpiryDate <= revokedAccessToken.ExpiryDate {
 | 
			
		||||
			return nil, errors.New("access token revoked")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if data.ExpiryDate < int(time.Now().Unix()) {
 | 
			
		||||
		return nil, errors.New("access token expired")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if data.ExpiryDate < initTimeStamp+int(accessTokenLifetime.Seconds()) {
 | 
			
		||||
		return nil, errors.New("access token expired prematurely")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	signature, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[2])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mac := hmac.New(sha256.New, db.secret)
 | 
			
		||||
	mac.Write([]byte(accessTokenBlob[0]))
 | 
			
		||||
	mac.Write([]byte("."))
 | 
			
		||||
	mac.Write([]byte(accessTokenBlob[1]))
 | 
			
		||||
	expectedSignature := mac.Sum(nil)
 | 
			
		||||
	if !hmac.Equal(signature, expectedSignature) {
 | 
			
		||||
		return nil, errors.New("access token signature does not match")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &data, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RevokeAccessToken revokes all access tokens with matching UserId and UserRole that don't have a later ExpiryDate.
 | 
			
		||||
// revokedAccessToken.ExpiryDate should be set to now + token-lifetime.
 | 
			
		||||
func RevokeAccessToken(accessToken *AccessToken) {
 | 
			
		||||
	RevokedAccessTokens = append(RevokedAccessTokens, accessToken)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CleanRevokedTokensTicker removes expired tokens from the list. This should be called in an interval of > accessTokenLifetime.
 | 
			
		||||
func (db CustomDB) CleanRevokedAccessTokensTicker(interval time.Duration) (stopCleaner chan bool) {
 | 
			
		||||
	ticker := time.NewTicker(interval)
 | 
			
		||||
	stop := make(chan bool)
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case <-stop:
 | 
			
		||||
				return
 | 
			
		||||
			case <-ticker.C:
 | 
			
		||||
				db.logger.Println("cleaning revoked access tokens")
 | 
			
		||||
				for i, revokedAccessToken := range RevokedAccessTokens {
 | 
			
		||||
					if revokedAccessToken.ExpiryDate < int(time.Now().Unix()) {
 | 
			
		||||
						revokedAccessToken = nil
 | 
			
		||||
						slices.Delete(RevokedAccessTokens, i, i+1)
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	return stop
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) CleanExpiredRefreshTokensTicker(interval time.Duration) (stopCleaner chan bool) {
 | 
			
		||||
	ticker := time.NewTicker(interval)
 | 
			
		||||
	stop := make(chan bool)
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case <-stop:
 | 
			
		||||
				return
 | 
			
		||||
			case <-ticker.C:
 | 
			
		||||
				db.logger.Println("cleaning expired refresh tokens")
 | 
			
		||||
				_, err := db.connection.Exec("DELETE FROM RefreshToken WHERE expiryDate < unixepoch('now')")
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					db.logger.Println("failed to clean expired refresh tokens")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	return stop
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										283
									
								
								database/main.go
									
									
									
									
									
								
							
							
						
						
									
										283
									
								
								database/main.go
									
									
									
									
									
								
							@ -18,10 +18,9 @@ package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
 | 
			
		||||
)
 | 
			
		||||
@ -30,11 +29,14 @@ type CustomDB struct {
 | 
			
		||||
	connection *sql.DB
 | 
			
		||||
	logger     *log.Logger
 | 
			
		||||
	schema     uint
 | 
			
		||||
	secret     []byte
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InitSQLite3(path string, schema uint, logger *log.Logger) *CustomDB {
 | 
			
		||||
var initTimeStamp int
 | 
			
		||||
 | 
			
		||||
	db := CustomDB{logger: logger, schema: schema}
 | 
			
		||||
func InitSQLite3(path string, schema uint, logger *log.Logger, secret []byte, initialAdminName string, initialAdminPassword string) *CustomDB {
 | 
			
		||||
	initTimeStamp = int(time.Now().Unix())
 | 
			
		||||
	db := CustomDB{logger: logger, schema: schema, secret: secret}
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	db.connection, err = sql.Open("sqlite3", "file:"+path+"?_foreign_keys=1")
 | 
			
		||||
@ -58,6 +60,10 @@ func InitSQLite3(path string, schema uint, logger *log.Logger) *CustomDB {
 | 
			
		||||
		if err = db.createSQLite3Tables(); err != nil {
 | 
			
		||||
			db.logger.Fatalln("Error in creating table: ", err)
 | 
			
		||||
		}
 | 
			
		||||
		err = db.CreateInitialAdmin(initialAdminName, initialAdminPassword)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			db.logger.Fatal("failed to create initial admin. Try to fix and delete old database file: ", err)
 | 
			
		||||
		}
 | 
			
		||||
	case user_version > db.schema:
 | 
			
		||||
		db.logger.Fatalln("Incompatible database schema version. Try updating this software.")
 | 
			
		||||
	case user_version < db.schema:
 | 
			
		||||
@ -71,8 +77,11 @@ func (db CustomDB) createSQLite3Tables() error {
 | 
			
		||||
		name string
 | 
			
		||||
		sql  string
 | 
			
		||||
	}{
 | 
			
		||||
		{"User", "userId INTEGER PRIMARY KEY NOT NULL, userName VARCHAR NOT NULL UNIQUE, fullName VARCHAR"},
 | 
			
		||||
		{"Todo", "todoId INTEGER PRIMARY KEY NOT NULL, text VARCHAR NOT NULL, IS_done BOOL NOT NULL DEFAULT false, FK_User_userId INTEGER NOT NULL, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
 | 
			
		||||
		{"User", "userId INTEGER PRIMARY KEY NOT NULL, userName VARCHAR NOT NULL UNIQUE CHECK(length(userName)!=0), passwordHash VARCHAR NOT NULL CHECK(length(passwordHash)!=0), fullName VARCHAR CHECK(length(fullName)!=0)"},
 | 
			
		||||
		{"Todo", "todoId INTEGER PRIMARY KEY NOT NULL, text VARCHAR NOT NULL CHECK(length(text)!=0), IS_done BOOL NOT NULL, FK_User_userId INTEGER NOT NULL, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
 | 
			
		||||
		{"R_User_Role", "relationId INTEGER PRIMARY KEY NOT NULL, FK_Role_roleId INTEGER NOT NULL, FK_User_userId INTEGER NOT NULL, UNIQUE(FK_Role_roleId, FK_User_userId), FOREIGN KEY(FK_Role_roleId) REFERENCES Role(roleId) ON UPDATE CASCADE ON DELETE CASCADE, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
 | 
			
		||||
		{"Role", "roleId INTEGER PRIMARY KEY NOT NULL, roleName VARCHAR NOT NULL UNIQUE CHECK(length(roleName)!=0), IS_admin BOOL NOT NULL, IS_userCreator BOOL NOT NULL"},
 | 
			
		||||
		{"RefreshToken", "tokenId INTEGER PRIMARY KEY NOT NULL, FK_User_userId INTEGER NOT NULL, selector VARCHAR NOT NULL CHECK(length(selector)!=0) UNIQUE, tokenHash VARCHAR NOT NULL CHECK(length(tokenHash)!=0), expiryDate INTEGER NOT NULL, tokenName VARCHAR CHECK(length(tokenName)!=0), FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
 | 
			
		||||
	}
 | 
			
		||||
	for _, table := range tables {
 | 
			
		||||
		_, err := db.connection.Exec("CREATE TABLE IF NOT EXISTS " + table.name + " (" + table.sql + ")")
 | 
			
		||||
@ -91,262 +100,18 @@ func (db CustomDB) createSQLite3Tables() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) GetUser(user *model.User) (*model.User, error) {
 | 
			
		||||
	id, err := strconv.Atoi(user.ID)
 | 
			
		||||
func (db CustomDB) CreateInitialAdmin(initialAdminName string, initialAdminPassword string) error {
 | 
			
		||||
	role, err := db.CreateRole(&model.NewRole{RoleName: "admin", IsAdmin: true, IsUserCreator: true})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid userId")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT userName, fullName FROM User WHERE userId = ?")
 | 
			
		||||
	user, err := db.CreateUser(model.NewUser{UserName: initialAdminName, Password: initialAdminPassword})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := statement.QueryRow(id)
 | 
			
		||||
	if err := result.Scan(&user.UserName, &user.FullName); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	_, err = db.AddRole(user.ID, role.ID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return user, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) GetTodo(todo *model.Todo) (*model.Todo, error) {
 | 
			
		||||
	id, err := strconv.Atoi(todo.ID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid todoId")
 | 
			
		||||
	}
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT text, IS_done FROM Todo WHERE todoId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := statement.QueryRow(id)
 | 
			
		||||
	if err := result.Scan(&todo.Text, &todo.Done); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return todo, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) GetTodosFrom(user *model.User) ([]*model.Todo, error) {
 | 
			
		||||
	id, err := strconv.Atoi(user.ID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid userId")
 | 
			
		||||
	}
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT todoId, text, IS_done FROM Todo WHERE FK_User_userId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Query(id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	var all []*model.Todo
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		todo := model.Todo{User: user}
 | 
			
		||||
		if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		all = append(all, &todo)
 | 
			
		||||
	}
 | 
			
		||||
	return all, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) GetAllUsers() ([]*model.User, error) {
 | 
			
		||||
	rows, err := db.connection.Query("SELECT userId, userName, fullName FROM User")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	var all []*model.User
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var user model.User
 | 
			
		||||
		if err := rows.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		all = append(all, &user)
 | 
			
		||||
	}
 | 
			
		||||
	return all, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) GetAllTodos() ([]*model.Todo, error) {
 | 
			
		||||
	rows, err := db.connection.Query("SELECT Todo.todoID, Todo.text, Todo.IS_done, User.userID FROM Todo INNER JOIN User ON Todo.FK_User_userID=User.userID")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	var todos []*model.Todo
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var todo = model.Todo{User: &model.User{}}
 | 
			
		||||
		if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done, &todo.User.ID); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		todos = append(todos, &todo)
 | 
			
		||||
	}
 | 
			
		||||
	return todos, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) AddUser(newUser model.NewUser) (*model.User, error) {
 | 
			
		||||
	statement, err := db.connection.Prepare("INSERT INTO User (userName, fullName) VALUES (?, ?)")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(newUser.UserName, newUser.FullName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	insertId, err := rows.LastInsertId()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &model.User{ID: strconv.FormatInt(insertId, 10), UserName: newUser.UserName, FullName: newUser.FullName}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) AddTodo(newTodo model.NewTodo) (*model.Todo, error) {
 | 
			
		||||
	statement, err := db.connection.Prepare("INSERT INTO Todo (text, FK_User_userID) VALUES (?, ?)")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(newTodo.Text, newTodo.UserID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	insertId, err := rows.LastInsertId()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &model.Todo{ID: strconv.FormatInt(insertId, 10), Text: newTodo.Text, Done: false}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) UpdateUser(userId string, changes *model.UpdateUser) (*model.User, error) {
 | 
			
		||||
 | 
			
		||||
	id, err := strconv.Atoi(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid userId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("UPDATE User SET userName = IFNULL(?, userName), fullName = IFNULL(?, fullName) WHERE userId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(changes.UserName, changes.FullName, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return db.GetUser(&model.User{ID: userId})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) UpdateTodo(todoId string, changes *model.UpdateTodo) (*model.Todo, error) {
 | 
			
		||||
 | 
			
		||||
	id, err := strconv.Atoi(todoId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid userId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("UPDATE Todo SET text = IFNULL(?, text), IS_done = IFNULL(?, IS_done) WHERE todoId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(changes.Text, changes.Done, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return db.GetTodo(&model.Todo{ID: todoId})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) DeleteUser(userId string) (*string, error) {
 | 
			
		||||
	statement, err := db.connection.Prepare("DELETE FROM User WHERE userId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &userId, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) DeleteTodo(todoId string) (*string, error) {
 | 
			
		||||
	statement, err := db.connection.Prepare("DELETE FROM Todo WHERE todoId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(todoId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &todoId, nil
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										168
									
								
								database/role.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										168
									
								
								database/role.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,168 @@
 | 
			
		||||
/*
 | 
			
		||||
YetAnotherToDoList
 | 
			
		||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
 | 
			
		||||
 | 
			
		||||
This program is free software: you can redistribute it and/or modify
 | 
			
		||||
it under the terms of the GNU General Public License as published by
 | 
			
		||||
the Free Software Foundation, version 3.
 | 
			
		||||
 | 
			
		||||
This program is distributed in the hope that it will be useful,
 | 
			
		||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
GNU General Public License for more details.
 | 
			
		||||
 | 
			
		||||
You should have received a copy of the GNU General Public License
 | 
			
		||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
*/
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) CreateRole(newRole *model.NewRole) (role *model.Role, err error) {
 | 
			
		||||
	statement, err := db.connection.Prepare("INSERT INTO Role (roleName, IS_admin, IS_userCreator) VALUES (?, ?, ?)")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(newRole.RoleName, newRole.IsAdmin, newRole.IsUserCreator)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	insertId, err := rows.LastInsertId()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &model.Role{ID: strconv.FormatInt(insertId, 10), RoleName: newRole.RoleName, IsAdmin: newRole.IsAdmin, IsUserCreator: newRole.IsUserCreator}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) GetRole(role *model.Role) (*model.Role, error) {
 | 
			
		||||
	id, err := strconv.Atoi(role.ID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid roleId")
 | 
			
		||||
	}
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT roleName, IS_admin, IS_userCreator FROM Role WHERE roleId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := statement.QueryRow(id)
 | 
			
		||||
	if err := result.Scan(&role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return role, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) GetRolesFrom(userId string) ([]*model.Role, error) {
 | 
			
		||||
	numUserId, err := strconv.Atoi(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid userId")
 | 
			
		||||
	}
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT Role.roleId, Role.roleName, Role.IS_admin, Role.IS_userCreator FROM Role INNER JOIN R_User_Role ON R_User_Role.FK_Role_roleId = Role.roleId WHERE R_User_Role.FK_User_userId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Query(numUserId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	var all []*model.Role
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		role := model.Role{}
 | 
			
		||||
		if err := rows.Scan(&role.ID, &role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		all = append(all, &role)
 | 
			
		||||
	}
 | 
			
		||||
	return all, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) GetAllRoles() ([]*model.Role, error) {
 | 
			
		||||
	rows, err := db.connection.Query("SELECT roleId, roleName, IS_admin, IS_userCreator FROM Role")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	var all []*model.Role
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var role model.Role
 | 
			
		||||
		if err := rows.Scan(&role.ID, &role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		all = append(all, &role)
 | 
			
		||||
	}
 | 
			
		||||
	return all, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) UpdateRole(roleId string, changes *model.UpdateRole) (*model.Role, error) {
 | 
			
		||||
 | 
			
		||||
	id, err := strconv.Atoi(roleId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid userId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("UPDATE Role SET roleName = IFNULL(?, roleName), IS_admin = IFNULL(?, IS_admin), IS_userCreator = IFNULL(?, IS_userCreator) WHERE roleId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(changes.RoleName, changes.IsAdmin, changes.IsUserCreator, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return db.GetRole(&model.Role{ID: roleId})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) DeleteRole(roleId string) (*string, error) {
 | 
			
		||||
	statement, err := db.connection.Prepare("DELETE FROM Role WHERE roleId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(roleId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &roleId, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										199
									
								
								database/todo.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										199
									
								
								database/todo.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,199 @@
 | 
			
		||||
/*
 | 
			
		||||
YetAnotherToDoList
 | 
			
		||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
 | 
			
		||||
 | 
			
		||||
This program is free software: you can redistribute it and/or modify
 | 
			
		||||
it under the terms of the GNU General Public License as published by
 | 
			
		||||
the Free Software Foundation, version 3.
 | 
			
		||||
 | 
			
		||||
This program is distributed in the hope that it will be useful,
 | 
			
		||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
GNU General Public License for more details.
 | 
			
		||||
 | 
			
		||||
You should have received a copy of the GNU General Public License
 | 
			
		||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
*/
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GetOwner takes a todoId and return the owner's userId. Call before Update/Get/DeleteTodo when not IS_admin.
 | 
			
		||||
func (db CustomDB) GetOwner(todoId string) (string, error) {
 | 
			
		||||
	numTodoId, err := strconv.Atoi(todoId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.New("invalid todoId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT FK_User_userId, FROM Todo WHERE todoId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := statement.QueryRow(numTodoId)
 | 
			
		||||
	var owner string
 | 
			
		||||
	if err := result.Scan(&owner); err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return owner, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetTodo takes a *model.Todo with at least ID set and adds the missing fields.
 | 
			
		||||
func (db CustomDB) GetTodo(todo *model.Todo) (*model.Todo, error) {
 | 
			
		||||
	numTodoId, err := strconv.Atoi(todo.ID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid todoId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT text, IS_done FROM Todo WHERE todoId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := statement.QueryRow(numTodoId)
 | 
			
		||||
	if err := result.Scan(&todo.Text, &todo.Done); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return todo, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetTodosFrom gets all todos from the passed *model.User. ID must be set.
 | 
			
		||||
func (db CustomDB) GetTodosFrom(user *model.User) ([]*model.Todo, error) {
 | 
			
		||||
	id, err := strconv.Atoi(user.ID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid userId")
 | 
			
		||||
	}
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT todoId, text, IS_done FROM Todo WHERE FK_User_userId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Query(id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	var all []*model.Todo
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		todo := model.Todo{User: user}
 | 
			
		||||
		if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		all = append(all, &todo)
 | 
			
		||||
	}
 | 
			
		||||
	return all, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAllTodos gets all todos from the database. Check if the source has the rights to call this.
 | 
			
		||||
func (db CustomDB) GetAllTodos() ([]*model.Todo, error) {
 | 
			
		||||
	rows, err := db.connection.Query("SELECT todoID, text, IS_done, FK_User_userID FROM Todo")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	var todos []*model.Todo
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var todo = model.Todo{User: &model.User{}}
 | 
			
		||||
		if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done, &todo.User.ID); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		todos = append(todos, &todo)
 | 
			
		||||
	}
 | 
			
		||||
	return todos, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AddTodo adds a Todo to passed UserID. Check if the source has the rights to call this.
 | 
			
		||||
func (db CustomDB) AddTodo(newTodo model.NewTodo) (*model.Todo, error) {
 | 
			
		||||
	statement, err := db.connection.Prepare("INSERT INTO Todo (text, FK_User_userID) VALUES (?, ?)")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(newTodo.Text, newTodo.UserID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	insertId, err := rows.LastInsertId()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &model.Todo{ID: strconv.FormatInt(insertId, 10), Text: newTodo.Text, Done: false}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) UpdateTodo(todoId string, changes *model.UpdateTodo) (*model.Todo, error) {
 | 
			
		||||
 | 
			
		||||
	numTodoId, err := strconv.Atoi(todoId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid todoId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("UPDATE Todo SET text = IFNULL(?, text), IS_done = IFNULL(?, IS_done) WHERE todoId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(changes.Text, changes.Done, numTodoId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return db.GetTodo(&model.Todo{ID: todoId})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) DeleteTodo(todoId string) (deletedTodoId *string, err error) {
 | 
			
		||||
	numTodoId, err := strconv.Atoi(todoId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid todoId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("DELETE FROM Todo WHERE todoId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(numTodoId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &todoId, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										257
									
								
								database/user.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										257
									
								
								database/user.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,257 @@
 | 
			
		||||
/*
 | 
			
		||||
YetAnotherToDoList
 | 
			
		||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
 | 
			
		||||
 | 
			
		||||
This program is free software: you can redistribute it and/or modify
 | 
			
		||||
it under the terms of the GNU General Public License as published by
 | 
			
		||||
the Free Software Foundation, version 3.
 | 
			
		||||
 | 
			
		||||
This program is distributed in the hope that it will be useful,
 | 
			
		||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
GNU General Public License for more details.
 | 
			
		||||
 | 
			
		||||
You should have received a copy of the GNU General Public License
 | 
			
		||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
*/
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) CreateUser(newUser model.NewUser) (*model.User, error) {
 | 
			
		||||
	statement, err := db.connection.Prepare("INSERT INTO User (userName, fullName, passwordHash) VALUES (?, NULLIF(?, ''), ?)")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ValidateUserName(newUser.UserName); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ValidatePassword(newUser.Password); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	passwordHash, err := db.GenerateHashFromPassword(newUser.Password)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(newUser.UserName, newUser.FullName, string(passwordHash))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	insertId, err := rows.LastInsertId()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &model.User{ID: strconv.FormatInt(insertId, 10), UserName: newUser.UserName, FullName: newUser.FullName}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetUser takes a *model.User with at least ID or UserName set and adds the missing fields.
 | 
			
		||||
func (db CustomDB) GetUser(user *model.User) (*model.User, error) {
 | 
			
		||||
	numUserId, err := strconv.Atoi(user.ID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid userId")
 | 
			
		||||
	}
 | 
			
		||||
	statement, err := db.connection.Prepare("SELECT userID, userName, fullName FROM User WHERE userId = ? OR userName = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := statement.QueryRow(numUserId, user.UserName)
 | 
			
		||||
	if err := result.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return user, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) GetAllUsers() ([]*model.User, error) {
 | 
			
		||||
	rows, err := db.connection.Query("SELECT userId, userName, fullName FROM User")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	var all []*model.User
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var user model.User
 | 
			
		||||
		if err := rows.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		all = append(all, &user)
 | 
			
		||||
	}
 | 
			
		||||
	return all, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) UpdateUser(userId string, changes *model.UpdateUser) (*model.User, error) {
 | 
			
		||||
	var passwordHash *string
 | 
			
		||||
	needAccessTokenRefresh := false
 | 
			
		||||
 | 
			
		||||
	id, err := strconv.Atoi(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.New("invalid userId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("UPDATE User SET userName = IFNULL(?, userName), fullName = IFNULL(NULLIF(?, ''), fullName), passwordHash = IFNULL(?, passwordHash) WHERE userId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if *changes.UserName == "" { // interpret empty string as nil
 | 
			
		||||
		changes.UserName = nil
 | 
			
		||||
	} else {
 | 
			
		||||
		if err := ValidateUserName(*changes.UserName); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		needAccessTokenRefresh = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if *changes.Password == "" { // interpret empty string as nil
 | 
			
		||||
		passwordHash = nil
 | 
			
		||||
	} else {
 | 
			
		||||
		if err := ValidatePassword(*changes.Password); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		passwordHashByte, err := db.GenerateHashFromPassword(*changes.Password)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		passwordHashString := string(passwordHashByte)
 | 
			
		||||
		passwordHash = &passwordHashString
 | 
			
		||||
		needAccessTokenRefresh = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(changes.UserName, changes.FullName, passwordHash, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if needAccessTokenRefresh {
 | 
			
		||||
		RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return db.GetUser(&model.User{ID: userId})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) DeleteUser(userId string) (*string, error) {
 | 
			
		||||
	statement, err := db.connection.Prepare("DELETE FROM User WHERE userId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return nil, errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
 | 
			
		||||
	return &userId, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) AddRole(userId string, roleId string) (relationId string, err error) {
 | 
			
		||||
	encUserId, err := strconv.Atoi(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.New("invalid userId")
 | 
			
		||||
	}
 | 
			
		||||
	encRoleId, err := strconv.Atoi(roleId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.New("invalid roleId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("INSERT INTO R_User_Role (FK_User_userId, FK_Role_roleId) VALUES (?, ?)")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(encUserId, encRoleId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return "", errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	insertId, err := rows.LastInsertId()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
 | 
			
		||||
	return strconv.FormatInt(insertId, 10), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db CustomDB) RemoveRole(userId string, roleId string) (relationId string, err error) {
 | 
			
		||||
	encUserId, err := strconv.Atoi(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.New("invalid userId")
 | 
			
		||||
	}
 | 
			
		||||
	encRoleId, err := strconv.Atoi(roleId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.New("invalid roleId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statement, err := db.connection.Prepare("DELETE FROM R_User_Role WHERE FK_User_userId = ? AND FK_Role_roleId = ?")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := statement.Exec(encUserId, encRoleId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	num, err := rows.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num < 1 {
 | 
			
		||||
		return "", errors.New("no rows affected")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
 | 
			
		||||
	return strconv.FormatInt(int64(encRoleId), 10), nil
 | 
			
		||||
}
 | 
			
		||||
@ -7,7 +7,7 @@ export default defineConfig({
 | 
			
		||||
	server: {
 | 
			
		||||
		port: 4243,
 | 
			
		||||
		proxy: {
 | 
			
		||||
			'^/(api|playground|version)': 'http://localhost:4242/'
 | 
			
		||||
			'^/(api|playground|version|auth)': 'http://localhost:4242/'
 | 
			
		||||
		}
 | 
			
		||||
	},
 | 
			
		||||
	build: {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							@ -4,10 +4,12 @@ go 1.20
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/99designs/gqlgen v0.17.37
 | 
			
		||||
	github.com/go-chi/chi v1.5.5
 | 
			
		||||
	github.com/mattn/go-sqlite3 v1.14.17
 | 
			
		||||
	github.com/spf13/cobra v1.7.0
 | 
			
		||||
	github.com/spf13/viper v1.16.0
 | 
			
		||||
	github.com/vektah/gqlparser/v2 v2.5.9
 | 
			
		||||
	golang.org/x/crypto v0.9.0
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
								
							@ -69,6 +69,8 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7
 | 
			
		||||
github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
 | 
			
		||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
 | 
			
		||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
 | 
			
		||||
github.com/go-chi/chi v1.5.5 h1:vOB/HbEMt9QqBqErz07QehcOKHaWFtuj87tTDVz2qXE=
 | 
			
		||||
github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw=
 | 
			
		||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
 | 
			
		||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
 | 
			
		||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
 | 
			
		||||
@ -216,6 +218,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
 | 
			
		||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
 | 
			
		||||
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
 | 
			
		||||
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
 | 
			
		||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 | 
			
		||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 | 
			
		||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
 | 
			
		||||
 | 
			
		||||
@ -108,3 +108,5 @@ models:
 | 
			
		||||
  fields:
 | 
			
		||||
   todos:
 | 
			
		||||
    resolver: true
 | 
			
		||||
   roles:
 | 
			
		||||
    resolver: true
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2358
									
								
								graph/generated.go
									
									
									
									
									
								
							
							
						
						
									
										2358
									
								
								graph/generated.go
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -2,21 +2,51 @@
 | 
			
		||||
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
type NewRefreshToken struct {
 | 
			
		||||
	TokenName *string `json:"tokenName,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NewRole struct {
 | 
			
		||||
	RoleName      string `json:"roleName"`
 | 
			
		||||
	IsAdmin       bool   `json:"isAdmin"`
 | 
			
		||||
	IsUserCreator bool   `json:"isUserCreator"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NewTodo struct {
 | 
			
		||||
	Text   string `json:"text"`
 | 
			
		||||
	UserID string `json:"userId"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NewUser struct {
 | 
			
		||||
	UserName string `json:"userName"`
 | 
			
		||||
	FullName string `json:"fullName"`
 | 
			
		||||
	UserName string  `json:"userName"`
 | 
			
		||||
	FullName *string `json:"fullName,omitempty"`
 | 
			
		||||
	Password string  `json:"password"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type User struct {
 | 
			
		||||
	ID       string  `json:"id"`
 | 
			
		||||
	UserName string  `json:"userName"`
 | 
			
		||||
	FullName string  `json:"fullName"`
 | 
			
		||||
	Todos    []*Todo `json:"todos"`
 | 
			
		||||
type RefreshToken struct {
 | 
			
		||||
	ID         string  `json:"id"`
 | 
			
		||||
	ExpiryDate int     `json:"expiryDate"`
 | 
			
		||||
	TokenName  *string `json:"tokenName,omitempty"`
 | 
			
		||||
	Selector   *string `json:"selector,omitempty"`
 | 
			
		||||
	Token      *string `json:"token,omitempty"`
 | 
			
		||||
	UserID     string  `json:"userId"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Role struct {
 | 
			
		||||
	ID            string `json:"id"`
 | 
			
		||||
	RoleName      string `json:"roleName"`
 | 
			
		||||
	IsAdmin       bool   `json:"isAdmin"`
 | 
			
		||||
	IsUserCreator bool   `json:"isUserCreator"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type UpdateRefreshToken struct {
 | 
			
		||||
	TokenName *string `json:"tokenName,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type UpdateRole struct {
 | 
			
		||||
	RoleName      *string `json:"roleName,omitempty"`
 | 
			
		||||
	IsAdmin       *bool   `json:"isAdmin,omitempty"`
 | 
			
		||||
	IsUserCreator *bool   `json:"isUserCreator,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type UpdateTodo struct {
 | 
			
		||||
@ -27,4 +57,13 @@ type UpdateTodo struct {
 | 
			
		||||
type UpdateUser struct {
 | 
			
		||||
	UserName *string `json:"userName,omitempty"`
 | 
			
		||||
	FullName *string `json:"fullName,omitempty"`
 | 
			
		||||
	Password *string `json:"password,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type User struct {
 | 
			
		||||
	ID       string  `json:"id"`
 | 
			
		||||
	UserName string  `json:"userName"`
 | 
			
		||||
	FullName *string `json:"fullName,omitempty"`
 | 
			
		||||
	Todos    []*Todo `json:"todos"`
 | 
			
		||||
	Roles    []*Role `json:"roles"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -27,20 +27,42 @@ type Todo {
 | 
			
		||||
type User {
 | 
			
		||||
	id: ID!
 | 
			
		||||
	userName: String!
 | 
			
		||||
	fullName: String!
 | 
			
		||||
	fullName: String
 | 
			
		||||
	todos: [Todo!]!
 | 
			
		||||
	roles: [Role!]!
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Role {
 | 
			
		||||
	id: ID!
 | 
			
		||||
	roleName: String!
 | 
			
		||||
	isAdmin: Boolean!
 | 
			
		||||
	isUserCreator: Boolean!
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RefreshToken {
 | 
			
		||||
	id: ID!
 | 
			
		||||
	expiryDate: Int!
 | 
			
		||||
	tokenName: String
 | 
			
		||||
	selector: String
 | 
			
		||||
	token: String
 | 
			
		||||
	userId: String!
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Query {
 | 
			
		||||
	todos: [Todo!]!
 | 
			
		||||
	users: [User!]!
 | 
			
		||||
	roles: [Role!]!
 | 
			
		||||
	refreshTokens: [RefreshToken!]!
 | 
			
		||||
	user(id: ID!): User!
 | 
			
		||||
	todo(id: ID!): Todo!
 | 
			
		||||
	role(id: ID!): Role!
 | 
			
		||||
	refreshToken(id: ID!): RefreshToken!
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
input NewUser {
 | 
			
		||||
	userName: String!
 | 
			
		||||
	fullName: String!
 | 
			
		||||
	fullName: String
 | 
			
		||||
	password: String!
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
input NewTodo {
 | 
			
		||||
@ -48,21 +70,50 @@ input NewTodo {
 | 
			
		||||
	userId: ID!
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
input updateTodo {
 | 
			
		||||
input NewRole {
 | 
			
		||||
	roleName: String!
 | 
			
		||||
	isAdmin: Boolean!
 | 
			
		||||
	isUserCreator: Boolean!
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
input NewRefreshToken {
 | 
			
		||||
	tokenName: String
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
input UpdateTodo {
 | 
			
		||||
	text: String
 | 
			
		||||
	done: Boolean
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
input updateUser {
 | 
			
		||||
input UpdateUser {
 | 
			
		||||
	userName: String
 | 
			
		||||
	fullName: String
 | 
			
		||||
	password: String
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
input UpdateRole {
 | 
			
		||||
	roleName: String
 | 
			
		||||
	isAdmin: Boolean
 | 
			
		||||
	isUserCreator: Boolean
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
input UpdateRefreshToken {
 | 
			
		||||
	tokenName: String
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Mutation {
 | 
			
		||||
	createUser(input: NewUser!): User!
 | 
			
		||||
	createTodo(input: NewTodo!): Todo!
 | 
			
		||||
	updateTodo(id: ID!, changes: updateTodo!): Todo!
 | 
			
		||||
	updateUser(id: ID!, changes: updateUser!): User!
 | 
			
		||||
	createRole(input: NewRole!): Role!
 | 
			
		||||
	createRefreshToken(input: NewRefreshToken!): RefreshToken!
 | 
			
		||||
	updateTodo(id: ID!, changes: UpdateTodo!): Todo!
 | 
			
		||||
	updateUser(id: ID!, changes: UpdateUser!): User!
 | 
			
		||||
	updateRole(id: ID!, changes: UpdateRole!): Role!
 | 
			
		||||
	updateRefreshToken(id: ID!, changes: UpdateRefreshToken!): RefreshToken!
 | 
			
		||||
	deleteUser(id: ID!): ID
 | 
			
		||||
	deleteTodo(id: ID!): ID
 | 
			
		||||
	deleteRole(id: ID!): ID
 | 
			
		||||
	deleteRefreshToken(id: ID!): ID
 | 
			
		||||
	addRole(userId: ID!, roleId: ID!): [Role!]!
 | 
			
		||||
	removeRole(userId: ID!, roleId: ID!): [Role!]!
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -10,11 +10,12 @@ import (
 | 
			
		||||
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals"
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/server/auth"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CreateUser is the resolver for the createUser field.
 | 
			
		||||
func (r *mutationResolver) CreateUser(ctx context.Context, input model.NewUser) (*model.User, error) {
 | 
			
		||||
	todo, err := globals.DB.AddUser(input)
 | 
			
		||||
	todo, err := globals.DB.CreateUser(input)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		globals.Logger.Println("Failed to add new user:", err)
 | 
			
		||||
		return nil, errors.New("failed to add new user")
 | 
			
		||||
@ -32,6 +33,28 @@ func (r *mutationResolver) CreateTodo(ctx context.Context, input model.NewTodo)
 | 
			
		||||
	return todo, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateRole is the resolver for the createRole field.
 | 
			
		||||
func (r *mutationResolver) CreateRole(ctx context.Context, input model.NewRole) (*model.Role, error) {
 | 
			
		||||
	role, err := globals.DB.CreateRole(&input)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		globals.Logger.Println("Failed to add new role:", err)
 | 
			
		||||
		return nil, errors.New("failed to add new role")
 | 
			
		||||
	}
 | 
			
		||||
	return role, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateRefreshToken is the resolver for the createRefreshToken field.
 | 
			
		||||
func (r *mutationResolver) CreateRefreshToken(ctx context.Context, input model.NewRefreshToken) (*model.RefreshToken, error) {
 | 
			
		||||
	// TODO: unify model.RefreshToken & auth.RefreshToken
 | 
			
		||||
	userToken := auth.ForContext(ctx)
 | 
			
		||||
	refreshToken, tokenId, err := globals.DB.IssueRefreshToken(userToken.UserId, input.TokenName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		globals.Logger.Println("Failed to create refresh token:", err)
 | 
			
		||||
		return nil, errors.New("failed to create refresh token")
 | 
			
		||||
	}
 | 
			
		||||
	return &model.RefreshToken{ID: tokenId, ExpiryDate: refreshToken.ExpiryDate, TokenName: input.TokenName, Selector: &refreshToken.Selector, Token: &refreshToken.Token, UserID: userToken.UserId}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateTodo is the resolver for the updateTodo field.
 | 
			
		||||
func (r *mutationResolver) UpdateTodo(ctx context.Context, id string, changes model.UpdateTodo) (*model.Todo, error) {
 | 
			
		||||
	return globals.DB.UpdateTodo(id, &changes)
 | 
			
		||||
@ -42,6 +65,16 @@ func (r *mutationResolver) UpdateUser(ctx context.Context, id string, changes mo
 | 
			
		||||
	return globals.DB.UpdateUser(id, &changes)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateRole is the resolver for the updateRole field.
 | 
			
		||||
func (r *mutationResolver) UpdateRole(ctx context.Context, id string, changes model.UpdateRole) (*model.Role, error) {
 | 
			
		||||
	return globals.DB.UpdateRole(id, &changes)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateRefreshToken is the resolver for the updateRefreshToken field.
 | 
			
		||||
func (r *mutationResolver) UpdateRefreshToken(ctx context.Context, id string, changes model.UpdateRefreshToken) (*model.RefreshToken, error) {
 | 
			
		||||
	return globals.DB.UpdateRefreshToken(id, &changes)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteUser is the resolver for the deleteUser field.
 | 
			
		||||
func (r *mutationResolver) DeleteUser(ctx context.Context, id string) (*string, error) {
 | 
			
		||||
	return globals.DB.DeleteUser(id)
 | 
			
		||||
@ -52,6 +85,32 @@ func (r *mutationResolver) DeleteTodo(ctx context.Context, id string) (*string,
 | 
			
		||||
	return globals.DB.DeleteTodo(id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteRole is the resolver for the deleteRole field.
 | 
			
		||||
func (r *mutationResolver) DeleteRole(ctx context.Context, id string) (*string, error) {
 | 
			
		||||
	return globals.DB.DeleteRole(id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteRefreshToken is the resolver for the deleteRefreshToken field.
 | 
			
		||||
func (r *mutationResolver) DeleteRefreshToken(ctx context.Context, id string) (*string, error) {
 | 
			
		||||
	return globals.DB.RevokeRefreshToken(id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AddRole is the resolver for the addRole field.
 | 
			
		||||
func (r *mutationResolver) AddRole(ctx context.Context, userID string, roleID string) ([]*model.Role, error) {
 | 
			
		||||
	if _, err := globals.DB.AddRole(userID, roleID); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return globals.DB.GetRolesFrom(userID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveRole is the resolver for the removeRole field.
 | 
			
		||||
func (r *mutationResolver) RemoveRole(ctx context.Context, userID string, roleID string) ([]*model.Role, error) {
 | 
			
		||||
	if _, err := globals.DB.RemoveRole(userID, roleID); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return globals.DB.GetRolesFrom(userID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Todos is the resolver for the todos field.
 | 
			
		||||
func (r *queryResolver) Todos(ctx context.Context) ([]*model.Todo, error) {
 | 
			
		||||
	return globals.DB.GetAllTodos()
 | 
			
		||||
@ -62,6 +121,16 @@ func (r *queryResolver) Users(ctx context.Context) ([]*model.User, error) {
 | 
			
		||||
	return globals.DB.GetAllUsers()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Roles is the resolver for the roles field.
 | 
			
		||||
func (r *queryResolver) Roles(ctx context.Context) ([]*model.Role, error) {
 | 
			
		||||
	return globals.DB.GetAllRoles()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RefreshTokens is the resolver for the refreshTokens field.
 | 
			
		||||
func (r *queryResolver) RefreshTokens(ctx context.Context) ([]*model.RefreshToken, error) {
 | 
			
		||||
	return globals.DB.GetAllRefreshTokens()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// User is the resolver for the user field.
 | 
			
		||||
func (r *queryResolver) User(ctx context.Context, id string) (*model.User, error) {
 | 
			
		||||
	return globals.DB.GetUser(&model.User{ID: id})
 | 
			
		||||
@ -72,6 +141,16 @@ func (r *queryResolver) Todo(ctx context.Context, id string) (*model.Todo, error
 | 
			
		||||
	return globals.DB.GetTodo(&model.Todo{ID: id})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Role is the resolver for the role field.
 | 
			
		||||
func (r *queryResolver) Role(ctx context.Context, id string) (*model.Role, error) {
 | 
			
		||||
	return globals.DB.GetRole(&model.Role{ID: id})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RefreshToken is the resolver for the refreshToken field.
 | 
			
		||||
func (r *queryResolver) RefreshToken(ctx context.Context, id string) (*model.RefreshToken, error) {
 | 
			
		||||
	return globals.DB.GetRefreshToken(&model.RefreshToken{ID: id})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// User is the resolver for the user field.
 | 
			
		||||
func (r *todoResolver) User(ctx context.Context, obj *model.Todo) (*model.User, error) {
 | 
			
		||||
	// TODO: implement dataloader
 | 
			
		||||
@ -83,6 +162,11 @@ func (r *userResolver) Todos(ctx context.Context, obj *model.User) ([]*model.Tod
 | 
			
		||||
	return globals.DB.GetTodosFrom(obj)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Roles is the resolver for the roles field.
 | 
			
		||||
func (r *userResolver) Roles(ctx context.Context, obj *model.User) ([]*model.Role, error) {
 | 
			
		||||
	return globals.DB.GetRolesFrom(obj.ID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Mutation returns MutationResolver implementation.
 | 
			
		||||
func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										91
									
								
								server/auth/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								server/auth/README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,91 @@
 | 
			
		||||
# Authentication in Golang
 | 
			
		||||
 | 
			
		||||
## Use header
 | 
			
		||||
 | 
			
		||||
Use the http header or the body, but avoid using cookies to transport tokens etc
 | 
			
		||||
(because of CSRF)
 | 
			
		||||
 | 
			
		||||
## Implementation
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
# create password/user
 | 
			
		||||
update_password(password):
 | 
			
		||||
 generate salt
 | 
			
		||||
 db.store(hash(salt + password))
 | 
			
		||||
 | 
			
		||||
#login
 | 
			
		||||
get_token(userId, password):
 | 
			
		||||
  if not hash(salt + password) == db.get(salted_hash_of_password): # use timing-attack resistant compare here
 | 
			
		||||
    return FAILED
 | 
			
		||||
  generate selector
 | 
			
		||||
  db.store(selector)
 | 
			
		||||
		generate salt
 | 
			
		||||
  generate auth_token
 | 
			
		||||
  db.store(salt, hash(salt+auth_token))
 | 
			
		||||
  return selector:auth_token
 | 
			
		||||
 | 
			
		||||
#authenticate
 | 
			
		||||
validate_token(selector, auth_token):
 | 
			
		||||
  if not hash(salt+auth_token) == db.get(salt, salted_hash_of_auth_token WHERE selector): # use timing-attack resistant compare here
 | 
			
		||||
    return UNKNOWN_TOKEN
 | 
			
		||||
  return AUTHENTICATED
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
idea: replace selector with userId?
 | 
			
		||||
 | 
			
		||||
## JWT
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
{
 | 
			
		||||
	"userId": "id",
 | 
			
		||||
	"userRole": "role",
 | 
			
		||||
	"expiryTime": "now+10min"
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
We use JWT with a lifespan of 10min and an in-memory db to blacklist revoked
 | 
			
		||||
tokens. So if for e.g. a user changes it's password, we would add the userId and
 | 
			
		||||
the time of the change to the blacklist, filtering out all tokens that have been
 | 
			
		||||
issued before.
 | 
			
		||||
 | 
			
		||||
After a server restart, all tokens will become invalid as well, since we can not
 | 
			
		||||
be sure which ones were 'on the blacklist'. This could be mitigated by making
 | 
			
		||||
the blacklist persistent during restarts.
 | 
			
		||||
 | 
			
		||||
If a token has expired (either by a server restart or after 10min), a new token
 | 
			
		||||
is requested with a 'long lived' refresh-token (lifetime of ~1 week) that is
 | 
			
		||||
stored in a database.
 | 
			
		||||
 | 
			
		||||
### Pros:
 | 
			
		||||
 | 
			
		||||
- less DB lookups in general
 | 
			
		||||
 | 
			
		||||
### Cons:
 | 
			
		||||
 | 
			
		||||
- timestamps of blacklisting could become a problem (maybe add 1 second or use
 | 
			
		||||
  timestamp returned by n-th node when it has received the update).
 | 
			
		||||
- increased load on db + service since we need to issue new jwt for everybody.
 | 
			
		||||
 | 
			
		||||
## SSL/TLS
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout keyFile.key -out certFile.crt
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## CRIME/BREACH http compression attack
 | 
			
		||||
 | 
			
		||||
While we do not use a CSRF token, headers sent to the `/api` still contain
 | 
			
		||||
private data.
 | 
			
		||||
 | 
			
		||||
If you are using http/1.1 or lower and have compression enabled on your proxy,
 | 
			
		||||
you would be .
 | 
			
		||||
 | 
			
		||||
## CRSF
 | 
			
		||||
 | 
			
		||||
We check for a http header like this `X-YOURSITE-CSRF-PROTECTION=1`. This should
 | 
			
		||||
be enough, according to
 | 
			
		||||
[cheatsheetseries.owasp.org](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#custom-request-headers)
 | 
			
		||||
 | 
			
		||||
## XSS
 | 
			
		||||
 | 
			
		||||
We rely on Vue.js's ability to escape user-input in templates.
 | 
			
		||||
							
								
								
									
										139
									
								
								server/auth/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								server/auth/main.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,139 @@
 | 
			
		||||
/*
 | 
			
		||||
YetAnotherToDoList
 | 
			
		||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
 | 
			
		||||
 | 
			
		||||
This program is free software: you can redistribute it and/or modify
 | 
			
		||||
it under the terms of the GNU General Public License as published by
 | 
			
		||||
the Free Software Foundation, version 3.
 | 
			
		||||
 | 
			
		||||
This program is distributed in the hope that it will be useful,
 | 
			
		||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
GNU General Public License for more details.
 | 
			
		||||
 | 
			
		||||
You should have received a copy of the GNU General Public License
 | 
			
		||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
*/
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/database"
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var userCtxKey = &contextKey{"user"}
 | 
			
		||||
 | 
			
		||||
type contextKey struct {
 | 
			
		||||
	name string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Middleware() func(http.Handler) http.Handler {
 | 
			
		||||
	return func(next http.Handler) http.Handler {
 | 
			
		||||
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			const headerPrefix = "Bearer "
 | 
			
		||||
			authField := r.Header.Get("Authorization")
 | 
			
		||||
			if !strings.HasPrefix(authField, headerPrefix) {
 | 
			
		||||
				http.Error(w, fmt.Sprintf(`{"error":"wrong token type, expect '%s'"}`, headerPrefix), http.StatusBadRequest)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// get the user from the database
 | 
			
		||||
			user, err := globals.DB.CheckAccessToken(strings.TrimPrefix(authField, headerPrefix))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				http.Error(w, strings.ReplaceAll(fmt.Sprintf(`{"error":"failed token check: '%s'"}`, err), "\\", "\\\\"), http.StatusInternalServerError)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// put it in context
 | 
			
		||||
			ctx := context.WithValue(r.Context(), userCtxKey, user)
 | 
			
		||||
			// and call the next with our new context
 | 
			
		||||
			r = r.WithContext(ctx)
 | 
			
		||||
			next.ServeHTTP(w, r)
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ForContext(ctx context.Context) *database.AccessToken {
 | 
			
		||||
	raw, _ := ctx.Value(userCtxKey).(*database.AccessToken)
 | 
			
		||||
	return raw
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type userAuth struct {
 | 
			
		||||
	UserId   *string `json:"userId"`
 | 
			
		||||
	UserName *string `json:"userName"`
 | 
			
		||||
	Password string  `json:"password"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IssueRefreshTokenHandler(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	const headerPrefix = "Login "
 | 
			
		||||
	authField := r.Header.Get("Authorization")
 | 
			
		||||
	if !strings.HasPrefix(authField, headerPrefix) {
 | 
			
		||||
		http.Error(w, fmt.Sprintf("wrong token type, expect %q", headerPrefix), http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userCredentials := userAuth{}
 | 
			
		||||
 | 
			
		||||
	err := json.Unmarshal([]byte(strings.TrimPrefix(authField, headerPrefix)), &userCredentials)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(w, "malformed or missing user credentials", http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId, err := globals.DB.ValidateUserCredentials(userCredentials.UserId, userCredentials.UserName, userCredentials.Password)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(w, "invalid credentials", http.StatusUnauthorized)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	refreshToken, _, err := globals.DB.IssueRefreshToken(userId, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(w, "failed to issue refresh token", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	jsonRefreshToken, err := json.Marshal(refreshToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(w, "internal server error", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	w.Write([]byte(jsonRefreshToken))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IssueAccessTokenHandler(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	const headerPrefix = "Refresh "
 | 
			
		||||
	authField := r.Header.Get("Authorization")
 | 
			
		||||
	if !strings.HasPrefix(authField, headerPrefix) {
 | 
			
		||||
		http.Error(w, fmt.Sprintf("wrong token type, expect %q", headerPrefix), http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	refreshToken := database.RefreshToken{}
 | 
			
		||||
 | 
			
		||||
	err := json.Unmarshal([]byte(strings.TrimPrefix(authField, headerPrefix)), &refreshToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(w, "malformed or missing refresh token", http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	accessToken, err := globals.DB.IssueAccessToken(&refreshToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(w, "invalid access token", http.StatusUnauthorized)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	encAccessToken, err := globals.DB.SignAccessToken(*accessToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(w, "internal server error", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	w.Write([]byte(encAccessToken))
 | 
			
		||||
}
 | 
			
		||||
@ -20,11 +20,24 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
package server
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/99designs/gqlgen/graphql/playground"
 | 
			
		||||
	"github.com/go-chi/chi"
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func handleDevTools(port int) {
 | 
			
		||||
	mux.Handle("/playground", playground.Handler("GraphQL playground", "/api"))
 | 
			
		||||
	globals.Logger.Printf("connect to http://localhost:%v/ for GraphQL playground", port)
 | 
			
		||||
func handleDevTools(router *chi.Mux, portHTTP int, portHTTPS int) {
 | 
			
		||||
	router.Handle("/playground", playground.Handler("GraphQL playground", "/api"))
 | 
			
		||||
	url := "connect to "
 | 
			
		||||
	if portHTTP != -1 && portHTTPS != -1 {
 | 
			
		||||
		url += fmt.Sprintf("http://localhost:%v/ or https://localhost:%v/", portHTTP, portHTTPS)
 | 
			
		||||
	} else if portHTTP != -1 {
 | 
			
		||||
		url += fmt.Sprintf("http://localhost:%v/", portHTTP)
 | 
			
		||||
	} else if portHTTPS != -1 {
 | 
			
		||||
		url += fmt.Sprintf("https://localhost:%v/", portHTTPS)
 | 
			
		||||
	} else {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	globals.Logger.Println(url + " for GraphQL playground")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -19,5 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
*/
 | 
			
		||||
package server
 | 
			
		||||
 | 
			
		||||
func handleDevTools(_ int) {
 | 
			
		||||
import "github.com/go-chi/chi"
 | 
			
		||||
 | 
			
		||||
func handleDevTools(router *chi.Mux, _ int, _ int) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -24,19 +24,21 @@ import (
 | 
			
		||||
	"io/fs"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-chi/chi"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
//go:embed dist/*
 | 
			
		||||
var frontend embed.FS
 | 
			
		||||
 | 
			
		||||
func handleFrontend() {
 | 
			
		||||
func handleFrontend(router *chi.Mux) {
 | 
			
		||||
	stripped, err := fs.Sub(frontend, "dist")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalln(err)
 | 
			
		||||
	}
 | 
			
		||||
	frontendFS := http.FileServer(http.FS(stripped))
 | 
			
		||||
	mux.Handle("/assets/", frontendFS)
 | 
			
		||||
	mux.HandleFunc("/", indexHandler)
 | 
			
		||||
	router.Handle("/assets/*", frontendFS)
 | 
			
		||||
	router.HandleFunc("/", indexHandler)
 | 
			
		||||
	// TODO: redirect from vue to 404 page (on go/proxy server)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -19,5 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
*/
 | 
			
		||||
package server
 | 
			
		||||
 | 
			
		||||
func handleFrontend() {
 | 
			
		||||
import "github.com/go-chi/chi"
 | 
			
		||||
 | 
			
		||||
func handleFrontend(router *chi.Mux) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -22,24 +22,68 @@ import (
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	"github.com/99designs/gqlgen/graphql/handler"
 | 
			
		||||
	"github.com/go-chi/chi"
 | 
			
		||||
	"github.com/go-chi/chi/middleware"
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals"
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph"
 | 
			
		||||
	"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/server/auth"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var mux *http.ServeMux
 | 
			
		||||
func StartServer(portHTTP int, portHTTPS int, certFile string, keyFile string) {
 | 
			
		||||
	router := chi.NewRouter()
 | 
			
		||||
	router.Use(middleware.StripSlashes)
 | 
			
		||||
 | 
			
		||||
func StartServer(port int) {
 | 
			
		||||
	mux = http.NewServeMux()
 | 
			
		||||
 | 
			
		||||
	handleDevTools(port) // controlled by 'dev' tag
 | 
			
		||||
	handleFrontend()     // controlled by 'headless' tag
 | 
			
		||||
 | 
			
		||||
	mux.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	handleDevTools(router, portHTTP, portHTTPS) // controlled by 'dev' tag
 | 
			
		||||
	handleFrontend(router)                      // controlled by 'headless' tag
 | 
			
		||||
	router.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		fmt.Fprintf(w, "%s %s", globals.Version, globals.CommitHash)
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	srv := handler.NewDefaultServer(graph.NewExecutableSchema(graph.Config{Resolvers: &graph.Resolver{}}))
 | 
			
		||||
	mux.Handle("/api", srv)
 | 
			
		||||
	router.HandleFunc("/auth/login", auth.IssueRefreshTokenHandler)
 | 
			
		||||
	router.HandleFunc("/auth", auth.IssueAccessTokenHandler)
 | 
			
		||||
 | 
			
		||||
	router.Group(func(r chi.Router) {
 | 
			
		||||
		r.Use(auth.Middleware())
 | 
			
		||||
		r.Handle("/api", srv)
 | 
			
		||||
		r.HandleFunc("/protected", func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			fmt.Printf("user is %+v\n", auth.ForContext(r.Context()))
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
	// router.HandleFunc("/auth/", func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	// 	http.Redirect(w, r, "/auth", http.StatusMovedPermanently)
 | 
			
		||||
	// })
 | 
			
		||||
 | 
			
		||||
	err := <-listen(portHTTP, portHTTPS, certFile, keyFile, router)
 | 
			
		||||
	globals.Logger.Fatalf("Could not start serving service due to (error: %s)", err)
 | 
			
		||||
 | 
			
		||||
	globals.Logger.Fatal(http.ListenAndServe(":"+strconv.FormatInt(int64(port), 10), mux))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func listen(portHTTP int, portHTTPS int, certFile string, keyFile string, router *chi.Mux) chan error {
 | 
			
		||||
 | 
			
		||||
	errs := make(chan error)
 | 
			
		||||
 | 
			
		||||
	// Starting HTTP server
 | 
			
		||||
	if portHTTP != -1 {
 | 
			
		||||
		go func() {
 | 
			
		||||
			globals.Logger.Printf("Staring HTTP service on %d ...", portHTTP)
 | 
			
		||||
 | 
			
		||||
			if err := http.ListenAndServe(":"+strconv.FormatInt(int64(portHTTP), 10), router); err != nil {
 | 
			
		||||
				errs <- err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Starting HTTPS server
 | 
			
		||||
	if portHTTPS != -1 && certFile != "" && keyFile != "" {
 | 
			
		||||
		go func() {
 | 
			
		||||
			globals.Logger.Printf("Staring HTTPS service on %d ...", portHTTPS)
 | 
			
		||||
			if err := http.ListenAndServeTLS(":"+strconv.FormatInt(int64(portHTTPS), 10), certFile, keyFile, router); err != nil {
 | 
			
		||||
				errs <- err
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return errs
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user