diff --git a/apps/backend/internal/account/http.go b/apps/backend/internal/account/http.go index bbe8127..f737d2e 100644 --- a/apps/backend/internal/account/http.go +++ b/apps/backend/internal/account/http.go @@ -5,6 +5,7 @@ import ( "github.com/get-drexa/drexa/internal/auth" "github.com/get-drexa/drexa/internal/httperr" + "github.com/get-drexa/drexa/internal/reqctx" "github.com/get-drexa/drexa/internal/user" "github.com/gofiber/fiber/v2" "github.com/google/uuid" @@ -54,17 +55,14 @@ func (h *HTTPHandler) RegisterRoutes(api fiber.Router) fiber.Router { } func (h *HTTPHandler) accountMiddleware(c *fiber.Ctx) error { - user, err := auth.AuthenticatedUser(c) - if err != nil { - return c.SendStatus(fiber.StatusUnauthorized) - } + u := reqctx.AuthenticatedUser(c).(*user.User) accountID, err := uuid.Parse(c.Params("accountID")) if err != nil { return c.SendStatus(fiber.StatusNotFound) } - account, err := h.accountService.AccountByID(c.Context(), h.db, user.ID, accountID) + account, err := h.accountService.AccountByID(c.Context(), h.db, u.ID, accountID) if err != nil { if errors.Is(err, ErrAccountNotFound) { return c.SendStatus(fiber.StatusNotFound) diff --git a/apps/backend/internal/auth/err.go b/apps/backend/internal/auth/err.go index 5d72eb7..93e5481 100644 --- a/apps/backend/internal/auth/err.go +++ b/apps/backend/internal/auth/err.go @@ -5,7 +5,6 @@ import ( "fmt" ) -var ErrUnauthenticatedRequest = errors.New("unauthenticated request") var ErrInvalidRefreshToken = errors.New("invalid refresh token") var ErrRefreshTokenExpired = errors.New("refresh token expired") var ErrRefreshTokenReused = errors.New("refresh token reused") diff --git a/apps/backend/internal/auth/middleware.go b/apps/backend/internal/auth/middleware.go index b1a2109..237118f 100644 --- a/apps/backend/internal/auth/middleware.go +++ b/apps/backend/internal/auth/middleware.go @@ -7,15 +7,14 @@ import ( "time" "github.com/get-drexa/drexa/internal/httperr" + "github.com/get-drexa/drexa/internal/reqctx" "github.com/get-drexa/drexa/internal/user" "github.com/gofiber/fiber/v2" "github.com/uptrace/bun" ) -const authenticatedUserKey = "authenticatedUser" - // NewAuthMiddleware creates a middleware that authenticates requests via Bearer token or cookies. -// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser. +// To obtain the authenticated user in subsequent handlers, see reqctx.AuthenticatedUser. func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler { return func(c *fiber.Ctx) error { var at string @@ -59,7 +58,7 @@ func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber. return httperr.Internal(err) } - c.Locals(authenticatedUserKey, authResult.User) + reqctx.SetAuthenticatedUser(c, authResult.User) // if cookie based auth and access token is about to expire (within 5 minutes), // attempt to refresh the access token. if there is any error, ignore it and let the request continue. @@ -81,12 +80,3 @@ func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber. return c.Next() } } - -// AuthenticatedUser returns the authenticated user from the given fiber context. -// Returns ErrUnauthenticatedRequest if not authenticated. -func AuthenticatedUser(c *fiber.Ctx) (*user.User, error) { - if u, ok := c.Locals(authenticatedUserKey).(*user.User); ok { - return u, nil - } - return nil, ErrUnauthenticatedRequest -} diff --git a/apps/backend/internal/reqctx/reqctx.go b/apps/backend/internal/reqctx/reqctx.go new file mode 100644 index 0000000..c29cc7b --- /dev/null +++ b/apps/backend/internal/reqctx/reqctx.go @@ -0,0 +1,23 @@ +package reqctx + +import ( + "errors" + + "github.com/gofiber/fiber/v2" +) + +const authenticatedUserKey = "authenticatedUser" + +var ErrUnauthenticatedRequest = errors.New("unauthenticated request") + +// AuthenticatedUser returns the authenticated user from the given fiber context. +// Returns ErrUnauthenticatedRequest if not authenticated. +// The caller must type assert the returned value to the appropriate user type. +func AuthenticatedUser(c *fiber.Ctx) any { + return c.Locals(authenticatedUserKey) +} + +// SetAuthenticatedUser sets the authenticated user in the fiber context. +func SetAuthenticatedUser(c *fiber.Ctx, user interface{}) { + c.Locals(authenticatedUserKey, user) +}