diff --git a/apps/backend/internal/account/service_integration_test.go b/apps/backend/internal/account/service_integration_test.go index f220e3e..5bc7d07 100644 --- a/apps/backend/internal/account/service_integration_test.go +++ b/apps/backend/internal/account/service_integration_test.go @@ -1,6 +1,6 @@ //go:build integration -package account +package account_test import ( "context" @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/get-drexa/drexa/internal/account" "github.com/get-drexa/drexa/internal/database" "github.com/get-drexa/drexa/internal/organization" "github.com/get-drexa/drexa/internal/password" @@ -44,7 +45,7 @@ func TestService_AccountQueries(t *testing.T) { userSvc := user.NewService() orgSvc := organization.NewService() - accSvc := NewService() + accSvc := account.NewService() testUser, err := userSvc.RegisterUser(ctx, db, user.UserRegistrationOptions{ Email: "account@example.com", @@ -60,7 +61,7 @@ func TestService_AccountQueries(t *testing.T) { t.Fatalf("CreatePersonalOrganization(personal): %v", err) } - accPersonal, err := accSvc.CreateAccount(ctx, db, personalOrg.ID, testUser.ID, RoleAdmin, StatusActive) + accPersonal, err := accSvc.CreateAccount(ctx, db, personalOrg.ID, testUser.ID, account.RoleAdmin, account.StatusActive) if err != nil { t.Fatalf("CreateAccount(personal): %v", err) } @@ -86,8 +87,8 @@ func TestService_AccountQueries(t *testing.T) { if gotPersonal.OrgID != personalOrg.ID { t.Fatalf("unexpected personal org id: got %q want %q", gotPersonal.OrgID, personalOrg.ID) } - if gotPersonal.Role != RoleAdmin { - t.Fatalf("unexpected personal role: got %q want %q", gotPersonal.Role, RoleAdmin) + if gotPersonal.Role != account.RoleAdmin { + t.Fatalf("unexpected personal role: got %q want %q", gotPersonal.Role, account.RoleAdmin) } }) diff --git a/apps/backend/internal/drexa/api_integration_test.go b/apps/backend/internal/drexa/api_integration_test.go index 2ba49f3..9b06507 100644 --- a/apps/backend/internal/drexa/api_integration_test.go +++ b/apps/backend/internal/drexa/api_integration_test.go @@ -107,7 +107,7 @@ func TestRegistrationFlow(t *testing.T) { var drives []struct { ID string `json:"id"` } - doJSON(t, s.app, http.MethodGet, "/api/drives", reg.AccessToken, nil, http.StatusOK, &drives) + doJSON(t, s.app, http.MethodGet, "/api/my/drives", reg.AccessToken, nil, http.StatusOK, &drives) if len(drives) != 1 { t.Fatalf("expected 1 drive, got %d", len(drives)) } @@ -156,7 +156,7 @@ func TestRegistrationFlow(t *testing.T) { t, s.app, http.MethodGet, - fmt.Sprintf("/api/drives/%s/directories/root/content?limit=100", reg.Drive.ID), + fmt.Sprintf("/api/my/drives/%s/directories/root/content?limit=100", reg.Drive.ID), reg.AccessToken, nil, http.StatusOK, diff --git a/apps/backend/internal/drexa/server.go b/apps/backend/internal/drexa/server.go index 28e0b35..8f3e62a 100644 --- a/apps/backend/internal/drexa/server.go +++ b/apps/backend/internal/drexa/server.go @@ -123,12 +123,13 @@ func NewServer(c Config) (*Server, error) { api := app.Group("/api") auth.NewHTTPHandler(authService, db, cookieConfig).RegisterRoutes(api) - user.NewHTTPHandler(userService, db, authMiddleware).RegisterRoutes(api) - - account.NewHTTPHandler(accountService, db, authMiddleware).RegisterRoutes(api) registration.NewHTTPHandler(registrationService, authService, db, cookieConfig).RegisterRoutes(api) + user.NewHTTPHandler(userService, db, authMiddleware).RegisterRoutes(api) + account.NewHTTPHandler(accountService, db, authMiddleware).RegisterRoutes(api) - driveRouter := drive.NewHTTPHandler(driveService, accountService, vfs, db, authMiddleware).RegisterRoutes(api) + orgAPI := api.Group("/:orgSlug", authMiddleware, organization.NewMiddleware(organizationService, accountService, db)) + + driveRouter := drive.NewHTTPHandler(driveService, accountService, vfs, db, authMiddleware).RegisterRoutes(orgAPI) upload.NewHTTPHandler(uploadService, db).RegisterRoutes(driveRouter) shareHTTP := sharing.NewHTTPHandler(sharingService, accountService, driveService, vfs, db, optionalAuthMiddleware) diff --git a/apps/backend/internal/drive/http.go b/apps/backend/internal/drive/http.go index bd84e98..2be0a82 100644 --- a/apps/backend/internal/drive/http.go +++ b/apps/backend/internal/drive/http.go @@ -5,6 +5,7 @@ import ( "github.com/get-drexa/drexa/internal/account" "github.com/get-drexa/drexa/internal/httperr" + "github.com/get-drexa/drexa/internal/organization" "github.com/get-drexa/drexa/internal/reqctx" "github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/virtualfs" @@ -14,20 +15,20 @@ import ( ) type HTTPHandler struct { - driveService *Service - accountService *account.Service - vfs *virtualfs.VirtualFS - db *bun.DB - authMiddleware fiber.Handler + driveService *Service + accountService *account.Service + vfs *virtualfs.VirtualFS + db *bun.DB + authMiddleware fiber.Handler } func NewHTTPHandler(driveService *Service, accountService *account.Service, vfs *virtualfs.VirtualFS, db *bun.DB, authMiddleware fiber.Handler) *HTTPHandler { return &HTTPHandler{ - driveService: driveService, - accountService: accountService, - vfs: vfs, - db: db, - authMiddleware: authMiddleware, + driveService: driveService, + accountService: accountService, + vfs: vfs, + db: db, + authMiddleware: authMiddleware, } } @@ -45,8 +46,23 @@ func (h *HTTPHandler) RegisterRoutes(api fiber.Router) *virtualfs.ScopedRouter { func (h *HTTPHandler) listDrives(c *fiber.Ctx) error { u := reqctx.AuthenticatedUser(c).(*user.User) + org, ok := reqctx.CurrentOrganization(c).(*organization.Organization) + if !ok || org == nil { + return c.SendStatus(fiber.StatusNotFound) + } - drives, err := h.driveService.ListDrivesForUser(c.Context(), h.db, u.ID) + acc, err := h.accountService.FindUserAccountInOrg(c.Context(), h.db, org.ID, u.ID) + if err != nil { + if errors.Is(err, account.ErrAccountNotFound) { + return c.SendStatus(fiber.StatusNotFound) + } + return httperr.Internal(err) + } + if acc.Status != account.StatusActive { + return c.SendStatus(fiber.StatusNotFound) + } + + drives, err := h.driveService.ListAccessibleDrives(c.Context(), h.db, org.ID, acc.ID) if err != nil { return httperr.Internal(err) } @@ -62,7 +78,10 @@ func (h *HTTPHandler) getDrive(c *fiber.Ctx) error { } func (h *HTTPHandler) driveMiddleware(c *fiber.Ctx) error { - u := reqctx.AuthenticatedUser(c).(*user.User) + org, ok := reqctx.CurrentOrganization(c).(*organization.Organization) + if !ok || org == nil { + return c.SendStatus(fiber.StatusNotFound) + } driveID, err := uuid.Parse(c.Params("driveID")) if err != nil { @@ -76,15 +95,12 @@ func (h *HTTPHandler) driveMiddleware(c *fiber.Ctx) error { } return httperr.Internal(err) } - - acc, err := h.accountService.FindUserAccountInOrg(c.Context(), h.db, drive.OrgID, u.ID) - if err != nil { - if errors.Is(err, account.ErrAccountNotFound) { - return c.SendStatus(fiber.StatusNotFound) - } - return httperr.Internal(err) + if drive.OrgID != org.ID { + return c.SendStatus(fiber.StatusNotFound) } - if acc.Status != account.StatusActive { + + acc, ok := reqctx.CurrentAccount(c).(*account.Account) + if !ok || acc == nil { return c.SendStatus(fiber.StatusNotFound) } @@ -107,7 +123,6 @@ func (h *HTTPHandler) driveMiddleware(c *fiber.Ctx) error { } reqctx.SetCurrentDrive(c, drive) - reqctx.SetCurrentAccount(c, acc) reqctx.SetVFSAccessScope(c, scope) return c.Next() diff --git a/apps/backend/internal/organization/err.go b/apps/backend/internal/organization/err.go new file mode 100644 index 0000000..cb52712 --- /dev/null +++ b/apps/backend/internal/organization/err.go @@ -0,0 +1,7 @@ +package organization + +import "errors" + +var ( + ErrOrganizationNotFound = errors.New("organization not found") +) diff --git a/apps/backend/internal/organization/middleware.go b/apps/backend/internal/organization/middleware.go new file mode 100644 index 0000000..54b7571 --- /dev/null +++ b/apps/backend/internal/organization/middleware.go @@ -0,0 +1,59 @@ +package organization + +import ( + "errors" + "strings" + + "github.com/get-drexa/drexa/internal/account" + "github.com/get-drexa/drexa/internal/httperr" + "github.com/get-drexa/drexa/internal/reqctx" + "github.com/get-drexa/drexa/internal/user" + "github.com/gofiber/fiber/v2" + "github.com/uptrace/bun" +) + +func NewMiddleware(orgService *Service, accountService *account.Service, db *bun.DB) fiber.Handler { + return func(c *fiber.Ctx) error { + slug := strings.ToLower(c.Params("orgSlug")) + if slug == "" { + return c.SendStatus(fiber.StatusNotFound) + } + + u, _ := reqctx.AuthenticatedUser(c).(*user.User) + if u == nil { + return c.SendStatus(fiber.StatusUnauthorized) + } + + var org *Organization + var err error + + if slug == reservedSlug { + org, err = orgService.PersonalOrganizationForUser(c.Context(), db, u.ID) + } else { + org, err = orgService.OrganizationBySlug(c.Context(), db, slug) + } + + if err != nil { + if errors.Is(err, ErrOrganizationNotFound) { + return c.SendStatus(fiber.StatusNotFound) + } + return httperr.Internal(err) + } + + acc, err := accountService.FindUserAccountInOrg(c.Context(), db, org.ID, u.ID) + if err != nil { + if errors.Is(err, account.ErrAccountNotFound) { + return c.SendStatus(fiber.StatusNotFound) + } + return httperr.Internal(err) + } + if acc.Status != account.StatusActive { + return c.SendStatus(fiber.StatusNotFound) + } + reqctx.SetCurrentAccount(c, acc) + + reqctx.SetCurrentOrganization(c, org) + + return c.Next() + } +} diff --git a/apps/backend/internal/organization/service.go b/apps/backend/internal/organization/service.go index 2ed8b3d..40b180f 100644 --- a/apps/backend/internal/organization/service.go +++ b/apps/backend/internal/organization/service.go @@ -2,7 +2,10 @@ package organization import ( "context" + "database/sql" + "errors" + "github.com/get-drexa/drexa/internal/account" "github.com/google/uuid" "github.com/uptrace/bun" ) @@ -37,8 +40,39 @@ func (s *Service) OrganizationByID(ctx context.Context, db bun.IDB, id uuid.UUID var org Organization err := db.NewSelect().Model(&org).Where("id = ?", id).Scan(ctx) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrOrganizationNotFound + } return nil, err } return &org, nil } +func (s *Service) OrganizationBySlug(ctx context.Context, db bun.IDB, slug string) (*Organization, error) { + var org Organization + err := db.NewSelect().Model(&org).Where("lower(slug) = lower(?)", slug).Scan(ctx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrOrganizationNotFound + } + return nil, err + } + return &org, nil +} + +func (s *Service) PersonalOrganizationForUser(ctx context.Context, db bun.IDB, userID uuid.UUID) (*Organization, error) { + var org Organization + err := db.NewSelect().Model(&org). + Join("JOIN accounts ON accounts.org_id = organization.id"). + Where("accounts.user_id = ?", userID). + Where("accounts.status = ?", account.StatusActive). + Where("organization.kind = ?", KindPersonal). + Scan(ctx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrOrganizationNotFound + } + return nil, err + } + return &org, nil +} diff --git a/apps/backend/internal/organization/service_integration_test.go b/apps/backend/internal/organization/service_integration_test.go index f6f11fe..9608614 100644 --- a/apps/backend/internal/organization/service_integration_test.go +++ b/apps/backend/internal/organization/service_integration_test.go @@ -8,7 +8,11 @@ import ( "testing" "time" + "github.com/get-drexa/drexa/internal/account" "github.com/get-drexa/drexa/internal/database" + "github.com/get-drexa/drexa/internal/password" + "github.com/get-drexa/drexa/internal/user" + "github.com/google/uuid" "github.com/testcontainers/testcontainers-go/modules/postgres" ) @@ -62,6 +66,64 @@ func TestService_CreatePersonalOrganization(t *testing.T) { t.Fatalf("unexpected org name: got %q want %q", got.Name, "Personal Org") } }) + + t.Run("organization by slug", func(t *testing.T) { + slug := "test-org" + orgWithSlug := &Organization{ + ID: uuid.Must(uuid.NewV7()), + Kind: KindTeam, + Name: "Team Org", + Slug: &slug, + } + if _, err := db.NewInsert().Model(orgWithSlug).Returning("*").Exec(ctx); err != nil { + t.Fatalf("insert org with slug: %v", err) + } + + for _, candidate := range []string{"TEST-ORG", "Test-Org", "test-org"} { + got, err := svc.OrganizationBySlug(ctx, db, candidate) + if err != nil { + t.Fatalf("OrganizationBySlug(%q): %v", candidate, err) + } + if got.ID != orgWithSlug.ID { + t.Fatalf("unexpected org id: got %q want %q", got.ID, orgWithSlug.ID) + } + if got.Slug == nil || *got.Slug != slug { + t.Fatalf("unexpected org slug: got %v want %q", got.Slug, slug) + } + } + }) + + t.Run("personal organization for user", func(t *testing.T) { + accSvc := account.NewService() + userSvc := user.NewService() + + hashed, err := password.HashString("test-password") + if err != nil { + t.Fatalf("HashString: %v", err) + } + + u, err := userSvc.RegisterUser(ctx, db, user.UserRegistrationOptions{ + Email: fmt.Sprintf("org-user-%s@example.com", uuid.Must(uuid.NewV7())), + DisplayName: "Org User", + Password: hashed, + }) + if err != nil { + t.Fatalf("RegisterUser: %v", err) + } + + acc, err := accSvc.CreateAccount(ctx, db, org.ID, u.ID, account.RoleAdmin, account.StatusActive) + if err != nil { + t.Fatalf("CreateAccount: %v", err) + } + + got, err := svc.PersonalOrganizationForUser(ctx, db, acc.UserID) + if err != nil { + t.Fatalf("PersonalOrganizationForUser: %v", err) + } + if got.ID != org.ID { + t.Fatalf("unexpected org id: got %q want %q", got.ID, org.ID) + } + }) } func runPostgres(ctx context.Context) (_ *postgres.PostgresContainer, err error) { diff --git a/apps/backend/internal/reqctx/reqctx.go b/apps/backend/internal/reqctx/reqctx.go index 97ba7c3..ceb15cb 100644 --- a/apps/backend/internal/reqctx/reqctx.go +++ b/apps/backend/internal/reqctx/reqctx.go @@ -10,6 +10,7 @@ const authenticatedUserKey = "authenticatedUser" const vfsAccessScope = "vfsAccessScope" const currentAccountKey = "currentAccount" const currentDriveKey = "currentDrive" +const currentOrganizationKey = "currentOrganization" var ErrUnauthenticatedRequest = errors.New("unauthenticated request") @@ -40,6 +41,11 @@ func SetVFSAccessScope(c *fiber.Ctx, scope any) { c.Locals(vfsAccessScope, scope) } +// SetCurrentOrganization sets the current organization in the fiber context. +func SetCurrentOrganization(c *fiber.Ctx, organization any) { + c.Locals(currentOrganizationKey, organization) +} + // CurrentAccount returns the current account from the given fiber context. func CurrentAccount(c *fiber.Ctx) any { return c.Locals(currentAccountKey) @@ -50,6 +56,11 @@ func CurrentDrive(c *fiber.Ctx) any { return c.Locals(currentDriveKey) } +// CurrentOrganization returns the current organization from the given fiber context. +func CurrentOrganization(c *fiber.Ctx) any { + return c.Locals(currentOrganizationKey) +} + // VFSAccessScope returns the VFS access scope from the given fiber context. func VFSAccessScope(c *fiber.Ctx) any { return c.Locals(vfsAccessScope) diff --git a/apps/backend/internal/sharing/http.go b/apps/backend/internal/sharing/http.go index c72e883..f51b524 100644 --- a/apps/backend/internal/sharing/http.go +++ b/apps/backend/internal/sharing/http.go @@ -8,6 +8,7 @@ import ( "github.com/get-drexa/drexa/internal/drive" "github.com/get-drexa/drexa/internal/httperr" "github.com/get-drexa/drexa/internal/nullable" + "github.com/get-drexa/drexa/internal/organization" "github.com/get-drexa/drexa/internal/reqctx" "github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/virtualfs" @@ -78,14 +79,19 @@ func (h *HTTPHandler) shareMiddleware(c *fiber.Ctx) error { return httperr.Internal(err) } + drive, err := h.driveService.DriveByID(c.Context(), h.db, share.DriveID) + if err != nil { + return httperr.Internal(err) + } + + org, _ := reqctx.CurrentOrganization(c).(*organization.Organization) + if org != nil && drive.OrgID != org.ID { + return c.SendStatus(fiber.StatusNotFound) + } + var consumerAccount *account.Account u, _ := reqctx.AuthenticatedUser(c).(*user.User) if u != nil { - drive, err := h.driveService.DriveByID(c.Context(), h.db, share.DriveID) - if err != nil { - return httperr.Internal(err) - } - consumerAccount, err = h.accountService.FindUserAccountInOrg(c.Context(), h.db, drive.OrgID, u.ID) if err != nil { if errors.Is(err, account.ErrAccountNotFound) {