add chi-router, auth middleware & user roles.

group config options & split database logic
This commit is contained in:
gilex-dev 2023-11-05 17:42:14 +01:00
parent e4c9563961
commit a21494f94b
24 changed files with 3988 additions and 346 deletions

View File

@ -1,4 +1,16 @@
sqlite3_file: 'YetAnotherToDoList.sqlite3'
log_file: 'YetAnotherToDoList.log'
log_UTC: false
port: 4242
database:
sqlite3File: 'YetAnotherToDoList.sqlite3'
secret: 'aS3cureAppl1cationk3y'
initialAdmin:
userName: 'admin'
password: 'temporaryPassword'
logging:
logFile: 'YetAnotherToDoList.log'
logUTC: false
server:
portHTTP: 4242
portHTTPS: 4241
certFile: 'certFile.crt'
keyFile: 'keyFile.key'

View File

@ -62,3 +62,8 @@ Commands were run in the order listed below on a debian based system.
pnpm install villus graphql
pnpm install graphql-tag
```
- Add go-chi
```bash
go get -u github.com/go-chi/chi/v5
go mod tidy
```

View File

@ -74,8 +74,8 @@ func init() {
// will be global for your application.
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.YetAnotherToDoList.yaml)")
rootCmd.PersistentFlags().String("log_file", "", "Path to log file")
rootCmd.PersistentFlags().String("sqlite3_file", "", "Path to SQLite3 database")
rootCmd.PersistentFlags().String("logFile", "", "Path to log file")
rootCmd.PersistentFlags().String("sqlite3File", "", "Path to SQLite3 database")
// Cobra also supports local flags, which will only run
// when this action is called directly.
@ -120,7 +120,7 @@ func initLog() {
time_zone_local, _ := time.Now().Zone()
time_zone_offset := strings.Split(time.Now().In(time.Local).String(), " ")[2]
if viper.GetBool("log_UTC") {
if viper.GetBool("logging.logUTC") {
utc = log.LUTC
time_zone_use = "UTC"
time_zone_alt = time_zone_local
@ -140,25 +140,25 @@ func initLog() {
logger_flags := log.Ldate | log.Ltime | utc
globals.Logger = log.New(os.Stdout, "", logger_flags)
if err := viper.BindPFlag("log_file", rootCmd.Flags().Lookup("log_file")); err != nil {
if err := viper.BindPFlag("logging.logFile", rootCmd.Flags().Lookup("logFile")); err != nil {
fmt.Println("Unable to bind flag:", err)
}
if viper.GetString("log_file") != "" {
log_path, err := filepath.Abs(viper.GetString("log_file"))
if viper.GetString("logging.logFile") != "" {
log_path, err := filepath.Abs(viper.GetString("logging.logFile"))
globals.Logger.SetOutput(os.Stdout)
if err != nil {
globals.Logger.Println("Invalid path for log file", log_path)
}
log_file, err := os.OpenFile(log_path, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
logFile, err := os.OpenFile(log_path, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
if err != nil {
globals.Logger.Println("Failed to write to log file:", err)
} else {
globals.Logger.Println("Switching to log file", log_path)
globals.Logger.SetOutput(log_file)
globals.Logger.SetOutput(logFile)
}
}
@ -170,19 +170,21 @@ func initLog() {
}
func initDB() {
if err := viper.BindPFlag("sqlite3_file", rootCmd.Flags().Lookup("sqlite3_file")); err != nil {
if err := viper.BindPFlag("database.sqlite3File", rootCmd.Flags().Lookup("sqlite3File")); err != nil {
fmt.Println("Unable to bind flag:", err)
}
if viper.GetString("sqlite3_file") == "" {
if viper.GetString("database.sqlite3File") == "" {
globals.Logger.Fatalln("No SQLite3 file specified")
}
db_path, err := filepath.Abs(viper.GetString("sqlite3_file"))
db_path, err := filepath.Abs(viper.GetString("database.sqlite3File"))
if err != nil {
globals.Logger.Fatalln("Invalid path for SQLite3 file", db_path)
}
globals.Logger.Println("Connecting to SQLite3", db_path)
globals.DB = database.InitSQLite3(db_path, globals.DB_schema, globals.Logger)
globals.DB = database.InitSQLite3(db_path, globals.DB_schema, globals.Logger, []byte(viper.GetString("database.secret")), viper.GetString("database.initialAdmin.userName"), viper.GetString("database.initialAdmin.password"))
globals.DB.CleanExpiredRefreshTokensTicker(time.Minute * 10)
globals.DB.CleanRevokedAccessTokensTicker(time.Minute * 10)
}

View File

@ -38,7 +38,7 @@ var serverCmd = &cobra.Command{
}
globals.Logger.Println("starting http server...")
server.StartServer(viper.GetInt("port"))
server.StartServer(viper.GetInt("server.portHTTP"), viper.GetInt("server.portHTTPS"), viper.GetString("server.certFile"), viper.GetString("server.keyFile"))
},
}

475
database/crypto_helpers.go Normal file
View File

@ -0,0 +1,475 @@
/*
YetAnotherToDoList
Copyright © 2023 gilex-dev gilex-dev@proton.me
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package database
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"slices"
"strconv"
"strings"
"time"
"unicode"
"golang.org/x/crypto/bcrypt"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
)
type RefreshToken struct {
Selector string `json:"selector"`
Token string `json:"token"`
ExpiryDate int `json:"expiryDate"`
}
type AccessToken struct {
UserId string `json:"userId"`
IsAdmin bool `json:"isAdmin"`
IsUserCreator bool `json:"isUserCreator"`
ExpiryDate int `json:"expiryDate"`
}
const refreshTokenLifetime = "+10 day" // TODO: add to viper
const accessTokenLifetime = time.Minute * 10
const minPasswordLength = 10
const minUserNameLength = 10
var RevokedAccessTokens []*AccessToken
// ValidatePassword validates the passed string against the password criteria.
func ValidatePassword(password string) error {
if len([]rune(password)) < minPasswordLength {
return fmt.Errorf("password must be at least %d characters long", minPasswordLength)
}
for i, c := range password {
if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
return fmt.Errorf("password contains none-alphanumeric symbol at index %d", i)
}
}
return nil
}
// ValidateUserName validates the passed string against the user name criteria.
func ValidateUserName(userName string) error {
if len([]rune(userName)) < minUserNameLength {
return fmt.Errorf("userName must be at least %d characters long", minUserNameLength)
}
for i, c := range userName {
if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
return fmt.Errorf("userName contains none-alphanumeric symbol at index %d", i)
}
}
return nil
}
// GenerateHashFromPassword generates a hash of the passed password. Returns salt and salted & peppered hash.
func (db CustomDB) GenerateHashFromPassword(password string) (passwordHash []byte, err error) {
hashBytes, err := bcrypt.GenerateFromPassword(bytes.Join([][]byte{db.secret, []byte(password)}, nil), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
return hashBytes, nil
}
// GetRefreshTokenOwner takes a tokenId and return the owner's userId. Call before Update/Get/DeleteRefreshToken when not IS_admin.
func (db CustomDB) GetRefreshTokenOwner(tokenId string) (string, error) {
numTokenId, err := strconv.Atoi(tokenId)
if err != nil {
return "", errors.New("invalid tokenId")
}
statement, err := db.connection.Prepare("SELECT FK_User_userId, FROM RefreshToken WHERE tokenId = ?")
if err != nil {
return "", err
}
result := statement.QueryRow(numTokenId)
var owner string
if err := result.Scan(&owner); err != nil {
return "", err
}
return owner, nil
}
func (db CustomDB) ValidateUserCredentials(userId *string, userName *string, password string) (validUserId string, err error) {
var result *sql.Row
var hash string
if userId != nil { // user userId
numUserId, err := strconv.Atoi(*userId)
if err != nil {
return "", errors.New("userId not numeric")
}
statement, err := db.connection.Prepare("SELECT passwordHash FROM User WHERE userId = ?")
if err != nil {
return "", err
}
result = statement.QueryRow(numUserId)
if err := result.Scan(&hash); err != nil {
return "", err
}
} else if userName != nil { // user userName
statement, err := db.connection.Prepare("SELECT userId, passwordHash FROM User WHERE userName = ?")
if err != nil {
return "", err
}
result = statement.QueryRow(&userName)
if err := result.Scan(&userId, &hash); err != nil {
return "", err
}
} else {
return "", errors.New("neither userId nor userName specified")
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), bytes.Join([][]byte{db.secret, []byte(password)}, nil)); err != nil {
return "", err
}
return *userId, nil
}
// IssueRefreshToken issues a refresh token if the passed user credentials are valid. Returned refresh token can be passed to IssueAccessToken.
func (db CustomDB) IssueRefreshToken(userId string, tokenName *string) (refreshToken *RefreshToken, refreshTokenId string, err error) {
numUserId, err := strconv.Atoi(userId)
if err != nil {
return nil, "", errors.New("userId not numeric")
}
selector := make([]byte, 9)
if _, err := rand.Read(selector); err != nil {
return nil, "", err
}
token := make([]byte, 33)
if _, err := rand.Read(token); err != nil {
return nil, "", err
}
statement, err := db.connection.Prepare("INSERT INTO RefreshToken (FK_User_userId, selector, tokenHash, expiryDate, tokenName) VALUES (?, ?, ?, unixepoch('now','" + refreshTokenLifetime + "'), NULLIF(?, '')) RETURNING tokenId, expiryDate")
if err != nil {
return nil, "", err
}
encSelector := base64.RawURLEncoding.EncodeToString(selector)
encToken := base64.RawURLEncoding.EncodeToString(token)
tokenHash := sha256.Sum256(token)
var expiryDate int // int(time.Now().AddDate(0, 1, 0).Unix())
var tokenId string
result := statement.QueryRow(numUserId, encSelector, base64.RawURLEncoding.EncodeToString(tokenHash[:]), &tokenName)
if err := result.Scan(&tokenId, &expiryDate); err != nil {
return nil, "", err
}
return &RefreshToken{Selector: encSelector, Token: encToken, ExpiryDate: expiryDate}, tokenId, nil
}
func (db CustomDB) GetRefreshToken(token *model.RefreshToken) (*model.RefreshToken, error) {
numTokenId, err := strconv.Atoi(token.ID)
if err != nil {
return nil, errors.New("invalid tokenId")
}
statement, err := db.connection.Prepare("SELECT expiryDate, tokenName FROM RefreshToken WHERE tokenId = ?")
if err != nil {
return nil, err
}
result := statement.QueryRow(numTokenId)
if err := result.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil {
return nil, err
}
return token, nil
}
func (db CustomDB) GetRefreshTokensFrom(userId string) ([]*model.RefreshToken, error) {
numUserId, err := strconv.Atoi(userId)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("SELECT tokenId, expiaryDate, tokenName FROM RefreshToken WHERE FK_User_userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Query(numUserId)
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.RefreshToken
for rows.Next() {
token := model.RefreshToken{}
if err := rows.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil {
return nil, err
}
all = append(all, &token)
}
return all, nil
}
func (db CustomDB) GetAllRefreshTokens() ([]*model.RefreshToken, error) {
statement, err := db.connection.Prepare("SELECT tokenID, FK_User_userId, expiryDate, tokenName FROM RefreshToken")
if err != nil {
return nil, err
}
rows, err := statement.Query()
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.RefreshToken
for rows.Next() {
var token model.RefreshToken
if err := rows.Scan(&token.ID, &token.UserID, &token.ExpiryDate, &token.TokenName); err != nil {
return nil, err
}
all = append(all, &token)
}
return all, nil
}
func (db CustomDB) UpdateRefreshToken(tokenId string, changes *model.UpdateRefreshToken) (*model.RefreshToken, error) {
numTokenId, err := strconv.Atoi(tokenId)
if err != nil {
return nil, errors.New("invalid tokenId")
}
statement, err := db.connection.Prepare("UPDATE AuthToken SET tokenName = ? WHERE tokenId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(changes.TokenName, numTokenId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
token, err := db.GetRefreshToken(&model.RefreshToken{ID: tokenId})
if err != nil {
return nil, errors.New("failed to get updated token")
}
return token, nil
}
// RevokeRefreshToken revokes the access token matching the tokenId. Also calls RevokeAccessToken.
func (db CustomDB) RevokeRefreshToken(tokenId string) (*string, error) {
// TODO: return string instead of *string
numTokenId, err := strconv.Atoi(tokenId)
if err != nil {
return nil, errors.New("invalid tokenId")
}
statement, err := db.connection.Prepare("DELETE FROM RefreshToken WHERE tokenId = ? RETURNING FK_User_userId")
if err != nil {
return nil, err
}
result := statement.QueryRow(numTokenId)
var userId string
if err := result.Scan(&userId); err != nil {
return nil, err
}
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
return &tokenId, nil
}
// IssueAccessToken issues an access token if the passed refresh token is valid. Returned access token must be passed to SignAccessToken to be accepted.
func (db CustomDB) IssueAccessToken(refreshToken *RefreshToken) (*AccessToken, error) {
statement, err := db.connection.Prepare("SELECT RefreshToken.tokenHash, RefreshToken.FK_User_userId, Role.IS_admin, ROLE.IS_userCreator FROM RefreshToken INNER JOIN R_User_Role ON RefreshToken.FK_User_userId = R_User_Role.FK_User_userId INNER JOIN Role ON R_User_Role.FK_Role_roleId = Role.roleId WHERE RefreshToken.selector = ? AND RefreshToken.expiryDate >= unixepoch('now')")
if err != nil {
return nil, err
}
result := statement.QueryRow(refreshToken.Selector)
var tokenHash string
var newAccessToken AccessToken
if err := result.Scan(&tokenHash, &newAccessToken.UserId, &newAccessToken.IsAdmin, &newAccessToken.IsUserCreator); err != nil {
return nil, err
}
decUserToken, err := base64.RawURLEncoding.DecodeString(refreshToken.Token)
if err != nil {
return nil, err
}
userTokenHash := sha256.Sum256(decUserToken)
encUserTokenHash := base64.RawURLEncoding.EncodeToString(userTokenHash[:])
if encUserTokenHash != tokenHash {
return nil, errors.New("failed to issue access token: refreshToken does not match")
}
newAccessToken.ExpiryDate = int(time.Now().Add(accessTokenLifetime).Unix())
return &newAccessToken, nil
}
// SignAccessToken signs an access token and attaches a header. Returns access token encoded as jwt.
func (db CustomDB) SignAccessToken(accessToken AccessToken) (encAccessToken string, err error) {
data, err := json.Marshal(accessToken)
if err != nil {
return "", err
}
body := base64.RawURLEncoding.EncodeToString(data)
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
combined := header + "." + body
mac := hmac.New(sha256.New, db.secret)
mac.Write([]byte(combined))
signature := mac.Sum(nil)
return combined + "." + base64.RawURLEncoding.EncodeToString(signature), nil
}
func (db CustomDB) CheckAccessToken(encAccessToken string) (accessToken *AccessToken, err error) {
accessTokenBlob := strings.Split(encAccessToken, ".")
header, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[0])
if err != nil {
return nil, err
}
if !bytes.Equal(header, []byte(`{"alg":"HS256","typ":"JWT"}`)) {
return nil, errors.New("access token header wrong")
}
body, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[1])
if err != nil {
return nil, err
}
var data AccessToken
err = json.Unmarshal(body, &data)
if err != nil {
return nil, err
}
// TODO: maybe change part of signKey instead?
for _, revokedAccessToken := range RevokedAccessTokens {
if data.UserId == revokedAccessToken.UserId && data.ExpiryDate <= revokedAccessToken.ExpiryDate {
return nil, errors.New("access token revoked")
}
}
if data.ExpiryDate < int(time.Now().Unix()) {
return nil, errors.New("access token expired")
}
if data.ExpiryDate < initTimeStamp+int(accessTokenLifetime.Seconds()) {
return nil, errors.New("access token expired prematurely")
}
signature, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[2])
if err != nil {
return nil, err
}
mac := hmac.New(sha256.New, db.secret)
mac.Write([]byte(accessTokenBlob[0]))
mac.Write([]byte("."))
mac.Write([]byte(accessTokenBlob[1]))
expectedSignature := mac.Sum(nil)
if !hmac.Equal(signature, expectedSignature) {
return nil, errors.New("access token signature does not match")
}
return &data, nil
}
// RevokeAccessToken revokes all access tokens with matching UserId and UserRole that don't have a later ExpiryDate.
// revokedAccessToken.ExpiryDate should be set to now + token-lifetime.
func RevokeAccessToken(accessToken *AccessToken) {
RevokedAccessTokens = append(RevokedAccessTokens, accessToken)
}
// CleanRevokedTokensTicker removes expired tokens from the list. This should be called in an interval of > accessTokenLifetime.
func (db CustomDB) CleanRevokedAccessTokensTicker(interval time.Duration) (stopCleaner chan bool) {
ticker := time.NewTicker(interval)
stop := make(chan bool)
go func() {
for {
select {
case <-stop:
return
case <-ticker.C:
db.logger.Println("cleaning revoked access tokens")
for i, revokedAccessToken := range RevokedAccessTokens {
if revokedAccessToken.ExpiryDate < int(time.Now().Unix()) {
revokedAccessToken = nil
slices.Delete(RevokedAccessTokens, i, i+1)
}
}
}
}
}()
return stop
}
func (db CustomDB) CleanExpiredRefreshTokensTicker(interval time.Duration) (stopCleaner chan bool) {
ticker := time.NewTicker(interval)
stop := make(chan bool)
go func() {
for {
select {
case <-stop:
return
case <-ticker.C:
db.logger.Println("cleaning expired refresh tokens")
_, err := db.connection.Exec("DELETE FROM RefreshToken WHERE expiryDate < unixepoch('now')")
if err != nil {
db.logger.Println("failed to clean expired refresh tokens")
}
}
}
}()
return stop
}

View File

@ -18,10 +18,9 @@ package database
import (
"database/sql"
"errors"
"fmt"
"log"
"strconv"
"time"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
)
@ -30,11 +29,14 @@ type CustomDB struct {
connection *sql.DB
logger *log.Logger
schema uint
secret []byte
}
func InitSQLite3(path string, schema uint, logger *log.Logger) *CustomDB {
var initTimeStamp int
db := CustomDB{logger: logger, schema: schema}
func InitSQLite3(path string, schema uint, logger *log.Logger, secret []byte, initialAdminName string, initialAdminPassword string) *CustomDB {
initTimeStamp = int(time.Now().Unix())
db := CustomDB{logger: logger, schema: schema, secret: secret}
var err error
db.connection, err = sql.Open("sqlite3", "file:"+path+"?_foreign_keys=1")
@ -58,6 +60,10 @@ func InitSQLite3(path string, schema uint, logger *log.Logger) *CustomDB {
if err = db.createSQLite3Tables(); err != nil {
db.logger.Fatalln("Error in creating table: ", err)
}
err = db.CreateInitialAdmin(initialAdminName, initialAdminPassword)
if err != nil {
db.logger.Fatal("failed to create initial admin. Try to fix and delete old database file: ", err)
}
case user_version > db.schema:
db.logger.Fatalln("Incompatible database schema version. Try updating this software.")
case user_version < db.schema:
@ -71,8 +77,11 @@ func (db CustomDB) createSQLite3Tables() error {
name string
sql string
}{
{"User", "userId INTEGER PRIMARY KEY NOT NULL, userName VARCHAR NOT NULL UNIQUE, fullName VARCHAR"},
{"Todo", "todoId INTEGER PRIMARY KEY NOT NULL, text VARCHAR NOT NULL, IS_done BOOL NOT NULL DEFAULT false, FK_User_userId INTEGER NOT NULL, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
{"User", "userId INTEGER PRIMARY KEY NOT NULL, userName VARCHAR NOT NULL UNIQUE CHECK(length(userName)!=0), passwordHash VARCHAR NOT NULL CHECK(length(passwordHash)!=0), fullName VARCHAR CHECK(length(fullName)!=0)"},
{"Todo", "todoId INTEGER PRIMARY KEY NOT NULL, text VARCHAR NOT NULL CHECK(length(text)!=0), IS_done BOOL NOT NULL, FK_User_userId INTEGER NOT NULL, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
{"R_User_Role", "relationId INTEGER PRIMARY KEY NOT NULL, FK_Role_roleId INTEGER NOT NULL, FK_User_userId INTEGER NOT NULL, UNIQUE(FK_Role_roleId, FK_User_userId), FOREIGN KEY(FK_Role_roleId) REFERENCES Role(roleId) ON UPDATE CASCADE ON DELETE CASCADE, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
{"Role", "roleId INTEGER PRIMARY KEY NOT NULL, roleName VARCHAR NOT NULL UNIQUE CHECK(length(roleName)!=0), IS_admin BOOL NOT NULL, IS_userCreator BOOL NOT NULL"},
{"RefreshToken", "tokenId INTEGER PRIMARY KEY NOT NULL, FK_User_userId INTEGER NOT NULL, selector VARCHAR NOT NULL CHECK(length(selector)!=0) UNIQUE, tokenHash VARCHAR NOT NULL CHECK(length(tokenHash)!=0), expiryDate INTEGER NOT NULL, tokenName VARCHAR CHECK(length(tokenName)!=0), FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
}
for _, table := range tables {
_, err := db.connection.Exec("CREATE TABLE IF NOT EXISTS " + table.name + " (" + table.sql + ")")
@ -91,262 +100,18 @@ func (db CustomDB) createSQLite3Tables() error {
return nil
}
func (db CustomDB) GetUser(user *model.User) (*model.User, error) {
id, err := strconv.Atoi(user.ID)
func (db CustomDB) CreateInitialAdmin(initialAdminName string, initialAdminPassword string) error {
role, err := db.CreateRole(&model.NewRole{RoleName: "admin", IsAdmin: true, IsUserCreator: true})
if err != nil {
return nil, errors.New("invalid userId")
return err
}
statement, err := db.connection.Prepare("SELECT userName, fullName FROM User WHERE userId = ?")
user, err := db.CreateUser(model.NewUser{UserName: initialAdminName, Password: initialAdminPassword})
if err != nil {
return nil, err
return err
}
result := statement.QueryRow(id)
if err := result.Scan(&user.UserName, &user.FullName); err != nil {
return nil, err
_, err = db.AddRole(user.ID, role.ID)
if err != nil {
return err
}
return user, nil
}
func (db CustomDB) GetTodo(todo *model.Todo) (*model.Todo, error) {
id, err := strconv.Atoi(todo.ID)
if err != nil {
return nil, errors.New("invalid todoId")
}
statement, err := db.connection.Prepare("SELECT text, IS_done FROM Todo WHERE todoId = ?")
if err != nil {
return nil, err
}
result := statement.QueryRow(id)
if err := result.Scan(&todo.Text, &todo.Done); err != nil {
return nil, err
}
return todo, nil
}
func (db CustomDB) GetTodosFrom(user *model.User) ([]*model.Todo, error) {
id, err := strconv.Atoi(user.ID)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("SELECT todoId, text, IS_done FROM Todo WHERE FK_User_userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Query(id)
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.Todo
for rows.Next() {
todo := model.Todo{User: user}
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done); err != nil {
return nil, err
}
all = append(all, &todo)
}
return all, nil
}
func (db CustomDB) GetAllUsers() ([]*model.User, error) {
rows, err := db.connection.Query("SELECT userId, userName, fullName FROM User")
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.User
for rows.Next() {
var user model.User
if err := rows.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
return nil, err
}
all = append(all, &user)
}
return all, nil
}
func (db CustomDB) GetAllTodos() ([]*model.Todo, error) {
rows, err := db.connection.Query("SELECT Todo.todoID, Todo.text, Todo.IS_done, User.userID FROM Todo INNER JOIN User ON Todo.FK_User_userID=User.userID")
if err != nil {
return nil, err
}
defer rows.Close()
var todos []*model.Todo
for rows.Next() {
var todo = model.Todo{User: &model.User{}}
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done, &todo.User.ID); err != nil {
return nil, err
}
todos = append(todos, &todo)
}
return todos, nil
}
func (db CustomDB) AddUser(newUser model.NewUser) (*model.User, error) {
statement, err := db.connection.Prepare("INSERT INTO User (userName, fullName) VALUES (?, ?)")
if err != nil {
return nil, err
}
rows, err := statement.Exec(newUser.UserName, newUser.FullName)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
insertId, err := rows.LastInsertId()
if err != nil {
return nil, err
}
return &model.User{ID: strconv.FormatInt(insertId, 10), UserName: newUser.UserName, FullName: newUser.FullName}, nil
}
func (db CustomDB) AddTodo(newTodo model.NewTodo) (*model.Todo, error) {
statement, err := db.connection.Prepare("INSERT INTO Todo (text, FK_User_userID) VALUES (?, ?)")
if err != nil {
return nil, err
}
rows, err := statement.Exec(newTodo.Text, newTodo.UserID)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
insertId, err := rows.LastInsertId()
if err != nil {
return nil, err
}
return &model.Todo{ID: strconv.FormatInt(insertId, 10), Text: newTodo.Text, Done: false}, nil
}
func (db CustomDB) UpdateUser(userId string, changes *model.UpdateUser) (*model.User, error) {
id, err := strconv.Atoi(userId)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("UPDATE User SET userName = IFNULL(?, userName), fullName = IFNULL(?, fullName) WHERE userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(changes.UserName, changes.FullName, id)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return db.GetUser(&model.User{ID: userId})
}
func (db CustomDB) UpdateTodo(todoId string, changes *model.UpdateTodo) (*model.Todo, error) {
id, err := strconv.Atoi(todoId)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("UPDATE Todo SET text = IFNULL(?, text), IS_done = IFNULL(?, IS_done) WHERE todoId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(changes.Text, changes.Done, id)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return db.GetTodo(&model.Todo{ID: todoId})
}
func (db CustomDB) DeleteUser(userId string) (*string, error) {
statement, err := db.connection.Prepare("DELETE FROM User WHERE userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(userId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return &userId, nil
}
func (db CustomDB) DeleteTodo(todoId string) (*string, error) {
statement, err := db.connection.Prepare("DELETE FROM Todo WHERE todoId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(todoId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return &todoId, nil
return nil
}

168
database/role.go Normal file
View File

@ -0,0 +1,168 @@
/*
YetAnotherToDoList
Copyright © 2023 gilex-dev gilex-dev@proton.me
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package database
import (
"errors"
"strconv"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
)
func (db CustomDB) CreateRole(newRole *model.NewRole) (role *model.Role, err error) {
statement, err := db.connection.Prepare("INSERT INTO Role (roleName, IS_admin, IS_userCreator) VALUES (?, ?, ?)")
if err != nil {
return nil, err
}
rows, err := statement.Exec(newRole.RoleName, newRole.IsAdmin, newRole.IsUserCreator)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
insertId, err := rows.LastInsertId()
if err != nil {
return nil, err
}
return &model.Role{ID: strconv.FormatInt(insertId, 10), RoleName: newRole.RoleName, IsAdmin: newRole.IsAdmin, IsUserCreator: newRole.IsUserCreator}, nil
}
func (db CustomDB) GetRole(role *model.Role) (*model.Role, error) {
id, err := strconv.Atoi(role.ID)
if err != nil {
return nil, errors.New("invalid roleId")
}
statement, err := db.connection.Prepare("SELECT roleName, IS_admin, IS_userCreator FROM Role WHERE roleId = ?")
if err != nil {
return nil, err
}
result := statement.QueryRow(id)
if err := result.Scan(&role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
return nil, err
}
return role, nil
}
func (db CustomDB) GetRolesFrom(userId string) ([]*model.Role, error) {
numUserId, err := strconv.Atoi(userId)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("SELECT Role.roleId, Role.roleName, Role.IS_admin, Role.IS_userCreator FROM Role INNER JOIN R_User_Role ON R_User_Role.FK_Role_roleId = Role.roleId WHERE R_User_Role.FK_User_userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Query(numUserId)
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.Role
for rows.Next() {
role := model.Role{}
if err := rows.Scan(&role.ID, &role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
return nil, err
}
all = append(all, &role)
}
return all, nil
}
func (db CustomDB) GetAllRoles() ([]*model.Role, error) {
rows, err := db.connection.Query("SELECT roleId, roleName, IS_admin, IS_userCreator FROM Role")
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.Role
for rows.Next() {
var role model.Role
if err := rows.Scan(&role.ID, &role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
return nil, err
}
all = append(all, &role)
}
return all, nil
}
func (db CustomDB) UpdateRole(roleId string, changes *model.UpdateRole) (*model.Role, error) {
id, err := strconv.Atoi(roleId)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("UPDATE Role SET roleName = IFNULL(?, roleName), IS_admin = IFNULL(?, IS_admin), IS_userCreator = IFNULL(?, IS_userCreator) WHERE roleId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(changes.RoleName, changes.IsAdmin, changes.IsUserCreator, id)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return db.GetRole(&model.Role{ID: roleId})
}
func (db CustomDB) DeleteRole(roleId string) (*string, error) {
statement, err := db.connection.Prepare("DELETE FROM Role WHERE roleId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(roleId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return &roleId, nil
}

199
database/todo.go Normal file
View File

@ -0,0 +1,199 @@
/*
YetAnotherToDoList
Copyright © 2023 gilex-dev gilex-dev@proton.me
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package database
import (
"errors"
"strconv"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
)
// GetOwner takes a todoId and return the owner's userId. Call before Update/Get/DeleteTodo when not IS_admin.
func (db CustomDB) GetOwner(todoId string) (string, error) {
numTodoId, err := strconv.Atoi(todoId)
if err != nil {
return "", errors.New("invalid todoId")
}
statement, err := db.connection.Prepare("SELECT FK_User_userId, FROM Todo WHERE todoId = ?")
if err != nil {
return "", err
}
result := statement.QueryRow(numTodoId)
var owner string
if err := result.Scan(&owner); err != nil {
return "", err
}
return owner, nil
}
// GetTodo takes a *model.Todo with at least ID set and adds the missing fields.
func (db CustomDB) GetTodo(todo *model.Todo) (*model.Todo, error) {
numTodoId, err := strconv.Atoi(todo.ID)
if err != nil {
return nil, errors.New("invalid todoId")
}
statement, err := db.connection.Prepare("SELECT text, IS_done FROM Todo WHERE todoId = ?")
if err != nil {
return nil, err
}
result := statement.QueryRow(numTodoId)
if err := result.Scan(&todo.Text, &todo.Done); err != nil {
return nil, err
}
return todo, nil
}
// GetTodosFrom gets all todos from the passed *model.User. ID must be set.
func (db CustomDB) GetTodosFrom(user *model.User) ([]*model.Todo, error) {
id, err := strconv.Atoi(user.ID)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("SELECT todoId, text, IS_done FROM Todo WHERE FK_User_userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Query(id)
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.Todo
for rows.Next() {
todo := model.Todo{User: user}
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done); err != nil {
return nil, err
}
all = append(all, &todo)
}
return all, nil
}
// GetAllTodos gets all todos from the database. Check if the source has the rights to call this.
func (db CustomDB) GetAllTodos() ([]*model.Todo, error) {
rows, err := db.connection.Query("SELECT todoID, text, IS_done, FK_User_userID FROM Todo")
if err != nil {
return nil, err
}
defer rows.Close()
var todos []*model.Todo
for rows.Next() {
var todo = model.Todo{User: &model.User{}}
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done, &todo.User.ID); err != nil {
return nil, err
}
todos = append(todos, &todo)
}
return todos, nil
}
// AddTodo adds a Todo to passed UserID. Check if the source has the rights to call this.
func (db CustomDB) AddTodo(newTodo model.NewTodo) (*model.Todo, error) {
statement, err := db.connection.Prepare("INSERT INTO Todo (text, FK_User_userID) VALUES (?, ?)")
if err != nil {
return nil, err
}
rows, err := statement.Exec(newTodo.Text, newTodo.UserID)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
insertId, err := rows.LastInsertId()
if err != nil {
return nil, err
}
return &model.Todo{ID: strconv.FormatInt(insertId, 10), Text: newTodo.Text, Done: false}, nil
}
func (db CustomDB) UpdateTodo(todoId string, changes *model.UpdateTodo) (*model.Todo, error) {
numTodoId, err := strconv.Atoi(todoId)
if err != nil {
return nil, errors.New("invalid todoId")
}
statement, err := db.connection.Prepare("UPDATE Todo SET text = IFNULL(?, text), IS_done = IFNULL(?, IS_done) WHERE todoId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(changes.Text, changes.Done, numTodoId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return db.GetTodo(&model.Todo{ID: todoId})
}
func (db CustomDB) DeleteTodo(todoId string) (deletedTodoId *string, err error) {
numTodoId, err := strconv.Atoi(todoId)
if err != nil {
return nil, errors.New("invalid todoId")
}
statement, err := db.connection.Prepare("DELETE FROM Todo WHERE todoId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(numTodoId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
return &todoId, nil
}

257
database/user.go Normal file
View File

@ -0,0 +1,257 @@
/*
YetAnotherToDoList
Copyright © 2023 gilex-dev gilex-dev@proton.me
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package database
import (
"errors"
"strconv"
"time"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
)
func (db CustomDB) CreateUser(newUser model.NewUser) (*model.User, error) {
statement, err := db.connection.Prepare("INSERT INTO User (userName, fullName, passwordHash) VALUES (?, NULLIF(?, ''), ?)")
if err != nil {
return nil, err
}
if ValidateUserName(newUser.UserName); err != nil {
return nil, err
}
if ValidatePassword(newUser.Password); err != nil {
return nil, err
}
passwordHash, err := db.GenerateHashFromPassword(newUser.Password)
if err != nil {
return nil, err
}
rows, err := statement.Exec(newUser.UserName, newUser.FullName, string(passwordHash))
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
insertId, err := rows.LastInsertId()
if err != nil {
return nil, err
}
return &model.User{ID: strconv.FormatInt(insertId, 10), UserName: newUser.UserName, FullName: newUser.FullName}, nil
}
// GetUser takes a *model.User with at least ID or UserName set and adds the missing fields.
func (db CustomDB) GetUser(user *model.User) (*model.User, error) {
numUserId, err := strconv.Atoi(user.ID)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("SELECT userID, userName, fullName FROM User WHERE userId = ? OR userName = ?")
if err != nil {
return nil, err
}
result := statement.QueryRow(numUserId, user.UserName)
if err := result.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
return nil, err
}
return user, nil
}
func (db CustomDB) GetAllUsers() ([]*model.User, error) {
rows, err := db.connection.Query("SELECT userId, userName, fullName FROM User")
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.User
for rows.Next() {
var user model.User
if err := rows.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
return nil, err
}
all = append(all, &user)
}
return all, nil
}
func (db CustomDB) UpdateUser(userId string, changes *model.UpdateUser) (*model.User, error) {
var passwordHash *string
needAccessTokenRefresh := false
id, err := strconv.Atoi(userId)
if err != nil {
return nil, errors.New("invalid userId")
}
statement, err := db.connection.Prepare("UPDATE User SET userName = IFNULL(?, userName), fullName = IFNULL(NULLIF(?, ''), fullName), passwordHash = IFNULL(?, passwordHash) WHERE userId = ?")
if err != nil {
return nil, err
}
if *changes.UserName == "" { // interpret empty string as nil
changes.UserName = nil
} else {
if err := ValidateUserName(*changes.UserName); err != nil {
return nil, err
}
needAccessTokenRefresh = true
}
if *changes.Password == "" { // interpret empty string as nil
passwordHash = nil
} else {
if err := ValidatePassword(*changes.Password); err != nil {
return nil, err
}
passwordHashByte, err := db.GenerateHashFromPassword(*changes.Password)
if err != nil {
return nil, err
}
passwordHashString := string(passwordHashByte)
passwordHash = &passwordHashString
needAccessTokenRefresh = true
}
rows, err := statement.Exec(changes.UserName, changes.FullName, passwordHash, id)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
if needAccessTokenRefresh {
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
}
return db.GetUser(&model.User{ID: userId})
}
func (db CustomDB) DeleteUser(userId string) (*string, error) {
statement, err := db.connection.Prepare("DELETE FROM User WHERE userId = ?")
if err != nil {
return nil, err
}