YetAnotherToDoList/database/crypto_helpers.go

502 lines
14 KiB
Go

/*
YetAnotherToDoList
Copyright © 2023 gilex-dev gilex-dev@proton.me
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package database
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"slices"
"strconv"
"strings"
"time"
"unicode"
"golang.org/x/crypto/bcrypt"
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/graph/model"
)
type RefreshToken struct {
Selector string `json:"selector"`
Token string `json:"token"`
ExpiryDate int `json:"expiryDate"`
}
type AccessToken struct {
UserId string `json:"userId"`
IsAdmin bool `json:"isAdmin"`
IsUserCreator bool `json:"isUserCreator"`
ExpiryDate int `json:"expiryDate"`
}
const refreshTokenLifetime = "+10 day" // TODO: add to viper
const accessTokenLifetime = time.Minute * 10
const minPasswordLength = 10
const minUserNameLength = 10
var RevokedAccessTokens []*AccessToken
// ValidatePassword validates the passed string against the password criteria.
func ValidatePassword(password string) error {
if len([]rune(password)) < minPasswordLength {
return fmt.Errorf("password must be at least %d characters long", minPasswordLength)
}
for i, c := range password {
if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
return fmt.Errorf("password contains none-alphanumeric symbol at index %d", i)
}
}
return nil
}
// ValidateUserName validates the passed string against the user name criteria.
func ValidateUserName(userName string) error {
if len([]rune(userName)) < minUserNameLength {
return fmt.Errorf("userName must be at least %d characters long", minUserNameLength)
}
for i, c := range userName {
if !(unicode.IsLetter(c) || unicode.IsNumber(c)) {
return fmt.Errorf("userName contains none-alphanumeric symbol at index %d", i)
}
}
return nil
}
// GenerateHashFromPassword generates a hash of the passed password. Returns salt and salted & peppered hash.
func (db CustomDB) GenerateHashFromPassword(password string) (passwordHash []byte, err error) {
hashBytes, err := bcrypt.GenerateFromPassword(bytes.Join([][]byte{db.secret, []byte(password)}, nil), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
return hashBytes, nil
}
// GetRefreshTokenOwner takes a tokenId and return the owner's userId. Call before Update/Get/DeleteRefreshToken when not IS_admin.
func (db CustomDB) GetRefreshTokenOwner(tokenId string) (ownerId string, err error) {
numTokenId, err := strconv.Atoi(tokenId)
if err != nil {
return "", errors.New("malformed refresh token Id")
}
statement, err := db.connection.Prepare("SELECT FK_User_userId FROM RefreshToken WHERE tokenId = ?")
if err != nil {
return "", err
}
result := statement.QueryRow(numTokenId)
var owner string
if err := result.Scan(&owner); err != nil {
if err == sql.ErrNoRows {
return "", errors.New("invalid refresh token Id")
}
return "", err
}
return owner, nil
}
func (db CustomDB) ValidateUserCredentials(userId *string, userName *string, password string) (validUserId string, err error) {
var result *sql.Row
var hash string
if userId != nil { // use userId
numUserId, err := strconv.Atoi(*userId)
if err != nil {
return "", errors.New("userId not numeric")
}
statement, err := db.connection.Prepare("SELECT passwordHash FROM User WHERE userId = ?")
if err != nil {
return "", err
}
result = statement.QueryRow(numUserId)
if err := result.Scan(&hash); err != nil {
if err == sql.ErrNoRows {
return "", errors.New("invalid user Id")
}
return "", err
}
} else if userName != nil { // use userName
statement, err := db.connection.Prepare("SELECT userId, passwordHash FROM User WHERE userName = ?")
if err != nil {
return "", err
}
result = statement.QueryRow(&userName)
if err := result.Scan(&userId, &hash); err != nil {
if err == sql.ErrNoRows {
return "", errors.New("invalid user Id")
}
return "", err
}
} else {
return "", errors.New("neither userId nor userName specified")
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), bytes.Join([][]byte{db.secret, []byte(password)}, nil)); err != nil {
return "", err
}
return *userId, nil
}
// IssueRefreshToken issues a refresh token if the passed user credentials are valid. Returned refresh token can be passed to IssueAccessToken.
func (db CustomDB) IssueRefreshToken(userId string, tokenName *string) (refreshToken *RefreshToken, refreshTokenId string, err error) {
numUserId, err := strconv.Atoi(userId)
if err != nil {
return nil, "", errors.New("userId not numeric")
}
selector := make([]byte, 9)
if _, err := rand.Read(selector); err != nil {
return nil, "", err
}
token := make([]byte, 33)
if _, err := rand.Read(token); err != nil {
return nil, "", err
}
statement, err := db.connection.Prepare("INSERT INTO RefreshToken (FK_User_userId, selector, tokenHash, expiryDate, tokenName) VALUES (?, ?, ?, unixepoch('now','" + refreshTokenLifetime + "'), NULLIF(?, '')) RETURNING tokenId, expiryDate")
if err != nil {
return nil, "", err
}
encSelector := base64.RawURLEncoding.EncodeToString(selector)
encToken := base64.RawURLEncoding.EncodeToString(token)
tokenHash := sha256.Sum256(token)
var expiryDate int // int(time.Now().AddDate(0, 1, 0).Unix())
var tokenId string
result := statement.QueryRow(numUserId, encSelector, base64.RawURLEncoding.EncodeToString(tokenHash[:]), &tokenName)
if err := result.Scan(&tokenId, &expiryDate); err != nil {
if err == sql.ErrNoRows {
return nil, "", errors.New("failed to add new refresh token")
}
return nil, "", err
}
return &RefreshToken{Selector: encSelector, Token: encToken, ExpiryDate: expiryDate}, tokenId, nil
}
func (db CustomDB) GetRefreshToken(token *model.RefreshToken) (*model.RefreshToken, error) {
numTokenId, err := strconv.Atoi(token.ID)
if err != nil {
return nil, errors.New("malformed refresh token Id")
}
statement, err := db.connection.Prepare("SELECT expiryDate, tokenName FROM RefreshToken WHERE tokenId = ?")
if err != nil {
return nil, err
}
result := statement.QueryRow(numTokenId)
if err := result.Scan(&token.ExpiryDate, &token.TokenName); err != nil {
if err == sql.ErrNoRows {
return nil, errors.New("invalid refresh token Id")
}
return nil, err
}
return token, nil
}
func (db CustomDB) GetRefreshTokensFrom(userId string) ([]*model.RefreshToken, error) {
numUserId, err := strconv.Atoi(userId)
if err != nil {
return nil, errors.New("malformed userId")
}
statement, err := db.connection.Prepare("SELECT tokenId, expiaryDate, tokenName FROM RefreshToken WHERE FK_User_userId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Query(numUserId)
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.RefreshToken
for rows.Next() {
token := model.RefreshToken{}
if err := rows.Scan(&token.ID, &token.ExpiryDate, &token.TokenName); err != nil {
return nil, err
}
all = append(all, &token)
}
return all, nil
}
func (db CustomDB) GetAllRefreshTokens() ([]*model.RefreshToken, error) {
statement, err := db.connection.Prepare("SELECT tokenID, FK_User_userId, expiryDate, tokenName FROM RefreshToken")
if err != nil {
return nil, err
}
rows, err := statement.Query()
if err != nil {
return nil, err
}
defer rows.Close()
var all []*model.RefreshToken
for rows.Next() {
var token model.RefreshToken
if err := rows.Scan(&token.ID, &token.UserID, &token.ExpiryDate, &token.TokenName); err != nil {
return nil, err
}
all = append(all, &token)
}
return all, nil
}
func (db CustomDB) UpdateRefreshToken(tokenId string, changes *model.UpdateRefreshToken) (*model.RefreshToken, error) {
numTokenId, err := strconv.Atoi(tokenId)
if err != nil {
return nil, errors.New("malformed refresh token Id")
}
statement, err := db.connection.Prepare("UPDATE RefreshToken SET tokenName = ? WHERE tokenId = ?")
if err != nil {
return nil, err
}
rows, err := statement.Exec(changes.TokenName, numTokenId)
if err != nil {
return nil, err
}
num, err := rows.RowsAffected()
if err != nil {
return nil, err
}
if num < 1 {
return nil, errors.New("no rows affected")
}
token, err := db.GetRefreshToken(&model.RefreshToken{ID: tokenId})
if err != nil {
return nil, errors.New("failed to get updated token")
}
return token, nil
}
// RevokeRefreshToken revokes the access token matching the tokenId. Also calls RevokeAccessToken.
func (db CustomDB) RevokeRefreshToken(tokenId string) (*string, error) {
// TODO: return string instead of *string
numTokenId, err := strconv.Atoi(tokenId)
if err != nil {
return nil, errors.New("malformed refresh token Id")
}
statement, err := db.connection.Prepare("DELETE FROM RefreshToken WHERE tokenId = ? RETURNING FK_User_userId")
if err != nil {
return nil, err
}
result := statement.QueryRow(numTokenId)
var userId string
if err := result.Scan(&userId); err != nil {
if err == sql.ErrNoRows {
return nil, errors.New("invalid refresh token Id")
}
return nil, err
}
RevokeAccessToken(&AccessToken{UserId: userId, ExpiryDate: int(time.Now().Add(accessTokenLifetime).Unix())})
return &tokenId, nil
}
// IssueAccessToken issues an access token if the passed refresh token is valid. Returned access token must be passed to SignAccessToken to be accepted.
func (db CustomDB) IssueAccessToken(refreshToken *RefreshToken) (*AccessToken, error) {
statement, err := db.connection.Prepare("SELECT tokenHash, FK_User_userId FROM RefreshToken WHERE selector = ? AND expiryDate >= unixepoch('now')")
if err != nil {
return nil, err
}
result := statement.QueryRow(refreshToken.Selector)
var tokenHash string
var newAccessToken AccessToken
if err := result.Scan(&tokenHash, &newAccessToken.UserId); err != nil {
if err == sql.ErrNoRows {
return nil, errors.New("invalid refresh token selector")
}
return nil, err
}
newAccessToken.IsAdmin, newAccessToken.IsUserCreator, err = db.GetUserPermissions(newAccessToken.UserId)
if err != nil {
return nil, err
}
decUserToken, err := base64.RawURLEncoding.DecodeString(refreshToken.Token)
if err != nil {
return nil, err
}
userTokenHash := sha256.Sum256(decUserToken)
encUserTokenHash := base64.RawURLEncoding.EncodeToString(userTokenHash[:])
if encUserTokenHash != tokenHash {
return nil, errors.New("failed to issue access token: refreshToken does not match")
}
newAccessToken.ExpiryDate = int(time.Now().Add(accessTokenLifetime).Unix())
return &newAccessToken, nil
}
// SignAccessToken signs an access token and attaches a header. Returns access token encoded as jwt.
func (db CustomDB) SignAccessToken(accessToken AccessToken) (encAccessToken string, err error) {
data, err := json.Marshal(accessToken)
if err != nil {
return "", err
}
body := base64.RawURLEncoding.EncodeToString(data)
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
combined := header + "." + body
mac := hmac.New(sha256.New, db.secret)
mac.Write([]byte(combined))
signature := mac.Sum(nil)
return combined + "." + base64.RawURLEncoding.EncodeToString(signature), nil
}
func (db CustomDB) CheckAccessToken(encAccessToken string) (accessToken *AccessToken, err error) {
accessTokenBlob := strings.Split(encAccessToken, ".")
header, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[0])
if err != nil {
return nil, err
}
if !bytes.Equal(header, []byte(`{"alg":"HS256","typ":"JWT"}`)) {
return nil, errors.New("access token header wrong")
}
body, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[1])
if err != nil {
return nil, err
}
var data AccessToken
err = json.Unmarshal(body, &data)
if err != nil {
return nil, err
}
// TODO: maybe change part of signKey instead?
for _, revokedAccessToken := range RevokedAccessTokens {
if data.UserId == revokedAccessToken.UserId && data.ExpiryDate <= revokedAccessToken.ExpiryDate {
return nil, errors.New("access token revoked")
}
}
if data.ExpiryDate < int(time.Now().Unix()) {
return nil, errors.New("access token expired")
}
if data.ExpiryDate < initTimeStamp+int(accessTokenLifetime.Seconds()) {
return nil, errors.New("access token expired prematurely")
}
signature, err := base64.RawURLEncoding.DecodeString(accessTokenBlob[2])
if err != nil {
return nil, err
}
mac := hmac.New(sha256.New, db.secret)
mac.Write([]byte(accessTokenBlob[0]))
mac.Write([]byte("."))
mac.Write([]byte(accessTokenBlob[1]))
expectedSignature := mac.Sum(nil)
if !hmac.Equal(signature, expectedSignature) {
return nil, errors.New("access token signature does not match")
}
return &data, nil
}
// RevokeAccessToken revokes all access tokens with matching UserId and UserRole that don't have a later ExpiryDate.
// revokedAccessToken.ExpiryDate should be set to now + token-lifetime.
func RevokeAccessToken(accessToken *AccessToken) {
RevokedAccessTokens = append(RevokedAccessTokens, accessToken)
}
// CleanRevokedTokensTicker removes expired tokens from the list. This should be called in an interval of > accessTokenLifetime.
func (db CustomDB) CleanRevokedAccessTokensTicker(interval time.Duration) (stopCleaner chan bool) {
ticker := time.NewTicker(interval)
stop := make(chan bool)
go func() {
for {
select {
case <-stop:
return
case <-ticker.C:
db.logger.Println("cleaning revoked access tokens")
for i, revokedAccessToken := range RevokedAccessTokens {
if revokedAccessToken.ExpiryDate < int(time.Now().Unix()) {
revokedAccessToken = nil
slices.Delete(RevokedAccessTokens, i, i+1)
}
}
}
}
}()
return stop
}
func (db CustomDB) CleanExpiredRefreshTokensTicker(interval time.Duration) (stopCleaner chan bool) {
ticker := time.NewTicker(interval)
stop := make(chan bool)
go func() {
for {
select {
case <-stop:
return
case <-ticker.C:
db.logger.Println("cleaning expired refresh tokens")
_, err := db.connection.Exec("DELETE FROM RefreshToken WHERE expiryDate < unixepoch('now')")
if err != nil {
db.logger.Println("failed to clean expired refresh tokens")
}
}
}
}()
return stop
}