2025-11-26 01:09:42 +00:00
|
|
|
package auth
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"errors"
|
2025-11-30 19:19:33 +00:00
|
|
|
"log/slog"
|
2025-11-26 01:09:42 +00:00
|
|
|
"strings"
|
2025-12-04 00:26:20 +00:00
|
|
|
"time"
|
2025-11-26 01:09:42 +00:00
|
|
|
|
2025-11-30 19:19:33 +00:00
|
|
|
"github.com/get-drexa/drexa/internal/httperr"
|
2025-12-04 00:48:43 +00:00
|
|
|
"github.com/get-drexa/drexa/internal/reqctx"
|
2025-11-26 01:09:42 +00:00
|
|
|
"github.com/get-drexa/drexa/internal/user"
|
|
|
|
|
"github.com/gofiber/fiber/v2"
|
2025-11-29 20:51:56 +00:00
|
|
|
"github.com/uptrace/bun"
|
2025-11-26 01:09:42 +00:00
|
|
|
)
|
|
|
|
|
|
2025-12-04 00:26:20 +00:00
|
|
|
// NewAuthMiddleware creates a middleware that authenticates requests via Bearer token or cookies.
|
2025-12-04 00:48:43 +00:00
|
|
|
// To obtain the authenticated user in subsequent handlers, see reqctx.AuthenticatedUser.
|
2025-12-04 00:26:20 +00:00
|
|
|
func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler {
|
2025-11-26 01:09:42 +00:00
|
|
|
return func(c *fiber.Ctx) error {
|
2025-12-04 00:26:20 +00:00
|
|
|
var at string
|
|
|
|
|
var rt string
|
|
|
|
|
var setCookies bool
|
2025-11-26 01:09:42 +00:00
|
|
|
authHeader := c.Get("Authorization")
|
2025-12-04 00:26:20 +00:00
|
|
|
if authHeader != "" {
|
|
|
|
|
parts := strings.Split(authHeader, " ")
|
|
|
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
|
|
|
|
slog.Info("invalid auth header")
|
|
|
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
|
|
|
}
|
|
|
|
|
at = parts[1]
|
|
|
|
|
setCookies = false
|
|
|
|
|
} else {
|
|
|
|
|
cookies := authCookies(c)
|
|
|
|
|
at = cookies[cookieKeyAccessToken]
|
|
|
|
|
rt = cookies[cookieKeyRefreshToken]
|
|
|
|
|
setCookies = true
|
2025-11-26 01:09:42 +00:00
|
|
|
}
|
|
|
|
|
|
2025-12-15 00:38:23 +00:00
|
|
|
if at == "" && rt == "" {
|
|
|
|
|
slog.Info("no access token or refresh token")
|
2025-11-26 01:09:42 +00:00
|
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
|
|
|
}
|
|
|
|
|
|
2025-12-15 00:38:23 +00:00
|
|
|
if at == "" {
|
|
|
|
|
// if there is no access token, attempt to get new access token using the refresh token.
|
|
|
|
|
tx, err := db.BeginTx(c.Context(), nil)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
|
|
|
}
|
|
|
|
|
defer tx.Rollback()
|
|
|
|
|
|
|
|
|
|
newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
|
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
|
|
|
}
|
|
|
|
|
|
2025-12-16 00:41:30 +00:00
|
|
|
SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
|
2025-12-15 00:38:23 +00:00
|
|
|
at = newTokens.AccessToken
|
|
|
|
|
rt = newTokens.RefreshToken
|
|
|
|
|
}
|
|
|
|
|
|
2025-12-04 00:26:20 +00:00
|
|
|
authResult, err := s.AuthenticateWithAccessToken(c.Context(), db, at)
|
2025-11-26 01:09:42 +00:00
|
|
|
if err != nil {
|
|
|
|
|
var e *InvalidAccessTokenError
|
|
|
|
|
if errors.As(err, &e) {
|
2025-11-30 19:19:33 +00:00
|
|
|
slog.Info("invalid access token")
|
2025-11-26 01:09:42 +00:00
|
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var nf *user.NotFoundError
|
|
|
|
|
if errors.As(err, &nf) {
|
2025-11-30 19:19:33 +00:00
|
|
|
slog.Info("user not found")
|
2025-11-26 01:09:42 +00:00
|
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
|
|
|
}
|
|
|
|
|
|
2025-11-30 19:19:33 +00:00
|
|
|
return httperr.Internal(err)
|
2025-11-26 01:09:42 +00:00
|
|
|
}
|
|
|
|
|
|
2025-12-04 00:48:43 +00:00
|
|
|
reqctx.SetAuthenticatedUser(c, authResult.User)
|
2025-12-04 00:26:20 +00:00
|
|
|
|
|
|
|
|
// if cookie based auth and access token is about to expire (within 5 minutes),
|
|
|
|
|
// attempt to refresh the access token. if there is any error, ignore it and let the request continue.
|
|
|
|
|
if setCookies && time.Until(authResult.Claims.ExpiresAt.Time) < 5*time.Minute && rt != "" {
|
|
|
|
|
tx, txErr := db.BeginTx(c.Context(), nil)
|
|
|
|
|
if txErr == nil {
|
|
|
|
|
newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt)
|
|
|
|
|
if err == nil {
|
|
|
|
|
if commitErr := tx.Commit(); commitErr == nil {
|
2025-12-16 00:41:30 +00:00
|
|
|
SetAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
|
2025-12-04 00:26:20 +00:00
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
_ = tx.Rollback()
|
|
|
|
|
slog.Debug("auto-refresh failed", "error", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2025-11-26 01:09:42 +00:00
|
|
|
|
|
|
|
|
return c.Next()
|
|
|
|
|
}
|
|
|
|
|
}
|
2025-12-29 00:07:44 +00:00
|
|
|
|
|
|
|
|
// NewOptionalAuthMiddleware conditionally runs the given auth middleware only when
|
|
|
|
|
// the request contains auth credentials (Authorization header or auth cookies).
|
|
|
|
|
//
|
|
|
|
|
// This is useful for endpoints that are publicly accessible but can also behave
|
|
|
|
|
// differently for authenticated callers (e.g. share consumption routes that may
|
|
|
|
|
// resolve additional permissions for a specific account when credentials are
|
|
|
|
|
// present).
|
|
|
|
|
//
|
|
|
|
|
// Usage:
|
|
|
|
|
// authMiddleware := auth.NewAuthMiddleware(...)
|
|
|
|
|
// router.Use(auth.NewOptionalAuthMiddleware(authMiddleware))
|
|
|
|
|
func NewOptionalAuthMiddleware(authMiddleware fiber.Handler) fiber.Handler {
|
|
|
|
|
return func(c *fiber.Ctx) error {
|
|
|
|
|
hasAuthHeader := c.Get("Authorization") != ""
|
|
|
|
|
hasAuthCookies := c.Cookies(cookieKeyAccessToken) != "" || c.Cookies(cookieKeyRefreshToken) != ""
|
|
|
|
|
if !hasAuthHeader && !hasAuthCookies {
|
|
|
|
|
return c.Next()
|
|
|
|
|
}
|
|
|
|
|
return authMiddleware(c)
|
|
|
|
|
}
|
|
|
|
|
}
|