/* YetAnotherToDoList Copyright © 2023 gilex-dev gilex-dev@proton.me This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 3. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ package auth import ( "context" "encoding/json" "fmt" "net/http" "strings" "somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/database" "somepi.ddns.net/gitea/gilex-dev/YetAnotherToDoList/globals" ) var userCtxKey = &contextKey{"user"} type contextKey struct { name string } func Middleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { const headerPrefix = "Bearer " authField := r.Header.Get("Authorization") if !strings.HasPrefix(authField, headerPrefix) { http.Error(w, fmt.Sprintf(`{"error":"wrong token type, expect '%s'"}`, headerPrefix), http.StatusBadRequest) return } // get the user from the database user, err := globals.DB.CheckAccessToken(strings.TrimPrefix(authField, headerPrefix)) if err != nil { http.Error(w, strings.ReplaceAll(fmt.Sprintf(`{"error":"failed token check: '%s'"}`, err), "\\", "\\\\"), http.StatusInternalServerError) return } // put it in context ctx := context.WithValue(r.Context(), userCtxKey, user) // and call the next with our new context r = r.WithContext(ctx) next.ServeHTTP(w, r) }) } } func ForContext(ctx context.Context) *database.AccessToken { raw, _ := ctx.Value(userCtxKey).(*database.AccessToken) return raw } type userAuth struct { UserId *string `json:"userId"` UserName *string `json:"userName"` Password string `json:"password"` } func IssueRefreshTokenHandler(w http.ResponseWriter, r *http.Request) { const headerPrefix = "Login " authField := r.Header.Get("Authorization") if !strings.HasPrefix(authField, headerPrefix) { http.Error(w, fmt.Sprintf("wrong token type, expect %q", headerPrefix), http.StatusBadRequest) return } userCredentials := userAuth{} err := json.Unmarshal([]byte(strings.TrimPrefix(authField, headerPrefix)), &userCredentials) if err != nil { http.Error(w, "malformed or missing user credentials", http.StatusBadRequest) return } userId, err := globals.DB.ValidateUserCredentials(userCredentials.UserId, userCredentials.UserName, userCredentials.Password) if err != nil { http.Error(w, "invalid credentials", http.StatusUnauthorized) return } refreshToken, _, err := globals.DB.IssueRefreshToken(userId, nil) if err != nil { http.Error(w, "failed to issue refresh token", http.StatusInternalServerError) return } jsonRefreshToken, err := json.Marshal(refreshToken) if err != nil { http.Error(w, "internal server error", http.StatusInternalServerError) return } w.Write([]byte(jsonRefreshToken)) } func IssueAccessTokenHandler(w http.ResponseWriter, r *http.Request) { const headerPrefix = "Refresh " authField := r.Header.Get("Authorization") if !strings.HasPrefix(authField, headerPrefix) { http.Error(w, fmt.Sprintf("wrong token type, expect %q", headerPrefix), http.StatusBadRequest) return } refreshToken := database.RefreshToken{} err := json.Unmarshal([]byte(strings.TrimPrefix(authField, headerPrefix)), &refreshToken) if err != nil { http.Error(w, "malformed or missing refresh token", http.StatusBadRequest) return } accessToken, err := globals.DB.IssueAccessToken(&refreshToken) if err != nil { http.Error(w, "invalid access token", http.StatusUnauthorized) return } encAccessToken, err := globals.DB.SignAccessToken(*accessToken) if err != nil { http.Error(w, "internal server error", http.StatusInternalServerError) return } w.Write([]byte(encAccessToken)) }