2023-11-05 17:42:14 +01:00
/ *
YetAnotherToDoList
Copyright © 2023 gilex - dev gilex - dev @ proton . me
This program is free software : you can redistribute it and / or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation , version 3.
This program is distributed in the hope that it will be useful ,
but WITHOUT ANY WARRANTY ; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE . See the
GNU General Public License for more details .
You should have received a copy of the GNU General Public License
along with this program . If not , see < http : //www.gnu.org/licenses/>.
* /
package database
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"slices"
"strconv"
"strings"
"time"
"unicode"
"golang.org/x/crypto/bcrypt"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
)
type RefreshToken struct {
Selector string ` json:"selector" `
Token string ` json:"token" `
ExpiryDate int ` json:"expiryDate" `
}
type AccessToken struct {
UserId string ` json:"userId" `
IsAdmin bool ` json:"isAdmin" `
IsUserCreator bool ` json:"isUserCreator" `
ExpiryDate int ` json:"expiryDate" `
}
const refreshTokenLifetime = "+10 day" // TODO: add to viper
const accessTokenLifetime = time . Minute * 10
const minPasswordLength = 10
const minUserNameLength = 10
var RevokedAccessTokens [ ] * AccessToken
// ValidatePassword validates the passed string against the password criteria.
func ValidatePassword ( password string ) error {
if len ( [ ] rune ( password ) ) < minPasswordLength {
return fmt . Errorf ( "password must be at least %d characters long" , minPasswordLength )
}
for i , c := range password {
if ! ( unicode . IsLetter ( c ) || unicode . IsNumber ( c ) ) {
return fmt . Errorf ( "password contains none-alphanumeric symbol at index %d" , i )
}
}
return nil
}
// ValidateUserName validates the passed string against the user name criteria.
func ValidateUserName ( userName string ) error {
if len ( [ ] rune ( userName ) ) < minUserNameLength {
return fmt . Errorf ( "userName must be at least %d characters long" , minUserNameLength )
}
for i , c := range userName {
if ! ( unicode . IsLetter ( c ) || unicode . IsNumber ( c ) ) {
return fmt . Errorf ( "userName contains none-alphanumeric symbol at index %d" , i )
}
}
return nil
}
// GenerateHashFromPassword generates a hash of the passed password. Returns salt and salted & peppered hash.
func ( db CustomDB ) GenerateHashFromPassword ( password string ) ( passwordHash [ ] byte , err error ) {
hashBytes , err := bcrypt . GenerateFromPassword ( bytes . Join ( [ ] [ ] byte { db . secret , [ ] byte ( password ) } , nil ) , bcrypt . DefaultCost )
if err != nil {
return nil , err
}
return hashBytes , nil
}
// GetRefreshTokenOwner takes a tokenId and return the owner's userId. Call before Update/Get/DeleteRefreshToken when not IS_admin.
2024-02-02 21:23:32 +01:00
func ( db CustomDB ) GetRefreshTokenOwner ( tokenId string ) ( ownerId string , err error ) {
2023-11-05 17:42:14 +01:00
numTokenId , err := strconv . Atoi ( tokenId )
if err != nil {
2024-02-02 21:23:32 +01:00
return "" , errors . New ( "malformed refresh token Id" )
2023-11-05 17:42:14 +01:00
}
2024-02-02 21:23:32 +01:00
statement , err := db . connection . Prepare ( "SELECT FK_User_userId FROM RefreshToken WHERE tokenId = ?" )
2023-11-05 17:42:14 +01:00
if err != nil {
return "" , err
}
result := statement . QueryRow ( numTokenId )
var owner string
if err := result . Scan ( & owner ) ; err != nil {
2024-02-02 21:23:32 +01:00
if err == sql . ErrNoRows {
return "" , errors . New ( "invalid refresh token Id" )
}
2023-11-05 17:42:14 +01:00
return "" , err
}
return owner , nil
}
func ( db CustomDB ) ValidateUserCredentials ( userId * string , userName * string , password string ) ( validUserId string , err error ) {
var result * sql . Row
var hash string
2024-02-02 21:23:32 +01:00
if userId != nil { // use userId
2023-11-05 17:42:14 +01:00
numUserId , err := strconv . Atoi ( * userId )
if err != nil {
return "" , errors . New ( "userId not numeric" )
}
statement , err := db . connection . Prepare ( "SELECT passwordHash FROM User WHERE userId = ?" )
if err != nil {
return "" , err
}
result = statement . QueryRow ( numUserId )
if err := result . Scan ( & hash ) ; err != nil {
2024-02-02 21:23:32 +01:00
if err == sql . ErrNoRows {
return "" , errors . New ( "invalid user Id" )
}
2023-11-05 17:42:14 +01:00
return "" , err
}
2024-02-02 21:23:32 +01:00
} else if userName != nil { // use userName
2023-11-05 17:42:14 +01:00
statement , err := db . connection . Prepare ( "SELECT userId, passwordHash FROM User WHERE userName = ?" )
if err != nil {
return "" , err
}
result = statement . QueryRow ( & userName )
if err := result . Scan ( & userId , & hash ) ; err != nil {
2024-02-02 21:23:32 +01:00
if err == sql . ErrNoRows {
return "" , errors . New ( "invalid user Id" )
}
2023-11-05 17:42:14 +01:00
return "" , err
}
} else {
return "" , errors . New ( "neither userId nor userName specified" )
}
if err := bcrypt . CompareHashAndPassword ( [ ] byte ( hash ) , bytes . Join ( [ ] [ ] byte { db . secret , [ ] byte ( password ) } , nil ) ) ; err != nil {
return "" , err
}
return * userId , nil
}
// IssueRefreshToken issues a refresh token if the passed user credentials are valid. Returned refresh token can be passed to IssueAccessToken.
func ( db CustomDB ) IssueRefreshToken ( userId string , tokenName * string ) ( refreshToken * RefreshToken , refreshTokenId string , err error ) {
numUserId , err := strconv . Atoi ( userId )
if err != nil {
return nil , "" , errors . New ( "userId not numeric" )
}
selector := make ( [ ] byte , 9 )
if _ , err := rand . Read ( selector ) ; err != nil {
return nil , "" , err
}
token := make ( [ ] byte , 33 )
if _ , err := rand . Read ( token ) ; err != nil {
return nil , "" , err
}
statement , err := db . connection . Prepare ( "INSERT INTO RefreshToken (FK_User_userId, selector, tokenHash, expiryDate, tokenName) VALUES (?, ?, ?, unixepoch('now','" + refreshTokenLifetime + "'), NULLIF(?, '')) RETURNING tokenId, expiryDate" )
if err != nil {
return nil , "" , err
}
encSelector := base64 . RawURLEncoding . EncodeToString ( selector )
encToken := base64 . RawURLEncoding . EncodeToString ( token )
tokenHash := sha256 . Sum256 ( token )
var expiryDate int // int(time.Now().AddDate(0, 1, 0).Unix())
var tokenId string
result := statement . QueryRow ( numUserId , encSelector , base64 . RawURLEncoding . EncodeToString ( tokenHash [ : ] ) , & tokenName )
if err := result . Scan ( & tokenId , & expiryDate ) ; err != nil {
2024-02-02 21:23:32 +01:00
if err == sql . ErrNoRows {
return nil , "" , errors . New ( "failed to add new refresh token" )
}
2023-11-05 17:42:14 +01:00
return nil , "" , err
}
return & RefreshToken { Selector : encSelector , Token : encToken , ExpiryDate : expiryDate } , tokenId , nil
}
func ( db CustomDB ) GetRefreshToken ( token * model . RefreshToken ) ( * model . RefreshToken , error ) {
numTokenId , err := strconv . Atoi ( token . ID )
if err != nil {
2024-02-02 21:23:32 +01:00
return nil , errors . New ( "malformed refresh token Id" )
2023-11-05 17:42:14 +01:00
}
statement , err := db . connection . Prepare ( "SELECT expiryDate, tokenName FROM RefreshToken WHERE tokenId = ?" )
if err != nil {
return nil , err
}
result := statement . QueryRow ( numTokenId )
2024-02-02 21:23:32 +01:00
if err := result . Scan ( & token . ExpiryDate , & token . TokenName ) ; err != nil {
if err == sql . ErrNoRows {
return nil , errors . New ( "invalid refresh token Id" )
}
2023-11-05 17:42:14 +01:00
return nil , err
}
return token , nil
}
func ( db CustomDB ) GetRefreshTokensFrom ( userId string ) ( [ ] * model . RefreshToken , error ) {
numUserId , err := strconv . Atoi ( userId )
if err != nil {
2024-02-02 21:23:32 +01:00
return nil , errors . New ( "malformed userId" )
2023-11-05 17:42:14 +01:00
}
statement , err := db . connection . Prepare ( "SELECT tokenId, expiaryDate, tokenName FROM RefreshToken WHERE FK_User_userId = ?" )
if err != nil {
return nil , err
}
rows , err := statement . Query ( numUserId )
if err != nil {
return nil , err
}
defer rows . Close ( )
var all [ ] * model . RefreshToken
for rows . Next ( ) {
token := model . RefreshToken { }
if err := rows . Scan ( & token . ID , & token . ExpiryDate , & token . TokenName ) ; err != nil {
return nil , err
}
all = append ( all , & token )
}
return all , nil
}
func ( db CustomDB ) GetAllRefreshTokens ( ) ( [ ] * model . RefreshToken , error ) {
statement , err := db . connection . Prepare ( "SELECT tokenID, FK_User_userId, expiryDate, tokenName FROM RefreshToken" )
if err != nil {
return nil , err
}
rows , err := statement . Query ( )
if err != nil {
return nil , err
}
defer rows . Close ( )
var all [ ] * model . RefreshToken
for rows . Next ( ) {
var token model . RefreshToken
if err := rows . Scan ( & token . ID , & token . UserID , & token . ExpiryDate , & token . TokenName ) ; err != nil {
return nil , err
}
all = append ( all , & token )
}
return all , nil
}
func ( db CustomDB ) UpdateRefreshToken ( tokenId string , changes * model . UpdateRefreshToken ) ( * model . RefreshToken , error ) {
numTokenId , err := strconv . Atoi ( tokenId )
if err != nil {
2024-02-02 21:23:32 +01:00
return nil , errors . New ( "malformed refresh token Id" )
2023-11-05 17:42:14 +01:00
}
2024-02-02 21:23:32 +01:00
statement , err := db . connection . Prepare ( "UPDATE RefreshToken SET tokenName = ? WHERE tokenId = ?" )
2023-11-05 17:42:14 +01:00
if err != nil {
return nil , err
}
rows , err := statement . Exec ( changes . TokenName , numTokenId )
if err != nil {
return nil , err
}
num , err := rows . RowsAffected ( )
if err != nil {
return nil , err
}
if num < 1 {
return nil , errors . New ( "no rows affected" )
}
token , err := db . GetRefreshToken ( & model . RefreshToken { ID : tokenId } )
if err != nil {
return nil , errors . New ( "failed to get updated token" )
}
return token , nil
}
// RevokeRefreshToken revokes the access token matching the tokenId. Also calls RevokeAccessToken.
func ( db CustomDB ) RevokeRefreshToken ( tokenId string ) ( * string , error ) {
// TODO: return string instead of *string
numTokenId , err := strconv . Atoi ( tokenId )
if err != nil {
2024-02-02 21:23:32 +01:00
return nil , errors . New ( "malformed refresh token Id" )
2023-11-05 17:42:14 +01:00
}
statement , err := db . connection . Prepare ( "DELETE FROM RefreshToken WHERE tokenId = ? RETURNING FK_User_userId" )
if err != nil {
return nil , err
}
result := statement . QueryRow ( numTokenId )
var userId string
if err := result . Scan ( & userId ) ; err != nil {
2024-02-02 21:23:32 +01:00
if err == sql . ErrNoRows {
return nil , errors . New ( "invalid refresh token Id" )
}
2023-11-05 17:42:14 +01:00
return nil , err
}
RevokeAccessToken ( & AccessToken { UserId : userId , ExpiryDate : int ( time . Now ( ) . Add ( accessTokenLifetime ) . Unix ( ) ) } )
return & tokenId , nil
}
// IssueAccessToken issues an access token if the passed refresh token is valid. Returned access token must be passed to SignAccessToken to be accepted.
func ( db CustomDB ) IssueAccessToken ( refreshToken * RefreshToken ) ( * AccessToken , error ) {
2024-02-02 21:23:32 +01:00
statement , err := db . connection . Prepare ( "SELECT tokenHash, FK_User_userId FROM RefreshToken WHERE selector = ? AND expiryDate >= unixepoch('now')" )
2023-11-05 17:42:14 +01:00
if err != nil {
return nil , err
}
result := statement . QueryRow ( refreshToken . Selector )
var tokenHash string
var newAccessToken AccessToken
2024-02-02 21:23:32 +01:00
if err := result . Scan ( & tokenHash , & newAccessToken . UserId ) ; err != nil {
if err == sql . ErrNoRows {
return nil , errors . New ( "invalid refresh token selector" )
}
return nil , err
}
newAccessToken . IsAdmin , newAccessToken . IsUserCreator , err = db . GetUserPermissions ( newAccessToken . UserId )
if err != nil {
2023-11-05 17:42:14 +01:00
return nil , err
}
decUserToken , err := base64 . RawURLEncoding . DecodeString ( refreshToken . Token )
if err != nil {
return nil , err
}
userTokenHash := sha256 . Sum256 ( decUserToken )
encUserTokenHash := base64 . RawURLEncoding . EncodeToString ( userTokenHash [ : ] )
if encUserTokenHash != tokenHash {
return nil , errors . New ( "failed to issue access token: refreshToken does not match" )
}
newAccessToken . ExpiryDate = int ( time . Now ( ) . Add ( accessTokenLifetime ) . Unix ( ) )
return & newAccessToken , nil
}
// SignAccessToken signs an access token and attaches a header. Returns access token encoded as jwt.
func ( db CustomDB ) SignAccessToken ( accessToken AccessToken ) ( encAccessToken string , err error ) {
data , err := json . Marshal ( accessToken )
if err != nil {
return "" , err
}
body := base64 . RawURLEncoding . EncodeToString ( data )
header := base64 . RawURLEncoding . EncodeToString ( [ ] byte ( ` { "alg":"HS256","typ":"JWT"} ` ) )
combined := header + "." + body
mac := hmac . New ( sha256 . New , db . secret )
mac . Write ( [ ] byte ( combined ) )
signature := mac . Sum ( nil )
return combined + "." + base64 . RawURLEncoding . EncodeToString ( signature ) , nil
}
func ( db CustomDB ) CheckAccessToken ( encAccessToken string ) ( accessToken * AccessToken , err error ) {
accessTokenBlob := strings . Split ( encAccessToken , "." )
header , err := base64 . RawURLEncoding . DecodeString ( accessTokenBlob [ 0 ] )
if err != nil {
return nil , err
}
if ! bytes . Equal ( header , [ ] byte ( ` { "alg":"HS256","typ":"JWT"} ` ) ) {
return nil , errors . New ( "access token header wrong" )
}
body , err := base64 . RawURLEncoding . DecodeString ( accessTokenBlob [ 1 ] )
if err != nil {
return nil , err
}
var data AccessToken
err = json . Unmarshal ( body , & data )
if err != nil {
return nil , err
}
// TODO: maybe change part of signKey instead?
for _ , revokedAccessToken := range RevokedAccessTokens {
if data . UserId == revokedAccessToken . UserId && data . ExpiryDate <= revokedAccessToken . ExpiryDate {
return nil , errors . New ( "access token revoked" )
}
}
if data . ExpiryDate < int ( time . Now ( ) . Unix ( ) ) {
return nil , errors . New ( "access token expired" )
}
if data . ExpiryDate < initTimeStamp + int ( accessTokenLifetime . Seconds ( ) ) {
return nil , errors . New ( "access token expired prematurely" )
}
signature , err := base64 . RawURLEncoding . DecodeString ( accessTokenBlob [ 2 ] )
if err != nil {
return nil , err
}
mac := hmac . New ( sha256 . New , db . secret )
mac . Write ( [ ] byte ( accessTokenBlob [ 0 ] ) )
mac . Write ( [ ] byte ( "." ) )
mac . Write ( [ ] byte ( accessTokenBlob [ 1 ] ) )
expectedSignature := mac . Sum ( nil )
if ! hmac . Equal ( signature , expectedSignature ) {
return nil , errors . New ( "access token signature does not match" )
}
return & data , nil
}
// RevokeAccessToken revokes all access tokens with matching UserId and UserRole that don't have a later ExpiryDate.
// revokedAccessToken.ExpiryDate should be set to now + token-lifetime.
func RevokeAccessToken ( accessToken * AccessToken ) {
RevokedAccessTokens = append ( RevokedAccessTokens , accessToken )
}
// CleanRevokedTokensTicker removes expired tokens from the list. This should be called in an interval of > accessTokenLifetime.
func ( db CustomDB ) CleanRevokedAccessTokensTicker ( interval time . Duration ) ( stopCleaner chan bool ) {
ticker := time . NewTicker ( interval )
stop := make ( chan bool )
go func ( ) {
for {
select {
case <- stop :
return
case <- ticker . C :
db . logger . Println ( "cleaning revoked access tokens" )
for i , revokedAccessToken := range RevokedAccessTokens {
if revokedAccessToken . ExpiryDate < int ( time . Now ( ) . Unix ( ) ) {
revokedAccessToken = nil
slices . Delete ( RevokedAccessTokens , i , i + 1 )
}
}
}
}
} ( )
return stop
}
func ( db CustomDB ) CleanExpiredRefreshTokensTicker ( interval time . Duration ) ( stopCleaner chan bool ) {
ticker := time . NewTicker ( interval )
stop := make ( chan bool )
go func ( ) {
for {
select {
case <- stop :
return
case <- ticker . C :
db . logger . Println ( "cleaning expired refresh tokens" )
_ , err := db . connection . Exec ( "DELETE FROM RefreshToken WHERE expiryDate < unixepoch('now')" )
if err != nil {
db . logger . Println ( "failed to clean expired refresh tokens" )
}
}
}
} ( )
return stop
}