fix: registration endpoint and db auto close issue

This commit is contained in:
2025-11-29 20:51:56 +00:00
parent 5e4e08c255
commit 629d56b5ab
12 changed files with 136 additions and 82 deletions

View File

@@ -2,13 +2,13 @@ package auth
import (
"errors"
"log/slog"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
)
const authServiceKey = "authService"
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
@@ -26,29 +26,35 @@ type loginResponse struct {
RefreshToken string `json:"refreshToken"`
}
func RegisterAPIRoutes(api fiber.Router, s *Service) {
auth := api.Group("/auth", func(c *fiber.Ctx) error {
c.Locals(authServiceKey, s)
return c.Next()
})
auth.Post("/login", login)
auth.Post("/register", register)
type HTTPHandler struct {
service *Service
db *bun.DB
}
func mustAuthService(c *fiber.Ctx) *Service {
return c.Locals(authServiceKey).(*Service)
func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
return &HTTPHandler{service: s, db: db}
}
func login(c *fiber.Ctx) error {
s := mustAuthService(c)
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
auth := api.Group("/auth")
auth.Post("/login", h.Login)
auth.Post("/register", h.Register)
}
func (h *HTTPHandler) Login(c *fiber.Ctx) error {
req := new(loginRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
result, err := s.LoginWithEmailAndPassword(c.Context(), req.Email, req.Password)
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
defer tx.Rollback()
result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
if err != nil {
if errors.Is(err, ErrInvalidCredentials) {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
@@ -56,6 +62,10 @@ func login(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
if err := tx.Commit(); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
return c.JSON(loginResponse{
User: *result.User,
AccessToken: result.AccessToken,
@@ -63,15 +73,20 @@ func login(c *fiber.Ctx) error {
})
}
func register(c *fiber.Ctx) error {
s := mustAuthService(c)
func (h *HTTPHandler) Register(c *fiber.Ctx) error {
req := new(registerRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
result, err := s.Register(c.Context(), registerOptions{
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
slog.Error("failed to begin transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
defer tx.Rollback()
result, err := h.service.Register(c.Context(), tx, registerOptions{
email: req.Email,
password: req.Password,
displayName: req.DisplayName,
@@ -81,6 +96,12 @@ func register(c *fiber.Ctx) error {
if errors.As(err, &ae) {
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"})
}
slog.Error("failed to register user", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
if err := tx.Commit(); err != nil {
slog.Error("failed to commit transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}

View File

@@ -6,13 +6,14 @@ import (
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
)
const authenticatedUserKey = "authenticatedUser"
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token.
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
func NewBearerAuthMiddleware(s *Service) fiber.Handler {
func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
return func(c *fiber.Ctx) error {
authHeader := c.Get("Authorization")
if authHeader == "" {
@@ -25,7 +26,7 @@ func NewBearerAuthMiddleware(s *Service) fiber.Handler {
}
token := parts[1]
u, err := s.AuthenticateWithAccessToken(c.Context(), token)
u, err := s.AuthenticateWithAccessToken(c.Context(), db, token)
if err != nil {
var e *InvalidAccessTokenError
if errors.As(err, &e) {

View File

@@ -20,7 +20,6 @@ type LoginResult struct {
var ErrInvalidCredentials = errors.New("invalid credentials")
type Service struct {
db *bun.DB
userService *user.Service
tokenConfig TokenConfig
}
@@ -31,16 +30,15 @@ type registerOptions struct {
password string
}
func NewService(db *bun.DB, userService *user.Service, tokenConfig TokenConfig) *Service {
func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
return &Service{
db: db,
userService: userService,
tokenConfig: tokenConfig,
}
}
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain string) (*LoginResult, error) {
u, err := s.userService.UserByEmail(ctx, email)
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*LoginResult, error) {
u, err := s.userService.UserByEmail(ctx, db, email)
if err != nil {
var nf *user.NotFoundError
if errors.As(err, &nf) {
@@ -64,7 +62,7 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain st
return nil, err
}
_, err = s.db.NewInsert().Model(rt).Exec(ctx)
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
@@ -76,13 +74,13 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain st
}, nil
}
func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) {
func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions) (*LoginResult, error) {
hashed, err := password.Hash(opts.password)
if err != nil {
return nil, err
}
u, err := s.userService.RegisterUser(ctx, user.UserRegistrationOptions{
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
Email: opts.email,
DisplayName: opts.displayName,
Password: hashed,
@@ -101,7 +99,7 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
return nil, err
}
_, err = s.db.NewInsert().Model(rt).Exec(ctx)
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
@@ -113,7 +111,7 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
}, nil
}
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string) (*user.User, error) {
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
claims, err := ParseAccessToken(token, &s.tokenConfig)
if err != nil {
return nil, err
@@ -124,5 +122,5 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string)
return nil, newInvalidAccessTokenError(err)
}
return s.userService.UserByID(ctx, id)
return s.userService.UserByID(ctx, db, id)
}

View File

@@ -36,6 +36,10 @@ type RefreshToken struct {
CreatedAt time.Time `bun:"created_at,notnull"`
}
func newTokenID() (uuid.UUID, error) {
return uuid.NewV7()
}
func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) {
now := time.Now()
@@ -63,9 +67,9 @@ func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
}
id, err := uuid.NewV7()
id, err := newTokenID()
if err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
return nil, fmt.Errorf("failed to generate token ID: %w", err)
}
h := sha256.Sum256(buf)