feat: impl cookie-based auth tokens exchange

implement access/refresh token exchange via cookies as well as automatic
access token refresh
This commit is contained in:
2025-12-04 00:26:20 +00:00
parent d4c4e84fbf
commit 57167d5715
7 changed files with 179 additions and 30 deletions

View File

@@ -28,3 +28,8 @@ storage:
# Required when backend is "s3"
# bucket: my-drexa-bucket
cookie:
# Domain for cross-subdomain auth cookies.
# Set this when frontend and API are on different subdomains (e.g., "app.com" for web.app.com + api.app.com).
# Leave empty for single-domain or localhost setups.
# domain: app.com

View File

@@ -0,0 +1,56 @@
package auth
import (
"time"
"github.com/gofiber/fiber/v2"
)
// CookieConfig controls auth cookie behavior.
type CookieConfig struct {
// Domain for cross-subdomain cookies (e.g., "app.com" for web.app.com + api.app.com).
// Leave empty for same-host cookies (localhost, single domain).
Domain string
}
// authCookies returns auth cookies from the given fiber context.
// Returns a map with the cookie names as keys and the cookie values as values.
func authCookies(c *fiber.Ctx) map[string]string {
m := make(map[string]string)
at := c.Cookies(cookieKeyAccessToken)
if at != "" {
m[cookieKeyAccessToken] = at
}
rt := c.Cookies(cookieKeyRefreshToken)
if rt != "" {
m[cookieKeyRefreshToken] = rt
}
return m
}
// setAuthCookies sets HTTP-only auth cookies with security settings derived from the request.
// Secure flag is based on actual protocol (works automatically with proxies/tunnels).
func setAuthCookies(c *fiber.Ctx, accessToken, refreshToken string, cfg CookieConfig) {
secure := c.Protocol() == "https"
c.Cookie(&fiber.Cookie{
Name: cookieKeyAccessToken,
Value: accessToken,
Path: "/",
Domain: cfg.Domain,
Expires: time.Now().Add(accessTokenValidFor),
SameSite: fiber.CookieSameSiteLaxMode,
HTTPOnly: true,
Secure: secure,
})
c.Cookie(&fiber.Cookie{
Name: cookieKeyRefreshToken,
Value: refreshToken,
Path: "/",
Domain: cfg.Domain,
Expires: time.Now().Add(refreshTokenValidFor),
SameSite: fiber.CookieSameSiteLaxMode,
HTTPOnly: true,
Secure: secure,
})
}

View File

