Compare commits

...

7 Commits

Author SHA1 Message Date
629d56b5ab fix: registration endpoint and db auto close issue 2025-11-29 20:51:56 +00:00
5e4e08c255 fix: migration code not working
- read database config from config file
- rename migration file to expected file name format
2025-11-29 20:32:32 +00:00
42b805fbd1 feat: add blob store Initialize method 2025-11-29 18:39:21 +00:00
ab4c14bc09 feat: impl config loading 2025-11-29 18:09:41 +00:00
fd3b2d3908 fix: handle NewServer error 2025-11-29 17:28:53 +00:00
39824e45d9 feat: impl upload api endpoints 2025-11-29 17:25:11 +00:00
6aee150a59 build: remove devcontainer name 2025-11-29 17:24:49 +00:00
25 changed files with 569 additions and 191 deletions

View File

@@ -1,5 +1,4 @@
{
"name": "React + Bun + Convex Development",
"build": {
"context": ".",
"dockerfile": "Dockerfile"

View File

@@ -1,22 +1,34 @@
package main
import (
"flag"
"fmt"
"log"
"os"
"github.com/get-drexa/drexa/internal/drexa"
"github.com/joho/godotenv"
)
func main() {
_ = godotenv.Load()
configPath := flag.String("config", "", "path to config file (required)")
flag.Parse()
config, err := drexa.ServerConfigFromEnv()
if *configPath == "" {
fmt.Fprintln(os.Stderr, "error: --config is required")
flag.Usage()
os.Exit(1)
}
config, err := drexa.ConfigFromFile(*configPath)
if err != nil {
log.Fatalf("failed to load config: %v", err)
}
server, err := drexa.NewServer(*config)
if err != nil {
log.Fatal(err)
}
server := drexa.NewServer(*config)
log.Fatal(server.Listen(fmt.Sprintf(":%d", config.Port)))
log.Printf("starting server on :%d", config.Server.Port)
log.Fatal(server.Listen(fmt.Sprintf(":%d", config.Server.Port)))
}

View File

@@ -1,13 +1,37 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"os"
"github.com/get-drexa/drexa/internal/database"
"github.com/get-drexa/drexa/internal/drexa"
)
func main() {
if err := database.RunMigrations(); err != nil {
log.Fatalf("Failed to run migrations: %v", err)
configPath := flag.String("config", "", "path to config file (required)")
flag.Parse()
if *configPath == "" {
fmt.Fprintln(os.Stderr, "error: --config is required")
flag.Usage()
os.Exit(1)
}
config, err := drexa.ConfigFromFile(*configPath)
if err != nil {
log.Fatalf("failed to load config: %v", err)
}
db := database.NewFromPostgres(config.Database.PostgresURL)
defer db.Close()
log.Println("running migrations...")
if err := database.RunMigrations(context.Background(), db); err != nil {
log.Fatalf("failed to run migrations: %v", err)
}
log.Println("migrations completed successfully")
}

View File

@@ -0,0 +1,30 @@
# Drexa Backend Configuration
# Copy this file to config.yaml and adjust values for your environment.
server:
port: 8080
database:
postgres_url: postgres://user:password@localhost:5432/drexa?sslmode=disable
jwt:
issuer: drexa
audience: drexa-api
# Secret key can be provided via (in order of precedence):
# 1. JWT_SECRET_KEY environment variable (base64 encoded)
# 2. secret_key_base64 below (base64 encoded)
# 3. secret_key_path below (file with base64 encoded content)
# secret_key_base64: "base64encodedkey"
secret_key_path: /run/secrets/jwt_secret_key
storage:
# Mode: "flat" (UUID-based keys) or "hierarchical" (path-based keys)
# Note: S3 backend only supports "flat" mode
mode: flat
# Backend: "fs" (filesystem) or "s3" (not yet implemented)
backend: fs
# Required when backend is "fs"
root_path: /var/lib/drexa/blobs
# Required when backend is "s3"
# bucket: my-drexa-bucket

15
apps/backend/config.yaml Normal file
View File

@@ -0,0 +1,15 @@
server:
port: 8080
database:
postgres_url: postgres://drexa:hunter2@helian:5433/drexa?sslmode=disable
jwt:
issuer: drexa
audience: drexa-api
secret_key_base64: "pNeUExoqdakfecZLFL53NJpY4iB9zFot9EuEBItlYKY="
storage:
mode: hierarchical
backend: fs
root_path: ./data

View File

@@ -3,16 +3,16 @@ module github.com/get-drexa/drexa
go 1.25.4
require (
github.com/gabriel-vasile/mimetype v1.4.11
github.com/gofiber/fiber/v2 v2.52.9
github.com/google/uuid v1.6.0
github.com/sqids/sqids-go v0.4.1
github.com/uptrace/bun v1.2.15
golang.org/x/crypto v0.40.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/gabriel-vasile/mimetype v1.4.11 // indirect
github.com/joho/godotenv v1.5.1 // indirect
github.com/sqids/sqids-go v0.4.1 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect
mellium.im/sasl v0.3.2 // indirect
@@ -20,7 +20,7 @@ require (
require (
github.com/andybalholm/brotli v1.1.0 // indirect
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/klauspost/compress v1.17.9 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
@@ -28,7 +28,6 @@ require (
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/stretchr/testify v1.10.0 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/uptrace/bun/dialect/pgdialect v1.2.15
github.com/uptrace/bun/driver/pgdriver v1.2.15

View File

@@ -8,20 +8,24 @@ github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5
github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
@@ -59,6 +63,9 @@ golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+Zdx
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
mellium.im/sasl v0.3.2 h1:PT6Xp7ccn9XaXAnJ03FcEjmAn7kK1x7aoXV6F+Vmrl0=

View File

@@ -2,13 +2,13 @@ package auth
import (
"errors"
"log/slog"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
)
const authServiceKey = "authService"
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
@@ -26,29 +26,35 @@ type loginResponse struct {
RefreshToken string `json:"refreshToken"`
}
func RegisterAPIRoutes(api fiber.Router, s *Service) {
auth := api.Group("/auth", func(c *fiber.Ctx) error {
c.Locals(authServiceKey, s)
return c.Next()
})
auth.Post("/login", login)
auth.Post("/register", register)
type HTTPHandler struct {
service *Service
db *bun.DB
}
func mustAuthService(c *fiber.Ctx) *Service {
return c.Locals(authServiceKey).(*Service)
func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
return &HTTPHandler{service: s, db: db}
}
func login(c *fiber.Ctx) error {
s := mustAuthService(c)
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
auth := api.Group("/auth")
auth.Post("/login", h.Login)
auth.Post("/register", h.Register)
}
func (h *HTTPHandler) Login(c *fiber.Ctx) error {
req := new(loginRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
result, err := s.LoginWithEmailAndPassword(c.Context(), req.Email, req.Password)
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
defer tx.Rollback()
result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
if err != nil {
if errors.Is(err, ErrInvalidCredentials) {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
@@ -56,6 +62,10 @@ func login(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
if err := tx.Commit(); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
return c.JSON(loginResponse{
User: *result.User,
AccessToken: result.AccessToken,
@@ -63,15 +73,20 @@ func login(c *fiber.Ctx) error {
})
}
func register(c *fiber.Ctx) error {
s := mustAuthService(c)
func (h *HTTPHandler) Register(c *fiber.Ctx) error {
req := new(registerRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
result, err := s.Register(c.Context(), registerOptions{
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
slog.Error("failed to begin transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
defer tx.Rollback()
result, err := h.service.Register(c.Context(), tx, registerOptions{
email: req.Email,
password: req.Password,
displayName: req.DisplayName,
@@ -81,6 +96,12 @@ func register(c *fiber.Ctx) error {
if errors.As(err, &ae) {
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"})
}
slog.Error("failed to register user", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
if err := tx.Commit(); err != nil {
slog.Error("failed to commit transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}

View File

@@ -6,13 +6,14 @@ import (
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
)
const authenticatedUserKey = "authenticatedUser"
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token.
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
func NewBearerAuthMiddleware(s *Service) fiber.Handler {
func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
return func(c *fiber.Ctx) error {
authHeader := c.Get("Authorization")
if authHeader == "" {
@@ -25,7 +26,7 @@ func NewBearerAuthMiddleware(s *Service) fiber.Handler {
}
token := parts[1]
u, err := s.AuthenticateWithAccessToken(c.Context(), token)
u, err := s.AuthenticateWithAccessToken(c.Context(), db, token)
if err != nil {
var e *InvalidAccessTokenError
if errors.As(err, &e) {

View File

@@ -20,7 +20,6 @@ type LoginResult struct {
var ErrInvalidCredentials = errors.New("invalid credentials")
type Service struct {
db *bun.DB
userService *user.Service
tokenConfig TokenConfig
}
@@ -31,16 +30,15 @@ type registerOptions struct {
password string
}
func NewService(db *bun.DB, userService *user.Service, tokenConfig TokenConfig) *Service {
func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
return &Service{
db: db,
userService: userService,
tokenConfig: tokenConfig,
}
}
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain string) (*LoginResult, error) {
u, err := s.userService.UserByEmail(ctx, email)
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*LoginResult, error) {
u, err := s.userService.UserByEmail(ctx, db, email)
if err != nil {
var nf *user.NotFoundError
if errors.As(err, &nf) {
@@ -64,7 +62,7 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain st
return nil, err
}
_, err = s.db.NewInsert().Model(rt).Exec(ctx)
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
@@ -76,13 +74,13 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain st
}, nil
}
func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) {
func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions) (*LoginResult, error) {
hashed, err := password.Hash(opts.password)
if err != nil {
return nil, err
}
u, err := s.userService.RegisterUser(ctx, user.UserRegistrationOptions{
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
Email: opts.email,
DisplayName: opts.displayName,
Password: hashed,
@@ -101,7 +99,7 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
return nil, err
}
_, err = s.db.NewInsert().Model(rt).Exec(ctx)
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
@@ -113,7 +111,7 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
}, nil
}
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string) (*user.User, error) {
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
claims, err := ParseAccessToken(token, &s.tokenConfig)
if err != nil {
return nil, err
@@ -124,5 +122,5 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string)
return nil, newInvalidAccessTokenError(err)
}
return s.userService.UserByID(ctx, id)
return s.userService.UserByID(ctx, db, id)
}

View File

@@ -36,6 +36,10 @@ type RefreshToken struct {
CreatedAt time.Time `bun:"created_at,notnull"`
}
func newTokenID() (uuid.UUID, error) {
return uuid.NewV7()
}
func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) {
now := time.Now()
@@ -63,9 +67,9 @@ func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
}
id, err := uuid.NewV7()
id, err := newTokenID()
if err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
return nil, fmt.Errorf("failed to generate token ID: %w", err)
}
h := sha256.Sum256(buf)

View File

@@ -24,6 +24,10 @@ func NewFSStore(config FSStoreConfig) *FSStore {
return &FSStore{config: config}
}
func (s *FSStore) Initialize(ctx context.Context) error {
return os.MkdirAll(s.config.Root, 0755)
}
func (s *FSStore) GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error) {
return s.config.UploadURL, nil
}

View File

@@ -15,6 +15,7 @@ type UpdateOptions struct {
}
type Store interface {
Initialize(ctx context.Context) error
GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error)
Put(ctx context.Context, key Key, reader io.Reader) error
Update(ctx context.Context, key Key, opts UpdateOptions) error

View File

@@ -1,17 +1,28 @@
package database
import (
"context"
"embed"
"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate"
)
//go:embed migrations/*.sql
var sqlMigrations embed.FS
// RunMigrations discovers and runs all migrations in the migrations directory.
// Currently, the migrations directory is in internal/db/migrations.
func RunMigrations() error {
m := migrate.NewMigrations()
return m.Discover(sqlMigrations)
// RunMigrations discovers and runs all migrations against the database.
func RunMigrations(ctx context.Context, db *bun.DB) error {
migrations := migrate.NewMigrations()
if err := migrations.Discover(sqlMigrations); err != nil {
return err
}
migrator := migrate.NewMigrator(db, migrations)
if err := migrator.Init(ctx); err != nil {
return err
}
_, err := migrator.Migrate(ctx)
return err
}

View File

@@ -1,32 +1,9 @@
-- Enable UUID extension for UUIDv7 support
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
-- UUIDv7 generation function (timestamp-ordered UUIDs)
-- Based on the draft RFC: https://datatracker.ietf.org/doc/html/draft-peabody-dispatch-new-uuid-format
CREATE OR REPLACE FUNCTION uuid_generate_v7()
RETURNS UUID
AS $$
DECLARE
unix_ts_ms BIGINT;
uuid_bytes BYTEA;
BEGIN
unix_ts_ms = (EXTRACT(EPOCH FROM CLOCK_TIMESTAMP()) * 1000)::BIGINT;
uuid_bytes = OVERLAY(gen_random_bytes(16) PLACING
SUBSTRING(INT8SEND(unix_ts_ms) FROM 3) FROM 1 FOR 6
);
-- Set version (7) and variant bits
uuid_bytes = SET_BYTE(uuid_bytes, 6, (GET_BYTE(uuid_bytes, 6) & 15) | 112);
uuid_bytes = SET_BYTE(uuid_bytes, 8, (GET_BYTE(uuid_bytes, 8) & 63) | 128);
RETURN ENCODE(uuid_bytes, 'hex')::UUID;
END;
$$ LANGUAGE plpgsql VOLATILE;
-- ============================================================================
-- Application Tables
-- ============================================================================
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT uuid_generate_v7(),
id UUID PRIMARY KEY,
display_name TEXT,
email TEXT NOT NULL UNIQUE,
password TEXT NOT NULL,
@@ -39,7 +16,7 @@ CREATE TABLE IF NOT EXISTS users (
CREATE INDEX idx_users_email ON users(email);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id UUID PRIMARY KEY DEFAULT uuid_generate_v7(),
id UUID PRIMARY KEY,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token_hash TEXT NOT NULL UNIQUE,
expires_at TIMESTAMPTZ NOT NULL,
@@ -52,7 +29,7 @@ CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
-- Virtual filesystem nodes (unified files + directories)
CREATE TABLE IF NOT EXISTS vfs_nodes (
id UUID PRIMARY KEY DEFAULT uuid_generate_v7(),
id UUID PRIMARY KEY,
public_id TEXT NOT NULL UNIQUE, -- opaque ID for external API (no timestamp leak)
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
parent_id UUID REFERENCES vfs_nodes(id) ON DELETE CASCADE, -- NULL = root directory
@@ -83,7 +60,7 @@ CREATE UNIQUE INDEX idx_vfs_nodes_user_root ON vfs_nodes(user_id) WHERE parent_i
CREATE INDEX idx_vfs_nodes_pending ON vfs_nodes(created_at) WHERE status = 'pending'; -- for cleanup job
CREATE TABLE IF NOT EXISTS node_shares (
id UUID PRIMARY KEY DEFAULT uuid_generate_v7(),
id UUID PRIMARY KEY,
node_id UUID NOT NULL REFERENCES vfs_nodes(id) ON DELETE CASCADE,
share_token TEXT NOT NULL UNIQUE,
expires_at TIMESTAMPTZ,

View File

@@ -2,6 +2,7 @@ package database
import (
"database/sql"
"time"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
@@ -10,6 +11,17 @@ import (
func NewFromPostgres(url string) *bun.DB {
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(url)))
// Configure connection pool to prevent "database closed" errors
// SetMaxOpenConns sets the maximum number of open connections to the database
sqldb.SetMaxOpenConns(25)
// SetMaxIdleConns sets the maximum number of connections in the idle connection pool
sqldb.SetMaxIdleConns(5)
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused
sqldb.SetConnMaxLifetime(5 * time.Minute)
// SetConnMaxIdleTime sets the maximum amount of time a connection may be idle
sqldb.SetConnMaxIdleTime(10 * time.Minute)
db := bun.NewDB(sqldb, pgdialect.New())
return db
}

View File

@@ -0,0 +1,155 @@
package drexa
import (
"encoding/base64"
"errors"
"fmt"
"os"
"gopkg.in/yaml.v3"
)
type StorageMode string
type StorageBackend string
const (
StorageModeFlat StorageMode = "flat"
StorageModeHierarchical StorageMode = "hierarchical"
)
const (
StorageBackendFS StorageBackend = "fs"
StorageBackendS3 StorageBackend = "s3"
)
type Config struct {
Server ServerConfig `yaml:"server"`
Database DatabaseConfig `yaml:"database"`
JWT JWTConfig `yaml:"jwt"`
Storage StorageConfig `yaml:"storage"`
}
type ServerConfig struct {
Port int `yaml:"port"`
}
type DatabaseConfig struct {
PostgresURL string `yaml:"postgres_url"`
}
type JWTConfig struct {
Issuer string `yaml:"issuer"`
Audience string `yaml:"audience"`
SecretKeyBase64 string `yaml:"secret_key_base64"`
SecretKeyPath string `yaml:"secret_key_path"`
SecretKey []byte `yaml:"-"`
}
type StorageConfig struct {
Mode StorageMode `yaml:"mode"`
Backend StorageBackend `yaml:"backend"`
RootPath string `yaml:"root_path"`
Bucket string `yaml:"bucket"`
}
// ConfigFromFile loads configuration from a YAML file.
// JWT secret key is loaded from JWT_SECRET_KEY env var (base64 encoded),
// falling back to the file path specified in jwt.secret_key_path.
func ConfigFromFile(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("config file not found: %s", path)
}
return nil, err
}
var config Config
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, err
}
// Load JWT secret key (priority: env var > config base64 > config file path)
if envKey := os.Getenv("JWT_SECRET_KEY"); envKey != "" {
key, err := base64.StdEncoding.DecodeString(envKey)
if err != nil {
return nil, errors.New("JWT_SECRET_KEY env var is not valid base64")
}
config.JWT.SecretKey = key
} else if config.JWT.SecretKeyBase64 != "" {
key, err := base64.StdEncoding.DecodeString(config.JWT.SecretKeyBase64)
if err != nil {
return nil, errors.New("jwt.secret_key_base64 is not valid base64")
}
config.JWT.SecretKey = key
} else if config.JWT.SecretKeyPath != "" {
keyData, err := os.ReadFile(config.JWT.SecretKeyPath)
if err != nil {
return nil, err
}
key, err := base64.StdEncoding.DecodeString(string(keyData))
if err != nil {
return nil, errors.New("jwt.secret_key_path file content is not valid base64")
}
config.JWT.SecretKey = key
}
if errs := config.Validate(); len(errs) > 0 {
return nil, NewConfigError(errs...)
}
return &config, nil
}
// Validate checks for required configuration fields.
func (c *Config) Validate() []error {
var errs []error
// Server
if c.Server.Port == 0 {
errs = append(errs, errors.New("server.port is required"))
}
// Database
if c.Database.PostgresURL == "" {
errs = append(errs, errors.New("database.postgres_url is required"))
}
// JWT
if c.JWT.Issuer == "" {
errs = append(errs, errors.New("jwt.issuer is required"))
}
if c.JWT.Audience == "" {
errs = append(errs, errors.New("jwt.audience is required"))
}
if len(c.JWT.SecretKey) == 0 {
errs = append(errs, errors.New("jwt secret key is required (set JWT_SECRET_KEY env var, jwt.secret_key_base64, or jwt.secret_key_path)"))
}
// Storage
if c.Storage.Mode == "" {
errs = append(errs, errors.New("storage.mode is required"))
} else if c.Storage.Mode != StorageModeFlat && c.Storage.Mode != StorageModeHierarchical {
errs = append(errs, errors.New("storage.mode must be 'flat' or 'hierarchical'"))
}
if c.Storage.Backend == "" {
errs = append(errs, errors.New("storage.backend is required"))
} else if c.Storage.Backend != StorageBackendFS && c.Storage.Backend != StorageBackendS3 {
errs = append(errs, errors.New("storage.backend must be 'fs' or 's3'"))
}
if c.Storage.Backend == StorageBackendFS && c.Storage.RootPath == "" {
errs = append(errs, errors.New("storage.root_path is required when backend is 'fs'"))
}
if c.Storage.Backend == StorageBackendS3 {
if c.Storage.Bucket == "" {
errs = append(errs, errors.New("storage.bucket is required when backend is 's3'"))
}
if c.Storage.Mode == StorageModeHierarchical {
errs = append(errs, errors.New("storage.mode must be 'flat' when backend is 's3'"))
}
}
return errs
}

View File

@@ -5,17 +5,17 @@ import (
"strings"
)
type ServerConfigError struct {
type ConfigError struct {
Errors []error
}
func NewServerConfigError(errs ...error) *ServerConfigError {
return &ServerConfigError{Errors: errs}
func NewConfigError(errs ...error) *ConfigError {
return &ConfigError{Errors: errs}
}
func (e *ServerConfigError) Error() string {
func (e *ConfigError) Error() string {
sb := strings.Builder{}
sb.WriteString("invalid server config:\n")
sb.WriteString("invalid config:\n")
for _, err := range e.Errors {
sb.WriteString(fmt.Sprintf(" - %s\n", err.Error()))
}

View File

@@ -1,86 +1,70 @@
package drexa
import (
"encoding/hex"
"errors"
"context"
"fmt"
"os"
"strconv"
"github.com/get-drexa/drexa/internal/auth"
"github.com/get-drexa/drexa/internal/blob"
"github.com/get-drexa/drexa/internal/database"
"github.com/get-drexa/drexa/internal/upload"
"github.com/get-drexa/drexa/internal/user"
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger"
)
type ServerConfig struct {
Port int
PostgresURL string
JWTIssuer string
JWTAudience string
JWTSecretKey []byte
func NewServer(c Config) (*fiber.App, error) {
app := fiber.New()
db := database.NewFromPostgres(c.Database.PostgresURL)
app.Use(logger.New())
// Initialize blob store based on config
var blobStore blob.Store
switch c.Storage.Backend {
case StorageBackendFS:
blobStore = blob.NewFSStore(blob.FSStoreConfig{
Root: c.Storage.RootPath,
})
case StorageBackendS3:
return nil, fmt.Errorf("s3 storage backend not yet implemented")
default:
return nil, fmt.Errorf("unknown storage backend: %s", c.Storage.Backend)
}
func NewServer(c ServerConfig) *fiber.App {
app := fiber.New()
db := database.NewFromPostgres(c.PostgresURL)
err := blobStore.Initialize(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to initialize blob store: %w", err)
}
userService := user.NewService(db)
authService := auth.NewService(db, userService, auth.TokenConfig{
Issuer: c.JWTIssuer,
Audience: c.JWTAudience,
SecretKey: c.JWTSecretKey,
// Initialize key resolver based on config
var keyResolver virtualfs.BlobKeyResolver
switch c.Storage.Mode {
case StorageModeFlat:
keyResolver = virtualfs.NewFlatKeyResolver()
case StorageModeHierarchical:
keyResolver = virtualfs.NewHierarchicalKeyResolver(db)
default:
return nil, fmt.Errorf("unknown storage mode: %s", c.Storage.Mode)
}
vfs, err := virtualfs.NewVirtualFS(db, blobStore, keyResolver)
if err != nil {
return nil, fmt.Errorf("failed to create virtual file system: %w", err)
}
userService := user.NewService()
authService := auth.NewService(userService, auth.TokenConfig{
Issuer: c.JWT.Issuer,
Audience: c.JWT.Audience,
SecretKey: c.JWT.SecretKey,
})
uploadService := upload.NewService(vfs, blobStore)
api := app.Group("/api")
auth.RegisterAPIRoutes(api, authService)
auth.NewHTTPHandler(authService, db).RegisterRoutes(api)
upload.NewHTTPHandler(uploadService).RegisterRoutes(api)
return app
}
// ServerConfigFromEnv creates a ServerConfig from environment variables.
func ServerConfigFromEnv() (*ServerConfig, error) {
c := ServerConfig{
PostgresURL: os.Getenv("POSTGRES_URL"),
JWTIssuer: os.Getenv("JWT_ISSUER"),
JWTAudience: os.Getenv("JWT_AUDIENCE"),
}
errs := []error{}
keyHex := os.Getenv("JWT_SECRET_KEY")
if keyHex == "" {
errs = append(errs, errors.New("JWT_SECRET_KEY is required"))
} else {
k, err := hex.DecodeString(keyHex)
if err != nil {
errs = append(errs, fmt.Errorf("failed to decode JWT_SECRET_KEY: %w", err))
}
c.JWTSecretKey = k
}
p, err := strconv.Atoi(os.Getenv("PORT"))
if err != nil {
errs = append(errs, fmt.Errorf("failed to parse PORT: %w", err))
}
c.Port = p
if c.PostgresURL == "" {
errs = append(errs, errors.New("POSTGRES_URL is required"))
}
if c.JWTIssuer == "" {
errs = append(errs, errors.New("JWT_ISSUER is required"))
}
if c.JWTAudience == "" {
errs = append(errs, errors.New("JWT_AUDIENCE is required"))
}
if len(c.JWTSecretKey) == 0 {
errs = append(errs, errors.New("JWT_SECRET_KEY is required"))
}
if len(errs) > 0 {
return nil, NewServerConfigError(errs...)
}
return &c, nil
return app, nil
}

View File

@@ -0,0 +1,97 @@
package upload
import (
"errors"
"github.com/get-drexa/drexa/internal/auth"
"github.com/gofiber/fiber/v2"
)
type createUploadRequest struct {
ParentID string `json:"parentId"`
Name string `json:"name"`
}
type updateUploadRequest struct {
Status Status `json:"status"`
}
type HTTPHandler struct {
service *Service
}
func NewHTTPHandler(s *Service) *HTTPHandler {
return &HTTPHandler{service: s}
}
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
upload := api.Group("/uploads")
upload.Post("/", h.Create)
upload.Put("/:uploadID/content", h.ReceiveContent)
upload.Patch("/:uploadID", h.Update)
}
func (h *HTTPHandler) Create(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
}
req := new(createUploadRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
upload, err := h.service.CreateUpload(c.Context(), u.ID, CreateUploadOptions{
ParentID: req.ParentID,
Name: req.Name,
})
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
return c.JSON(upload)
}
func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
}
uploadID := c.Params("uploadID")
err = h.service.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream())
defer c.Request().CloseBodyStream()
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
return c.SendStatus(fiber.StatusNoContent)
}
func (h *HTTPHandler) Update(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
}
req := new(updateUploadRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
if req.Status == StatusCompleted {
upload, err := h.service.CompleteUpload(c.Context(), u.ID, c.Params("uploadID"))
if err != nil {
if errors.Is(err, ErrNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
return c.JSON(upload)
}
return c.SendStatus(fiber.StatusBadRequest)
}

View File

@@ -33,17 +33,6 @@ type CreateUploadOptions struct {
Name string
}
type Upload struct {
ID string
TargetNode *virtualfs.Node
UploadURL string
}
type CreateUploadResult struct {
UploadID string
UploadURL string
}
func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts CreateUploadOptions) (*Upload, error) {
parentNode, err := s.vfs.FindNodeByPublicID(ctx, userID, opts.ParentID)
if err != nil {
@@ -77,6 +66,7 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
upload := &Upload{
ID: node.PublicID,
Status: StatusPending,
TargetNode: node,
UploadURL: uploadURL,
}
@@ -106,36 +96,37 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID
return err
}
s.pendingUploads.Delete(uploadID)
upload.Status = StatusCompleted
return nil
}
func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID string) error {
func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID string) (*Upload, error) {
n, ok := s.pendingUploads.Load(uploadID)
if !ok {
return ErrNotFound
return nil, ErrNotFound
}
upload, ok := n.(*Upload)
if !ok {
return ErrNotFound
return nil, ErrNotFound
}
if upload.TargetNode.UserID != userID {
return ErrNotFound
return nil, ErrNotFound
}
if upload.TargetNode.Status == virtualfs.NodeStatusReady {
return nil
if upload.TargetNode.Status == virtualfs.NodeStatusReady && upload.Status == StatusCompleted {
return upload, nil
}
err := s.vfs.WriteFile(ctx, upload.TargetNode, virtualfs.FileContentFromBlobKey(upload.TargetNode.BlobKey))
if err != nil {
return err
return nil, err
}
upload.Status = StatusCompleted
s.pendingUploads.Delete(uploadID)
return nil
return upload, nil
}

View File

@@ -0,0 +1,18 @@
package upload
import "github.com/get-drexa/drexa/internal/virtualfs"
type Status string
const (
StatusPending Status = "pending"
StatusCompleted Status = "completed"
StatusFailed Status = "failed"
)
type Upload struct {
ID string `json:"id"`
Status Status `json:"status"`
TargetNode *virtualfs.Node `json:"-"`
UploadURL string `json:"uploadUrl"`
}

View File

@@ -11,9 +11,7 @@ import (
"github.com/uptrace/bun"
)
type Service struct {
db *bun.DB
}
type Service struct{}
type UserRegistrationOptions struct {
Email string
@@ -21,20 +19,24 @@ type UserRegistrationOptions struct {
Password password.Hashed
}
func NewService(db *bun.DB) *Service {
return &Service{
db: db,
}
func NewService() *Service {
return &Service{}
}
func (s *Service) RegisterUser(ctx context.Context, db bun.IDB, opts UserRegistrationOptions) (*User, error) {
uid, err := newUserID()
if err != nil {
return nil, err
}
func (s *Service) RegisterUser(ctx context.Context, opts UserRegistrationOptions) (*User, error) {
u := User{
ID: uid,
Email: opts.Email,
DisplayName: opts.DisplayName,
Password: opts.Password,
}
_, err := s.db.NewInsert().Model(&u).Returning("*").Exec(ctx)
_, err = db.NewInsert().Model(&u).Returning("*").Exec(ctx)
if err != nil {
if database.IsUniqueViolation(err) {
return nil, newAlreadyExistsError(u.Email)
@@ -45,9 +47,9 @@ func (s *Service) RegisterUser(ctx context.Context, opts UserRegistrationOptions
return &u, nil
}
func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) {
func (s *Service) UserByID(ctx context.Context, db bun.IDB, id uuid.UUID) (*User, error) {
var user User
err := s.db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx)
err := db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, newNotFoundError(id, "")
@@ -57,9 +59,9 @@ func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) {
return &user, nil
}
func (s *Service) UserByEmail(ctx context.Context, email string) (*User, error) {
func (s *Service) UserByEmail(ctx context.Context, db bun.IDB, email string) (*User, error) {
var user User
err := s.db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx)
err := db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, newNotFoundError(uuid.Nil, email)
@@ -69,6 +71,6 @@ func (s *Service) UserByEmail(ctx context.Context, email string) (*User, error)
return &user, nil
}
func (s *Service) UserExistsByEmail(ctx context.Context, email string) (bool, error) {
return s.db.NewSelect().Model(&User{}).Where("email = ?", email).Exists(ctx)
func (s *Service) UserExistsByEmail(ctx context.Context, db bun.IDB, email string) (bool, error) {
return db.NewSelect().Model(&User{}).Where("email = ?", email).Exists(ctx)
}

View File

@@ -1,6 +1,8 @@
package user
import (
"time"
"github.com/get-drexa/drexa/internal/password"
"github.com/google/uuid"
"github.com/uptrace/bun"
@@ -10,9 +12,15 @@ type User struct {
bun.BaseModel `bun:"users"`
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
DisplayName string `bun:"display_name,notnull" json:"displayName"`
DisplayName string `bun:"display_name" json:"displayName"`
Email string `bun:"email,unique,notnull" json:"email"`
Password password.Hashed `bun:"password,notnull" json:"-"`
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
}
func newUserID() (uuid.UUID, error) {
return uuid.NewV7()
}

View File

@@ -15,6 +15,10 @@ type BlobKeyResolver interface {
type FlatKeyResolver struct{}
func NewFlatKeyResolver() *FlatKeyResolver {
return &FlatKeyResolver{}
}
func (r *FlatKeyResolver) KeyMode() blob.KeyMode {
return blob.KeyModeStable
}
@@ -34,6 +38,10 @@ type HierarchicalKeyResolver struct {
db *bun.DB
}
func NewHierarchicalKeyResolver(db *bun.DB) *HierarchicalKeyResolver {
return &HierarchicalKeyResolver{db: db}
}
func (r *HierarchicalKeyResolver) KeyMode() blob.KeyMode {
return blob.KeyModeDerived
}