feat: initial sharing impl

This commit is contained in:
2025-12-27 19:27:08 +00:00
parent 94458c2f1e
commit 1a1fc4743a
23 changed files with 4019 additions and 1232 deletions

View File

@@ -0,0 +1,265 @@
package sharing
import (
"errors"
"time"
"github.com/get-drexa/drexa/internal/account"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/reqctx"
"github.com/get-drexa/drexa/internal/user"
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type HTTPHandler struct {
sharingService *Service
accountService *account.Service
vfs *virtualfs.VirtualFS
db *bun.DB
authMiddleware fiber.Handler
}
// createShareRequest represents a request to create a share link
// @Description Request to create a new share link for files or directories
type createShareRequest struct {
// Array of file/directory IDs to share
Items []string `json:"items" example:"mElnUNCm8F22,kRp2XYTq9A55"`
// Optional expiration time for the share (ISO 8601)
ExpiresAt *time.Time `json:"expiresAt" example:"2025-01-15T00:00:00Z"`
}
func NewHTTPHandler(sharingService *Service, accountService *account.Service, vfs *virtualfs.VirtualFS, db *bun.DB, authMiddleware fiber.Handler) *HTTPHandler {
return &HTTPHandler{
sharingService: sharingService,
accountService: accountService,
vfs: vfs,
db: db,
authMiddleware: authMiddleware,
}
}
func (h *HTTPHandler) RegisterShareConsumeRoutes(r fiber.Router) *virtualfs.ScopedRouter {
g := r.Group("/shares/:shareID", h.shareMiddleware)
return &virtualfs.ScopedRouter{Router: g}
}
func (h *HTTPHandler) RegisterShareManagementRoutes(api *account.ScopedRouter) {
g := api.Group("/shares")
g.Get("/:shareID", h.getShare)
g.Post("/", h.createShare)
g.Delete("/:shareID", h.deleteShare)
}
func (h *HTTPHandler) shareMiddleware(c *fiber.Ctx) error {
shareID := c.Params("shareID")
share, err := h.sharingService.FindShareByPublicID(c.Context(), h.db, shareID)
if err != nil {
if errors.Is(err, ErrShareNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
// a share can be public or shared to specific accounts
// if latter, the accountId query param is expected and the route should be authenticated
// then the correct account is found using the authenticated user and the accountId query param
// finally, the account scope is resolved for the share
// otherwise, consumerAccount will be nil to attempt to resolve a public scope for the share
var consumerAccount *account.Account
qAccountID := c.Query("accountId")
if qAccountID != "" {
consumerAccountID, err := uuid.Parse(qAccountID)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "invalid account ID",
})
}
u := reqctx.AuthenticatedUser(c).(*user.User)
if u != nil {
consumerAccount, err = h.accountService.AccountByID(c.Context(), h.db, u.ID, consumerAccountID)
if err != nil {
if errors.Is(err, account.ErrAccountNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
consumerAccount, err = h.accountService.AccountByID(c.Context(), h.db, u.ID, consumerAccountID)
if err != nil {
if errors.Is(err, account.ErrAccountNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
}
}
scope, err := h.sharingService.ResolveScopeForShare(c.Context(), h.db, consumerAccount, share)
if err != nil {
if errors.Is(err, ErrShareNotFound) || errors.Is(err, ErrShareExpired) || errors.Is(err, ErrShareRevoked) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
if scope == nil {
// no scope can be resolved for the share
// the user is not authorized to access the share
// return 404 to hide the existence of the share
return c.SendStatus(fiber.StatusNotFound)
}
reqctx.SetVFSAccessScope(c, scope)
return c.Next()
}
// getShare retrieves a share by its ID
// @Summary Get share
// @Description Retrieve share link details by ID
// @Tags shares
// @Accept json
// @Produce json
// @Param accountID path string true "Account ID" format(uuid)
// @Param shareID path string true "Share ID"
// @Success 200 {object} Share "Share details"
// @Failure 401 {string} string "Not authenticated"
// @Failure 404 {string} string "Share not found"
// @Security BearerAuth
// @Router /accounts/{accountID}/shares/{shareID} [get]
func (h *HTTPHandler) getShare(c *fiber.Ctx) error {
shareID := c.Params("shareID")
share, err := h.sharingService.FindShareByPublicID(c.Context(), h.db, shareID)
if err != nil {
return httperr.Internal(err)
}
return c.JSON(share)
}
// createShare creates a new share link for files or directories
// @Summary Create share
// @Description Create a new share link for one or more files or directories
// @Tags shares
// @Accept json
// @Produce json
// @Param accountID path string true "Account ID" format(uuid)
// @Param request body createShareRequest true "Share details"
// @Success 200 {object} Share "Created share"
// @Failure 400 {object} map[string]string "Invalid request or no items provided"
// @Failure 401 {string} string "Not authenticated"
// @Failure 404 {object} map[string]string "One or more items not found"
// @Security BearerAuth
// @Router /accounts/{accountID}/shares [post]
func (h *HTTPHandler) createShare(c *fiber.Ctx) error {
scope, ok := scopeFromCtx(c)
if !ok {
return c.SendStatus(fiber.StatusUnauthorized)
}
acc, ok := reqctx.CurrentAccount(c).(*account.Account)
if !ok || acc == nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
var req createShareRequest
if err := c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "invalid request",
})
}
if len(req.Items) == 0 {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "expects at least one item to share",
})
}
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return httperr.Internal(err)
}
defer tx.Rollback()
nodes, err := h.vfs.FindNodesByPublicID(c.Context(), tx, req.Items, scope)
if err != nil {
return httperr.Internal(err)
}
if len(nodes) != len(req.Items) {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "One or more items not found"})
}
opts := CreateShareOptions{
Items: nodes,
}
if req.ExpiresAt != nil {
opts.ExpiresAt = *req.ExpiresAt
}
share, err := h.sharingService.CreateShare(c.Context(), tx, acc.ID, opts)
if err != nil {
return httperr.Internal(err)
}
err = tx.Commit()
if err != nil {
return httperr.Internal(err)
}
return c.JSON(share)
}
// deleteShare deletes a share link
// @Summary Delete share
// @Description Delete a share link, revoking access for all users
// @Tags shares
// @Param accountID path string true "Account ID" format(uuid)
// @Param shareID path string true "Share ID"
// @Success 204 {string} string "Share deleted"
// @Failure 401 {string} string "Not authenticated"
// @Failure 404 {string} string "Share not found"
// @Security BearerAuth
// @Router /accounts/{accountID}/shares/{shareID} [delete]
func (h *HTTPHandler) deleteShare(c *fiber.Ctx) error {
shareID := c.Params("shareID")
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return httperr.Internal(err)
}
defer tx.Rollback()
err = h.sharingService.DeleteShareByPublicID(c.Context(), tx, shareID)
if err != nil {
if errors.Is(err, ErrShareNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
err = tx.Commit()
if err != nil {
return httperr.Internal(err)
}
return c.SendStatus(fiber.StatusNoContent)
}
func scopeFromCtx(c *fiber.Ctx) (*virtualfs.Scope, bool) {
scopeAny := reqctx.VFSAccessScope(c)
if scopeAny == nil {
return nil, false
}
scope, ok := scopeAny.(*virtualfs.Scope)
if !ok || scope == nil {
return nil, false
}
return scope, true
}

View File

@@ -0,0 +1,319 @@
package sharing
import (
"context"
"crypto/rand"
"database/sql"
"encoding/binary"
"errors"
"time"
"github.com/get-drexa/drexa/internal/account"
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/google/uuid"
"github.com/sqids/sqids-go"
"github.com/uptrace/bun"
)
type Service struct {
sqid *sqids.Sqids
vfs *virtualfs.VirtualFS
}
type CreateShareOptions struct {
Items []*virtualfs.Node
ExpiresAt time.Time
}
// ListSharesOptions is the options for ListShares.
type ListSharesOptions struct {
// Items stores the list of nodes to filter the shares by. If nil, all shares for the account will be returned.
Items []*virtualfs.Node
// IncludesExpired includes shares that have expired in the results.
IncludesExpired bool
}
var (
ErrShareNotFound = errors.New("share not found")
ErrShareExpired = errors.New("share expired")
ErrShareRevoked = errors.New("share revoked")
ErrShareNoItems = errors.New("share has no items")
ErrNoPermissions = errors.New("no permissions found")
ErrNotSameParent = errors.New("items to be shared must be in the same parent directory")
)
func NewService(vfs *virtualfs.VirtualFS) (*Service, error) {
sqid, err := sqids.New()
if err != nil {
return nil, err
}
return &Service{vfs: vfs, sqid: sqid}, nil
}
// CreateShare creates a share record for a parent directory and its allowed items.
func (s *Service) CreateShare(ctx context.Context, db bun.IDB, accountID uuid.UUID, opts CreateShareOptions) (*Share, error) {
id, err := generateInternalID()
if err != nil {
return nil, err
}
pid, err := s.generatePublicID()
if err != nil {
return nil, err
}
var parentID *uuid.UUID
for _, item := range opts.Items {
if parentID == nil {
parentID = &item.ID
} else if item.ParentID != *parentID {
return nil, ErrNotSameParent
}
}
now := time.Now()
sh := &Share{
ID: id,
AccountID: accountID,
PublicID: pid,
SharedDirectoryID: opts.Items[0].ID,
CreatedAt: now,
UpdatedAt: now,
}
if !opts.ExpiresAt.IsZero() {
sh.ExpiresAt = &opts.ExpiresAt
}
shareItems := make([]ShareItem, len(opts.Items))
for i, item := range opts.Items {
si := &shareItems[i]
si.ID, err = generateInternalID()
if err != nil {
return nil, err
}
si.ShareID = sh.ID
si.NodeID = item.ID
si.CreatedAt = now
si.UpdatedAt = now
}
permID, err := generateInternalID()
if err != nil {
return nil, err
}
perm := SharePermission{
ID: permID,
ShareID: sh.ID,
// public access
AccountID: nil,
CanRead: true,
CanWrite: false,
CanDelete: false,
CanUpload: false,
CreatedAt: now,
UpdatedAt: now,
}
_, err = db.NewInsert().Model(sh).Returning("*").Exec(ctx)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(&shareItems).On("CONFLICT DO NOTHING").Returning("*").Exec(ctx)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(&perm).Returning("*").Exec(ctx)
if err != nil {
return nil, err
}
return sh, nil
}
// FindShareByPublicID loads a share by its public token.
func (s *Service) FindShareByPublicID(ctx context.Context, db bun.IDB, publicID string) (*Share, error) {
sh := new(Share)
err := db.NewSelect().Model(sh).
Where("public_id = ?", publicID).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrShareNotFound
}
return nil, err
}
return sh, nil
}
func (s *Service) ListShares(ctx context.Context, db bun.IDB, accountID uuid.UUID, opts ListSharesOptions) ([]Share, error) {
var shares []Share
q := db.NewSelect().Model(&shares).
Where("account_id = ?", accountID)
if !opts.IncludesExpired {
q = q.Where("expires_at IS NULL OR expires_at > NOW()")
}
if len(opts.Items) > 0 {
ids := make([]uuid.UUID, 0, len(opts.Items))
for _, item := range opts.Items {
ids = append(ids, item.ID)
}
q = q.Where(
"EXISTS (SELECT 1 FROM share_items si WHERE si.share_id = share.id AND si.node_id IN (?))",
bun.In(ids),
)
}
err := q.Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return make([]Share, 0), nil
}
return nil, err
}
if len(shares) == 0 {
return make([]Share, 0), nil
}
return shares, nil
}
// ResolveScopeForShare validates a share and derives a VFS scope from its permissions and items.
// share_items must contain the allowed nodes (including a directory node to share its subtree).
func (s *Service) ResolveScopeForShare(ctx context.Context, db bun.IDB, consumerAccount *account.Account, share *Share) (*virtualfs.Scope, error) {
now := time.Now()
if share.ExpiresAt != nil && share.ExpiresAt.Before(now) {
return nil, ErrShareExpired
}
if share.RevokedAt != nil {
return nil, ErrShareRevoked
}
var permList []*SharePermission
err := db.NewSelect().Model(&permList).
Where("share_id = ?", share.ID).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrShareNotFound
}
return nil, err
}
if len(permList) == 0 {
return nil, ErrNoPermissions
}
var publicPerm *SharePermission
var accPerm *SharePermission
if len(permList) == 1 {
accPerm = permList[0]
if accPerm.AccountID == nil {
publicPerm = accPerm
}
} else {
for _, p := range permList {
if p.AccountID != nil && consumerAccount != nil && *p.AccountID == consumerAccount.ID {
accPerm = p
} else if p.AccountID == nil {
publicPerm = p
}
}
}
if accPerm == nil && publicPerm == nil {
return nil, ErrNoPermissions
}
var perm *SharePermission
if accPerm != nil {
perm = accPerm
} else {
perm = publicPerm
}
if perm.ExpiresAt != nil && perm.ExpiresAt.Before(now) {
return nil, ErrShareExpired
}
scope := &virtualfs.Scope{
AccountID: share.AccountID,
RootNodeID: share.SharedDirectoryID,
}
if perm.AccountID == nil {
// the share is public
scope.ActorKind = virtualfs.ScopeActorShare
scope.ActorID = share.ID
} else {
scope.ActorKind = virtualfs.ScopeActorAccount
scope.ActorID = *perm.AccountID
}
scope.AllowedOps = map[virtualfs.Operation]bool{
virtualfs.OperationRead: perm.CanRead,
virtualfs.OperationWrite: perm.CanWrite,
virtualfs.OperationDelete: perm.CanDelete,
virtualfs.OperationUpload: perm.CanUpload,
}
var items []*ShareItem
err = db.NewSelect().Model(&items).
Where("share_id = ?", share.ID).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrShareNoItems
}
return nil, err
}
if len(items) == 0 {
return nil, ErrShareNoItems
}
scope.AllowedNodes = make(map[uuid.UUID]struct{})
for _, item := range items {
scope.AllowedNodes[item.NodeID] = struct{}{}
}
return scope, nil
}
// DeleteShare deletes the given share completely. It will no longer be accessible after.
func (s *Service) DeleteShareByPublicID(ctx context.Context, db bun.IDB, publicID string) error {
r, err := db.NewDelete().Model(&Share{}).Where("public_id = ?", publicID).Exec(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrShareNotFound
}
return err
}
c, err := r.RowsAffected()
if err != nil {
return err
}
if c == 0 {
return ErrShareNotFound
}
return nil
}
func (s *Service) generatePublicID() (string, error) {
var b [8]byte
_, err := rand.Read(b[:])
if err != nil {
return "", err
}
n := binary.BigEndian.Uint64(b[:])
return s.sqid.Encode([]uint64{n})
}

