add chi-router, auth middleware & user roles.
group config options & split database logic
This commit is contained in:
parent
e4c9563961
commit
a21494f94b
|
@ -1,4 +1,16 @@
|
|||
sqlite3_file: 'YetAnotherToDoList.sqlite3'
|
||||
log_file: 'YetAnotherToDoList.log'
|
||||
log_UTC: false
|
||||
port: 4242
|
||||
database:
|
||||
sqlite3File: 'YetAnotherToDoList.sqlite3'
|
||||
secret: 'aS3cureAppl1cationk3y'
|
||||
initialAdmin:
|
||||
userName: 'admin'
|
||||
password: 'temporaryPassword'
|
||||
|
||||
logging:
|
||||
logFile: 'YetAnotherToDoList.log'
|
||||
logUTC: false
|
||||
|
||||
server:
|
||||
portHTTP: 4242
|
||||
portHTTPS: 4241
|
||||
certFile: 'certFile.crt'
|
||||
keyFile: 'keyFile.key'
|
||||
|
|
|
@ -62,3 +62,8 @@ Commands were run in the order listed below on a debian based system.
|
|||
pnpm install villus graphql
|
||||
pnpm install graphql-tag
|
||||
```
|
||||
- Add go-chi
|
||||
```bash
|
||||
go get -u github.com/go-chi/chi/v5
|
||||
go mod tidy
|
||||
```
|
||||
|
|
26
cmd/root.go
26
cmd/root.go
|
@ -74,8 +74,8 @@ func init() {
|
|||
// will be global for your application.
|
||||
|
||||
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.YetAnotherToDoList.yaml)")
|
||||
rootCmd.PersistentFlags().String("log_file", "", "Path to log file")
|
||||
rootCmd.PersistentFlags().String("sqlite3_file", "", "Path to SQLite3 database")
|
||||
rootCmd.PersistentFlags().String("logFile", "", "Path to log file")
|
||||
rootCmd.PersistentFlags().String("sqlite3File", "", "Path to SQLite3 database")
|
||||
|
||||
// Cobra also supports local flags, which will only run
|
||||
// when this action is called directly.
|
||||
|
@ -120,7 +120,7 @@ func initLog() {
|
|||
time_zone_local, _ := time.Now().Zone()
|
||||
time_zone_offset := strings.Split(time.Now().In(time.Local).String(), " ")[2]
|
||||
|
||||
if viper.GetBool("log_UTC") {
|
||||
if viper.GetBool("logging.logUTC") {
|
||||
utc = log.LUTC
|
||||
time_zone_use = "UTC"
|
||||
time_zone_alt = time_zone_local
|
||||
|
@ -140,25 +140,25 @@ func initLog() {
|
|||
logger_flags := log.Ldate | log.Ltime | utc
|
||||
globals.Logger = log.New(os.Stdout, "", logger_flags)
|
||||
|
||||
if err := viper.BindPFlag("log_file", rootCmd.Flags().Lookup("log_file")); err != nil {
|
||||
if err := viper.BindPFlag("logging.logFile", rootCmd.Flags().Lookup("logFile")); err != nil {
|
||||
fmt.Println("Unable to bind flag:", err)
|
||||
}
|
||||
|
||||
if viper.GetString("log_file") != "" {
|
||||
log_path, err := filepath.Abs(viper.GetString("log_file"))
|
||||
if viper.GetString("logging.logFile") != "" {
|
||||
log_path, err := filepath.Abs(viper.GetString("logging.logFile"))
|
||||
|
||||
globals.Logger.SetOutput(os.Stdout)
|
||||
if err != nil {
|
||||
globals.Logger.Println("Invalid path for log file", log_path)
|
||||
}
|
||||
|
||||
log_file, err := os.OpenFile(log_path, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
logFile, err := os.OpenFile(log_path, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
|
||||
if err != nil {
|
||||
globals.Logger.Println("Failed to write to log file:", err)
|
||||
} else {
|
||||
globals.Logger.Println("Switching to log file", log_path)
|
||||
globals.Logger.SetOutput(log_file)
|
||||
globals.Logger.SetOutput(logFile)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -170,19 +170,21 @@ func initLog() {
|
|||
}
|
||||
|
||||
func initDB() {
|
||||
if err := viper.BindPFlag("sqlite3_file", rootCmd.Flags().Lookup("sqlite3_file")); err != nil {
|
||||
if err := viper.BindPFlag("database.sqlite3File", rootCmd.Flags().Lookup("sqlite3File")); err != nil {
|
||||
fmt.Println("Unable to bind flag:", err)
|
||||
}
|
||||
|
||||
if viper.GetString("sqlite3_file") == "" {
|
||||
if viper.GetString("database.sqlite3File") == "" {
|
||||
globals.Logger.Fatalln("No SQLite3 file specified")
|
||||
}
|
||||
|
||||
db_path, err := filepath.Abs(viper.GetString("sqlite3_file"))
|
||||
db_path, err := filepath.Abs(viper.GetString("database.sqlite3File"))
|
||||
if err != nil {
|
||||
globals.Logger.Fatalln("Invalid path for SQLite3 file", db_path)
|
||||
}
|
||||
|
||||
globals.Logger.Println("Connecting to SQLite3", db_path)
|
||||
globals.DB = database.InitSQLite3(db_path, globals.DB_schema, globals.Logger)
|
||||
globals.DB = database.InitSQLite3(db_path, globals.DB_schema, globals.Logger, []byte(viper.GetString("database.secret")), viper.GetString("database.initialAdmin.userName"), viper.GetString("database.initialAdmin.password"))
|
||||
globals.DB.CleanExpiredRefreshTokensTicker(time.Minute * 10)
|
||||
globals.DB.CleanRevokedAccessTokensTicker(time.Minute * 10)
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ var serverCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
globals.Logger.Println("starting http server...")
|
||||
server.StartServer(viper.GetInt("port"))
|
||||
server.StartServer(viper.GetInt("server.portHTTP"), viper.GetInt("server.portHTTPS"), viper.GetString("server.certFile"), viper.GetString("server.keyFile"))
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,475 @@
|
|||
/*
|
||||
YetAnotherToDoList
|
||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, version 3.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package database
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
|
||||
)
|
||||
|
||||
type RefreshToken struct {
|
||||
Selector string `json:"selector"`
|
||||
Token string `json:"token"`
|
||||
ExpiryDate int `json:"expiryDate"`
|
||||
}
|
||||
|
||||
type AccessToken struct {
|
||||
UserId string `json:"userId"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
IsUserCreator bool `json:"isUserCreator"`
|
||||
ExpiryDate int `json:"expiryDate"`
|
||||
}
|
||||
|
||||
const refreshTokenLifetime = "+10 day" // TODO: add to viper
|
||||
const accessTokenLifetime = time.Minute * 10
|
||||
|
||||
const minPasswordLength = 10
|
||||
const minUserNameLength = 10
|
||||
|
||||
var RevokedAccessTokens []*AccessToken
|
||||
|
||||
// ValidatePassword validates the passed string against the password criteria.
|
||||
func ValidatePassword(password string) error {
|
||||
if len([]rune(password)) < minPasswordLength {
|
||||
return fmt.Errorf("password must be at least %d characters long", minPasswordLength)
|
||||
}
|
||||
for i, c := range password {
|
||||
if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
|
||||
return fmt.Errorf("password contains none-alphanumeric symbol at index %d", i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateUserName validates the passed string against the user name criteria.
|
||||
func ValidateUserName(userName string) error {
|
||||
if len([]rune(userName)) < minUserNameLength {
|
||||
return fmt.Errorf("userName must be at least %d characters long", minUserNameLength)
|
||||
}
|
||||
for i, c := range userName {
|
||||
if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
|
||||
return fmt.Errorf("userName contains none-alphanumeric symbol at index %d", i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateHashFromPassword generates a hash of the passed password. Returns salt and salted & peppered hash.
|
||||
func (db CustomDB) GenerateHashFromPassword(password string) (passwordHash []byte, err error) {
|
||||
hashBytes, err := bcrypt.GenerateFromPassword(bytes.Join([][]byte{db.secret, []byte(password)}, nil), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hashBytes, nil
|
||||
}
|
||||
|
||||
// GetRefreshTokenOwner takes a tokenId and return the owner's userId. Call before Update/Get/DeleteRefreshToken when not IS_admin.
|
||||
func (db CustomDB) GetRefreshTokenOwner(tokenId string) (string, error) {
|
||||
numTokenId, err := strconv.Atoi(tokenId)
|
||||
if err != nil {
|
||||
return "", errors.New("invalid tokenId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("SELECT FK_User_userId, FROM RefreshToken WHERE tokenId = ?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(numTokenId)
|
||||
var owner string
|
||||
if err := result.Scan(&owner); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return owner, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) ValidateUserCredentials(userId *string, userName *string, password string) (validUserId string, err error) {
|
||||
var result *sql.Row
|
||||
var hash string
|
||||
|
||||
if userId != nil { // user userId
|
||||
numUserId, err := strconv.Atoi(*userId)
|
||||
if err != nil {
|
||||
return "", errors.New("userId not numeric")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("SELECT passwordHash FROM User WHERE userId = ?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = statement.QueryRow(numUserId)
|
||||
if err := result.Scan(&hash); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else if userName != nil { // user userName
|
||||
statement, err := db.connection.Prepare("SELECT userId, passwordHash FROM User WHERE userName = ?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = statement.QueryRow(&userName)
|
||||
if err := result.Scan(&userId, &hash); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
return "", errors.New("neither userId nor userName specified")
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), bytes.Join([][]byte{db.secret, []byte(password)}, nil)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return *userId, nil
|
||||
}
|
||||
|
||||
// IssueRefreshToken issues a refresh token if the passed user credentials are valid. Returned refresh token can be passed to IssueAccessToken.
|
||||
func (db CustomDB) IssueRefreshToken(userId string, tokenName *string) (refreshToken *RefreshToken, refreshTokenId string, err error) {
|
||||
numUserId, err := strconv.Atoi(userId)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("userId not numeric")
|
||||
}
|
||||
selector := make([]byte, 9)
|
||||
if _, err := rand.Read(selector); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
token := make([]byte, 33)
|
||||
if _, err := rand.Read(token); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("INSERT INTO RefreshToken (FK_User_userId, selector, tokenHash, expiryDate, tokenName) VALUES (?, ?, ?, unixepoch('now','" + refreshTokenLifetime + "'), NULLIF(?, '')) RETURNING tokenId, expiryDate")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
encSelector := base64.RawURLEncoding.EncodeToString(selector)
|
||||
encToken := base64.RawURLEncoding.EncodeToString(token)
|
||||
|
||||
tokenHash := sha256.Sum256(token)
|
||||
|
||||
var expiryDate int // int(time.Now().AddDate(0, 1, 0).Unix())
|
||||
var tokenId string
|
||||
result := statement.QueryRow(numUserId, encSelector, base64.RawURLEncoding.EncodeToString(tokenHash[:]), &tokenName)
|
||||
|
||||
if err := result.Scan(&tokenId, &expiryDate); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return &RefreshToken{Selector: encSelector, Token: encToken, ExpiryDate: expiryDate}, tokenId, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetRefreshToken(token *model.RefreshToken) (*model.RefreshToken, error) {
|
||||
numTokenId, err := strconv.Atoi(token.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid tokenId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("SELECT expiryDate, tokenName FROM RefreshToken WHERE tokenId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(numTokenId)
|
||||
if err := result.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetRefreshTokensFrom(userId string) ([]*model.RefreshToken, error) {
|
||||
numUserId, err := strconv.Atoi(userId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT tokenId, expiaryDate, tokenName FROM RefreshToken WHERE FK_User_userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Query(numUserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.RefreshToken
|
||||
for rows.Next() {
|
||||
token := model.RefreshToken{}
|
||||
if err := rows.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &token)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetAllRefreshTokens() ([]*model.RefreshToken, error) {
|
||||
|
||||
statement, err := db.connection.Prepare("SELECT tokenID, FK_User_userId, expiryDate, tokenName FROM RefreshToken")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Query()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.RefreshToken
|
||||
for rows.Next() {
|
||||
var token model.RefreshToken
|
||||
if err := rows.Scan(&token.ID, &token.UserID, &token.ExpiryDate, &token.TokenName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &token)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) UpdateRefreshToken(tokenId string, changes *model.UpdateRefreshToken) (*model.RefreshToken, error) {
|
||||
numTokenId, err := strconv.Atoi(tokenId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid tokenId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("UPDATE AuthToken SET tokenName = ? WHERE tokenId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(changes.TokenName, numTokenId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
token, err := db.GetRefreshToken(&model.RefreshToken{ID: tokenId})
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get updated token")
|
||||
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// RevokeRefreshToken revokes the access token matching the tokenId. Also calls RevokeAccessToken.
|
||||
func (db CustomDB) RevokeRefreshToken(tokenId string) (*string, error) {
|
||||
// TODO: return string instead of *string
|
||||
numTokenId, err := strconv.Atoi(tokenId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid tokenId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("DELETE FROM RefreshToken WHERE tokenId = ? RETURNING FK_User_userId")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(numTokenId)
|
||||
|
||||
var userId string
|
||||
if err := result.Scan(&userId); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
|
||||
return &tokenId, nil
|
||||
}
|
||||
|
||||
// IssueAccessToken issues an access token if the passed refresh token is valid. Returned access token must be passed to SignAccessToken to be accepted.
|
||||
func (db CustomDB) IssueAccessToken(refreshToken *RefreshToken) (*AccessToken, error) {
|
||||
statement, err := db.connection.Prepare("SELECT RefreshToken.tokenHash, RefreshToken.FK_User_userId, Role.IS_admin, ROLE.IS_userCreator FROM RefreshToken INNER JOIN R_User_Role ON RefreshToken.FK_User_userId = R_User_Role.FK_User_userId INNER JOIN Role ON R_User_Role.FK_Role_roleId = Role.roleId WHERE RefreshToken.selector = ? AND RefreshToken.expiryDate >= unixepoch('now')")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(refreshToken.Selector)
|
||||
|
||||
var tokenHash string
|
||||
var newAccessToken AccessToken
|
||||
if err := result.Scan(&tokenHash, &newAccessToken.UserId, &newAccessToken.IsAdmin, &newAccessToken.IsUserCreator); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decUserToken, err := base64.RawURLEncoding.DecodeString(refreshToken.Token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userTokenHash := sha256.Sum256(decUserToken)
|
||||
encUserTokenHash := base64.RawURLEncoding.EncodeToString(userTokenHash[:])
|
||||
if encUserTokenHash != tokenHash {
|
||||
return nil, errors.New("failed to issue access token: refreshToken does not match")
|
||||
}
|
||||
|
||||
newAccessToken.ExpiryDate = int(time.Now().Add(accessTokenLifetime).Unix())
|
||||
return &newAccessToken, nil
|
||||
}
|
||||
|
||||
// SignAccessToken signs an access token and attaches a header. Returns access token encoded as jwt.
|
||||
func (db CustomDB) SignAccessToken(accessToken AccessToken) (encAccessToken string, err error) {
|
||||
|
||||
data, err := json.Marshal(accessToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
body := base64.RawURLEncoding.EncodeToString(data)
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
|
||||
combined := header + "." + body
|
||||
|
||||
mac := hmac.New(sha256.New, db.secret)
|
||||
mac.Write([]byte(combined))
|
||||
signature := mac.Sum(nil)
|
||||
|
||||
return combined + "." + base64.RawURLEncoding.EncodeToString(signature), nil
|
||||
}
|
||||
|
||||
func (db CustomDB) CheckAccessToken(encAccessToken string) (accessToken *AccessToken, err error) {
|
||||
accessTokenBlob := strings.Split(encAccessToken, ".")
|
||||
|
||||
header, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !bytes.Equal(header, []byte(`{"alg":"HS256","typ":"JWT"}`)) {
|
||||
return nil, errors.New("access token header wrong")
|
||||
}
|
||||
|
||||
body, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var data AccessToken
|
||||
err = json.Unmarshal(body, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: maybe change part of signKey instead?
|
||||
for _, revokedAccessToken := range RevokedAccessTokens {
|
||||
if data.UserId == revokedAccessToken.UserId && data.ExpiryDate <= revokedAccessToken.ExpiryDate {
|
||||
return nil, errors.New("access token revoked")
|
||||
}
|
||||
}
|
||||
|
||||
if data.ExpiryDate < int(time.Now().Unix()) {
|
||||
return nil, errors.New("access token expired")
|
||||
}
|
||||
|
||||
if data.ExpiryDate < initTimeStamp+int(accessTokenLifetime.Seconds()) {
|
||||
return nil, errors.New("access token expired prematurely")
|
||||
}
|
||||
|
||||
signature, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[2])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, db.secret)
|
||||
mac.Write([]byte(accessTokenBlob[0]))
|
||||
mac.Write([]byte("."))
|
||||
mac.Write([]byte(accessTokenBlob[1]))
|
||||
expectedSignature := mac.Sum(nil)
|
||||
if !hmac.Equal(signature, expectedSignature) {
|
||||
return nil, errors.New("access token signature does not match")
|
||||
}
|
||||
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// RevokeAccessToken revokes all access tokens with matching UserId and UserRole that don't have a later ExpiryDate.
|
||||
// revokedAccessToken.ExpiryDate should be set to now + token-lifetime.
|
||||
func RevokeAccessToken(accessToken *AccessToken) {
|
||||
RevokedAccessTokens = append(RevokedAccessTokens, accessToken)
|
||||
}
|
||||
|
||||
// CleanRevokedTokensTicker removes expired tokens from the list. This should be called in an interval of > accessTokenLifetime.
|
||||
func (db CustomDB) CleanRevokedAccessTokensTicker(interval time.Duration) (stopCleaner chan bool) {
|
||||
ticker := time.NewTicker(interval)
|
||||
stop := make(chan bool)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
db.logger.Println("cleaning revoked access tokens")
|
||||
for i, revokedAccessToken := range RevokedAccessTokens {
|
||||
if revokedAccessToken.ExpiryDate < int(time.Now().Unix()) {
|
||||
revokedAccessToken = nil
|
||||
slices.Delete(RevokedAccessTokens, i, i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return stop
|
||||
}
|
||||
|
||||
func (db CustomDB) CleanExpiredRefreshTokensTicker(interval time.Duration) (stopCleaner chan bool) {
|
||||
ticker := time.NewTicker(interval)
|
||||
stop := make(chan bool)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
db.logger.Println("cleaning expired refresh tokens")
|
||||
_, err := db.connection.Exec("DELETE FROM RefreshToken WHERE expiryDate < unixepoch('now')")
|
||||
if err != nil {
|
||||
db.logger.Println("failed to clean expired refresh tokens")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return stop
|
||||
}
|
283
database/main.go
283
database/main.go
|
@ -18,10 +18,9 @@ package database
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
|
||||
)
|
||||
|
@ -30,11 +29,14 @@ type CustomDB struct {
|
|||
connection *sql.DB
|
||||
logger *log.Logger
|
||||
schema uint
|
||||
secret []byte
|
||||
}
|
||||
|
||||
func InitSQLite3(path string, schema uint, logger *log.Logger) *CustomDB {
|
||||
var initTimeStamp int
|
||||
|
||||
db := CustomDB{logger: logger, schema: schema}
|
||||
func InitSQLite3(path string, schema uint, logger *log.Logger, secret []byte, initialAdminName string, initialAdminPassword string) *CustomDB {
|
||||
initTimeStamp = int(time.Now().Unix())
|
||||
db := CustomDB{logger: logger, schema: schema, secret: secret}
|
||||
var err error
|
||||
|
||||
db.connection, err = sql.Open("sqlite3", "file:"+path+"?_foreign_keys=1")
|
||||
|
@ -58,6 +60,10 @@ func InitSQLite3(path string, schema uint, logger *log.Logger) *CustomDB {
|
|||
if err = db.createSQLite3Tables(); err != nil {
|
||||
db.logger.Fatalln("Error in creating table: ", err)
|
||||
}
|
||||
err = db.CreateInitialAdmin(initialAdminName, initialAdminPassword)
|
||||
if err != nil {
|
||||
db.logger.Fatal("failed to create initial admin. Try to fix and delete old database file: ", err)
|
||||
}
|
||||
case user_version > db.schema:
|
||||
db.logger.Fatalln("Incompatible database schema version. Try updating this software.")
|
||||
case user_version < db.schema:
|
||||
|
@ -71,8 +77,11 @@ func (db CustomDB) createSQLite3Tables() error {
|
|||
name string
|
||||
sql string
|
||||
}{
|
||||
{"User", "userId INTEGER PRIMARY KEY NOT NULL, userName VARCHAR NOT NULL UNIQUE, fullName VARCHAR"},
|
||||
{"Todo", "todoId INTEGER PRIMARY KEY NOT NULL, text VARCHAR NOT NULL, IS_done BOOL NOT NULL DEFAULT false, FK_User_userId INTEGER NOT NULL, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
|
||||
{"User", "userId INTEGER PRIMARY KEY NOT NULL, userName VARCHAR NOT NULL UNIQUE CHECK(length(userName)!=0), passwordHash VARCHAR NOT NULL CHECK(length(passwordHash)!=0), fullName VARCHAR CHECK(length(fullName)!=0)"},
|
||||
{"Todo", "todoId INTEGER PRIMARY KEY NOT NULL, text VARCHAR NOT NULL CHECK(length(text)!=0), IS_done BOOL NOT NULL, FK_User_userId INTEGER NOT NULL, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
|
||||
{"R_User_Role", "relationId INTEGER PRIMARY KEY NOT NULL, FK_Role_roleId INTEGER NOT NULL, FK_User_userId INTEGER NOT NULL, UNIQUE(FK_Role_roleId, FK_User_userId), FOREIGN KEY(FK_Role_roleId) REFERENCES Role(roleId) ON UPDATE CASCADE ON DELETE CASCADE, FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
|
||||
{"Role", "roleId INTEGER PRIMARY KEY NOT NULL, roleName VARCHAR NOT NULL UNIQUE CHECK(length(roleName)!=0), IS_admin BOOL NOT NULL, IS_userCreator BOOL NOT NULL"},
|
||||
{"RefreshToken", "tokenId INTEGER PRIMARY KEY NOT NULL, FK_User_userId INTEGER NOT NULL, selector VARCHAR NOT NULL CHECK(length(selector)!=0) UNIQUE, tokenHash VARCHAR NOT NULL CHECK(length(tokenHash)!=0), expiryDate INTEGER NOT NULL, tokenName VARCHAR CHECK(length(tokenName)!=0), FOREIGN KEY(FK_User_userId) REFERENCES User(userId) ON UPDATE CASCADE ON DELETE CASCADE"},
|
||||
}
|
||||
for _, table := range tables {
|
||||
_, err := db.connection.Exec("CREATE TABLE IF NOT EXISTS " + table.name + " (" + table.sql + ")")
|
||||
|
@ -91,262 +100,18 @@ func (db CustomDB) createSQLite3Tables() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetUser(user *model.User) (*model.User, error) {
|
||||
id, err := strconv.Atoi(user.ID)
|
||||
func (db CustomDB) CreateInitialAdmin(initialAdminName string, initialAdminPassword string) error {
|
||||
role, err := db.CreateRole(&model.NewRole{RoleName: "admin", IsAdmin: true, IsUserCreator: true})
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
return err
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT userName, fullName FROM User WHERE userId = ?")
|
||||
user, err := db.CreateUser(model.NewUser{UserName: initialAdminName, Password: initialAdminPassword})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(id)
|
||||
if err := result.Scan(&user.UserName, &user.FullName); err != nil {
|
||||
return nil, err
|
||||
_, err = db.AddRole(user.ID, role.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetTodo(todo *model.Todo) (*model.Todo, error) {
|
||||
id, err := strconv.Atoi(todo.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid todoId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT text, IS_done FROM Todo WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(id)
|
||||
if err := result.Scan(&todo.Text, &todo.Done); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return todo, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetTodosFrom(user *model.User) ([]*model.Todo, error) {
|
||||
id, err := strconv.Atoi(user.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT todoId, text, IS_done FROM Todo WHERE FK_User_userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Query(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.Todo
|
||||
for rows.Next() {
|
||||
todo := model.Todo{User: user}
|
||||
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &todo)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetAllUsers() ([]*model.User, error) {
|
||||
rows, err := db.connection.Query("SELECT userId, userName, fullName FROM User")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.User
|
||||
for rows.Next() {
|
||||
var user model.User
|
||||
if err := rows.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &user)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetAllTodos() ([]*model.Todo, error) {
|
||||
rows, err := db.connection.Query("SELECT Todo.todoID, Todo.text, Todo.IS_done, User.userID FROM Todo INNER JOIN User ON Todo.FK_User_userID=User.userID")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var todos []*model.Todo
|
||||
for rows.Next() {
|
||||
var todo = model.Todo{User: &model.User{}}
|
||||
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done, &todo.User.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
todos = append(todos, &todo)
|
||||
}
|
||||
return todos, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) AddUser(newUser model.NewUser) (*model.User, error) {
|
||||
statement, err := db.connection.Prepare("INSERT INTO User (userName, fullName) VALUES (?, ?)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(newUser.UserName, newUser.FullName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
insertId, err := rows.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.User{ID: strconv.FormatInt(insertId, 10), UserName: newUser.UserName, FullName: newUser.FullName}, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) AddTodo(newTodo model.NewTodo) (*model.Todo, error) {
|
||||
statement, err := db.connection.Prepare("INSERT INTO Todo (text, FK_User_userID) VALUES (?, ?)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(newTodo.Text, newTodo.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
insertId, err := rows.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.Todo{ID: strconv.FormatInt(insertId, 10), Text: newTodo.Text, Done: false}, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) UpdateUser(userId string, changes *model.UpdateUser) (*model.User, error) {
|
||||
|
||||
id, err := strconv.Atoi(userId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("UPDATE User SET userName = IFNULL(?, userName), fullName = IFNULL(?, fullName) WHERE userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(changes.UserName, changes.FullName, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return db.GetUser(&model.User{ID: userId})
|
||||
}
|
||||
|
||||
func (db CustomDB) UpdateTodo(todoId string, changes *model.UpdateTodo) (*model.Todo, error) {
|
||||
|
||||
id, err := strconv.Atoi(todoId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("UPDATE Todo SET text = IFNULL(?, text), IS_done = IFNULL(?, IS_done) WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(changes.Text, changes.Done, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return db.GetTodo(&model.Todo{ID: todoId})
|
||||
}
|
||||
|
||||
func (db CustomDB) DeleteUser(userId string) (*string, error) {
|
||||
statement, err := db.connection.Prepare("DELETE FROM User WHERE userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(userId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return &userId, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) DeleteTodo(todoId string) (*string, error) {
|
||||
statement, err := db.connection.Prepare("DELETE FROM Todo WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(todoId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return &todoId, nil
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
/*
|
||||
YetAnotherToDoList
|
||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, version 3.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
|
||||
)
|
||||
|
||||
func (db CustomDB) CreateRole(newRole *model.NewRole) (role *model.Role, err error) {
|
||||
statement, err := db.connection.Prepare("INSERT INTO Role (roleName, IS_admin, IS_userCreator) VALUES (?, ?, ?)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(newRole.RoleName, newRole.IsAdmin, newRole.IsUserCreator)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
insertId, err := rows.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.Role{ID: strconv.FormatInt(insertId, 10), RoleName: newRole.RoleName, IsAdmin: newRole.IsAdmin, IsUserCreator: newRole.IsUserCreator}, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetRole(role *model.Role) (*model.Role, error) {
|
||||
id, err := strconv.Atoi(role.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid roleId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT roleName, IS_admin, IS_userCreator FROM Role WHERE roleId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(id)
|
||||
if err := result.Scan(&role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return role, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetRolesFrom(userId string) ([]*model.Role, error) {
|
||||
numUserId, err := strconv.Atoi(userId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT Role.roleId, Role.roleName, Role.IS_admin, Role.IS_userCreator FROM Role INNER JOIN R_User_Role ON R_User_Role.FK_Role_roleId = Role.roleId WHERE R_User_Role.FK_User_userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Query(numUserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.Role
|
||||
for rows.Next() {
|
||||
role := model.Role{}
|
||||
if err := rows.Scan(&role.ID, &role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &role)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetAllRoles() ([]*model.Role, error) {
|
||||
rows, err := db.connection.Query("SELECT roleId, roleName, IS_admin, IS_userCreator FROM Role")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.Role
|
||||
for rows.Next() {
|
||||
var role model.Role
|
||||
if err := rows.Scan(&role.ID, &role.RoleName, &role.IsAdmin, &role.IsUserCreator); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &role)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) UpdateRole(roleId string, changes *model.UpdateRole) (*model.Role, error) {
|
||||
|
||||
id, err := strconv.Atoi(roleId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("UPDATE Role SET roleName = IFNULL(?, roleName), IS_admin = IFNULL(?, IS_admin), IS_userCreator = IFNULL(?, IS_userCreator) WHERE roleId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(changes.RoleName, changes.IsAdmin, changes.IsUserCreator, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return db.GetRole(&model.Role{ID: roleId})
|
||||
}
|
||||
|
||||
func (db CustomDB) DeleteRole(roleId string) (*string, error) {
|
||||
statement, err := db.connection.Prepare("DELETE FROM Role WHERE roleId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(roleId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return &roleId, nil
|
||||
}
|
|
@ -0,0 +1,199 @@
|
|||
/*
|
||||
YetAnotherToDoList
|
||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, version 3.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
|
||||
)
|
||||
|
||||
// GetOwner takes a todoId and return the owner's userId. Call before Update/Get/DeleteTodo when not IS_admin.
|
||||
func (db CustomDB) GetOwner(todoId string) (string, error) {
|
||||
numTodoId, err := strconv.Atoi(todoId)
|
||||
if err != nil {
|
||||
return "", errors.New("invalid todoId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("SELECT FK_User_userId, FROM Todo WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(numTodoId)
|
||||
var owner string
|
||||
if err := result.Scan(&owner); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return owner, nil
|
||||
}
|
||||
|
||||
// GetTodo takes a *model.Todo with at least ID set and adds the missing fields.
|
||||
func (db CustomDB) GetTodo(todo *model.Todo) (*model.Todo, error) {
|
||||
numTodoId, err := strconv.Atoi(todo.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid todoId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("SELECT text, IS_done FROM Todo WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(numTodoId)
|
||||
if err := result.Scan(&todo.Text, &todo.Done); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return todo, nil
|
||||
}
|
||||
|
||||
// GetTodosFrom gets all todos from the passed *model.User. ID must be set.
|
||||
func (db CustomDB) GetTodosFrom(user *model.User) ([]*model.Todo, error) {
|
||||
id, err := strconv.Atoi(user.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT todoId, text, IS_done FROM Todo WHERE FK_User_userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Query(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.Todo
|
||||
for rows.Next() {
|
||||
todo := model.Todo{User: user}
|
||||
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &todo)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// GetAllTodos gets all todos from the database. Check if the source has the rights to call this.
|
||||
func (db CustomDB) GetAllTodos() ([]*model.Todo, error) {
|
||||
rows, err := db.connection.Query("SELECT todoID, text, IS_done, FK_User_userID FROM Todo")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var todos []*model.Todo
|
||||
for rows.Next() {
|
||||
var todo = model.Todo{User: &model.User{}}
|
||||
if err := rows.Scan(&todo.ID, &todo.Text, &todo.Done, &todo.User.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
todos = append(todos, &todo)
|
||||
}
|
||||
return todos, nil
|
||||
}
|
||||
|
||||
// AddTodo adds a Todo to passed UserID. Check if the source has the rights to call this.
|
||||
func (db CustomDB) AddTodo(newTodo model.NewTodo) (*model.Todo, error) {
|
||||
statement, err := db.connection.Prepare("INSERT INTO Todo (text, FK_User_userID) VALUES (?, ?)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(newTodo.Text, newTodo.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
insertId, err := rows.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.Todo{ID: strconv.FormatInt(insertId, 10), Text: newTodo.Text, Done: false}, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) UpdateTodo(todoId string, changes *model.UpdateTodo) (*model.Todo, error) {
|
||||
|
||||
numTodoId, err := strconv.Atoi(todoId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid todoId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("UPDATE Todo SET text = IFNULL(?, text), IS_done = IFNULL(?, IS_done) WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(changes.Text, changes.Done, numTodoId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return db.GetTodo(&model.Todo{ID: todoId})
|
||||
}
|
||||
|
||||
func (db CustomDB) DeleteTodo(todoId string) (deletedTodoId *string, err error) {
|
||||
numTodoId, err := strconv.Atoi(todoId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid todoId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("DELETE FROM Todo WHERE todoId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(numTodoId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
return &todoId, nil
|
||||
}
|
|
@ -0,0 +1,257 @@
|
|||
/*
|
||||
YetAnotherToDoList
|
||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, version 3.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
|
||||
)
|
||||
|
||||
func (db CustomDB) CreateUser(newUser model.NewUser) (*model.User, error) {
|
||||
statement, err := db.connection.Prepare("INSERT INTO User (userName, fullName, passwordHash) VALUES (?, NULLIF(?, ''), ?)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ValidateUserName(newUser.UserName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ValidatePassword(newUser.Password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
passwordHash, err := db.GenerateHashFromPassword(newUser.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(newUser.UserName, newUser.FullName, string(passwordHash))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
insertId, err := rows.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.User{ID: strconv.FormatInt(insertId, 10), UserName: newUser.UserName, FullName: newUser.FullName}, nil
|
||||
}
|
||||
|
||||
// GetUser takes a *model.User with at least ID or UserName set and adds the missing fields.
|
||||
func (db CustomDB) GetUser(user *model.User) (*model.User, error) {
|
||||
numUserId, err := strconv.Atoi(user.ID)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
statement, err := db.connection.Prepare("SELECT userID, userName, fullName FROM User WHERE userId = ? OR userName = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := statement.QueryRow(numUserId, user.UserName)
|
||||
if err := result.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) GetAllUsers() ([]*model.User, error) {
|
||||
rows, err := db.connection.Query("SELECT userId, userName, fullName FROM User")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var all []*model.User
|
||||
for rows.Next() {
|
||||
var user model.User
|
||||
if err := rows.Scan(&user.ID, &user.UserName, &user.FullName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, &user)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) UpdateUser(userId string, changes *model.UpdateUser) (*model.User, error) {
|
||||
var passwordHash *string
|
||||
needAccessTokenRefresh := false
|
||||
|
||||
id, err := strconv.Atoi(userId)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("UPDATE User SET userName = IFNULL(?, userName), fullName = IFNULL(NULLIF(?, ''), fullName), passwordHash = IFNULL(?, passwordHash) WHERE userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if *changes.UserName == "" { // interpret empty string as nil
|
||||
changes.UserName = nil
|
||||
} else {
|
||||
if err := ValidateUserName(*changes.UserName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
needAccessTokenRefresh = true
|
||||
}
|
||||
|
||||
if *changes.Password == "" { // interpret empty string as nil
|
||||
passwordHash = nil
|
||||
} else {
|
||||
if err := ValidatePassword(*changes.Password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
passwordHashByte, err := db.GenerateHashFromPassword(*changes.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
passwordHashString := string(passwordHashByte)
|
||||
passwordHash = &passwordHashString
|
||||
needAccessTokenRefresh = true
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(changes.UserName, changes.FullName, passwordHash, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
if needAccessTokenRefresh {
|
||||
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
|
||||
}
|
||||
|
||||
return db.GetUser(&model.User{ID: userId})
|
||||
}
|
||||
|
||||
func (db CustomDB) DeleteUser(userId string) (*string, error) {
|
||||
statement, err := db.connection.Prepare("DELETE FROM User WHERE userId = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(userId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return nil, errors.New("no rows affected")
|
||||
}
|
||||
|
||||
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
|
||||
return &userId, nil
|
||||
}
|
||||
|
||||
func (db CustomDB) AddRole(userId string, roleId string) (relationId string, err error) {
|
||||
encUserId, err := strconv.Atoi(userId)
|
||||
if err != nil {
|
||||
return "", errors.New("invalid userId")
|
||||
}
|
||||
encRoleId, err := strconv.Atoi(roleId)
|
||||
if err != nil {
|
||||
return "", errors.New("invalid roleId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("INSERT INTO R_User_Role (FK_User_userId, FK_Role_roleId) VALUES (?, ?)")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(encUserId, encRoleId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return "", errors.New("no rows affected")
|
||||
}
|
||||
|
||||
insertId, err := rows.LastInsertId()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
|
||||
return strconv.FormatInt(insertId, 10), nil
|
||||
}
|
||||
|
||||
func (db CustomDB) RemoveRole(userId string, roleId string) (relationId string, err error) {
|
||||
encUserId, err := strconv.Atoi(userId)
|
||||
if err != nil {
|
||||
return "", errors.New("invalid userId")
|
||||
}
|
||||
encRoleId, err := strconv.Atoi(roleId)
|
||||
if err != nil {
|
||||
return "", errors.New("invalid roleId")
|
||||
}
|
||||
|
||||
statement, err := db.connection.Prepare("DELETE FROM R_User_Role WHERE FK_User_userId = ? AND FK_Role_roleId = ?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
rows, err := statement.Exec(encUserId, encRoleId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
num, err := rows.RowsAffected()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if num < 1 {
|
||||
return "", errors.New("no rows affected")
|
||||
}
|
||||
|
||||
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
|
||||
return strconv.FormatInt(int64(encRoleId), 10), nil
|
||||
}
|
|
@ -7,7 +7,7 @@ export default defineConfig({
|
|||
server: {
|
||||
port: 4243,
|
||||
proxy: {
|
||||
'^/(api|playground|version)': 'http://localhost:4242/'
|
||||
'^/(api|playground|version|auth)': 'http://localhost:4242/'
|
||||
}
|
||||
},
|
||||
build: {
|
||||
|
|
2
go.mod
2
go.mod
|
@ -4,10 +4,12 @@ go 1.20
|
|||
|
||||
require (
|
||||
github.com/99designs/gqlgen v0.17.37
|
||||
github.com/go-chi/chi v1.5.5
|
||||
github.com/mattn/go-sqlite3 v1.14.17
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/spf13/viper v1.16.0
|
||||
github.com/vektah/gqlparser/v2 v2.5.9
|
||||
golang.org/x/crypto v0.9.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
|
4
go.sum
4
go.sum
|
@ -69,6 +69,8 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7
|
|||
github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
|
||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||
github.com/go-chi/chi v1.5.5 h1:vOB/HbEMt9QqBqErz07QehcOKHaWFtuj87tTDVz2qXE=
|
||||
github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw=
|
||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
|
@ -216,6 +218,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
|||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
|
||||
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
|
|
|
@ -108,3 +108,5 @@ models:
|
|||
fields:
|
||||
todos:
|
||||
resolver: true
|
||||
roles:
|
||||
resolver: true
|
||||
|
|
2358
graph/generated.go
2358
graph/generated.go
File diff suppressed because it is too large
Load Diff
|
@ -2,21 +2,51 @@
|
|||
|
||||
package model
|
||||
|
||||
type NewRefreshToken struct {
|
||||
TokenName *string `json:"tokenName,omitempty"`
|
||||
}
|
||||
|
||||
type NewRole struct {
|
||||
RoleName string `json:"roleName"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
IsUserCreator bool `json:"isUserCreator"`
|
||||
}
|
||||
|
||||
type NewTodo struct {
|
||||
Text string `json:"text"`
|
||||
UserID string `json:"userId"`
|
||||
}
|
||||
|
||||
type NewUser struct {
|
||||
UserName string `json:"userName"`
|
||||
FullName string `json:"fullName"`
|
||||
UserName string `json:"userName"`
|
||||
FullName *string `json:"fullName,omitempty"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
UserName string `json:"userName"`
|
||||
FullName string `json:"fullName"`
|
||||
Todos []*Todo `json:"todos"`
|
||||
type RefreshToken struct {
|
||||
ID string `json:"id"`
|
||||
ExpiryDate int `json:"expiryDate"`
|
||||
TokenName *string `json:"tokenName,omitempty"`
|
||||
Selector *string `json:"selector,omitempty"`
|
||||
Token *string `json:"token,omitempty"`
|
||||
UserID string `json:"userId"`
|
||||
}
|
||||
|
||||
type Role struct {
|
||||
ID string `json:"id"`
|
||||
RoleName string `json:"roleName"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
IsUserCreator bool `json:"isUserCreator"`
|
||||
}
|
||||
|
||||
type UpdateRefreshToken struct {
|
||||
TokenName *string `json:"tokenName,omitempty"`
|
||||
}
|
||||
|
||||
type UpdateRole struct {
|
||||
RoleName *string `json:"roleName,omitempty"`
|
||||
IsAdmin *bool `json:"isAdmin,omitempty"`
|
||||
IsUserCreator *bool `json:"isUserCreator,omitempty"`
|
||||
}
|
||||
|
||||
type UpdateTodo struct {
|
||||
|
@ -27,4 +57,13 @@ type UpdateTodo struct {
|
|||
type UpdateUser struct {
|
||||
UserName *string `json:"userName,omitempty"`
|
||||
FullName *string `json:"fullName,omitempty"`
|
||||
Password *string `json:"password,omitempty"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
UserName string `json:"userName"`
|
||||
FullName *string `json:"fullName,omitempty"`
|
||||
Todos []*Todo `json:"todos"`
|
||||
Roles []*Role `json:"roles"`
|
||||
}
|
||||
|
|
|
@ -27,20 +27,42 @@ type Todo {
|
|||
type User {
|
||||
id: ID!
|
||||
userName: String!
|
||||
fullName: String!
|
||||
fullName: String
|
||||
todos: [Todo!]!
|
||||
roles: [Role!]!
|
||||
}
|
||||
|
||||
type Role {
|
||||
id: ID!
|
||||
roleName: String!
|
||||
isAdmin: Boolean!
|
||||
isUserCreator: Boolean!
|
||||
}
|
||||
|
||||
type RefreshToken {
|
||||
id: ID!
|
||||
expiryDate: Int!
|
||||
tokenName: String
|
||||
selector: String
|
||||
token: String
|
||||
userId: String!
|
||||
}
|
||||
|
||||
type Query {
|
||||
todos: [Todo!]!
|
||||
users: [User!]!
|
||||
roles: [Role!]!
|
||||
refreshTokens: [RefreshToken!]!
|
||||
user(id: ID!): User!
|
||||
todo(id: ID!): Todo!
|
||||
role(id: ID!): Role!
|
||||
refreshToken(id: ID!): RefreshToken!
|
||||
}
|
||||
|
||||
input NewUser {
|
||||
userName: String!
|
||||
fullName: String!
|
||||
fullName: String
|
||||
password: String!
|
||||
}
|
||||
|
||||
input NewTodo {
|
||||
|
@ -48,21 +70,50 @@ input NewTodo {
|
|||
userId: ID!
|
||||
}
|
||||
|
||||
input updateTodo {
|
||||
input NewRole {
|
||||
roleName: String!
|
||||
isAdmin: Boolean!
|
||||
isUserCreator: Boolean!
|
||||
}
|
||||
|
||||
input NewRefreshToken {
|
||||
tokenName: String
|
||||
}
|
||||
|
||||
input UpdateTodo {
|
||||
text: String
|
||||
done: Boolean
|
||||
}
|
||||
|
||||
input updateUser {
|
||||
input UpdateUser {
|
||||
userName: String
|
||||
fullName: String
|
||||
password: String
|
||||
}
|
||||
|
||||
input UpdateRole {
|
||||
roleName: String
|
||||
isAdmin: Boolean
|
||||
isUserCreator: Boolean
|
||||
}
|
||||
|
||||
input UpdateRefreshToken {
|
||||
tokenName: String
|
||||
}
|
||||
|
||||
type Mutation {
|
||||
createUser(input: NewUser!): User!
|
||||
createTodo(input: NewTodo!): Todo!
|
||||
updateTodo(id: ID!, changes: updateTodo!): Todo!
|
||||
updateUser(id: ID!, changes: updateUser!): User!
|
||||
createRole(input: NewRole!): Role!
|
||||
createRefreshToken(input: NewRefreshToken!): RefreshToken!
|
||||
updateTodo(id: ID!, changes: UpdateTodo!): Todo!
|
||||
updateUser(id: ID!, changes: UpdateUser!): User!
|
||||
updateRole(id: ID!, changes: UpdateRole!): Role!
|
||||
updateRefreshToken(id: ID!, changes: UpdateRefreshToken!): RefreshToken!
|
||||
deleteUser(id: ID!): ID
|
||||
deleteTodo(id: ID!): ID
|
||||
deleteRole(id: ID!): ID
|
||||
deleteRefreshToken(id: ID!): ID
|
||||
addRole(userId: ID!, roleId: ID!): [Role!]!
|
||||
removeRole(userId: ID!, roleId: ID!): [Role!]!
|
||||
}
|
||||
|
|
|
@ -10,11 +10,12 @@ import (
|
|||
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals"
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/server/auth"
|
||||
)
|
||||
|
||||
// CreateUser is the resolver for the createUser field.
|
||||
func (r *mutationResolver) CreateUser(ctx context.Context, input model.NewUser) (*model.User, error) {
|
||||
todo, err := globals.DB.AddUser(input)
|
||||
todo, err := globals.DB.CreateUser(input)
|
||||
if err != nil {
|
||||
globals.Logger.Println("Failed to add new user:", err)
|
||||
return nil, errors.New("failed to add new user")
|
||||
|
@ -32,6 +33,28 @@ func (r *mutationResolver) CreateTodo(ctx context.Context, input model.NewTodo)
|
|||
return todo, nil
|
||||
}
|
||||
|
||||
// CreateRole is the resolver for the createRole field.
|
||||
func (r *mutationResolver) CreateRole(ctx context.Context, input model.NewRole) (*model.Role, error) {
|
||||
role, err := globals.DB.CreateRole(&input)
|
||||
if err != nil {
|
||||
globals.Logger.Println("Failed to add new role:", err)
|
||||
return nil, errors.New("failed to add new role")
|
||||
}
|
||||
return role, nil
|
||||
}
|
||||
|
||||
// CreateRefreshToken is the resolver for the createRefreshToken field.
|
||||
func (r *mutationResolver) CreateRefreshToken(ctx context.Context, input model.NewRefreshToken) (*model.RefreshToken, error) {
|
||||
// TODO: unify model.RefreshToken & auth.RefreshToken
|
||||
userToken := auth.ForContext(ctx)
|
||||
refreshToken, tokenId, err := globals.DB.IssueRefreshToken(userToken.UserId, input.TokenName)
|
||||
if err != nil {
|
||||
globals.Logger.Println("Failed to create refresh token:", err)
|
||||
return nil, errors.New("failed to create refresh token")
|
||||
}
|
||||
return &model.RefreshToken{ID: tokenId, ExpiryDate: refreshToken.ExpiryDate, TokenName: input.TokenName, Selector: &refreshToken.Selector, Token: &refreshToken.Token, UserID: userToken.UserId}, nil
|
||||
}
|
||||
|
||||
// UpdateTodo is the resolver for the updateTodo field.
|
||||
func (r *mutationResolver) UpdateTodo(ctx context.Context, id string, changes model.UpdateTodo) (*model.Todo, error) {
|
||||
return globals.DB.UpdateTodo(id, &changes)
|
||||
|
@ -42,6 +65,16 @@ func (r *mutationResolver) UpdateUser(ctx context.Context, id string, changes mo
|
|||
return globals.DB.UpdateUser(id, &changes)
|
||||
}
|
||||
|
||||
// UpdateRole is the resolver for the updateRole field.
|
||||
func (r *mutationResolver) UpdateRole(ctx context.Context, id string, changes model.UpdateRole) (*model.Role, error) {
|
||||
return globals.DB.UpdateRole(id, &changes)
|
||||
}
|
||||
|
||||
// UpdateRefreshToken is the resolver for the updateRefreshToken field.
|
||||
func (r *mutationResolver) UpdateRefreshToken(ctx context.Context, id string, changes model.UpdateRefreshToken) (*model.RefreshToken, error) {
|
||||
return globals.DB.UpdateRefreshToken(id, &changes)
|
||||
}
|
||||
|
||||
// DeleteUser is the resolver for the deleteUser field.
|
||||
func (r *mutationResolver) DeleteUser(ctx context.Context, id string) (*string, error) {
|
||||
return globals.DB.DeleteUser(id)
|
||||
|
@ -52,6 +85,32 @@ func (r *mutationResolver) DeleteTodo(ctx context.Context, id string) (*string,
|
|||
return globals.DB.DeleteTodo(id)
|
||||
}
|
||||
|
||||
// DeleteRole is the resolver for the deleteRole field.
|
||||
func (r *mutationResolver) DeleteRole(ctx context.Context, id string) (*string, error) {
|
||||
return globals.DB.DeleteRole(id)
|
||||
}
|
||||
|
||||
// DeleteRefreshToken is the resolver for the deleteRefreshToken field.
|
||||
func (r *mutationResolver) DeleteRefreshToken(ctx context.Context, id string) (*string, error) {
|
||||
return globals.DB.RevokeRefreshToken(id)
|
||||
}
|
||||
|
||||
// AddRole is the resolver for the addRole field.
|
||||
func (r *mutationResolver) AddRole(ctx context.Context, userID string, roleID string) ([]*model.Role, error) {
|
||||
if _, err := globals.DB.AddRole(userID, roleID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return globals.DB.GetRolesFrom(userID)
|
||||
}
|
||||
|
||||
// RemoveRole is the resolver for the removeRole field.
|
||||
func (r *mutationResolver) RemoveRole(ctx context.Context, userID string, roleID string) ([]*model.Role, error) {
|
||||
if _, err := globals.DB.RemoveRole(userID, roleID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return globals.DB.GetRolesFrom(userID)
|
||||
}
|
||||
|
||||
// Todos is the resolver for the todos field.
|
||||
func (r *queryResolver) Todos(ctx context.Context) ([]*model.Todo, error) {
|
||||
return globals.DB.GetAllTodos()
|
||||
|
@ -62,6 +121,16 @@ func (r *queryResolver) Users(ctx context.Context) ([]*model.User, error) {
|
|||
return globals.DB.GetAllUsers()
|
||||
}
|
||||
|
||||
// Roles is the resolver for the roles field.
|
||||
func (r *queryResolver) Roles(ctx context.Context) ([]*model.Role, error) {
|
||||
return globals.DB.GetAllRoles()
|
||||
}
|
||||
|
||||
// RefreshTokens is the resolver for the refreshTokens field.
|
||||
func (r *queryResolver) RefreshTokens(ctx context.Context) ([]*model.RefreshToken, error) {
|
||||
return globals.DB.GetAllRefreshTokens()
|
||||
}
|
||||
|
||||
// User is the resolver for the user field.
|
||||
func (r *queryResolver) User(ctx context.Context, id string) (*model.User, error) {
|
||||
return globals.DB.GetUser(&model.User{ID: id})
|
||||
|
@ -72,6 +141,16 @@ func (r *queryResolver) Todo(ctx context.Context, id string) (*model.Todo, error
|
|||
return globals.DB.GetTodo(&model.Todo{ID: id})
|
||||
}
|
||||
|
||||
// Role is the resolver for the role field.
|
||||
func (r *queryResolver) Role(ctx context.Context, id string) (*model.Role, error) {
|
||||
return globals.DB.GetRole(&model.Role{ID: id})
|
||||
}
|
||||
|
||||
// RefreshToken is the resolver for the refreshToken field.
|
||||
func (r *queryResolver) RefreshToken(ctx context.Context, id string) (*model.RefreshToken, error) {
|
||||
return globals.DB.GetRefreshToken(&model.RefreshToken{ID: id})
|
||||
}
|
||||
|
||||
// User is the resolver for the user field.
|
||||
func (r *todoResolver) User(ctx context.Context, obj *model.Todo) (*model.User, error) {
|
||||
// TODO: implement dataloader
|
||||
|
@ -83,6 +162,11 @@ func (r *userResolver) Todos(ctx context.Context, obj *model.User) ([]*model.Tod
|
|||
return globals.DB.GetTodosFrom(obj)
|
||||
}
|
||||
|
||||
// Roles is the resolver for the roles field.
|
||||
func (r *userResolver) Roles(ctx context.Context, obj *model.User) ([]*model.Role, error) {
|
||||
return globals.DB.GetRolesFrom(obj.ID)
|
||||
}
|
||||
|
||||
// Mutation returns MutationResolver implementation.
|
||||
func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} }
|
||||
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
# Authentication in Golang
|
||||
|
||||
## Use header
|
||||
|
||||
Use the http header or the body, but avoid using cookies to transport tokens etc
|
||||
(because of CSRF)
|
||||
|
||||
## Implementation
|
||||
|
||||
```
|
||||
# create password/user
|
||||
update_password(password):
|
||||
generate salt
|
||||
db.store(hash(salt + password))
|
||||
|
||||
#login
|
||||
get_token(userId, password):
|
||||
if not hash(salt + password) == db.get(salted_hash_of_password): # use timing-attack resistant compare here
|
||||
return FAILED
|
||||
generate selector
|
||||
db.store(selector)
|
||||
generate salt
|
||||
generate auth_token
|
||||
db.store(salt, hash(salt+auth_token))
|
||||
return selector:auth_token
|
||||
|
||||
#authenticate
|
||||
validate_token(selector, auth_token):
|
||||
if not hash(salt+auth_token) == db.get(salt, salted_hash_of_auth_token WHERE selector): # use timing-attack resistant compare here
|
||||
return UNKNOWN_TOKEN
|
||||
return AUTHENTICATED
|
||||
```
|
||||
|
||||
idea: replace selector with userId?
|
||||
|
||||
## JWT
|
||||
|
||||
```json
|
||||
{
|
||||
"userId": "id",
|
||||
"userRole": "role",
|
||||
"expiryTime": "now+10min"
|
||||
}
|
||||
```
|
||||
|
||||
We use JWT with a lifespan of 10min and an in-memory db to blacklist revoked
|
||||
tokens. So if for e.g. a user changes it's password, we would add the userId and
|
||||
the time of the change to the blacklist, filtering out all tokens that have been
|
||||
issued before.
|
||||
|
||||
After a server restart, all tokens will become invalid as well, since we can not
|
||||
be sure which ones were 'on the blacklist'. This could be mitigated by making
|
||||
the blacklist persistent during restarts.
|
||||
|
||||
If a token has expired (either by a server restart or after 10min), a new token
|
||||
is requested with a 'long lived' refresh-token (lifetime of ~1 week) that is
|
||||
stored in a database.
|
||||
|
||||
### Pros:
|
||||
|
||||
- less DB lookups in general
|
||||
|
||||
### Cons:
|
||||
|
||||
- timestamps of blacklisting could become a problem (maybe add 1 second or use
|
||||
timestamp returned by n-th node when it has received the update).
|
||||
- increased load on db + service since we need to issue new jwt for everybody.
|
||||
|
||||
## SSL/TLS
|
||||
|
||||
```bash
|
||||
openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout keyFile.key -out certFile.crt
|
||||
```
|
||||
|
||||
## CRIME/BREACH http compression attack
|
||||
|
||||
While we do not use a CSRF token, headers sent to the `/api` still contain
|
||||
private data.
|
||||
|
||||
If you are using http/1.1 or lower and have compression enabled on your proxy,
|
||||
you would be .
|
||||
|
||||
## CRSF
|
||||
|
||||
We check for a http header like this `X-YOURSITE-CSRF-PROTECTION=1`. This should
|
||||
be enough, according to
|
||||
[cheatsheetseries.owasp.org](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#custom-request-headers)
|
||||
|
||||
## XSS
|
||||
|
||||
We rely on Vue.js's ability to escape user-input in templates.
|
|
@ -0,0 +1,139 @@
|
|||
/*
|
||||
YetAnotherToDoList
|
||||
Copyright © 2023 gilex-dev gilex-dev@proton.me
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, version 3.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/database"
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals"
|
||||
)
|
||||
|
||||
var userCtxKey = &contextKey{"user"}
|
||||
|
||||
type contextKey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func Middleware() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
const headerPrefix = "Bearer "
|
||||
authField := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authField, headerPrefix) {
|
||||
http.Error(w, fmt.Sprintf(`{"error":"wrong token type, expect '%s'"}`, headerPrefix), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// get the user from the database
|
||||
user, err := globals.DB.CheckAccessToken(strings.TrimPrefix(authField, headerPrefix))
|
||||
if err != nil {
|
||||
http.Error(w, strings.ReplaceAll(fmt.Sprintf(`{"error":"failed token check: '%s'"}`, err), "\\", "\\\\"), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// put it in context
|
||||
ctx := context.WithValue(r.Context(), userCtxKey, user)
|
||||
// and call the next with our new context
|
||||
r = r.WithContext(ctx)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ForContext(ctx context.Context) *database.AccessToken {
|
||||
raw, _ := ctx.Value(userCtxKey).(*database.AccessToken)
|
||||
return raw
|
||||
}
|
||||
|
||||
type userAuth struct {
|
||||
UserId *string `json:"userId"`
|
||||
UserName *string `json:"userName"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func IssueRefreshTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
const headerPrefix = "Login "
|
||||
authField := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authField, headerPrefix) {
|
||||
http.Error(w, fmt.Sprintf("wrong token type, expect %q", headerPrefix), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
userCredentials := userAuth{}
|
||||
|
||||
err := json.Unmarshal([]byte(strings.TrimPrefix(authField, headerPrefix)), &userCredentials)
|
||||
if err != nil {
|
||||
http.Error(w, "malformed or missing user credentials", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
userId, err := globals.DB.ValidateUserCredentials(userCredentials.UserId, userCredentials.UserName, userCredentials.Password)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid credentials", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
refreshToken, _, err := globals.DB.IssueRefreshToken(userId, nil)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to issue refresh token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
jsonRefreshToken, err := json.Marshal(refreshToken)
|
||||
if err != nil {
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Write([]byte(jsonRefreshToken))
|
||||
}
|
||||
|
||||
func IssueAccessTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
const headerPrefix = "Refresh "
|
||||
authField := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authField, headerPrefix) {
|
||||
http.Error(w, fmt.Sprintf("wrong token type, expect %q", headerPrefix), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
refreshToken := database.RefreshToken{}
|
||||
|
||||
err := json.Unmarshal([]byte(strings.TrimPrefix(authField, headerPrefix)), &refreshToken)
|
||||
if err != nil {
|
||||
http.Error(w, "malformed or missing refresh token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, err := globals.DB.IssueAccessToken(&refreshToken)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid access token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
encAccessToken, err := globals.DB.SignAccessToken(*accessToken)
|
||||
if err != nil {
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Write([]byte(encAccessToken))
|
||||
}
|
|
@ -20,11 +20,24 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql/playground"
|
||||
"github.com/go-chi/chi"
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals"
|
||||
)
|
||||
|
||||
func handleDevTools(port int) {
|
||||
mux.Handle("/playground", playground.Handler("GraphQL playground", "/api"))
|
||||
globals.Logger.Printf("connect to http://localhost:%v/ for GraphQL playground", port)
|
||||
func handleDevTools(router *chi.Mux, portHTTP int, portHTTPS int) {
|
||||
router.Handle("/playground", playground.Handler("GraphQL playground", "/api"))
|
||||
url := "connect to "
|
||||
if portHTTP != -1 && portHTTPS != -1 {
|
||||
url += fmt.Sprintf("http://localhost:%v/ or https://localhost:%v/", portHTTP, portHTTPS)
|
||||
} else if portHTTP != -1 {
|
||||
url += fmt.Sprintf("http://localhost:%v/", portHTTP)
|
||||
} else if portHTTPS != -1 {
|
||||
url += fmt.Sprintf("https://localhost:%v/", portHTTPS)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
globals.Logger.Println(url + " for GraphQL playground")
|
||||
}
|
||||
|
|
|
@ -19,5 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
*/
|
||||
package server
|
||||
|
||||
func handleDevTools(_ int) {
|
||||
import "github.com/go-chi/chi"
|
||||
|
||||
func handleDevTools(router *chi.Mux, _ int, _ int) {
|
||||
}
|
||||
|
|
|
@ -24,19 +24,21 @@ import (
|
|||
"io/fs"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
)
|
||||
|
||||
//go:embed dist/*
|
||||
var frontend embed.FS
|
||||
|
||||
func handleFrontend() {
|
||||
func handleFrontend(router *chi.Mux) {
|
||||
stripped, err := fs.Sub(frontend, "dist")
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
frontendFS := http.FileServer(http.FS(stripped))
|
||||
mux.Handle("/assets/", frontendFS)
|
||||
mux.HandleFunc("/", indexHandler)
|
||||
router.Handle("/assets/*", frontendFS)
|
||||
router.HandleFunc("/", indexHandler)
|
||||
// TODO: redirect from vue to 404 page (on go/proxy server)
|
||||
}
|
||||
|
||||
|
|
|
@ -19,5 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
*/
|
||||
package server
|
||||
|
||||
func handleFrontend() {
|
||||
import "github.com/go-chi/chi"
|
||||
|
||||
func handleFrontend(router *chi.Mux) {
|
||||
}
|
||||
|
|
|
@ -22,24 +22,68 @@ import (
|
|||
"strconv"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql/handler"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals"
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph"
|
||||
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/server/auth"
|
||||
)
|
||||
|
||||
var mux *http.ServeMux
|
||||
func StartServer(portHTTP int, portHTTPS int, certFile string, keyFile string) {
|
||||
router := chi.NewRouter()
|
||||
router.Use(middleware.StripSlashes)
|
||||
|
||||
func StartServer(port int) {
|
||||
mux = http.NewServeMux()
|
||||
|
||||
handleDevTools(port) // controlled by 'dev' tag
|
||||
handleFrontend() // controlled by 'headless' tag
|
||||
|
||||
mux.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) {
|
||||
handleDevTools(router, portHTTP, portHTTPS) // controlled by 'dev' tag
|
||||
handleFrontend(router) // controlled by 'headless' tag
|
||||
router.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "%s %s", globals.Version, globals.CommitHash)
|
||||
})
|
||||
|
||||
srv := handler.NewDefaultServer(graph.NewExecutableSchema(graph.Config{Resolvers: &graph.Resolver{}}))
|
||||
mux.Handle("/api", srv)
|
||||
router.HandleFunc("/auth/login", auth.IssueRefreshTokenHandler)
|
||||
router.HandleFunc("/auth", auth.IssueAccessTokenHandler)
|
||||
|
||||
router.Group(func(r chi.Router) {
|
||||
r.Use(auth.Middleware())
|
||||
r.Handle("/api", srv)
|
||||
r.HandleFunc("/protected", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Printf("user is %+v\n", auth.ForContext(r.Context()))
|
||||
})
|
||||
})
|
||||
// router.HandleFunc("/auth/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// http.Redirect(w, r, "/auth", http.StatusMovedPermanently)
|
||||
// })
|
||||
|
||||
err := <-listen(portHTTP, portHTTPS, certFile, keyFile, router)
|
||||
globals.Logger.Fatalf("Could not start serving service due to (error: %s)", err)
|
||||
|
||||
globals.Logger.Fatal(http.ListenAndServe(":"+strconv.FormatInt(int64(port), 10), mux))
|
||||
}
|
||||
|
||||
func listen(portHTTP int, portHTTPS int, certFile string, keyFile string, router *chi.Mux) chan error {
|
||||
|
||||
errs := make(chan error)
|
||||
|
||||
// Starting HTTP server
|
||||
if portHTTP != -1 {
|
||||
go func() {
|
||||
globals.Logger.Printf("Staring HTTP service on %d ...", portHTTP)
|
||||
|
||||
if err := http.ListenAndServe(":"+strconv.FormatInt(int64(portHTTP), 10), router); err != nil {
|
||||
errs <- err
|
||||
}
|
||||
|
||||
}()
|
||||
}
|
||||
|
||||
// Starting HTTPS server
|
||||
if portHTTPS != -1 && certFile != "" && keyFile != "" {
|
||||
go func() {
|
||||
globals.Logger.Printf("Staring HTTPS service on %d ...", portHTTPS)
|
||||
if err := http.ListenAndServeTLS(":"+strconv.FormatInt(int64(portHTTPS), 10), certFile, keyFile, router); err != nil {
|
||||
errs <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue