mirror of
https://github.com/get-drexa/drive.git
synced 2025-12-08 17:11:40 +00:00
feat: impl cookie-based auth tokens exchange
implement access/refresh token exchange via cookies as well as automatic access token refresh
This commit is contained in:
56
apps/backend/internal/auth/cookies.go
Normal file
56
apps/backend/internal/auth/cookies.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// CookieConfig controls auth cookie behavior.
|
||||
type CookieConfig struct {
|
||||
// Domain for cross-subdomain cookies (e.g., "app.com" for web.app.com + api.app.com).
|
||||
// Leave empty for same-host cookies (localhost, single domain).
|
||||
Domain string
|
||||
}
|
||||
|
||||
// authCookies returns auth cookies from the given fiber context.
|
||||
// Returns a map with the cookie names as keys and the cookie values as values.
|
||||
func authCookies(c *fiber.Ctx) map[string]string {
|
||||
m := make(map[string]string)
|
||||
at := c.Cookies(cookieKeyAccessToken)
|
||||
if at != "" {
|
||||
m[cookieKeyAccessToken] = at
|
||||
}
|
||||
rt := c.Cookies(cookieKeyRefreshToken)
|
||||
if rt != "" {
|
||||
m[cookieKeyRefreshToken] = rt
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// setAuthCookies sets HTTP-only auth cookies with security settings derived from the request.
|
||||
// Secure flag is based on actual protocol (works automatically with proxies/tunnels).
|
||||
func setAuthCookies(c *fiber.Ctx, accessToken, refreshToken string, cfg CookieConfig) {
|
||||
secure := c.Protocol() == "https"
|
||||
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: cookieKeyAccessToken,
|
||||
Value: accessToken,
|
||||
Path: "/",
|
||||
Domain: cfg.Domain,
|
||||
Expires: time.Now().Add(accessTokenValidFor),
|
||||
SameSite: fiber.CookieSameSiteLaxMode,
|
||||
HTTPOnly: true,
|
||||
Secure: secure,
|
||||
})
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: cookieKeyRefreshToken,
|
||||
Value: refreshToken,
|
||||
Path: "/",
|
||||
Domain: cfg.Domain,
|
||||
Expires: time.Now().Add(refreshTokenValidFor),
|
||||
SameSite: fiber.CookieSameSiteLaxMode,
|
||||
HTTPOnly: true,
|
||||
Secure: secure,
|
||||
})
|
||||
}
|
||||
@@ -10,15 +10,26 @@ import (
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
const (
|
||||
tokenDeliveryCookie = "cookie"
|
||||
tokenDeliveryBody = "body"
|
||||
)
|
||||
|
||||
const (
|
||||
cookieKeyAccessToken = "access_token"
|
||||
cookieKeyRefreshToken = "refresh_token"
|
||||
)
|
||||
|
||||
type loginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
TokenDelivery string `json:"tokenDelivery"`
|
||||
}
|
||||
|
||||
type loginResponse struct {
|
||||
User user.User `json:"user"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
AccessToken string `json:"accessToken,omitempty"`
|
||||
RefreshToken string `json:"refreshToken,omitempty"`
|
||||
}
|
||||
|
||||
type refreshAccessTokenRequest struct {
|
||||
@@ -31,12 +42,13 @@ type tokenResponse struct {
|
||||
}
|
||||
|
||||
type HTTPHandler struct {
|
||||
service *Service
|
||||
db *bun.DB
|
||||
service *Service
|
||||
db *bun.DB
|
||||
cookieConfig CookieConfig
|
||||
}
|
||||
|
||||
func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
|
||||
return &HTTPHandler{service: s, db: db}
|
||||
func NewHTTPHandler(s *Service, db *bun.DB, cookieConfig CookieConfig) *HTTPHandler {
|
||||
return &HTTPHandler{service: s, db: db, cookieConfig: cookieConfig}
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
|
||||
@@ -69,11 +81,23 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
||||
return httperr.Internal(err)
|
||||
}
|
||||
|
||||
return c.JSON(loginResponse{
|
||||
User: *result.User,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
})
|
||||
switch req.TokenDelivery {
|
||||
default:
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid token delivery method"})
|
||||
|
||||
case tokenDeliveryCookie:
|
||||
setAuthCookies(c, result.AccessToken, result.RefreshToken, h.cookieConfig)
|
||||
return c.JSON(loginResponse{
|
||||
User: *result.User,
|
||||
})
|
||||
|
||||
case tokenDeliveryBody:
|
||||
return c.JSON(loginResponse{
|
||||
User: *result.User,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) refreshAccessToken(c *fiber.Ctx) error {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/get-drexa/drexa/internal/httperr"
|
||||
"github.com/get-drexa/drexa/internal/user"
|
||||
@@ -13,24 +14,35 @@ import (
|
||||
|
||||
const authenticatedUserKey = "authenticatedUser"
|
||||
|
||||
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token.
|
||||
// NewAuthMiddleware creates a middleware that authenticates requests via Bearer token or cookies.
|
||||
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
|
||||
func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
|
||||
func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
var at string
|
||||
var rt string
|
||||
var setCookies bool
|
||||
authHeader := c.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
slog.Info("no auth header")
|
||||
if authHeader != "" {
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
slog.Info("invalid auth header")
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
at = parts[1]
|
||||
setCookies = false
|
||||
} else {
|
||||
cookies := authCookies(c)
|
||||
at = cookies[cookieKeyAccessToken]
|
||||
rt = cookies[cookieKeyRefreshToken]
|
||||
setCookies = true
|
||||
}
|
||||
|
||||
if at == "" {
|
||||
slog.Info("no access token")
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
slog.Info("invalid auth header")
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
|
||||
token := parts[1]
|
||||
u, err := s.AuthenticateWithAccessToken(c.Context(), db, token)
|
||||
authResult, err := s.AuthenticateWithAccessToken(c.Context(), db, at)
|
||||
if err != nil {
|
||||
var e *InvalidAccessTokenError
|
||||
if errors.As(err, &e) {
|
||||
@@ -47,7 +59,24 @@ func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
|
||||
return httperr.Internal(err)
|
||||
}
|
||||
|
||||
c.Locals(authenticatedUserKey, u)
|
||||
c.Locals(authenticatedUserKey, authResult.User)
|
||||
|
||||
// if cookie based auth and access token is about to expire (within 5 minutes),
|
||||
// attempt to refresh the access token. if there is any error, ignore it and let the request continue.
|
||||
if setCookies && time.Until(authResult.Claims.ExpiresAt.Time) < 5*time.Minute && rt != "" {
|
||||
tx, txErr := db.BeginTx(c.Context(), nil)
|
||||
if txErr == nil {
|
||||
newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt)
|
||||
if err == nil {
|
||||
if commitErr := tx.Commit(); commitErr == nil {
|
||||
setAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
|
||||
}
|
||||
} else {
|
||||
_ = tx.Rollback()
|
||||
slog.Debug("auto-refresh failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/get-drexa/drexa/internal/password"
|
||||
"github.com/get-drexa/drexa/internal/user"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
@@ -20,6 +21,11 @@ type AuthenticationTokens struct {
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
type AccessTokenAuthentication struct {
|
||||
Claims *jwt.RegisteredClaims
|
||||
User *user.User
|
||||
}
|
||||
|
||||
var ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
|
||||
type Service struct {
|
||||
@@ -65,7 +71,7 @@ func (s *Service) GrantForUser(ctx context.Context, db bun.IDB, user *user.User)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
|
||||
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*AccessTokenAuthentication, error) {
|
||||
claims, err := ParseAccessToken(token, &s.tokenConfig)
|
||||
if err != nil {
|
||||
slog.Info("failed to parse access token", "error", err)
|
||||
@@ -78,7 +84,19 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, t
|
||||
return nil, newInvalidAccessTokenError(err)
|
||||
}
|
||||
|
||||
return s.userService.UserByID(ctx, db, id)
|
||||
u, err := s.userService.UserByID(ctx, db, id)
|
||||
if err != nil {
|
||||
var nf *user.NotFoundError
|
||||
if errors.As(err, &nf) {
|
||||
return nil, newInvalidAccessTokenError(err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccessTokenAuthentication{
|
||||
Claims: claims,
|
||||
User: u,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshToken string) (*AuthenticationTokens, error) {
|
||||
|
||||
Reference in New Issue
Block a user