feat(backend): return list of orgs in login resp

This commit is contained in:
2026-01-02 15:57:35 +00:00
parent 4688ba49c1
commit b40a80e4b6
8 changed files with 98 additions and 25 deletions

View File

@@ -14,13 +14,6 @@ RUN curl -LO https://github.com/neovim/neovim/releases/latest/download/nvim-linu
&& ln -s /opt/nvim-linux-x86_64/bin/nvim /usr/local/bin/nvim \
&& rm nvim-linux-x86_64.tar.gz
# Install lazygit
RUN LAZYGIT_VERSION=$(curl -s "https://api.github.com/repos/jesseduffield/lazygit/releases/latest" | grep -Po '"tag_name": "v\K[^"]*') \
&& curl -Lo lazygit.tar.gz "https://github.com/jesseduffield/lazygit/releases/latest/download/lazygit_${LAZYGIT_VERSION}_Linux_x86_64.tar.gz" \
&& tar xf lazygit.tar.gz lazygit \
&& install lazygit /usr/local/bin \
&& rm lazygit.tar.gz lazygit
# Install Bun as the node user
USER node
RUN curl -fsSL https://bun.sh/install | bash

View File

@@ -15,7 +15,6 @@
"golangciLintVersion": "2.6.1"
}
},
"postCreateCommand": "./scripts/setup-git.sh",
"customizations": {
"vscode": {
"extensions": [
@@ -97,6 +96,8 @@
}
},
"forwardPorts": [3000, 5173, 8080, 1455],
"appPort": [3000, 5173, 8080, 1455],
"postCreateCommand": "./scripts/editor-setup.sh",
"portsAttributes": {
"3000": {
"label": "Development Server",
@@ -111,5 +112,9 @@
"onAutoForward": "notify"
}
},
"mounts": [
"source=${localEnv:HOME}/.config/nvim,target=/home/node/.config/nvim,type=bind,consistency=cached",
"source=${localWorkspaceFolder}/.codex,target=/home/node/.codex,type=bind,consistency=cached"
],
"remoteUser": "node"
}

View File

