mirror of
https://github.com/get-drexa/drive.git
synced 2026-02-02 11:51:17 +00:00
feat(backend): return list of orgs in login resp
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/get-drexa/drexa/internal/httperr"
|
||||
"github.com/get-drexa/drexa/internal/organization"
|
||||
"github.com/get-drexa/drexa/internal/user"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/uptrace/bun"
|
||||
@@ -36,6 +37,8 @@ type loginRequest struct {
|
||||
type loginResponse struct {
|
||||
// Authenticated user information
|
||||
User user.User `json:"user"`
|
||||
// Organizations the user is a member of
|
||||
Organizations []organization.Organization `json:"organizations"`
|
||||
// JWT access token (only included when tokenDelivery is "body")
|
||||
AccessToken string `json:"accessToken,omitempty" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI1NTBlODQwMC1lMjliLTQxZDQtYTcxNi00NDY2NTU0NDAwMDAifQ.signature"`
|
||||
// Base64 URL encoded refresh token (only included when tokenDelivery is "body")
|
||||
@@ -59,13 +62,14 @@ type tokenResponse struct {
|
||||
}
|
||||
|
||||
type HTTPHandler struct {
|
||||
service *Service
|
||||
db *bun.DB
|
||||
cookieConfig CookieConfig
|
||||
service *Service
|
||||
organizationService *organization.Service
|
||||
db *bun.DB
|
||||
cookieConfig CookieConfig
|
||||
}
|
||||
|
||||
func NewHTTPHandler(s *Service, db *bun.DB, cookieConfig CookieConfig) *HTTPHandler {
|
||||
return &HTTPHandler{service: s, db: db, cookieConfig: cookieConfig}
|
||||
func NewHTTPHandler(s *Service, organizationService *organization.Service, db *bun.DB, cookieConfig CookieConfig) *HTTPHandler {
|
||||
return &HTTPHandler{service: s, organizationService: organizationService, db: db, cookieConfig: cookieConfig}
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
|
||||
@@ -91,6 +95,10 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
||||
}
|
||||
|
||||
if req.TokenDelivery != TokenDeliveryCookie && req.TokenDelivery != TokenDeliveryBody {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid token delivery method"})
|
||||
}
|
||||
|
||||
tx, err := h.db.BeginTx(c.Context(), nil)
|
||||
if err != nil {
|
||||
return httperr.Internal(err)
|
||||
@@ -105,27 +113,33 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
||||
return httperr.Internal(err)
|
||||
}
|
||||
|
||||
orgs, err := h.organizationService.ListOrganizationsForUser(c.Context(), tx, result.User.ID)
|
||||
if err != nil {
|
||||
return httperr.Internal(err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return httperr.Internal(err)
|
||||
}
|
||||
|
||||
switch req.TokenDelivery {
|
||||
default:
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid token delivery method"})
|
||||
|
||||
case TokenDeliveryCookie:
|
||||
SetAuthCookies(c, result.AccessToken, result.RefreshToken, h.cookieConfig)
|
||||
return c.JSON(loginResponse{
|
||||
User: *result.User,
|
||||
User: *result.User,
|
||||
Organizations: orgs,
|
||||
})
|
||||
|
||||
case TokenDeliveryBody:
|
||||
return c.JSON(loginResponse{
|
||||
User: *result.User,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
User: *result.User,
|
||||
Organizations: orgs,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
return httperr.Internal(errors.New("unreachable token delivery"))
|
||||
}
|
||||
|
||||
// refreshAccessToken exchanges a refresh token for new access and refresh tokens
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/get-drexa/drexa/internal/database"
|
||||
"github.com/get-drexa/drexa/internal/organization"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
)
|
||||
@@ -103,6 +104,45 @@ func TestRegistrationFlow(t *testing.T) {
|
||||
t.Fatalf("expected account.id and drive.id to be set")
|
||||
}
|
||||
|
||||
t.Run("login includes organizations", func(t *testing.T) {
|
||||
type loginResponse struct {
|
||||
User struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
} `json:"user"`
|
||||
Organizations []struct {
|
||||
ID string `json:"id"`
|
||||
Kind string `json:"kind"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
} `json:"organizations"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
loginBody := map[string]any{
|
||||
"email": "alice@example.com",
|
||||
"password": "password123",
|
||||
"tokenDelivery": "body",
|
||||
}
|
||||
|
||||
var login loginResponse
|
||||
doJSON(t, s.app, http.MethodPost, "/api/auth/login", "", loginBody, http.StatusOK, &login)
|
||||
if login.AccessToken == "" {
|
||||
t.Fatalf("expected access token in login response")
|
||||
}
|
||||
if login.User.ID != reg.User.ID {
|
||||
t.Fatalf("unexpected user id: got %q want %q", login.User.ID, reg.User.ID)
|
||||
}
|
||||
if len(login.Organizations) != 1 {
|
||||
t.Fatalf("expected 1 organization, got %d", len(login.Organizations))
|
||||
}
|
||||
if login.Organizations[0].Kind != string(organization.KindPersonal) {
|
||||
t.Fatalf("unexpected organization kind: %q", login.Organizations[0].Kind)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("drives list", func(t *testing.T) {
|
||||
var drives []struct {
|
||||
ID string `json:"id"`
|
||||
|
||||
@@ -122,7 +122,7 @@ func NewServer(c Config) (*Server, error) {
|
||||
optionalAuthMiddleware := auth.NewOptionalAuthMiddleware(authService, db, cookieConfig)
|
||||
|
||||
api := app.Group("/api")
|
||||
auth.NewHTTPHandler(authService, db, cookieConfig).RegisterRoutes(api)
|
||||
auth.NewHTTPHandler(authService, organizationService, db, cookieConfig).RegisterRoutes(api)
|
||||
registration.NewHTTPHandler(registrationService, authService, db, cookieConfig).RegisterRoutes(api)
|
||||
user.NewHTTPHandler(userService, db, authMiddleware).RegisterRoutes(api)
|
||||
account.NewHTTPHandler(accountService, db, authMiddleware).RegisterRoutes(api)
|
||||
|
||||
@@ -27,7 +27,7 @@ func NewMiddleware(orgService *Service, accountService *account.Service, db *bun
|
||||
var org *Organization
|
||||
var err error
|
||||
|
||||
if slug == reservedSlug {
|
||||
if slug == ReservedSlug {
|
||||
org, err = orgService.PersonalOrganizationForUser(c.Context(), db, u.ID)
|
||||
} else {
|
||||
org, err = orgService.OrganizationBySlug(c.Context(), db, slug)
|
||||
|
||||
@@ -76,3 +76,21 @@ func (s *Service) PersonalOrganizationForUser(ctx context.Context, db bun.IDB, u
|
||||
}
|
||||
return &org, nil
|
||||
}
|
||||
|
||||
func (s *Service) ListOrganizationsForUser(ctx context.Context, db bun.IDB, userID uuid.UUID) ([]Organization, error) {
|
||||
orgs := make([]Organization, 0)
|
||||
err := db.NewSelect().
|
||||
Model(&orgs).
|
||||
Join("JOIN accounts ON accounts.org_id = organization.id").
|
||||
Where("accounts.user_id = ?", userID).
|
||||
Where("accounts.status = ?", account.StatusActive).
|
||||
Order("organization.kind ASC", "organization.name ASC").
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return orgs, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return orgs, nil
|
||||
}
|
||||
|
||||
@@ -11,9 +11,12 @@ const (
|
||||
slugMaxLength = 63
|
||||
)
|
||||
|
||||
// ReservedSlug is the special organization slug used to address the authenticated
|
||||
// user's personal organization in API routes (e.g. /api/my/...).
|
||||
const ReservedSlug = "my"
|
||||
|
||||
var (
|
||||
slugPattern = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
|
||||
reservedSlug = "my"
|
||||
slugPattern = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
|
||||
)
|
||||
|
||||
var ErrInvalidSlug = errors.New("invalid organization slug")
|
||||
@@ -24,7 +27,7 @@ func NormalizeSlug(input string) (string, error) {
|
||||
if len(slug) < slugMinLength || len(slug) > slugMaxLength {
|
||||
return "", ErrInvalidSlug
|
||||
}
|
||||
if slug == reservedSlug {
|
||||
if slug == ReservedSlug {
|
||||
return "", ErrInvalidSlug
|
||||
}
|
||||
if !slugPattern.MatchString(slug) {
|
||||
|
||||
Reference in New Issue
Block a user