feat: introduce account

This commit is contained in:
2025-11-30 17:12:50 +00:00
parent 1c1392a0a1
commit 89b62f6d8a
13 changed files with 380 additions and 166 deletions

View File

@@ -0,0 +1,23 @@
package account
import (
"time"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type Account struct {
bun.BaseModel `bun:"accounts"`
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
UserID uuid.UUID `bun:"user_id,notnull,type:uuid" json:"userId"`
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
}
func newAccountID() (uuid.UUID, error) {
return uuid.NewV7()
}

View File

@@ -0,0 +1,8 @@
package account
import "errors"
var (
ErrAccountNotFound = errors.New("account not found")
ErrAccountAlreadyExists = errors.New("account already exists")
)

View File

@@ -0,0 +1,131 @@
package account
import (
"errors"
"github.com/get-drexa/drexa/internal/auth"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type HTTPHandler struct {
accountService *Service
authService *auth.Service
db *bun.DB
authMiddleware fiber.Handler
}
type registerAccountRequest struct {
Email string `json:"email"`
Password string `json:"password"`
DisplayName string `json:"displayName"`
}
type registerAccountResponse struct {
Account *Account `json:"account"`
User *user.User `json:"user"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
}
const currentAccountKey = "currentAccount"
func CurrentAccount(c *fiber.Ctx) *Account {
return c.Locals(currentAccountKey).(*Account)
}
func NewHTTPHandler(accountService *Service, authService *auth.Service, db *bun.DB, authMiddleware fiber.Handler) *HTTPHandler {
return &HTTPHandler{accountService: accountService, authService: authService, db: db, authMiddleware: authMiddleware}
}
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) fiber.Router {
api.Post("/accounts", h.registerAccount)
account := api.Group("/accounts/:accountID")
account.Use(h.authMiddleware)
account.Use(h.accountMiddleware)
account.Get("/", h.getAccount)
return account
}
func (h *HTTPHandler) accountMiddleware(c *fiber.Ctx) error {
user, err := auth.AuthenticatedUser(c)
if err != nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
accountID, err := uuid.Parse(c.Params("accountID"))
if err != nil {
return c.SendStatus(fiber.StatusNotFound)
}
account, err := h.accountService.AccountByID(c.Context(), h.db, user.ID, accountID)
if err != nil {
if errors.Is(err, ErrAccountNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return c.SendStatus(fiber.StatusInternalServerError)
}
c.Locals(currentAccountKey, account)
return c.Next()
}
func (h *HTTPHandler) getAccount(c *fiber.Ctx) error {
account := CurrentAccount(c)
if account == nil {
return c.SendStatus(fiber.StatusNotFound)
}
return c.JSON(account)
}
func (h *HTTPHandler) registerAccount(c *fiber.Ctx) error {
req := new(registerAccountRequest)
if err := c.BodyParser(req); err != nil {
return c.SendStatus(fiber.StatusBadRequest)
}
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
defer tx.Rollback()
acc, u, err := h.accountService.Register(c.Context(), tx, RegisterOptions{
Email: req.Email,
Password: req.Password,
DisplayName: req.DisplayName,
})
if err != nil {
var ae *user.AlreadyExistsError
if errors.As(err, &ae) {
return c.SendStatus(fiber.StatusConflict)
}
if errors.Is(err, ErrAccountAlreadyExists) {
return c.SendStatus(fiber.StatusConflict)
}
return c.SendStatus(fiber.StatusBadRequest)
}
result, err := h.authService.GenerateTokenForUser(c.Context(), tx, u)
if err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
err = tx.Commit()
if err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.JSON(registerAccountResponse{
Account: acc,
User: u,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
})
}

View File

@@ -0,0 +1,107 @@
package account
import (
"context"
"database/sql"
"errors"
"github.com/get-drexa/drexa/internal/database"
"github.com/get-drexa/drexa/internal/password"
"github.com/get-drexa/drexa/internal/user"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type Service struct {
userService user.Service
}
type RegisterOptions struct {
Email string
Password string
DisplayName string
}
type CreateAccountOptions struct {
OrganizationID uuid.UUID
QuotaBytes int64
}
func NewService(userService *user.Service) *Service {
return &Service{
userService: *userService,
}
}
func (s *Service) Register(ctx context.Context, db bun.IDB, opts RegisterOptions) (*Account, *user.User, error) {
hashed, err := password.Hash(opts.Password)
if err != nil {
return nil, nil, err
}
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
Email: opts.Email,
Password: hashed,
DisplayName: opts.DisplayName,
})
if err != nil {
return nil, nil, err
}
acc, err := s.CreateAccount(ctx, db, u.ID, CreateAccountOptions{
// TODO: make quota configurable
QuotaBytes: 1024 * 1024 * 1024, // 1GB
})
if err != nil {
return nil, nil, err
}
return acc, u, nil
}
func (s *Service) CreateAccount(ctx context.Context, db bun.IDB, userID uuid.UUID, opts CreateAccountOptions) (*Account, error) {
id, err := newAccountID()
if err != nil {
return nil, err
}
account := &Account{
ID: id,
UserID: userID,
StorageQuotaBytes: opts.QuotaBytes,
}
_, err = db.NewInsert().Model(account).Returning("*").Exec(ctx)
if err != nil {
if database.IsUniqueViolation(err) {
return nil, ErrAccountAlreadyExists
}
return nil, err
}
return account, nil
}
func (s *Service) AccountByUserID(ctx context.Context, db bun.IDB, userID uuid.UUID) (*Account, error) {
var account Account
err := db.NewSelect().Model(&account).Where("user_id = ?", userID).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrAccountNotFound
}
return nil, err
}
return &account, nil
}
func (s *Service) AccountByID(ctx context.Context, db bun.IDB, userID uuid.UUID, id uuid.UUID) (*Account, error) {
var account Account
err := db.NewSelect().Model(&account).Where("user_id = ?", userID).Where("id = ?", id).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrAccountNotFound
}
return nil, err
}
return &account, nil
}

View File

@@ -2,7 +2,6 @@ package auth
import ( import (
"errors" "errors"
"log/slog"
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@@ -14,12 +13,6 @@ type loginRequest struct {
Password string `json:"password"` Password string `json:"password"`
} }
type registerRequest struct {
Email string `json:"email"`
Password string `json:"password"`
DisplayName string `json:"displayName"`
}
type loginResponse struct { type loginResponse struct {
User user.User `json:"user"` User user.User `json:"user"`
AccessToken string `json:"accessToken"` AccessToken string `json:"accessToken"`
@@ -37,9 +30,7 @@ func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) { func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
auth := api.Group("/auth") auth := api.Group("/auth")
auth.Post("/login", h.Login) auth.Post("/login", h.Login)
auth.Post("/register", h.Register)
} }
func (h *HTTPHandler) Login(c *fiber.Ctx) error { func (h *HTTPHandler) Login(c *fiber.Ctx) error {
@@ -54,7 +45,7 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
} }
defer tx.Rollback() defer tx.Rollback()
result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password) result, err := h.service.AuthenticateWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
if err != nil { if err != nil {
if errors.Is(err, ErrInvalidCredentials) { if errors.Is(err, ErrInvalidCredentials) {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"}) return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
@@ -72,42 +63,3 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
RefreshToken: result.RefreshToken, RefreshToken: result.RefreshToken,
}) })
} }
func (h *HTTPHandler) Register(c *fiber.Ctx) error {
req := new(registerRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
slog.Error("failed to begin transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
defer tx.Rollback()
result, err := h.service.Register(c.Context(), tx, registerOptions{
email: req.Email,
password: req.Password,
displayName: req.DisplayName,
})
if err != nil {
var ae *user.AlreadyExistsError
if errors.As(err, &ae) {
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"})
}
slog.Error("failed to register user", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
if err := tx.Commit(); err != nil {
slog.Error("failed to commit transaction", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
return c.JSON(loginResponse{
User: *result.User,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
})
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type LoginResult struct { type AuthenticationResult struct {
User *user.User User *user.User
AccessToken string AccessToken string
RefreshToken string RefreshToken string
@@ -24,12 +24,6 @@ type Service struct {
tokenConfig TokenConfig tokenConfig TokenConfig
} }
type registerOptions struct {
displayName string
email string
password string
}
func NewService(userService *user.Service, tokenConfig TokenConfig) *Service { func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
return &Service{ return &Service{
userService: userService, userService: userService,
@@ -37,7 +31,30 @@ func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
} }
} }
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*LoginResult, error) { func (s *Service) GenerateTokenForUser(ctx context.Context, db bun.IDB, user *user.User) (*AuthenticationResult, error) {
at, err := GenerateAccessToken(user, &s.tokenConfig)
if err != nil {
return nil, err
}
rt, err := GenerateRefreshToken(user, &s.tokenConfig)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
return &AuthenticationResult{
User: user,
AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token),
}, nil
}
func (s *Service) AuthenticateWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*AuthenticationResult, error) {
u, err := s.userService.UserByEmail(ctx, db, email) u, err := s.userService.UserByEmail(ctx, db, email)
if err != nil { if err != nil {
var nf *user.NotFoundError var nf *user.NotFoundError
@@ -67,44 +84,7 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, ema
return nil, err return nil, err
} }
return &LoginResult{ return &AuthenticationResult{
User: u,
AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token),
}, nil
}
func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions) (*LoginResult, error) {
hashed, err := password.Hash(opts.password)
if err != nil {
return nil, err
}
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
Email: opts.email,
DisplayName: opts.displayName,
Password: hashed,
})
if err != nil {
return nil, err
}
at, err := GenerateAccessToken(u, &s.tokenConfig)
if err != nil {
return nil, err
}
rt, err := GenerateRefreshToken(u, &s.tokenConfig)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
return &LoginResult{
User: u, User: u,
AccessToken: at, AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token), RefreshToken: hex.EncodeToString(rt.Token),

View File

@@ -7,14 +7,23 @@ CREATE TABLE IF NOT EXISTS users (
display_name TEXT, display_name TEXT,
email TEXT NOT NULL UNIQUE, email TEXT NOT NULL UNIQUE,
password TEXT NOT NULL, password TEXT NOT NULL,
storage_usage_bytes BIGINT NOT NULL,
storage_quota_bytes BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE INDEX idx_users_email ON users(email); CREATE INDEX idx_users_email ON users(email);
CREATE TABLE IF NOT EXISTS accounts (
id UUID PRIMARY KEY,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
storage_usage_bytes BIGINT NOT NULL DEFAULT 0,
storage_quota_bytes BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_accounts_user_id ON accounts(user_id);
CREATE TABLE IF NOT EXISTS refresh_tokens ( CREATE TABLE IF NOT EXISTS refresh_tokens (
id UUID PRIMARY KEY, id UUID PRIMARY KEY,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
@@ -31,7 +40,7 @@ CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
CREATE TABLE IF NOT EXISTS vfs_nodes ( CREATE TABLE IF NOT EXISTS vfs_nodes (
id UUID PRIMARY KEY, id UUID PRIMARY KEY,
public_id TEXT NOT NULL UNIQUE, -- opaque ID for external API (no timestamp leak) public_id TEXT NOT NULL UNIQUE, -- opaque ID for external API (no timestamp leak)
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, account_id UUID NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
parent_id UUID REFERENCES vfs_nodes(id) ON DELETE CASCADE, -- NULL = root directory parent_id UUID REFERENCES vfs_nodes(id) ON DELETE CASCADE, -- NULL = root directory
kind TEXT NOT NULL CHECK (kind IN ('file', 'directory')), kind TEXT NOT NULL CHECK (kind IN ('file', 'directory')),
status TEXT NOT NULL DEFAULT 'ready' CHECK (status IN ('pending', 'ready')), status TEXT NOT NULL DEFAULT 'ready' CHECK (status IN ('pending', 'ready')),
@@ -46,17 +55,17 @@ CREATE TABLE IF NOT EXISTS vfs_nodes (
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ, -- soft delete for trash deleted_at TIMESTAMPTZ, -- soft delete for trash
-- No duplicate names in same parent (per user, excluding deleted) -- No duplicate names in same parent (per account, excluding deleted)
CONSTRAINT unique_node_name UNIQUE NULLS NOT DISTINCT (user_id, parent_id, name, deleted_at) CONSTRAINT unique_node_name UNIQUE NULLS NOT DISTINCT (account_id, parent_id, name, deleted_at)
); );
CREATE INDEX idx_vfs_nodes_user_id ON vfs_nodes(user_id) WHERE deleted_at IS NULL; CREATE INDEX idx_vfs_nodes_account_id ON vfs_nodes(account_id) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_parent_id ON vfs_nodes(parent_id) WHERE deleted_at IS NULL; CREATE INDEX idx_vfs_nodes_parent_id ON vfs_nodes(parent_id) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_user_parent ON vfs_nodes(user_id, parent_id) WHERE deleted_at IS NULL; CREATE INDEX idx_vfs_nodes_account_parent ON vfs_nodes(account_id, parent_id) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_kind ON vfs_nodes(user_id, kind) WHERE deleted_at IS NULL; CREATE INDEX idx_vfs_nodes_kind ON vfs_nodes(account_id, kind) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_deleted ON vfs_nodes(user_id, deleted_at) WHERE deleted_at IS NOT NULL; CREATE INDEX idx_vfs_nodes_deleted ON vfs_nodes(account_id, deleted_at) WHERE deleted_at IS NOT NULL;
CREATE INDEX idx_vfs_nodes_public_id ON vfs_nodes(public_id); CREATE INDEX idx_vfs_nodes_public_id ON vfs_nodes(public_id);
CREATE UNIQUE INDEX idx_vfs_nodes_user_root ON vfs_nodes(user_id) WHERE parent_id IS NULL; -- one root per user CREATE UNIQUE INDEX idx_vfs_nodes_account_root ON vfs_nodes(account_id) WHERE parent_id IS NULL; -- one root per account
CREATE INDEX idx_vfs_nodes_pending ON vfs_nodes(created_at) WHERE status = 'pending'; -- for cleanup job CREATE INDEX idx_vfs_nodes_pending ON vfs_nodes(created_at) WHERE status = 'pending'; -- for cleanup job
CREATE TABLE IF NOT EXISTS node_shares ( CREATE TABLE IF NOT EXISTS node_shares (
@@ -92,3 +101,6 @@ CREATE TRIGGER update_vfs_nodes_updated_at BEFORE UPDATE ON vfs_nodes
CREATE TRIGGER update_node_shares_updated_at BEFORE UPDATE ON node_shares CREATE TRIGGER update_node_shares_updated_at BEFORE UPDATE ON node_shares
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
CREATE TRIGGER update_accounts_updated_at BEFORE UPDATE ON accounts
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/get-drexa/drexa/internal/account"
"github.com/get-drexa/drexa/internal/auth" "github.com/get-drexa/drexa/internal/auth"
"github.com/get-drexa/drexa/internal/blob" "github.com/get-drexa/drexa/internal/blob"
"github.com/get-drexa/drexa/internal/database" "github.com/get-drexa/drexa/internal/database"
@@ -61,12 +62,16 @@ func NewServer(c Config) (*fiber.App, error) {
SecretKey: c.JWT.SecretKey, SecretKey: c.JWT.SecretKey,
}) })
uploadService := upload.NewService(vfs, blobStore) uploadService := upload.NewService(vfs, blobStore)
accountService := account.NewService(userService)
authMiddleware := auth.NewBearerAuthMiddleware(authService, db) authMiddleware := auth.NewBearerAuthMiddleware(authService, db)
api := app.Group("/api") api := app.Group("/api")
accRouter := account.NewHTTPHandler(accountService, authService, db, authMiddleware).RegisterRoutes(api)
auth.NewHTTPHandler(authService, db).RegisterRoutes(api) auth.NewHTTPHandler(authService, db).RegisterRoutes(api)
upload.NewHTTPHandler(uploadService, authMiddleware).RegisterRoutes(api, authMiddleware) upload.NewHTTPHandler(uploadService).RegisterRoutes(accRouter)
return app, nil return app, nil
} }

View File

@@ -3,7 +3,7 @@ package upload
import ( import (
"errors" "errors"
"github.com/get-drexa/drexa/internal/auth" "github.com/get-drexa/drexa/internal/account"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
@@ -17,17 +17,15 @@ type updateUploadRequest struct {
} }
type HTTPHandler struct { type HTTPHandler struct {
service *Service service *Service
authMiddleware fiber.Handler
} }
func NewHTTPHandler(s *Service, authMiddleware fiber.Handler) *HTTPHandler { func NewHTTPHandler(s *Service) *HTTPHandler {
return &HTTPHandler{service: s, authMiddleware: authMiddleware} return &HTTPHandler{service: s}
} }
func (h *HTTPHandler) RegisterRoutes(api fiber.Router, authMiddleware fiber.Handler) { func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
upload := api.Group("/uploads") upload := api.Group("/uploads")
upload.Use(authMiddleware)
upload.Post("/", h.Create) upload.Post("/", h.Create)
upload.Put("/:uploadID/content", h.ReceiveContent) upload.Put("/:uploadID/content", h.ReceiveContent)
@@ -35,9 +33,9 @@ func (h *HTTPHandler) RegisterRoutes(api fiber.Router, authMiddleware fiber.Hand
} }
func (h *HTTPHandler) Create(c *fiber.Ctx) error { func (h *HTTPHandler) Create(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c) account := account.CurrentAccount(c)
if err != nil { if account == nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"}) return c.SendStatus(fiber.StatusUnauthorized)
} }
req := new(createUploadRequest) req := new(createUploadRequest)
@@ -45,7 +43,7 @@ func (h *HTTPHandler) Create(c *fiber.Ctx) error {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
} }
upload, err := h.service.CreateUpload(c.Context(), u.ID, CreateUploadOptions{ upload, err := h.service.CreateUpload(c.Context(), account.ID, CreateUploadOptions{
ParentID: req.ParentID, ParentID: req.ParentID,
Name: req.Name, Name: req.Name,
}) })
@@ -57,14 +55,14 @@ func (h *HTTPHandler) Create(c *fiber.Ctx) error {
} }
func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error { func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c) account := account.CurrentAccount(c)
if err != nil { if account == nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"}) return c.SendStatus(fiber.StatusUnauthorized)
} }
uploadID := c.Params("uploadID") uploadID := c.Params("uploadID")
err = h.service.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream()) err := h.service.ReceiveUpload(c.Context(), account.ID, uploadID, c.Request().BodyStream())
defer c.Request().CloseBodyStream() defer c.Request().CloseBodyStream()
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
@@ -74,9 +72,9 @@ func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
} }
func (h *HTTPHandler) Update(c *fiber.Ctx) error { func (h *HTTPHandler) Update(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c) account := account.CurrentAccount(c)
if err != nil { if account == nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"}) return c.SendStatus(fiber.StatusUnauthorized)
} }
req := new(updateUploadRequest) req := new(updateUploadRequest)
@@ -85,7 +83,7 @@ func (h *HTTPHandler) Update(c *fiber.Ctx) error {
} }
if req.Status == StatusCompleted { if req.Status == StatusCompleted {
upload, err := h.service.CompleteUpload(c.Context(), u.ID, c.Params("uploadID")) upload, err := h.service.CompleteUpload(c.Context(), account.ID, c.Params("uploadID"))
if err != nil { if err != nil {
if errors.Is(err, ErrNotFound) { if errors.Is(err, ErrNotFound) {
return c.SendStatus(fiber.StatusNotFound) return c.SendStatus(fiber.StatusNotFound)

View File

@@ -33,8 +33,8 @@ type CreateUploadOptions struct {
Name string Name string
} }
func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts CreateUploadOptions) (*Upload, error) { func (s *Service) CreateUpload(ctx context.Context, accountID uuid.UUID, opts CreateUploadOptions) (*Upload, error) {
parentNode, err := s.vfs.FindNodeByPublicID(ctx, userID, opts.ParentID) parentNode, err := s.vfs.FindNodeByPublicID(ctx, accountID, opts.ParentID)
if err != nil { if err != nil {
if errors.Is(err, virtualfs.ErrNodeNotFound) { if errors.Is(err, virtualfs.ErrNodeNotFound) {
return nil, ErrNotFound return nil, ErrNotFound
@@ -46,7 +46,7 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
return nil, ErrParentNotDirectory return nil, ErrParentNotDirectory
} }
node, err := s.vfs.CreateFile(ctx, userID, virtualfs.CreateFileOptions{ node, err := s.vfs.CreateFile(ctx, accountID, virtualfs.CreateFileOptions{
ParentID: parentNode.ID, ParentID: parentNode.ID,
Name: opts.Name, Name: opts.Name,
}) })
@@ -77,7 +77,7 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
return upload, nil return upload, nil
} }
func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID string, reader io.Reader) error { func (s *Service) ReceiveUpload(ctx context.Context, accountID uuid.UUID, uploadID string, reader io.Reader) error {
n, ok := s.pendingUploads.Load(uploadID) n, ok := s.pendingUploads.Load(uploadID)
if !ok { if !ok {
return ErrNotFound return ErrNotFound
@@ -88,7 +88,7 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID
return ErrNotFound return ErrNotFound
} }
if upload.TargetNode.UserID != userID { if upload.TargetNode.AccountID != accountID {
return ErrNotFound return ErrNotFound
} }
@@ -102,7 +102,7 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID
return nil return nil
} }
func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID string) (*Upload, error) { func (s *Service) CompleteUpload(ctx context.Context, accountID uuid.UUID, uploadID string) (*Upload, error) {
n, ok := s.pendingUploads.Load(uploadID) n, ok := s.pendingUploads.Load(uploadID)
if !ok { if !ok {
return nil, ErrNotFound return nil, ErrNotFound
@@ -113,7 +113,7 @@ func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID
return nil, ErrNotFound return nil, ErrNotFound
} }
if upload.TargetNode.UserID != userID { if upload.TargetNode.AccountID != accountID {
return nil, ErrNotFound return nil, ErrNotFound
} }

View File

@@ -11,14 +11,12 @@ import (
type User struct { type User struct {
bun.BaseModel `bun:"users"` bun.BaseModel `bun:"users"`
ID uuid.UUID `bun:",pk,type:uuid" json:"id"` ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
DisplayName string `bun:"display_name" json:"displayName"` DisplayName string `bun:"display_name" json:"displayName"`
Email string `bun:"email,unique,notnull" json:"email"` Email string `bun:"email,unique,notnull" json:"email"`
Password password.Hashed `bun:"password,notnull" json:"-"` Password password.Hashed `bun:"password,notnull" json:"-"`
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"` CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"` UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
} }
func newUserID() (uuid.UUID, error) { func newUserID() (uuid.UUID, error) {

View File

@@ -25,13 +25,13 @@ const (
type Node struct { type Node struct {
bun.BaseModel `bun:"vfs_nodes"` bun.BaseModel `bun:"vfs_nodes"`
ID uuid.UUID `bun:",pk,type:uuid"` ID uuid.UUID `bun:",pk,type:uuid"`
PublicID string `bun:"public_id,notnull"` PublicID string `bun:"public_id,notnull"`
UserID uuid.UUID `bun:"user_id,notnull"` AccountID uuid.UUID `bun:"account_id,notnull,type:uuid"`
ParentID uuid.UUID `bun:"parent_id,notnull"` ParentID uuid.UUID `bun:"parent_id"`
Kind NodeKind `bun:"kind,notnull"` Kind NodeKind `bun:"kind,notnull"`
Status NodeStatus `bun:"status,notnull"` Status NodeStatus `bun:"status,notnull"`
Name string `bun:"name,notnull"` Name string `bun:"name,notnull"`
BlobKey blob.Key `bun:"blob_key"` BlobKey blob.Key `bun:"blob_key"`
Size int64 `bun:"size"` Size int64 `bun:"size"`

View File

@@ -63,10 +63,10 @@ func NewVirtualFS(db *bun.DB, blobStore blob.Store, keyResolver BlobKeyResolver)
}, nil }, nil
} }
func (vfs *VirtualFS) FindNode(ctx context.Context, userID, fileID string) (*Node, error) { func (vfs *VirtualFS) FindNode(ctx context.Context, accountID, fileID string) (*Node, error) {
var node Node var node Node
err := vfs.db.NewSelect().Model(&node). err := vfs.db.NewSelect().Model(&node).
Where("user_id = ?", userID). Where("account_id = ?", accountID).
Where("id = ?", fileID). Where("id = ?", fileID).
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
@@ -80,10 +80,10 @@ func (vfs *VirtualFS) FindNode(ctx context.Context, userID, fileID string) (*Nod
return &node, nil return &node, nil
} }
func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, userID uuid.UUID, publicID string) (*Node, error) { func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, accountID uuid.UUID, publicID string) (*Node, error) {
var node Node var node Node
err := vfs.db.NewSelect().Model(&node). err := vfs.db.NewSelect().Model(&node).
Where("user_id = ?", userID). Where("account_id = ?", accountID).
Where("public_id = ?", publicID). Where("public_id = ?", publicID).
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
@@ -104,7 +104,7 @@ func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, er
var nodes []*Node var nodes []*Node
err := vfs.db.NewSelect().Model(&nodes). err := vfs.db.NewSelect().Model(&nodes).
Where("user_id = ?", node.UserID). Where("account_id = ?", node.AccountID).
Where("parent_id = ?", node.ID). Where("parent_id = ?", node.ID).
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
@@ -119,19 +119,19 @@ func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, er
return nodes, nil return nodes, nil
} }
func (vfs *VirtualFS) CreateFile(ctx context.Context, userID uuid.UUID, opts CreateFileOptions) (*Node, error) { func (vfs *VirtualFS) CreateFile(ctx context.Context, accountID uuid.UUID, opts CreateFileOptions) (*Node, error) {
pid, err := vfs.generatePublicID() pid, err := vfs.generatePublicID()
if err != nil { if err != nil {
return nil, err return nil, err
} }
node := Node{ node := Node{
PublicID: pid, PublicID: pid,
UserID: userID, AccountID: accountID,
ParentID: opts.ParentID, ParentID: opts.ParentID,
Kind: NodeKindFile, Kind: NodeKindFile,
Status: NodeStatusPending, Status: NodeStatusPending,
Name: opts.Name, Name: opts.Name,
} }
if vfs.keyResolver.ShouldPersistKey() { if vfs.keyResolver.ShouldPersistKey() {
@@ -233,19 +233,19 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
return nil return nil
} }
func (vfs *VirtualFS) CreateDirectory(ctx context.Context, userID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) { func (vfs *VirtualFS) CreateDirectory(ctx context.Context, accountID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) {
pid, err := vfs.generatePublicID() pid, err := vfs.generatePublicID()
if err != nil { if err != nil {
return nil, err return nil, err
} }
node := Node{ node := Node{
PublicID: pid, PublicID: pid,
UserID: userID, AccountID: accountID,
ParentID: parentID, ParentID: parentID,
Kind: NodeKindDirectory, Kind: NodeKindDirectory,
Status: NodeStatusReady, Status: NodeStatusReady,
Name: name, Name: name,
} }
_, err = vfs.db.NewInsert().Model(&node).Exec(ctx) _, err = vfs.db.NewInsert().Model(&node).Exec(ctx)