diff --git a/apps/backend/internal/account/service.go b/apps/backend/internal/account/service.go index 9f4c468..749f313 100644 --- a/apps/backend/internal/account/service.go +++ b/apps/backend/internal/account/service.go @@ -8,12 +8,14 @@ import ( "github.com/get-drexa/drexa/internal/database" "github.com/get-drexa/drexa/internal/password" "github.com/get-drexa/drexa/internal/user" + "github.com/get-drexa/drexa/internal/virtualfs" "github.com/google/uuid" "github.com/uptrace/bun" ) type Service struct { userService user.Service + vfs *virtualfs.VirtualFS } type RegisterOptions struct { @@ -27,9 +29,10 @@ type CreateAccountOptions struct { QuotaBytes int64 } -func NewService(userService *user.Service) *Service { +func NewService(userService *user.Service, vfs *virtualfs.VirtualFS) *Service { return &Service{ userService: *userService, + vfs: vfs, } } diff --git a/apps/backend/internal/drexa/server.go b/apps/backend/internal/drexa/server.go index 25f7aae..e26e2c3 100644 --- a/apps/backend/internal/drexa/server.go +++ b/apps/backend/internal/drexa/server.go @@ -50,7 +50,7 @@ func NewServer(c Config) (*fiber.App, error) { return nil, fmt.Errorf("unknown storage mode: %s", c.Storage.Mode) } - vfs, err := virtualfs.NewVirtualFS(db, blobStore, keyResolver) + vfs, err := virtualfs.NewVirtualFS(blobStore, keyResolver) if err != nil { return nil, fmt.Errorf("failed to create virtual file system: %w", err) } @@ -62,7 +62,7 @@ func NewServer(c Config) (*fiber.App, error) { SecretKey: c.JWT.SecretKey, }) uploadService := upload.NewService(vfs, blobStore) - accountService := account.NewService(userService) + accountService := account.NewService(userService, vfs) authMiddleware := auth.NewBearerAuthMiddleware(authService, db) @@ -71,7 +71,7 @@ func NewServer(c Config) (*fiber.App, error) { accRouter := account.NewHTTPHandler(accountService, authService, db, authMiddleware).RegisterRoutes(api) auth.NewHTTPHandler(authService, db).RegisterRoutes(api) - upload.NewHTTPHandler(uploadService).RegisterRoutes(accRouter) + upload.NewHTTPHandler(uploadService, db).RegisterRoutes(accRouter) return app, nil } diff --git a/apps/backend/internal/upload/http.go b/apps/backend/internal/upload/http.go index e59a4a5..dab07a3 100644 --- a/apps/backend/internal/upload/http.go +++ b/apps/backend/internal/upload/http.go @@ -5,6 +5,7 @@ import ( "github.com/get-drexa/drexa/internal/account" "github.com/gofiber/fiber/v2" + "github.com/uptrace/bun" ) type createUploadRequest struct { @@ -18,10 +19,11 @@ type updateUploadRequest struct { type HTTPHandler struct { service *Service + db *bun.DB } -func NewHTTPHandler(s *Service) *HTTPHandler { - return &HTTPHandler{service: s} +func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler { + return &HTTPHandler{service: s, db: db} } func (h *HTTPHandler) RegisterRoutes(api fiber.Router) { @@ -43,7 +45,7 @@ func (h *HTTPHandler) Create(c *fiber.Ctx) error { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) } - upload, err := h.service.CreateUpload(c.Context(), account.ID, CreateUploadOptions{ + upload, err := h.service.CreateUpload(c.Context(), h.db, account.ID, CreateUploadOptions{ ParentID: req.ParentID, Name: req.Name, }) @@ -62,7 +64,7 @@ func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error { uploadID := c.Params("uploadID") - err := h.service.ReceiveUpload(c.Context(), account.ID, uploadID, c.Request().BodyStream()) + err := h.service.ReceiveUpload(c.Context(), h.db, account.ID, uploadID, c.Request().BodyStream()) defer c.Request().CloseBodyStream() if err != nil { return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) @@ -83,7 +85,7 @@ func (h *HTTPHandler) Update(c *fiber.Ctx) error { } if req.Status == StatusCompleted { - upload, err := h.service.CompleteUpload(c.Context(), account.ID, c.Params("uploadID")) + upload, err := h.service.CompleteUpload(c.Context(), h.db, account.ID, c.Params("uploadID")) if err != nil { if errors.Is(err, ErrNotFound) { return c.SendStatus(fiber.StatusNotFound) diff --git a/apps/backend/internal/upload/service.go b/apps/backend/internal/upload/service.go index 58c2160..0b7f607 100644 --- a/apps/backend/internal/upload/service.go +++ b/apps/backend/internal/upload/service.go @@ -10,6 +10,7 @@ import ( "github.com/get-drexa/drexa/internal/blob" "github.com/get-drexa/drexa/internal/virtualfs" "github.com/google/uuid" + "github.com/uptrace/bun" ) type Service struct { @@ -33,8 +34,8 @@ type CreateUploadOptions struct { Name string } -func (s *Service) CreateUpload(ctx context.Context, accountID uuid.UUID, opts CreateUploadOptions) (*Upload, error) { - parentNode, err := s.vfs.FindNodeByPublicID(ctx, accountID, opts.ParentID) +func (s *Service) CreateUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, opts CreateUploadOptions) (*Upload, error) { + parentNode, err := s.vfs.FindNodeByPublicID(ctx, db, accountID, opts.ParentID) if err != nil { if errors.Is(err, virtualfs.ErrNodeNotFound) { return nil, ErrNotFound @@ -46,7 +47,7 @@ func (s *Service) CreateUpload(ctx context.Context, accountID uuid.UUID, opts Cr return nil, ErrParentNotDirectory } - node, err := s.vfs.CreateFile(ctx, accountID, virtualfs.CreateFileOptions{ + node, err := s.vfs.CreateFile(ctx, db, accountID, virtualfs.CreateFileOptions{ ParentID: parentNode.ID, Name: opts.Name, }) @@ -61,7 +62,7 @@ func (s *Service) CreateUpload(ctx context.Context, accountID uuid.UUID, opts Cr Duration: 1 * time.Hour, }) if err != nil { - _ = s.vfs.PermanentlyDeleteNode(ctx, node) + _ = s.vfs.PermanentlyDeleteNode(ctx, db, node) return nil, err } @@ -77,7 +78,7 @@ func (s *Service) CreateUpload(ctx context.Context, accountID uuid.UUID, opts Cr return upload, nil } -func (s *Service) ReceiveUpload(ctx context.Context, accountID uuid.UUID, uploadID string, reader io.Reader) error { +func (s *Service) ReceiveUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, uploadID string, reader io.Reader) error { n, ok := s.pendingUploads.Load(uploadID) if !ok { return ErrNotFound @@ -92,7 +93,7 @@ func (s *Service) ReceiveUpload(ctx context.Context, accountID uuid.UUID, upload return ErrNotFound } - err := s.vfs.WriteFile(ctx, upload.TargetNode, virtualfs.FileContentFromReader(reader)) + err := s.vfs.WriteFile(ctx, db, upload.TargetNode, virtualfs.FileContentFromReader(reader)) if err != nil { return err } @@ -102,7 +103,7 @@ func (s *Service) ReceiveUpload(ctx context.Context, accountID uuid.UUID, upload return nil } -func (s *Service) CompleteUpload(ctx context.Context, accountID uuid.UUID, uploadID string) (*Upload, error) { +func (s *Service) CompleteUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, uploadID string) (*Upload, error) { n, ok := s.pendingUploads.Load(uploadID) if !ok { return nil, ErrNotFound @@ -121,7 +122,7 @@ func (s *Service) CompleteUpload(ctx context.Context, accountID uuid.UUID, uploa return upload, nil } - err := s.vfs.WriteFile(ctx, upload.TargetNode, virtualfs.FileContentFromBlobKey(upload.TargetNode.BlobKey)) + err := s.vfs.WriteFile(ctx, db, upload.TargetNode, virtualfs.FileContentFromBlobKey(upload.TargetNode.BlobKey)) if err != nil { return nil, err } diff --git a/apps/backend/internal/virtualfs/path.go b/apps/backend/internal/virtualfs/path.go index 2f9ebd2..23b725e 100644 --- a/apps/backend/internal/virtualfs/path.go +++ b/apps/backend/internal/virtualfs/path.go @@ -29,7 +29,7 @@ func JoinPath(parts ...string) string { return strings.Join(parts, "/") } -func buildNodeAbsolutePath(ctx context.Context, db *bun.DB, nodeID uuid.UUID) (string, error) { +func buildNodeAbsolutePath(ctx context.Context, db bun.IDB, nodeID uuid.UUID) (string, error) { var path []string err := db.NewRaw(absolutePathQuery, nodeID).Scan(ctx, &path) if err != nil { diff --git a/apps/backend/internal/virtualfs/vfs.go b/apps/backend/internal/virtualfs/vfs.go index 2ea96bb..990bff2 100644 --- a/apps/backend/internal/virtualfs/vfs.go +++ b/apps/backend/internal/virtualfs/vfs.go @@ -19,7 +19,6 @@ import ( ) type VirtualFS struct { - db *bun.DB blobStore blob.Store keyResolver BlobKeyResolver @@ -50,22 +49,21 @@ func FileContentFromBlobKey(blobKey blob.Key) FileContent { return FileContent{blobKey: blobKey} } -func NewVirtualFS(db *bun.DB, blobStore blob.Store, keyResolver BlobKeyResolver) (*VirtualFS, error) { +func NewVirtualFS(blobStore blob.Store, keyResolver BlobKeyResolver) (*VirtualFS, error) { sqid, err := sqids.New() if err != nil { return nil, err } return &VirtualFS{ - db: db, blobStore: blobStore, keyResolver: keyResolver, sqid: sqid, }, nil } -func (vfs *VirtualFS) FindNode(ctx context.Context, accountID, fileID string) (*Node, error) { +func (vfs *VirtualFS) FindNode(ctx context.Context, db bun.IDB, accountID, fileID string) (*Node, error) { var node Node - err := vfs.db.NewSelect().Model(&node). + err := db.NewSelect().Model(&node). Where("account_id = ?", accountID). Where("id = ?", fileID). Where("status = ?", NodeStatusReady). @@ -80,9 +78,9 @@ func (vfs *VirtualFS) FindNode(ctx context.Context, accountID, fileID string) (* return &node, nil } -func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, accountID uuid.UUID, publicID string) (*Node, error) { +func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, db bun.IDB, accountID uuid.UUID, publicID string) (*Node, error) { var node Node - err := vfs.db.NewSelect().Model(&node). + err := db.NewSelect().Model(&node). Where("account_id = ?", accountID). Where("public_id = ?", publicID). Where("status = ?", NodeStatusReady). @@ -97,13 +95,13 @@ func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, accountID uuid.UUI return &node, nil } -func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, error) { +func (vfs *VirtualFS) ListChildren(ctx context.Context, db bun.IDB, node *Node) ([]*Node, error) { if !node.IsAccessible() { return nil, ErrNodeNotFound } var nodes []*Node - err := vfs.db.NewSelect().Model(&nodes). + err := db.NewSelect().Model(&nodes). Where("account_id = ?", node.AccountID). Where("parent_id = ?", node.ID). Where("status = ?", NodeStatusReady). @@ -119,7 +117,7 @@ func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, er return nodes, nil } -func (vfs *VirtualFS) CreateFile(ctx context.Context, accountID uuid.UUID, opts CreateFileOptions) (*Node, error) { +func (vfs *VirtualFS) CreateFile(ctx context.Context, db bun.IDB, accountID uuid.UUID, opts CreateFileOptions) (*Node, error) { pid, err := vfs.generatePublicID() if err != nil { return nil, err @@ -141,7 +139,7 @@ func (vfs *VirtualFS) CreateFile(ctx context.Context, accountID uuid.UUID, opts } } - _, err = vfs.db.NewInsert().Model(&node).Returning("*").Exec(ctx) + _, err = db.NewInsert().Model(&node).Returning("*").Exec(ctx) if err != nil { if database.IsUniqueViolation(err) { return nil, ErrNodeConflict @@ -152,7 +150,7 @@ func (vfs *VirtualFS) CreateFile(ctx context.Context, accountID uuid.UUID, opts return &node, nil } -func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileContent) error { +func (vfs *VirtualFS) WriteFile(ctx context.Context, db bun.IDB, node *Node, content FileContent) error { if content.reader == nil && content.blobKey.IsNil() { return blob.ErrInvalidFileContent } @@ -222,7 +220,7 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon setCols = append(setCols, "mime_type", "blob_key", "size", "status") } - _, err := vfs.db.NewUpdate().Model(&node). + _, err := db.NewUpdate().Model(&node). Column(setCols...). WherePK(). Exec(ctx) @@ -233,7 +231,7 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon return nil } -func (vfs *VirtualFS) CreateDirectory(ctx context.Context, accountID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) { +func (vfs *VirtualFS) CreateDirectory(ctx context.Context, db bun.IDB, accountID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) { pid, err := vfs.generatePublicID() if err != nil { return nil, err @@ -248,7 +246,7 @@ func (vfs *VirtualFS) CreateDirectory(ctx context.Context, accountID uuid.UUID, Name: name, } - _, err = vfs.db.NewInsert().Model(&node).Exec(ctx) + _, err = db.NewInsert().Model(&node).Exec(ctx) if err != nil { if database.IsUniqueViolation(err) { return nil, ErrNodeConflict @@ -259,12 +257,12 @@ func (vfs *VirtualFS) CreateDirectory(ctx context.Context, accountID uuid.UUID, return &node, nil } -func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, node *Node) error { +func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, db bun.IDB, node *Node) error { if !node.IsAccessible() { return ErrNodeNotFound } - _, err := vfs.db.NewUpdate().Model(node). + _, err := db.NewUpdate().Model(node). WherePK(). Where("deleted_at IS NULL"). Where("status = ?", NodeStatusReady). @@ -281,12 +279,12 @@ func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, node *Node) error { return nil } -func (vfs *VirtualFS) RestoreNode(ctx context.Context, node *Node) error { +func (vfs *VirtualFS) RestoreNode(ctx context.Context, db bun.IDB, node *Node) error { if node.Status != NodeStatusReady { return ErrNodeNotFound } - _, err := vfs.db.NewUpdate().Model(node). + _, err := db.NewUpdate().Model(node). WherePK(). Where("deleted_at IS NOT NULL"). Set("deleted_at = NULL"). @@ -302,12 +300,12 @@ func (vfs *VirtualFS) RestoreNode(ctx context.Context, node *Node) error { return nil } -func (vfs *VirtualFS) RenameNode(ctx context.Context, node *Node, name string) error { +func (vfs *VirtualFS) RenameNode(ctx context.Context, db bun.IDB, node *Node, name string) error { if !node.IsAccessible() { return ErrNodeNotFound } - _, err := vfs.db.NewUpdate().Model(node). + _, err := db.NewUpdate().Model(node). WherePK(). Where("status = ?", NodeStatusReady). Where("deleted_at IS NULL"). @@ -323,7 +321,7 @@ func (vfs *VirtualFS) RenameNode(ctx context.Context, node *Node, name string) e return nil } -func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UUID) error { +func (vfs *VirtualFS) MoveNode(ctx context.Context, db bun.IDB, node *Node, parentID uuid.UUID) error { if !node.IsAccessible() { return ErrNodeNotFound } @@ -333,7 +331,7 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU return err } - _, err = vfs.db.NewUpdate().Model(node). + _, err = db.NewUpdate().Model(node). WherePK(). Where("status = ?", NodeStatusReady). Where("deleted_at IS NULL"). @@ -362,7 +360,7 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU if vfs.keyResolver.ShouldPersistKey() { node.BlobKey = newKey - _, err = vfs.db.NewUpdate().Model(node). + _, err = db.NewUpdate().Model(node). WherePK(). Set("blob_key = ?", newKey). Exec(ctx) @@ -374,34 +372,34 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU return nil } -func (vfs *VirtualFS) AbsolutePath(ctx context.Context, node *Node) (string, error) { +func (vfs *VirtualFS) AbsolutePath(ctx context.Context, db bun.IDB, node *Node) (string, error) { if !node.IsAccessible() { return "", ErrNodeNotFound } - return buildNodeAbsolutePath(ctx, vfs.db, node.ID) + return buildNodeAbsolutePath(ctx, db, node.ID) } -func (vfs *VirtualFS) PermanentlyDeleteNode(ctx context.Context, node *Node) error { +func (vfs *VirtualFS) PermanentlyDeleteNode(ctx context.Context, db bun.IDB, node *Node) error { if !node.IsAccessible() { return ErrNodeNotFound } switch node.Kind { case NodeKindFile: - return vfs.permanentlyDeleteFileNode(ctx, node) + return vfs.permanentlyDeleteFileNode(ctx, db, node) case NodeKindDirectory: - return vfs.permanentlyDeleteDirectoryNode(ctx, node) + return vfs.permanentlyDeleteDirectoryNode(ctx, db, node) default: return ErrUnsupportedOperation } } -func (vfs *VirtualFS) permanentlyDeleteFileNode(ctx context.Context, node *Node) error { +func (vfs *VirtualFS) permanentlyDeleteFileNode(ctx context.Context, db bun.IDB, node *Node) error { err := vfs.blobStore.Delete(ctx, node.BlobKey) if err != nil { return err } - _, err = vfs.db.NewDelete().Model(node).WherePK().Exec(ctx) + _, err = db.NewDelete().Model(node).WherePK().Exec(ctx) if err != nil { return err } @@ -409,7 +407,7 @@ func (vfs *VirtualFS) permanentlyDeleteFileNode(ctx context.Context, node *Node) return nil } -func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, node *Node) error { +func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, db bun.IDB, node *Node) error { const descendantsQuery = `WITH RECURSIVE descendants AS ( SELECT id, blob_key FROM vfs_nodes WHERE id = ? UNION ALL @@ -423,14 +421,29 @@ func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, node * BlobKey blob.Key `bun:"blob_key"` } - tx, err := vfs.db.BeginTx(ctx, nil) - if err != nil { - return err + // If db is already a transaction, use it directly; otherwise start a new transaction + var tx bun.IDB + var startedTx *bun.Tx + switch v := db.(type) { + case *bun.DB: + newTx, err := v.BeginTx(ctx, nil) + if err != nil { + return err + } + startedTx = &newTx + tx = newTx + defer func() { + if startedTx != nil { + (*startedTx).Rollback() + } + }() + default: + // Assume it's already a transaction + tx = db } - defer tx.Rollback() var records []nodeRecord - err = tx.NewRaw(descendantsQuery, node.ID).Scan(ctx, &records) + err := tx.NewRaw(descendantsQuery, node.ID).Scan(ctx, &records) if err != nil { if errors.Is(err, sql.ErrNoRows) { return ErrNodeNotFound @@ -472,7 +485,14 @@ func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, node * } } - return tx.Commit() + // Only commit if we started the transaction + if startedTx != nil { + err := (*startedTx).Commit() + startedTx = nil // Prevent defer from rolling back + return err + } + + return nil } func (vfs *VirtualFS) generatePublicID() (string, error) {