diff --git a/apps/backend/internal/account/account.go b/apps/backend/internal/account/account.go new file mode 100644 index 0000000..487aa22 --- /dev/null +++ b/apps/backend/internal/account/account.go @@ -0,0 +1,23 @@ +package account + +import ( + "time" + + "github.com/google/uuid" + "github.com/uptrace/bun" +) + +type Account struct { + bun.BaseModel `bun:"accounts"` + + ID uuid.UUID `bun:",pk,type:uuid" json:"id"` + UserID uuid.UUID `bun:"user_id,notnull,type:uuid" json:"userId"` + StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"` + StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"` + CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"` + UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"` +} + +func newAccountID() (uuid.UUID, error) { + return uuid.NewV7() +} diff --git a/apps/backend/internal/account/err.go b/apps/backend/internal/account/err.go new file mode 100644 index 0000000..e2e2b3f --- /dev/null +++ b/apps/backend/internal/account/err.go @@ -0,0 +1,8 @@ +package account + +import "errors" + +var ( + ErrAccountNotFound = errors.New("account not found") + ErrAccountAlreadyExists = errors.New("account already exists") +) diff --git a/apps/backend/internal/account/http.go b/apps/backend/internal/account/http.go new file mode 100644 index 0000000..b702fea --- /dev/null +++ b/apps/backend/internal/account/http.go @@ -0,0 +1,131 @@ +package account + +import ( + "errors" + + "github.com/get-drexa/drexa/internal/auth" + "github.com/get-drexa/drexa/internal/user" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "github.com/uptrace/bun" +) + +type HTTPHandler struct { + accountService *Service + authService *auth.Service + db *bun.DB + authMiddleware fiber.Handler +} + +type registerAccountRequest struct { + Email string `json:"email"` + Password string `json:"password"` + DisplayName string `json:"displayName"` +} + +type registerAccountResponse struct { + Account *Account `json:"account"` + User *user.User `json:"user"` + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` +} + +const currentAccountKey = "currentAccount" + +func CurrentAccount(c *fiber.Ctx) *Account { + return c.Locals(currentAccountKey).(*Account) +} + +func NewHTTPHandler(accountService *Service, authService *auth.Service, db *bun.DB, authMiddleware fiber.Handler) *HTTPHandler { + return &HTTPHandler{accountService: accountService, authService: authService, db: db, authMiddleware: authMiddleware} +} + +func (h *HTTPHandler) RegisterRoutes(api fiber.Router) fiber.Router { + api.Post("/accounts", h.registerAccount) + + account := api.Group("/accounts/:accountID") + account.Use(h.authMiddleware) + account.Use(h.accountMiddleware) + + account.Get("/", h.getAccount) + + return account +} + +func (h *HTTPHandler) accountMiddleware(c *fiber.Ctx) error { + user, err := auth.AuthenticatedUser(c) + if err != nil { + return c.SendStatus(fiber.StatusUnauthorized) + } + + accountID, err := uuid.Parse(c.Params("accountID")) + if err != nil { + return c.SendStatus(fiber.StatusNotFound) + } + + account, err := h.accountService.AccountByID(c.Context(), h.db, user.ID, accountID) + if err != nil { + if errors.Is(err, ErrAccountNotFound) { + return c.SendStatus(fiber.StatusNotFound) + } + return c.SendStatus(fiber.StatusInternalServerError) + } + + c.Locals(currentAccountKey, account) + + return c.Next() +} + +func (h *HTTPHandler) getAccount(c *fiber.Ctx) error { + account := CurrentAccount(c) + if account == nil { + return c.SendStatus(fiber.StatusNotFound) + } + return c.JSON(account) +} + +func (h *HTTPHandler) registerAccount(c *fiber.Ctx) error { + req := new(registerAccountRequest) + if err := c.BodyParser(req); err != nil { + return c.SendStatus(fiber.StatusBadRequest) + } + + tx, err := h.db.BeginTx(c.Context(), nil) + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + defer tx.Rollback() + + acc, u, err := h.accountService.Register(c.Context(), tx, RegisterOptions{ + Email: req.Email, + Password: req.Password, + DisplayName: req.DisplayName, + }) + if err != nil { + var ae *user.AlreadyExistsError + if errors.As(err, &ae) { + return c.SendStatus(fiber.StatusConflict) + } + if errors.Is(err, ErrAccountAlreadyExists) { + return c.SendStatus(fiber.StatusConflict) + } + return c.SendStatus(fiber.StatusBadRequest) + } + + result, err := h.authService.GenerateTokenForUser(c.Context(), tx, u) + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + + err = tx.Commit() + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + + return c.JSON(registerAccountResponse{ + Account: acc, + User: u, + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + }) +} diff --git a/apps/backend/internal/account/service.go b/apps/backend/internal/account/service.go new file mode 100644 index 0000000..9f4c468 --- /dev/null +++ b/apps/backend/internal/account/service.go @@ -0,0 +1,107 @@ +package account + +import ( + "context" + "database/sql" + "errors" + + "github.com/get-drexa/drexa/internal/database" + "github.com/get-drexa/drexa/internal/password" + "github.com/get-drexa/drexa/internal/user" + "github.com/google/uuid" + "github.com/uptrace/bun" +) + +type Service struct { + userService user.Service +} + +type RegisterOptions struct { + Email string + Password string + DisplayName string +} + +type CreateAccountOptions struct { + OrganizationID uuid.UUID + QuotaBytes int64 +} + +func NewService(userService *user.Service) *Service { + return &Service{ + userService: *userService, + } +} + +func (s *Service) Register(ctx context.Context, db bun.IDB, opts RegisterOptions) (*Account, *user.User, error) { + hashed, err := password.Hash(opts.Password) + if err != nil { + return nil, nil, err + } + + u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{ + Email: opts.Email, + Password: hashed, + DisplayName: opts.DisplayName, + }) + if err != nil { + return nil, nil, err + } + + acc, err := s.CreateAccount(ctx, db, u.ID, CreateAccountOptions{ + // TODO: make quota configurable + QuotaBytes: 1024 * 1024 * 1024, // 1GB + }) + if err != nil { + return nil, nil, err + } + + return acc, u, nil +} + +func (s *Service) CreateAccount(ctx context.Context, db bun.IDB, userID uuid.UUID, opts CreateAccountOptions) (*Account, error) { + id, err := newAccountID() + if err != nil { + return nil, err + } + + account := &Account{ + ID: id, + UserID: userID, + StorageQuotaBytes: opts.QuotaBytes, + } + + _, err = db.NewInsert().Model(account).Returning("*").Exec(ctx) + if err != nil { + if database.IsUniqueViolation(err) { + return nil, ErrAccountAlreadyExists + } + return nil, err + } + + return account, nil +} + +func (s *Service) AccountByUserID(ctx context.Context, db bun.IDB, userID uuid.UUID) (*Account, error) { + var account Account + err := db.NewSelect().Model(&account).Where("user_id = ?", userID).Scan(ctx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrAccountNotFound + } + return nil, err + } + return &account, nil +} + +func (s *Service) AccountByID(ctx context.Context, db bun.IDB, userID uuid.UUID, id uuid.UUID) (*Account, error) { + var account Account + err := db.NewSelect().Model(&account).Where("user_id = ?", userID).Where("id = ?", id).Scan(ctx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrAccountNotFound + } + return nil, err + } + return &account, nil +} diff --git a/apps/backend/internal/auth/http.go b/apps/backend/internal/auth/http.go index b9459ca..64b607f 100644 --- a/apps/backend/internal/auth/http.go +++ b/apps/backend/internal/auth/http.go @@ -2,7 +2,6 @@ package auth import ( "errors" - "log/slog" "github.com/get-drexa/drexa/internal/user" "github.com/gofiber/fiber/v2" @@ -14,12 +13,6 @@ type loginRequest struct { Password string `json:"password"` } -type registerRequest struct { - Email string `json:"email"` - Password string `json:"password"` - DisplayName string `json:"displayName"` -} - type loginResponse struct { User user.User `json:"user"` AccessToken string `json:"accessToken"` @@ -37,9 +30,7 @@ func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler { func (h *HTTPHandler) RegisterRoutes(api fiber.Router) { auth := api.Group("/auth") - auth.Post("/login", h.Login) - auth.Post("/register", h.Register) } func (h *HTTPHandler) Login(c *fiber.Ctx) error { @@ -54,7 +45,7 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error { } defer tx.Rollback() - result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password) + result, err := h.service.AuthenticateWithEmailAndPassword(c.Context(), tx, req.Email, req.Password) if err != nil { if errors.Is(err, ErrInvalidCredentials) { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"}) @@ -72,42 +63,3 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error { RefreshToken: result.RefreshToken, }) } - -func (h *HTTPHandler) Register(c *fiber.Ctx) error { - req := new(registerRequest) - if err := c.BodyParser(req); err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) - } - - tx, err := h.db.BeginTx(c.Context(), nil) - if err != nil { - slog.Error("failed to begin transaction", "error", err) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) - } - defer tx.Rollback() - - result, err := h.service.Register(c.Context(), tx, registerOptions{ - email: req.Email, - password: req.Password, - displayName: req.DisplayName, - }) - if err != nil { - var ae *user.AlreadyExistsError - if errors.As(err, &ae) { - return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"}) - } - slog.Error("failed to register user", "error", err) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) - } - - if err := tx.Commit(); err != nil { - slog.Error("failed to commit transaction", "error", err) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) - } - - return c.JSON(loginResponse{ - User: *result.User, - AccessToken: result.AccessToken, - RefreshToken: result.RefreshToken, - }) -} diff --git a/apps/backend/internal/auth/service.go b/apps/backend/internal/auth/service.go index 9ef9c58..24671d1 100644 --- a/apps/backend/internal/auth/service.go +++ b/apps/backend/internal/auth/service.go @@ -11,7 +11,7 @@ import ( "github.com/uptrace/bun" ) -type LoginResult struct { +type AuthenticationResult struct { User *user.User AccessToken string RefreshToken string @@ -24,12 +24,6 @@ type Service struct { tokenConfig TokenConfig } -type registerOptions struct { - displayName string - email string - password string -} - func NewService(userService *user.Service, tokenConfig TokenConfig) *Service { return &Service{ userService: userService, @@ -37,7 +31,30 @@ func NewService(userService *user.Service, tokenConfig TokenConfig) *Service { } } -func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*LoginResult, error) { +func (s *Service) GenerateTokenForUser(ctx context.Context, db bun.IDB, user *user.User) (*AuthenticationResult, error) { + at, err := GenerateAccessToken(user, &s.tokenConfig) + if err != nil { + return nil, err + } + + rt, err := GenerateRefreshToken(user, &s.tokenConfig) + if err != nil { + return nil, err + } + + _, err = db.NewInsert().Model(rt).Exec(ctx) + if err != nil { + return nil, err + } + + return &AuthenticationResult{ + User: user, + AccessToken: at, + RefreshToken: hex.EncodeToString(rt.Token), + }, nil +} + +func (s *Service) AuthenticateWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*AuthenticationResult, error) { u, err := s.userService.UserByEmail(ctx, db, email) if err != nil { var nf *user.NotFoundError @@ -67,44 +84,7 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, ema return nil, err } - return &LoginResult{ - User: u, - AccessToken: at, - RefreshToken: hex.EncodeToString(rt.Token), - }, nil -} - -func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions) (*LoginResult, error) { - hashed, err := password.Hash(opts.password) - if err != nil { - return nil, err - } - - u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{ - Email: opts.email, - DisplayName: opts.displayName, - Password: hashed, - }) - if err != nil { - return nil, err - } - - at, err := GenerateAccessToken(u, &s.tokenConfig) - if err != nil { - return nil, err - } - - rt, err := GenerateRefreshToken(u, &s.tokenConfig) - if err != nil { - return nil, err - } - - _, err = db.NewInsert().Model(rt).Exec(ctx) - if err != nil { - return nil, err - } - - return &LoginResult{ + return &AuthenticationResult{ User: u, AccessToken: at, RefreshToken: hex.EncodeToString(rt.Token), diff --git a/apps/backend/internal/database/migrations/001_initial.up.sql b/apps/backend/internal/database/migrations/001_initial.up.sql index 5bda205..096bdc5 100644 --- a/apps/backend/internal/database/migrations/001_initial.up.sql +++ b/apps/backend/internal/database/migrations/001_initial.up.sql @@ -7,14 +7,23 @@ CREATE TABLE IF NOT EXISTS users ( display_name TEXT, email TEXT NOT NULL UNIQUE, password TEXT NOT NULL, - storage_usage_bytes BIGINT NOT NULL, - storage_quota_bytes BIGINT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); CREATE INDEX idx_users_email ON users(email); +CREATE TABLE IF NOT EXISTS accounts ( + id UUID PRIMARY KEY, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + storage_usage_bytes BIGINT NOT NULL DEFAULT 0, + storage_quota_bytes BIGINT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX idx_accounts_user_id ON accounts(user_id); + CREATE TABLE IF NOT EXISTS refresh_tokens ( id UUID PRIMARY KEY, user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, @@ -31,7 +40,7 @@ CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at); CREATE TABLE IF NOT EXISTS vfs_nodes ( id UUID PRIMARY KEY, public_id TEXT NOT NULL UNIQUE, -- opaque ID for external API (no timestamp leak) - user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + account_id UUID NOT NULL REFERENCES accounts(id) ON DELETE CASCADE, parent_id UUID REFERENCES vfs_nodes(id) ON DELETE CASCADE, -- NULL = root directory kind TEXT NOT NULL CHECK (kind IN ('file', 'directory')), status TEXT NOT NULL DEFAULT 'ready' CHECK (status IN ('pending', 'ready')), @@ -46,17 +55,17 @@ CREATE TABLE IF NOT EXISTS vfs_nodes ( updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), deleted_at TIMESTAMPTZ, -- soft delete for trash - -- No duplicate names in same parent (per user, excluding deleted) - CONSTRAINT unique_node_name UNIQUE NULLS NOT DISTINCT (user_id, parent_id, name, deleted_at) + -- No duplicate names in same parent (per account, excluding deleted) + CONSTRAINT unique_node_name UNIQUE NULLS NOT DISTINCT (account_id, parent_id, name, deleted_at) ); -CREATE INDEX idx_vfs_nodes_user_id ON vfs_nodes(user_id) WHERE deleted_at IS NULL; +CREATE INDEX idx_vfs_nodes_account_id ON vfs_nodes(account_id) WHERE deleted_at IS NULL; CREATE INDEX idx_vfs_nodes_parent_id ON vfs_nodes(parent_id) WHERE deleted_at IS NULL; -CREATE INDEX idx_vfs_nodes_user_parent ON vfs_nodes(user_id, parent_id) WHERE deleted_at IS NULL; -CREATE INDEX idx_vfs_nodes_kind ON vfs_nodes(user_id, kind) WHERE deleted_at IS NULL; -CREATE INDEX idx_vfs_nodes_deleted ON vfs_nodes(user_id, deleted_at) WHERE deleted_at IS NOT NULL; +CREATE INDEX idx_vfs_nodes_account_parent ON vfs_nodes(account_id, parent_id) WHERE deleted_at IS NULL; +CREATE INDEX idx_vfs_nodes_kind ON vfs_nodes(account_id, kind) WHERE deleted_at IS NULL; +CREATE INDEX idx_vfs_nodes_deleted ON vfs_nodes(account_id, deleted_at) WHERE deleted_at IS NOT NULL; CREATE INDEX idx_vfs_nodes_public_id ON vfs_nodes(public_id); -CREATE UNIQUE INDEX idx_vfs_nodes_user_root ON vfs_nodes(user_id) WHERE parent_id IS NULL; -- one root per user +CREATE UNIQUE INDEX idx_vfs_nodes_account_root ON vfs_nodes(account_id) WHERE parent_id IS NULL; -- one root per account CREATE INDEX idx_vfs_nodes_pending ON vfs_nodes(created_at) WHERE status = 'pending'; -- for cleanup job CREATE TABLE IF NOT EXISTS node_shares ( @@ -91,4 +100,7 @@ CREATE TRIGGER update_vfs_nodes_updated_at BEFORE UPDATE ON vfs_nodes FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); CREATE TRIGGER update_node_shares_updated_at BEFORE UPDATE ON node_shares + FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); + +CREATE TRIGGER update_accounts_updated_at BEFORE UPDATE ON accounts FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); \ No newline at end of file diff --git a/apps/backend/internal/drexa/server.go b/apps/backend/internal/drexa/server.go index 3cd0d0e..25f7aae 100644 --- a/apps/backend/internal/drexa/server.go +++ b/apps/backend/internal/drexa/server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/get-drexa/drexa/internal/account" "github.com/get-drexa/drexa/internal/auth" "github.com/get-drexa/drexa/internal/blob" "github.com/get-drexa/drexa/internal/database" @@ -61,12 +62,16 @@ func NewServer(c Config) (*fiber.App, error) { SecretKey: c.JWT.SecretKey, }) uploadService := upload.NewService(vfs, blobStore) + accountService := account.NewService(userService) authMiddleware := auth.NewBearerAuthMiddleware(authService, db) api := app.Group("/api") + + accRouter := account.NewHTTPHandler(accountService, authService, db, authMiddleware).RegisterRoutes(api) + auth.NewHTTPHandler(authService, db).RegisterRoutes(api) - upload.NewHTTPHandler(uploadService, authMiddleware).RegisterRoutes(api, authMiddleware) + upload.NewHTTPHandler(uploadService).RegisterRoutes(accRouter) return app, nil } diff --git a/apps/backend/internal/upload/http.go b/apps/backend/internal/upload/http.go index fd7aa7f..e59a4a5 100644 --- a/apps/backend/internal/upload/http.go +++ b/apps/backend/internal/upload/http.go @@ -3,7 +3,7 @@ package upload import ( "errors" - "github.com/get-drexa/drexa/internal/auth" + "github.com/get-drexa/drexa/internal/account" "github.com/gofiber/fiber/v2" ) @@ -17,17 +17,15 @@ type updateUploadRequest struct { } type HTTPHandler struct { - service *Service - authMiddleware fiber.Handler + service *Service } -func NewHTTPHandler(s *Service, authMiddleware fiber.Handler) *HTTPHandler { - return &HTTPHandler{service: s, authMiddleware: authMiddleware} +func NewHTTPHandler(s *Service) *HTTPHandler { + return &HTTPHandler{service: s} } -func (h *HTTPHandler) RegisterRoutes(api fiber.Router, authMiddleware fiber.Handler) { +func (h *HTTPHandler) RegisterRoutes(api fiber.Router) { upload := api.Group("/uploads") - upload.Use(authMiddleware) upload.Post("/", h.Create) upload.Put("/:uploadID/content", h.ReceiveContent) @@ -35,9 +33,9 @@ func (h *HTTPHandler) RegisterRoutes(api fiber.Router, authMiddleware fiber.Hand } func (h *HTTPHandler) Create(c *fiber.Ctx) error { - u, err := auth.AuthenticatedUser(c) - if err != nil { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"}) + account := account.CurrentAccount(c) + if account == nil { + return c.SendStatus(fiber.StatusUnauthorized) } req := new(createUploadRequest) @@ -45,7 +43,7 @@ func (h *HTTPHandler) Create(c *fiber.Ctx) error { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) } - upload, err := h.service.CreateUpload(c.Context(), u.ID, CreateUploadOptions{ + upload, err := h.service.CreateUpload(c.Context(), account.ID, CreateUploadOptions{ ParentID: req.ParentID, Name: req.Name, }) @@ -57,14 +55,14 @@ func (h *HTTPHandler) Create(c *fiber.Ctx) error { } func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error { - u, err := auth.AuthenticatedUser(c) - if err != nil { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"}) + account := account.CurrentAccount(c) + if account == nil { + return c.SendStatus(fiber.StatusUnauthorized) } uploadID := c.Params("uploadID") - err = h.service.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream()) + err := h.service.ReceiveUpload(c.Context(), account.ID, uploadID, c.Request().BodyStream()) defer c.Request().CloseBodyStream() if err != nil { return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) @@ -74,9 +72,9 @@ func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error { } func (h *HTTPHandler) Update(c *fiber.Ctx) error { - u, err := auth.AuthenticatedUser(c) - if err != nil { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"}) + account := account.CurrentAccount(c) + if account == nil { + return c.SendStatus(fiber.StatusUnauthorized) } req := new(updateUploadRequest) @@ -85,7 +83,7 @@ func (h *HTTPHandler) Update(c *fiber.Ctx) error { } if req.Status == StatusCompleted { - upload, err := h.service.CompleteUpload(c.Context(), u.ID, c.Params("uploadID")) + upload, err := h.service.CompleteUpload(c.Context(), account.ID, c.Params("uploadID")) if err != nil { if errors.Is(err, ErrNotFound) { return c.SendStatus(fiber.StatusNotFound) diff --git a/apps/backend/internal/upload/service.go b/apps/backend/internal/upload/service.go index 0a5a742..58c2160 100644 --- a/apps/backend/internal/upload/service.go +++ b/apps/backend/internal/upload/service.go @@ -33,8 +33,8 @@ type CreateUploadOptions struct { Name string } -func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts CreateUploadOptions) (*Upload, error) { - parentNode, err := s.vfs.FindNodeByPublicID(ctx, userID, opts.ParentID) +func (s *Service) CreateUpload(ctx context.Context, accountID uuid.UUID, opts CreateUploadOptions) (*Upload, error) { + parentNode, err := s.vfs.FindNodeByPublicID(ctx, accountID, opts.ParentID) if err != nil { if errors.Is(err, virtualfs.ErrNodeNotFound) { return nil, ErrNotFound @@ -46,7 +46,7 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat return nil, ErrParentNotDirectory } - node, err := s.vfs.CreateFile(ctx, userID, virtualfs.CreateFileOptions{ + node, err := s.vfs.CreateFile(ctx, accountID, virtualfs.CreateFileOptions{ ParentID: parentNode.ID, Name: opts.Name, }) @@ -77,7 +77,7 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat return upload, nil } -func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID string, reader io.Reader) error { +func (s *Service) ReceiveUpload(ctx context.Context, accountID uuid.UUID, uploadID string, reader io.Reader) error { n, ok := s.pendingUploads.Load(uploadID) if !ok { return ErrNotFound @@ -88,7 +88,7 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID return ErrNotFound } - if upload.TargetNode.UserID != userID { + if upload.TargetNode.AccountID != accountID { return ErrNotFound } @@ -102,7 +102,7 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID return nil } -func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID string) (*Upload, error) { +func (s *Service) CompleteUpload(ctx context.Context, accountID uuid.UUID, uploadID string) (*Upload, error) { n, ok := s.pendingUploads.Load(uploadID) if !ok { return nil, ErrNotFound @@ -113,7 +113,7 @@ func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID return nil, ErrNotFound } - if upload.TargetNode.UserID != userID { + if upload.TargetNode.AccountID != accountID { return nil, ErrNotFound } diff --git a/apps/backend/internal/user/user.go b/apps/backend/internal/user/user.go index e4469b5..f1bb44f 100644 --- a/apps/backend/internal/user/user.go +++ b/apps/backend/internal/user/user.go @@ -11,14 +11,12 @@ import ( type User struct { bun.BaseModel `bun:"users"` - ID uuid.UUID `bun:",pk,type:uuid" json:"id"` - DisplayName string `bun:"display_name" json:"displayName"` - Email string `bun:"email,unique,notnull" json:"email"` - Password password.Hashed `bun:"password,notnull" json:"-"` - StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"` - StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"` - CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"` - UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"` + ID uuid.UUID `bun:",pk,type:uuid" json:"id"` + DisplayName string `bun:"display_name" json:"displayName"` + Email string `bun:"email,unique,notnull" json:"email"` + Password password.Hashed `bun:"password,notnull" json:"-"` + CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"` + UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"` } func newUserID() (uuid.UUID, error) { diff --git a/apps/backend/internal/virtualfs/node.go b/apps/backend/internal/virtualfs/node.go index 3aab117..9d5a044 100644 --- a/apps/backend/internal/virtualfs/node.go +++ b/apps/backend/internal/virtualfs/node.go @@ -25,13 +25,13 @@ const ( type Node struct { bun.BaseModel `bun:"vfs_nodes"` - ID uuid.UUID `bun:",pk,type:uuid"` - PublicID string `bun:"public_id,notnull"` - UserID uuid.UUID `bun:"user_id,notnull"` - ParentID uuid.UUID `bun:"parent_id,notnull"` - Kind NodeKind `bun:"kind,notnull"` - Status NodeStatus `bun:"status,notnull"` - Name string `bun:"name,notnull"` + ID uuid.UUID `bun:",pk,type:uuid"` + PublicID string `bun:"public_id,notnull"` + AccountID uuid.UUID `bun:"account_id,notnull,type:uuid"` + ParentID uuid.UUID `bun:"parent_id"` + Kind NodeKind `bun:"kind,notnull"` + Status NodeStatus `bun:"status,notnull"` + Name string `bun:"name,notnull"` BlobKey blob.Key `bun:"blob_key"` Size int64 `bun:"size"` diff --git a/apps/backend/internal/virtualfs/vfs.go b/apps/backend/internal/virtualfs/vfs.go index 44b8857..2ea96bb 100644 --- a/apps/backend/internal/virtualfs/vfs.go +++ b/apps/backend/internal/virtualfs/vfs.go @@ -63,10 +63,10 @@ func NewVirtualFS(db *bun.DB, blobStore blob.Store, keyResolver BlobKeyResolver) }, nil } -func (vfs *VirtualFS) FindNode(ctx context.Context, userID, fileID string) (*Node, error) { +func (vfs *VirtualFS) FindNode(ctx context.Context, accountID, fileID string) (*Node, error) { var node Node err := vfs.db.NewSelect().Model(&node). - Where("user_id = ?", userID). + Where("account_id = ?", accountID). Where("id = ?", fileID). Where("status = ?", NodeStatusReady). Where("deleted_at IS NULL"). @@ -80,10 +80,10 @@ func (vfs *VirtualFS) FindNode(ctx context.Context, userID, fileID string) (*Nod return &node, nil } -func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, userID uuid.UUID, publicID string) (*Node, error) { +func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, accountID uuid.UUID, publicID string) (*Node, error) { var node Node err := vfs.db.NewSelect().Model(&node). - Where("user_id = ?", userID). + Where("account_id = ?", accountID). Where("public_id = ?", publicID). Where("status = ?", NodeStatusReady). Where("deleted_at IS NULL"). @@ -104,7 +104,7 @@ func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, er var nodes []*Node err := vfs.db.NewSelect().Model(&nodes). - Where("user_id = ?", node.UserID). + Where("account_id = ?", node.AccountID). Where("parent_id = ?", node.ID). Where("status = ?", NodeStatusReady). Where("deleted_at IS NULL"). @@ -119,19 +119,19 @@ func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, er return nodes, nil } -func (vfs *VirtualFS) CreateFile(ctx context.Context, userID uuid.UUID, opts CreateFileOptions) (*Node, error) { +func (vfs *VirtualFS) CreateFile(ctx context.Context, accountID uuid.UUID, opts CreateFileOptions) (*Node, error) { pid, err := vfs.generatePublicID() if err != nil { return nil, err } node := Node{ - PublicID: pid, - UserID: userID, - ParentID: opts.ParentID, - Kind: NodeKindFile, - Status: NodeStatusPending, - Name: opts.Name, + PublicID: pid, + AccountID: accountID, + ParentID: opts.ParentID, + Kind: NodeKindFile, + Status: NodeStatusPending, + Name: opts.Name, } if vfs.keyResolver.ShouldPersistKey() { @@ -233,19 +233,19 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon return nil } -func (vfs *VirtualFS) CreateDirectory(ctx context.Context, userID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) { +func (vfs *VirtualFS) CreateDirectory(ctx context.Context, accountID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) { pid, err := vfs.generatePublicID() if err != nil { return nil, err } node := Node{ - PublicID: pid, - UserID: userID, - ParentID: parentID, - Kind: NodeKindDirectory, - Status: NodeStatusReady, - Name: name, + PublicID: pid, + AccountID: accountID, + ParentID: parentID, + Kind: NodeKindDirectory, + Status: NodeStatusReady, + Name: name, } _, err = vfs.db.NewInsert().Model(&node).Exec(ctx)