mirror of
https://github.com/get-drexa/drive.git
synced 2026-02-02 19:21:18 +00:00
test(backend): add auth error tests
Co-authored-by: Ona <no-reply@ona.com>
This commit is contained in:
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/get-drexa/drexa/internal/password"
|
||||
"github.com/get-drexa/drexa/internal/user"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
)
|
||||
|
||||
@@ -194,6 +195,117 @@ func TestHTTP_Tokens_Rotation(t *testing.T) {
|
||||
}, http.StatusUnauthorized, nil)
|
||||
}
|
||||
|
||||
func TestHTTP_Tokens_InvalidRefreshToken(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
app := newAuthApp(t, ctx)
|
||||
|
||||
// Invalid base64
|
||||
doJSON(t, app, http.MethodPost, "/api/auth/tokens", map[string]any{
|
||||
"refreshToken": "not-valid-base64!!!",
|
||||
}, http.StatusUnauthorized, nil)
|
||||
|
||||
// Valid base64 but not a real token
|
||||
doJSON(t, app, http.MethodPost, "/api/auth/tokens", map[string]any{
|
||||
"refreshToken": "dGhpcyBpcyBub3QgYSByZWFsIHRva2Vu",
|
||||
}, http.StatusUnauthorized, nil)
|
||||
|
||||
// Empty token
|
||||
doJSON(t, app, http.MethodPost, "/api/auth/tokens", map[string]any{
|
||||
"refreshToken": "",
|
||||
}, http.StatusUnauthorized, nil)
|
||||
}
|
||||
|
||||
func TestHTTP_AuthMiddleware(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
app := newAuthAppWithProtectedRoute(t, ctx)
|
||||
|
||||
t.Run("no token returns 401", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/protected", nil)
|
||||
res, err := app.Test(req, 10_000)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", res.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid bearer token returns 401", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
res, err := app.Test(req, 10_000)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", res.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("malformed authorization header returns 401", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/protected", nil)
|
||||
req.Header.Set("Authorization", "NotBearer token")
|
||||
res, err := app.Test(req, 10_000)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", res.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid token succeeds", func(t *testing.T) {
|
||||
// Login to get a valid token
|
||||
var loginResp struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
}
|
||||
doJSON(t, app, http.MethodPost, "/api/auth/login", map[string]any{
|
||||
"email": "reg@example.com",
|
||||
"password": "password123",
|
||||
"tokenDelivery": auth.TokenDeliveryBody,
|
||||
}, http.StatusOK, &loginResp)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+loginResp.AccessToken)
|
||||
res, err := app.Test(req, 10_000)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", res.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expired token returns 401", func(t *testing.T) {
|
||||
// Generate a properly signed but expired JWT using the same secret
|
||||
expiredJWT := generateExpiredJWT(t, "drexa-test", "drexa-test", []byte("drexa-test-secret"))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+expiredJWT)
|
||||
res, err := app.Test(req, 10_000)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", res.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newAuthApp(t *testing.T, ctx context.Context) *fiber.App {
|
||||
t.Helper()
|
||||
|
||||
@@ -257,6 +369,76 @@ func newAuthApp(t *testing.T, ctx context.Context) *fiber.App {
|
||||
return app
|
||||
}
|
||||
|
||||
func newAuthAppWithProtectedRoute(t *testing.T, ctx context.Context) *fiber.App {
|
||||
t.Helper()
|
||||
|
||||
pg, err := runPostgres(ctx)
|
||||
if err != nil {
|
||||
t.Skipf("postgres testcontainer unavailable (docker not running/configured?): %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = pg.Terminate(ctx) })
|
||||
|
||||
postgresURL, err := pg.ConnectionString(ctx, "sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("postgres connection string: %v", err)
|
||||
}
|
||||
|
||||
db := database.NewFromPostgres(postgresURL)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
if err := database.RunMigrations(ctx, db); err != nil {
|
||||
t.Fatalf("RunMigrations: %v", err)
|
||||
}
|
||||
|
||||
userSvc := user.NewService()
|
||||
orgSvc := organization.NewService()
|
||||
accSvc := account.NewService()
|
||||
|
||||
hashed, err := password.HashString("password123")
|
||||
if err != nil {
|
||||
t.Fatalf("HashString: %v", err)
|
||||
}
|
||||
|
||||
u, err := userSvc.RegisterUser(ctx, db, user.UserRegistrationOptions{
|
||||
Email: "reg@example.com",
|
||||
DisplayName: "Reg User",
|
||||
Password: hashed,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RegisterUser: %v", err)
|
||||
}
|
||||
|
||||
org, err := orgSvc.CreatePersonalOrganization(ctx, db, "Personal")
|
||||
if err != nil {
|
||||
t.Fatalf("CreatePersonalOrganization: %v", err)
|
||||
}
|
||||
|
||||
if _, err := accSvc.CreateAccount(ctx, db, org.ID, u.ID, account.RoleAdmin, account.StatusActive); err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
authSvc := auth.NewService(userSvc, auth.TokenConfig{
|
||||
Issuer: "drexa-test",
|
||||
Audience: "drexa-test",
|
||||
SecretKey: []byte("drexa-test-secret"),
|
||||
})
|
||||
|
||||
authMiddleware := auth.NewAuthMiddleware(authSvc, db, auth.CookieConfig{})
|
||||
|
||||
app := fiber.New(fiber.Config{
|
||||
ErrorHandler: httperr.ErrorHandler,
|
||||
})
|
||||
api := app.Group("/api")
|
||||
auth.NewHTTPHandler(authSvc, orgSvc, db, auth.CookieConfig{}).RegisterRoutes(api)
|
||||
|
||||
// Add a protected route that requires authentication
|
||||
api.Get("/protected", authMiddleware, func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
return app
|
||||
}
|
||||
|
||||
func doRequest(
|
||||
t *testing.T,
|
||||
app *fiber.App,
|
||||
@@ -331,3 +513,24 @@ func runPostgres(ctx context.Context) (_ *postgres.PostgresContainer, err error)
|
||||
postgres.BasicWaitStrategies(),
|
||||
)
|
||||
}
|
||||
|
||||
// generateExpiredJWT creates a properly signed JWT that expired 1 hour ago
|
||||
func generateExpiredJWT(t *testing.T, issuer, audience string, secretKey []byte) string {
|
||||
t.Helper()
|
||||
|
||||
now := time.Now()
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
|
||||
Issuer: issuer,
|
||||
Audience: jwt.ClaimStrings{audience},
|
||||
Subject: "test-user-id",
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(-1 * time.Hour)), // Expired 1 hour ago
|
||||
IssuedAt: jwt.NewNumericDate(now.Add(-2 * time.Hour)), // Issued 2 hours ago
|
||||
})
|
||||
|
||||
signed, err := token.SignedString(secretKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to sign expired JWT: %v", err)
|
||||
}
|
||||
|
||||
return signed
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user