fix: registration endpoint and db auto close issue

This commit is contained in:
2025-11-29 20:51:56 +00:00
parent 5e4e08c255
commit 629d56b5ab
12 changed files with 136 additions and 82 deletions

15
apps/backend/config.yaml Normal file
View File

@@ -0,0 +1,15 @@
server:
port: 8080
database:
postgres_url: postgres://drexa:hunter2@helian:5433/drexa?sslmode=disable
jwt:
issuer: drexa
audience: drexa-api
secret_key_base64: "pNeUExoqdakfecZLFL53NJpY4iB9zFot9EuEBItlYKY="
storage:
mode: hierarchical
backend: fs
root_path: ./data

View File

@@ -6,7 +6,6 @@ require (
github.com/gabriel-vasile/mimetype v1.4.11 github.com/gabriel-vasile/mimetype v1.4.11
github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/fiber/v2 v2.52.9
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
github.com/sqids/sqids-go v0.4.1 github.com/sqids/sqids-go v0.4.1
github.com/uptrace/bun v1.2.15 github.com/uptrace/bun v1.2.15
golang.org/x/crypto v0.40.0 golang.org/x/crypto v0.40.0

View File

@@ -14,8 +14,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=

View File

@@ -2,13 +2,13 @@ package auth
import ( import (
"errors" "errors"
"log/slog"
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
) )
const authServiceKey = "authService"
type loginRequest struct { type loginRequest struct {
Email string `json:"email"` Email string `json:"email"`
Password string `json:"password"` Password string `json:"password"`
@@ -26,29 +26,35 @@ type loginResponse struct {
RefreshToken string `json:"refreshToken"` RefreshToken string `json:"refreshToken"`
} }
func RegisterAPIRoutes(api fiber.Router, s *Service) { type HTTPHandler struct {
auth := api.Group("/auth", func(c *fiber.Ctx) error { service *Service
c.Locals(authServiceKey, s) db *bun.DB
return c.Next()
})
auth.Post("/login", login)
auth.Post("/register", register)
} }
func mustAuthService(c *fiber.Ctx) *Service { func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
return c.Locals(authServiceKey).(*Service) return &HTTPHandler{service: s, db: db}
} }
func login(c *fiber.Ctx) error { func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
s := mustAuthService(c) auth := api.Group("/auth")
auth.Post("/login", h.Login)
auth.Post("/register", h.Register)
}
func (h *HTTPHandler) Login(c *fiber.Ctx) error {
req := new(loginRequest) req := new(loginRequest)
if err := c.BodyParser(req); err != nil { if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
} }
result, err := s.LoginWithEmailAndPassword(c.Context(), req.Email, req.Password) tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
defer tx.Rollback()
result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
if err != nil { if err != nil {
if errors.Is(err, ErrInvalidCredentials) { if errors.Is(err, ErrInvalidCredentials) {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"}) return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
@@ -56,6 +62,10 @@ func login(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
} }
if err := tx.Commit(); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
return c.JSON(loginResponse{ return c.JSON(loginResponse{
User: *result.User, User: *result.User,
AccessToken: result.AccessToken, AccessToken: result.AccessToken,
@@ -63,15 +73,20 @@ func login(c *fiber.Ctx) error {
}) })
} }
func register(c *fiber.Ctx) error { func (h *HTTPHandler) Register(c *fiber.Ctx) error {
s := mustAuthService(c)
req := new(registerRequest) req := new(registerRequest)
if err := c.BodyParser(req); err != nil { if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
} }
result, err := s.Register(c.Context(), registerOptions{ tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
slog.Error("failed to begin transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
defer tx.Rollback()
result, err := h.service.Register(c.Context(), tx, registerOptions{
email: req.Email, email: req.Email,
password: req.Password, password: req.Password,
displayName: req.DisplayName, displayName: req.DisplayName,
@@ -81,6 +96,12 @@ func register(c *fiber.Ctx) error {
if errors.As(err, &ae) { if errors.As(err, &ae) {
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"}) return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"})
} }
slog.Error("failed to register user", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
if err := tx.Commit(); err != nil {
slog.Error("failed to commit transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
} }

View File

@@ -6,13 +6,14 @@ import (
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
) )
const authenticatedUserKey = "authenticatedUser" const authenticatedUserKey = "authenticatedUser"
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token. // NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token.
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser. // To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
func NewBearerAuthMiddleware(s *Service) fiber.Handler { func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
authHeader := c.Get("Authorization") authHeader := c.Get("Authorization")
if authHeader == "" { if authHeader == "" {
@@ -25,7 +26,7 @@ func NewBearerAuthMiddleware(s *Service) fiber.Handler {
} }
token := parts[1] token := parts[1]
u, err := s.AuthenticateWithAccessToken(c.Context(), token) u, err := s.AuthenticateWithAccessToken(c.Context(), db, token)
if err != nil { if err != nil {
var e *InvalidAccessTokenError var e *InvalidAccessTokenError
if errors.As(err, &e) { if errors.As(err, &e) {

View File

@@ -20,7 +20,6 @@ type LoginResult struct {
var ErrInvalidCredentials = errors.New("invalid credentials") var ErrInvalidCredentials = errors.New("invalid credentials")
type Service struct { type Service struct {
db *bun.DB
userService *user.Service userService *user.Service
tokenConfig TokenConfig tokenConfig TokenConfig
} }
@@ -31,16 +30,15 @@ type registerOptions struct {
password string password string
} }
func NewService(db *bun.DB, userService *user.Service, tokenConfig TokenConfig) *Service { func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
return &Service{ return &Service{
db: db,
userService: userService, userService: userService,
tokenConfig: tokenConfig, tokenConfig: tokenConfig,
} }
} }
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain string) (*LoginResult, error) { func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*LoginResult, error) {
u, err := s.userService.UserByEmail(ctx, email) u, err := s.userService.UserByEmail(ctx, db, email)
if err != nil { if err != nil {
var nf *user.NotFoundError var nf *user.NotFoundError
if errors.As(err, &nf) { if errors.As(err, &nf) {
@@ -64,7 +62,7 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain st
return nil, err return nil, err
} }
_, err = s.db.NewInsert().Model(rt).Exec(ctx) _, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -76,13 +74,13 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain st
}, nil }, nil
} }
func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) { func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions) (*LoginResult, error) {
hashed, err := password.Hash(opts.password) hashed, err := password.Hash(opts.password)
if err != nil { if err != nil {
return nil, err return nil, err
} }
u, err := s.userService.RegisterUser(ctx, user.UserRegistrationOptions{ u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
Email: opts.email, Email: opts.email,
DisplayName: opts.displayName, DisplayName: opts.displayName,
Password: hashed, Password: hashed,
@@ -101,7 +99,7 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
return nil, err return nil, err
} }
_, err = s.db.NewInsert().Model(rt).Exec(ctx) _, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -113,7 +111,7 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
}, nil }, nil
} }
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string) (*user.User, error) { func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
claims, err := ParseAccessToken(token, &s.tokenConfig) claims, err := ParseAccessToken(token, &s.tokenConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -124,5 +122,5 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string)
return nil, newInvalidAccessTokenError(err) return nil, newInvalidAccessTokenError(err)
} }
return s.userService.UserByID(ctx, id) return s.userService.UserByID(ctx, db, id)
} }

View File

@@ -36,6 +36,10 @@ type RefreshToken struct {
CreatedAt time.Time `bun:"created_at,notnull"` CreatedAt time.Time `bun:"created_at,notnull"`
} }
func newTokenID() (uuid.UUID, error) {
return uuid.NewV7()
}
func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) { func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) {
now := time.Now() now := time.Now()
@@ -63,9 +67,9 @@ func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error
return nil, fmt.Errorf("failed to generate refresh token: %w", err) return nil, fmt.Errorf("failed to generate refresh token: %w", err)
} }
id, err := uuid.NewV7() id, err := newTokenID()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err) return nil, fmt.Errorf("failed to generate token ID: %w", err)
} }
h := sha256.Sum256(buf) h := sha256.Sum256(buf)

