diff --git a/apps/backend/internal/auth/http_integration_test.go b/apps/backend/internal/auth/http_integration_test.go index 40c1a7f..72e76eb 100644 --- a/apps/backend/internal/auth/http_integration_test.go +++ b/apps/backend/internal/auth/http_integration_test.go @@ -22,6 +22,7 @@ import ( "github.com/get-drexa/drexa/internal/password" "github.com/get-drexa/drexa/internal/user" "github.com/gofiber/fiber/v2" + "github.com/golang-jwt/jwt/v5" "github.com/testcontainers/testcontainers-go/modules/postgres" ) @@ -194,6 +195,117 @@ func TestHTTP_Tokens_Rotation(t *testing.T) { }, http.StatusUnauthorized, nil) } +func TestHTTP_Tokens_InvalidRefreshToken(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + app := newAuthApp(t, ctx) + + // Invalid base64 + doJSON(t, app, http.MethodPost, "/api/auth/tokens", map[string]any{ + "refreshToken": "not-valid-base64!!!", + }, http.StatusUnauthorized, nil) + + // Valid base64 but not a real token + doJSON(t, app, http.MethodPost, "/api/auth/tokens", map[string]any{ + "refreshToken": "dGhpcyBpcyBub3QgYSByZWFsIHRva2Vu", + }, http.StatusUnauthorized, nil) + + // Empty token + doJSON(t, app, http.MethodPost, "/api/auth/tokens", map[string]any{ + "refreshToken": "", + }, http.StatusUnauthorized, nil) +} + +func TestHTTP_AuthMiddleware(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + app := newAuthAppWithProtectedRoute(t, ctx) + + t.Run("no token returns 401", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/protected", nil) + res, err := app.Test(req, 10_000) + if err != nil { + t.Fatalf("request failed: %v", err) + } + res.Body.Close() + + if res.StatusCode != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", res.StatusCode) + } + }) + + t.Run("invalid bearer token returns 401", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/protected", nil) + req.Header.Set("Authorization", "Bearer invalid-token") + res, err := app.Test(req, 10_000) + if err != nil { + t.Fatalf("request failed: %v", err) + } + res.Body.Close() + + if res.StatusCode != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", res.StatusCode) + } + }) + + t.Run("malformed authorization header returns 401", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/protected", nil) + req.Header.Set("Authorization", "NotBearer token") + res, err := app.Test(req, 10_000) + if err != nil { + t.Fatalf("request failed: %v", err) + } + res.Body.Close() + + if res.StatusCode != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", res.StatusCode) + } + }) + + t.Run("valid token succeeds", func(t *testing.T) { + // Login to get a valid token + var loginResp struct { + AccessToken string `json:"accessToken"` + } + doJSON(t, app, http.MethodPost, "/api/auth/login", map[string]any{ + "email": "reg@example.com", + "password": "password123", + "tokenDelivery": auth.TokenDeliveryBody, + }, http.StatusOK, &loginResp) + + req := httptest.NewRequest(http.MethodGet, "/api/protected", nil) + req.Header.Set("Authorization", "Bearer "+loginResp.AccessToken) + res, err := app.Test(req, 10_000) + if err != nil { + t.Fatalf("request failed: %v", err) + } + res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", res.StatusCode) + } + }) + + t.Run("expired token returns 401", func(t *testing.T) { + // Generate a properly signed but expired JWT using the same secret + expiredJWT := generateExpiredJWT(t, "drexa-test", "drexa-test", []byte("drexa-test-secret")) + + req := httptest.NewRequest(http.MethodGet, "/api/protected", nil) + req.Header.Set("Authorization", "Bearer "+expiredJWT) + res, err := app.Test(req, 10_000) + if err != nil { + t.Fatalf("request failed: %v", err) + } + res.Body.Close() + + if res.StatusCode != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", res.StatusCode) + } + }) +} + func newAuthApp(t *testing.T, ctx context.Context) *fiber.App { t.Helper() @@ -257,6 +369,76 @@ func newAuthApp(t *testing.T, ctx context.Context) *fiber.App { return app } +func newAuthAppWithProtectedRoute(t *testing.T, ctx context.Context) *fiber.App { + t.Helper() + + pg, err := runPostgres(ctx) + if err != nil { + t.Skipf("postgres testcontainer unavailable (docker not running/configured?): %v", err) + } + t.Cleanup(func() { _ = pg.Terminate(ctx) }) + + postgresURL, err := pg.ConnectionString(ctx, "sslmode=disable") + if err != nil { + t.Fatalf("postgres connection string: %v", err) + } + + db := database.NewFromPostgres(postgresURL) + t.Cleanup(func() { _ = db.Close() }) + + if err := database.RunMigrations(ctx, db); err != nil { + t.Fatalf("RunMigrations: %v", err) + } + + userSvc := user.NewService() + orgSvc := organization.NewService() + accSvc := account.NewService() + + hashed, err := password.HashString("password123") + if err != nil { + t.Fatalf("HashString: %v", err) + } + + u, err := userSvc.RegisterUser(ctx, db, user.UserRegistrationOptions{ + Email: "reg@example.com", + DisplayName: "Reg User", + Password: hashed, + }) + if err != nil { + t.Fatalf("RegisterUser: %v", err) + } + + org, err := orgSvc.CreatePersonalOrganization(ctx, db, "Personal") + if err != nil { + t.Fatalf("CreatePersonalOrganization: %v", err) + } + + if _, err := accSvc.CreateAccount(ctx, db, org.ID, u.ID, account.RoleAdmin, account.StatusActive); err != nil { + t.Fatalf("CreateAccount: %v", err) + } + + authSvc := auth.NewService(userSvc, auth.TokenConfig{ + Issuer: "drexa-test", + Audience: "drexa-test", + SecretKey: []byte("drexa-test-secret"), + }) + + authMiddleware := auth.NewAuthMiddleware(authSvc, db, auth.CookieConfig{}) + + app := fiber.New(fiber.Config{ + ErrorHandler: httperr.ErrorHandler, + }) + api := app.Group("/api") + auth.NewHTTPHandler(authSvc, orgSvc, db, auth.CookieConfig{}).RegisterRoutes(api) + + // Add a protected route that requires authentication + api.Get("/protected", authMiddleware, func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + return app +} + func doRequest( t *testing.T, app *fiber.App, @@ -331,3 +513,24 @@ func runPostgres(ctx context.Context) (_ *postgres.PostgresContainer, err error) postgres.BasicWaitStrategies(), ) } + +// generateExpiredJWT creates a properly signed JWT that expired 1 hour ago +func generateExpiredJWT(t *testing.T, issuer, audience string, secretKey []byte) string { + t.Helper() + + now := time.Now() + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + Issuer: issuer, + Audience: jwt.ClaimStrings{audience}, + Subject: "test-user-id", + ExpiresAt: jwt.NewNumericDate(now.Add(-1 * time.Hour)), // Expired 1 hour ago + IssuedAt: jwt.NewNumericDate(now.Add(-2 * time.Hour)), // Issued 2 hours ago + }) + + signed, err := token.SignedString(secretKey) + if err != nil { + t.Fatalf("failed to sign expired JWT: %v", err) + } + + return signed +}