mirror of
https://github.com/get-drexa/drive.git
synced 2025-12-01 14:01:40 +00:00
feat: introduce account
This commit is contained in:
23
apps/backend/internal/account/account.go
Normal file
23
apps/backend/internal/account/account.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type Account struct {
|
||||
bun.BaseModel `bun:"accounts"`
|
||||
|
||||
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
|
||||
UserID uuid.UUID `bun:"user_id,notnull,type:uuid" json:"userId"`
|
||||
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
|
||||
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
|
||||
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
|
||||
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
|
||||
}
|
||||
|
||||
func newAccountID() (uuid.UUID, error) {
|
||||
return uuid.NewV7()
|
||||
}
|
||||
8
apps/backend/internal/account/err.go
Normal file
8
apps/backend/internal/account/err.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package account
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrAccountNotFound = errors.New("account not found")
|
||||
ErrAccountAlreadyExists = errors.New("account already exists")
|
||||
)
|
||||
131
apps/backend/internal/account/http.go
Normal file
131
apps/backend/internal/account/http.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/get-drexa/drexa/internal/auth"
|
||||
"github.com/get-drexa/drexa/internal/user"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type HTTPHandler struct {
|
||||
accountService *Service
|
||||
authService *auth.Service
|
||||
db *bun.DB
|
||||
authMiddleware fiber.Handler
|
||||
}
|
||||
|
||||
type registerAccountRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
DisplayName string `json:"displayName"`
|
||||
}
|
||||
|
||||
type registerAccountResponse struct {
|
||||
Account *Account `json:"account"`
|
||||
User *user.User `json:"user"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
const currentAccountKey = "currentAccount"
|
||||
|
||||
func CurrentAccount(c *fiber.Ctx) *Account {
|
||||
return c.Locals(currentAccountKey).(*Account)
|
||||
}
|
||||
|
||||
func NewHTTPHandler(accountService *Service, authService *auth.Service, db *bun.DB, authMiddleware fiber.Handler) *HTTPHandler {
|
||||
return &HTTPHandler{accountService: accountService, authService: authService, db: db, authMiddleware: authMiddleware}
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) fiber.Router {
|
||||
api.Post("/accounts", h.registerAccount)
|
||||
|
||||
account := api.Group("/accounts/:accountID")
|
||||
account.Use(h.authMiddleware)
|
||||
account.Use(h.accountMiddleware)
|
||||
|
||||
account.Get("/", h.getAccount)
|
||||
|
||||
return account
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) accountMiddleware(c *fiber.Ctx) error {
|
||||
user, err := auth.AuthenticatedUser(c)
|
||||
if err != nil {
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
|
||||
accountID, err := uuid.Parse(c.Params("accountID"))
|
||||
if err != nil {
|
||||
return c.SendStatus(fiber.StatusNotFound)
|
||||
}
|
||||
|
||||
account, err := h.accountService.AccountByID(c.Context(), h.db, user.ID, accountID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrAccountNotFound) {
|
||||
return c.SendStatus(fiber.StatusNotFound)
|
||||
}
|
||||
return c.SendStatus(fiber.StatusInternalServerError)
|
||||
}
|
||||
|
||||
c.Locals(currentAccountKey, account)
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) getAccount(c *fiber.Ctx) error {
|
||||
account := CurrentAccount(c)
|
||||
if account == nil {
|
||||
return c.SendStatus(fiber.StatusNotFound)
|
||||
}
|
||||
return c.JSON(account)
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) registerAccount(c *fiber.Ctx) error {
|
||||
req := new(registerAccountRequest)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return c.SendStatus(fiber.StatusBadRequest)
|
||||
}
|
||||
|
||||
tx, err := h.db.BeginTx(c.Context(), nil)
|
||||
if err != nil {
|
||||
return c.SendStatus(fiber.StatusInternalServerError)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
acc, u, err := h.accountService.Register(c.Context(), tx, RegisterOptions{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
DisplayName: req.DisplayName,
|
||||
})
|
||||
if err != nil {
|
||||
var ae *user.AlreadyExistsError
|
||||
if errors.As(err, &ae) {
|
||||
return c.SendStatus(fiber.StatusConflict)
|
||||
}
|
||||
if errors.Is(err, ErrAccountAlreadyExists) {
|
||||
return c.SendStatus(fiber.StatusConflict)
|
||||
}
|
||||
return c.SendStatus(fiber.StatusBadRequest)
|
||||
}
|
||||
|
||||
result, err := h.authService.GenerateTokenForUser(c.Context(), tx, u)
|
||||
if err != nil {
|
||||
return c.SendStatus(fiber.StatusInternalServerError)
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return c.SendStatus(fiber.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return c.JSON(registerAccountResponse{
|
||||
Account: acc,
|
||||
User: u,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
})
|
||||
}
|
||||
107
apps/backend/internal/account/service.go
Normal file
107
apps/backend/internal/account/service.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/get-drexa/drexa/internal/database"
|
||||
"github.com/get-drexa/drexa/internal/password"
|
||||
"github.com/get-drexa/drexa/internal/user"
|
||||
"github.com/google/uuid"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
userService user.Service
|
||||
}
|
||||
|
||||
type RegisterOptions struct {
|
||||
Email string
|
||||
Password string
|
||||
DisplayName string
|
||||
}
|
||||
|
||||
type CreateAccountOptions struct {
|
||||
OrganizationID uuid.UUID
|
||||
QuotaBytes int64
|
||||
}
|
||||
|
||||
func NewService(userService *user.Service) *Service {
|
||||
return &Service{
|
||||
userService: *userService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Register(ctx context.Context, db bun.IDB, opts RegisterOptions) (*Account, *user.User, error) {
|
||||
hashed, err := password.Hash(opts.Password)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
|
||||
Email: opts.Email,
|
||||
Password: hashed,
|
||||
DisplayName: opts.DisplayName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
acc, err := s.CreateAccount(ctx, db, u.ID, CreateAccountOptions{
|
||||
// TODO: make quota configurable
|
||||
QuotaBytes: 1024 * 1024 * 1024, // 1GB
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return acc, u, nil
|
||||
}
|
||||
|
||||
func (s *Service) CreateAccount(ctx context.Context, db bun.IDB, userID uuid.UUID, opts CreateAccountOptions) (*Account, error) {
|
||||
id, err := newAccountID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
StorageQuotaBytes: opts.QuotaBytes,
|
||||
}
|
||||
|
||||
_, err = db.NewInsert().Model(account).Returning("*").Exec(ctx)
|
||||
if err != nil {
|
||||
if database.IsUniqueViolation(err) {
|
||||
return nil, ErrAccountAlreadyExists
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *Service) AccountByUserID(ctx context.Context, db bun.IDB, userID uuid.UUID) (*Account, error) {
|
||||
var account Account
|
||||
err := db.NewSelect().Model(&account).Where("user_id = ?", userID).Scan(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrAccountNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *Service) AccountByID(ctx context.Context, db bun.IDB, userID uuid.UUID, id uuid.UUID) (*Account, error) {
|
||||
var account Account
|
||||
err := db.NewSelect().Model(&account).Where("user_id = ?", userID).Where("id = ?", id).Scan(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrAccountNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
|
||||
"github.com/get-drexa/drexa/internal/user"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
@@ -14,12 +13,6 @@ type loginRequest struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type registerRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
DisplayName string `json:"displayName"`
|
||||
}
|
||||
|
||||
type loginResponse struct {
|
||||
User user.User `json:"user"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
@@ -37,9 +30,7 @@ func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
|
||||
|
||||
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
|
||||
auth := api.Group("/auth")
|
||||
|
||||
auth.Post("/login", h.Login)
|
||||
auth.Post("/register", h.Register)
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
||||
@@ -54,7 +45,7 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
|
||||
result, err := h.service.AuthenticateWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrInvalidCredentials) {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
|
||||
@@ -72,42 +63,3 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
||||
RefreshToken: result.RefreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) Register(c *fiber.Ctx) error {
|
||||
req := new(registerRequest)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
||||
}
|
||||
|
||||
tx, err := h.db.BeginTx(c.Context(), nil)
|
||||
if err != nil {
|
||||
slog.Error("failed to begin transaction", "error", err)
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
result, err := h.service.Register(c.Context(), tx, registerOptions{
|
||||
email: req.Email,
|
||||
password: req.Password,
|
||||
displayName: req.DisplayName,
|
||||
})
|
||||
if err != nil {
|
||||
var ae *user.AlreadyExistsError
|
||||
if errors.As(err, &ae) {
|
||||
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"})
|
||||
}
|
||||
slog.Error("failed to register user", "error", err)
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
slog.Error("failed to commit transaction", "error", err)
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
||||
}
|
||||
|
||||
return c.JSON(loginResponse{
|
||||
User: *result.User,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type LoginResult struct {
|
||||
type AuthenticationResult struct {
|
||||
User *user.User
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
@@ -24,12 +24,6 @@ type Service struct {
|
||||
tokenConfig TokenConfig
|
||||
}
|
||||
|
||||
type registerOptions struct {
|
||||
displayName string
|
||||
email string
|
||||
password string
|
||||
}
|
||||
|
||||
func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
|
||||
return &Service{
|
||||
userService: userService,
|
||||
@@ -37,7 +31,30 @@ func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*LoginResult, error) {
|
||||
func (s *Service) GenerateTokenForUser(ctx context.Context, db bun.IDB, user *user.User) (*AuthenticationResult, error) {
|
||||
at, err := GenerateAccessToken(user, &s.tokenConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rt, err := GenerateRefreshToken(user, &s.tokenConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = db.NewInsert().Model(rt).Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AuthenticationResult{
|
||||
User: user,
|
||||
AccessToken: at,
|
||||
RefreshToken: hex.EncodeToString(rt.Token),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) AuthenticateWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*AuthenticationResult, error) {
|
||||
u, err := s.userService.UserByEmail(ctx, db, email)
|
||||
if err != nil {
|
||||
var nf *user.NotFoundError
|
||||
@@ -67,44 +84,7 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, ema
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &LoginResult{
|
||||
User: u,
|
||||
AccessToken: at,
|
||||
RefreshToken: hex.EncodeToString(rt.Token),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions) (*LoginResult, error) {
|
||||
hashed, err := password.Hash(opts.password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
|
||||
Email: opts.email,
|
||||
DisplayName: opts.displayName,
|
||||
Password: hashed,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
at, err := GenerateAccessToken(u, &s.tokenConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rt, err := GenerateRefreshToken(u, &s.tokenConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = db.NewInsert().Model(rt).Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &LoginResult{
|
||||
return &AuthenticationResult{
|
||||
User: u,
|
||||
AccessToken: at,
|
||||
RefreshToken: hex.EncodeToString(rt.Token),
|
||||
|
||||
@@ -7,14 +7,23 @@ CREATE TABLE IF NOT EXISTS users (
|
||||
display_name TEXT,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
password TEXT NOT NULL,
|
||||
storage_usage_bytes BIGINT NOT NULL,
|
||||
storage_quota_bytes BIGINT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_users_email ON users(email);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
id UUID PRIMARY KEY,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
storage_usage_bytes BIGINT NOT NULL DEFAULT 0,
|
||||
storage_quota_bytes BIGINT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_accounts_user_id ON accounts(user_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
id UUID PRIMARY KEY,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
@@ -31,7 +40,7 @@ CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
|
||||
CREATE TABLE IF NOT EXISTS vfs_nodes (
|
||||
id UUID PRIMARY KEY,
|
||||
public_id TEXT NOT NULL UNIQUE, -- opaque ID for external API (no timestamp leak)
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
account_id UUID NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
parent_id UUID REFERENCES vfs_nodes(id) ON DELETE CASCADE, -- NULL = root directory
|
||||
kind TEXT NOT NULL CHECK (kind IN ('file', 'directory')),
|
||||
status TEXT NOT NULL DEFAULT 'ready' CHECK (status IN ('pending', 'ready')),
|
||||
@@ -46,17 +55,17 @@ CREATE TABLE IF NOT EXISTS vfs_nodes (
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
deleted_at TIMESTAMPTZ, -- soft delete for trash
|
||||
|
||||
-- No duplicate names in same parent (per user, excluding deleted)
|
||||
CONSTRAINT unique_node_name UNIQUE NULLS NOT DISTINCT (user_id, parent_id, name, deleted_at)
|
||||
-- No duplicate names in same parent (per account, excluding deleted)
|
||||
CONSTRAINT unique_node_name UNIQUE NULLS NOT DISTINCT (account_id, parent_id, name, deleted_at)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_vfs_nodes_user_id ON vfs_nodes(user_id) WHERE deleted_at IS NULL;
|
||||
CREATE INDEX idx_vfs_nodes_account_id ON vfs_nodes(account_id) WHERE deleted_at IS NULL;
|
||||
CREATE INDEX idx_vfs_nodes_parent_id ON vfs_nodes(parent_id) WHERE deleted_at IS NULL;
|
||||
CREATE INDEX idx_vfs_nodes_user_parent ON vfs_nodes(user_id, parent_id) WHERE deleted_at IS NULL;
|
||||
CREATE INDEX idx_vfs_nodes_kind ON vfs_nodes(user_id, kind) WHERE deleted_at IS NULL;
|
||||
CREATE INDEX idx_vfs_nodes_deleted ON vfs_nodes(user_id, deleted_at) WHERE deleted_at IS NOT NULL;
|
||||
CREATE INDEX idx_vfs_nodes_account_parent ON vfs_nodes(account_id, parent_id) WHERE deleted_at IS NULL;
|
||||
CREATE INDEX idx_vfs_nodes_kind ON vfs_nodes(account_id, kind) WHERE deleted_at IS NULL;
|
||||
CREATE INDEX idx_vfs_nodes_deleted ON vfs_nodes(account_id, deleted_at) WHERE deleted_at IS NOT NULL;
|
||||
CREATE INDEX idx_vfs_nodes_public_id ON vfs_nodes(public_id);
|
||||
CREATE UNIQUE INDEX idx_vfs_nodes_user_root ON vfs_nodes(user_id) WHERE parent_id IS NULL; -- one root per user
|
||||
CREATE UNIQUE INDEX idx_vfs_nodes_account_root ON vfs_nodes(account_id) WHERE parent_id IS NULL; -- one root per account
|
||||
CREATE INDEX idx_vfs_nodes_pending ON vfs_nodes(created_at) WHERE status = 'pending'; -- for cleanup job
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_shares (
|
||||
@@ -91,4 +100,7 @@ CREATE TRIGGER update_vfs_nodes_updated_at BEFORE UPDATE ON vfs_nodes
|
||||
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE TRIGGER update_node_shares_updated_at BEFORE UPDATE ON node_shares
|
||||
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE TRIGGER update_accounts_updated_at BEFORE UPDATE ON accounts
|
||||
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/get-drexa/drexa/internal/account"
|
||||
"github.com/get-drexa/drexa/internal/auth"
|
||||
"github.com/get-drexa/drexa/internal/blob"
|
||||
"github.com/get-drexa/drexa/internal/database"
|
||||
@@ -61,12 +62,16 @@ func NewServer(c Config) (*fiber.App, error) {
|
||||
SecretKey: c.JWT.SecretKey,
|
||||
})
|
||||
uploadService := upload.NewService(vfs, blobStore)
|
||||
accountService := account.NewService(userService)
|
||||
|
||||
authMiddleware := auth.NewBearerAuthMiddleware(authService, db)
|
||||
|
||||
api := app.Group("/api")
|
||||
|
||||
accRouter := account.NewHTTPHandler(accountService, authService, db, authMiddleware).RegisterRoutes(api)
|
||||
|
||||
auth.NewHTTPHandler(authService, db).RegisterRoutes(api)
|
||||
upload.NewHTTPHandler(uploadService, authMiddleware).RegisterRoutes(api, authMiddleware)
|
||||
upload.NewHTTPHandler(uploadService).RegisterRoutes(accRouter)
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package upload
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/get-drexa/drexa/internal/auth"
|
||||
"github.com/get-drexa/drexa/internal/account"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
@@ -17,17 +17,15 @@ type updateUploadRequest struct {
|
||||
}
|
||||
|
||||
type HTTPHandler struct {
|
||||
service *Service
|
||||
authMiddleware fiber.Handler
|
||||
service *Service
|
||||
}
|
||||
|
||||
func NewHTTPHandler(s *Service, authMiddleware fiber.Handler) *HTTPHandler {
|
||||
return &HTTPHandler{service: s, authMiddleware: authMiddleware}
|
||||
func NewHTTPHandler(s *Service) *HTTPHandler {
|
||||
return &HTTPHandler{service: s}
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) RegisterRoutes(api fiber.Router, authMiddleware fiber.Handler) {
|
||||
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
|
||||
upload := api.Group("/uploads")
|
||||
upload.Use(authMiddleware)
|
||||
|
||||
upload.Post("/", h.Create)
|
||||
upload.Put("/:uploadID/content", h.ReceiveContent)
|
||||
@@ -35,9 +33,9 @@ func (h *HTTPHandler) RegisterRoutes(api fiber.Router, authMiddleware fiber.Hand
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) Create(c *fiber.Ctx) error {
|
||||
u, err := auth.AuthenticatedUser(c)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
|
||||
account := account.CurrentAccount(c)
|
||||
if account == nil {
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
|
||||
req := new(createUploadRequest)
|
||||
@@ -45,7 +43,7 @@ func (h *HTTPHandler) Create(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
||||
}
|
||||
|
||||
upload, err := h.service.CreateUpload(c.Context(), u.ID, CreateUploadOptions{
|
||||
upload, err := h.service.CreateUpload(c.Context(), account.ID, CreateUploadOptions{
|
||||
ParentID: req.ParentID,
|
||||
Name: req.Name,
|
||||
})
|
||||
@@ -57,14 +55,14 @@ func (h *HTTPHandler) Create(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
|
||||
u, err := auth.AuthenticatedUser(c)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
|
||||
account := account.CurrentAccount(c)
|
||||
if account == nil {
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
|
||||
uploadID := c.Params("uploadID")
|
||||
|
||||
err = h.service.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream())
|
||||
err := h.service.ReceiveUpload(c.Context(), account.ID, uploadID, c.Request().BodyStream())
|
||||
defer c.Request().CloseBodyStream()
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
||||
@@ -74,9 +72,9 @@ func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) Update(c *fiber.Ctx) error {
|
||||
u, err := auth.AuthenticatedUser(c)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
|
||||
account := account.CurrentAccount(c)
|
||||
if account == nil {
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
|
||||
req := new(updateUploadRequest)
|
||||
@@ -85,7 +83,7 @@ func (h *HTTPHandler) Update(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
if req.Status == StatusCompleted {
|
||||
upload, err := h.service.CompleteUpload(c.Context(), u.ID, c.Params("uploadID"))
|
||||
upload, err := h.service.CompleteUpload(c.Context(), account.ID, c.Params("uploadID"))
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
return c.SendStatus(fiber.StatusNotFound)
|
||||
|
||||
@@ -33,8 +33,8 @@ type CreateUploadOptions struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts CreateUploadOptions) (*Upload, error) {
|
||||
parentNode, err := s.vfs.FindNodeByPublicID(ctx, userID, opts.ParentID)
|
||||
func (s *Service) CreateUpload(ctx context.Context, accountID uuid.UUID, opts CreateUploadOptions) (*Upload, error) {
|
||||
parentNode, err := s.vfs.FindNodeByPublicID(ctx, accountID, opts.ParentID)
|
||||
if err != nil {
|
||||
if errors.Is(err, virtualfs.ErrNodeNotFound) {
|
||||
return nil, ErrNotFound
|
||||
@@ -46,7 +46,7 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
|
||||
return nil, ErrParentNotDirectory
|
||||
}
|
||||
|
||||
node, err := s.vfs.CreateFile(ctx, userID, virtualfs.CreateFileOptions{
|
||||
node, err := s.vfs.CreateFile(ctx, accountID, virtualfs.CreateFileOptions{
|
||||
ParentID: parentNode.ID,
|
||||
Name: opts.Name,
|
||||
})
|
||||
@@ -77,7 +77,7 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
|
||||
return upload, nil
|
||||
}
|
||||
|
||||
func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID string, reader io.Reader) error {
|
||||
func (s *Service) ReceiveUpload(ctx context.Context, accountID uuid.UUID, uploadID string, reader io.Reader) error {
|
||||
n, ok := s.pendingUploads.Load(uploadID)
|
||||
if !ok {
|
||||
return ErrNotFound
|
||||
@@ -88,7 +88,7 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
if upload.TargetNode.UserID != userID {
|
||||
if upload.TargetNode.AccountID != accountID {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID string) (*Upload, error) {
|
||||
func (s *Service) CompleteUpload(ctx context.Context, accountID uuid.UUID, uploadID string) (*Upload, error) {
|
||||
n, ok := s.pendingUploads.Load(uploadID)
|
||||
if !ok {
|
||||
return nil, ErrNotFound
|
||||
@@ -113,7 +113,7 @@ func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
if upload.TargetNode.UserID != userID {
|
||||
if upload.TargetNode.AccountID != accountID {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
|
||||
@@ -11,14 +11,12 @@ import (
|
||||
type User struct {
|
||||
bun.BaseModel `bun:"users"`
|
||||
|
||||
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
|
||||
DisplayName string `bun:"display_name" json:"displayName"`
|
||||
Email string `bun:"email,unique,notnull" json:"email"`
|
||||
Password password.Hashed `bun:"password,notnull" json:"-"`
|
||||
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
|
||||
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
|
||||
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
|
||||
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
|
||||
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
|
||||
DisplayName string `bun:"display_name" json:"displayName"`
|
||||
Email string `bun:"email,unique,notnull" json:"email"`
|
||||
Password password.Hashed `bun:"password,notnull" json:"-"`
|
||||
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
|
||||
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
|
||||
}
|
||||
|
||||
func newUserID() (uuid.UUID, error) {
|
||||
|
||||
@@ -25,13 +25,13 @@ const (
|
||||
type Node struct {
|
||||
bun.BaseModel `bun:"vfs_nodes"`
|
||||
|
||||
ID uuid.UUID `bun:",pk,type:uuid"`
|
||||
PublicID string `bun:"public_id,notnull"`
|
||||
UserID uuid.UUID `bun:"user_id,notnull"`
|
||||
ParentID uuid.UUID `bun:"parent_id,notnull"`
|
||||
Kind NodeKind `bun:"kind,notnull"`
|
||||
Status NodeStatus `bun:"status,notnull"`
|
||||
Name string `bun:"name,notnull"`
|
||||
ID uuid.UUID `bun:",pk,type:uuid"`
|
||||
PublicID string `bun:"public_id,notnull"`
|
||||
AccountID uuid.UUID `bun:"account_id,notnull,type:uuid"`
|
||||
ParentID uuid.UUID `bun:"parent_id"`
|
||||
Kind NodeKind `bun:"kind,notnull"`
|
||||
Status NodeStatus `bun:"status,notnull"`
|
||||
Name string `bun:"name,notnull"`
|
||||
|
||||
BlobKey blob.Key `bun:"blob_key"`
|
||||
Size int64 `bun:"size"`
|
||||
|
||||
@@ -63,10 +63,10 @@ func NewVirtualFS(db *bun.DB, blobStore blob.Store, keyResolver BlobKeyResolver)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (vfs *VirtualFS) FindNode(ctx context.Context, userID, fileID string) (*Node, error) {
|
||||
func (vfs *VirtualFS) FindNode(ctx context.Context, accountID, fileID string) (*Node, error) {
|
||||
var node Node
|
||||
err := vfs.db.NewSelect().Model(&node).
|
||||
Where("user_id = ?", userID).
|
||||
Where("account_id = ?", accountID).
|
||||
Where("id = ?", fileID).
|
||||
Where("status = ?", NodeStatusReady).
|
||||
Where("deleted_at IS NULL").
|
||||
@@ -80,10 +80,10 @@ func (vfs *VirtualFS) FindNode(ctx context.Context, userID, fileID string) (*Nod
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, userID uuid.UUID, publicID string) (*Node, error) {
|
||||
func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, accountID uuid.UUID, publicID string) (*Node, error) {
|
||||
var node Node
|
||||
err := vfs.db.NewSelect().Model(&node).
|
||||
Where("user_id = ?", userID).
|
||||
Where("account_id = ?", accountID).
|
||||
Where("public_id = ?", publicID).
|
||||
Where("status = ?", NodeStatusReady).
|
||||
Where("deleted_at IS NULL").
|
||||
@@ -104,7 +104,7 @@ func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, er
|
||||
|
||||
var nodes []*Node
|
||||
err := vfs.db.NewSelect().Model(&nodes).
|
||||
Where("user_id = ?", node.UserID).
|
||||
Where("account_id = ?", node.AccountID).
|
||||
Where("parent_id = ?", node.ID).
|
||||
Where("status = ?", NodeStatusReady).
|
||||
Where("deleted_at IS NULL").
|
||||
@@ -119,19 +119,19 @@ func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, er
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (vfs *VirtualFS) CreateFile(ctx context.Context, userID uuid.UUID, opts CreateFileOptions) (*Node, error) {
|
||||
func (vfs *VirtualFS) CreateFile(ctx context.Context, accountID uuid.UUID, opts CreateFileOptions) (*Node, error) {
|
||||
pid, err := vfs.generatePublicID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node := Node{
|
||||
PublicID: pid,
|
||||
UserID: userID,
|
||||
ParentID: opts.ParentID,
|
||||
Kind: NodeKindFile,
|
||||
Status: NodeStatusPending,
|
||||
Name: opts.Name,
|
||||
PublicID: pid,
|
||||
AccountID: accountID,
|
||||
ParentID: opts.ParentID,
|
||||
Kind: NodeKindFile,
|
||||
Status: NodeStatusPending,
|
||||
Name: opts.Name,
|
||||
}
|
||||
|
||||
if vfs.keyResolver.ShouldPersistKey() {
|
||||
@@ -233,19 +233,19 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vfs *VirtualFS) CreateDirectory(ctx context.Context, userID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) {
|
||||
func (vfs *VirtualFS) CreateDirectory(ctx context.Context, accountID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) {
|
||||
pid, err := vfs.generatePublicID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node := Node{
|
||||
PublicID: pid,
|
||||
UserID: userID,
|
||||
ParentID: parentID,
|
||||
Kind: NodeKindDirectory,
|
||||
Status: NodeStatusReady,
|
||||
Name: name,
|
||||
PublicID: pid,
|
||||
AccountID: accountID,
|
||||
ParentID: parentID,
|
||||
Kind: NodeKindDirectory,
|
||||
Status: NodeStatusReady,
|
||||
Name: name,
|
||||
}
|
||||
|
||||
_, err = vfs.db.NewInsert().Model(&node).Exec(ctx)
|
||||
|
||||
Reference in New Issue
Block a user