diff --git a/apps/backend/internal/auth/cookies.go b/apps/backend/internal/auth/cookies.go index ddde4f3..a018c7c 100644 --- a/apps/backend/internal/auth/cookies.go +++ b/apps/backend/internal/auth/cookies.go @@ -62,3 +62,38 @@ func SetAuthCookies(c *fiber.Ctx, accessToken, refreshToken string, cfg CookieCo c.Cookie(accessTokenCookie) c.Cookie(refreshTokenCookie) } + +// ClearAuthCookies clears the HTTP-only auth cookies by setting them to an expired value. +func ClearAuthCookies(c *fiber.Ctx, cfg CookieConfig) { + secure := c.Protocol() == "https" + expired := time.Unix(0, 0) + + accessTokenCookie := &fiber.Cookie{ + Name: cookieKeyAccessToken, + Value: "", + Path: "/", + Expires: expired, + SameSite: fiber.CookieSameSiteLaxMode, + HTTPOnly: true, + Secure: secure, + } + if cfg.Domain != "" { + accessTokenCookie.Domain = cfg.Domain + } + + refreshTokenCookie := &fiber.Cookie{ + Name: cookieKeyRefreshToken, + Value: "", + Path: "/", + Expires: expired, + SameSite: fiber.CookieSameSiteLaxMode, + HTTPOnly: true, + Secure: secure, + } + if cfg.Domain != "" { + refreshTokenCookie.Domain = cfg.Domain + } + + c.Cookie(accessTokenCookie) + c.Cookie(refreshTokenCookie) +} diff --git a/apps/backend/internal/auth/middleware.go b/apps/backend/internal/auth/middleware.go index d158cc6..d29f99f 100644 --- a/apps/backend/internal/auth/middleware.go +++ b/apps/backend/internal/auth/middleware.go @@ -13,98 +13,157 @@ import ( "github.com/uptrace/bun" ) +var errUnauthorized = errors.New("unauthorized") + +// authAttempt is the normalized result of parsing and (optionally) authenticating +// a request. +// +// It intentionally separates: +// - whether the request used cookie-based auth (SetCookies == true), which enables +// setting/refreshing HTTP-only cookies, from +// - what credentials were available (RefreshToken), and +// - the authenticated principal (AuthResult). +// +// This is shared by: +// - NewAuthMiddleware (strict): rejects unauthenticated requests with 401. +// - NewOptionalAuthMiddleware (best-effort): proceeds as anonymous on auth failure, +// and clears auth cookies when the failure came from cookies to avoid repeated +// failing auth attempts on public endpoints. +type authAttempt struct { + // AuthResult is populated when access token authentication succeeds. + AuthResult *AccessTokenAuthentication + + // SetCookies is true when the request used cookie-based auth (no Authorization header). + // When true, the middleware may set/refresh cookies on successful auth or refresh flows. + SetCookies bool + + // RefreshToken is the refresh token extracted from cookies when SetCookies is true. + // It is used for auto-refreshing near-expiry access tokens. + RefreshToken string +} + +func authenticateRequest(c *fiber.Ctx, s *Service, db *bun.DB, cookieConfig CookieConfig) (*authAttempt, error) { + var at string + var rt string + var setCookies bool + + authHeader := c.Get("Authorization") + if authHeader != "" { + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return nil, errUnauthorized + } + at = parts[1] + setCookies = false + } else { + cookies := authCookies(c) + at = cookies[cookieKeyAccessToken] + rt = cookies[cookieKeyRefreshToken] + setCookies = true + } + + if at == "" && rt == "" { + return nil, errUnauthorized + } + + if at == "" { + // if there is no access token, attempt to get new access token using the refresh token. + tx, err := db.BeginTx(c.Context(), nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt) + if err != nil { + return nil, errUnauthorized + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig) + at = newTokens.AccessToken + rt = newTokens.RefreshToken + } + + authResult, err := s.AuthenticateWithAccessToken(c.Context(), db, at) + if err != nil { + var e *InvalidAccessTokenError + if errors.As(err, &e) { + return nil, errUnauthorized + } + + var nf *user.NotFoundError + if errors.As(err, &nf) { + return nil, errUnauthorized + } + + return nil, err + } + + return &authAttempt{ + AuthResult: authResult, + SetCookies: setCookies, + RefreshToken: rt, + }, nil +} + +func maybeAutoRefresh(c *fiber.Ctx, s *Service, db *bun.DB, cookieConfig CookieConfig, attempt *authAttempt) { + if attempt == nil || attempt.AuthResult == nil { + return + } + if !attempt.SetCookies || attempt.RefreshToken == "" { + return + } + if time.Until(attempt.AuthResult.Claims.ExpiresAt.Time) >= 5*time.Minute { + return + } + + tx, txErr := db.BeginTx(c.Context(), nil) + if txErr != nil { + return + } + defer tx.Rollback() + + newTokens, err := s.RefreshAccessToken(c.Context(), tx, attempt.RefreshToken) + if err != nil { + slog.Debug("auto-refresh failed", "error", err) + return + } + + if err := tx.Commit(); err != nil { + return + } + + SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig) +} + // NewAuthMiddleware creates a middleware that authenticates requests via Bearer token or cookies. // To obtain the authenticated user in subsequent handlers, see reqctx.AuthenticatedUser. func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler { return func(c *fiber.Ctx) error { - var at string - var rt string - var setCookies bool - authHeader := c.Get("Authorization") - if authHeader != "" { - parts := strings.Split(authHeader, " ") - if len(parts) != 2 || parts[0] != "Bearer" { - slog.Info("invalid auth header") - return c.SendStatus(fiber.StatusUnauthorized) - } - at = parts[1] - setCookies = false - } else { - cookies := authCookies(c) - at = cookies[cookieKeyAccessToken] - rt = cookies[cookieKeyRefreshToken] - setCookies = true - } - - if at == "" && rt == "" { - slog.Info("no access token or refresh token") - return c.SendStatus(fiber.StatusUnauthorized) - } - - if at == "" { - // if there is no access token, attempt to get new access token using the refresh token. - tx, err := db.BeginTx(c.Context(), nil) - if err != nil { - return c.SendStatus(fiber.StatusUnauthorized) - } - defer tx.Rollback() - - newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt) - if err != nil { - return c.SendStatus(fiber.StatusUnauthorized) - } - - if err := tx.Commit(); err != nil { - return c.SendStatus(fiber.StatusUnauthorized) - } - - SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig) - at = newTokens.AccessToken - rt = newTokens.RefreshToken - } - - authResult, err := s.AuthenticateWithAccessToken(c.Context(), db, at) + attempt, err := authenticateRequest(c, s, db, cookieConfig) if err != nil { - var e *InvalidAccessTokenError - if errors.As(err, &e) { - slog.Info("invalid access token") + if errors.Is(err, errUnauthorized) { return c.SendStatus(fiber.StatusUnauthorized) } - - var nf *user.NotFoundError - if errors.As(err, &nf) { - slog.Info("user not found") - return c.SendStatus(fiber.StatusUnauthorized) - } - return httperr.Internal(err) } - reqctx.SetAuthenticatedUser(c, authResult.User) - - // if cookie based auth and access token is about to expire (within 5 minutes), - // attempt to refresh the access token. if there is any error, ignore it and let the request continue. - if setCookies && time.Until(authResult.Claims.ExpiresAt.Time) < 5*time.Minute && rt != "" { - tx, txErr := db.BeginTx(c.Context(), nil) - if txErr == nil { - newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt) - if err == nil { - if commitErr := tx.Commit(); commitErr == nil { - SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig) - } - } else { - _ = tx.Rollback() - slog.Debug("auto-refresh failed", "error", err) - } - } - } + reqctx.SetAuthenticatedUser(c, attempt.AuthResult.User) + maybeAutoRefresh(c, s, db, cookieConfig, attempt) return c.Next() } } -// NewOptionalAuthMiddleware conditionally runs the given auth middleware only when -// the request contains auth credentials (Authorization header or auth cookies). +// NewOptionalAuthMiddleware attempts to authenticate a request if it contains auth +// credentials (Authorization header or auth cookies). +// +// Unlike NewAuthMiddleware, this middleware never rejects the request with 401. If +// authentication fails, it proceeds as an anonymous request. // // This is useful for endpoints that are publicly accessible but can also behave // differently for authenticated callers (e.g. share consumption routes that may @@ -112,15 +171,31 @@ func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber. // present). // // Usage: -// authMiddleware := auth.NewAuthMiddleware(...) -// router.Use(auth.NewOptionalAuthMiddleware(authMiddleware)) -func NewOptionalAuthMiddleware(authMiddleware fiber.Handler) fiber.Handler { +// +// router.Use(auth.NewOptionalAuthMiddleware(authService, db, cookieConfig)) +func NewOptionalAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler { return func(c *fiber.Ctx) error { - hasAuthHeader := c.Get("Authorization") != "" + authHeader := c.Get("Authorization") + hasAuthHeader := authHeader != "" hasAuthCookies := c.Cookies(cookieKeyAccessToken) != "" || c.Cookies(cookieKeyRefreshToken) != "" if !hasAuthHeader && !hasAuthCookies { return c.Next() } - return authMiddleware(c) + + attempt, err := authenticateRequest(c, s, db, cookieConfig) + if err != nil { + if errors.Is(err, errUnauthorized) { + if !hasAuthHeader && hasAuthCookies { + ClearAuthCookies(c, cookieConfig) + } + return c.Next() + } + return httperr.Internal(err) + } + + reqctx.SetAuthenticatedUser(c, attempt.AuthResult.User) + maybeAutoRefresh(c, s, db, cookieConfig, attempt) + + return c.Next() } } diff --git a/apps/backend/internal/drexa/server.go b/apps/backend/internal/drexa/server.go index d3e52f4..79dd7dc 100644 --- a/apps/backend/internal/drexa/server.go +++ b/apps/backend/internal/drexa/server.go @@ -113,6 +113,7 @@ func NewServer(c Config) (*Server, error) { } authMiddleware := auth.NewAuthMiddleware(authService, db, cookieConfig) + optionalAuthMiddleware := auth.NewOptionalAuthMiddleware(authService, db, cookieConfig) api := app.Group("/api") auth.NewHTTPHandler(authService, db, cookieConfig).RegisterRoutes(api) @@ -121,7 +122,7 @@ func NewServer(c Config) (*Server, error) { accountRouter := account.NewHTTPHandler(accountService, authService, vfs, db, authMiddleware, cookieConfig).RegisterRoutes(api) upload.NewHTTPHandler(uploadService, db).RegisterRoutes(accountRouter.VFSRouter()) - shareHTTP := sharing.NewHTTPHandler(sharingService, accountService, vfs, db, authMiddleware) + shareHTTP := sharing.NewHTTPHandler(sharingService, accountService, vfs, db, optionalAuthMiddleware) shareRoutes := shareHTTP.RegisterShareConsumeRoutes(api) shareHTTP.RegisterShareManagementRoutes(accountRouter) diff --git a/apps/backend/internal/sharing/http.go b/apps/backend/internal/sharing/http.go index c979281..0f33b99 100644 --- a/apps/backend/internal/sharing/http.go +++ b/apps/backend/internal/sharing/http.go @@ -5,7 +5,6 @@ import ( "time" "github.com/get-drexa/drexa/internal/account" - "github.com/get-drexa/drexa/internal/auth" "github.com/get-drexa/drexa/internal/httperr" "github.com/get-drexa/drexa/internal/nullable" "github.com/get-drexa/drexa/internal/reqctx" @@ -21,7 +20,7 @@ type HTTPHandler struct { accountService *account.Service vfs *virtualfs.VirtualFS db *bun.DB - authMiddleware fiber.Handler + optionalAuthMiddleware fiber.Handler } // createShareRequest represents a request to create a share link @@ -40,13 +39,13 @@ type patchShareRequest struct { ExpiresAt nullable.Time `json:"expiresAt" example:"2025-01-15T00:00:00Z"` } -func NewHTTPHandler(sharingService *Service, accountService *account.Service, vfs *virtualfs.VirtualFS, db *bun.DB, authMiddleware fiber.Handler) *HTTPHandler { +func NewHTTPHandler(sharingService *Service, accountService *account.Service, vfs *virtualfs.VirtualFS, db *bun.DB, optionalAuthMiddleware fiber.Handler) *HTTPHandler { return &HTTPHandler{ sharingService: sharingService, accountService: accountService, vfs: vfs, db: db, - authMiddleware: authMiddleware, + optionalAuthMiddleware: optionalAuthMiddleware, } } @@ -54,7 +53,7 @@ func (h *HTTPHandler) RegisterShareConsumeRoutes(r fiber.Router) *virtualfs.Scop // Public shares should be accessible without authentication. However, if the client provides auth // credentials (cookies or Authorization header), attempt auth so share scopes can be resolved for // account-scoped shares. - g := r.Group("/shares/:shareID", auth.NewOptionalAuthMiddleware(h.authMiddleware), h.shareMiddleware) + g := r.Group("/shares/:shareID", h.optionalAuthMiddleware, h.shareMiddleware) return &virtualfs.ScopedRouter{Router: g} }