476 lines
14 KiB
Go
476 lines
14 KiB
Go
/*
|
|
YetAnotherToDoList
|
|
Copyright © 2023 gilex-dev gilex-dev@proton.me
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, version 3.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
*/
|
|
package database
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/hmac"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
|
|
)
|
|
|
|
type RefreshToken struct {
|
|
Selector string `json:"selector"`
|
|
Token string `json:"token"`
|
|
ExpiryDate int `json:"expiryDate"`
|
|
}
|
|
|
|
type AccessToken struct {
|
|
UserId string `json:"userId"`
|
|
IsAdmin bool `json:"isAdmin"`
|
|
IsUserCreator bool `json:"isUserCreator"`
|
|
ExpiryDate int `json:"expiryDate"`
|
|
}
|
|
|
|
const refreshTokenLifetime = "+10 day" // TODO: add to viper
|
|
const accessTokenLifetime = time.Minute * 10
|
|
|
|
const minPasswordLength = 10
|
|
const minUserNameLength = 10
|
|
|
|
var RevokedAccessTokens []*AccessToken
|
|
|
|
// ValidatePassword validates the passed string against the password criteria.
|
|
func ValidatePassword(password string) error {
|
|
if len([]rune(password)) < minPasswordLength {
|
|
return fmt.Errorf("password must be at least %d characters long", minPasswordLength)
|
|
}
|
|
for i, c := range password {
|
|
if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
|
|
return fmt.Errorf("password contains none-alphanumeric symbol at index %d", i)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ValidateUserName validates the passed string against the user name criteria.
|
|
func ValidateUserName(userName string) error {
|
|
if len([]rune(userName)) < minUserNameLength {
|
|
return fmt.Errorf("userName must be at least %d characters long", minUserNameLength)
|
|
}
|
|
for i, c := range userName {
|
|
if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
|
|
return fmt.Errorf("userName contains none-alphanumeric symbol at index %d", i)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GenerateHashFromPassword generates a hash of the passed password. Returns salt and salted & peppered hash.
|
|
func (db CustomDB) GenerateHashFromPassword(password string) (passwordHash []byte, err error) {
|
|
hashBytes, err := bcrypt.GenerateFromPassword(bytes.Join([][]byte{db.secret, []byte(password)}, nil), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return hashBytes, nil
|
|
}
|
|
|
|
// GetRefreshTokenOwner takes a tokenId and return the owner's userId. Call before Update/Get/DeleteRefreshToken when not IS_admin.
|
|
func (db CustomDB) GetRefreshTokenOwner(tokenId string) (string, error) {
|
|
numTokenId, err := strconv.Atoi(tokenId)
|
|
if err != nil {
|
|
return "", errors.New("invalid tokenId")
|
|
}
|
|
|
|
statement, err := db.connection.Prepare("SELECT FK_User_userId, FROM RefreshToken WHERE tokenId = ?")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
result := statement.QueryRow(numTokenId)
|
|
var owner string
|
|
if err := result.Scan(&owner); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return owner, nil
|
|
}
|
|
|
|
func (db CustomDB) ValidateUserCredentials(userId *string, userName *string, password string) (validUserId string, err error) {
|
|
var result *sql.Row
|
|
var hash string
|
|
|
|
if userId != nil { // user userId
|
|
numUserId, err := strconv.Atoi(*userId)
|
|
if err != nil {
|
|
return "", errors.New("userId not numeric")
|
|
}
|
|
|
|
statement, err := db.connection.Prepare("SELECT passwordHash FROM User WHERE userId = ?")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
result = statement.QueryRow(numUserId)
|
|
if err := result.Scan(&hash); err != nil {
|
|
return "", err
|
|
}
|
|
} else if userName != nil { // user userName
|
|
statement, err := db.connection.Prepare("SELECT userId, passwordHash FROM User WHERE userName = ?")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
result = statement.QueryRow(&userName)
|
|
if err := result.Scan(&userId, &hash); err != nil {
|
|
return "", err
|
|
}
|
|
} else {
|
|
return "", errors.New("neither userId nor userName specified")
|
|
}
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(hash), bytes.Join([][]byte{db.secret, []byte(password)}, nil)); err != nil {
|
|
return "", err
|
|
}
|
|
return *userId, nil
|
|
}
|
|
|
|
// IssueRefreshToken issues a refresh token if the passed user credentials are valid. Returned refresh token can be passed to IssueAccessToken.
|
|
func (db CustomDB) IssueRefreshToken(userId string, tokenName *string) (refreshToken *RefreshToken, refreshTokenId string, err error) {
|
|
numUserId, err := strconv.Atoi(userId)
|
|
if err != nil {
|
|
return nil, "", errors.New("userId not numeric")
|
|
}
|
|
selector := make([]byte, 9)
|
|
if _, err := rand.Read(selector); err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
token := make([]byte, 33)
|
|
if _, err := rand.Read(token); err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
statement, err := db.connection.Prepare("INSERT INTO RefreshToken (FK_User_userId, selector, tokenHash, expiryDate, tokenName) VALUES (?, ?, ?, unixepoch('now','" + refreshTokenLifetime + "'), NULLIF(?, '')) RETURNING tokenId, expiryDate")
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
encSelector := base64.RawURLEncoding.EncodeToString(selector)
|
|
encToken := base64.RawURLEncoding.EncodeToString(token)
|
|
|
|
tokenHash := sha256.Sum256(token)
|
|
|
|
var expiryDate int // int(time.Now().AddDate(0, 1, 0).Unix())
|
|
var tokenId string
|
|
result := statement.QueryRow(numUserId, encSelector, base64.RawURLEncoding.EncodeToString(tokenHash[:]), &tokenName)
|
|
|
|
if err := result.Scan(&tokenId, &expiryDate); err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
return &RefreshToken{Selector: encSelector, Token: encToken, ExpiryDate: expiryDate}, tokenId, nil
|
|
}
|
|
|
|
func (db CustomDB) GetRefreshToken(token *model.RefreshToken) (*model.RefreshToken, error) {
|
|
numTokenId, err := strconv.Atoi(token.ID)
|
|
if err != nil {
|
|
return nil, errors.New("invalid tokenId")
|
|
}
|
|
|
|
statement, err := db.connection.Prepare("SELECT expiryDate, tokenName FROM RefreshToken WHERE tokenId = ?")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := statement.QueryRow(numTokenId)
|
|
if err := result.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
func (db CustomDB) GetRefreshTokensFrom(userId string) ([]*model.RefreshToken, error) {
|
|
numUserId, err := strconv.Atoi(userId)
|
|
if err != nil {
|
|
return nil, errors.New("invalid userId")
|
|
}
|
|
statement, err := db.connection.Prepare("SELECT tokenId, expiaryDate, tokenName FROM RefreshToken WHERE FK_User_userId = ?")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rows, err := statement.Query(numUserId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
var all []*model.RefreshToken
|
|
for rows.Next() {
|
|
token := model.RefreshToken{}
|
|
if err := rows.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil {
|
|
return nil, err
|
|
}
|
|
all = append(all, &token)
|
|
}
|
|
return all, nil
|
|
}
|
|
|
|
func (db CustomDB) GetAllRefreshTokens() ([]*model.RefreshToken, error) {
|
|
|
|
statement, err := db.connection.Prepare("SELECT tokenID, FK_User_userId, expiryDate, tokenName FROM RefreshToken")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rows, err := statement.Query()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var all []*model.RefreshToken
|
|
for rows.Next() {
|
|
var token model.RefreshToken
|
|
if err := rows.Scan(&token.ID, &token.UserID, &token.ExpiryDate, &token.TokenName); err != nil {
|
|
return nil, err
|
|
}
|
|
all = append(all, &token)
|
|
}
|
|
return all, nil
|
|
}
|
|
|
|
func (db CustomDB) UpdateRefreshToken(tokenId string, changes *model.UpdateRefreshToken) (*model.RefreshToken, error) {
|
|
numTokenId, err := strconv.Atoi(tokenId)
|
|
if err != nil {
|
|
return nil, errors.New("invalid tokenId")
|
|
}
|
|
|
|
statement, err := db.connection.Prepare("UPDATE AuthToken SET tokenName = ? WHERE tokenId = ?")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rows, err := statement.Exec(changes.TokenName, numTokenId)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
num, err := rows.RowsAffected()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if num < 1 {
|
|
return nil, errors.New("no rows affected")
|
|
}
|
|
|
|
token, err := db.GetRefreshToken(&model.RefreshToken{ID: tokenId})
|
|
if err != nil {
|
|
return nil, errors.New("failed to get updated token")
|
|
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
// RevokeRefreshToken revokes the access token matching the tokenId. Also calls RevokeAccessToken.
|
|
func (db CustomDB) RevokeRefreshToken(tokenId string) (*string, error) {
|
|
// TODO: return string instead of *string
|
|
numTokenId, err := strconv.Atoi(tokenId)
|
|
if err != nil {
|
|
return nil, errors.New("invalid tokenId")
|
|
}
|
|
|
|
statement, err := db.connection.Prepare("DELETE FROM RefreshToken WHERE tokenId = ? RETURNING FK_User_userId")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := statement.QueryRow(numTokenId)
|
|
|
|
var userId string
|
|
if err := result.Scan(&userId); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
|
|
return &tokenId, nil
|
|
}
|
|
|
|
// IssueAccessToken issues an access token if the passed refresh token is valid. Returned access token must be passed to SignAccessToken to be accepted.
|
|
func (db CustomDB) IssueAccessToken(refreshToken *RefreshToken) (*AccessToken, error) {
|
|
statement, err := db.connection.Prepare("SELECT RefreshToken.tokenHash, RefreshToken.FK_User_userId, Role.IS_admin, ROLE.IS_userCreator FROM RefreshToken INNER JOIN R_User_Role ON RefreshToken.FK_User_userId = R_User_Role.FK_User_userId INNER JOIN Role ON R_User_Role.FK_Role_roleId = Role.roleId WHERE RefreshToken.selector = ? AND RefreshToken.expiryDate >= unixepoch('now')")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := statement.QueryRow(refreshToken.Selector)
|
|
|
|
var tokenHash string
|
|
var newAccessToken AccessToken
|
|
if err := result.Scan(&tokenHash, &newAccessToken.UserId, &newAccessToken.IsAdmin, &newAccessToken.IsUserCreator); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
decUserToken, err := base64.RawURLEncoding.DecodeString(refreshToken.Token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
userTokenHash := sha256.Sum256(decUserToken)
|
|
encUserTokenHash := base64.RawURLEncoding.EncodeToString(userTokenHash[:])
|
|
if encUserTokenHash != tokenHash {
|
|
return nil, errors.New("failed to issue access token: refreshToken does not match")
|
|
}
|
|
|
|
newAccessToken.ExpiryDate = int(time.Now().Add(accessTokenLifetime).Unix())
|
|
return &newAccessToken, nil
|
|
}
|
|
|
|
// SignAccessToken signs an access token and attaches a header. Returns access token encoded as jwt.
|
|
func (db CustomDB) SignAccessToken(accessToken AccessToken) (encAccessToken string, err error) {
|
|
|
|
data, err := json.Marshal(accessToken)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
body := base64.RawURLEncoding.EncodeToString(data)
|
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
|
|
combined := header + "." + body
|
|
|
|
mac := hmac.New(sha256.New, db.secret)
|
|
mac.Write([]byte(combined))
|
|
signature := mac.Sum(nil)
|
|
|
|
return combined + "." + base64.RawURLEncoding.EncodeToString(signature), nil
|
|
}
|
|
|
|
func (db CustomDB) CheckAccessToken(encAccessToken string) (accessToken *AccessToken, err error) {
|
|
accessTokenBlob := strings.Split(encAccessToken, ".")
|
|
|
|
header, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[0])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !bytes.Equal(header, []byte(`{"alg":"HS256","typ":"JWT"}`)) {
|
|
return nil, errors.New("access token header wrong")
|
|
}
|
|
|
|
body, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[1])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var data AccessToken
|
|
err = json.Unmarshal(body, &data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// TODO: maybe change part of signKey instead?
|
|
for _, revokedAccessToken := range RevokedAccessTokens {
|
|
if data.UserId == revokedAccessToken.UserId && data.ExpiryDate <= revokedAccessToken.ExpiryDate {
|
|
return nil, errors.New("access token revoked")
|
|
}
|
|
}
|
|
|
|
if data.ExpiryDate < int(time.Now().Unix()) {
|
|
return nil, errors.New("access token expired")
|
|
}
|
|
|
|
if data.ExpiryDate < initTimeStamp+int(accessTokenLifetime.Seconds()) {
|
|
return nil, errors.New("access token expired prematurely")
|
|
}
|
|
|
|
signature, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[2])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mac := hmac.New(sha256.New, db.secret)
|
|
mac.Write([]byte(accessTokenBlob[0]))
|
|
mac.Write([]byte("."))
|
|
mac.Write([]byte(accessTokenBlob[1]))
|
|
expectedSignature := mac.Sum(nil)
|
|
if !hmac.Equal(signature, expectedSignature) {
|
|
return nil, errors.New("access token signature does not match")
|
|
}
|
|
|
|
return &data, nil
|
|
}
|
|
|
|
// RevokeAccessToken revokes all access tokens with matching UserId and UserRole that don't have a later ExpiryDate.
|
|
// revokedAccessToken.ExpiryDate should be set to now + token-lifetime.
|
|
func RevokeAccessToken(accessToken *AccessToken) {
|
|
RevokedAccessTokens = append(RevokedAccessTokens, accessToken)
|
|
}
|
|
|
|
// CleanRevokedTokensTicker removes expired tokens from the list. This should be called in an interval of > accessTokenLifetime.
|
|
func (db CustomDB) CleanRevokedAccessTokensTicker(interval time.Duration) (stopCleaner chan bool) {
|
|
ticker := time.NewTicker(interval)
|
|
stop := make(chan bool)
|
|
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-stop:
|
|
return
|
|
case <-ticker.C:
|
|
db.logger.Println("cleaning revoked access tokens")
|
|
for i, revokedAccessToken := range RevokedAccessTokens {
|
|
if revokedAccessToken.ExpiryDate < int(time.Now().Unix()) {
|
|
revokedAccessToken = nil
|
|
slices.Delete(RevokedAccessTokens, i, i+1)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
return stop
|
|
}
|
|
|
|
func (db CustomDB) CleanExpiredRefreshTokensTicker(interval time.Duration) (stopCleaner chan bool) {
|
|
ticker := time.NewTicker(interval)
|
|
stop := make(chan bool)
|
|
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-stop:
|
|
return
|
|
case <-ticker.C:
|
|
db.logger.Println("cleaning expired refresh tokens")
|
|
_, err := db.connection.Exec("DELETE FROM RefreshToken WHERE expiryDate < unixepoch('now')")
|
|
if err != nil {
|
|
db.logger.Println("failed to clean expired refresh tokens")
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
return stop
|
|
}
|