140 lines
4.1 KiB
Go
140 lines
4.1 KiB
Go
|
/*
|
||
|
YetAnotherToDoList
|
||
|
Copyright © 2023 gilex-dev gilex-dev@proton.me
|
||
|
|
||
|
This program is free software: you can redistribute it and/or modify
|
||
|
it under the terms of the GNU General Public License as published by
|
||
|
the Free Software Foundation, version 3.
|
||
|
|
||
|
This program is distributed in the hope that it will be useful,
|
||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||
|
GNU General Public License for more details.
|
||
|
|
||
|
You should have received a copy of the GNU General Public License
|
||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||
|
*/
|
||
|
package auth
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"strings"
|
||
|
|
||
|
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/database"
|
||
|
"somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals"
|
||
|
)
|
||
|
|
||
|
var userCtxKey = &contextKey{"user"}
|
||
|
|
||
|
type contextKey struct {
|
||
|
name string
|
||
|
}
|
||
|
|
||
|
func Middleware() func(http.Handler) http.Handler {
|
||
|
return func(next http.Handler) http.Handler {
|
||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
const headerPrefix = "Bearer "
|
||
|
authField := r.Header.Get("Authorization")
|
||
|
if !strings.HasPrefix(authField, headerPrefix) {
|
||
|
http.Error(w, fmt.Sprintf(`{"error":"wrong token type, expect '%s'"}`, headerPrefix), http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// get the user from the database
|
||
|
user, err := globals.DB.CheckAccessToken(strings.TrimPrefix(authField, headerPrefix))
|
||
|
if err != nil {
|
||
|
http.Error(w, strings.ReplaceAll(fmt.Sprintf(`{"error":"failed token check: '%s'"}`, err), "\\", "\\\\"), http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// put it in context
|
||
|
ctx := context.WithValue(r.Context(), userCtxKey, user)
|
||
|
// and call the next with our new context
|
||
|
r = r.WithContext(ctx)
|
||
|
next.ServeHTTP(w, r)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func ForContext(ctx context.Context) *database.AccessToken {
|
||
|
raw, _ := ctx.Value(userCtxKey).(*database.AccessToken)
|
||
|
return raw
|
||
|
}
|
||
|
|
||
|
type userAuth struct {
|
||
|
UserId *string `json:"userId"`
|
||
|
UserName *string `json:"userName"`
|
||
|
Password string `json:"password"`
|
||
|
}
|
||
|
|
||
|
func IssueRefreshTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||
|
const headerPrefix = "Login "
|
||
|
authField := r.Header.Get("Authorization")
|
||
|
if !strings.HasPrefix(authField, headerPrefix) {
|
||
|
http.Error(w, fmt.Sprintf("wrong token type, expect %q", headerPrefix), http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
userCredentials := userAuth{}
|
||
|
|
||
|
err := json.Unmarshal([]byte(strings.TrimPrefix(authField, headerPrefix)), &userCredentials)
|
||
|
if err != nil {
|
||
|
http.Error(w, "malformed or missing user credentials", http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
userId, err := globals.DB.ValidateUserCredentials(userCredentials.UserId, userCredentials.UserName, userCredentials.Password)
|
||
|
if err != nil {
|
||
|
http.Error(w, "invalid credentials", http.StatusUnauthorized)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
refreshToken, _, err := globals.DB.IssueRefreshToken(userId, nil)
|
||
|
if err != nil {
|
||
|
http.Error(w, "failed to issue refresh token", http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
jsonRefreshToken, err := json.Marshal(refreshToken)
|
||
|
if err != nil {
|
||
|
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
w.Write([]byte(jsonRefreshToken))
|
||
|
}
|
||
|
|
||
|
func IssueAccessTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||
|
const headerPrefix = "Refresh "
|
||
|
authField := r.Header.Get("Authorization")
|
||
|
if !strings.HasPrefix(authField, headerPrefix) {
|
||
|
http.Error(w, fmt.Sprintf("wrong token type, expect %q", headerPrefix), http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
refreshToken := database.RefreshToken{}
|
||
|
|
||
|
err := json.Unmarshal([]byte(strings.TrimPrefix(authField, headerPrefix)), &refreshToken)
|
||
|
if err != nil {
|
||
|
http.Error(w, "malformed or missing refresh token", http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
accessToken, err := globals.DB.IssueAccessToken(&refreshToken)
|
||
|
if err != nil {
|
||
|
http.Error(w, "invalid access token", http.StatusUnauthorized)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
encAccessToken, err := globals.DB.SignAccessToken(*accessToken)
|
||
|
if err != nil {
|
||
|
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
w.Write([]byte(encAccessToken))
|
||
|
}
|