Compare commits

...

7 Commits

Author SHA1 Message Date
629d56b5ab fix: registration endpoint and db auto close issue 2025-11-29 20:51:56 +00:00
5e4e08c255 fix: migration code not working
- read database config from config file
- rename migration file to expected file name format
2025-11-29 20:32:32 +00:00
42b805fbd1 feat: add blob store Initialize method 2025-11-29 18:39:21 +00:00
ab4c14bc09 feat: impl config loading 2025-11-29 18:09:41 +00:00
fd3b2d3908 fix: handle NewServer error 2025-11-29 17:28:53 +00:00
39824e45d9 feat: impl upload api endpoints 2025-11-29 17:25:11 +00:00
6aee150a59 build: remove devcontainer name 2025-11-29 17:24:49 +00:00
25 changed files with 569 additions and 191 deletions

View File

@@ -1,5 +1,4 @@
{ {
"name": "React + Bun + Convex Development",
"build": { "build": {
"context": ".", "context": ".",
"dockerfile": "Dockerfile" "dockerfile": "Dockerfile"

View File

@@ -1,22 +1,34 @@
package main package main
import ( import (
"flag"
"fmt" "fmt"
"log" "log"
"os"
"github.com/get-drexa/drexa/internal/drexa" "github.com/get-drexa/drexa/internal/drexa"
"github.com/joho/godotenv"
) )
func main() { func main() {
_ = godotenv.Load() configPath := flag.String("config", "", "path to config file (required)")
flag.Parse()
config, err := drexa.ServerConfigFromEnv() if *configPath == "" {
fmt.Fprintln(os.Stderr, "error: --config is required")
flag.Usage()
os.Exit(1)
}
config, err := drexa.ConfigFromFile(*configPath)
if err != nil {
log.Fatalf("failed to load config: %v", err)
}
server, err := drexa.NewServer(*config)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
server := drexa.NewServer(*config) log.Printf("starting server on :%d", config.Server.Port)
log.Fatal(server.Listen(fmt.Sprintf(":%d", config.Server.Port)))
log.Fatal(server.Listen(fmt.Sprintf(":%d", config.Port)))
} }

View File

@@ -1,13 +1,37 @@
package main package main
import ( import (
"context"
"flag"
"fmt"
"log" "log"
"os"
"github.com/get-drexa/drexa/internal/database" "github.com/get-drexa/drexa/internal/database"
"github.com/get-drexa/drexa/internal/drexa"
) )
func main() { func main() {
if err := database.RunMigrations(); err != nil { configPath := flag.String("config", "", "path to config file (required)")
log.Fatalf("Failed to run migrations: %v", err) flag.Parse()
if *configPath == "" {
fmt.Fprintln(os.Stderr, "error: --config is required")
flag.Usage()
os.Exit(1)
} }
config, err := drexa.ConfigFromFile(*configPath)
if err != nil {
log.Fatalf("failed to load config: %v", err)
}
db := database.NewFromPostgres(config.Database.PostgresURL)
defer db.Close()
log.Println("running migrations...")
if err := database.RunMigrations(context.Background(), db); err != nil {
log.Fatalf("failed to run migrations: %v", err)
}
log.Println("migrations completed successfully")
} }

View File

@@ -0,0 +1,30 @@
# Drexa Backend Configuration
# Copy this file to config.yaml and adjust values for your environment.
server:
port: 8080
database:
postgres_url: postgres://user:password@localhost:5432/drexa?sslmode=disable
jwt:
issuer: drexa
audience: drexa-api
# Secret key can be provided via (in order of precedence):
# 1. JWT_SECRET_KEY environment variable (base64 encoded)
# 2. secret_key_base64 below (base64 encoded)
# 3. secret_key_path below (file with base64 encoded content)
# secret_key_base64: "base64encodedkey"
secret_key_path: /run/secrets/jwt_secret_key
storage:
# Mode: "flat" (UUID-based keys) or "hierarchical" (path-based keys)
# Note: S3 backend only supports "flat" mode
mode: flat
# Backend: "fs" (filesystem) or "s3" (not yet implemented)
backend: fs
# Required when backend is "fs"
root_path: /var/lib/drexa/blobs
# Required when backend is "s3"
# bucket: my-drexa-bucket

15
apps/backend/config.yaml Normal file
View File

@@ -0,0 +1,15 @@
server:
port: 8080
database:
postgres_url: postgres://drexa:hunter2@helian:5433/drexa?sslmode=disable
jwt:
issuer: drexa
audience: drexa-api
secret_key_base64: "pNeUExoqdakfecZLFL53NJpY4iB9zFot9EuEBItlYKY="
storage:
mode: hierarchical
backend: fs
root_path: ./data

View File

@@ -3,16 +3,16 @@ module github.com/get-drexa/drexa
go 1.25.4 go 1.25.4
require ( require (
github.com/gabriel-vasile/mimetype v1.4.11
github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/fiber/v2 v2.52.9
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/sqids/sqids-go v0.4.1
github.com/uptrace/bun v1.2.15 github.com/uptrace/bun v1.2.15
golang.org/x/crypto v0.40.0 golang.org/x/crypto v0.40.0
gopkg.in/yaml.v3 v3.0.1
) )
require ( require (
github.com/gabriel-vasile/mimetype v1.4.11 // indirect
github.com/joho/godotenv v1.5.1 // indirect
github.com/sqids/sqids-go v0.4.1 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect go.opentelemetry.io/otel v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect go.opentelemetry.io/otel/trace v1.37.0 // indirect
mellium.im/sasl v0.3.2 // indirect mellium.im/sasl v0.3.2 // indirect
@@ -20,7 +20,7 @@ require (
require ( require (
github.com/andybalholm/brotli v1.1.0 // indirect github.com/andybalholm/brotli v1.1.0 // indirect
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/golang-jwt/jwt/v5 v5.3.0
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/compress v1.17.9 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-colorable v0.1.14 // indirect
@@ -28,7 +28,6 @@ require (
github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
github.com/rivo/uniseg v0.2.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect
github.com/stretchr/testify v1.10.0 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/uptrace/bun/dialect/pgdialect v1.2.15 github.com/uptrace/bun/dialect/pgdialect v1.2.15
github.com/uptrace/bun/driver/pgdriver v1.2.15 github.com/uptrace/bun/driver/pgdriver v1.2.15

View File

@@ -8,20 +8,24 @@ github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5
github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
@@ -59,6 +63,9 @@ golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+Zdx
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
mellium.im/sasl v0.3.2 h1:PT6Xp7ccn9XaXAnJ03FcEjmAn7kK1x7aoXV6F+Vmrl0= mellium.im/sasl v0.3.2 h1:PT6Xp7ccn9XaXAnJ03FcEjmAn7kK1x7aoXV6F+Vmrl0=

View File

@@ -2,13 +2,13 @@ package auth
import ( import (
"errors" "errors"
"log/slog"
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
) )
const authServiceKey = "authService"
type loginRequest struct { type loginRequest struct {
Email string `json:"email"` Email string `json:"email"`
Password string `json:"password"` Password string `json:"password"`
@@ -26,29 +26,35 @@ type loginResponse struct {
RefreshToken string `json:"refreshToken"` RefreshToken string `json:"refreshToken"`
} }
func RegisterAPIRoutes(api fiber.Router, s *Service) { type HTTPHandler struct {
auth := api.Group("/auth", func(c *fiber.Ctx) error { service *Service
c.Locals(authServiceKey, s) db *bun.DB
return c.Next()
})
auth.Post("/login", login)
auth.Post("/register", register)
} }
func mustAuthService(c *fiber.Ctx) *Service { func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
return c.Locals(authServiceKey).(*Service) return &HTTPHandler{service: s, db: db}
} }
func login(c *fiber.Ctx) error { func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
s := mustAuthService(c) auth := api.Group("/auth")
auth.Post("/login", h.Login)
auth.Post("/register", h.Register)
}
func (h *HTTPHandler) Login(c *fiber.Ctx) error {
req := new(loginRequest) req := new(loginRequest)
if err := c.BodyParser(req); err != nil { if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
} }
result, err := s.LoginWithEmailAndPassword(c.Context(), req.Email, req.Password) tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
defer tx.Rollback()
result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
if err != nil { if err != nil {
if errors.Is(err, ErrInvalidCredentials) { if errors.Is(err, ErrInvalidCredentials) {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"}) return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
@@ -56,6 +62,10 @@ func login(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
} }
if err := tx.Commit(); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
return c.JSON(loginResponse{ return c.JSON(loginResponse{
User: *result.User, User: *result.User,
AccessToken: result.AccessToken, AccessToken: result.AccessToken,
@@ -63,15 +73,20 @@ func login(c *fiber.Ctx) error {
}) })
} }
func register(c *fiber.Ctx) error { func (h *HTTPHandler) Register(c *fiber.Ctx) error {
s := mustAuthService(c)
req := new(registerRequest) req := new(registerRequest)
if err := c.BodyParser(req); err != nil { if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
} }
result, err := s.Register(c.Context(), registerOptions{ tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
slog.Error("failed to begin transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
defer tx.Rollback()
result, err := h.service.Register(c.Context(), tx, registerOptions{
email: req.Email, email: req.Email,
password: req.Password, password: req.Password,
displayName: req.DisplayName, displayName: req.DisplayName,
@@ -81,6 +96,12 @@ func register(c *fiber.Ctx) error {
if errors.As(err, &ae) { if errors.As(err, &ae) {
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"}) return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"})
} }
slog.Error("failed to register user", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
if err := tx.Commit(); err != nil {
slog.Error("failed to commit transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
} }

View File

@@ -6,13 +6,14 @@ import (
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
) )
const authenticatedUserKey = "authenticatedUser" const authenticatedUserKey = "authenticatedUser"
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token. // NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token.
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser. // To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
func NewBearerAuthMiddleware(s *Service) fiber.Handler { func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
authHeader := c.Get("Authorization") authHeader := c.Get("Authorization")
if authHeader == "" { if authHeader == "" {
@@ -25,7 +26,7 @@ func NewBearerAuthMiddleware(s *Service) fiber.Handler {
} }
token := parts[1] token := parts[1]
u, err := s.AuthenticateWithAccessToken(c.Context(), token) u, err := s.AuthenticateWithAccessToken(c.Context(), db, token)
if err != nil { if err != nil {
var e *InvalidAccessTokenError var e *InvalidAccessTokenError
if errors.As(err, &e) { if errors.As(err, &e) {

View File

@@ -20,7 +20,6 @@ type LoginResult struct {
var ErrInvalidCredentials = errors.New("invalid credentials") var ErrInvalidCredentials = errors.New("invalid credentials")
type Service struct { type Service struct {
db *bun.DB
userService *user.Service userService *user.Service
tokenConfig TokenConfig tokenConfig TokenConfig
} }
@@ -31,16 +30,15 @@ type registerOptions struct {
password string password string
} }
func NewService(db *bun.DB, userService *user.Service, tokenConfig TokenConfig) *Service { func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
return &Service{ return &Service{
db: db,
userService: userService, userService: userService,
tokenConfig: tokenConfig, tokenConfig: tokenConfig,
} }
} }
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain string) (*LoginResult, error) { func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*LoginResult, error) {
u, err := s.userService.UserByEmail(ctx, email) u, err := s.userService.UserByEmail(ctx, db, email)
if err != nil { if err != nil {
var nf *user.NotFoundError var nf *user.NotFoundError
if errors.As(err, &nf) { if errors.As(err, &nf) {
@@ -64,7 +62,7 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain st
return nil, err return nil, err
} }
_, err = s.db.NewInsert().Model(rt).Exec(ctx) _, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -76,13 +74,13 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain st
}, nil }, nil
} }
func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) { func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions) (*LoginResult, error) {
hashed, err := password.Hash(opts.password) hashed, err := password.Hash(opts.password)
if err != nil { if err != nil {
return nil, err return nil, err
} }
u, err := s.userService.RegisterUser(ctx, user.UserRegistrationOptions{ u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
Email: opts.email, Email: opts.email,
DisplayName: opts.displayName, DisplayName: opts.displayName,
Password: hashed, Password: hashed,
@@ -101,7 +99,7 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
return nil, err return nil, err
} }
_, err = s.db.NewInsert().Model(rt).Exec(ctx) _, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -113,7 +111,7 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
}, nil }, nil
} }
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string) (*user.User, error) { func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
claims, err := ParseAccessToken(token, &s.tokenConfig) claims, err := ParseAccessToken(token, &s.tokenConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -124,5 +122,5 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string)
return nil, newInvalidAccessTokenError(err) return nil, newInvalidAccessTokenError(err)
} }
return s.userService.UserByID(ctx, id) return s.userService.UserByID(ctx, db, id)
} }

View File

@@ -36,6 +36,10 @@ type RefreshToken struct {
CreatedAt time.Time `bun:"created_at,notnull"` CreatedAt time.Time `bun:"created_at,notnull"`
} }
func newTokenID() (uuid.UUID, error) {
return uuid.NewV7()
}
func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) { func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) {
now := time.Now() now := time.Now()
@@ -63,9 +67,9 @@ func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error
return nil, fmt.Errorf("failed to generate refresh token: %w", err) return nil, fmt.Errorf("failed to generate refresh token: %w", err)
} }
id, err := uuid.NewV7() id, err := newTokenID()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err) return nil, fmt.Errorf("failed to generate token ID: %w", err)
} }
h := sha256.Sum256(buf) h := sha256.Sum256(buf)

View File

@@ -24,6 +24,10 @@ func NewFSStore(config FSStoreConfig) *FSStore {
return &FSStore{config: config} return &FSStore{config: config}
} }
func (s *FSStore) Initialize(ctx context.Context) error {
return os.MkdirAll(s.config.Root, 0755)
}
func (s *FSStore) GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error) { func (s *FSStore) GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error) {
return s.config.UploadURL, nil return s.config.UploadURL, nil
} }

View File

@@ -15,6 +15,7 @@ type UpdateOptions struct {
} }
type Store interface { type Store interface {
Initialize(ctx context.Context) error
GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error) GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error)
Put(ctx context.Context, key Key, reader io.Reader) error Put(ctx context.Context, key Key, reader io.Reader) error
Update(ctx context.Context, key Key, opts UpdateOptions) error Update(ctx context.Context, key Key, opts UpdateOptions) error

View File

@@ -1,17 +1,28 @@
package database package database
import ( import (
"context"
"embed" "embed"
"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate" "github.com/uptrace/bun/migrate"
) )
//go:embed migrations/*.sql //go:embed migrations/*.sql
var sqlMigrations embed.FS var sqlMigrations embed.FS
// RunMigrations discovers and runs all migrations in the migrations directory. // RunMigrations discovers and runs all migrations against the database.
// Currently, the migrations directory is in internal/db/migrations. func RunMigrations(ctx context.Context, db *bun.DB) error {
func RunMigrations() error { migrations := migrate.NewMigrations()
m := migrate.NewMigrations() if err := migrations.Discover(sqlMigrations); err != nil {
return m.Discover(sqlMigrations) return err
}
migrator := migrate.NewMigrator(db, migrations)
if err := migrator.Init(ctx); err != nil {
return err
}
_, err := migrator.Migrate(ctx)
return err
} }

View File

@@ -1,32 +1,9 @@
-- Enable UUID extension for UUIDv7 support
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
-- UUIDv7 generation function (timestamp-ordered UUIDs)
-- Based on the draft RFC: https://datatracker.ietf.org/doc/html/draft-peabody-dispatch-new-uuid-format
CREATE OR REPLACE FUNCTION uuid_generate_v7()
RETURNS UUID
AS $$
DECLARE
unix_ts_ms BIGINT;
uuid_bytes BYTEA;
BEGIN
unix_ts_ms = (EXTRACT(EPOCH FROM CLOCK_TIMESTAMP()) * 1000)::BIGINT;
uuid_bytes = OVERLAY(gen_random_bytes(16) PLACING
SUBSTRING(INT8SEND(unix_ts_ms) FROM 3) FROM 1 FOR 6
);
-- Set version (7) and variant bits
uuid_bytes = SET_BYTE(uuid_bytes, 6, (GET_BYTE(uuid_bytes, 6) & 15) | 112);
uuid_bytes = SET_BYTE(uuid_bytes, 8, (GET_BYTE(uuid_bytes, 8) & 63) | 128);
RETURN ENCODE(uuid_bytes, 'hex')::UUID;
END;
$$ LANGUAGE plpgsql VOLATILE;
-- ============================================================================ -- ============================================================================
-- Application Tables -- Application Tables
-- ============================================================================ -- ============================================================================
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT uuid_generate_v7(), id UUID PRIMARY KEY,
display_name TEXT, display_name TEXT,
email TEXT NOT NULL UNIQUE, email TEXT NOT NULL UNIQUE,
password TEXT NOT NULL, password TEXT NOT NULL,
@@ -39,7 +16,7 @@ CREATE TABLE IF NOT EXISTS users (
CREATE INDEX idx_users_email ON users(email); CREATE INDEX idx_users_email ON users(email);
CREATE TABLE IF NOT EXISTS refresh_tokens ( CREATE TABLE IF NOT EXISTS refresh_tokens (
id UUID PRIMARY KEY DEFAULT uuid_generate_v7(), id UUID PRIMARY KEY,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token_hash TEXT NOT NULL UNIQUE, token_hash TEXT NOT NULL UNIQUE,
expires_at TIMESTAMPTZ NOT NULL, expires_at TIMESTAMPTZ NOT NULL,
@@ -52,7 +29,7 @@ CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
-- Virtual filesystem nodes (unified files + directories) -- Virtual filesystem nodes (unified files + directories)
CREATE TABLE IF NOT EXISTS vfs_nodes ( CREATE TABLE IF NOT EXISTS vfs_nodes (
id UUID PRIMARY KEY DEFAULT uuid_generate_v7(), id UUID PRIMARY KEY,
public_id TEXT NOT NULL UNIQUE, -- opaque ID for external API (no timestamp leak) public_id TEXT NOT NULL UNIQUE, -- opaque ID for external API (no timestamp leak)
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
parent_id UUID REFERENCES vfs_nodes(id) ON DELETE CASCADE, -- NULL = root directory parent_id UUID REFERENCES vfs_nodes(id) ON DELETE CASCADE, -- NULL = root directory
@@ -83,7 +60,7 @@ CREATE UNIQUE INDEX idx_vfs_nodes_user_root ON vfs_nodes(user_id) WHERE parent_i
CREATE INDEX idx_vfs_nodes_pending ON vfs_nodes(created_at) WHERE status = 'pending'; -- for cleanup job CREATE INDEX idx_vfs_nodes_pending ON vfs_nodes(created_at) WHERE status = 'pending'; -- for cleanup job
CREATE TABLE IF NOT EXISTS node_shares ( CREATE TABLE IF NOT EXISTS node_shares (
id UUID PRIMARY KEY DEFAULT uuid_generate_v7(), id UUID PRIMARY KEY,
node_id UUID NOT NULL REFERENCES vfs_nodes(id) ON DELETE CASCADE, node_id UUID NOT NULL REFERENCES vfs_nodes(id) ON DELETE CASCADE,
share_token TEXT NOT NULL UNIQUE, share_token TEXT NOT NULL UNIQUE,
expires_at TIMESTAMPTZ, expires_at TIMESTAMPTZ,

View File

@@ -2,6 +2,7 @@ package database
import ( import (
"database/sql" "database/sql"
"time"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect" "github.com/uptrace/bun/dialect/pgdialect"
@@ -10,6 +11,17 @@ import (
func NewFromPostgres(url string) *bun.DB { func NewFromPostgres(url string) *bun.DB {
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(url))) sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(url)))
// Configure connection pool to prevent "database closed" errors
// SetMaxOpenConns sets the maximum number of open connections to the database
sqldb.SetMaxOpenConns(25)
// SetMaxIdleConns sets the maximum number of connections in the idle connection pool
sqldb.SetMaxIdleConns(5)
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused
sqldb.SetConnMaxLifetime(5 * time.Minute)
// SetConnMaxIdleTime sets the maximum amount of time a connection may be idle
sqldb.SetConnMaxIdleTime(10 * time.Minute)
db := bun.NewDB(sqldb, pgdialect.New()) db := bun.NewDB(sqldb, pgdialect.New())
return db return db
} }

View File

@@ -0,0 +1,155 @@
package drexa
import (
"encoding/base64"
"errors"
"fmt"
"os"
"gopkg.in/yaml.v3"
)
type StorageMode string
type StorageBackend string
const (
StorageModeFlat StorageMode = "flat"
StorageModeHierarchical StorageMode = "hierarchical"
)
const (
StorageBackendFS StorageBackend = "fs"
StorageBackendS3 StorageBackend = "s3"
)
type Config struct {
Server ServerConfig `yaml:"server"`
Database DatabaseConfig `yaml:"database"`
JWT JWTConfig `yaml:"jwt"`
Storage StorageConfig `yaml:"storage"`
}
type ServerConfig struct {
Port int `yaml:"port"`
}
type DatabaseConfig struct {
PostgresURL string `yaml:"postgres_url"`
}
type JWTConfig struct {
Issuer string `yaml:"issuer"`
Audience string `yaml:"audience"`
SecretKeyBase64 string `yaml:"secret_key_base64"`
SecretKeyPath string `yaml:"secret_key_path"`
SecretKey []byte `yaml:"-"`
}
type StorageConfig struct {
Mode StorageMode `yaml:"mode"`
Backend StorageBackend `yaml:"backend"`
RootPath string `yaml:"root_path"`
Bucket string `yaml:"bucket"`
}
// ConfigFromFile loads configuration from a YAML file.
// JWT secret key is loaded from JWT_SECRET_KEY env var (base64 encoded),
// falling back to the file path specified in jwt.secret_key_path.
func ConfigFromFile(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("config file not found: %s", path)
}
return nil, err
}
var config Config
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, err
}
// Load JWT secret key (priority: env var > config base64 > config file path)
if envKey := os.Getenv("JWT_SECRET_KEY"); envKey != "" {
key, err := base64.StdEncoding.DecodeString(envKey)
if err != nil {
return nil, errors.New("JWT_SECRET_KEY env var is not valid base64")
}
config.JWT.SecretKey = key
} else if config.JWT.SecretKeyBase64 != "" {
key, err := base64.StdEncoding.DecodeString(config.JWT.SecretKeyBase64)
if err != nil {
return nil, errors.New("jwt.secret_key_base64 is not valid base64")
}
config.JWT.SecretKey = key
} else if config.JWT.SecretKeyPath != "" {
keyData, err := os.ReadFile(config.JWT.SecretKeyPath)
if err != nil {
return nil, err
}
key, err := base64.StdEncoding.DecodeString(string(keyData))
if err != nil {
return nil, errors.New("jwt.secret_key_path file content is not valid base64")
}
config.JWT.SecretKey = key
}
if errs := config.Validate(); len(errs) > 0 {
return nil, NewConfigError(errs...)
}
return &config, nil
}
// Validate checks for required configuration fields.
func (c *Config) Validate() []error {
var errs []error
// Server
if c.Server.Port == 0 {
errs = append(errs, errors.New("server.port is required"))
}
// Database
if c.Database.PostgresURL == "" {
errs = append(errs, errors.New("database.postgres_url is required"))
}
// JWT
if c.JWT.Issuer == "" {
errs = append(errs, errors.New("jwt.issuer is required"))
}
if c.JWT.Audience == "" {
errs = append(errs, errors.New("jwt.audience is required"))
}
if len(c.JWT.SecretKey) == 0 {
errs = append(errs, errors.New("jwt secret key is required (set JWT_SECRET_KEY env var, jwt.secret_key_base64, or jwt.secret_key_path)"))
}
// Storage
if c.Storage.Mode == "" {
errs = append(errs, errors.New("storage.mode is required"))
} else if c.Storage.Mode != StorageModeFlat && c.Storage.Mode != StorageModeHierarchical {
errs = append(errs, errors.New("storage.mode must be 'flat' or 'hierarchical'"))
}
if c.Storage.Backend == "" {
errs = append(errs, errors.New("storage.backend is required"))
} else if c.Storage.Backend != StorageBackendFS && c.Storage.Backend != StorageBackendS3 {
errs = append(errs, errors.New("storage.backend must be 'fs' or 's3'"))
}
if c.Storage.Backend == StorageBackendFS && c.Storage.RootPath == "" {
errs = append(errs, errors.New("storage.root_path is required when backend is 'fs'"))
}
if c.Storage.Backend == StorageBackendS3 {
if c.Storage.Bucket == "" {
errs = append(errs, errors.New("storage.bucket is required when backend is 's3'"))
}
if c.Storage.Mode == StorageModeHierarchical {
errs = append(errs, errors.New("storage.mode must be 'flat' when backend is 's3'"))
}
}
return errs
}

View File

@@ -5,17 +5,17 @@ import (
"strings" "strings"
) )
type ServerConfigError struct { type ConfigError struct {
Errors []error Errors []error
} }
func NewServerConfigError(errs ...error) *ServerConfigError { func NewConfigError(errs ...error) *ConfigError {
return &ServerConfigError{Errors: errs} return &ConfigError{Errors: errs}
} }
func (e *ServerConfigError) Error() string { func (e *ConfigError) Error() string {
sb := strings.Builder{} sb := strings.Builder{}
sb.WriteString("invalid server config:\n") sb.WriteString("invalid config:\n")
for _, err := range e.Errors { for _, err := range e.Errors {
sb.WriteString(fmt.Sprintf(" - %s\n", err.Error())) sb.WriteString(fmt.Sprintf(" - %s\n", err.Error()))
} }

View File

@@ -1,86 +1,70 @@
package drexa package drexa
import ( import (
"encoding/hex" "context"
"errors"
"fmt" "fmt"
"os"
"strconv"
"github.com/get-drexa/drexa/internal/auth" "github.com/get-drexa/drexa/internal/auth"
"github.com/get-drexa/drexa/internal/blob"
"github.com/get-drexa/drexa/internal/database" "github.com/get-drexa/drexa/internal/database"
"github.com/get-drexa/drexa/internal/upload"
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger"
) )
type ServerConfig struct { func NewServer(c Config) (*fiber.App, error) {
Port int
PostgresURL string
JWTIssuer string
JWTAudience string
JWTSecretKey []byte
}
func NewServer(c ServerConfig) *fiber.App {
app := fiber.New() app := fiber.New()
db := database.NewFromPostgres(c.PostgresURL) db := database.NewFromPostgres(c.Database.PostgresURL)
userService := user.NewService(db) app.Use(logger.New())
authService := auth.NewService(db, userService, auth.TokenConfig{
Issuer: c.JWTIssuer, // Initialize blob store based on config
Audience: c.JWTAudience, var blobStore blob.Store
SecretKey: c.JWTSecretKey, switch c.Storage.Backend {
case StorageBackendFS:
blobStore = blob.NewFSStore(blob.FSStoreConfig{
Root: c.Storage.RootPath,
})
case StorageBackendS3:
return nil, fmt.Errorf("s3 storage backend not yet implemented")
default:
return nil, fmt.Errorf("unknown storage backend: %s", c.Storage.Backend)
}
err := blobStore.Initialize(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to initialize blob store: %w", err)
}
// Initialize key resolver based on config
var keyResolver virtualfs.BlobKeyResolver
switch c.Storage.Mode {
case StorageModeFlat:
keyResolver = virtualfs.NewFlatKeyResolver()
case StorageModeHierarchical:
keyResolver = virtualfs.NewHierarchicalKeyResolver(db)
default:
return nil, fmt.Errorf("unknown storage mode: %s", c.Storage.Mode)
}
vfs, err := virtualfs.NewVirtualFS(db, blobStore, keyResolver)
if err != nil {
return nil, fmt.Errorf("failed to create virtual file system: %w", err)
}
userService := user.NewService()
authService := auth.NewService(userService, auth.TokenConfig{
Issuer: c.JWT.Issuer,
Audience: c.JWT.Audience,
SecretKey: c.JWT.SecretKey,
}) })
uploadService := upload.NewService(vfs, blobStore)
api := app.Group("/api") api := app.Group("/api")
auth.RegisterAPIRoutes(api, authService) auth.NewHTTPHandler(authService, db).RegisterRoutes(api)
upload.NewHTTPHandler(uploadService).RegisterRoutes(api)
return app return app, nil
}
// ServerConfigFromEnv creates a ServerConfig from environment variables.
func ServerConfigFromEnv() (*ServerConfig, error) {
c := ServerConfig{
PostgresURL: os.Getenv("POSTGRES_URL"),
JWTIssuer: os.Getenv("JWT_ISSUER"),
JWTAudience: os.Getenv("JWT_AUDIENCE"),
}
errs := []error{}
keyHex := os.Getenv("JWT_SECRET_KEY")
if keyHex == "" {
errs = append(errs, errors.New("JWT_SECRET_KEY is required"))
} else {
k, err := hex.DecodeString(keyHex)
if err != nil {
errs = append(errs, fmt.Errorf("failed to decode JWT_SECRET_KEY: %w", err))
}
c.JWTSecretKey = k
}
p, err := strconv.Atoi(os.Getenv("PORT"))
if err != nil {
errs = append(errs, fmt.Errorf("failed to parse PORT: %w", err))
}
c.Port = p
if c.PostgresURL == "" {
errs = append(errs, errors.New("POSTGRES_URL is required"))
}
if c.JWTIssuer == "" {
errs = append(errs, errors.New("JWT_ISSUER is required"))
}
if c.JWTAudience == "" {
errs = append(errs, errors.New("JWT_AUDIENCE is required"))
}
if len(c.JWTSecretKey) == 0 {
errs = append(errs, errors.New("JWT_SECRET_KEY is required"))
}
if len(errs) > 0 {
return nil, NewServerConfigError(errs...)
}
return &c, nil
} }

View File

@@ -0,0 +1,97 @@
package upload
import (
"errors"
"github.com/get-drexa/drexa/internal/auth"
"github.com/gofiber/fiber/v2"
)
type createUploadRequest struct {
ParentID string `json:"parentId"`
Name string `json:"name"`
}
type updateUploadRequest struct {
Status Status `json:"status"`
}
type HTTPHandler struct {
service *Service
}
func NewHTTPHandler(s *Service) *HTTPHandler {
return &HTTPHandler{service: s}
}
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
upload := api.Group("/uploads")
upload.Post("/", h.Create)
upload.Put("/:uploadID/content", h.ReceiveContent)
upload.Patch("/:uploadID", h.Update)
}
func (h *HTTPHandler) Create(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
}
req := new(createUploadRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
upload, err := h.service.CreateUpload(c.Context(), u.ID, CreateUploadOptions{
ParentID: req.ParentID,
Name: req.Name,
})
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
return c.JSON(upload)
}
func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
}
uploadID := c.Params("uploadID")
err = h.service.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream())
defer c.Request().CloseBodyStream()
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
return c.SendStatus(fiber.StatusNoContent)
}
func (h *HTTPHandler) Update(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
}
req := new(updateUploadRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
if req.Status == StatusCompleted {
upload, err := h.service.CompleteUpload(c.Context(), u.ID, c.Params("uploadID"))
if err != nil {
if errors.Is(err, ErrNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
return c.JSON(upload)
}
return c.SendStatus(fiber.StatusBadRequest)
}

View File

@@ -33,17 +33,6 @@ type CreateUploadOptions struct {
Name string Name string
} }
type Upload struct {
ID string
TargetNode *virtualfs.Node
UploadURL string
}
type CreateUploadResult struct {
UploadID string
UploadURL string
}
func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts CreateUploadOptions) (*Upload, error) { func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts CreateUploadOptions) (*Upload, error) {
parentNode, err := s.vfs.FindNodeByPublicID(ctx, userID, opts.ParentID) parentNode, err := s.vfs.FindNodeByPublicID(ctx, userID, opts.ParentID)
if err != nil { if err != nil {
@@ -77,6 +66,7 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
upload := &Upload{ upload := &Upload{
ID: node.PublicID, ID: node.PublicID,
Status: StatusPending,
TargetNode: node, TargetNode: node,
UploadURL: uploadURL, UploadURL: uploadURL,
} }
@@ -106,36 +96,37 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID
return err return err
} }
s.pendingUploads.Delete(uploadID) upload.Status = StatusCompleted
return nil return nil
} }
func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID string) error { func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID string) (*Upload, error) {
n, ok := s.pendingUploads.Load(uploadID) n, ok := s.pendingUploads.Load(uploadID)
if !ok { if !ok {
return ErrNotFound return nil, ErrNotFound
} }
upload, ok := n.(*Upload) upload, ok := n.(*Upload)
if !ok { if !ok {
return ErrNotFound return nil, ErrNotFound
} }
if upload.TargetNode.UserID != userID { if upload.TargetNode.UserID != userID {
return ErrNotFound return nil, ErrNotFound
} }
if upload.TargetNode.Status == virtualfs.NodeStatusReady { if upload.TargetNode.Status == virtualfs.NodeStatusReady && upload.Status == StatusCompleted {
return nil return upload, nil
} }
err := s.vfs.WriteFile(ctx, upload.TargetNode, virtualfs.FileContentFromBlobKey(upload.TargetNode.BlobKey)) err := s.vfs.WriteFile(ctx, upload.TargetNode, virtualfs.FileContentFromBlobKey(upload.TargetNode.BlobKey))
if err != nil { if err != nil {
return err return nil, err
} }
upload.Status = StatusCompleted
s.pendingUploads.Delete(uploadID) s.pendingUploads.Delete(uploadID)
return nil return upload, nil
} }

