fix: registration endpoint and db auto close issue

This commit is contained in:
2025-11-29 20:51:56 +00:00
parent 5e4e08c255
commit 629d56b5ab
12 changed files with 136 additions and 82 deletions

15
apps/backend/config.yaml Normal file
View File

@@ -0,0 +1,15 @@
server:
port: 8080
database:
postgres_url: postgres://drexa:hunter2@helian:5433/drexa?sslmode=disable
jwt:
issuer: drexa
audience: drexa-api
secret_key_base64: "pNeUExoqdakfecZLFL53NJpY4iB9zFot9EuEBItlYKY="
storage:
mode: hierarchical
backend: fs
root_path: ./data

View File

@@ -6,7 +6,6 @@ require (
github.com/gabriel-vasile/mimetype v1.4.11
github.com/gofiber/fiber/v2 v2.52.9
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
github.com/sqids/sqids-go v0.4.1
github.com/uptrace/bun v1.2.15
golang.org/x/crypto v0.40.0

View File

@@ -14,8 +14,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=

View File

@@ -2,13 +2,13 @@ package auth
import (
"errors"
"log/slog"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
)
const authServiceKey = "authService"
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
@@ -26,29 +26,35 @@ type loginResponse struct {
RefreshToken string `json:"refreshToken"`
}
func RegisterAPIRoutes(api fiber.Router, s *Service) {
auth := api.Group("/auth", func(c *fiber.Ctx) error {
c.Locals(authServiceKey, s)
return c.Next()
})
auth.Post("/login", login)
auth.Post("/register", register)
type HTTPHandler struct {
service *Service
db *bun.DB
}
func mustAuthService(c *fiber.Ctx) *Service {
return c.Locals(authServiceKey).(*Service)
func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
return &HTTPHandler{service: s, db: db}
}
func login(c *fiber.Ctx) error {
s := mustAuthService(c)
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
auth := api.Group("/auth")
auth.Post("/login", h.Login)
auth.Post("/register", h.Register)
}
func (h *HTTPHandler) Login(c *fiber.Ctx) error {
req := new(loginRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
result, err := s.LoginWithEmailAndPassword(c.Context(), req.Email, req.Password)
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
defer tx.Rollback()
result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
if err != nil {
if errors.Is(err, ErrInvalidCredentials) {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
@@ -56,6 +62,10 @@ func login(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
if err := tx.Commit(); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
return c.JSON(loginResponse{
User: *result.User,
AccessToken: result.AccessToken,
@@ -63,15 +73,20 @@ func login(c *fiber.Ctx) error {
})
}
func register(c *fiber.Ctx) error {
s := mustAuthService(c)
func (h *HTTPHandler) Register(c *fiber.Ctx) error {
req := new(registerRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
result, err := s.Register(c.Context(), registerOptions{
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
slog.Error("failed to begin transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
defer tx.Rollback()
result, err := h.service.Register(c.Context(), tx, registerOptions{
email: req.Email,
password: req.Password,
displayName: req.DisplayName,
@@ -81,6 +96,12 @@ func register(c *fiber.Ctx) error {
if errors.As(err, &ae) {
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"})
}
slog.Error("failed to register user", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
if err := tx.Commit(); err != nil {
slog.Error("failed to commit transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}

View File

@@ -6,13 +6,14 @@ import (
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
)
const authenticatedUserKey = "authenticatedUser"
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token.
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
func NewBearerAuthMiddleware(s *Service) fiber.Handler {
func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
return func(c *fiber.Ctx) error {
authHeader := c.Get("Authorization")
if authHeader == "" {
@@ -25,7 +26,7 @@ func NewBearerAuthMiddleware(s *Service) fiber.Handler {
}
token := parts[1]
u, err := s.AuthenticateWithAccessToken(c.Context(), token)
u, err := s.AuthenticateWithAccessToken(c.Context(), db, token)
if err != nil {
var e *InvalidAccessTokenError
if errors.As(err, &e) {

View File

@@ -20,7 +20,6 @@ type LoginResult struct {
var ErrInvalidCredentials = errors.New("invalid credentials")
type Service struct {
db *bun.DB
userService *user.Service
tokenConfig TokenConfig
}
@@ -31,16 +30,15 @@ type registerOptions struct {
password string
}
func NewService(db *bun.DB, userService *user.Service, tokenConfig TokenConfig) *Service {
func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
return &Service{
db: db,
userService: userService,
tokenConfig: tokenConfig,
}
}
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain string) (*LoginResult, error) {
u, err := s.userService.UserByEmail(ctx, email)
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*LoginResult, error) {
u, err := s.userService.UserByEmail(ctx, db, email)
if err != nil {
var nf *user.NotFoundError
if errors.As(err, &nf) {
@@ -64,7 +62,7 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain st
return nil, err
}
_, err = s.db.NewInsert().Model(rt).Exec(ctx)
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
@@ -76,13 +74,13 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain st
}, nil
}
func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) {
func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions) (*LoginResult, error) {
hashed, err := password.Hash(opts.password)
if err != nil {
return nil, err
}
u, err := s.userService.RegisterUser(ctx, user.UserRegistrationOptions{
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
Email: opts.email,
DisplayName: opts.displayName,
Password: hashed,
@@ -101,7 +99,7 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
return nil, err
}
_, err = s.db.NewInsert().Model(rt).Exec(ctx)
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
@@ -113,7 +111,7 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
}, nil
}
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string) (*user.User, error) {
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
claims, err := ParseAccessToken(token, &s.tokenConfig)
if err != nil {
return nil, err
@@ -124,5 +122,5 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string)
return nil, newInvalidAccessTokenError(err)
}
return s.userService.UserByID(ctx, id)
return s.userService.UserByID(ctx, db, id)
}

View File

@@ -36,6 +36,10 @@ type RefreshToken struct {
CreatedAt time.Time `bun:"created_at,notnull"`
}
func newTokenID() (uuid.UUID, error) {
return uuid.NewV7()
}
func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) {
now := time.Now()
@@ -63,9 +67,9 @@ func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
}
id, err := uuid.NewV7()
id, err := newTokenID()
if err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
return nil, fmt.Errorf("failed to generate token ID: %w", err)
}
h := sha256.Sum256(buf)

View File

@@ -2,6 +2,7 @@ package database
import (
"database/sql"
"time"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
@@ -10,6 +11,17 @@ import (
func NewFromPostgres(url string) *bun.DB {
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(url)))
// Configure connection pool to prevent "database closed" errors
// SetMaxOpenConns sets the maximum number of open connections to the database
sqldb.SetMaxOpenConns(25)
// SetMaxIdleConns sets the maximum number of connections in the idle connection pool
sqldb.SetMaxIdleConns(5)
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused
sqldb.SetConnMaxLifetime(5 * time.Minute)
// SetConnMaxIdleTime sets the maximum amount of time a connection may be idle
sqldb.SetConnMaxIdleTime(10 * time.Minute)
db := bun.NewDB(sqldb, pgdialect.New())
return db
}

View File

@@ -11,12 +11,15 @@ import (
"github.com/get-drexa/drexa/internal/user"
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger"
)
func NewServer(c Config) (*fiber.App, error) {
app := fiber.New()
db := database.NewFromPostgres(c.Database.PostgresURL)
app.Use(logger.New())
// Initialize blob store based on config
var blobStore blob.Store
switch c.Storage.Backend {
@@ -51,8 +54,8 @@ func NewServer(c Config) (*fiber.App, error) {
return nil, fmt.Errorf("failed to create virtual file system: %w", err)
}
userService := user.NewService(db)
authService := auth.NewService(db, userService, auth.TokenConfig{
userService := user.NewService()
authService := auth.NewService(userService, auth.TokenConfig{
Issuer: c.JWT.Issuer,
Audience: c.JWT.Audience,
SecretKey: c.JWT.SecretKey,
@@ -60,8 +63,8 @@ func NewServer(c Config) (*fiber.App, error) {
uploadService := upload.NewService(vfs, blobStore)
api := app.Group("/api")
auth.RegisterAPIRoutes(api, authService)
upload.RegisterAPIRoutes(api, uploadService)
auth.NewHTTPHandler(authService, db).RegisterRoutes(api)
upload.NewHTTPHandler(uploadService).RegisterRoutes(api)
return app, nil
}

View File

@@ -7,8 +7,6 @@ import (
"github.com/gofiber/fiber/v2"
)
const uploadServiceKey = "uploadService"
type createUploadRequest struct {
ParentID string `json:"parentId"`
Name string `json:"name"`
@@ -18,35 +16,34 @@ type updateUploadRequest struct {
Status Status `json:"status"`
}
func RegisterAPIRoutes(api fiber.Router, s *Service) {
upload := api.Group("/uploads", func(c *fiber.Ctx) error {
c.Locals(uploadServiceKey, s)
return c.Next()
})
upload.Post("/", createUpload)
upload.Put("/:uploadID/content", receiveUpload)
upload.Patch("/:uploadID", updateUpload)
type HTTPHandler struct {
service *Service
}
func mustUploadService(c *fiber.Ctx) *Service {
return c.Locals(uploadServiceKey).(*Service)
func NewHTTPHandler(s *Service) *HTTPHandler {
return &HTTPHandler{service: s}
}
func createUpload(c *fiber.Ctx) error {
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
upload := api.Group("/uploads")
upload.Post("/", h.Create)
upload.Put("/:uploadID/content", h.ReceiveContent)
upload.Patch("/:uploadID", h.Update)
}
func (h *HTTPHandler) Create(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
}
s := mustUploadService(c)
req := new(createUploadRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
upload, err := s.CreateUpload(c.Context(), u.ID, CreateUploadOptions{
upload, err := h.service.CreateUpload(c.Context(), u.ID, CreateUploadOptions{
ParentID: req.ParentID,
Name: req.Name,
})
@@ -57,17 +54,15 @@ func createUpload(c *fiber.Ctx) error {
return c.JSON(upload)
}
func receiveUpload(c *fiber.Ctx) error {
func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
}
s := mustUploadService(c)
uploadID := c.Params("uploadID")
err = s.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream())
err = h.service.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream())
defer c.Request().CloseBodyStream()
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
@@ -76,21 +71,19 @@ func receiveUpload(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusNoContent)
}
func updateUpload(c *fiber.Ctx) error {
func (h *HTTPHandler) Update(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
}
s := mustUploadService(c)
req := new(updateUploadRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
if req.Status == StatusCompleted {
upload, err := s.CompleteUpload(c.Context(), u.ID, c.Params("uploadID"))
upload, err := h.service.CompleteUpload(c.Context(), u.ID, c.Params("uploadID"))
if err != nil {
if errors.Is(err, ErrNotFound) {
return c.SendStatus(fiber.StatusNotFound)

View File

@@ -11,9 +11,7 @@ import (
"github.com/uptrace/bun"
)
type Service struct {
db *bun.DB
}
type Service struct{}
type UserRegistrationOptions struct {
Email string
@@ -21,20 +19,24 @@ type UserRegistrationOptions struct {
Password password.Hashed
}
func NewService(db *bun.DB) *Service {
return &Service{
db: db,
}
func NewService() *Service {
return &Service{}
}
func (s *Service) RegisterUser(ctx context.Context, opts UserRegistrationOptions) (*User, error) {
func (s *Service) RegisterUser(ctx context.Context, db bun.IDB, opts UserRegistrationOptions) (*User, error) {
uid, err := newUserID()
if err != nil {
return nil, err
}
u := User{
ID: uid,
Email: opts.Email,
DisplayName: opts.DisplayName,
Password: opts.Password,
}
_, err := s.db.NewInsert().Model(&u).Returning("*").Exec(ctx)
_, err = db.NewInsert().Model(&u).Returning("*").Exec(ctx)
if err != nil {
if database.IsUniqueViolation(err) {
return nil, newAlreadyExistsError(u.Email)
@@ -45,9 +47,9 @@ func (s *Service) RegisterUser(ctx context.Context, opts UserRegistrationOptions
return &u, nil
}
func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) {
func (s *Service) UserByID(ctx context.Context, db bun.IDB, id uuid.UUID) (*User, error) {
var user User
err := s.db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx)
err := db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, newNotFoundError(id, "")
@@ -57,9 +59,9 @@ func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) {
return &user, nil
}
func (s *Service) UserByEmail(ctx context.Context, email string) (*User, error) {
func (s *Service) UserByEmail(ctx context.Context, db bun.IDB, email string) (*User, error) {
var user User
err := s.db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx)
err := db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, newNotFoundError(uuid.Nil, email)
@@ -69,6 +71,6 @@ func (s *Service) UserByEmail(ctx context.Context, email string) (*User, error)
return &user, nil
}
func (s *Service) UserExistsByEmail(ctx context.Context, email string) (bool, error) {
return s.db.NewSelect().Model(&User{}).Where("email = ?", email).Exists(ctx)
func (s *Service) UserExistsByEmail(ctx context.Context, db bun.IDB, email string) (bool, error) {
return db.NewSelect().Model(&User{}).Where("email = ?", email).Exists(ctx)
}

View File

@@ -1,6 +1,8 @@
package user
import (
"time"
"github.com/get-drexa/drexa/internal/password"
"github.com/google/uuid"
"github.com/uptrace/bun"
@@ -10,9 +12,15 @@ type User struct {
bun.BaseModel `bun:"users"`
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
DisplayName string `bun:"display_name,notnull" json:"displayName"`
DisplayName string `bun:"display_name" json:"displayName"`
Email string `bun:"email,unique,notnull" json:"email"`
Password password.Hashed `bun:"password,notnull" json:"-"`
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
}
func newUserID() (uuid.UUID, error) {
return uuid.NewV7()
}