11 Commits

27 changed files with 705 additions and 276 deletions

View File

@@ -7,14 +7,16 @@ require (
github.com/gofiber/fiber/v2 v2.52.9
github.com/google/uuid v1.6.0
github.com/sqids/sqids-go v0.4.1
github.com/uptrace/bun v1.2.15
golang.org/x/crypto v0.40.0
github.com/uptrace/bun v1.2.16
github.com/uptrace/bun/extra/bundebug v1.2.16
golang.org/x/crypto v0.45.0
gopkg.in/yaml.v3 v3.0.1
)
require (
go.opentelemetry.io/otel v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect
github.com/fatih/color v1.18.0 // indirect
go.opentelemetry.io/otel v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
mellium.im/sasl v0.3.2 // indirect
)
@@ -29,12 +31,12 @@ require (
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/uptrace/bun/dialect/pgdialect v1.2.15
github.com/uptrace/bun/driver/pgdriver v1.2.15
github.com/uptrace/bun/dialect/pgdialect v1.2.16
github.com/uptrace/bun/driver/pgdriver v1.2.16
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.51.0 // indirect
github.com/valyala/tcplisten v1.0.0 // indirect
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
golang.org/x/sys v0.34.0 // indirect
golang.org/x/sys v0.38.0 // indirect
)

View File

