add chi-router, auth middleware & user roles.
group config options & split database logic
This commit is contained in:
parent
e4c9563961
commit
a21494f94b
|
@ -1,4 +1,16 @@
|
|||
sqlite3_file: 'YetAnotherToDoList.sqlite3'
|
||||
log_file: 'YetAnotherToDoList.log'
|
||||
log_UTC: false
|
||||
port: 4242
|
||||
database:
|
||||
sqlite3File: 'YetAnotherToDoList.sqlite3'
|
||||
secret: 'aS3cureAppl1cationk3y'
|
||||
initialAdmin:
|
||||
userName: 'admin'
|
||||
password: 'temporaryPassword'
|
||||
|
||||
logging:
|
||||
logFile: 'YetAnotherToDoList.log'
|
||||
logUTC: false
|
||||
|
||||
server:
|
||||
portHTTP: 4242
|
||||
portHTTPS: 4241
|
||||
certFile: 'certFile.crt'
|
||||
keyFile: 'keyFile.key'
|
||||
|
|
|
@ -62,3 +62,8 @@ Commands were run in the order listed below on a debian based system.
|
|||
pnpm install villus graphql
|
||||
pnpm install graphql-tag
|
||||
```
|
||||
- Add go-chi
|
||||
```bash
|
||||
go get -u github.com/go-chi/chi/v5
|
||||
go mod tidy
|
||||
```
|
||||
|
|
26
cmd/root.go
26
cmd/root.go
|
@ -74,8 +74,8 @@ func init() {
|
|||
// will be global for your application.
|
||||
|
||||
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.YetAnotherToDoList.yaml)")
|
||||
rootCmd.PersistentFlags().String("log_file", "", "Path to log file")
|
||||
rootCmd.PersistentFlags().String("sqlite3_file", "", "Path to SQLite3 database")
|
||||
rootCmd.PersistentFlags().String("logFile", "", "Path to log file")
|
||||
rootCmd.PersistentFlags().String("sqlite3File", "", "Path to SQLite3 database")
|
||||
|
||||
// Cobra also supports local flags, which will only run
|
||||
// when this action is called directly.
|
||||
|
@ -120,7 +120,7 @@ func initLog() {
|
|||
time_zone_local, _ := time.Now().Zone()
|
||||
time_zone_offset := strings.Split(time.Now().In(time.Local).String(), " ")[2]
|
||||
|
||||
if viper.GetBool("log_UTC") {
|
||||
if viper.GetBool("logging.logUTC") {
|
||||
utc = log.LUTC
|
||||
time_zone_use = "UTC"
|
||||
time_zone_alt = time_zone_local
|
||||
|
@ -140,25 +140,25 @@ func initLog() {
|
|||
logger_flags := log.Ldate | log.Ltime | utc
|
||||
globals.Logger = log.New(os.Stdout, "", logger_flags)
|
||||
|
||||
if err := viper.BindPFlag("log_file", rootCmd.Flags().Lookup("log_file")); err != nil {
|
||||
if err := viper.BindPFlag("logging.logFile", rootCmd.Flags().Lookup("logFile")); err != nil {
|
||||
fmt.Println("Unable to bind flag:", err)
|
||||
}
|
||||
|
||||
if viper.GetString("log_file") != "" {
|
||||
log_path, err := filepath.Abs(viper.GetString("log_file"))
|
||||
if viper.GetString("logging.logFile") != "" {
|
||||
log_path, err := filepath.Abs(viper.GetString("logging.logFile"))
|
||||
|
||||
globals.Logger.SetOutput(os.Stdout)
|
||||
if err != nil {
|
||||
globals.Logger.Println("Invalid path for log file", log_path)
|
||||
}
|
||||
|
||||
log_file, err := os.OpenFile(log_path, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
logFile, err := os.OpenFile(log_path, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
|
||||
if err != nil {
|
||||
globals.Logger.Println("Failed to write to log file:", err)
|
||||
} else {
|
||||
globals.Logger.Println("Switching to log file", log_path)
|
||||
globals.Logger.SetOutput(log_file)
|
||||
globals.Logger.SetOutput(logFile)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -170,19 +170,21 @@ func initLog() {
|
|||
}
|
||||
|
||||
func initDB() {
|
||||
if err := viper.BindPFlag("sqlite3_file", rootCmd.Flags().Lookup("sqlite3_file")); err != nil {
|
||||
if err := viper.BindPFlag("database.sqlite3File", rootCmd.Flags().Lookup("sqlite3File")); err != nil {
|
||||
fmt.Println("Unable to bind flag:", err)
|
||||
}
|
||||
|
||||
if viper.GetString("sqlite3_file") == "" {
|
||||
if viper.GetString("database.sqlite3File") == "" {
|
||||
globals.Logger.Fatalln("No SQLite3 file specified")
|
||||
}
|
||||
|
||||
db_path, err := filepath.Abs(viper.GetString("sqlite3_file"))
|
||||
db_path, err := filepath.Abs(viper.GetString("database.sqlite3File"))
|
||||
if err != nil {
|
||||
globals.Logger.Fatalln("Invalid path for SQLite3 file", db_path)
|
||||
}
|
||||
|
||||
globals.Logger.Println("Connecting to SQLite3", db_path)
|
||||
globals.DB = database.InitSQLite3(db_path, globals.DB_schema, globals.Logger)
|
||||
globals.DB = database.InitSQLite3(db_path, globals.DB_schema, globals.Logger, []byte(viper.GetString("database.secret")), viper.GetString("database.initialAdmin.userName"), viper.GetString("database.initialAdmin.password"))
|
||||
globals.DB.CleanExpiredRefreshTokensTicker(time.Minute * 10)
|
||||
globals.DB.CleanRevokedAccessTokensTicker(time.Minute * 10)
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ var serverCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
globals.Logger.Println("starting http server...")
|
||||
server.StartServer(viper.GetInt("port"))
|
||||
server.StartServer(viper.GetInt("server.portHTTP"), viper.GetInt("server.portHTTPS"), viper.GetString("server.certFile"), viper.GetString("server.keyFile"))
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,475 @@
|
|||
/*
|
||||
YetAnotherToDoList
|
||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, version 3.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package database
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
|
||||
)
|
||||
|
||||
type RefreshToken struct {
|
||||
Selector string `json:"selector"`
|
||||
Token string `json:"token"`
|
||||
ExpiryDate int `json:"expiryDate"`
|
||||
}
|
||||
|
||||
type AccessToken struct {
|
||||
UserId string `json:"userId"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
IsUserCreator bool `json:"isUserCreator"`
|
||||
ExpiryDate int `json:"expiryDate"`
|
||||
}
|
||||
|
||||
const refreshTokenLifetime = "+10 day" // TODO: add to viper
|
||||
const accessTokenLifetime = time.Minute * 10
|
||||
|
||||
const minPasswordLength = 10
|
||||
const minUserNameLength = 10
|
||||
|
||||
var RevokedAccessTokens []*AccessToken
|
||||
|
||||
// ValidatePassword validates the passed string against the password criteria.
|
||||
func ValidatePassword(password string) error {
|
||||
if len([]rune(password)) < minPasswordLength {
|
||||
return fmt.Errorf("password must be at least %d characters long", minPasswordLength)
|
||||
}
|
||||
for i, c := range password {
|
||||
if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
|
||||
return fmt.Errorf("password contains none-alphanumeric symbol at index %d", i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateUserName validates the passed string against the user name criteria.
|
||||
func ValidateUserName(userName string) error {
|
||||
if len([]rune(userName)) < minUserNameLength {
|
||||
return fmt.Errorf("userName must be at least %d characters long", minUserNameLength)
|
||||
}
|
||||
for i, c := range userName {
|
||||
if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
|
||||
return fmt.Errorf("userName contains none-alphanumeric symbol at index %d", i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateHashFromPassword generates a hash of the passed password. Returns salt and salted & peppered hash.
|
||||
func (db CustomDB) GenerateHashFromPassword(password string) (passwordHash []byte, err error) {
|
||||
hashBytes, err := bcrypt.GenerateFromPassword(bytes.Join([][]byte{db.secret, []byte(password)}, nil), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hashBytes, nil
|
||||
}
|
||||
|
||||
// GetRefreshTokenOwner takes a tokenId and return the owner's userId. Call before Update/Get/DeleteRefreshToken when not IS_admin.
|
||||
func (db CustomDB) GetRefreshTokenOwner(tokenId string) (string, error) {
|
||||
numTokenId, err := strconv.Atoi(tokenId)
|
||||
if err != nil {
|
||||
return "", errors.New("invalid tokenId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("SELECT FK_User_userId, FROM RefreshToken WHERE tokenId = ?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(numTokenId)
|
||||
var owner string
|
||||
if err := result.Scan(&owner); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return owner, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) ValidateUserCredentials(userId *string, userName *string, password string) (validUserId string, err error) {
|
||||
var result *sql.Row
|
||||
var hash string
|
||||
|
||||
if userId != nil { // user userId
|
||||
numUserId, err := strconv.Atoi(*userId)
|
||||
if err != nil {
|
||||
return "", errors.New("userId not numeric")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("SELECT passwordHash FROM User WHERE userId = ?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = statement.QueryRow(numUserId)
|
||||
if err := result.Scan(&hash); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else if userName != nil { // user userName
|
||||
statement, err := db.connection.Prepare("SELECT userId, passwordHash FROM User WHERE userName = ?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = statement.QueryRow(&userName)
|
||||
if err := result.Scan(&userId, &hash); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
return "", errors.New("neither userId nor userName specified")
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), bytes.Join([][]byte{db.secret, []byte(password)}, nil)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return *userId, nil
|
||||
}
|
||||
|
||||
// IssueRefreshToken issues a refresh token if the passed user credentials are valid. Returned refresh token can be passed to IssueAccessToken.
|
||||
func (db CustomDB) IssueRefreshToken(userId string, tokenName *string) (refreshToken *RefreshToken, refreshTokenId string, err error) {
|
||||
numUserId, err := strconv.Atoi(userId)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("userId not numeric")
|
||||
}
|
||||
selector := make([]byte, 9)
|
||||
if _, err := rand.Read(selector); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
token := make([]byte, 33)
|
||||
if _, err := rand.Read(token); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("INSERT INTO RefreshToken (FK_User_userId, selector, tokenHash, expiryDate, tokenName) VALUES (?, ?, ?, unixepoch('now','" + refreshTokenLifetime + "'), NULLIF(?, '')) RETURNING tokenId, expiryDate")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
encSelector := base64.RawURLEncoding.EncodeToString(selector)
|
||||
encToken := base64.RawURLEncoding.EncodeToString(token)
|
||||
|
||||
tokenHash := sha256.Sum256(token)
|
||||
|
||||
var expiryDate int // int(time.Now().AddDate(0, 1, 0).Unix())
|
||||
var tokenId string
|
||||
result := statement.QueryRow(numUserId, encSelector, base64.RawURLEncoding.EncodeToString(tokenHash[:]), &tokenName)
|
||||
|
||||
if err := result.Scan(&tokenId, &expiryDate); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return &RefreshToken{Selector: encSelector, Token: encToken, ExpiryDate: expiryDate}, tokenId, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetRefreshToken(token *model.RefreshToken) (*model.RefreshToken, error) {
|
||||
numTokenId, err := strconv.Atoi(token.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid tokenId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("SELECT expiryDate, tokenName FROM RefreshToken WHERE tokenId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(numTokenId)
|
||||
if err := result.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetRefreshTokensFrom(userId string) ([]*model.RefreshToken, error) {
|
||||
numUserId, err := strconv.Atoi(userId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT tokenId, expiaryDate, tokenName FROM RefreshToken WHERE FK_User_userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Query(numUserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.RefreshToken
|
||||
for rows.Next() {
|
||||
token := model.RefreshToken{}
|
||||
if err := rows.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &token)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetAllRefreshTokens() ([]*model.RefreshToken, error) {
|
||||
|
||||
statement, err := db.connection.Prepare("SELECT tokenID, FK_User_userId, expiryDate, tokenName FROM RefreshToken")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Query()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.RefreshToken
|
||||
for rows.Next() {
|
||||
var token model.RefreshToken
|
||||
if err := rows.Scan(&token.ID, &token.UserID, &token.ExpiryDate, &token.TokenName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &token)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) UpdateRefreshToken(tokenId string, changes *model.UpdateRefreshToken) (*model.RefreshToken, error) {
|
||||
numTokenId, err := strconv.Atoi(tokenId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid tokenId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("UPDATE AuthToken SET tokenName = ? WHERE tokenId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(changes.TokenName, numTokenId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
token, err := db.GetRefreshToken(&model.RefreshToken{ID: tokenId})
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get updated token")
|
||||
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// RevokeRefreshToken revokes the access token matching the tokenId. Also calls RevokeAccessToken.
|
||||
func (db CustomDB) RevokeRefreshToken(tokenId string) (*string, error) {
|
||||
// TODO: return string instead of *string
|
||||
numTokenId, err := strconv.Atoi(tokenId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid tokenId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("DELETE FROM RefreshToken WHERE tokenId = ? RETURNING FK_User_userId")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(numTokenId)
|
||||
|
||||
var userId string
|
||||
if err := result.Scan(&userId); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
|
||||
return &tokenId, nil
|
||||
}
|
||||
|
||||
// IssueAccessToken issues an access token if the passed refresh token is valid. Returned access token must be passed to SignAccessToken to be accepted.
|
||||
func (db CustomDB) IssueAccessToken(refreshToken *RefreshToken) (*AccessToken, error) {
|
||||
statement, err := db.connection.Prepare("SELECT RefreshToken.tokenHash, RefreshToken.FK_User_userId, Role.IS_admin, ROLE.IS_userCreator FROM RefreshToken INNER JOIN R_User_Role ON RefreshToken.FK_User_userId = R_User_Role.FK_User_userId INNER JOIN Role ON R_User_Role.FK_Role_roleId = Role.roleId WHERE RefreshToken.selector = ? AND RefreshToken.expiryDate >= unixepoch('now')")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(refreshToken.Selector)
|
||||
|
||||
var tokenHash string
|
||||
var newAccessToken AccessToken
|
||||
if err := result.Scan(&tokenHash, &newAccessToken.UserId, &newAccessToken.IsAdmin, &newAccessToken.IsUserCreator); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decUserToken, err := base64.RawURLEncoding.DecodeString(refreshToken.Token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userTokenHash := sha256.Sum256(decUserToken)
|
||||
encUserTokenHash := base64.RawURLEncoding.EncodeToString(userTokenHash[:])
|
||||
if encUserTokenHash != tokenHash {
|
||||
return nil, errors.New("failed to issue access token: refreshToken does not match")
|
||||
}
|
||||
|
||||
newAccessToken.ExpiryDate = int(time.Now().Add(accessTokenLifetime).Unix())
|
||||
return &newAccessToken, nil
|
||||
}
|
||||
|
||||
// SignAccessToken signs an access token and attaches a header. Returns access token encoded as jwt.
|
||||
func (db CustomDB) SignAccessToken(accessToken AccessToken) (encAccessToken string, err error) {
|
||||
|
||||
data, err := json.Marshal(accessToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
body := base64.RawURLEncoding.EncodeToString(data)
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
|
||||
combined := header + "." + body
|
||||
|
||||
mac := hmac.New(sha256.New, db.secret)
|
||||
mac.Write([]byte(combined))
|
||||
signature := mac.Sum(nil)
|
||||
|
||||
return combined + "." + base64.RawURLEncoding.EncodeToString(signature), nil
|
||||
}
|
||||
|
||||
func (db CustomDB) CheckAccessToken(encAccessToken string) (accessToken *AccessToken, err error) {
|
||||
accessTokenBlob := strings.Split(encAccessToken, ".")
|
||||
|
||||
header, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !bytes.Equal(header, []byte(`{"alg":"HS256","typ":"JWT"}`)) {
|
||||
return nil, errors.New("access token header wrong")
|
||||
}
|
||||
|
||||
body, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var data AccessToken
|
||||
err = json.Unmarshal(body, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: maybe change part of signKey instead?
|
||||
for _, revokedAccessToken := range RevokedAccessTokens {
|
||||
if data.UserId == revokedAccessToken.UserId && data.ExpiryDate <= revokedAccessToken.ExpiryDate {
|
||||
return nil, errors.New("access token revoked")
|
||||
}
|
||||
}
|
||||
|
||||
if data.ExpiryDate < int(time.Now().Unix()) {
|
||||
return nil, errors.New("access token expired")
|
||||
}
|
||||
|
||||
if data.ExpiryDate < initTimeStamp+int(accessTokenLifetime.Seconds()) {
|
||||
return nil, errors.New("access token expired prematurely")
|
||||
}
|
||||
|
||||
signature, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[2])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, db.secret)
|
||||
mac.Write([]byte(accessTokenBlob[0]))
|
||||
mac.Write([]byte("."))
|
||||
mac.Write([]byte(accessTokenBlob[1]))
|
||||
expectedSignature := mac.Sum(nil)
|
||||
if !hmac.Equal(signature, expectedSignature) {
|
||||
return nil, errors.New("access token signature does not match")
|
||||
}
|
||||
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// RevokeAccessToken revokes all access tokens with matching UserId and UserRole that don't have a later ExpiryDate.
|
||||
// revokedAccessToken.ExpiryDate should be set to now + token-lifetime.
|
||||
func RevokeAccessToken(accessToken *AccessToken) {
|
||||
RevokedAccessTokens = append(RevokedAccessTokens, accessToken)
|
||||
}
|
||||
|
||||
// CleanRevokedTokensTicker removes expired tokens from the list. This should be called in an interval of > accessTokenLifetime.
|
||||
func (db CustomDB) CleanRevokedAccessTokensTicker(interval time.Duration) (stopCleaner chan bool) {
|
||||
ticker := time.NewTicker(interval)
|
||||
stop := make(chan bool)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
db.logger.Println("cleaning revoked access tokens")
|
||||
for i, revokedAccessToken := range RevokedAccessTokens {
|
||||
if revokedAccessToken.ExpiryDate < int(time.Now().Unix()) {
|
||||
revokedAccessToken = nil
|
||||
slices.Delete(RevokedAccessTokens, i, i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return stop
|
||||
}
|
||||
|
||||
func (db CustomDB) CleanExpiredRefreshTokensTicker(interval time.Duration) (stopCleaner chan bool) {
|
||||
ticker := time.NewTicker(interval)
|
||||
stop := make(chan bool)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
db.logger.Println("cleaning expired refresh tokens")
|
||||
_, err := db.connection.Exec("DELETE FROM RefreshToken WHERE expiryDate < unixepoch('now')")
|
||||
if err != nil {
|
||||
db.logger.Println("failed to clean expired refresh tokens")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return stop
|
||||
}
|
283
database/main.go
283
database/main.go
|
@ -18,10 +18,9 @@ package database
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
|
||||
)
|
||||
|
@ -30,11 +29,14 @@ type CustomDB struct {
|
|||
connection *sql.DB
|
||||
logger *log.Logger
|
||||
schema uint
|
||||
secret []byte
|
||||
}
|
||||
|
||||
func InitSQLite3(path string, schema uint, logger *log.Logger) *CustomDB {
|
||||
var initTimeStamp int
|
||||
|
||||
db := CustomDB{logger: logger, schema: schema}
|
||||
func InitSQLite3(path string, schema uint, logger *log.Logger, secret []byte, initialAdminName string, initialAdminPassword string) *CustomDB {
|
||||
initTimeStamp = int(time.Now().Unix())
|
||||
db := CustomDB{logger: logger, schema: schema, secret: secret}
|
||||
var err error
|
||||
|
||||
db.connection, err = sql.Open("sqlite3", "file:"+path+"?_foreign_keys=1")
|
||||
|
@ -58,6 +60,10 @@ func InitSQLite3(path string, schema uint, logger *log.Logger) *CustomDB {
|
|||
if err = db.createSQLite3Tables(); err != nil {
|
||||
db.logger.Fatalln("Error in creating table: ", err)
|
||||
}
|
||||
err = db.CreateInitialAdmin(initialAdminName, initialAdminPassword)
|
||||
if err != nil {
|
||||
db.logger.Fatal("failed to create initial admin. Try to fix and delete old database file: ", err)
|
||||
}
|
||||
case user_version > db.schema:
|
||||
db.logger.Fatalln("Incompatible database schema version. Try updating this software.")
|
||||
case user_version < db.schema:
|
||||
|
@ -71,8 +77,11 @@ func (db CustomDB) createSQLite3Tables() error {
|
|||
name string
|
||||
sql string
|
||||
}{
|
||||
{"User", "userId INTEGER PRIMARY KEY NOT NULL, userName VARCHAR NOT NULL UNIQUE, fullName VARCHAR"},
|
||||
{"Todo", "todoId INTEGER PRIMARY KEY NOT NULL, text VARCHAR NOT NULL, IS_done BOOL NOT NULL DEFAULT false, FK_User_userId INTEGER NOT NULL, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
|
||||
{"User", "userId INTEGER PRIMARY KEY NOT NULL, userName VARCHAR NOT NULL UNIQUE CHECK(length(userName)!=0), passwordHash VARCHAR NOT NULL CHECK(length(passwordHash)!=0), fullName VARCHAR CHECK(length(fullName)!=0)"},
|
||||
{"Todo", "todoId INTEGER PRIMARY KEY NOT NULL, text VARCHAR NOT NULL CHECK(length(text)!=0), IS_done BOOL NOT NULL, FK_User_userId INTEGER NOT NULL, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
|
||||
{"R_User_Role", "relationId INTEGER PRIMARY KEY NOT NULL, FK_Role_roleId INTEGER NOT NULL, FK_User_userId INTEGER NOT NULL, UNIQUE(FK_Role_roleId, FK_User_userId), FOREIGN KEY(FK_Role_roleId) REFERENCES Role(roleId) ON UPDATE CASCADE ON DELETE CASCADE, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
|
||||
{"Role", "roleId INTEGER PRIMARY KEY NOT NULL, roleName VARCHAR NOT NULL UNIQUE CHECK(length(roleName)!=0), IS_admin BOOL NOT NULL, IS_userCreator BOOL NOT NULL"},
|
||||
{"RefreshToken", "tokenId INTEGER PRIMARY KEY NOT NULL, FK_User_userId INTEGER NOT NULL, selector VARCHAR NOT NULL CHECK(length(selector)!=0) UNIQUE, tokenHash VARCHAR NOT NULL CHECK(length(tokenHash)!=0), expiryDate INTEGER NOT NULL, tokenName VARCHAR CHECK(length(tokenName)!=0), FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
|
||||
}
|
||||
for _, table := range tables {
|
||||
_, err := db.connection.Exec("CREATE TABLE IF NOT EXISTS " + table.name + " (" + table.sql + ")")
|
||||
|
@ -91,262 +100,18 @@ func (db CustomDB) createSQLite3Tables() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetUser(user *model.User) (*model.User, error) {
|
||||
id, err := strconv.Atoi(user.ID)
|
||||
func (db CustomDB) CreateInitialAdmin(initialAdminName string, initialAdminPassword string) error {
|
||||
role, err := db.CreateRole(&model.NewRole{RoleName: "admin", IsAdmin: true, IsUserCreator: true})
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
return err
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT userName, fullName FROM User WHERE userId = ?")
|
||||
user, err := db.CreateUser(model.NewUser{UserName: initialAdminName, Password: initialAdminPassword})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(id)
|
||||
if err := result.Scan(&user.UserName, &user.FullName); err != nil {
|
||||
return nil, err
|
||||
_, err = db.AddRole(user.ID, role.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetTodo(todo *model.Todo) (*model.Todo, error) {
|
||||
id, err := strconv.Atoi(todo.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid todoId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT text, IS_done FROM Todo WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(id)
|
||||
if err := result.Scan(&todo.Text, &todo.Done); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return todo, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetTodosFrom(user *model.User) ([]*model.Todo, error) {
|
||||
id, err := strconv.Atoi(user.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT todoId, text, IS_done FROM Todo WHERE FK_User_userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Query(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.Todo
|
||||
for rows.Next() {
|
||||
todo := model.Todo{User: user}
|
||||
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &todo)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetAllUsers() ([]*model.User, error) {
|
||||
rows, err := db.connection.Query("SELECT userId, userName, fullName FROM User")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.User
|
||||
for rows.Next() {
|
||||
var user model.User
|
||||
if err := rows.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &user)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetAllTodos() ([]*model.Todo, error) {
|
||||
rows, err := db.connection.Query("SELECT Todo.todoID, Todo.text, Todo.IS_done, User.userID FROM Todo INNER JOIN User ON Todo.FK_User_userID=User.userID")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var todos []*model.Todo
|
||||
for rows.Next() {
|
||||
var todo = model.Todo{User: &model.User{}}
|
||||
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done, &todo.User.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
todos = append(todos, &todo)
|
||||
}
|
||||
return todos, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) AddUser(newUser model.NewUser) (*model.User, error) {
|
||||
statement, err := db.connection.Prepare("INSERT INTO User (userName, fullName) VALUES (?, ?)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(newUser.UserName, newUser.FullName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
insertId, err := rows.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.User{ID: strconv.FormatInt(insertId, 10), UserName: newUser.UserName, FullName: newUser.FullName}, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) AddTodo(newTodo model.NewTodo) (*model.Todo, error) {
|
||||
statement, err := db.connection.Prepare("INSERT INTO Todo (text, FK_User_userID) VALUES (?, ?)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(newTodo.Text, newTodo.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
insertId, err := rows.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.Todo{ID: strconv.FormatInt(insertId, 10), Text: newTodo.Text, Done: false}, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) UpdateUser(userId string, changes *model.UpdateUser) (*model.User, error) {
|
||||
|
||||
id, err := strconv.Atoi(userId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("UPDATE User SET userName = IFNULL(?, userName), fullName = IFNULL(?, fullName) WHERE userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(changes.UserName, changes.FullName, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return db.GetUser(&model.User{ID: userId})
|
||||
}
|
||||
|
||||
func (db CustomDB) UpdateTodo(todoId string, changes *model.UpdateTodo) (*model.Todo, error) {
|
||||
|
||||
id, err := strconv.Atoi(todoId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("UPDATE Todo SET text = IFNULL(?, text), IS_done = IFNULL(?, IS_done) WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(changes.Text, changes.Done, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return db.GetTodo(&model.Todo{ID: todoId})
|
||||
}
|
||||
|
||||
func (db CustomDB) DeleteUser(userId string) (*string, error) {
|
||||
statement, err := db.connection.Prepare("DELETE FROM User WHERE userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(userId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return &userId, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) DeleteTodo(todoId string) (*string, error) {
|
||||
statement, err := db.connection.Prepare("DELETE FROM Todo WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(todoId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return &todoId, nil
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
/*
|
||||
YetAnotherToDoList
|
||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, version 3.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
|
||||
)
|
||||
|
||||
func (db CustomDB) CreateRole(newRole *model.NewRole) (role *model.Role, err error) {
|
||||
statement, err := db.connection.Prepare("INSERT INTO Role (roleName, IS_admin, IS_userCreator) VALUES (?, ?, ?)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(newRole.RoleName, newRole.IsAdmin, newRole.IsUserCreator)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
insertId, err := rows.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.Role{ID: strconv.FormatInt(insertId, 10), RoleName: newRole.RoleName, IsAdmin: newRole.IsAdmin, IsUserCreator: newRole.IsUserCreator}, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetRole(role *model.Role) (*model.Role, error) {
|
||||
id, err := strconv.Atoi(role.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid roleId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT roleName, IS_admin, IS_userCreator FROM Role WHERE roleId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(id)
|
||||
if err := result.Scan(&role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return role, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetRolesFrom(userId string) ([]*model.Role, error) {
|
||||
numUserId, err := strconv.Atoi(userId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT Role.roleId, Role.roleName, Role.IS_admin, Role.IS_userCreator FROM Role INNER JOIN R_User_Role ON R_User_Role.FK_Role_roleId = Role.roleId WHERE R_User_Role.FK_User_userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Query(numUserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.Role
|
||||
for rows.Next() {
|
||||
role := model.Role{}
|
||||
if err := rows.Scan(&role.ID, &role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &role)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetAllRoles() ([]*model.Role, error) {
|
||||
rows, err := db.connection.Query("SELECT roleId, roleName, IS_admin, IS_userCreator FROM Role")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.Role
|
||||
for rows.Next() {
|
||||
var role model.Role
|
||||
if err := rows.Scan(&role.ID, &role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &role)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) UpdateRole(roleId string, changes *model.UpdateRole) (*model.Role, error) {
|
||||
|
||||
id, err := strconv.Atoi(roleId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("UPDATE Role SET roleName = IFNULL(?, roleName), IS_admin = IFNULL(?, IS_admin), IS_userCreator = IFNULL(?, IS_userCreator) WHERE roleId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(changes.RoleName, changes.IsAdmin, changes.IsUserCreator, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return db.GetRole(&model.Role{ID: roleId})
|
||||
}
|
||||
|
||||
func (db CustomDB) DeleteRole(roleId string) (*string, error) {
|
||||
statement, err := db.connection.Prepare("DELETE FROM Role WHERE roleId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(roleId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return &roleId, nil
|
||||
}
|
|
@ -0,0 +1,199 @@
|
|||
/*
|
||||
YetAnotherToDoList
|
||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, version 3.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
|
||||
)
|
||||
|
||||
// GetOwner takes a todoId and return the owner's userId. Call before Update/Get/DeleteTodo when not IS_admin.
|
||||
func (db CustomDB) GetOwner(todoId string) (string, error) {
|
||||
numTodoId, err := strconv.Atoi(todoId)
|
||||
if err != nil {
|
||||
return "", errors.New("invalid todoId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("SELECT FK_User_userId, FROM Todo WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(numTodoId)
|
||||
var owner string
|
||||
if err := result.Scan(&owner); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return owner, nil
|
||||
}
|
||||
|
||||
// GetTodo takes a *model.Todo with at least ID set and adds the missing fields.
|
||||
func (db CustomDB) GetTodo(todo *model.Todo) (*model.Todo, error) {
|
||||
numTodoId, err := strconv.Atoi(todo.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid todoId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("SELECT text, IS_done FROM Todo WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(numTodoId)
|
||||
if err := result.Scan(&todo.Text, &todo.Done); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return todo, nil
|
||||
}
|
||||
|
||||
// GetTodosFrom gets all todos from the passed *model.User. ID must be set.
|
||||
func (db CustomDB) GetTodosFrom(user *model.User) ([]*model.Todo, error) {
|
||||
id, err := strconv.Atoi(user.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT todoId, text, IS_done FROM Todo WHERE FK_User_userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Query(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.Todo
|
||||
for rows.Next() {
|
||||
todo := model.Todo{User: user}
|
||||
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &todo)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// GetAllTodos gets all todos from the database. Check if the source has the rights to call this.
|
||||
func (db CustomDB) GetAllTodos() ([]*model.Todo, error) {
|
||||
rows, err := db.connection.Query("SELECT todoID, text, IS_done, FK_User_userID FROM Todo")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var todos []*model.Todo
|
||||
for rows.Next() {
|
||||
var todo = model.Todo{User: &model.User{}}
|
||||
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done, &todo.User.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
todos = append(todos, &todo)
|
||||
}
|
||||
return todos, nil
|
||||
}
|
||||
|
||||
// AddTodo adds a Todo to passed UserID. Check if the source has the rights to call this.
|
||||
func (db CustomDB) AddTodo(newTodo model.NewTodo) (*model.Todo, error) {
|
||||
statement, err := db.connection.Prepare("INSERT INTO Todo (text, FK_User_userID) VALUES (?, ?)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(newTodo.Text, newTodo.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
insertId, err := rows.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.Todo{ID: strconv.FormatInt(insertId, 10), Text: newTodo.Text, Done: false}, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) UpdateTodo(todoId string, changes *model.UpdateTodo) (*model.Todo, error) {
|
||||
|
||||
numTodoId, err := strconv.Atoi(todoId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid todoId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("UPDATE Todo SET text = IFNULL(?, text), IS_done = IFNULL(?, IS_done) WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(changes.Text, changes.Done, numTodoId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return db.GetTodo(&model.Todo{ID: todoId})
|
||||
}
|
||||
|
||||
func (db CustomDB) DeleteTodo(todoId string) (deletedTodoId *string, err error) {
|
||||
numTodoId, err := strconv.Atoi(todoId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid todoId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("DELETE FROM Todo WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(numTodoId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return &todoId, nil
|
||||
}
|
|
@ -0,0 +1,257 @@
|
|||
/*
|
||||
YetAnotherToDoList
|
||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, version 3.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
|
||||
)
|
||||
|
||||
func (db CustomDB) CreateUser(newUser model.NewUser) (*model.User, error) {
|
||||
statement, err := db.connection.Prepare("INSERT INTO User (userName, fullName, passwordHash) VALUES (?, NULLIF(?, ''), ?)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ValidateUserName(newUser.UserName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ValidatePassword(newUser.Password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
passwordHash, err := db.GenerateHashFromPassword(newUser.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(newUser.UserName, newUser.FullName, string(passwordHash))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
insertId, err := rows.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.User{ID: strconv.FormatInt(insertId, 10), UserName: newUser.UserName, FullName: newUser.FullName}, nil
|
||||
}
|
||||
|
||||
// GetUser takes a *model.User with at least ID or UserName set and adds the missing fields.
|
||||
func (db CustomDB) GetUser(user *model.User) (*model.User, error) {
|
||||
numUserId, err := strconv.Atoi(user.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT userID, userName, fullName FROM User WHERE userId = ? OR userName = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(numUserId, user.UserName)
|
||||
if err := result.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetAllUsers() ([]*model.User, error) {
|
||||
rows, err := db.connection.Query("SELECT userId, userName, fullName FROM User")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.User
|
||||
for rows.Next() {
|
||||
var user model.User
|
||||
if err := rows.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &user)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) UpdateUser(userId string, changes *model.UpdateUser) (*model.User, error) {
|
||||
var passwordHash *string
|
||||
needAccessTokenRefresh := false
|
||||
|
||||
id, err := strconv.Atoi(userId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("UPDATE User SET userName = IFNULL(?, userName), fullName = IFNULL(NULLIF(?, ''), fullName), passwordHash = IFNULL(?, passwordHash) WHERE userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if *changes.UserName == "" { // interpret empty string as nil
|
||||
changes.UserName = nil
|
||||
} else {
|
||||
if err := ValidateUserName(*changes.UserName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
needAccessTokenRefresh = true
|
||||
}
|
||||
|
||||
if *changes.Password == "" { // interpret empty string as nil
|
||||
passwordHash = nil
|
||||
} else {
|
||||
if err := ValidatePassword(*changes.Password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
passwordHashByte, err := db.GenerateHashFromPassword(*changes.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
passwordHashString := string(passwordHashByte)
|
||||
passwordHash = &passwordHashString
|
||||
needAccessTokenRefresh = true
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(changes.UserName, changes.FullName, passwordHash, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
if needAccessTokenRefresh {
|
||||
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
|
||||
}
|
||||
|
||||
return db.GetUser(&model.User{ID: userId})
|
||||
}
|
||||
|
||||
func (db CustomDB) DeleteUser(userId string) (*string, error) {
|
||||
statement, err := db.connection.Prepare("DELETE FROM User WHERE userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||