@@ -5,6 +5,7 @@ import (
"log/slog"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/organization"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
@@ -36,6 +37,8 @@ type loginRequest struct {
type loginResponse struct {
// Authenticated user information
User user.User `json:"user"`
// Organizations the user is a member of
Organizations []organization.Organization `json:"organizations"`
// JWT access token (only included when tokenDelivery is "body")
AccessToken string `json:"accessToken,omitempty" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI1NTBlODQwMC1lMjliLTQxZDQtYTcxNi00NDY2NTU0NDAwMDAifQ.signature"`
// Base64 URL encoded refresh token (only included when tokenDelivery is "body")
@@ -59,13 +62,14 @@ type tokenResponse struct {
}
type HTTPHandler struct {
service *Service
db *bun.DB
cookieConfig CookieConfig
service *Service
organizationService *organization.Service
db *bun.DB
cookieConfig CookieConfig
}
func NewHTTPHandler(s *Service, db *bun.DB, cookieConfig CookieConfig) *HTTPHandler {
return &HTTPHandler{service: s, db: db, cookieConfig: cookieConfig}
func NewHTTPHandler(s *Service, organizationService *organization.Service, db *bun.DB, cookieConfig CookieConfig) *HTTPHandler {
return &HTTPHandler{service: s, organizationService: organizationService, db: db, cookieConfig: cookieConfig}
}
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
@@ -91,6 +95,10 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
if req.TokenDelivery != TokenDeliveryCookie && req.TokenDelivery != TokenDeliveryBody {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid token delivery method"})
}
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return httperr.Internal(err)
@@ -105,27 +113,33 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
return httperr.Internal(err)
}
orgs, err := h.organizationService.ListOrganizationsForUser(c.Context(), tx, result.User.ID)
if err != nil {
return httperr.Internal(err)
}
if err := tx.Commit(); err != nil {
return httperr.Internal(err)
}
switch req.TokenDelivery {
default:
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid token delivery method"})
case TokenDeliveryCookie:
SetAuthCookies(c, result.AccessToken, result.RefreshToken, h.cookieConfig)
return c.JSON(loginResponse{
User: *result.User,
User: *result.User,
Organizations: orgs,
})
case TokenDeliveryBody:
return c.JSON(loginResponse{
User: *result.User,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
User: *result.User,
Organizations: orgs,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
})
}
return httperr.Internal(errors.New("unreachable token delivery"))
}
// refreshAccessToken exchanges a refresh token for new access and refresh tokens

View File

@@ -15,6 +15,7 @@ import (
"time"
"github.com/get-drexa/drexa/internal/database"
"github.com/get-drexa/drexa/internal/organization"
"github.com/gofiber/fiber/v2"
"github.com/testcontainers/testcontainers-go/modules/postgres"
)
@@ -103,6 +104,45 @@ func TestRegistrationFlow(t *testing.T) {
t.Fatalf("expected account.id and drive.id to be set")
}
t.Run("login includes organizations", func(t *testing.T) {
type loginResponse struct {
User struct {
ID string `json:"id"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
} `json:"user"`
Organizations []struct {
ID string `json:"id"`
Kind string `json:"kind"`
Name string `json:"name"`
Slug string `json:"slug"`
} `json:"organizations"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
}
loginBody := map[string]any{
"email": "alice@example.com",
"password": "password123",
"tokenDelivery": "body",
}
var login loginResponse
doJSON(t, s.app, http.MethodPost, "/api/auth/login", "", loginBody, http.StatusOK, &login)
if login.AccessToken == "" {
t.Fatalf("expected access token in login response")
}
if login.User.ID != reg.User.ID {
t.Fatalf("unexpected user id: got %q want %q", login.User.ID, reg.User.ID)
}
if len(login.Organizations) != 1 {
t.Fatalf("expected 1 organization, got %d", len(login.Organizations))
}
if login.Organizations[0].Kind != string(organization.KindPersonal) {
t.Fatalf("unexpected organization kind: %q", login.Organizations[0].Kind)
}
})
t.Run("drives list", func(t *testing.T) {
var drives []struct {
ID string `json:"id"`

View File

@@ -122,7 +122,7 @@ func NewServer(c Config) (*Server, error) {
optionalAuthMiddleware := auth.NewOptionalAuthMiddleware(authService, db, cookieConfig)
api := app.Group("/api")
auth.NewHTTPHandler(authService, db, cookieConfig).RegisterRoutes(api)
auth.NewHTTPHandler(authService, organizationService, db, cookieConfig).RegisterRoutes(api)
registration.NewHTTPHandler(registrationService, authService, db, cookieConfig).RegisterRoutes(api)
user.NewHTTPHandler(userService, db, authMiddleware).RegisterRoutes(api)
account.NewHTTPHandler(accountService, db, authMiddleware).RegisterRoutes(api)

View File

@@ -27,7 +27,7 @@ func NewMiddleware(orgService *Service, accountService *account.Service, db *bun
var org *Organization
var err error
if slug == reservedSlug {
if slug == ReservedSlug {
org, err = orgService.PersonalOrganizationForUser(c.Context(), db, u.ID)
} else {
org, err = orgService.OrganizationBySlug(c.Context(), db, slug)

View File

@@ -76,3 +76,21 @@ func (s *Service) PersonalOrganizationForUser(ctx context.Context, db bun.IDB, u
}
return &org, nil
}
func (s *Service) ListOrganizationsForUser(ctx context.Context, db bun.IDB, userID uuid.UUID) ([]Organization, error) {
orgs := make([]Organization, 0)
err := db.NewSelect().
Model(&orgs).
Join("JOIN accounts ON accounts.org_id = organization.id").
Where("accounts.user_id = ?", userID).
Where("accounts.status = ?", account.StatusActive).
Order("organization.kind ASC", "organization.name ASC").
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return orgs, nil
}
return nil, err
}
return orgs, nil
}

View File

@@ -11,9 +11,12 @@ const (
slugMaxLength = 63
)
// ReservedSlug is the special organization slug used to address the authenticated
// user's personal organization in API routes (e.g. /api/my/...).
const ReservedSlug = "my"
var (
slugPattern = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
reservedSlug = "my"
slugPattern = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
)
var ErrInvalidSlug = errors.New("invalid organization slug")
@@ -24,7 +27,7 @@ func NormalizeSlug(input string) (string, error) {
if len(slug) < slugMinLength || len(slug) > slugMaxLength {
return "", ErrInvalidSlug
}
if slug == reservedSlug {
if slug == ReservedSlug {
return "", ErrInvalidSlug
}
if !slugPattern.MatchString(slug) {