mirror of
https://github.com/get-drexa/drive.git
synced 2026-02-02 19:31:17 +00:00
fix(backend): optional auth ret 401 if token invalid
This commit is contained in:
@@ -62,3 +62,38 @@ func SetAuthCookies(c *fiber.Ctx, accessToken, refreshToken string, cfg CookieCo
|
|||||||
c.Cookie(accessTokenCookie)
|
c.Cookie(accessTokenCookie)
|
||||||
c.Cookie(refreshTokenCookie)
|
c.Cookie(refreshTokenCookie)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClearAuthCookies clears the HTTP-only auth cookies by setting them to an expired value.
|
||||||
|
func ClearAuthCookies(c *fiber.Ctx, cfg CookieConfig) {
|
||||||
|
secure := c.Protocol() == "https"
|
||||||
|
expired := time.Unix(0, 0)
|
||||||
|
|
||||||
|
accessTokenCookie := &fiber.Cookie{
|
||||||
|
Name: cookieKeyAccessToken,
|
||||||
|
Value: "",
|
||||||
|
Path: "/",
|
||||||
|
Expires: expired,
|
||||||
|
SameSite: fiber.CookieSameSiteLaxMode,
|
||||||
|
HTTPOnly: true,
|
||||||
|
Secure: secure,
|
||||||
|
}
|
||||||
|
if cfg.Domain != "" {
|
||||||
|
accessTokenCookie.Domain = cfg.Domain
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshTokenCookie := &fiber.Cookie{
|
||||||
|
Name: cookieKeyRefreshToken,
|
||||||
|
Value: "",
|
||||||
|
Path: "/",
|
||||||
|
Expires: expired,
|
||||||
|
SameSite: fiber.CookieSameSiteLaxMode,
|
||||||
|
HTTPOnly: true,
|
||||||
|
Secure: secure,
|
||||||
|
}
|
||||||
|
if cfg.Domain != "" {
|
||||||
|
refreshTokenCookie.Domain = cfg.Domain
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Cookie(accessTokenCookie)
|
||||||
|
c.Cookie(refreshTokenCookie)
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,19 +13,45 @@ import (
|
|||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewAuthMiddleware creates a middleware that authenticates requests via Bearer token or cookies.
|
var errUnauthorized = errors.New("unauthorized")
|
||||||
// To obtain the authenticated user in subsequent handlers, see reqctx.AuthenticatedUser.
|
|
||||||
func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler {
|
// authAttempt is the normalized result of parsing and (optionally) authenticating
|
||||||
return func(c *fiber.Ctx) error {
|
// a request.
|
||||||
|
//
|
||||||
|
// It intentionally separates:
|
||||||
|
// - whether the request used cookie-based auth (SetCookies == true), which enables
|
||||||
|
// setting/refreshing HTTP-only cookies, from
|
||||||
|
// - what credentials were available (RefreshToken), and
|
||||||
|
// - the authenticated principal (AuthResult).
|
||||||
|
//
|
||||||
|
// This is shared by:
|
||||||
|
// - NewAuthMiddleware (strict): rejects unauthenticated requests with 401.
|
||||||
|
// - NewOptionalAuthMiddleware (best-effort): proceeds as anonymous on auth failure,
|
||||||
|
// and clears auth cookies when the failure came from cookies to avoid repeated
|
||||||
|
// failing auth attempts on public endpoints.
|
||||||
|
type authAttempt struct {
|
||||||
|
// AuthResult is populated when access token authentication succeeds.
|
||||||
|
AuthResult *AccessTokenAuthentication
|
||||||
|
|
||||||
|
// SetCookies is true when the request used cookie-based auth (no Authorization header).
|
||||||
|
// When true, the middleware may set/refresh cookies on successful auth or refresh flows.
|
||||||
|
SetCookies bool
|
||||||
|
|
||||||
|
// RefreshToken is the refresh token extracted from cookies when SetCookies is true.
|
||||||
|
// It is used for auto-refreshing near-expiry access tokens.
|
||||||
|
RefreshToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
func authenticateRequest(c *fiber.Ctx, s *Service, db *bun.DB, cookieConfig CookieConfig) (*authAttempt, error) {
|
||||||
var at string
|
var at string
|
||||||
var rt string
|
var rt string
|
||||||
var setCookies bool
|
var setCookies bool
|
||||||
|
|
||||||
authHeader := c.Get("Authorization")
|
authHeader := c.Get("Authorization")
|
||||||
if authHeader != "" {
|
if authHeader != "" {
|
||||||
parts := strings.Split(authHeader, " ")
|
parts := strings.Split(authHeader, " ")
|
||||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||||
slog.Info("invalid auth header")
|
return nil, errUnauthorized
|
||||||
return c.SendStatus(fiber.StatusUnauthorized)
|
|
||||||
}
|
}
|
||||||
at = parts[1]
|
at = parts[1]
|
||||||
setCookies = false
|
setCookies = false
|
||||||
@@ -37,25 +63,24 @@ func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.
|
|||||||
}
|
}
|
||||||
|
|
||||||
if at == "" && rt == "" {
|
if at == "" && rt == "" {
|
||||||
slog.Info("no access token or refresh token")
|
return nil, errUnauthorized
|
||||||
return c.SendStatus(fiber.StatusUnauthorized)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if at == "" {
|
if at == "" {
|
||||||
// if there is no access token, attempt to get new access token using the refresh token.
|
// if there is no access token, attempt to get new access token using the refresh token.
|
||||||
tx, err := db.BeginTx(c.Context(), nil)
|
tx, err := db.BeginTx(c.Context(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.SendStatus(fiber.StatusUnauthorized)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt)
|
newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.SendStatus(fiber.StatusUnauthorized)
|
return nil, errUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
return c.SendStatus(fiber.StatusUnauthorized)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
|
SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
|
||||||
@@ -67,44 +92,78 @@ func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var e *InvalidAccessTokenError
|
var e *InvalidAccessTokenError
|
||||||
if errors.As(err, &e) {
|
if errors.As(err, &e) {
|
||||||
slog.Info("invalid access token")
|
return nil, errUnauthorized
|
||||||
return c.SendStatus(fiber.StatusUnauthorized)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var nf *user.NotFoundError
|
var nf *user.NotFoundError
|
||||||
if errors.As(err, &nf) {
|
if errors.As(err, &nf) {
|
||||||
slog.Info("user not found")
|
return nil, errUnauthorized
|
||||||
return c.SendStatus(fiber.StatusUnauthorized)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &authAttempt{
|
||||||
|
AuthResult: authResult,
|
||||||
|
SetCookies: setCookies,
|
||||||
|
RefreshToken: rt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func maybeAutoRefresh(c *fiber.Ctx, s *Service, db *bun.DB, cookieConfig CookieConfig, attempt *authAttempt) {
|
||||||
|
if attempt == nil || attempt.AuthResult == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !attempt.SetCookies || attempt.RefreshToken == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if time.Until(attempt.AuthResult.Claims.ExpiresAt.Time) >= 5*time.Minute {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, txErr := db.BeginTx(c.Context(), nil)
|
||||||
|
if txErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
newTokens, err := s.RefreshAccessToken(c.Context(), tx, attempt.RefreshToken)
|
||||||
|
if err != nil {
|
||||||
|
slog.Debug("auto-refresh failed", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthMiddleware creates a middleware that authenticates requests via Bearer token or cookies.
|
||||||
|
// To obtain the authenticated user in subsequent handlers, see reqctx.AuthenticatedUser.
|
||||||
|
func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
attempt, err := authenticateRequest(c, s, db, cookieConfig)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errUnauthorized) {
|
||||||
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
|
}
|
||||||
return httperr.Internal(err)
|
return httperr.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqctx.SetAuthenticatedUser(c, authResult.User)
|
reqctx.SetAuthenticatedUser(c, attempt.AuthResult.User)
|
||||||
|
maybeAutoRefresh(c, s, db, cookieConfig, attempt)
|
||||||
// if cookie based auth and access token is about to expire (within 5 minutes),
|
|
||||||
// attempt to refresh the access token. if there is any error, ignore it and let the request continue.
|
|
||||||
if setCookies && time.Until(authResult.Claims.ExpiresAt.Time) < 5*time.Minute && rt != "" {
|
|
||||||
tx, txErr := db.BeginTx(c.Context(), nil)
|
|
||||||
if txErr == nil {
|
|
||||||
newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt)
|
|
||||||
if err == nil {
|
|
||||||
if commitErr := tx.Commit(); commitErr == nil {
|
|
||||||
SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
_ = tx.Rollback()
|
|
||||||
slog.Debug("auto-refresh failed", "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOptionalAuthMiddleware conditionally runs the given auth middleware only when
|
// NewOptionalAuthMiddleware attempts to authenticate a request if it contains auth
|
||||||
// the request contains auth credentials (Authorization header or auth cookies).
|
// credentials (Authorization header or auth cookies).
|
||||||
|
//
|
||||||
|
// Unlike NewAuthMiddleware, this middleware never rejects the request with 401. If
|
||||||
|
// authentication fails, it proceeds as an anonymous request.
|
||||||
//
|
//
|
||||||
// This is useful for endpoints that are publicly accessible but can also behave
|
// This is useful for endpoints that are publicly accessible but can also behave
|
||||||
// differently for authenticated callers (e.g. share consumption routes that may
|
// differently for authenticated callers (e.g. share consumption routes that may
|
||||||
@@ -112,15 +171,31 @@ func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.
|
|||||||
// present).
|
// present).
|
||||||
//
|
//
|
||||||
// Usage:
|
// Usage:
|
||||||
// authMiddleware := auth.NewAuthMiddleware(...)
|
//
|
||||||
// router.Use(auth.NewOptionalAuthMiddleware(authMiddleware))
|
// router.Use(auth.NewOptionalAuthMiddleware(authService, db, cookieConfig))
|
||||||
func NewOptionalAuthMiddleware(authMiddleware fiber.Handler) fiber.Handler {
|
func NewOptionalAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
hasAuthHeader := c.Get("Authorization") != ""
|
authHeader := c.Get("Authorization")
|
||||||
|
hasAuthHeader := authHeader != ""
|
||||||
hasAuthCookies := c.Cookies(cookieKeyAccessToken) != "" || c.Cookies(cookieKeyRefreshToken) != ""
|
hasAuthCookies := c.Cookies(cookieKeyAccessToken) != "" || c.Cookies(cookieKeyRefreshToken) != ""
|
||||||
if !hasAuthHeader && !hasAuthCookies {
|
if !hasAuthHeader && !hasAuthCookies {
|
||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
return authMiddleware(c)
|
|
||||||
|
attempt, err := authenticateRequest(c, s, db, cookieConfig)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errUnauthorized) {
|
||||||
|
if !hasAuthHeader && hasAuthCookies {
|
||||||
|
ClearAuthCookies(c, cookieConfig)
|
||||||
|
}
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
return httperr.Internal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reqctx.SetAuthenticatedUser(c, attempt.AuthResult.User)
|
||||||
|
maybeAutoRefresh(c, s, db, cookieConfig, attempt)
|
||||||
|
|
||||||
|
return c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ func NewServer(c Config) (*Server, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
authMiddleware := auth.NewAuthMiddleware(authService, db, cookieConfig)
|
authMiddleware := auth.NewAuthMiddleware(authService, db, cookieConfig)
|
||||||
|
optionalAuthMiddleware := auth.NewOptionalAuthMiddleware(authService, db, cookieConfig)
|
||||||
|
|
||||||
api := app.Group("/api")
|
api := app.Group("/api")
|
||||||
auth.NewHTTPHandler(authService, db, cookieConfig).RegisterRoutes(api)
|
auth.NewHTTPHandler(authService, db, cookieConfig).RegisterRoutes(api)
|
||||||
@@ -121,7 +122,7 @@ func NewServer(c Config) (*Server, error) {
|
|||||||
accountRouter := account.NewHTTPHandler(accountService, authService, vfs, db, authMiddleware, cookieConfig).RegisterRoutes(api)
|
accountRouter := account.NewHTTPHandler(accountService, authService, vfs, db, authMiddleware, cookieConfig).RegisterRoutes(api)
|
||||||
upload.NewHTTPHandler(uploadService, db).RegisterRoutes(accountRouter.VFSRouter())
|
upload.NewHTTPHandler(uploadService, db).RegisterRoutes(accountRouter.VFSRouter())
|
||||||
|
|
||||||
shareHTTP := sharing.NewHTTPHandler(sharingService, accountService, vfs, db, authMiddleware)
|
shareHTTP := sharing.NewHTTPHandler(sharingService, accountService, vfs, db, optionalAuthMiddleware)
|
||||||
shareRoutes := shareHTTP.RegisterShareConsumeRoutes(api)
|
shareRoutes := shareHTTP.RegisterShareConsumeRoutes(api)
|
||||||
shareHTTP.RegisterShareManagementRoutes(accountRouter)
|
shareHTTP.RegisterShareManagementRoutes(accountRouter)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/get-drexa/drexa/internal/account"
|
"github.com/get-drexa/drexa/internal/account"
|
||||||
"github.com/get-drexa/drexa/internal/auth"
|
|
||||||
"github.com/get-drexa/drexa/internal/httperr"
|
"github.com/get-drexa/drexa/internal/httperr"
|
||||||
"github.com/get-drexa/drexa/internal/nullable"
|
"github.com/get-drexa/drexa/internal/nullable"
|
||||||
"github.com/get-drexa/drexa/internal/reqctx"
|
"github.com/get-drexa/drexa/internal/reqctx"
|
||||||
@@ -21,7 +20,7 @@ type HTTPHandler struct {
|
|||||||
accountService *account.Service
|
accountService *account.Service
|
||||||
vfs *virtualfs.VirtualFS
|
vfs *virtualfs.VirtualFS
|
||||||
db *bun.DB
|
db *bun.DB
|
||||||
authMiddleware fiber.Handler
|
optionalAuthMiddleware fiber.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// createShareRequest represents a request to create a share link
|
// createShareRequest represents a request to create a share link
|
||||||
@@ -40,13 +39,13 @@ type patchShareRequest struct {
|
|||||||
ExpiresAt nullable.Time `json:"expiresAt" example:"2025-01-15T00:00:00Z"`
|
ExpiresAt nullable.Time `json:"expiresAt" example:"2025-01-15T00:00:00Z"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTPHandler(sharingService *Service, accountService *account.Service, vfs *virtualfs.VirtualFS, db *bun.DB, authMiddleware fiber.Handler) *HTTPHandler {
|
func NewHTTPHandler(sharingService *Service, accountService *account.Service, vfs *virtualfs.VirtualFS, db *bun.DB, optionalAuthMiddleware fiber.Handler) *HTTPHandler {
|
||||||
return &HTTPHandler{
|
return &HTTPHandler{
|
||||||
sharingService: sharingService,
|
sharingService: sharingService,
|
||||||
accountService: accountService,
|
accountService: accountService,
|
||||||
vfs: vfs,
|
vfs: vfs,
|
||||||
db: db,
|
db: db,
|
||||||
authMiddleware: authMiddleware,
|
optionalAuthMiddleware: optionalAuthMiddleware,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,7 +53,7 @@ func (h *HTTPHandler) RegisterShareConsumeRoutes(r fiber.Router) *virtualfs.Scop
|
|||||||
// Public shares should be accessible without authentication. However, if the client provides auth
|
// Public shares should be accessible without authentication. However, if the client provides auth
|
||||||
// credentials (cookies or Authorization header), attempt auth so share scopes can be resolved for
|
// credentials (cookies or Authorization header), attempt auth so share scopes can be resolved for
|
||||||
// account-scoped shares.
|
// account-scoped shares.
|
||||||
g := r.Group("/shares/:shareID", auth.NewOptionalAuthMiddleware(h.authMiddleware), h.shareMiddleware)
|
g := r.Group("/shares/:shareID", h.optionalAuthMiddleware, h.shareMiddleware)
|
||||||
return &virtualfs.ScopedRouter{Router: g}
|
return &virtualfs.ScopedRouter{Router: g}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user