mirror of
https://github.com/get-drexa/drive.git
synced 2025-11-30 21:41:39 +00:00
fix: registration endpoint and db auto close issue
This commit is contained in:
15
apps/backend/config.yaml
Normal file
15
apps/backend/config.yaml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
server:
|
||||||
|
port: 8080
|
||||||
|
|
||||||
|
database:
|
||||||
|
postgres_url: postgres://drexa:hunter2@helian:5433/drexa?sslmode=disable
|
||||||
|
|
||||||
|
jwt:
|
||||||
|
issuer: drexa
|
||||||
|
audience: drexa-api
|
||||||
|
secret_key_base64: "pNeUExoqdakfecZLFL53NJpY4iB9zFot9EuEBItlYKY="
|
||||||
|
|
||||||
|
storage:
|
||||||
|
mode: hierarchical
|
||||||
|
backend: fs
|
||||||
|
root_path: ./data
|
||||||
@@ -6,7 +6,6 @@ require (
|
|||||||
github.com/gabriel-vasile/mimetype v1.4.11
|
github.com/gabriel-vasile/mimetype v1.4.11
|
||||||
github.com/gofiber/fiber/v2 v2.52.9
|
github.com/gofiber/fiber/v2 v2.52.9
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/joho/godotenv v1.5.1
|
|
||||||
github.com/sqids/sqids-go v0.4.1
|
github.com/sqids/sqids-go v0.4.1
|
||||||
github.com/uptrace/bun v1.2.15
|
github.com/uptrace/bun v1.2.15
|
||||||
golang.org/x/crypto v0.40.0
|
golang.org/x/crypto v0.40.0
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
|||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
|
||||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
|
||||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
|||||||
@@ -2,13 +2,13 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
"github.com/get-drexa/drexa/internal/user"
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
const authServiceKey = "authService"
|
|
||||||
|
|
||||||
type loginRequest struct {
|
type loginRequest struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
@@ -26,29 +26,35 @@ type loginResponse struct {
|
|||||||
RefreshToken string `json:"refreshToken"`
|
RefreshToken string `json:"refreshToken"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterAPIRoutes(api fiber.Router, s *Service) {
|
type HTTPHandler struct {
|
||||||
auth := api.Group("/auth", func(c *fiber.Ctx) error {
|
service *Service
|
||||||
c.Locals(authServiceKey, s)
|
db *bun.DB
|
||||||
return c.Next()
|
|
||||||
})
|
|
||||||
|
|
||||||
auth.Post("/login", login)
|
|
||||||
auth.Post("/register", register)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustAuthService(c *fiber.Ctx) *Service {
|
func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
|
||||||
return c.Locals(authServiceKey).(*Service)
|
return &HTTPHandler{service: s, db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func login(c *fiber.Ctx) error {
|
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
|
||||||
s := mustAuthService(c)
|
auth := api.Group("/auth")
|
||||||
|
|
||||||
|
auth.Post("/login", h.Login)
|
||||||
|
auth.Post("/register", h.Register)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
||||||
req := new(loginRequest)
|
req := new(loginRequest)
|
||||||
if err := c.BodyParser(req); err != nil {
|
if err := c.BodyParser(req); err != nil {
|
||||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := s.LoginWithEmailAndPassword(c.Context(), req.Email, req.Password)
|
tx, err := h.db.BeginTx(c.Context(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrInvalidCredentials) {
|
if errors.Is(err, ErrInvalidCredentials) {
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
|
||||||
@@ -56,6 +62,10 @@ func login(c *fiber.Ctx) error {
|
|||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
||||||
|
}
|
||||||
|
|
||||||
return c.JSON(loginResponse{
|
return c.JSON(loginResponse{
|
||||||
User: *result.User,
|
User: *result.User,
|
||||||
AccessToken: result.AccessToken,
|
AccessToken: result.AccessToken,
|
||||||
@@ -63,15 +73,20 @@ func login(c *fiber.Ctx) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func register(c *fiber.Ctx) error {
|
func (h *HTTPHandler) Register(c *fiber.Ctx) error {
|
||||||
s := mustAuthService(c)
|
|
||||||
|
|
||||||
req := new(registerRequest)
|
req := new(registerRequest)
|
||||||
if err := c.BodyParser(req); err != nil {
|
if err := c.BodyParser(req); err != nil {
|
||||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := s.Register(c.Context(), registerOptions{
|
tx, err := h.db.BeginTx(c.Context(), nil)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to begin transaction", "error", err)
|
||||||
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
result, err := h.service.Register(c.Context(), tx, registerOptions{
|
||||||
email: req.Email,
|
email: req.Email,
|
||||||
password: req.Password,
|
password: req.Password,
|
||||||
displayName: req.DisplayName,
|
displayName: req.DisplayName,
|
||||||
@@ -81,6 +96,12 @@ func register(c *fiber.Ctx) error {
|
|||||||
if errors.As(err, &ae) {
|
if errors.As(err, &ae) {
|
||||||
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"})
|
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"})
|
||||||
}
|
}
|
||||||
|
slog.Error("failed to register user", "error", err)
|
||||||
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
slog.Error("failed to commit transaction", "error", err)
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,13 +6,14 @@ import (
|
|||||||
|
|
||||||
"github.com/get-drexa/drexa/internal/user"
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
const authenticatedUserKey = "authenticatedUser"
|
const authenticatedUserKey = "authenticatedUser"
|
||||||
|
|
||||||
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token.
|
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token.
|
||||||
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
|
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
|
||||||
func NewBearerAuthMiddleware(s *Service) fiber.Handler {
|
func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
authHeader := c.Get("Authorization")
|
authHeader := c.Get("Authorization")
|
||||||
if authHeader == "" {
|
if authHeader == "" {
|
||||||
@@ -25,7 +26,7 @@ func NewBearerAuthMiddleware(s *Service) fiber.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
token := parts[1]
|
token := parts[1]
|
||||||
u, err := s.AuthenticateWithAccessToken(c.Context(), token)
|
u, err := s.AuthenticateWithAccessToken(c.Context(), db, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var e *InvalidAccessTokenError
|
var e *InvalidAccessTokenError
|
||||||
if errors.As(err, &e) {
|
if errors.As(err, &e) {
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ type LoginResult struct {
|
|||||||
var ErrInvalidCredentials = errors.New("invalid credentials")
|
var ErrInvalidCredentials = errors.New("invalid credentials")
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
db *bun.DB
|
|
||||||
userService *user.Service
|
userService *user.Service
|
||||||
tokenConfig TokenConfig
|
tokenConfig TokenConfig
|
||||||
}
|
}
|
||||||
@@ -31,16 +30,15 @@ type registerOptions struct {
|
|||||||
password string
|
password string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(db *bun.DB, userService *user.Service, tokenConfig TokenConfig) *Service {
|
func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
db: db,
|
|
||||||
userService: userService,
|
userService: userService,
|
||||||
tokenConfig: tokenConfig,
|
tokenConfig: tokenConfig,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain string) (*LoginResult, error) {
|
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*LoginResult, error) {
|
||||||
u, err := s.userService.UserByEmail(ctx, email)
|
u, err := s.userService.UserByEmail(ctx, db, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var nf *user.NotFoundError
|
var nf *user.NotFoundError
|
||||||
if errors.As(err, &nf) {
|
if errors.As(err, &nf) {
|
||||||
@@ -64,7 +62,7 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain st
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.db.NewInsert().Model(rt).Exec(ctx)
|
_, err = db.NewInsert().Model(rt).Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -76,13 +74,13 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain st
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) {
|
func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions) (*LoginResult, error) {
|
||||||
hashed, err := password.Hash(opts.password)
|
hashed, err := password.Hash(opts.password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
u, err := s.userService.RegisterUser(ctx, user.UserRegistrationOptions{
|
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
|
||||||
Email: opts.email,
|
Email: opts.email,
|
||||||
DisplayName: opts.displayName,
|
DisplayName: opts.displayName,
|
||||||
Password: hashed,
|
Password: hashed,
|
||||||
@@ -101,7 +99,7 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.db.NewInsert().Model(rt).Exec(ctx)
|
_, err = db.NewInsert().Model(rt).Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -113,7 +111,7 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string) (*user.User, error) {
|
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
|
||||||
claims, err := ParseAccessToken(token, &s.tokenConfig)
|
claims, err := ParseAccessToken(token, &s.tokenConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -124,5 +122,5 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string)
|
|||||||
return nil, newInvalidAccessTokenError(err)
|
return nil, newInvalidAccessTokenError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.userService.UserByID(ctx, id)
|
return s.userService.UserByID(ctx, db, id)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,6 +36,10 @@ type RefreshToken struct {
|
|||||||
CreatedAt time.Time `bun:"created_at,notnull"`
|
CreatedAt time.Time `bun:"created_at,notnull"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTokenID() (uuid.UUID, error) {
|
||||||
|
return uuid.NewV7()
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) {
|
func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
@@ -63,9 +67,9 @@ func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error
|
|||||||
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
id, err := uuid.NewV7()
|
id, err := newTokenID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
return nil, fmt.Errorf("failed to generate token ID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h := sha256.Sum256(buf)
|
h := sha256.Sum256(buf)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package database
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
"github.com/uptrace/bun/dialect/pgdialect"
|
"github.com/uptrace/bun/dialect/pgdialect"
|
||||||
@@ -10,6 +11,17 @@ import (
|
|||||||
|
|
||||||
func NewFromPostgres(url string) *bun.DB {
|
func NewFromPostgres(url string) *bun.DB {
|
||||||
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(url)))
|
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(url)))
|
||||||
|
|
||||||
|
// Configure connection pool to prevent "database closed" errors
|
||||||
|
// SetMaxOpenConns sets the maximum number of open connections to the database
|
||||||
|
sqldb.SetMaxOpenConns(25)
|
||||||
|
// SetMaxIdleConns sets the maximum number of connections in the idle connection pool
|
||||||
|
sqldb.SetMaxIdleConns(5)
|
||||||
|
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused
|
||||||
|
sqldb.SetConnMaxLifetime(5 * time.Minute)
|
||||||
|
// SetConnMaxIdleTime sets the maximum amount of time a connection may be idle
|
||||||
|
sqldb.SetConnMaxIdleTime(10 * time.Minute)
|
||||||
|
|
||||||
db := bun.NewDB(sqldb, pgdialect.New())
|
db := bun.NewDB(sqldb, pgdialect.New())
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,12 +11,15 @@ import (
|
|||||||
"github.com/get-drexa/drexa/internal/user"
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
"github.com/get-drexa/drexa/internal/virtualfs"
|
"github.com/get-drexa/drexa/internal/virtualfs"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewServer(c Config) (*fiber.App, error) {
|
func NewServer(c Config) (*fiber.App, error) {
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
db := database.NewFromPostgres(c.Database.PostgresURL)
|
db := database.NewFromPostgres(c.Database.PostgresURL)
|
||||||
|
|
||||||
|
app.Use(logger.New())
|
||||||
|
|
||||||
// Initialize blob store based on config
|
// Initialize blob store based on config
|
||||||
var blobStore blob.Store
|
var blobStore blob.Store
|
||||||
switch c.Storage.Backend {
|
switch c.Storage.Backend {
|
||||||
@@ -51,8 +54,8 @@ func NewServer(c Config) (*fiber.App, error) {
|
|||||||
return nil, fmt.Errorf("failed to create virtual file system: %w", err)
|
return nil, fmt.Errorf("failed to create virtual file system: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
userService := user.NewService(db)
|
userService := user.NewService()
|
||||||
authService := auth.NewService(db, userService, auth.TokenConfig{
|
authService := auth.NewService(userService, auth.TokenConfig{
|
||||||
Issuer: c.JWT.Issuer,
|
Issuer: c.JWT.Issuer,
|
||||||
Audience: c.JWT.Audience,
|
Audience: c.JWT.Audience,
|
||||||
SecretKey: c.JWT.SecretKey,
|
SecretKey: c.JWT.SecretKey,
|
||||||
@@ -60,8 +63,8 @@ func NewServer(c Config) (*fiber.App, error) {
|
|||||||
uploadService := upload.NewService(vfs, blobStore)
|
uploadService := upload.NewService(vfs, blobStore)
|
||||||
|
|
||||||
api := app.Group("/api")
|
api := app.Group("/api")
|
||||||
auth.RegisterAPIRoutes(api, authService)
|
auth.NewHTTPHandler(authService, db).RegisterRoutes(api)
|
||||||
upload.RegisterAPIRoutes(api, uploadService)
|
upload.NewHTTPHandler(uploadService).RegisterRoutes(api)
|
||||||
|
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,8 +7,6 @@ import (
|
|||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const uploadServiceKey = "uploadService"
|
|
||||||
|
|
||||||
type createUploadRequest struct {
|
type createUploadRequest struct {
|
||||||
ParentID string `json:"parentId"`
|
ParentID string `json:"parentId"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@@ -18,35 +16,34 @@ type updateUploadRequest struct {
|
|||||||
Status Status `json:"status"`
|
Status Status `json:"status"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterAPIRoutes(api fiber.Router, s *Service) {
|
type HTTPHandler struct {
|
||||||
upload := api.Group("/uploads", func(c *fiber.Ctx) error {
|
service *Service
|
||||||
c.Locals(uploadServiceKey, s)
|
|
||||||
return c.Next()
|
|
||||||
})
|
|
||||||
|
|
||||||
upload.Post("/", createUpload)
|
|
||||||
upload.Put("/:uploadID/content", receiveUpload)
|
|
||||||
upload.Patch("/:uploadID", updateUpload)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustUploadService(c *fiber.Ctx) *Service {
|
func NewHTTPHandler(s *Service) *HTTPHandler {
|
||||||
return c.Locals(uploadServiceKey).(*Service)
|
return &HTTPHandler{service: s}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUpload(c *fiber.Ctx) error {
|
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
|
||||||
|
upload := api.Group("/uploads")
|
||||||
|
|
||||||
|
upload.Post("/", h.Create)
|
||||||
|
upload.Put("/:uploadID/content", h.ReceiveContent)
|
||||||
|
upload.Patch("/:uploadID", h.Update)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HTTPHandler) Create(c *fiber.Ctx) error {
|
||||||
u, err := auth.AuthenticatedUser(c)
|
u, err := auth.AuthenticatedUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
|
||||||
}
|
}
|
||||||
|
|
||||||
s := mustUploadService(c)
|
|
||||||
|
|
||||||
req := new(createUploadRequest)
|
req := new(createUploadRequest)
|
||||||
if err := c.BodyParser(req); err != nil {
|
if err := c.BodyParser(req); err != nil {
|
||||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
||||||
}
|
}
|
||||||
|
|
||||||
upload, err := s.CreateUpload(c.Context(), u.ID, CreateUploadOptions{
|
upload, err := h.service.CreateUpload(c.Context(), u.ID, CreateUploadOptions{
|
||||||
ParentID: req.ParentID,
|
ParentID: req.ParentID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
})
|
})
|
||||||
@@ -57,17 +54,15 @@ func createUpload(c *fiber.Ctx) error {
|
|||||||
return c.JSON(upload)
|
return c.JSON(upload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func receiveUpload(c *fiber.Ctx) error {
|
func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
|
||||||
u, err := auth.AuthenticatedUser(c)
|
u, err := auth.AuthenticatedUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
|
||||||
}
|
}
|
||||||
|
|
||||||
s := mustUploadService(c)
|
|
||||||
|
|
||||||
uploadID := c.Params("uploadID")
|
uploadID := c.Params("uploadID")
|
||||||
|
|
||||||
err = s.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream())
|
err = h.service.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream())
|
||||||
defer c.Request().CloseBodyStream()
|
defer c.Request().CloseBodyStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
||||||
@@ -76,21 +71,19 @@ func receiveUpload(c *fiber.Ctx) error {
|
|||||||
return c.SendStatus(fiber.StatusNoContent)
|
return c.SendStatus(fiber.StatusNoContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateUpload(c *fiber.Ctx) error {
|
func (h *HTTPHandler) Update(c *fiber.Ctx) error {
|
||||||
u, err := auth.AuthenticatedUser(c)
|
u, err := auth.AuthenticatedUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
|
||||||
}
|
}
|
||||||
|
|
||||||
s := mustUploadService(c)
|
|
||||||
|
|
||||||
req := new(updateUploadRequest)
|
req := new(updateUploadRequest)
|
||||||
if err := c.BodyParser(req); err != nil {
|
if err := c.BodyParser(req); err != nil {
|
||||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Status == StatusCompleted {
|
if req.Status == StatusCompleted {
|
||||||
upload, err := s.CompleteUpload(c.Context(), u.ID, c.Params("uploadID"))
|
upload, err := h.service.CompleteUpload(c.Context(), u.ID, c.Params("uploadID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrNotFound) {
|
if errors.Is(err, ErrNotFound) {
|
||||||
return c.SendStatus(fiber.StatusNotFound)
|
return c.SendStatus(fiber.StatusNotFound)
|
||||||
|
|||||||
@@ -11,9 +11,7 @@ import (
|
|||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Service struct {
|
type Service struct{}
|
||||||
db *bun.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserRegistrationOptions struct {
|
type UserRegistrationOptions struct {
|
||||||
Email string
|
Email string
|
||||||
@@ -21,20 +19,24 @@ type UserRegistrationOptions struct {
|
|||||||
Password password.Hashed
|
Password password.Hashed
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(db *bun.DB) *Service {
|
func NewService() *Service {
|
||||||
return &Service{
|
return &Service{}
|
||||||
db: db,
|
}
|
||||||
}
|
|
||||||
|
func (s *Service) RegisterUser(ctx context.Context, db bun.IDB, opts UserRegistrationOptions) (*User, error) {
|
||||||
|
uid, err := newUserID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) RegisterUser(ctx context.Context, opts UserRegistrationOptions) (*User, error) {
|
|
||||||
u := User{
|
u := User{
|
||||||
|
ID: uid,
|
||||||
Email: opts.Email,
|
Email: opts.Email,
|
||||||
DisplayName: opts.DisplayName,
|
DisplayName: opts.DisplayName,
|
||||||
Password: opts.Password,
|
Password: opts.Password,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := s.db.NewInsert().Model(&u).Returning("*").Exec(ctx)
|
_, err = db.NewInsert().Model(&u).Returning("*").Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if database.IsUniqueViolation(err) {
|
if database.IsUniqueViolation(err) {
|
||||||
return nil, newAlreadyExistsError(u.Email)
|
return nil, newAlreadyExistsError(u.Email)
|
||||||
@@ -45,9 +47,9 @@ func (s *Service) RegisterUser(ctx context.Context, opts UserRegistrationOptions
|
|||||||
return &u, nil
|
return &u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) {
|
func (s *Service) UserByID(ctx context.Context, db bun.IDB, id uuid.UUID) (*User, error) {
|
||||||
var user User
|
var user User
|
||||||
err := s.db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx)
|
err := db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, newNotFoundError(id, "")
|
return nil, newNotFoundError(id, "")
|
||||||
@@ -57,9 +59,9 @@ func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) {
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) UserByEmail(ctx context.Context, email string) (*User, error) {
|
func (s *Service) UserByEmail(ctx context.Context, db bun.IDB, email string) (*User, error) {
|
||||||
var user User
|
var user User
|
||||||
err := s.db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx)
|
err := db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, newNotFoundError(uuid.Nil, email)
|
return nil, newNotFoundError(uuid.Nil, email)
|
||||||
@@ -69,6 +71,6 @@ func (s *Service) UserByEmail(ctx context.Context, email string) (*User, error)
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) UserExistsByEmail(ctx context.Context, email string) (bool, error) {
|
func (s *Service) UserExistsByEmail(ctx context.Context, db bun.IDB, email string) (bool, error) {
|
||||||
return s.db.NewSelect().Model(&User{}).Where("email = ?", email).Exists(ctx)
|
return db.NewSelect().Model(&User{}).Where("email = ?", email).Exists(ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package user
|
package user
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/get-drexa/drexa/internal/password"
|
"github.com/get-drexa/drexa/internal/password"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
@@ -10,9 +12,15 @@ type User struct {
|
|||||||
bun.BaseModel `bun:"users"`
|
bun.BaseModel `bun:"users"`
|
||||||
|
|
||||||
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
|
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
|
||||||
DisplayName string `bun:"display_name,notnull" json:"displayName"`
|
DisplayName string `bun:"display_name" json:"displayName"`
|
||||||
Email string `bun:"email,unique,notnull" json:"email"`
|
Email string `bun:"email,unique,notnull" json:"email"`
|
||||||
Password password.Hashed `bun:"password,notnull" json:"-"`
|
Password password.Hashed `bun:"password,notnull" json:"-"`
|
||||||
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
|
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
|
||||||
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
|
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
|
||||||
|
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
|
||||||
|
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUserID() (uuid.UUID, error) {
|
||||||
|
return uuid.NewV7()
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user