fix(backend): optional auth ret 401 if token invalid

This commit is contained in:
2025-12-29 01:01:28 +00:00
parent a956481215
commit 294fadfe4c
4 changed files with 198 additions and 88 deletions

View File

@@ -62,3 +62,38 @@ func SetAuthCookies(c *fiber.Ctx, accessToken, refreshToken string, cfg CookieCo
c.Cookie(accessTokenCookie) c.Cookie(accessTokenCookie)
c.Cookie(refreshTokenCookie) c.Cookie(refreshTokenCookie)
} }
// ClearAuthCookies clears the HTTP-only auth cookies by setting them to an expired value.
func ClearAuthCookies(c *fiber.Ctx, cfg CookieConfig) {
secure := c.Protocol() == "https"
expired := time.Unix(0, 0)
accessTokenCookie := &fiber.Cookie{
Name: cookieKeyAccessToken,
Value: "",
Path: "/",
Expires: expired,
SameSite: fiber.CookieSameSiteLaxMode,
HTTPOnly: true,
Secure: secure,
}
if cfg.Domain != "" {
accessTokenCookie.Domain = cfg.Domain
}
refreshTokenCookie := &fiber.Cookie{
Name: cookieKeyRefreshToken,
Value: "",
Path: "/",
Expires: expired,
SameSite: fiber.CookieSameSiteLaxMode,
HTTPOnly: true,
Secure: secure,
}
if cfg.Domain != "" {
refreshTokenCookie.Domain = cfg.Domain
}
c.Cookie(accessTokenCookie)
c.Cookie(refreshTokenCookie)
}

View File