View File

@@ -0,0 +1,56 @@
package sharing
import (
"time"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
// Share represents a share link for files or directories
// @Description Share link information including expiration and timestamps
type Share struct {
bun.BaseModel `bun:"node_shares"`
ID uuid.UUID `bun:",pk,type:uuid" json:"-"`
AccountID uuid.UUID `bun:"account_id,notnull,type:uuid" json:"-"`
// Unique share identifier (public ID)
PublicID string `bun:"public_id,notnull" json:"id" example:"kRp2XYTq9A55"`
SharedDirectoryID uuid.UUID `bun:"shared_directory_id,notnull,type:uuid" json:"-"`
RevokedAt *time.Time `bun:"revoked_at" json:"-"`
// When the share expires, null if it never expires (ISO 8601)
ExpiresAt *time.Time `bun:"expires_at" json:"expiresAt" example:"2025-01-15T00:00:00Z"`
// When the share was created (ISO 8601)
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt" example:"2024-12-13T15:04:05Z"`
// When the share was last updated (ISO 8601)
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt" example:"2024-12-13T16:30:00Z"`
}
type SharePermission struct {
bun.BaseModel `bun:"share_permissions"`
ID uuid.UUID `bun:",pk,type:uuid"`
ShareID uuid.UUID `bun:"share_id,notnull,type:uuid"`
AccountID *uuid.UUID `bun:"account_id,type:uuid"`
CanRead bool `bun:"can_read,notnull"`
CanWrite bool `bun:"can_write,notnull"`
CanDelete bool `bun:"can_delete,notnull"`
CanUpload bool `bun:"can_upload,notnull"`
ExpiresAt *time.Time `bun:"expires_at"`
CreatedAt time.Time `bun:"created_at,notnull"`
UpdatedAt time.Time `bun:"updated_at,notnull"`
}
type ShareItem struct {
bun.BaseModel `bun:"share_items"`
ID uuid.UUID `bun:",pk,type:uuid"`
ShareID uuid.UUID `bun:"share_id,notnull,type:uuid"`
NodeID uuid.UUID `bun:"node_id,notnull,type:uuid"`
CreatedAt time.Time `bun:"created_at,notnull"`
UpdatedAt time.Time `bun:"updated_at,notnull"`
}
func generateInternalID() (uuid.UUID, error) {
return uuid.NewV7()
}