@@ -10,15 +10,26 @@ import (
"github.com/uptrace/bun"
)
const (
tokenDeliveryCookie = "cookie"
tokenDeliveryBody = "body"
)
const (
cookieKeyAccessToken = "access_token"
cookieKeyRefreshToken = "refresh_token"
)
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
Email string `json:"email"`
Password string `json:"password"`
TokenDelivery string `json:"tokenDelivery"`
}
type loginResponse struct {
User user.User `json:"user"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
AccessToken string `json:"accessToken,omitempty"`
RefreshToken string `json:"refreshToken,omitempty"`
}
type refreshAccessTokenRequest struct {
@@ -31,12 +42,13 @@ type tokenResponse struct {
}
type HTTPHandler struct {
service *Service
db *bun.DB
service *Service
db *bun.DB
cookieConfig CookieConfig
}
func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
return &HTTPHandler{service: s, db: db}
func NewHTTPHandler(s *Service, db *bun.DB, cookieConfig CookieConfig) *HTTPHandler {
return &HTTPHandler{service: s, db: db, cookieConfig: cookieConfig}
}
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
@@ -69,11 +81,23 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
return httperr.Internal(err)
}
return c.JSON(loginResponse{
User: *result.User,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
})
switch req.TokenDelivery {
default:
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid token delivery method"})
case tokenDeliveryCookie:
setAuthCookies(c, result.AccessToken, result.RefreshToken, h.cookieConfig)
return c.JSON(loginResponse{
User: *result.User,
})
case tokenDeliveryBody:
return c.JSON(loginResponse{
User: *result.User,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
})
}
}
func (h *HTTPHandler) refreshAccessToken(c *fiber.Ctx) error {

View File

@@ -4,6 +4,7 @@ import (
"errors"
"log/slog"
"strings"
"time"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/user"
@@ -13,24 +14,35 @@ import (
const authenticatedUserKey = "authenticatedUser"
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token.
// NewAuthMiddleware creates a middleware that authenticates requests via Bearer token or cookies.
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
var at string
var rt string
var setCookies bool
authHeader := c.Get("Authorization")
if authHeader == "" {
slog.Info("no auth header")
if authHeader != "" {
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
slog.Info("invalid auth header")
return c.SendStatus(fiber.StatusUnauthorized)
}
at = parts[1]
setCookies = false
} else {
cookies := authCookies(c)
at = cookies[cookieKeyAccessToken]
rt = cookies[cookieKeyRefreshToken]
setCookies = true
}
if at == "" {
slog.Info("no access token")
return c.SendStatus(fiber.StatusUnauthorized)
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
slog.Info("invalid auth header")
return c.SendStatus(fiber.StatusUnauthorized)
}
token := parts[1]
u, err := s.AuthenticateWithAccessToken(c.Context(), db, token)
authResult, err := s.AuthenticateWithAccessToken(c.Context(), db, at)
if err != nil {
var e *InvalidAccessTokenError
if errors.As(err, &e) {
@@ -47,7 +59,24 @@ func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
return httperr.Internal(err)
}
c.Locals(authenticatedUserKey, u)
c.Locals(authenticatedUserKey, authResult.User)
// if cookie based auth and access token is about to expire (within 5 minutes),
// attempt to refresh the access token. if there is any error, ignore it and let the request continue.
if setCookies && time.Until(authResult.Claims.ExpiresAt.Time) < 5*time.Minute && rt != "" {
tx, txErr := db.BeginTx(c.Context(), nil)
if txErr == nil {
newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt)
if err == nil {
if commitErr := tx.Commit(); commitErr == nil {
setAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
}
} else {
_ = tx.Rollback()
slog.Debug("auto-refresh failed", "error", err)
}
}
}
return c.Next()
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/get-drexa/drexa/internal/password"
"github.com/get-drexa/drexa/internal/user"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
@@ -20,6 +21,11 @@ type AuthenticationTokens struct {
RefreshToken string
}
type AccessTokenAuthentication struct {
Claims *jwt.RegisteredClaims
User *user.User
}
var ErrInvalidCredentials = errors.New("invalid credentials")
type Service struct {
@@ -65,7 +71,7 @@ func (s *Service) GrantForUser(ctx context.Context, db bun.IDB, user *user.User)
return result, nil
}
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*AccessTokenAuthentication, error) {
claims, err := ParseAccessToken(token, &s.tokenConfig)
if err != nil {
slog.Info("failed to parse access token", "error", err)
@@ -78,7 +84,19 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, t
return nil, newInvalidAccessTokenError(err)
}
return s.userService.UserByID(ctx, db, id)
u, err := s.userService.UserByID(ctx, db, id)
if err != nil {
var nf *user.NotFoundError
if errors.As(err, &nf) {
return nil, newInvalidAccessTokenError(err)
}
return nil, err
}
return &AccessTokenAuthentication{
Claims: claims,
User: u,
}, nil
}
func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshToken string) (*AuthenticationTokens, error) {

View File

@@ -27,6 +27,7 @@ type Config struct {
Database DatabaseConfig `yaml:"database"`
JWT JWTConfig `yaml:"jwt"`
Storage StorageConfig `yaml:"storage"`
Cookie CookieConfig `yaml:"cookie"`
}
type ServerConfig struct {
@@ -52,6 +53,13 @@ type StorageConfig struct {
Bucket string `yaml:"bucket"`
}
// CookieConfig controls auth cookie behavior.
// Domain is optional - only needed for cross-subdomain setups (e.g., "app.com" for web.app.com + api.app.com).
// Secure flag is derived from the request protocol automatically.
type CookieConfig struct {
Domain string `yaml:"domain"`
}
// ConfigFromFile loads configuration from a YAML file.
// JWT secret key is loaded from JWT_SECRET_KEY env var (base64 encoded),
// falling back to the file path specified in jwt.secret_key_path.

View File

@@ -36,6 +36,11 @@ func NewServer(c Config) (*Server, error) {
app := fiber.New(fiber.Config{
ErrorHandler: httperr.ErrorHandler,
StreamRequestBody: true,
// Trust proxy headers (X-Forwarded-Proto, X-Forwarded-For) for proper
// protocol detection behind reverse proxies, tunnels (ngrok, cloudflare), etc.
EnableTrustedProxyCheck: true,
TrustedProxies: []string{"127.0.0.1", "::1"},
ProxyHeader: fiber.HeaderXForwardedFor,
})
app.Use(logger.New())
@@ -85,10 +90,14 @@ func NewServer(c Config) (*Server, error) {
uploadService := upload.NewService(vfs, blobStore)
accountService := account.NewService(userService, vfs)
authMiddleware := auth.NewBearerAuthMiddleware(authService, db)
cookieConfig := auth.CookieConfig{
Domain: c.Cookie.Domain,
}
authMiddleware := auth.NewAuthMiddleware(authService, db, cookieConfig)
api := app.Group("/api")
auth.NewHTTPHandler(authService, db).RegisterRoutes(api)
auth.NewHTTPHandler(authService, db, cookieConfig).RegisterRoutes(api)
accountRouter := account.NewHTTPHandler(accountService, authService, db, authMiddleware).RegisterRoutes(api)
upload.NewHTTPHandler(uploadService, db).RegisterRoutes(accountRouter)