View File

@@ -2,6 +2,7 @@ package database
import ( import (
"database/sql" "database/sql"
"time"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect" "github.com/uptrace/bun/dialect/pgdialect"
@@ -10,6 +11,17 @@ import (
func NewFromPostgres(url string) *bun.DB { func NewFromPostgres(url string) *bun.DB {
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(url))) sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(url)))
// Configure connection pool to prevent "database closed" errors
// SetMaxOpenConns sets the maximum number of open connections to the database
sqldb.SetMaxOpenConns(25)
// SetMaxIdleConns sets the maximum number of connections in the idle connection pool
sqldb.SetMaxIdleConns(5)
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused
sqldb.SetConnMaxLifetime(5 * time.Minute)
// SetConnMaxIdleTime sets the maximum amount of time a connection may be idle
sqldb.SetConnMaxIdleTime(10 * time.Minute)
db := bun.NewDB(sqldb, pgdialect.New()) db := bun.NewDB(sqldb, pgdialect.New())
return db return db
} }

View File

@@ -11,12 +11,15 @@ import (
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/get-drexa/drexa/internal/virtualfs" "github.com/get-drexa/drexa/internal/virtualfs"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger"
) )
func NewServer(c Config) (*fiber.App, error) { func NewServer(c Config) (*fiber.App, error) {
app := fiber.New() app := fiber.New()
db := database.NewFromPostgres(c.Database.PostgresURL) db := database.NewFromPostgres(c.Database.PostgresURL)
app.Use(logger.New())
// Initialize blob store based on config // Initialize blob store based on config
var blobStore blob.Store var blobStore blob.Store
switch c.Storage.Backend { switch c.Storage.Backend {
@@ -51,8 +54,8 @@ func NewServer(c Config) (*fiber.App, error) {
return nil, fmt.Errorf("failed to create virtual file system: %w", err) return nil, fmt.Errorf("failed to create virtual file system: %w", err)
} }
userService := user.NewService(db) userService := user.NewService()
authService := auth.NewService(db, userService, auth.TokenConfig{ authService := auth.NewService(userService, auth.TokenConfig{
Issuer: c.JWT.Issuer, Issuer: c.JWT.Issuer,
Audience: c.JWT.Audience, Audience: c.JWT.Audience,
SecretKey: c.JWT.SecretKey, SecretKey: c.JWT.SecretKey,
@@ -60,8 +63,8 @@ func NewServer(c Config) (*fiber.App, error) {
uploadService := upload.NewService(vfs, blobStore) uploadService := upload.NewService(vfs, blobStore)
api := app.Group("/api") api := app.Group("/api")
auth.RegisterAPIRoutes(api, authService) auth.NewHTTPHandler(authService, db).RegisterRoutes(api)
upload.RegisterAPIRoutes(api, uploadService) upload.NewHTTPHandler(uploadService).RegisterRoutes(api)
return app, nil return app, nil
} }

View File

@@ -7,8 +7,6 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
const uploadServiceKey = "uploadService"
type createUploadRequest struct { type createUploadRequest struct {
ParentID string `json:"parentId"` ParentID string `json:"parentId"`
Name string `json:"name"` Name string `json:"name"`
@@ -18,35 +16,34 @@ type updateUploadRequest struct {
Status Status `json:"status"` Status Status `json:"status"`
} }
func RegisterAPIRoutes(api fiber.Router, s *Service) { type HTTPHandler struct {
upload := api.Group("/uploads", func(c *fiber.Ctx) error { service *Service
c.Locals(uploadServiceKey, s)
return c.Next()
})
upload.Post("/", createUpload)
upload.Put("/:uploadID/content", receiveUpload)
upload.Patch("/:uploadID", updateUpload)
} }
func mustUploadService(c *fiber.Ctx) *Service { func NewHTTPHandler(s *Service) *HTTPHandler {
return c.Locals(uploadServiceKey).(*Service) return &HTTPHandler{service: s}
} }
func createUpload(c *fiber.Ctx) error { func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
upload := api.Group("/uploads")
upload.Post("/", h.Create)
upload.Put("/:uploadID/content", h.ReceiveContent)
upload.Patch("/:uploadID", h.Update)
}
func (h *HTTPHandler) Create(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c) u, err := auth.AuthenticatedUser(c)
if err != nil { if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"}) return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
} }
s := mustUploadService(c)
req := new(createUploadRequest) req := new(createUploadRequest)
if err := c.BodyParser(req); err != nil { if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
} }
upload, err := s.CreateUpload(c.Context(), u.ID, CreateUploadOptions{ upload, err := h.service.CreateUpload(c.Context(), u.ID, CreateUploadOptions{
ParentID: req.ParentID, ParentID: req.ParentID,
Name: req.Name, Name: req.Name,
}) })
@@ -57,17 +54,15 @@ func createUpload(c *fiber.Ctx) error {
return c.JSON(upload) return c.JSON(upload)
} }
func receiveUpload(c *fiber.Ctx) error { func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c) u, err := auth.AuthenticatedUser(c)
if err != nil { if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"}) return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
} }
s := mustUploadService(c)
uploadID := c.Params("uploadID") uploadID := c.Params("uploadID")
err = s.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream()) err = h.service.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream())
defer c.Request().CloseBodyStream() defer c.Request().CloseBodyStream()
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
@@ -76,21 +71,19 @@ func receiveUpload(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusNoContent) return c.SendStatus(fiber.StatusNoContent)
} }
func updateUpload(c *fiber.Ctx) error { func (h *HTTPHandler) Update(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c) u, err := auth.AuthenticatedUser(c)
if err != nil { if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"}) return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
} }
s := mustUploadService(c)
req := new(updateUploadRequest) req := new(updateUploadRequest)
if err := c.BodyParser(req); err != nil { if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
} }
if req.Status == StatusCompleted { if req.Status == StatusCompleted {
upload, err := s.CompleteUpload(c.Context(), u.ID, c.Params("uploadID")) upload, err := h.service.CompleteUpload(c.Context(), u.ID, c.Params("uploadID"))
if err != nil { if err != nil {
if errors.Is(err, ErrNotFound) { if errors.Is(err, ErrNotFound) {
return c.SendStatus(fiber.StatusNotFound) return c.SendStatus(fiber.StatusNotFound)

View File

@@ -11,9 +11,7 @@ import (
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type Service struct { type Service struct{}
db *bun.DB
}
type UserRegistrationOptions struct { type UserRegistrationOptions struct {
Email string Email string
@@ -21,20 +19,24 @@ type UserRegistrationOptions struct {
Password password.Hashed Password password.Hashed
} }
func NewService(db *bun.DB) *Service { func NewService() *Service {
return &Service{ return &Service{}
db: db,
}
} }
func (s *Service) RegisterUser(ctx context.Context, opts UserRegistrationOptions) (*User, error) { func (s *Service) RegisterUser(ctx context.Context, db bun.IDB, opts UserRegistrationOptions) (*User, error) {
uid, err := newUserID()
if err != nil {
return nil, err
}
u := User{ u := User{
ID: uid,
Email: opts.Email, Email: opts.Email,
DisplayName: opts.DisplayName, DisplayName: opts.DisplayName,
Password: opts.Password, Password: opts.Password,
} }
_, err := s.db.NewInsert().Model(&u).Returning("*").Exec(ctx) _, err = db.NewInsert().Model(&u).Returning("*").Exec(ctx)
if err != nil { if err != nil {
if database.IsUniqueViolation(err) { if database.IsUniqueViolation(err) {
return nil, newAlreadyExistsError(u.Email) return nil, newAlreadyExistsError(u.Email)
@@ -45,9 +47,9 @@ func (s *Service) RegisterUser(ctx context.Context, opts UserRegistrationOptions
return &u, nil return &u, nil
} }
func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) { func (s *Service) UserByID(ctx context.Context, db bun.IDB, id uuid.UUID) (*User, error) {
var user User var user User
err := s.db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx) err := db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, newNotFoundError(id, "") return nil, newNotFoundError(id, "")
@@ -57,9 +59,9 @@ func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) {
return &user, nil return &user, nil
} }
func (s *Service) UserByEmail(ctx context.Context, email string) (*User, error) { func (s *Service) UserByEmail(ctx context.Context, db bun.IDB, email string) (*User, error) {
var user User var user User
err := s.db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx) err := db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, newNotFoundError(uuid.Nil, email) return nil, newNotFoundError(uuid.Nil, email)
@@ -69,6 +71,6 @@ func (s *Service) UserByEmail(ctx context.Context, email string) (*User, error)
return &user, nil return &user, nil
} }
func (s *Service) UserExistsByEmail(ctx context.Context, email string) (bool, error) { func (s *Service) UserExistsByEmail(ctx context.Context, db bun.IDB, email string) (bool, error) {
return s.db.NewSelect().Model(&User{}).Where("email = ?", email).Exists(ctx) return db.NewSelect().Model(&User{}).Where("email = ?", email).Exists(ctx)
} }

View File

@@ -1,6 +1,8 @@
package user package user
import ( import (
"time"
"github.com/get-drexa/drexa/internal/password" "github.com/get-drexa/drexa/internal/password"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -10,9 +12,15 @@ type User struct {
bun.BaseModel `bun:"users"` bun.BaseModel `bun:"users"`
ID uuid.UUID `bun:",pk,type:uuid" json:"id"` ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
DisplayName string `bun:"display_name,notnull" json:"displayName"` DisplayName string `bun:"display_name" json:"displayName"`
Email string `bun:"email,unique,notnull" json:"email"` Email string `bun:"email,unique,notnull" json:"email"`
Password password.Hashed `bun:"password,notnull" json:"-"` Password password.Hashed `bun:"password,notnull" json:"-"`
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"` StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"` StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
}
func newUserID() (uuid.UUID, error) {
return uuid.NewV7()
} }