@@ -13,98 +13,157 @@ import (
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
var errUnauthorized = errors.New("unauthorized")
// authAttempt is the normalized result of parsing and (optionally) authenticating
// a request.
//
// It intentionally separates:
// - whether the request used cookie-based auth (SetCookies == true), which enables
// setting/refreshing HTTP-only cookies, from
// - what credentials were available (RefreshToken), and
// - the authenticated principal (AuthResult).
//
// This is shared by:
// - NewAuthMiddleware (strict): rejects unauthenticated requests with 401.
// - NewOptionalAuthMiddleware (best-effort): proceeds as anonymous on auth failure,
// and clears auth cookies when the failure came from cookies to avoid repeated
// failing auth attempts on public endpoints.
type authAttempt struct {
// AuthResult is populated when access token authentication succeeds.
AuthResult *AccessTokenAuthentication
// SetCookies is true when the request used cookie-based auth (no Authorization header).
// When true, the middleware may set/refresh cookies on successful auth or refresh flows.
SetCookies bool
// RefreshToken is the refresh token extracted from cookies when SetCookies is true.
// It is used for auto-refreshing near-expiry access tokens.
RefreshToken string
}
func authenticateRequest(c *fiber.Ctx, s *Service, db *bun.DB, cookieConfig CookieConfig) (*authAttempt, error) {
var at string
var rt string
var setCookies bool
authHeader := c.Get("Authorization")
if authHeader != "" {
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
return nil, errUnauthorized
}
at = parts[1]
setCookies = false
} else {
cookies := authCookies(c)
at = cookies[cookieKeyAccessToken]
rt = cookies[cookieKeyRefreshToken]
setCookies = true
}
if at == "" && rt == "" {
return nil, errUnauthorized
}
if at == "" {
// if there is no access token, attempt to get new access token using the refresh token.
tx, err := db.BeginTx(c.Context(), nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt)
if err != nil {
return nil, errUnauthorized
}
if err := tx.Commit(); err != nil {
return nil, err
}
SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
at = newTokens.AccessToken
rt = newTokens.RefreshToken
}
authResult, err := s.AuthenticateWithAccessToken(c.Context(), db, at)
if err != nil {
var e *InvalidAccessTokenError
if errors.As(err, &e) {
return nil, errUnauthorized
}
var nf *user.NotFoundError
if errors.As(err, &nf) {
return nil, errUnauthorized
}
return nil, err
}
return &authAttempt{
AuthResult: authResult,
SetCookies: setCookies,
RefreshToken: rt,
}, nil
}
func maybeAutoRefresh(c *fiber.Ctx, s *Service, db *bun.DB, cookieConfig CookieConfig, attempt *authAttempt) {
if attempt == nil || attempt.AuthResult == nil {
return
}
if !attempt.SetCookies || attempt.RefreshToken == "" {
return
}
if time.Until(attempt.AuthResult.Claims.ExpiresAt.Time) >= 5*time.Minute {
return
}
tx, txErr := db.BeginTx(c.Context(), nil)
if txErr != nil {
return
}
defer tx.Rollback()
newTokens, err := s.RefreshAccessToken(c.Context(), tx, attempt.RefreshToken)
if err != nil {
slog.Debug("auto-refresh failed", "error", err)
return
}
if err := tx.Commit(); err != nil {
return
}
SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
}
// NewAuthMiddleware creates a middleware that authenticates requests via Bearer token or cookies. // NewAuthMiddleware creates a middleware that authenticates requests via Bearer token or cookies.
// To obtain the authenticated user in subsequent handlers, see reqctx.AuthenticatedUser. // To obtain the authenticated user in subsequent handlers, see reqctx.AuthenticatedUser.
func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler { func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
var at string attempt, err := authenticateRequest(c, s, db, cookieConfig)
var rt string
var setCookies bool
authHeader := c.Get("Authorization")
if authHeader != "" {
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
slog.Info("invalid auth header")
return c.SendStatus(fiber.StatusUnauthorized)
}
at = parts[1]
setCookies = false
} else {
cookies := authCookies(c)
at = cookies[cookieKeyAccessToken]
rt = cookies[cookieKeyRefreshToken]
setCookies = true
}
if at == "" && rt == "" {
slog.Info("no access token or refresh token")
return c.SendStatus(fiber.StatusUnauthorized)
}
if at == "" {
// if there is no access token, attempt to get new access token using the refresh token.
tx, err := db.BeginTx(c.Context(), nil)
if err != nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
defer tx.Rollback()
newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt)
if err != nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
if err := tx.Commit(); err != nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
at = newTokens.AccessToken
rt = newTokens.RefreshToken
}
authResult, err := s.AuthenticateWithAccessToken(c.Context(), db, at)
if err != nil { if err != nil {
var e *InvalidAccessTokenError if errors.Is(err, errUnauthorized) {
if errors.As(err, &e) {
slog.Info("invalid access token")
return c.SendStatus(fiber.StatusUnauthorized) return c.SendStatus(fiber.StatusUnauthorized)
} }
var nf *user.NotFoundError
if errors.As(err, &nf) {
slog.Info("user not found")
return c.SendStatus(fiber.StatusUnauthorized)
}
return httperr.Internal(err) return httperr.Internal(err)
} }
reqctx.SetAuthenticatedUser(c, authResult.User) reqctx.SetAuthenticatedUser(c, attempt.AuthResult.User)
maybeAutoRefresh(c, s, db, cookieConfig, attempt)
// if cookie based auth and access token is about to expire (within 5 minutes),
// attempt to refresh the access token. if there is any error, ignore it and let the request continue.
if setCookies && time.Until(authResult.Claims.ExpiresAt.Time) < 5*time.Minute && rt != "" {
tx, txErr := db.BeginTx(c.Context(), nil)
if txErr == nil {
newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt)
if err == nil {
if commitErr := tx.Commit(); commitErr == nil {
SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
}
} else {
_ = tx.Rollback()
slog.Debug("auto-refresh failed", "error", err)
}
}
}
return c.Next() return c.Next()
} }
} }
// NewOptionalAuthMiddleware conditionally runs the given auth middleware only when // NewOptionalAuthMiddleware attempts to authenticate a request if it contains auth
// the request contains auth credentials (Authorization header or auth cookies). // credentials (Authorization header or auth cookies).
//
// Unlike NewAuthMiddleware, this middleware never rejects the request with 401. If
// authentication fails, it proceeds as an anonymous request.
// //
// This is useful for endpoints that are publicly accessible but can also behave // This is useful for endpoints that are publicly accessible but can also behave
// differently for authenticated callers (e.g. share consumption routes that may // differently for authenticated callers (e.g. share consumption routes that may
@@ -112,15 +171,31 @@ func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.
// present). // present).
// //
// Usage: // Usage:
// authMiddleware := auth.NewAuthMiddleware(...) //
// router.Use(auth.NewOptionalAuthMiddleware(authMiddleware)) // router.Use(auth.NewOptionalAuthMiddleware(authService, db, cookieConfig))
func NewOptionalAuthMiddleware(authMiddleware fiber.Handler) fiber.Handler { func NewOptionalAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
hasAuthHeader := c.Get("Authorization") != "" authHeader := c.Get("Authorization")
hasAuthHeader := authHeader != ""
hasAuthCookies := c.Cookies(cookieKeyAccessToken) != "" || c.Cookies(cookieKeyRefreshToken) != "" hasAuthCookies := c.Cookies(cookieKeyAccessToken) != "" || c.Cookies(cookieKeyRefreshToken) != ""
if !hasAuthHeader && !hasAuthCookies { if !hasAuthHeader && !hasAuthCookies {
return c.Next() return c.Next()
} }
return authMiddleware(c)
attempt, err := authenticateRequest(c, s, db, cookieConfig)
if err != nil {
if errors.Is(err, errUnauthorized) {
if !hasAuthHeader && hasAuthCookies {
ClearAuthCookies(c, cookieConfig)
}
return c.Next()
}
return httperr.Internal(err)
}
reqctx.SetAuthenticatedUser(c, attempt.AuthResult.User)
maybeAutoRefresh(c, s, db, cookieConfig, attempt)
return c.Next()
} }
} }

