refactor: make vfs methods accept bun.IDB

This commit is contained in:
2025-11-30 17:31:24 +00:00
parent 89b62f6d8a
commit 0b8aee5d60
6 changed files with 82 additions and 56 deletions

View File

@@ -8,12 +8,14 @@ import (
"github.com/get-drexa/drexa/internal/database" "github.com/get-drexa/drexa/internal/database"
"github.com/get-drexa/drexa/internal/password" "github.com/get-drexa/drexa/internal/password"
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type Service struct { type Service struct {
userService user.Service userService user.Service
vfs *virtualfs.VirtualFS
} }
type RegisterOptions struct { type RegisterOptions struct {
@@ -27,9 +29,10 @@ type CreateAccountOptions struct {
QuotaBytes int64 QuotaBytes int64
} }
func NewService(userService *user.Service) *Service { func NewService(userService *user.Service, vfs *virtualfs.VirtualFS) *Service {
return &Service{ return &Service{
userService: *userService, userService: *userService,
vfs: vfs,
} }
} }

View File

@@ -50,7 +50,7 @@ func NewServer(c Config) (*fiber.App, error) {
return nil, fmt.Errorf("unknown storage mode: %s", c.Storage.Mode) return nil, fmt.Errorf("unknown storage mode: %s", c.Storage.Mode)
} }
vfs, err := virtualfs.NewVirtualFS(db, blobStore, keyResolver) vfs, err := virtualfs.NewVirtualFS(blobStore, keyResolver)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create virtual file system: %w", err) return nil, fmt.Errorf("failed to create virtual file system: %w", err)
} }
@@ -62,7 +62,7 @@ func NewServer(c Config) (*fiber.App, error) {
SecretKey: c.JWT.SecretKey, SecretKey: c.JWT.SecretKey,
}) })
uploadService := upload.NewService(vfs, blobStore) uploadService := upload.NewService(vfs, blobStore)
accountService := account.NewService(userService) accountService := account.NewService(userService, vfs)
authMiddleware := auth.NewBearerAuthMiddleware(authService, db) authMiddleware := auth.NewBearerAuthMiddleware(authService, db)
@@ -71,7 +71,7 @@ func NewServer(c Config) (*fiber.App, error) {
accRouter := account.NewHTTPHandler(accountService, authService, db, authMiddleware).RegisterRoutes(api) accRouter := account.NewHTTPHandler(accountService, authService, db, authMiddleware).RegisterRoutes(api)
auth.NewHTTPHandler(authService, db).RegisterRoutes(api) auth.NewHTTPHandler(authService, db).RegisterRoutes(api)
upload.NewHTTPHandler(uploadService).RegisterRoutes(accRouter) upload.NewHTTPHandler(uploadService, db).RegisterRoutes(accRouter)
return app, nil return app, nil
} }

View File