@@ -2,6 +2,8 @@ github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik=
github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw=
@@ -34,16 +36,18 @@ github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038=
github.com/uptrace/bun/dialect/pgdialect v1.2.15 h1:er+/3giAIqpfrXJw+KP9B7ujyQIi5XkPnFmgjAVL6bA=
github.com/uptrace/bun/dialect/pgdialect v1.2.15/go.mod h1:QSiz6Qpy9wlGFsfpf7UMSL6mXAL1jDJhFwuOVacCnOQ=
github.com/uptrace/bun/driver/pgdriver v1.2.15 h1:eZZ60ZtUUE6jjv6VAI1pCMaTgtx3sxmChQzwbvchOOo=
github.com/uptrace/bun/driver/pgdriver v1.2.15/go.mod h1:s2zz/BAeScal4KLFDI8PURwATN8s9RDBsElEbnPAjv4=
github.com/uptrace/bun v1.2.16 h1:QlObi6ZIK5Ao7kAALnh91HWYNZUBbVwye52fmlQM9kc=
github.com/uptrace/bun v1.2.16/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM=
github.com/uptrace/bun/dialect/pgdialect v1.2.16 h1:KFNZ0LxAyczKNfK/IJWMyaleO6eI9/Z5tUv3DE1NVL4=
github.com/uptrace/bun/dialect/pgdialect v1.2.16/go.mod h1:IJdMeV4sLfh0LDUZl7TIxLI0LipF1vwTK3hBC7p5qLo=
github.com/uptrace/bun/driver/pgdriver v1.2.16 h1:b1kpXKUxtTSGYow5Vlsb+dKV3z0R7aSAJNfMfKp61ZU=
github.com/uptrace/bun/driver/pgdriver v1.2.16/go.mod h1:H6lUZ9CBfp1X5Vq62YGSV7q96/v94ja9AYFjKvdoTk0=
github.com/uptrace/bun/extra/bundebug v1.2.16 h1:3OXAfHTU4ydu2+4j05oB1BxPx6+ypdWIVzTugl/7zl0=
github.com/uptrace/bun/extra/bundebug v1.2.16/go.mod h1:vk6R/1i67/S2RvUI5AH/m3P5e67mOkfDCmmCsAPUumo=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA=
@@ -54,15 +58,15 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -0,0 +1,23 @@
package account
import (
"time"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type Account struct {
bun.BaseModel `bun:"accounts"`
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
UserID uuid.UUID `bun:"user_id,notnull,type:uuid" json:"userId"`
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
CreatedAt time.Time `bun:"created_at,notnull,nullzero" json:"createdAt"`
UpdatedAt time.Time `bun:"updated_at,notnull,nullzero" json:"updatedAt"`
}
func newAccountID() (uuid.UUID, error) {
return uuid.NewV7()
}

View File

@@ -0,0 +1,8 @@
package account
import "errors"
var (
ErrAccountNotFound = errors.New("account not found")
ErrAccountAlreadyExists = errors.New("account already exists")
)

View File

@@ -0,0 +1,132 @@
package account
import (
"errors"
"github.com/get-drexa/drexa/internal/auth"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type HTTPHandler struct {
accountService *Service
authService *auth.Service
db *bun.DB
authMiddleware fiber.Handler
}
type registerAccountRequest struct {
Email string `json:"email"`
Password string `json:"password"`
DisplayName string `json:"displayName"`
}
type registerAccountResponse struct {
Account *Account `json:"account"`
User *user.User `json:"user"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
}
const currentAccountKey = "currentAccount"
func CurrentAccount(c *fiber.Ctx) *Account {
return c.Locals(currentAccountKey).(*Account)
}
func NewHTTPHandler(accountService *Service, authService *auth.Service, db *bun.DB, authMiddleware fiber.Handler) *HTTPHandler {
return &HTTPHandler{accountService: accountService, authService: authService, db: db, authMiddleware: authMiddleware}
}
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) fiber.Router {
api.Post("/accounts", h.registerAccount)
account := api.Group("/accounts/:accountID")
account.Use(h.authMiddleware)
account.Use(h.accountMiddleware)
account.Get("/", h.getAccount)
return account
}
func (h *HTTPHandler) accountMiddleware(c *fiber.Ctx) error {
user, err := auth.AuthenticatedUser(c)
if err != nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
accountID, err := uuid.Parse(c.Params("accountID"))
if err != nil {
return c.SendStatus(fiber.StatusNotFound)
}
account, err := h.accountService.AccountByID(c.Context(), h.db, user.ID, accountID)
if err != nil {
if errors.Is(err, ErrAccountNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
c.Locals(currentAccountKey, account)
return c.Next()
}
func (h *HTTPHandler) getAccount(c *fiber.Ctx) error {
account := CurrentAccount(c)
if account == nil {
return c.SendStatus(fiber.StatusNotFound)
}
return c.JSON(account)
}
func (h *HTTPHandler) registerAccount(c *fiber.Ctx) error {
req := new(registerAccountRequest)
if err := c.BodyParser(req); err != nil {
return c.SendStatus(fiber.StatusBadRequest)
}
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return httperr.Internal(err)
}
defer tx.Rollback()
acc, u, err := h.accountService.Register(c.Context(), tx, RegisterOptions{
Email: req.Email,
Password: req.Password,
DisplayName: req.DisplayName,
})
if err != nil {
var ae *user.AlreadyExistsError
if errors.As(err, &ae) {
return c.SendStatus(fiber.StatusConflict)
}
if errors.Is(err, ErrAccountAlreadyExists) {
return c.SendStatus(fiber.StatusConflict)
}
return httperr.Internal(err)
}
result, err := h.authService.GenerateTokenForUser(c.Context(), tx, u)
if err != nil {
return httperr.Internal(err)
}
err = tx.Commit()
if err != nil {
return httperr.Internal(err)
}
return c.JSON(registerAccountResponse{
Account: acc,
User: u,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
})
}

View File

@@ -0,0 +1,115 @@
package account
import (
"context"
"database/sql"
"errors"
"github.com/get-drexa/drexa/internal/database"
"github.com/get-drexa/drexa/internal/password"
"github.com/get-drexa/drexa/internal/user"
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type Service struct {
userService user.Service
vfs *virtualfs.VirtualFS
}
type RegisterOptions struct {
Email string
Password string
DisplayName string
}
type CreateAccountOptions struct {
OrganizationID uuid.UUID
QuotaBytes int64
}
func NewService(userService *user.Service, vfs *virtualfs.VirtualFS) *Service {
return &Service{
userService: *userService,
vfs: vfs,
}
}
func (s *Service) Register(ctx context.Context, db bun.IDB, opts RegisterOptions) (*Account, *user.User, error) {
hashed, err := password.Hash(opts.Password)
if err != nil {
return nil, nil, err
}
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
Email: opts.Email,
Password: hashed,
DisplayName: opts.DisplayName,
})
if err != nil {
return nil, nil, err
}
acc, err := s.CreateAccount(ctx, db, u.ID, CreateAccountOptions{
// TODO: make quota configurable
QuotaBytes: 1024 * 1024 * 1024, // 1GB
})
if err != nil {
return nil, nil, err
}
_, err = s.vfs.CreateDirectory(ctx, db, acc.ID, uuid.Nil, virtualfs.RootDirectoryName)
if err != nil {
return nil, nil, err
}
return acc, u, nil
}
func (s *Service) CreateAccount(ctx context.Context, db bun.IDB, userID uuid.UUID, opts CreateAccountOptions) (*Account, error) {
id, err := newAccountID()
if err != nil {
return nil, err
}
account := &Account{
ID: id,
UserID: userID,
StorageQuotaBytes: opts.QuotaBytes,
}
_, err = db.NewInsert().Model(account).Returning("*").Exec(ctx)
if err != nil {
if database.IsUniqueViolation(err) {
return nil, ErrAccountAlreadyExists
}
return nil, err
}
return account, nil
}
func (s *Service) AccountByUserID(ctx context.Context, db bun.IDB, userID uuid.UUID) (*Account, error) {
var account Account
err := db.NewSelect().Model(&account).Where("user_id = ?", userID).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrAccountNotFound
}
return nil, err
}
return &account, nil
}
func (s *Service) AccountByID(ctx context.Context, db bun.IDB, userID uuid.UUID, id uuid.UUID) (*Account, error) {
var account Account
err := db.NewSelect().Model(&account).Where("user_id = ?", userID).Where("id = ?", id).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrAccountNotFound
}
return nil, err
}
return &account, nil
}

View File

@@ -2,8 +2,8 @@ package auth
import (
"errors"
"log/slog"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
@@ -14,12 +14,6 @@ type loginRequest struct {
Password string `json:"password"`
}
type registerRequest struct {
Email string `json:"email"`
Password string `json:"password"`
DisplayName string `json:"displayName"`
}
type loginResponse struct {
User user.User `json:"user"`
AccessToken string `json:"accessToken"`
@@ -37,9 +31,7 @@ func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
auth := api.Group("/auth")
auth.Post("/login", h.Login)
auth.Post("/register", h.Register)
}
func (h *HTTPHandler) Login(c *fiber.Ctx) error {
@@ -50,59 +42,20 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
return httperr.Internal(err)
}
defer tx.Rollback()
result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
result, err := h.service.AuthenticateWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
if err != nil {
if errors.Is(err, ErrInvalidCredentials) {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
}
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
return httperr.Internal(err)
}
if err := tx.Commit(); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
return c.JSON(loginResponse{
User: *result.User,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
})
}
func (h *HTTPHandler) Register(c *fiber.Ctx) error {
req := new(registerRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
slog.Error("failed to begin transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
defer tx.Rollback()
result, err := h.service.Register(c.Context(), tx, registerOptions{
email: req.Email,
password: req.Password,
displayName: req.DisplayName,
})
if err != nil {
var ae *user.AlreadyExistsError
if errors.As(err, &ae) {
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"})
}
slog.Error("failed to register user", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
if err := tx.Commit(); err != nil {
slog.Error("failed to commit transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
return httperr.Internal(err)
}
return c.JSON(loginResponse{

View File

@@ -2,8 +2,10 @@ package auth
import (
"errors"
"log/slog"
"strings"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
@@ -17,11 +19,13 @@ func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
return func(c *fiber.Ctx) error {
authHeader := c.Get("Authorization")
if authHeader == "" {
slog.Info("no auth header")
return c.SendStatus(fiber.StatusUnauthorized)
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
slog.Info("invalid auth header")
return c.SendStatus(fiber.StatusUnauthorized)
}
@@ -30,15 +34,17 @@ func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
if err != nil {
var e *InvalidAccessTokenError
if errors.As(err, &e) {
slog.Info("invalid access token")
return c.SendStatus(fiber.StatusUnauthorized)
}
var nf *user.NotFoundError
if errors.As(err, &nf) {
slog.Info("user not found")
return c.SendStatus(fiber.StatusUnauthorized)
}
return c.SendStatus(fiber.StatusInternalServerError)
return httperr.Internal(err)
}
c.Locals(authenticatedUserKey, u)

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/hex"
"errors"
"log/slog"
"github.com/get-drexa/drexa/internal/password"
"github.com/get-drexa/drexa/internal/user"
@@ -11,7 +12,7 @@ import (
"github.com/uptrace/bun"
)
type LoginResult struct {
type AuthenticationResult struct {
User *user.User
AccessToken string
RefreshToken string
@@ -24,12 +25,6 @@ type Service struct {
tokenConfig TokenConfig
}
type registerOptions struct {
displayName string
email string
password string
}
func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
return &Service{
userService: userService,
@@ -37,7 +32,30 @@ func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
}
}
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*LoginResult, error) {
func (s *Service) GenerateTokenForUser(ctx context.Context, db bun.IDB, user *user.User) (*AuthenticationResult, error) {
at, err := GenerateAccessToken(user, &s.tokenConfig)
if err != nil {
return nil, err
}
rt, err := GenerateRefreshToken(user, &s.tokenConfig)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
return &AuthenticationResult{
User: user,
AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token),
}, nil
}
func (s *Service) AuthenticateWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*AuthenticationResult, error) {
u, err := s.userService.UserByEmail(ctx, db, email)
if err != nil {
var nf *user.NotFoundError
@@ -67,44 +85,7 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, ema
return nil, err
}
return &LoginResult{
User: u,
AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token),
}, nil
}
func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions) (*LoginResult, error) {
hashed, err := password.Hash(opts.password)
if err != nil {
return nil, err
}
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
Email: opts.email,
DisplayName: opts.displayName,
Password: hashed,
})
if err != nil {
return nil, err
}
at, err := GenerateAccessToken(u, &s.tokenConfig)
if err != nil {
return nil, err
}
rt, err := GenerateRefreshToken(u, &s.tokenConfig)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
return &LoginResult{
return &AuthenticationResult{
User: u,
AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token),
@@ -114,11 +95,13 @@ func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
claims, err := ParseAccessToken(token, &s.tokenConfig)
if err != nil {
slog.Info("failed to parse access token", "error", err)
return nil, err
}
id, err := uuid.Parse(claims.Subject)
if err != nil {
slog.Info("failed to parse access token subject", "error", err)
return nil, newInvalidAccessTokenError(err)
}

View File

@@ -33,7 +33,7 @@ type RefreshToken struct {
Token []byte `bun:"-"`
TokenHash string `bun:"token_hash,notnull"`
ExpiresAt time.Time `bun:"expires_at,notnull"`
CreatedAt time.Time `bun:"created_at,notnull"`
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
}
func newTokenID() (uuid.UUID, error) {

View File

@@ -16,8 +16,7 @@ type FSStore struct {
}
type FSStoreConfig struct {
Root string
UploadURL string
Root string
}
func NewFSStore(config FSStoreConfig) *FSStore {
@@ -28,10 +27,6 @@ func (s *FSStore) Initialize(ctx context.Context) error {
return os.MkdirAll(s.config.Root, 0755)
}
func (s *FSStore) GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error) {
return s.config.UploadURL, nil
}
func (s *FSStore) Put(ctx context.Context, key Key, reader io.Reader) error {
path := filepath.Join(s.config.Root, string(key))
@@ -149,3 +144,11 @@ func (s *FSStore) Move(ctx context.Context, srcKey, dstKey Key) error {
return nil
}
func (s *FSStore) SupportsDirectUpload() bool {
return false
}
func (s *FSStore) GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error) {
return "", nil
}

View File

@@ -2,13 +2,6 @@ package blob
type Key string
type KeyMode int
const (
KeyModeStable KeyMode = iota
KeyModeDerived
)
func (k Key) IsNil() bool {
return k == ""
}

View File

@@ -16,7 +16,6 @@ type UpdateOptions struct {
type Store interface {
Initialize(ctx context.Context) error
GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error)
Put(ctx context.Context, key Key, reader io.Reader) error
Update(ctx context.Context, key Key, opts UpdateOptions) error
Delete(ctx context.Context, key Key) error
@@ -25,4 +24,10 @@ type Store interface {
Read(ctx context.Context, key Key) (io.ReadCloser, error)
ReadRange(ctx context.Context, key Key, offset, length int64) (io.ReadCloser, error)
ReadSize(ctx context.Context, key Key) (int64, error)
// SupportsDirectUpload returns true if the store allows files to be uploaded directly to the blob store.
SupportsDirectUpload() bool
// GenerateUploadURL generates a URL that can be used to upload a file directly to the blob store. If unsupported, returns an empty string with no error.
GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error)
}

View File

@@ -7,14 +7,23 @@ CREATE TABLE IF NOT EXISTS users (
display_name TEXT,
email TEXT NOT NULL UNIQUE,
password TEXT NOT NULL,
storage_usage_bytes BIGINT NOT NULL,
storage_quota_bytes BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_users_email ON users(email);
CREATE TABLE IF NOT EXISTS accounts (
id UUID PRIMARY KEY,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
storage_usage_bytes BIGINT NOT NULL DEFAULT 0,
storage_quota_bytes BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_accounts_user_id ON accounts(user_id);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id UUID PRIMARY KEY,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
@@ -31,7 +40,7 @@ CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
CREATE TABLE IF NOT EXISTS vfs_nodes (
id UUID PRIMARY KEY,
public_id TEXT NOT NULL UNIQUE, -- opaque ID for external API (no timestamp leak)
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
account_id UUID NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
parent_id UUID REFERENCES vfs_nodes(id) ON DELETE CASCADE, -- NULL = root directory
kind TEXT NOT NULL CHECK (kind IN ('file', 'directory')),
status TEXT NOT NULL DEFAULT 'ready' CHECK (status IN ('pending', 'ready')),
@@ -46,17 +55,17 @@ CREATE TABLE IF NOT EXISTS vfs_nodes (
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ, -- soft delete for trash
-- No duplicate names in same parent (per user, excluding deleted)
CONSTRAINT unique_node_name UNIQUE NULLS NOT DISTINCT (user_id, parent_id, name, deleted_at)
-- No duplicate names in same parent (per account, excluding deleted)
CONSTRAINT unique_node_name UNIQUE NULLS NOT DISTINCT (account_id, parent_id, name, deleted_at)
);
CREATE INDEX idx_vfs_nodes_user_id ON vfs_nodes(user_id) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_account_id ON vfs_nodes(account_id) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_parent_id ON vfs_nodes(parent_id) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_user_parent ON vfs_nodes(user_id, parent_id) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_kind ON vfs_nodes(user_id, kind) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_deleted ON vfs_nodes(user_id, deleted_at) WHERE deleted_at IS NOT NULL;
CREATE INDEX idx_vfs_nodes_account_parent ON vfs_nodes(account_id, parent_id) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_kind ON vfs_nodes(account_id, kind) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_deleted ON vfs_nodes(account_id, deleted_at) WHERE deleted_at IS NOT NULL;
CREATE INDEX idx_vfs_nodes_public_id ON vfs_nodes(public_id);
CREATE UNIQUE INDEX idx_vfs_nodes_user_root ON vfs_nodes(user_id) WHERE parent_id IS NULL; -- one root per user
CREATE UNIQUE INDEX idx_vfs_nodes_account_root ON vfs_nodes(account_id) WHERE parent_id IS NULL; -- one root per account
CREATE INDEX idx_vfs_nodes_pending ON vfs_nodes(created_at) WHERE status = 'pending'; -- for cleanup job
CREATE TABLE IF NOT EXISTS node_shares (
@@ -91,4 +100,7 @@ CREATE TRIGGER update_vfs_nodes_updated_at BEFORE UPDATE ON vfs_nodes
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
CREATE TRIGGER update_node_shares_updated_at BEFORE UPDATE ON node_shares
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
CREATE TRIGGER update_accounts_updated_at BEFORE UPDATE ON accounts
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();

View File

@@ -4,22 +4,29 @@ import (
"context"
"fmt"
"github.com/get-drexa/drexa/internal/account"
"github.com/get-drexa/drexa/internal/auth"
"github.com/get-drexa/drexa/internal/blob"
"github.com/get-drexa/drexa/internal/database"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/upload"
"github.com/get-drexa/drexa/internal/user"
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/uptrace/bun/extra/bundebug"
)
func NewServer(c Config) (*fiber.App, error) {
app := fiber.New()
db := database.NewFromPostgres(c.Database.PostgresURL)
app := fiber.New(fiber.Config{
ErrorHandler: httperr.ErrorHandler,
StreamRequestBody: true,
})
app.Use(logger.New())
db := database.NewFromPostgres(c.Database.PostgresURL)
db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
// Initialize blob store based on config
var blobStore blob.Store
switch c.Storage.Backend {
@@ -49,7 +56,7 @@ func NewServer(c Config) (*fiber.App, error) {
return nil, fmt.Errorf("unknown storage mode: %s", c.Storage.Mode)
}
vfs, err := virtualfs.NewVirtualFS(db, blobStore, keyResolver)
vfs, err := virtualfs.NewVirtualFS(blobStore, keyResolver)
if err != nil {
return nil, fmt.Errorf("failed to create virtual file system: %w", err)
}
@@ -61,12 +68,16 @@ func NewServer(c Config) (*fiber.App, error) {
SecretKey: c.JWT.SecretKey,
})
uploadService := upload.NewService(vfs, blobStore)
accountService := account.NewService(userService, vfs)
authMiddleware := auth.NewBearerAuthMiddleware(authService, db)
api := app.Group("/api")
accRouter := account.NewHTTPHandler(accountService, authService, db, authMiddleware).RegisterRoutes(api)
auth.NewHTTPHandler(authService, db).RegisterRoutes(api)
upload.NewHTTPHandler(uploadService, authMiddleware).RegisterRoutes(api, authMiddleware)
upload.NewHTTPHandler(uploadService, db).RegisterRoutes(accRouter)
return app, nil
}

View File

@@ -0,0 +1,42 @@
package httperr
import (
"fmt"
"github.com/gofiber/fiber/v2"
)
// HTTPError represents an HTTP error with a status code and underlying error.
type HTTPError struct {
Code int
Message string
Err error
}
// Error implements the error interface.
func (e *HTTPError) Error() string {
if e.Err != nil {
return fmt.Sprintf("HTTP %d: %s: %v", e.Code, e.Message, e.Err)
}
return fmt.Sprintf("HTTP %d: %s", e.Code, e.Message)
}
// Unwrap returns the underlying error.
func (e *HTTPError) Unwrap() error {
return e.Err
}
// NewHTTPError creates a new HTTPError with the given status code, message, and underlying error.
func NewHTTPError(code int, message string, err error) *HTTPError {
return &HTTPError{
Code: code,
Message: message,
Err: err,
}
}
// Internal creates a new HTTPError with status 500.
func Internal(err error) *HTTPError {
return NewHTTPError(fiber.StatusInternalServerError, "Internal", err)
}

View File

@@ -0,0 +1,64 @@
package httperr
import (
"errors"
"log/slog"
"github.com/gofiber/fiber/v2"
)
// ErrorHandler is a global error handler for Fiber that logs errors and returns appropriate responses.
func ErrorHandler(c *fiber.Ctx, err error) error {
// Default status code
code := fiber.StatusInternalServerError
message := "Internal"
// Check if it's our custom HTTPError
var httpErr *HTTPError
if errors.As(err, &httpErr) {
code = httpErr.Code
message = httpErr.Message
// Log the error with underlying error details
if httpErr.Err != nil {
slog.Error("HTTP error",
"status", code,
"message", message,
"error", httpErr.Err.Error(),
"path", c.Path(),
"method", c.Method(),
)
} else {
slog.Warn("HTTP error",
"status", code,
"message", message,
"path", c.Path(),
"method", c.Method(),
)
}
} else {
// Check if it's a Fiber error
var fiberErr *fiber.Error
if errors.As(err, &fiberErr) {
code = fiberErr.Code
message = fiberErr.Message
} else {
// Generic error - log it
slog.Error("Unhandled error",
"status", code,
"error", err.Error(),
"path", c.Path(),
"method", c.Method(),
)
}
}
// Set Content-Type header
c.Set(fiber.HeaderContentType, fiber.MIMEApplicationJSONCharsetUTF8)
// Return JSON response
return c.Status(code).JSON(fiber.Map{
"error": message,
})
}

View File

@@ -6,4 +6,5 @@ var (
ErrNotFound = errors.New("not found")
ErrParentNotDirectory = errors.New("parent is not a directory")
ErrConflict = errors.New("node conflict")
ErrContentNotUploaded = errors.New("content has not been uploaded")
)

View File

@@ -2,9 +2,12 @@ package upload
import (
"errors"
"fmt"
"github.com/get-drexa/drexa/internal/auth"
"github.com/get-drexa/drexa/internal/account"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
)
type createUploadRequest struct {
@@ -17,17 +20,16 @@ type updateUploadRequest struct {
}
type HTTPHandler struct {
service *Service
authMiddleware fiber.Handler
service *Service
db *bun.DB
}
func NewHTTPHandler(s *Service, authMiddleware fiber.Handler) *HTTPHandler {
return &HTTPHandler{service: s, authMiddleware: authMiddleware}
func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
return &HTTPHandler{service: s, db: db}
}
func (h *HTTPHandler) RegisterRoutes(api fiber.Router, authMiddleware fiber.Handler) {
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
upload := api.Group("/uploads")
upload.Use(authMiddleware)
upload.Post("/", h.Create)
upload.Put("/:uploadID/content", h.ReceiveContent)
@@ -35,9 +37,9 @@ func (h *HTTPHandler) RegisterRoutes(api fiber.Router, authMiddleware fiber.Hand
}
func (h *HTTPHandler) Create(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
account := account.CurrentAccount(c)
if account == nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
req := new(createUploadRequest)
@@ -45,38 +47,54 @@ func (h *HTTPHandler) Create(c *fiber.Ctx) error {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
upload, err := h.service.CreateUpload(c.Context(), u.ID, CreateUploadOptions{
upload, err := h.service.CreateUpload(c.Context(), h.db, account.ID, CreateUploadOptions{
ParentID: req.ParentID,
Name: req.Name,
})
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
if errors.Is(err, ErrNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
if errors.Is(err, ErrParentNotDirectory) {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Parent is not a directory"})
}
if errors.Is(err, ErrConflict) {
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "A file with this name already exists"})
}
return httperr.Internal(err)
}
if upload.UploadURL == "" {
upload.UploadURL = fmt.Sprintf("%s%s/%s/content", c.BaseURL(), c.OriginalURL(), upload.ID)
}
return c.JSON(upload)
}
func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
account := account.CurrentAccount(c)
if account == nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
uploadID := c.Params("uploadID")
err = h.service.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream())
defer c.Request().CloseBodyStream()
err := h.service.ReceiveUpload(c.Context(), h.db, account.ID, uploadID, c.Context().RequestBodyStream())
defer c.Context().Request.CloseBodyStream()
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
if errors.Is(err, ErrNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
return c.SendStatus(fiber.StatusNoContent)
}
func (h *HTTPHandler) Update(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
account := account.CurrentAccount(c)
if account == nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
req := new(updateUploadRequest)
@@ -85,12 +103,15 @@ func (h *HTTPHandler) Update(c *fiber.Ctx) error {
}
if req.Status == StatusCompleted {
upload, err := h.service.CompleteUpload(c.Context(), u.ID, c.Params("uploadID"))
upload, err := h.service.CompleteUpload(c.Context(), h.db, account.ID, c.Params("uploadID"))
if err != nil {
if errors.Is(err, ErrNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
if errors.Is(err, ErrContentNotUploaded) {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Content has not been uploaded"})
}
return httperr.Internal(err)
}
return c.JSON(upload)
}

View File

@@ -3,6 +3,7 @@ package upload
import (
"context"
"errors"
"fmt"
"io"
"sync"
"time"
@@ -10,6 +11,7 @@ import (
"github.com/get-drexa/drexa/internal/blob"
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type Service struct {
@@ -33,8 +35,8 @@ type CreateUploadOptions struct {
Name string
}
func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts CreateUploadOptions) (*Upload, error) {
parentNode, err := s.vfs.FindNodeByPublicID(ctx, userID, opts.ParentID)
func (s *Service) CreateUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, opts CreateUploadOptions) (*Upload, error) {
parentNode, err := s.vfs.FindNodeByPublicID(ctx, db, accountID, opts.ParentID)
if err != nil {
if errors.Is(err, virtualfs.ErrNodeNotFound) {
return nil, ErrNotFound
@@ -46,7 +48,7 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
return nil, ErrParentNotDirectory
}
node, err := s.vfs.CreateFile(ctx, userID, virtualfs.CreateFileOptions{
node, err := s.vfs.CreateFile(ctx, db, accountID, virtualfs.CreateFileOptions{
ParentID: parentNode.ID,
Name: opts.Name,
})
@@ -57,12 +59,17 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
return nil, err
}
uploadURL, err := s.blobStore.GenerateUploadURL(ctx, node.BlobKey, blob.UploadURLOptions{
Duration: 1 * time.Hour,
})
if err != nil {
_ = s.vfs.PermanentlyDeleteNode(ctx, node)
return nil, err
var uploadURL string
if s.blobStore.SupportsDirectUpload() {
uploadURL, err = s.blobStore.GenerateUploadURL(ctx, node.BlobKey, blob.UploadURLOptions{
Duration: 1 * time.Hour,
})
if err != nil {
_ = s.vfs.PermanentlyDeleteNode(ctx, db, node)
return nil, err
}
} else {
uploadURL = ""
}
upload := &Upload{
@@ -77,7 +84,8 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
return upload, nil
}
func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID string, reader io.Reader) error {
func (s *Service) ReceiveUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, uploadID string, reader io.Reader) error {
fmt.Printf("reader: %v\n", reader)
n, ok := s.pendingUploads.Load(uploadID)
if !ok {
return ErrNotFound
@@ -88,11 +96,11 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID
return ErrNotFound
}
if upload.TargetNode.UserID != userID {
if upload.TargetNode.AccountID != accountID {
return ErrNotFound
}
err := s.vfs.WriteFile(ctx, upload.TargetNode, virtualfs.FileContentFromReader(reader))
err := s.vfs.WriteFile(ctx, db, upload.TargetNode, virtualfs.FileContentFromReader(reader))
if err != nil {
return err
}
@@ -102,7 +110,7 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID
return nil
}
func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID string) (*Upload, error) {
func (s *Service) CompleteUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, uploadID string) (*Upload, error) {
n, ok := s.pendingUploads.Load(uploadID)
if !ok {
return nil, ErrNotFound
@@ -113,7 +121,7 @@ func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID
return nil, ErrNotFound
}
if upload.TargetNode.UserID != userID {
if upload.TargetNode.AccountID != accountID {
return nil, ErrNotFound
}
@@ -121,8 +129,11 @@ func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID
return upload, nil
}
err := s.vfs.WriteFile(ctx, upload.TargetNode, virtualfs.FileContentFromBlobKey(upload.TargetNode.BlobKey))
err := s.vfs.WriteFile(ctx, db, upload.TargetNode, virtualfs.FileContentFromBlobKey(upload.TargetNode.BlobKey))
if err != nil {
if errors.Is(err, blob.ErrNotFound) {
return nil, ErrContentNotUploaded
}
return nil, err
}

View File

@@ -11,14 +11,12 @@ import (
type User struct {
bun.BaseModel `bun:"users"`
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
DisplayName string `bun:"display_name" json:"displayName"`
Email string `bun:"email,unique,notnull" json:"email"`
Password password.Hashed `bun:"password,notnull" json:"-"`
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
DisplayName string `bun:"display_name" json:"displayName"`
Email string `bun:"email,unique,notnull" json:"email"`
Password password.Hashed `bun:"password,notnull" json:"-"`
CreatedAt time.Time `bun:"created_at,notnull,nullzero" json:"createdAt"`
UpdatedAt time.Time `bun:"updated_at,notnull,nullzero" json:"updatedAt"`
}
func newUserID() (uuid.UUID, error) {

View File

@@ -15,8 +15,8 @@ func NewFlatKeyResolver() *FlatKeyResolver {
return &FlatKeyResolver{}
}
func (r *FlatKeyResolver) KeyMode() blob.KeyMode {
return blob.KeyModeStable
func (r *FlatKeyResolver) ShouldPersistKey() bool {
return true
}
func (r *FlatKeyResolver) Resolve(ctx context.Context, node *Node) (blob.Key, error) {

View File

@@ -2,6 +2,7 @@ package virtualfs
import (
"context"
"fmt"
"github.com/get-drexa/drexa/internal/blob"
"github.com/uptrace/bun"
@@ -17,8 +18,8 @@ func NewHierarchicalKeyResolver(db *bun.DB) *HierarchicalKeyResolver {
return &HierarchicalKeyResolver{db: db}
}
func (r *HierarchicalKeyResolver) KeyMode() blob.KeyMode {
return blob.KeyModeDerived
func (r *HierarchicalKeyResolver) ShouldPersistKey() bool {
return false
}
func (r *HierarchicalKeyResolver) Resolve(ctx context.Context, node *Node) (blob.Key, error) {
@@ -27,7 +28,7 @@ func (r *HierarchicalKeyResolver) Resolve(ctx context.Context, node *Node) (blob
return "", err
}
return blob.Key(path), nil
return blob.Key(fmt.Sprintf("%s/%s", node.AccountID, path)), nil
}
func (r *HierarchicalKeyResolver) ResolveDeletionKeys(ctx context.Context, node *Node, allKeys []blob.Key) (*DeletionPlan, error) {

View File

@@ -7,7 +7,10 @@ import (
)
type BlobKeyResolver interface {
KeyMode() blob.KeyMode
// ShouldPersistKey returns true if the resolved key should be stored in node.BlobKey.
// Flat keys (e.g. UUIDs) return true - key is generated once and stored.
// Hierarchical keys return false - key is derived from path each time.
ShouldPersistKey() bool
Resolve(ctx context.Context, node *Node) (blob.Key, error)
ResolveDeletionKeys(ctx context.Context, node *Node, allKeys []blob.Key) (*DeletionPlan, error)
}

View File

@@ -25,21 +25,25 @@ const (
type Node struct {
bun.BaseModel `bun:"vfs_nodes"`
ID uuid.UUID `bun:",pk,type:uuid"`
PublicID string `bun:"public_id,notnull"`
UserID uuid.UUID `bun:"user_id,notnull"`
ParentID uuid.UUID `bun:"parent_id,notnull"`
Kind NodeKind `bun:"kind,notnull"`
Status NodeStatus `bun:"status,notnull"`
Name string `bun:"name,notnull"`
ID uuid.UUID `bun:",pk,type:uuid"`
PublicID string `bun:"public_id,notnull"`
AccountID uuid.UUID `bun:"account_id,notnull,type:uuid"`
ParentID uuid.UUID `bun:"parent_id,nullzero"`
Kind NodeKind `bun:"kind,notnull"`
Status NodeStatus `bun:"status,notnull"`
Name string `bun:"name,notnull"`
BlobKey blob.Key `bun:"blob_key"`
BlobKey blob.Key `bun:"blob_key,nullzero"`
Size int64 `bun:"size"`
MimeType string `bun:"mime_type"`
MimeType string `bun:"mime_type,nullzero"`
CreatedAt time.Time `bun:"created_at,notnull"`
UpdatedAt time.Time `bun:"updated_at,notnull"`
DeletedAt time.Time `bun:"deleted_at"`
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
UpdatedAt time.Time `bun:"updated_at,notnull,nullzero"`
DeletedAt time.Time `bun:"deleted_at,nullzero"`
}
func newNodeID() (uuid.UUID, error) {
return uuid.NewV7()
}
// IsAccessible returns true if the node can be accessed.

View File

@@ -12,7 +12,7 @@ import (
const absolutePathQuery = `WITH RECURSIVE path AS (
SELECT id, parent_id, name, 1 as depth
FROM vfs_nodes WHERE id = $1 AND deleted_at IS NULL
FROM vfs_nodes WHERE id = ? AND deleted_at IS NULL
UNION ALL
@@ -29,7 +29,7 @@ func JoinPath(parts ...string) string {
return strings.Join(parts, "/")
}
func buildNodeAbsolutePath(ctx context.Context, db *bun.DB, nodeID uuid.UUID) (string, error) {
func buildNodeAbsolutePath(ctx context.Context, db bun.IDB, nodeID uuid.UUID) (string, error) {
var path []string
err := db.NewRaw(absolutePathQuery, nodeID).Scan(ctx, &path)
if err != nil {

View File

@@ -19,7 +19,6 @@ import (
)
type VirtualFS struct {
db *bun.DB
blobStore blob.Store
keyResolver BlobKeyResolver
@@ -42,6 +41,8 @@ type FileContent struct {
blobKey blob.Key
}
const RootDirectoryName = "root"
func FileContentFromReader(reader io.Reader) FileContent {
return FileContent{reader: reader}
}
@@ -50,23 +51,22 @@ func FileContentFromBlobKey(blobKey blob.Key) FileContent {
return FileContent{blobKey: blobKey}
}
func NewVirtualFS(db *bun.DB, blobStore blob.Store, keyResolver BlobKeyResolver) (*VirtualFS, error) {
func NewVirtualFS(blobStore blob.Store, keyResolver BlobKeyResolver) (*VirtualFS, error) {
sqid, err := sqids.New()
if err != nil {
return nil, err
}
return &VirtualFS{
db: db,
blobStore: blobStore,
keyResolver: keyResolver,
sqid: sqid,
}, nil
}
func (vfs *VirtualFS) FindNode(ctx context.Context, userID, fileID string) (*Node, error) {
func (vfs *VirtualFS) FindNode(ctx context.Context, db bun.IDB, accountID, fileID string) (*Node, error) {
var node Node
err := vfs.db.NewSelect().Model(&node).
Where("user_id = ?", userID).
err := db.NewSelect().Model(&node).
Where("account_id = ?", accountID).
Where("id = ?", fileID).
Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL").
@@ -80,10 +80,10 @@ func (vfs *VirtualFS) FindNode(ctx context.Context, userID, fileID string) (*Nod
return &node, nil
}
func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, userID uuid.UUID, publicID string) (*Node, error) {
func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, db bun.IDB, accountID uuid.UUID, publicID string) (*Node, error) {
var node Node
err := vfs.db.NewSelect().Model(&node).
Where("user_id = ?", userID).
err := db.NewSelect().Model(&node).
Where("account_id = ?", accountID).
Where("public_id = ?", publicID).
Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL").
@@ -97,14 +97,14 @@ func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, userID uuid.UUID,
return &node, nil
}
func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, error) {
func (vfs *VirtualFS) ListChildren(ctx context.Context, db bun.IDB, node *Node) ([]*Node, error) {
if !node.IsAccessible() {
return nil, ErrNodeNotFound
}
var nodes []*Node
err := vfs.db.NewSelect().Model(&nodes).
Where("user_id = ?", node.UserID).
err := db.NewSelect().Model(&nodes).
Where("account_id = ?", node.AccountID).
Where("parent_id = ?", node.ID).
Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL").
@@ -119,29 +119,35 @@ func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, er
return nodes, nil
}
func (vfs *VirtualFS) CreateFile(ctx context.Context, userID uuid.UUID, opts CreateFileOptions) (*Node, error) {
func (vfs *VirtualFS) CreateFile(ctx context.Context, db bun.IDB, accountID uuid.UUID, opts CreateFileOptions) (*Node, error) {
pid, err := vfs.generatePublicID()
if err != nil {
return nil, err
}
node := Node{
PublicID: pid,
UserID: userID,
ParentID: opts.ParentID,
Kind: NodeKindFile,
Status: NodeStatusPending,
Name: opts.Name,
id, err := newNodeID()
if err != nil {
return nil, err
}
if vfs.keyResolver.KeyMode() == blob.KeyModeStable {
node := Node{
ID: id,
PublicID: pid,
AccountID: accountID,
ParentID: opts.ParentID,
Kind: NodeKindFile,
Status: NodeStatusPending,
Name: opts.Name,
}
if vfs.keyResolver.ShouldPersistKey() {
node.BlobKey, err = vfs.keyResolver.Resolve(ctx, &node)
if err != nil {
return nil, err
}
}
_, err = vfs.db.NewInsert().Model(&node).Returning("*").Exec(ctx)
_, err = db.NewInsert().Model(&node).Returning("*").Exec(ctx)
if err != nil {
if database.IsUniqueViolation(err) {
return nil, ErrNodeConflict
@@ -152,7 +158,7 @@ func (vfs *VirtualFS) CreateFile(ctx context.Context, userID uuid.UUID, opts Cre
return &node, nil
}
func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileContent) error {
func (vfs *VirtualFS) WriteFile(ctx context.Context, db bun.IDB, node *Node, content FileContent) error {
if content.reader == nil && content.blobKey.IsNil() {
return blob.ErrInvalidFileContent
}
@@ -184,7 +190,7 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
return err
}
if vfs.keyResolver.KeyMode() == blob.KeyModeStable {
if vfs.keyResolver.ShouldPersistKey() {
node.BlobKey = key
setCols = append(setCols, "blob_key")
}
@@ -222,7 +228,7 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
setCols = append(setCols, "mime_type", "blob_key", "size", "status")
}
_, err := vfs.db.NewUpdate().Model(&node).
_, err := db.NewUpdate().Model(node).
Column(setCols...).
WherePK().
Exec(ctx)
@@ -233,22 +239,28 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
return nil
}
func (vfs *VirtualFS) CreateDirectory(ctx context.Context, userID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) {
func (vfs *VirtualFS) CreateDirectory(ctx context.Context, db bun.IDB, accountID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) {
pid, err := vfs.generatePublicID()
if err != nil {
return nil, err
}
node := Node{
PublicID: pid,
UserID: userID,
ParentID: parentID,
Kind: NodeKindDirectory,
Status: NodeStatusReady,
Name: name,
id, err := newNodeID()
if err != nil {
return nil, err
}
_, err = vfs.db.NewInsert().Model(&node).Exec(ctx)
node := Node{
ID: id,
PublicID: pid,
AccountID: accountID,
ParentID: parentID,
Kind: NodeKindDirectory,
Status: NodeStatusReady,
Name: name,
}
_, err = db.NewInsert().Model(node).Exec(ctx)
if err != nil {
if database.IsUniqueViolation(err) {
return nil, ErrNodeConflict
@@ -259,12 +271,12 @@ func (vfs *VirtualFS) CreateDirectory(ctx context.Context, userID uuid.UUID, par
return &node, nil
}
func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, node *Node) error {
func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, db bun.IDB, node *Node) error {
if !node.IsAccessible() {
return ErrNodeNotFound
}
_, err := vfs.db.NewUpdate().Model(node).
_, err := db.NewUpdate().Model(node).
WherePK().
Where("deleted_at IS NULL").
Where("status = ?", NodeStatusReady).
@@ -281,12 +293,12 @@ func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, node *Node) error {
return nil
}
func (vfs *VirtualFS) RestoreNode(ctx context.Context, node *Node) error {
func (vfs *VirtualFS) RestoreNode(ctx context.Context, db bun.IDB, node *Node) error {
if node.Status != NodeStatusReady {
return ErrNodeNotFound
}
_, err := vfs.db.NewUpdate().Model(node).
_, err := db.NewUpdate().Model(node).
WherePK().
Where("deleted_at IS NOT NULL").
Set("deleted_at = NULL").
@@ -302,12 +314,12 @@ func (vfs *VirtualFS) RestoreNode(ctx context.Context, node *Node) error {
return nil
}
func (vfs *VirtualFS) RenameNode(ctx context.Context, node *Node, name string) error {
func (vfs *VirtualFS) RenameNode(ctx context.Context, db bun.IDB, node *Node, name string) error {
if !node.IsAccessible() {
return ErrNodeNotFound
}
_, err := vfs.db.NewUpdate().Model(node).
_, err := db.NewUpdate().Model(node).
WherePK().
Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL").
@@ -323,7 +335,7 @@ func (vfs *VirtualFS) RenameNode(ctx context.Context, node *Node, name string) e
return nil
}
func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UUID) error {
func (vfs *VirtualFS) MoveNode(ctx context.Context, db bun.IDB, node *Node, parentID uuid.UUID) error {
if !node.IsAccessible() {
return ErrNodeNotFound
}
@@ -333,7 +345,7 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU
return err
}
_, err = vfs.db.NewUpdate().Model(node).
_, err = db.NewUpdate().Model(node).
WherePK().
Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL").
@@ -360,9 +372,9 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU
return err
}
if vfs.keyResolver.KeyMode() == blob.KeyModeStable {
if vfs.keyResolver.ShouldPersistKey() {
node.BlobKey = newKey
_, err = vfs.db.NewUpdate().Model(node).
_, err = db.NewUpdate().Model(node).
WherePK().
Set("blob_key = ?", newKey).
Exec(ctx)
@@ -374,34 +386,34 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU
return nil
}
func (vfs *VirtualFS) AbsolutePath(ctx context.Context, node *Node) (string, error) {
func (vfs *VirtualFS) AbsolutePath(ctx context.Context, db bun.IDB, node *Node) (string, error) {
if !node.IsAccessible() {
return "", ErrNodeNotFound
}
return buildNodeAbsolutePath(ctx, vfs.db, node.ID)
return buildNodeAbsolutePath(ctx, db, node.ID)
}
func (vfs *VirtualFS) PermanentlyDeleteNode(ctx context.Context, node *Node) error {
func (vfs *VirtualFS) PermanentlyDeleteNode(ctx context.Context, db bun.IDB, node *Node) error {
if !node.IsAccessible() {
return ErrNodeNotFound
}
switch node.Kind {
case NodeKindFile:
return vfs.permanentlyDeleteFileNode(ctx, node)
return vfs.permanentlyDeleteFileNode(ctx, db, node)
case NodeKindDirectory:
return vfs.permanentlyDeleteDirectoryNode(ctx, node)
return vfs.permanentlyDeleteDirectoryNode(ctx, db, node)
default:
return ErrUnsupportedOperation
}
}
func (vfs *VirtualFS) permanentlyDeleteFileNode(ctx context.Context, node *Node) error {
func (vfs *VirtualFS) permanentlyDeleteFileNode(ctx context.Context, db bun.IDB, node *Node) error {
err := vfs.blobStore.Delete(ctx, node.BlobKey)
if err != nil {
return err
}
_, err = vfs.db.NewDelete().Model(node).WherePK().Exec(ctx)
_, err = db.NewDelete().Model(node).WherePK().Exec(ctx)
if err != nil {
return err
}
@@ -409,7 +421,7 @@ func (vfs *VirtualFS) permanentlyDeleteFileNode(ctx context.Context, node *Node)
return nil
}
func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, node *Node) error {
func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, db bun.IDB, node *Node) error {
const descendantsQuery = `WITH RECURSIVE descendants AS (
SELECT id, blob_key FROM vfs_nodes WHERE id = ?
UNION ALL
@@ -423,14 +435,29 @@ func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, node *
BlobKey blob.Key `bun:"blob_key"`
}
tx, err := vfs.db.BeginTx(ctx, nil)
if err != nil {
return err
// If db is already a transaction, use it directly; otherwise start a new transaction
var tx bun.IDB
var startedTx *bun.Tx
switch v := db.(type) {
case *bun.DB:
newTx, err := v.BeginTx(ctx, nil)
if err != nil {
return err
}
startedTx = &newTx
tx = newTx
defer func() {
if startedTx != nil {
(*startedTx).Rollback()
}
}()
default:
// Assume it's already a transaction
tx = db
}
defer tx.Rollback()
var records []nodeRecord
err = tx.NewRaw(descendantsQuery, node.ID).Scan(ctx, &records)
err := tx.NewRaw(descendantsQuery, node.ID).Scan(ctx, &records)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrNodeNotFound
@@ -472,7 +499,14 @@ func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, node *
}
}
return tx.Commit()
// Only commit if we started the transaction
if startedTx != nil {
err := (*startedTx).Commit()
startedTx = nil // Prevent defer from rolling back
return err
}
return nil
}
func (vfs *VirtualFS) generatePublicID() (string, error) {