View File

@@ -113,6 +113,7 @@ func NewServer(c Config) (*Server, error) {
} }
authMiddleware := auth.NewAuthMiddleware(authService, db, cookieConfig) authMiddleware := auth.NewAuthMiddleware(authService, db, cookieConfig)
optionalAuthMiddleware := auth.NewOptionalAuthMiddleware(authService, db, cookieConfig)
api := app.Group("/api") api := app.Group("/api")
auth.NewHTTPHandler(authService, db, cookieConfig).RegisterRoutes(api) auth.NewHTTPHandler(authService, db, cookieConfig).RegisterRoutes(api)
@@ -121,7 +122,7 @@ func NewServer(c Config) (*Server, error) {
accountRouter := account.NewHTTPHandler(accountService, authService, vfs, db, authMiddleware, cookieConfig).RegisterRoutes(api) accountRouter := account.NewHTTPHandler(accountService, authService, vfs, db, authMiddleware, cookieConfig).RegisterRoutes(api)
upload.NewHTTPHandler(uploadService, db).RegisterRoutes(accountRouter.VFSRouter()) upload.NewHTTPHandler(uploadService, db).RegisterRoutes(accountRouter.VFSRouter())
shareHTTP := sharing.NewHTTPHandler(sharingService, accountService, vfs, db, authMiddleware) shareHTTP := sharing.NewHTTPHandler(sharingService, accountService, vfs, db, optionalAuthMiddleware)
shareRoutes := shareHTTP.RegisterShareConsumeRoutes(api) shareRoutes := shareHTTP.RegisterShareConsumeRoutes(api)
shareHTTP.RegisterShareManagementRoutes(accountRouter) shareHTTP.RegisterShareManagementRoutes(accountRouter)

View File

@@ -5,7 +5,6 @@ import (
"time" "time"
"github.com/get-drexa/drexa/internal/account" "github.com/get-drexa/drexa/internal/account"
"github.com/get-drexa/drexa/internal/auth"
"github.com/get-drexa/drexa/internal/httperr" "github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/nullable" "github.com/get-drexa/drexa/internal/nullable"
"github.com/get-drexa/drexa/internal/reqctx" "github.com/get-drexa/drexa/internal/reqctx"
@@ -21,7 +20,7 @@ type HTTPHandler struct {
accountService *account.Service accountService *account.Service
vfs *virtualfs.VirtualFS vfs *virtualfs.VirtualFS
db *bun.DB db *bun.DB
authMiddleware fiber.Handler optionalAuthMiddleware fiber.Handler
} }
// createShareRequest represents a request to create a share link // createShareRequest represents a request to create a share link
@@ -40,13 +39,13 @@ type patchShareRequest struct {
ExpiresAt nullable.Time `json:"expiresAt" example:"2025-01-15T00:00:00Z"` ExpiresAt nullable.Time `json:"expiresAt" example:"2025-01-15T00:00:00Z"`
} }
func NewHTTPHandler(sharingService *Service, accountService *account.Service, vfs *virtualfs.VirtualFS, db *bun.DB, authMiddleware fiber.Handler) *HTTPHandler { func NewHTTPHandler(sharingService *Service, accountService *account.Service, vfs *virtualfs.VirtualFS, db *bun.DB, optionalAuthMiddleware fiber.Handler) *HTTPHandler {
return &HTTPHandler{ return &HTTPHandler{
sharingService: sharingService, sharingService: sharingService,
accountService: accountService, accountService: accountService,
vfs: vfs, vfs: vfs,
db: db, db: db,
authMiddleware: authMiddleware, optionalAuthMiddleware: optionalAuthMiddleware,
} }
} }
@@ -54,7 +53,7 @@ func (h *HTTPHandler) RegisterShareConsumeRoutes(r fiber.Router) *virtualfs.Scop
// Public shares should be accessible without authentication. However, if the client provides auth // Public shares should be accessible without authentication. However, if the client provides auth
// credentials (cookies or Authorization header), attempt auth so share scopes can be resolved for // credentials (cookies or Authorization header), attempt auth so share scopes can be resolved for
// account-scoped shares. // account-scoped shares.
g := r.Group("/shares/:shareID", auth.NewOptionalAuthMiddleware(h.authMiddleware), h.shareMiddleware) g := r.Group("/shares/:shareID", h.optionalAuthMiddleware, h.shareMiddleware)
return &virtualfs.ScopedRouter{Router: g} return &virtualfs.ScopedRouter{Router: g}
} }