mirror of
https://github.com/get-drexa/drive.git
synced 2026-02-02 14:51:18 +00:00
test(backend): add auth error tests
Co-authored-by: Ona <no-reply@ona.com>
This commit is contained in:
@@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/get-drexa/drexa/internal/password"
|
"github.com/get-drexa/drexa/internal/password"
|
||||||
"github.com/get-drexa/drexa/internal/user"
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -194,6 +195,117 @@ func TestHTTP_Tokens_Rotation(t *testing.T) {
|
|||||||
}, http.StatusUnauthorized, nil)
|
}, http.StatusUnauthorized, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHTTP_Tokens_InvalidRefreshToken(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
app := newAuthApp(t, ctx)
|
||||||
|
|
||||||
|
// Invalid base64
|
||||||
|
doJSON(t, app, http.MethodPost, "/api/auth/tokens", map[string]any{
|
||||||
|
"refreshToken": "not-valid-base64!!!",
|
||||||
|
}, http.StatusUnauthorized, nil)
|
||||||
|
|
||||||
|
// Valid base64 but not a real token
|
||||||
|
doJSON(t, app, http.MethodPost, "/api/auth/tokens", map[string]any{
|
||||||
|
"refreshToken": "dGhpcyBpcyBub3QgYSByZWFsIHRva2Vu",
|
||||||
|
}, http.StatusUnauthorized, nil)
|
||||||
|
|
||||||
|
// Empty token
|
||||||
|
doJSON(t, app, http.MethodPost, "/api/auth/tokens", map[string]any{
|
||||||
|
"refreshToken": "",
|
||||||
|
}, http.StatusUnauthorized, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTP_AuthMiddleware(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
app := newAuthAppWithProtectedRoute(t, ctx)
|
||||||
|
|
||||||
|
t.Run("no token returns 401", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/protected", nil)
|
||||||
|
res, err := app.Test(req, 10_000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected 401, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid bearer token returns 401", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/protected", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||||
|
res, err := app.Test(req, 10_000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected 401, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("malformed authorization header returns 401", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/protected", nil)
|
||||||
|
req.Header.Set("Authorization", "NotBearer token")
|
||||||
|
res, err := app.Test(req, 10_000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected 401, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("valid token succeeds", func(t *testing.T) {
|
||||||
|
// Login to get a valid token
|
||||||
|
var loginResp struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
}
|
||||||
|
doJSON(t, app, http.MethodPost, "/api/auth/login", map[string]any{
|
||||||
|
"email": "reg@example.com",
|
||||||
|
"password": "password123",
|
||||||
|
"tokenDelivery": auth.TokenDeliveryBody,
|
||||||
|
}, http.StatusOK, &loginResp)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/protected", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+loginResp.AccessToken)
|
||||||
|
res, err := app.Test(req, 10_000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("expired token returns 401", func(t *testing.T) {
|
||||||
|
// Generate a properly signed but expired JWT using the same secret
|
||||||
|
expiredJWT := generateExpiredJWT(t, "drexa-test", "drexa-test", []byte("drexa-test-secret"))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/protected", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+expiredJWT)
|
||||||
|
res, err := app.Test(req, 10_000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected 401, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func newAuthApp(t *testing.T, ctx context.Context) *fiber.App {
|
func newAuthApp(t *testing.T, ctx context.Context) *fiber.App {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -257,6 +369,76 @@ func newAuthApp(t *testing.T, ctx context.Context) *fiber.App {
|
|||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newAuthAppWithProtectedRoute(t *testing.T, ctx context.Context) *fiber.App {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
pg, err := runPostgres(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("postgres testcontainer unavailable (docker not running/configured?): %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = pg.Terminate(ctx) })
|
||||||
|
|
||||||
|
postgresURL, err := pg.ConnectionString(ctx, "sslmode=disable")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("postgres connection string: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
db := database.NewFromPostgres(postgresURL)
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
|
||||||
|
if err := database.RunMigrations(ctx, db); err != nil {
|
||||||
|
t.Fatalf("RunMigrations: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userSvc := user.NewService()
|
||||||
|
orgSvc := organization.NewService()
|
||||||
|
accSvc := account.NewService()
|
||||||
|
|
||||||
|
hashed, err := password.HashString("password123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HashString: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := userSvc.RegisterUser(ctx, db, user.UserRegistrationOptions{
|
||||||
|
Email: "reg@example.com",
|
||||||
|
DisplayName: "Reg User",
|
||||||
|
Password: hashed,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RegisterUser: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
org, err := orgSvc.CreatePersonalOrganization(ctx, db, "Personal")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreatePersonalOrganization: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := accSvc.CreateAccount(ctx, db, org.ID, u.ID, account.RoleAdmin, account.StatusActive); err != nil {
|
||||||
|
t.Fatalf("CreateAccount: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
authSvc := auth.NewService(userSvc, auth.TokenConfig{
|
||||||
|
Issuer: "drexa-test",
|
||||||
|
Audience: "drexa-test",
|
||||||
|
SecretKey: []byte("drexa-test-secret"),
|
||||||
|
})
|
||||||
|
|
||||||
|
authMiddleware := auth.NewAuthMiddleware(authSvc, db, auth.CookieConfig{})
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{
|
||||||
|
ErrorHandler: httperr.ErrorHandler,
|
||||||
|
})
|
||||||
|
api := app.Group("/api")
|
||||||
|
auth.NewHTTPHandler(authSvc, orgSvc, db, auth.CookieConfig{}).RegisterRoutes(api)
|
||||||
|
|
||||||
|
// Add a protected route that requires authentication
|
||||||
|
api.Get("/protected", authMiddleware, func(c *fiber.Ctx) error {
|
||||||
|
return c.SendStatus(fiber.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
func doRequest(
|
func doRequest(
|
||||||
t *testing.T,
|
t *testing.T,
|
||||||
app *fiber.App,
|
app *fiber.App,
|
||||||
@@ -331,3 +513,24 @@ func runPostgres(ctx context.Context) (_ *postgres.PostgresContainer, err error)
|
|||||||
postgres.BasicWaitStrategies(),
|
postgres.BasicWaitStrategies(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateExpiredJWT creates a properly signed JWT that expired 1 hour ago
|
||||||
|
func generateExpiredJWT(t *testing.T, issuer, audience string, secretKey []byte) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
|
||||||
|
Issuer: issuer,
|
||||||
|
Audience: jwt.ClaimStrings{audience},
|
||||||
|
Subject: "test-user-id",
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(-1 * time.Hour)), // Expired 1 hour ago
|
||||||
|
IssuedAt: jwt.NewNumericDate(now.Add(-2 * time.Hour)), // Issued 2 hours ago
|
||||||
|
})
|
||||||
|
|
||||||
|
signed, err := token.SignedString(secretKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to sign expired JWT: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return signed
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user