@@ -5,6 +5,7 @@ import (
"github.com/get-drexa/drexa/internal/account" "github.com/get-drexa/drexa/internal/account"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
) )
type createUploadRequest struct { type createUploadRequest struct {
@@ -18,10 +19,11 @@ type updateUploadRequest struct {
type HTTPHandler struct { type HTTPHandler struct {
service *Service service *Service
db *bun.DB
} }
func NewHTTPHandler(s *Service) *HTTPHandler { func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
return &HTTPHandler{service: s} return &HTTPHandler{service: s, db: db}
} }
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) { func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
@@ -43,7 +45,7 @@ func (h *HTTPHandler) Create(c *fiber.Ctx) error {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
} }
upload, err := h.service.CreateUpload(c.Context(), account.ID, CreateUploadOptions{ upload, err := h.service.CreateUpload(c.Context(), h.db, account.ID, CreateUploadOptions{
ParentID: req.ParentID, ParentID: req.ParentID,
Name: req.Name, Name: req.Name,
}) })
@@ -62,7 +64,7 @@ func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
uploadID := c.Params("uploadID") uploadID := c.Params("uploadID")
err := h.service.ReceiveUpload(c.Context(), account.ID, uploadID, c.Request().BodyStream()) err := h.service.ReceiveUpload(c.Context(), h.db, account.ID, uploadID, c.Request().BodyStream())
defer c.Request().CloseBodyStream() defer c.Request().CloseBodyStream()
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
@@ -83,7 +85,7 @@ func (h *HTTPHandler) Update(c *fiber.Ctx) error {
} }
if req.Status == StatusCompleted { if req.Status == StatusCompleted {
upload, err := h.service.CompleteUpload(c.Context(), account.ID, c.Params("uploadID")) upload, err := h.service.CompleteUpload(c.Context(), h.db, account.ID, c.Params("uploadID"))
if err != nil { if err != nil {
if errors.Is(err, ErrNotFound) { if errors.Is(err, ErrNotFound) {
return c.SendStatus(fiber.StatusNotFound) return c.SendStatus(fiber.StatusNotFound)

View File

@@ -10,6 +10,7 @@ import (
"github.com/get-drexa/drexa/internal/blob" "github.com/get-drexa/drexa/internal/blob"
"github.com/get-drexa/drexa/internal/virtualfs" "github.com/get-drexa/drexa/internal/virtualfs"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/uptrace/bun"
) )
type Service struct { type Service struct {
@@ -33,8 +34,8 @@ type CreateUploadOptions struct {
Name string Name string
} }
func (s *Service) CreateUpload(ctx context.Context, accountID uuid.UUID, opts CreateUploadOptions) (*Upload, error) { func (s *Service) CreateUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, opts CreateUploadOptions) (*Upload, error) {
parentNode, err := s.vfs.FindNodeByPublicID(ctx, accountID, opts.ParentID) parentNode, err := s.vfs.FindNodeByPublicID(ctx, db, accountID, opts.ParentID)
if err != nil { if err != nil {
if errors.Is(err, virtualfs.ErrNodeNotFound) { if errors.Is(err, virtualfs.ErrNodeNotFound) {
return nil, ErrNotFound return nil, ErrNotFound
@@ -46,7 +47,7 @@ func (s *Service) CreateUpload(ctx context.Context, accountID uuid.UUID, opts Cr
return nil, ErrParentNotDirectory return nil, ErrParentNotDirectory
} }
node, err := s.vfs.CreateFile(ctx, accountID, virtualfs.CreateFileOptions{ node, err := s.vfs.CreateFile(ctx, db, accountID, virtualfs.CreateFileOptions{
ParentID: parentNode.ID, ParentID: parentNode.ID,
Name: opts.Name, Name: opts.Name,
}) })
@@ -61,7 +62,7 @@ func (s *Service) CreateUpload(ctx context.Context, accountID uuid.UUID, opts Cr
Duration: 1 * time.Hour, Duration: 1 * time.Hour,
}) })
if err != nil { if err != nil {
_ = s.vfs.PermanentlyDeleteNode(ctx, node) _ = s.vfs.PermanentlyDeleteNode(ctx, db, node)
return nil, err return nil, err
} }
@@ -77,7 +78,7 @@ func (s *Service) CreateUpload(ctx context.Context, accountID uuid.UUID, opts Cr
return upload, nil return upload, nil
} }
func (s *Service) ReceiveUpload(ctx context.Context, accountID uuid.UUID, uploadID string, reader io.Reader) error { func (s *Service) ReceiveUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, uploadID string, reader io.Reader) error {
n, ok := s.pendingUploads.Load(uploadID) n, ok := s.pendingUploads.Load(uploadID)
if !ok { if !ok {
return ErrNotFound return ErrNotFound
@@ -92,7 +93,7 @@ func (s *Service) ReceiveUpload(ctx context.Context, accountID uuid.UUID, upload
return ErrNotFound return ErrNotFound
} }
err := s.vfs.WriteFile(ctx, upload.TargetNode, virtualfs.FileContentFromReader(reader)) err := s.vfs.WriteFile(ctx, db, upload.TargetNode, virtualfs.FileContentFromReader(reader))
if err != nil { if err != nil {
return err return err
} }
@@ -102,7 +103,7 @@ func (s *Service) ReceiveUpload(ctx context.Context, accountID uuid.UUID, upload
return nil return nil
} }
func (s *Service) CompleteUpload(ctx context.Context, accountID uuid.UUID, uploadID string) (*Upload, error) { func (s *Service) CompleteUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, uploadID string) (*Upload, error) {
n, ok := s.pendingUploads.Load(uploadID) n, ok := s.pendingUploads.Load(uploadID)
if !ok { if !ok {
return nil, ErrNotFound return nil, ErrNotFound
@@ -121,7 +122,7 @@ func (s *Service) CompleteUpload(ctx context.Context, accountID uuid.UUID, uploa
return upload, nil return upload, nil
} }
err := s.vfs.WriteFile(ctx, upload.TargetNode, virtualfs.FileContentFromBlobKey(upload.TargetNode.BlobKey)) err := s.vfs.WriteFile(ctx, db, upload.TargetNode, virtualfs.FileContentFromBlobKey(upload.TargetNode.BlobKey))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -29,7 +29,7 @@ func JoinPath(parts ...string) string {
return strings.Join(parts, "/") return strings.Join(parts, "/")
} }
func buildNodeAbsolutePath(ctx context.Context, db *bun.DB, nodeID uuid.UUID) (string, error) { func buildNodeAbsolutePath(ctx context.Context, db bun.IDB, nodeID uuid.UUID) (string, error) {
var path []string var path []string
err := db.NewRaw(absolutePathQuery, nodeID).Scan(ctx, &path) err := db.NewRaw(absolutePathQuery, nodeID).Scan(ctx, &path)
if err != nil { if err != nil {

View File

@@ -19,7 +19,6 @@ import (
) )
type VirtualFS struct { type VirtualFS struct {
db *bun.DB
blobStore blob.Store blobStore blob.Store
keyResolver BlobKeyResolver keyResolver BlobKeyResolver
@@ -50,22 +49,21 @@ func FileContentFromBlobKey(blobKey blob.Key) FileContent {
return FileContent{blobKey: blobKey} return FileContent{blobKey: blobKey}
} }
func NewVirtualFS(db *bun.DB, blobStore blob.Store, keyResolver BlobKeyResolver) (*VirtualFS, error) { func NewVirtualFS(blobStore blob.Store, keyResolver BlobKeyResolver) (*VirtualFS, error) {
sqid, err := sqids.New() sqid, err := sqids.New()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &VirtualFS{ return &VirtualFS{
db: db,
blobStore: blobStore, blobStore: blobStore,
keyResolver: keyResolver, keyResolver: keyResolver,
sqid: sqid, sqid: sqid,
}, nil }, nil
} }
func (vfs *VirtualFS) FindNode(ctx context.Context, accountID, fileID string) (*Node, error) { func (vfs *VirtualFS) FindNode(ctx context.Context, db bun.IDB, accountID, fileID string) (*Node, error) {
var node Node var node Node
err := vfs.db.NewSelect().Model(&node). err := db.NewSelect().Model(&node).
Where("account_id = ?", accountID). Where("account_id = ?", accountID).
Where("id = ?", fileID). Where("id = ?", fileID).
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
@@ -80,9 +78,9 @@ func (vfs *VirtualFS) FindNode(ctx context.Context, accountID, fileID string) (*
return &node, nil return &node, nil
} }
func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, accountID uuid.UUID, publicID string) (*Node, error) { func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, db bun.IDB, accountID uuid.UUID, publicID string) (*Node, error) {
var node Node var node Node
err := vfs.db.NewSelect().Model(&node). err := db.NewSelect().Model(&node).
Where("account_id = ?", accountID). Where("account_id = ?", accountID).
Where("public_id = ?", publicID). Where("public_id = ?", publicID).
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
@@ -97,13 +95,13 @@ func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, accountID uuid.UUI
return &node, nil return &node, nil
} }
func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, error) { func (vfs *VirtualFS) ListChildren(ctx context.Context, db bun.IDB, node *Node) ([]*Node, error) {
if !node.IsAccessible() { if !node.IsAccessible() {
return nil, ErrNodeNotFound return nil, ErrNodeNotFound
} }
var nodes []*Node var nodes []*Node
err := vfs.db.NewSelect().Model(&nodes). err := db.NewSelect().Model(&nodes).
Where("account_id = ?", node.AccountID). Where("account_id = ?", node.AccountID).
Where("parent_id = ?", node.ID). Where("parent_id = ?", node.ID).
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
@@ -119,7 +117,7 @@ func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, er
return nodes, nil return nodes, nil
} }
func (vfs *VirtualFS) CreateFile(ctx context.Context, accountID uuid.UUID, opts CreateFileOptions) (*Node, error) { func (vfs *VirtualFS) CreateFile(ctx context.Context, db bun.IDB, accountID uuid.UUID, opts CreateFileOptions) (*Node, error) {
pid, err := vfs.generatePublicID() pid, err := vfs.generatePublicID()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -141,7 +139,7 @@ func (vfs *VirtualFS) CreateFile(ctx context.Context, accountID uuid.UUID, opts
} }
} }
_, err = vfs.db.NewInsert().Model(&node).Returning("*").Exec(ctx) _, err = db.NewInsert().Model(&node).Returning("*").Exec(ctx)
if err != nil { if err != nil {
if database.IsUniqueViolation(err) { if database.IsUniqueViolation(err) {
return nil, ErrNodeConflict return nil, ErrNodeConflict
@@ -152,7 +150,7 @@ func (vfs *VirtualFS) CreateFile(ctx context.Context, accountID uuid.UUID, opts
return &node, nil return &node, nil
} }
func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileContent) error { func (vfs *VirtualFS) WriteFile(ctx context.Context, db bun.IDB, node *Node, content FileContent) error {
if content.reader == nil && content.blobKey.IsNil() { if content.reader == nil && content.blobKey.IsNil() {
return blob.ErrInvalidFileContent return blob.ErrInvalidFileContent
} }
@@ -222,7 +220,7 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
setCols = append(setCols, "mime_type", "blob_key", "size", "status") setCols = append(setCols, "mime_type", "blob_key", "size", "status")
} }
_, err := vfs.db.NewUpdate().Model(&node). _, err := db.NewUpdate().Model(&node).
Column(setCols...). Column(setCols...).
WherePK(). WherePK().
Exec(ctx) Exec(ctx)
@@ -233,7 +231,7 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
return nil return nil
} }
func (vfs *VirtualFS) CreateDirectory(ctx context.Context, accountID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) { func (vfs *VirtualFS) CreateDirectory(ctx context.Context, db bun.IDB, accountID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) {
pid, err := vfs.generatePublicID() pid, err := vfs.generatePublicID()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -248,7 +246,7 @@ func (vfs *VirtualFS) CreateDirectory(ctx context.Context, accountID uuid.UUID,
Name: name, Name: name,
} }
_, err = vfs.db.NewInsert().Model(&node).Exec(ctx) _, err = db.NewInsert().Model(&node).Exec(ctx)
if err != nil { if err != nil {
if database.IsUniqueViolation(err) { if database.IsUniqueViolation(err) {
return nil, ErrNodeConflict return nil, ErrNodeConflict
@@ -259,12 +257,12 @@ func (vfs *VirtualFS) CreateDirectory(ctx context.Context, accountID uuid.UUID,
return &node, nil return &node, nil
} }
func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, node *Node) error { func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, db bun.IDB, node *Node) error {
if !node.IsAccessible() { if !node.IsAccessible() {
return ErrNodeNotFound return ErrNodeNotFound
} }
_, err := vfs.db.NewUpdate().Model(node). _, err := db.NewUpdate().Model(node).
WherePK(). WherePK().
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
@@ -281,12 +279,12 @@ func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, node *Node) error {
return nil return nil
} }
func (vfs *VirtualFS) RestoreNode(ctx context.Context, node *Node) error { func (vfs *VirtualFS) RestoreNode(ctx context.Context, db bun.IDB, node *Node) error {
if node.Status != NodeStatusReady { if node.Status != NodeStatusReady {
return ErrNodeNotFound return ErrNodeNotFound
} }
_, err := vfs.db.NewUpdate().Model(node). _, err := db.NewUpdate().Model(node).
WherePK(). WherePK().
Where("deleted_at IS NOT NULL"). Where("deleted_at IS NOT NULL").
Set("deleted_at = NULL"). Set("deleted_at = NULL").
@@ -302,12 +300,12 @@ func (vfs *VirtualFS) RestoreNode(ctx context.Context, node *Node) error {
return nil return nil
} }
func (vfs *VirtualFS) RenameNode(ctx context.Context, node *Node, name string) error { func (vfs *VirtualFS) RenameNode(ctx context.Context, db bun.IDB, node *Node, name string) error {
if !node.IsAccessible() { if !node.IsAccessible() {
return ErrNodeNotFound return ErrNodeNotFound
} }
_, err := vfs.db.NewUpdate().Model(node). _, err := db.NewUpdate().Model(node).
WherePK(). WherePK().
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
@@ -323,7 +321,7 @@ func (vfs *VirtualFS) RenameNode(ctx context.Context, node *Node, name string) e
return nil return nil
} }
func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UUID) error { func (vfs *VirtualFS) MoveNode(ctx context.Context, db bun.IDB, node *Node, parentID uuid.UUID) error {
if !node.IsAccessible() { if !node.IsAccessible() {
return ErrNodeNotFound return ErrNodeNotFound
} }
@@ -333,7 +331,7 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU
return err return err
} }
_, err = vfs.db.NewUpdate().Model(node). _, err = db.NewUpdate().Model(node).
WherePK(). WherePK().
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
@@ -362,7 +360,7 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU
if vfs.keyResolver.ShouldPersistKey() { if vfs.keyResolver.ShouldPersistKey() {
node.BlobKey = newKey node.BlobKey = newKey
_, err = vfs.db.NewUpdate().Model(node). _, err = db.NewUpdate().Model(node).
WherePK(). WherePK().
Set("blob_key = ?", newKey). Set("blob_key = ?", newKey).
Exec(ctx) Exec(ctx)
@@ -374,34 +372,34 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU
return nil return nil
} }
func (vfs *VirtualFS) AbsolutePath(ctx context.Context, node *Node) (string, error) { func (vfs *VirtualFS) AbsolutePath(ctx context.Context, db bun.IDB, node *Node) (string, error) {
if !node.IsAccessible() { if !node.IsAccessible() {
return "", ErrNodeNotFound return "", ErrNodeNotFound
} }
return buildNodeAbsolutePath(ctx, vfs.db, node.ID) return buildNodeAbsolutePath(ctx, db, node.ID)
} }
func (vfs *VirtualFS) PermanentlyDeleteNode(ctx context.Context, node *Node) error { func (vfs *VirtualFS) PermanentlyDeleteNode(ctx context.Context, db bun.IDB, node *Node) error {
if !node.IsAccessible() { if !node.IsAccessible() {
return ErrNodeNotFound return ErrNodeNotFound
} }
switch node.Kind { switch node.Kind {
case NodeKindFile: case NodeKindFile:
return vfs.permanentlyDeleteFileNode(ctx, node) return vfs.permanentlyDeleteFileNode(ctx, db, node)
case NodeKindDirectory: case NodeKindDirectory:
return vfs.permanentlyDeleteDirectoryNode(ctx, node) return vfs.permanentlyDeleteDirectoryNode(ctx, db, node)
default: default:
return ErrUnsupportedOperation return ErrUnsupportedOperation
} }
} }
func (vfs *VirtualFS) permanentlyDeleteFileNode(ctx context.Context, node *Node) error { func (vfs *VirtualFS) permanentlyDeleteFileNode(ctx context.Context, db bun.IDB, node *Node) error {
err := vfs.blobStore.Delete(ctx, node.BlobKey) err := vfs.blobStore.Delete(ctx, node.BlobKey)
if err != nil { if err != nil {
return err return err
} }
_, err = vfs.db.NewDelete().Model(node).WherePK().Exec(ctx) _, err = db.NewDelete().Model(node).WherePK().Exec(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -409,7 +407,7 @@ func (vfs *VirtualFS) permanentlyDeleteFileNode(ctx context.Context, node *Node)
return nil return nil
} }
func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, node *Node) error { func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, db bun.IDB, node *Node) error {
const descendantsQuery = `WITH RECURSIVE descendants AS ( const descendantsQuery = `WITH RECURSIVE descendants AS (
SELECT id, blob_key FROM vfs_nodes WHERE id = ? SELECT id, blob_key FROM vfs_nodes WHERE id = ?
UNION ALL UNION ALL
@@ -423,14 +421,29 @@ func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, node *
BlobKey blob.Key `bun:"blob_key"` BlobKey blob.Key `bun:"blob_key"`
} }
tx, err := vfs.db.BeginTx(ctx, nil) // If db is already a transaction, use it directly; otherwise start a new transaction
var tx bun.IDB
var startedTx *bun.Tx
switch v := db.(type) {
case *bun.DB:
newTx, err := v.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() startedTx = &newTx
tx = newTx
defer func() {
if startedTx != nil {
(*startedTx).Rollback()
}
}()
default:
// Assume it's already a transaction
tx = db
}
var records []nodeRecord var records []nodeRecord
err = tx.NewRaw(descendantsQuery, node.ID).Scan(ctx, &records) err := tx.NewRaw(descendantsQuery, node.ID).Scan(ctx, &records)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return ErrNodeNotFound return ErrNodeNotFound
@@ -472,7 +485,14 @@ func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, node *
} }
} }
return tx.Commit() // Only commit if we started the transaction
if startedTx != nil {
err := (*startedTx).Commit()
startedTx = nil // Prevent defer from rolling back
return err
}
return nil
} }
func (vfs *VirtualFS) generatePublicID() (string, error) { func (vfs *VirtualFS) generatePublicID() (string, error) {