mirror of
https://github.com/get-drexa/drive.git
synced 2025-12-06 16:21:39 +00:00
feat: impl cookie-based auth tokens exchange
implement access/refresh token exchange via cookies as well as automatic access token refresh
This commit is contained in:
@@ -28,3 +28,8 @@ storage:
|
||||
# Required when backend is "s3"
|
||||
# bucket: my-drexa-bucket
|
||||
|
||||
cookie:
|
||||
# Domain for cross-subdomain auth cookies.
|
||||
# Set this when frontend and API are on different subdomains (e.g., "app.com" for web.app.com + api.app.com).
|
||||
# Leave empty for single-domain or localhost setups.
|
||||
# domain: app.com
|
||||
|
||||
56
apps/backend/internal/auth/cookies.go
Normal file
56
apps/backend/internal/auth/cookies.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// CookieConfig controls auth cookie behavior.
|
||||
type CookieConfig struct {
|
||||
// Domain for cross-subdomain cookies (e.g., "app.com" for web.app.com + api.app.com).
|
||||
// Leave empty for same-host cookies (localhost, single domain).
|
||||
Domain string
|
||||
}
|
||||
|
||||
// authCookies returns auth cookies from the given fiber context.
|
||||
// Returns a map with the cookie names as keys and the cookie values as values.
|
||||
func authCookies(c *fiber.Ctx) map[string]string {
|
||||
m := make(map[string]string)
|
||||
at := c.Cookies(cookieKeyAccessToken)
|
||||
if at != "" {
|
||||
m[cookieKeyAccessToken] = at
|
||||
}
|
||||
rt := c.Cookies(cookieKeyRefreshToken)
|
||||
if rt != "" {
|
||||
m[cookieKeyRefreshToken] = rt
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// setAuthCookies sets HTTP-only auth cookies with security settings derived from the request.
|
||||
// Secure flag is based on actual protocol (works automatically with proxies/tunnels).
|
||||
func setAuthCookies(c *fiber.Ctx, accessToken, refreshToken string, cfg CookieConfig) {
|
||||
secure := c.Protocol() == "https"
|
||||
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: cookieKeyAccessToken,
|
||||
Value: accessToken,
|
||||
Path: "/",
|
||||
Domain: cfg.Domain,
|
||||
Expires: time.Now().Add(accessTokenValidFor),
|
||||
SameSite: fiber.CookieSameSiteLaxMode,
|
||||
HTTPOnly: true,
|
||||
Secure: secure,
|
||||
})
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: cookieKeyRefreshToken,
|
||||
Value: refreshToken,
|
||||
Path: "/",
|
||||
Domain: cfg.Domain,
|
||||
Expires: time.Now().Add(refreshTokenValidFor),
|
||||
SameSite: fiber.CookieSameSiteLaxMode,
|
||||
HTTPOnly: true,
|
||||
Secure: secure,
|
||||
})
|
||||
}
|
||||
@@ -10,15 +10,26 @@ import (
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
const (
|
||||
tokenDeliveryCookie = "cookie"
|
||||
tokenDeliveryBody = "body"
|
||||
)
|
||||
|
||||
const (
|
||||
cookieKeyAccessToken = "access_token"
|
||||
cookieKeyRefreshToken = "refresh_token"
|
||||
)
|
||||
|
||||
type loginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
TokenDelivery string `json:"tokenDelivery"`
|
||||
}
|
||||
|
||||
type loginResponse struct {
|
||||
User user.User `json:"user"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
AccessToken string `json:"accessToken,omitempty"`
|
||||
RefreshToken string `json:"refreshToken,omitempty"`
|
||||
}
|
||||
|
||||
type refreshAccessTokenRequest struct {
|
||||
@@ -31,12 +42,13 @@ type tokenResponse struct {
|
||||
}
|
||||
|
||||
type HTTPHandler struct {
|
||||
service *Service
|
||||
db *bun.DB
|
||||
service *Service
|
||||
db *bun.DB
|
||||
cookieConfig CookieConfig
|
||||
}
|
||||
|
||||
func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
|
||||
return &HTTPHandler{service: s, db: db}
|
||||
func NewHTTPHandler(s *Service, db *bun.DB, cookieConfig CookieConfig) *HTTPHandler {
|
||||
return &HTTPHandler{service: s, db: db, cookieConfig: cookieConfig}
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
|
||||
@@ -69,11 +81,23 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
||||
return httperr.Internal(err)
|
||||
}
|
||||
|
||||
return c.JSON(loginResponse{
|
||||
User: *result.User,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
})
|
||||
switch req.TokenDelivery {
|
||||
default:
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid token delivery method"})
|
||||
|
||||
case tokenDeliveryCookie:
|
||||
setAuthCookies(c, result.AccessToken, result.RefreshToken, h.cookieConfig)
|
||||
return c.JSON(loginResponse{
|
||||
User: *result.User,
|
||||
})
|
||||
|
||||
case tokenDeliveryBody:
|
||||
return c.JSON(loginResponse{
|
||||
User: *result.User,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) refreshAccessToken(c *fiber.Ctx) error {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/get-drexa/drexa/internal/httperr"
|
||||
"github.com/get-drexa/drexa/internal/user"
|
||||
@@ -13,24 +14,35 @@ import (
|
||||
|
||||
const authenticatedUserKey = "authenticatedUser"
|
||||
|
||||
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token.
|
||||
// NewAuthMiddleware creates a middleware that authenticates requests via Bearer token or cookies.
|
||||
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
|
||||
func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
|
||||
func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
var at string
|
||||
var rt string
|
||||
var setCookies bool
|
||||
authHeader := c.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
slog.Info("no auth header")
|
||||
if authHeader != "" {
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
slog.Info("invalid auth header")
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
at = parts[1]
|
||||
setCookies = false
|
||||
} else {
|
||||
cookies := authCookies(c)
|
||||
at = cookies[cookieKeyAccessToken]
|
||||
rt = cookies[cookieKeyRefreshToken]
|
||||
setCookies = true
|
||||
}
|
||||
|
||||
if at == "" {
|
||||
slog.Info("no access token")
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
slog.Info("invalid auth header")
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
|
||||
token := parts[1]
|
||||
u, err := s.AuthenticateWithAccessToken(c.Context(), db, token)
|
||||
authResult, err := s.AuthenticateWithAccessToken(c.Context(), db, at)
|
||||
if err != nil {
|
||||
var e *InvalidAccessTokenError
|
||||
if errors.As(err, &e) {
|
||||
@@ -47,7 +59,24 @@ func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
|
||||
return httperr.Internal(err)
|
||||
}
|
||||
|
||||
c.Locals(authenticatedUserKey, u)
|
||||
c.Locals(authenticatedUserKey, authResult.User)
|
||||
|
||||
// if cookie based auth and access token is about to expire (within 5 minutes),
|
||||
// attempt to refresh the access token. if there is any error, ignore it and let the request continue.
|
||||
if setCookies && time.Until(authResult.Claims.ExpiresAt.Time) < 5*time.Minute && rt != "" {
|
||||
tx, txErr := db.BeginTx(c.Context(), nil)
|
||||
if txErr == nil {
|
||||
newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt)
|
||||
if err == nil {
|
||||
if commitErr := tx.Commit(); commitErr == nil {
|
||||
setAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
|
||||
}
|
||||
} else {
|
||||
_ = tx.Rollback()
|
||||
slog.Debug("auto-refresh failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/get-drexa/drexa/internal/password"
|
||||
"github.com/get-drexa/drexa/internal/user"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
@@ -20,6 +21,11 @@ type AuthenticationTokens struct {
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
type AccessTokenAuthentication struct {
|
||||
Claims *jwt.RegisteredClaims
|
||||
User *user.User
|
||||
}
|
||||
|
||||
var ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
|
||||
type Service struct {
|
||||
@@ -65,7 +71,7 @@ func (s *Service) GrantForUser(ctx context.Context, db bun.IDB, user *user.User)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
|
||||
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*AccessTokenAuthentication, error) {
|
||||
claims, err := ParseAccessToken(token, &s.tokenConfig)
|
||||
if err != nil {
|
||||
slog.Info("failed to parse access token", "error", err)
|
||||
@@ -78,7 +84,19 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, t
|
||||
return nil, newInvalidAccessTokenError(err)
|
||||
}
|
||||
|
||||
return s.userService.UserByID(ctx, db, id)
|
||||
u, err := s.userService.UserByID(ctx, db, id)
|
||||
if err != nil {
|
||||
var nf *user.NotFoundError
|
||||
if errors.As(err, &nf) {
|
||||
return nil, newInvalidAccessTokenError(err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccessTokenAuthentication{
|
||||
Claims: claims,
|
||||
User: u,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshToken string) (*AuthenticationTokens, error) {
|
||||
|
||||
@@ -27,6 +27,7 @@ type Config struct {
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
JWT JWTConfig `yaml:"jwt"`
|
||||
Storage StorageConfig `yaml:"storage"`
|
||||
Cookie CookieConfig `yaml:"cookie"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
@@ -52,6 +53,13 @@ type StorageConfig struct {
|
||||
Bucket string `yaml:"bucket"`
|
||||
}
|
||||
|
||||
// CookieConfig controls auth cookie behavior.
|
||||
// Domain is optional - only needed for cross-subdomain setups (e.g., "app.com" for web.app.com + api.app.com).
|
||||
// Secure flag is derived from the request protocol automatically.
|
||||
type CookieConfig struct {
|
||||
Domain string `yaml:"domain"`
|
||||
}
|
||||
|
||||
// ConfigFromFile loads configuration from a YAML file.
|
||||
// JWT secret key is loaded from JWT_SECRET_KEY env var (base64 encoded),
|
||||
// falling back to the file path specified in jwt.secret_key_path.
|
||||
|
||||
@@ -36,6 +36,11 @@ func NewServer(c Config) (*Server, error) {
|
||||
app := fiber.New(fiber.Config{
|
||||
ErrorHandler: httperr.ErrorHandler,
|
||||
StreamRequestBody: true,
|
||||
// Trust proxy headers (X-Forwarded-Proto, X-Forwarded-For) for proper
|
||||
// protocol detection behind reverse proxies, tunnels (ngrok, cloudflare), etc.
|
||||
EnableTrustedProxyCheck: true,
|
||||
TrustedProxies: []string{"127.0.0.1", "::1"},
|
||||
ProxyHeader: fiber.HeaderXForwardedFor,
|
||||
})
|
||||
app.Use(logger.New())
|
||||
|
||||
@@ -85,10 +90,14 @@ func NewServer(c Config) (*Server, error) {
|
||||
uploadService := upload.NewService(vfs, blobStore)
|
||||
accountService := account.NewService(userService, vfs)
|
||||
|
||||
authMiddleware := auth.NewBearerAuthMiddleware(authService, db)
|
||||
cookieConfig := auth.CookieConfig{
|
||||
Domain: c.Cookie.Domain,
|
||||
}
|
||||
|
||||
authMiddleware := auth.NewAuthMiddleware(authService, db, cookieConfig)
|
||||
|
||||
api := app.Group("/api")
|
||||
auth.NewHTTPHandler(authService, db).RegisterRoutes(api)
|
||||
auth.NewHTTPHandler(authService, db, cookieConfig).RegisterRoutes(api)
|
||||
|
||||
accountRouter := account.NewHTTPHandler(accountService, authService, db, authMiddleware).RegisterRoutes(api)
|
||||
upload.NewHTTPHandler(uploadService, db).RegisterRoutes(accountRouter)
|
||||
|
||||
Reference in New Issue
Block a user