/* YetAnotherToDoList Copyright © 2023 gilex-dev gilex-dev@proton.me This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 3. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ package database import ( "bytes" "crypto/hmac" "crypto/rand" "crypto/sha256" "database/sql" "encoding/base64" "encoding/json" "errors" "fmt" "slices" "strconv" "strings" "time" "unicode" "golang.org/x/crypto/bcrypt" "somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model" ) type RefreshToken struct { Selector string `json:"selector"` Token string `json:"token"` ExpiryDate int `json:"expiryDate"` } type AccessToken struct { UserId string `json:"userId"` IsAdmin bool `json:"isAdmin"` IsUserCreator bool `json:"isUserCreator"` ExpiryDate int `json:"expiryDate"` } const refreshTokenLifetime = "+10 day" // TODO: add to viper const accessTokenLifetime = time.Minute * 10 const minPasswordLength = 10 const minUserNameLength = 10 var RevokedAccessTokens []*AccessToken // ValidatePassword validates the passed string against the password criteria. func ValidatePassword(password string) error { if len([]rune(password)) < minPasswordLength { return fmt.Errorf("password must be at least %d characters long", minPasswordLength) } for i, c := range password { if !(unicode.IsLetter(c) || unicode.IsNumber(c)) { return fmt.Errorf("password contains none-alphanumeric symbol at index %d", i) } } return nil } // ValidateUserName validates the passed string against the user name criteria. func ValidateUserName(userName string) error { if len([]rune(userName)) < minUserNameLength { return fmt.Errorf("userName must be at least %d characters long", minUserNameLength) } for i, c := range userName { if !(unicode.IsLetter(c) || unicode.IsNumber(c)) { return fmt.Errorf("userName contains none-alphanumeric symbol at index %d", i) } } return nil } // GenerateHashFromPassword generates a hash of the passed password. Returns salt and salted & peppered hash. func (db CustomDB) GenerateHashFromPassword(password string) (passwordHash []byte, err error) { hashBytes, err := bcrypt.GenerateFromPassword(bytes.Join([][]byte{db.secret, []byte(password)}, nil), bcrypt.DefaultCost) if err != nil { return nil, err } return hashBytes, nil } // GetRefreshTokenOwner takes a tokenId and return the owner's userId. Call before Update/Get/DeleteRefreshToken when not IS_admin. func (db CustomDB) GetRefreshTokenOwner(tokenId string) (ownerId string, err error) { numTokenId, err := strconv.Atoi(tokenId) if err != nil { return "", errors.New("malformed refresh token Id") } statement, err := db.connection.Prepare("SELECT FK_User_userId FROM RefreshToken WHERE tokenId = ?") if err != nil { return "", err } result := statement.QueryRow(numTokenId) var owner string if err := result.Scan(&owner); err != nil { if err == sql.ErrNoRows { return "", errors.New("invalid refresh token Id") } return "", err } return owner, nil } func (db CustomDB) ValidateUserCredentials(userId *string, userName *string, password string) (validUserId string, err error) { var result *sql.Row var hash string if userId != nil { // use userId numUserId, err := strconv.Atoi(*userId) if err != nil { return "", errors.New("userId not numeric") } statement, err := db.connection.Prepare("SELECT passwordHash FROM User WHERE userId = ?") if err != nil { return "", err } result = statement.QueryRow(numUserId) if err := result.Scan(&hash); err != nil { if err == sql.ErrNoRows { return "", errors.New("invalid user Id") } return "", err } } else if userName != nil { // use userName statement, err := db.connection.Prepare("SELECT userId, passwordHash FROM User WHERE userName = ?") if err != nil { return "", err } result = statement.QueryRow(&userName) if err := result.Scan(&userId, &hash); err != nil { if err == sql.ErrNoRows { return "", errors.New("invalid user Id") } return "", err } } else { return "", errors.New("neither userId nor userName specified") } if err := bcrypt.CompareHashAndPassword([]byte(hash), bytes.Join([][]byte{db.secret, []byte(password)}, nil)); err != nil { return "", err } return *userId, nil } // IssueRefreshToken issues a refresh token if the passed user credentials are valid. Returned refresh token can be passed to IssueAccessToken. func (db CustomDB) IssueRefreshToken(userId string, tokenName *string) (refreshToken *RefreshToken, refreshTokenId string, err error) { numUserId, err := strconv.Atoi(userId) if err != nil { return nil, "", errors.New("userId not numeric") } selector := make([]byte, 9) if _, err := rand.Read(selector); err != nil { return nil, "", err } token := make([]byte, 33) if _, err := rand.Read(token); err != nil { return nil, "", err } statement, err := db.connection.Prepare("INSERT INTO RefreshToken (FK_User_userId, selector, tokenHash, expiryDate, tokenName) VALUES (?, ?, ?, unixepoch('now','" + refreshTokenLifetime + "'), NULLIF(?, '')) RETURNING tokenId, expiryDate") if err != nil { return nil, "", err } encSelector := base64.RawURLEncoding.EncodeToString(selector) encToken := base64.RawURLEncoding.EncodeToString(token) tokenHash := sha256.Sum256(token) var expiryDate int // int(time.Now().AddDate(0, 1, 0).Unix()) var tokenId string result := statement.QueryRow(numUserId, encSelector, base64.RawURLEncoding.EncodeToString(tokenHash[:]), &tokenName) if err := result.Scan(&tokenId, &expiryDate); err != nil { if err == sql.ErrNoRows { return nil, "", errors.New("failed to add new refresh token") } return nil, "", err } return &RefreshToken{Selector: encSelector, Token: encToken, ExpiryDate: expiryDate}, tokenId, nil } func (db CustomDB) GetRefreshToken(token *model.RefreshToken) (*model.RefreshToken, error) { numTokenId, err := strconv.Atoi(token.ID) if err != nil { return nil, errors.New("malformed refresh token Id") } statement, err := db.connection.Prepare("SELECT expiryDate, tokenName FROM RefreshToken WHERE tokenId = ?") if err != nil { return nil, err } result := statement.QueryRow(numTokenId) if err := result.Scan(&token.ExpiryDate, &token.TokenName); err != nil { if err == sql.ErrNoRows { return nil, errors.New("invalid refresh token Id") } return nil, err } return token, nil } func (db CustomDB) GetRefreshTokensFrom(userId string) ([]*model.RefreshToken, error) { numUserId, err := strconv.Atoi(userId) if err != nil { return nil, errors.New("malformed userId") } statement, err := db.connection.Prepare("SELECT tokenId, expiaryDate, tokenName FROM RefreshToken WHERE FK_User_userId = ?") if err != nil { return nil, err } rows, err := statement.Query(numUserId) if err != nil { return nil, err } defer rows.Close() var all []*model.RefreshToken for rows.Next() { token := model.RefreshToken{} if err := rows.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil { return nil, err } all = append(all, &token) } return all, nil } func (db CustomDB) GetAllRefreshTokens() ([]*model.RefreshToken, error) { statement, err := db.connection.Prepare("SELECT tokenID, FK_User_userId, expiryDate, tokenName FROM RefreshToken") if err != nil { return nil, err } rows, err := statement.Query() if err != nil { return nil, err } defer rows.Close() var all []*model.RefreshToken for rows.Next() { var token model.RefreshToken if err := rows.Scan(&token.ID, &token.UserID, &token.ExpiryDate, &token.TokenName); err != nil { return nil, err } all = append(all, &token) } return all, nil } func (db CustomDB) UpdateRefreshToken(tokenId string, changes *model.UpdateRefreshToken) (*model.RefreshToken, error) { numTokenId, err := strconv.Atoi(tokenId) if err != nil { return nil, errors.New("malformed refresh token Id") } statement, err := db.connection.Prepare("UPDATE RefreshToken SET tokenName = ? WHERE tokenId = ?") if err != nil { return nil, err } rows, err := statement.Exec(changes.TokenName, numTokenId) if err != nil { return nil, err } num, err := rows.RowsAffected() if err != nil { return nil, err } if num < 1 { return nil, errors.New("no rows affected") } token, err := db.GetRefreshToken(&model.RefreshToken{ID: tokenId}) if err != nil { return nil, errors.New("failed to get updated token") } return token, nil } // RevokeRefreshToken revokes the access token matching the tokenId. Also calls RevokeAccessToken. func (db CustomDB) RevokeRefreshToken(tokenId string) (*string, error) { // TODO: return string instead of *string numTokenId, err := strconv.Atoi(tokenId) if err != nil { return nil, errors.New("malformed refresh token Id") } statement, err := db.connection.Prepare("DELETE FROM RefreshToken WHERE tokenId = ? RETURNING FK_User_userId") if err != nil { return nil, err } result := statement.QueryRow(numTokenId) var userId string if err := result.Scan(&userId); err != nil { if err == sql.ErrNoRows { return nil, errors.New("invalid refresh token Id") } return nil, err } RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())}) return &tokenId, nil } // IssueAccessToken issues an access token if the passed refresh token is valid. Returned access token must be passed to SignAccessToken to be accepted. func (db CustomDB) IssueAccessToken(refreshToken *RefreshToken) (*AccessToken, error) { statement, err := db.connection.Prepare("SELECT tokenHash, FK_User_userId FROM RefreshToken WHERE selector = ? AND expiryDate >= unixepoch('now')") if err != nil { return nil, err } result := statement.QueryRow(refreshToken.Selector) var tokenHash string var newAccessToken AccessToken if err := result.Scan(&tokenHash, &newAccessToken.UserId); err != nil { if err == sql.ErrNoRows { return nil, errors.New("invalid refresh token selector") } return nil, err } newAccessToken.IsAdmin, newAccessToken.IsUserCreator, err = db.GetUserPermissions(newAccessToken.UserId) if err != nil { return nil, err } decUserToken, err := base64.RawURLEncoding.DecodeString(refreshToken.Token) if err != nil { return nil, err } userTokenHash := sha256.Sum256(decUserToken) encUserTokenHash := base64.RawURLEncoding.EncodeToString(userTokenHash[:]) if encUserTokenHash != tokenHash { return nil, errors.New("failed to issue access token: refreshToken does not match") } newAccessToken.ExpiryDate = int(time.Now().Add(accessTokenLifetime).Unix()) return &newAccessToken, nil } // SignAccessToken signs an access token and attaches a header. Returns access token encoded as jwt. func (db CustomDB) SignAccessToken(accessToken AccessToken) (encAccessToken string, err error) { data, err := json.Marshal(accessToken) if err != nil { return "", err } body := base64.RawURLEncoding.EncodeToString(data) header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`)) combined := header + "." + body mac := hmac.New(sha256.New, db.secret) mac.Write([]byte(combined)) signature := mac.Sum(nil) return combined + "." + base64.RawURLEncoding.EncodeToString(signature), nil } func (db CustomDB) CheckAccessToken(encAccessToken string) (accessToken *AccessToken, err error) { accessTokenBlob := strings.Split(encAccessToken, ".") header, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[0]) if err != nil { return nil, err } if !bytes.Equal(header, []byte(`{"alg":"HS256","typ":"JWT"}`)) { return nil, errors.New("access token header wrong") } body, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[1]) if err != nil { return nil, err } var data AccessToken err = json.Unmarshal(body, &data) if err != nil { return nil, err } // TODO: maybe change part of signKey instead? for _, revokedAccessToken := range RevokedAccessTokens { if data.UserId == revokedAccessToken.UserId && data.ExpiryDate <= revokedAccessToken.ExpiryDate { return nil, errors.New("access token revoked") } } if data.ExpiryDate < int(time.Now().Unix()) { return nil, errors.New("access token expired") } if data.ExpiryDate < initTimeStamp+int(accessTokenLifetime.Seconds()) { return nil, errors.New("access token expired prematurely") } signature, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[2]) if err != nil { return nil, err } mac := hmac.New(sha256.New, db.secret) mac.Write([]byte(accessTokenBlob[0])) mac.Write([]byte(".")) mac.Write([]byte(accessTokenBlob[1])) expectedSignature := mac.Sum(nil) if !hmac.Equal(signature, expectedSignature) { return nil, errors.New("access token signature does not match") } return &data, nil } // RevokeAccessToken revokes all access tokens with matching UserId and UserRole that don't have a later ExpiryDate. // revokedAccessToken.ExpiryDate should be set to now + token-lifetime. func RevokeAccessToken(accessToken *AccessToken) { RevokedAccessTokens = append(RevokedAccessTokens, accessToken) } // CleanRevokedTokensTicker removes expired tokens from the list. This should be called in an interval of > accessTokenLifetime. func (db CustomDB) CleanRevokedAccessTokensTicker(interval time.Duration) (stopCleaner chan bool) { ticker := time.NewTicker(interval) stop := make(chan bool) go func() { for { select { case <-stop: return case <-ticker.C: db.logger.Println("cleaning revoked access tokens") for i, revokedAccessToken := range RevokedAccessTokens { if revokedAccessToken.ExpiryDate < int(time.Now().Unix()) { revokedAccessToken = nil slices.Delete(RevokedAccessTokens, i, i+1) } } } } }() return stop } func (db CustomDB) CleanExpiredRefreshTokensTicker(interval time.Duration) (stopCleaner chan bool) { ticker := time.NewTicker(interval) stop := make(chan bool) go func() { for { select { case <-stop: return case <-ticker.C: db.logger.Println("cleaning expired refresh tokens") _, err := db.connection.Exec("DELETE FROM RefreshToken WHERE expiryDate < unixepoch('now')") if err != nil { db.logger.Println("failed to clean expired refresh tokens") } } } }() return stop }