View File

@@ -0,0 +1,18 @@
package upload
import "github.com/get-drexa/drexa/internal/virtualfs"
type Status string
const (
StatusPending Status = "pending"
StatusCompleted Status = "completed"
StatusFailed Status = "failed"
)
type Upload struct {
ID string `json:"id"`
Status Status `json:"status"`
TargetNode *virtualfs.Node `json:"-"`
UploadURL string `json:"uploadUrl"`
}

View File

@@ -11,9 +11,7 @@ import (
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type Service struct { type Service struct{}
db *bun.DB
}
type UserRegistrationOptions struct { type UserRegistrationOptions struct {
Email string Email string
@@ -21,20 +19,24 @@ type UserRegistrationOptions struct {
Password password.Hashed Password password.Hashed
} }
func NewService(db *bun.DB) *Service { func NewService() *Service {
return &Service{ return &Service{}
db: db,
}
} }
func (s *Service) RegisterUser(ctx context.Context, opts UserRegistrationOptions) (*User, error) { func (s *Service) RegisterUser(ctx context.Context, db bun.IDB, opts UserRegistrationOptions) (*User, error) {
uid, err := newUserID()
if err != nil {
return nil, err
}
u := User{ u := User{
ID: uid,
Email: opts.Email, Email: opts.Email,
DisplayName: opts.DisplayName, DisplayName: opts.DisplayName,
Password: opts.Password, Password: opts.Password,
} }
_, err := s.db.NewInsert().Model(&u).Returning("*").Exec(ctx) _, err = db.NewInsert().Model(&u).Returning("*").Exec(ctx)
if err != nil { if err != nil {
if database.IsUniqueViolation(err) { if database.IsUniqueViolation(err) {
return nil, newAlreadyExistsError(u.Email) return nil, newAlreadyExistsError(u.Email)
@@ -45,9 +47,9 @@ func (s *Service) RegisterUser(ctx context.Context, opts UserRegistrationOptions
return &u, nil return &u, nil
} }
func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) { func (s *Service) UserByID(ctx context.Context, db bun.IDB, id uuid.UUID) (*User, error) {
var user User var user User
err := s.db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx) err := db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, newNotFoundError(id, "") return nil, newNotFoundError(id, "")
@@ -57,9 +59,9 @@ func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) {
return &user, nil return &user, nil
} }
func (s *Service) UserByEmail(ctx context.Context, email string) (*User, error) { func (s *Service) UserByEmail(ctx context.Context, db bun.IDB, email string) (*User, error) {
var user User var user User
err := s.db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx) err := db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, newNotFoundError(uuid.Nil, email) return nil, newNotFoundError(uuid.Nil, email)
@@ -69,6 +71,6 @@ func (s *Service) UserByEmail(ctx context.Context, email string) (*User, error)
return &user, nil return &user, nil
} }
func (s *Service) UserExistsByEmail(ctx context.Context, email string) (bool, error) { func (s *Service) UserExistsByEmail(ctx context.Context, db bun.IDB, email string) (bool, error) {
return s.db.NewSelect().Model(&User{}).Where("email = ?", email).Exists(ctx) return db.NewSelect().Model(&User{}).Where("email = ?", email).Exists(ctx)
} }

View File

@@ -1,6 +1,8 @@
package user package user
import ( import (
"time"
"github.com/get-drexa/drexa/internal/password" "github.com/get-drexa/drexa/internal/password"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -10,9 +12,15 @@ type User struct {
bun.BaseModel `bun:"users"` bun.BaseModel `bun:"users"`
ID uuid.UUID `bun:",pk,type:uuid" json:"id"` ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
DisplayName string `bun:"display_name,notnull" json:"displayName"` DisplayName string `bun:"display_name" json:"displayName"`
Email string `bun:"email,unique,notnull" json:"email"` Email string `bun:"email,unique,notnull" json:"email"`
Password password.Hashed `bun:"password,notnull" json:"-"` Password password.Hashed `bun:"password,notnull" json:"-"`
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"` StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"` StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
}
func newUserID() (uuid.UUID, error) {
return uuid.NewV7()
} }

View File

@@ -15,6 +15,10 @@ type BlobKeyResolver interface {
type FlatKeyResolver struct{} type FlatKeyResolver struct{}
func NewFlatKeyResolver() *FlatKeyResolver {
return &FlatKeyResolver{}
}
func (r *FlatKeyResolver) KeyMode() blob.KeyMode { func (r *FlatKeyResolver) KeyMode() blob.KeyMode {
return blob.KeyModeStable return blob.KeyModeStable
} }
@@ -34,6 +38,10 @@ type HierarchicalKeyResolver struct {
db *bun.DB db *bun.DB
} }
func NewHierarchicalKeyResolver(db *bun.DB) *HierarchicalKeyResolver {
return &HierarchicalKeyResolver{db: db}
}
func (r *HierarchicalKeyResolver) KeyMode() blob.KeyMode { func (r *HierarchicalKeyResolver) KeyMode() blob.KeyMode {
return blob.KeyModeDerived return blob.KeyModeDerived
} }