Files
drive/apps/backend/internal/sharing/service.go

368 lines
8.5 KiB
Go
Raw Normal View History

2025-12-27 19:27:08 +00:00
package sharing
import (
"context"
"crypto/rand"
"database/sql"
"encoding/binary"
"errors"
"time"
"github.com/get-drexa/drexa/internal/account"
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/google/uuid"
"github.com/sqids/sqids-go"
"github.com/uptrace/bun"
)
type Service struct {
sqid *sqids.Sqids
vfs *virtualfs.VirtualFS
}
type CreateShareOptions struct {
Items []*virtualfs.Node
ExpiresAt time.Time
}
// ListSharesOptions is the options for ListShares.
type ListSharesOptions struct {
// Items stores the list of nodes to filter the shares by. If nil, all shares for the account will be returned.
Items []*virtualfs.Node
// IncludesExpired includes shares that have expired in the results.
IncludesExpired bool
}
type UpdateShareOptions struct {
// ExpiresAt sets the expiration time for the share. If nil, the expiration time is not changed. If it is a zero time, the expiration time is cleared.
ExpiresAt *time.Time
}
2025-12-27 19:27:08 +00:00
var (
ErrShareNotFound = errors.New("share not found")
ErrShareExpired = errors.New("share expired")
ErrShareRevoked = errors.New("share revoked")
ErrShareNoItems = errors.New("share has no items")
ErrNoPermissions = errors.New("no permissions found")
ErrNotSameParent = errors.New("items to be shared must be in the same parent directory")
ErrCannotShareRoot = errors.New("cannot share root directory")
2025-12-27 19:27:08 +00:00
)
func NewService(vfs *virtualfs.VirtualFS) (*Service, error) {
sqid, err := sqids.New()
if err != nil {
return nil, err
}
return &Service{vfs: vfs, sqid: sqid}, nil
}
// CreateShare creates a share record for its allowed items.
// A share is a partial share of a directory: the share root is always the common parent directory of all items.
2026-01-01 18:29:52 +00:00
func (s *Service) CreateShare(ctx context.Context, db bun.IDB, driveID uuid.UUID, createdByAccountID uuid.UUID, opts CreateShareOptions) (*Share, error) {
if len(opts.Items) == 0 {
return nil, ErrShareNoItems
}
2025-12-27 19:27:08 +00:00
id, err := generateInternalID()
if err != nil {
return nil, err
}
pid, err := s.generatePublicID()
if err != nil {
return nil, err
}
sharedDirectoryID := opts.Items[0].ParentID
if sharedDirectoryID == uuid.Nil {
// Root directories have no parent; they are never shareable.
return nil, ErrCannotShareRoot
}
for _, item := range opts.Items[1:] {
if item.ParentID != sharedDirectoryID {
2025-12-27 19:27:08 +00:00
return nil, ErrNotSameParent
}
}
now := time.Now()
sh := &Share{
2026-01-01 18:29:52 +00:00
ID: id,
DriveID: driveID,
CreatedByAccountID: createdByAccountID,
PublicID: pid,
SharedDirectoryID: sharedDirectoryID,
CreatedAt: now,
UpdatedAt: now,
2025-12-27 19:27:08 +00:00
}
if !opts.ExpiresAt.IsZero() {
sh.ExpiresAt = &opts.ExpiresAt
}
shareItems := make([]ShareItem, len(opts.Items))
for i, item := range opts.Items {
si := &shareItems[i]
si.ID, err = generateInternalID()
if err != nil {
return nil, err
}
si.ShareID = sh.ID
si.NodeID = item.ID
si.CreatedAt = now
si.UpdatedAt = now
}
permID, err := generateInternalID()
if err != nil {
return nil, err
}
perm := SharePermission{
ID: permID,
ShareID: sh.ID,
// public access
AccountID: nil,
CanRead: true,
CanWrite: false,
CanDelete: false,
CanUpload: false,
CreatedAt: now,
UpdatedAt: now,
}
_, err = db.NewInsert().Model(sh).Returning("*").Exec(ctx)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(&shareItems).On("CONFLICT DO NOTHING").Returning("*").Exec(ctx)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(&perm).Returning("*").Exec(ctx)
if err != nil {
return nil, err
}
return sh, nil
}
// FindShareByPublicID loads a share by its public token.
func (s *Service) FindShareByPublicID(ctx context.Context, db bun.IDB, publicID string) (*Share, error) {
sh := new(Share)
err := db.NewSelect().Model(sh).
Where("public_id = ?", publicID).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrShareNotFound
}
return nil, err
}
return sh, nil
}
2026-01-01 18:29:52 +00:00
func (s *Service) ListShares(ctx context.Context, db bun.IDB, driveID uuid.UUID, opts ListSharesOptions) ([]Share, error) {
2025-12-27 19:27:08 +00:00
var shares []Share
q := db.NewSelect().Model(&shares).
2026-01-01 18:29:52 +00:00
Where("drive_id = ?", driveID)
2025-12-27 19:27:08 +00:00
if !opts.IncludesExpired {
q = q.Where("expires_at IS NULL OR expires_at > NOW()")
}
if len(opts.Items) > 0 {
ids := make([]uuid.UUID, 0, len(opts.Items))
for _, item := range opts.Items {
ids = append(ids, item.ID)
}
q = q.Where(
"EXISTS (SELECT 1 FROM share_items si WHERE si.share_id = share.id AND si.node_id IN (?))",
bun.In(ids),
)
}
err := q.Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return make([]Share, 0), nil
}
return nil, err
}
if len(shares) == 0 {
return make([]Share, 0), nil
}
return shares, nil
}
// ResolveScopeForShare validates a share and derives a VFS scope from its permissions and items.
// share_items must contain the allowed nodes (including a directory node to share its subtree).
func (s *Service) ResolveScopeForShare(ctx context.Context, db bun.IDB, consumerAccount *account.Account, share *Share) (*virtualfs.Scope, error) {
now := time.Now()
if share.ExpiresAt != nil && share.ExpiresAt.Before(now) {
return nil, ErrShareExpired
}
if share.RevokedAt != nil {
return nil, ErrShareRevoked
}
var permList []*SharePermission
err := db.NewSelect().Model(&permList).
Where("share_id = ?", share.ID).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrShareNotFound
}
return nil, err
}
if len(permList) == 0 {
return nil, ErrNoPermissions
}
var publicPerm *SharePermission
var accPerm *SharePermission
if len(permList) == 1 {
accPerm = permList[0]
if accPerm.AccountID == nil {
publicPerm = accPerm
}
} else {
for _, p := range permList {
if p.AccountID != nil && consumerAccount != nil && *p.AccountID == consumerAccount.ID {
accPerm = p
} else if p.AccountID == nil {
publicPerm = p
}
}
}
if accPerm == nil && publicPerm == nil {
return nil, ErrNoPermissions
}
var perm *SharePermission
if accPerm != nil {
perm = accPerm
} else {
perm = publicPerm
}
if perm.ExpiresAt != nil && perm.ExpiresAt.Before(now) {
return nil, ErrShareExpired
}
scope := &virtualfs.Scope{
2026-01-01 18:29:52 +00:00
DriveID: share.DriveID,
2025-12-27 19:27:08 +00:00
RootNodeID: share.SharedDirectoryID,
}
if perm.AccountID == nil {
// the share is public
scope.ActorKind = virtualfs.ScopeActorShare
scope.ActorID = share.ID
} else {
scope.ActorKind = virtualfs.ScopeActorAccount
scope.ActorID = *perm.AccountID
}
scope.AllowedOps = map[virtualfs.Operation]bool{
virtualfs.OperationRead: perm.CanRead,
virtualfs.OperationWrite: perm.CanWrite,
virtualfs.OperationDelete: perm.CanDelete,
virtualfs.OperationUpload: perm.CanUpload,
}
var items []*ShareItem
err = db.NewSelect().Model(&items).
Where("share_id = ?", share.ID).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrShareNoItems
}
return nil, err
}
if len(items) == 0 {
return nil, ErrShareNoItems
}
scope.AllowedNodes = make(map[uuid.UUID]struct{})
for _, item := range items {
scope.AllowedNodes[item.NodeID] = struct{}{}
}
return scope, nil
}
// DeleteShare deletes the given share completely. It will no longer be accessible after.
func (s *Service) DeleteShareByPublicID(ctx context.Context, db bun.IDB, publicID string) error {
r, err := db.NewDelete().Model(&Share{}).Where("public_id = ?", publicID).Exec(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrShareNotFound
}
return err
}
c, err := r.RowsAffected()
if err != nil {
return err
}
if c == 0 {
return ErrShareNotFound
}
return nil
}
func (s *Service) generatePublicID() (string, error) {
var b [8]byte
_, err := rand.Read(b[:])
if err != nil {
return "", err
}
n := binary.BigEndian.Uint64(b[:])
return s.sqid.Encode([]uint64{n})
}
func (s *Service) UpdateShare(ctx context.Context, db bun.IDB, share *Share, opts UpdateShareOptions) error {
now := time.Now()
cols := make([]string, 0, 2)
if opts.ExpiresAt != nil {
newExpiresAt := opts.ExpiresAt
if opts.ExpiresAt.IsZero() {
newExpiresAt = nil
}
if !timePtrEqual(share.ExpiresAt, newExpiresAt) {
share.ExpiresAt = newExpiresAt
share.UpdatedAt = now
cols = append(cols, "expires_at", "updated_at")
}
}
if len(cols) > 0 {
_, err := db.NewUpdate().Model(share).Column(cols...).WherePK().Exec(ctx)
if err != nil {
return err
}
}
return nil
}
func timePtrEqual(a, b *time.Time) bool {
if a == nil || b == nil {
return a == b
}
return a.Equal(*b)
}