From 3a6fafaccadf7023bc7274ee1e208cb778941fe3 Mon Sep 17 00:00:00 2001 From: Kenneth Date: Wed, 3 Dec 2025 00:07:39 +0000 Subject: [PATCH] feat: impl refresh token rotation --- apps/backend/internal/account/http.go | 2 +- apps/backend/internal/auth/err.go | 3 + apps/backend/internal/auth/grant.go | 22 ++ apps/backend/internal/auth/http.go | 47 ++++- apps/backend/internal/auth/service.go | 194 ++++++++++++++---- apps/backend/internal/auth/tokens.go | 27 ++- .../database/migrations/001_initial.up.sql | 16 +- apps/backend/internal/virtualfs/vfs.go | 4 +- 8 files changed, 259 insertions(+), 56 deletions(-) create mode 100644 apps/backend/internal/auth/grant.go diff --git a/apps/backend/internal/account/http.go b/apps/backend/internal/account/http.go index 016b640..bbe8127 100644 --- a/apps/backend/internal/account/http.go +++ b/apps/backend/internal/account/http.go @@ -113,7 +113,7 @@ func (h *HTTPHandler) registerAccount(c *fiber.Ctx) error { return httperr.Internal(err) } - result, err := h.authService.GenerateTokenForUser(c.Context(), tx, u) + result, err := h.authService.GrantForUser(c.Context(), tx, u) if err != nil { return httperr.Internal(err) } diff --git a/apps/backend/internal/auth/err.go b/apps/backend/internal/auth/err.go index 0def7f4..5d72eb7 100644 --- a/apps/backend/internal/auth/err.go +++ b/apps/backend/internal/auth/err.go @@ -6,6 +6,9 @@ import ( ) var ErrUnauthenticatedRequest = errors.New("unauthenticated request") +var ErrInvalidRefreshToken = errors.New("invalid refresh token") +var ErrRefreshTokenExpired = errors.New("refresh token expired") +var ErrRefreshTokenReused = errors.New("refresh token reused") type InvalidAccessTokenError struct { err error diff --git a/apps/backend/internal/auth/grant.go b/apps/backend/internal/auth/grant.go new file mode 100644 index 0000000..90d9934 --- /dev/null +++ b/apps/backend/internal/auth/grant.go @@ -0,0 +1,22 @@ +package auth + +import ( + "time" + + "github.com/google/uuid" + "github.com/uptrace/bun" +) + +type Grant struct { + bun.BaseModel `bun:"grants"` + + ID uuid.UUID `bun:",pk,type:uuid"` + UserID uuid.UUID `bun:"user_id,notnull"` + CreatedAt time.Time `bun:"created_at,notnull,nullzero"` + UpdatedAt time.Time `bun:"updated_at,notnull,nullzero"` + RevokedAt *time.Time `bun:"revoked_at,nullzero"` +} + +func newGrantID() (uuid.UUID, error) { + return uuid.NewV7() +} diff --git a/apps/backend/internal/auth/http.go b/apps/backend/internal/auth/http.go index 58cea2b..fa0fabc 100644 --- a/apps/backend/internal/auth/http.go +++ b/apps/backend/internal/auth/http.go @@ -2,6 +2,7 @@ package auth import ( "errors" + "log/slog" "github.com/get-drexa/drexa/internal/httperr" "github.com/get-drexa/drexa/internal/user" @@ -20,6 +21,15 @@ type loginResponse struct { RefreshToken string `json:"refreshToken"` } +type refreshAccessTokenRequest struct { + RefreshToken string `json:"refreshToken"` +} + +type tokenResponse struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` +} + type HTTPHandler struct { service *Service db *bun.DB @@ -32,6 +42,7 @@ func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler { func (h *HTTPHandler) RegisterRoutes(api fiber.Router) { auth := api.Group("/auth") auth.Post("/login", h.Login) + auth.Post("/tokens", h.refreshAccessToken) } func (h *HTTPHandler) Login(c *fiber.Ctx) error { @@ -46,7 +57,7 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error { } defer tx.Rollback() - result, err := h.service.AuthenticateWithEmailAndPassword(c.Context(), tx, req.Email, req.Password) + result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password) if err != nil { if errors.Is(err, ErrInvalidCredentials) { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"}) @@ -64,3 +75,37 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error { RefreshToken: result.RefreshToken, }) } + +func (h *HTTPHandler) refreshAccessToken(c *fiber.Ctx) error { + req := new(refreshAccessTokenRequest) + if err := c.BodyParser(req); err != nil { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) + } + + tx, err := h.db.BeginTx(c.Context(), nil) + if err != nil { + return httperr.Internal(err) + } + defer tx.Rollback() + + result, err := h.service.RefreshAccessToken(c.Context(), tx, req.RefreshToken) + if err != nil { + if errors.Is(err, ErrInvalidRefreshToken) || + errors.Is(err, ErrRefreshTokenExpired) || + errors.Is(err, ErrRefreshTokenReused) { + _ = tx.Commit() + slog.Info("invalid refresh token", "error", err) + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "invalid refresh token"}) + } + return httperr.Internal(err) + } + + if err := tx.Commit(); err != nil { + return httperr.Internal(err) + } + + return c.JSON(tokenResponse{ + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + }) +} diff --git a/apps/backend/internal/auth/service.go b/apps/backend/internal/auth/service.go index 5b1e288..68423d7 100644 --- a/apps/backend/internal/auth/service.go +++ b/apps/backend/internal/auth/service.go @@ -2,9 +2,10 @@ package auth import ( "context" - "encoding/hex" + "database/sql" "errors" "log/slog" + "time" "github.com/get-drexa/drexa/internal/password" "github.com/get-drexa/drexa/internal/user" @@ -12,7 +13,7 @@ import ( "github.com/uptrace/bun" ) -type AuthenticationResult struct { +type AuthenticationTokens struct { User *user.User AccessToken string RefreshToken string @@ -32,64 +33,35 @@ func NewService(userService *user.Service, tokenConfig TokenConfig) *Service { } } -func (s *Service) GenerateTokenForUser(ctx context.Context, db bun.IDB, user *user.User) (*AuthenticationResult, error) { - at, err := GenerateAccessToken(user, &s.tokenConfig) +func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*AuthenticationTokens, error) { + u, err := s.authenticateWithEmailAndPassword(ctx, db, email, plain) if err != nil { return nil, err } - - rt, err := GenerateRefreshToken(user, &s.tokenConfig) - if err != nil { - return nil, err - } - - _, err = db.NewInsert().Model(rt).Exec(ctx) - if err != nil { - return nil, err - } - - return &AuthenticationResult{ - User: user, - AccessToken: at, - RefreshToken: hex.EncodeToString(rt.Token), - }, nil + return s.GrantForUser(ctx, db, u) } -func (s *Service) AuthenticateWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*AuthenticationResult, error) { - u, err := s.userService.UserByEmail(ctx, db, email) - if err != nil { - var nf *user.NotFoundError - if errors.As(err, &nf) { - return nil, ErrInvalidCredentials - } - return nil, err - } - - ok, err := password.Verify(plain, u.Password) - if err != nil || !ok { - return nil, ErrInvalidCredentials - } - - at, err := GenerateAccessToken(u, &s.tokenConfig) +func (s *Service) GrantForUser(ctx context.Context, db bun.IDB, user *user.User) (*AuthenticationTokens, error) { + id, err := newGrantID() if err != nil { return nil, err } - rt, err := GenerateRefreshToken(u, &s.tokenConfig) + grant := &Grant{ + ID: id, + UserID: user.ID, + } + _, err = db.NewInsert().Model(grant).Exec(ctx) if err != nil { return nil, err } - _, err = db.NewInsert().Model(rt).Exec(ctx) + result, err := s.generateTokens(ctx, db, user, grant) if err != nil { return nil, err } - return &AuthenticationResult{ - User: u, - AccessToken: at, - RefreshToken: hex.EncodeToString(rt.Token), - }, nil + return result, nil } func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) { @@ -107,3 +79,139 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, t return s.userService.UserByID(ctx, db, id) } + +func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshToken string) (*AuthenticationTokens, error) { + rtBytes, err := DecodeRefreshToken(refreshToken) + if err != nil { + return nil, err + } + + rtHash := HashRefreshToken(rtBytes) + + rt := &RefreshToken{} + err = db.NewSelect().Model(rt).Where("token_hash = ?", rtHash).Scan(ctx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrInvalidRefreshToken + } + return nil, err + } + + if rt.ExpiresAt.Before(time.Now()) { + return nil, ErrRefreshTokenExpired + } + + if rt.ConsumedAt != nil { + // token reuse detected, invalidate all refresh tokens under the same grant + _, err = db.NewDelete().Model((*RefreshToken)(nil)).Where("grant_id = ?", rt.GrantID).Exec(ctx) + if err != nil { + return nil, err + } + + _, err = db.NewUpdate().Model((*Grant)(nil)).Set("revoked_at = ?", time.Now()).Where("id = ?", rt.GrantID).Exec(ctx) + if err != nil { + return nil, err + } + + return nil, ErrRefreshTokenReused + } + + grant := &Grant{ + ID: rt.GrantID, + } + err = db.NewSelect().Model(grant).WherePK().Scan(ctx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + // the grant for this refresh token was deleted, delete the refresh token + _, _ = db.NewDelete().Model(rt).WherePK().Exec(ctx) + return nil, ErrInvalidRefreshToken + } + return nil, err + } + + if grant.RevokedAt != nil { + return nil, ErrInvalidRefreshToken + } + + u, err := s.userService.UserByID(ctx, db, grant.UserID) + if err != nil { + var nf *user.NotFoundError + if errors.As(err, &nf) { + // the user for this grant was deleted, delete the grant and refresh token + _, _ = db.NewDelete().Model(grant).WherePK().Exec(ctx) + _, _ = db.NewDelete().Model(rt).WherePK().Exec(ctx) + return nil, ErrInvalidRefreshToken + } + return nil, err + } + + newRT, err := GenerateRefreshToken(u, &s.tokenConfig) + if err != nil { + return nil, err + } + newRT.GrantID = grant.ID + + _, err = db.NewUpdate().Model(rt).Set("consumed_at = ?", time.Now()).WherePK().Exec(ctx) + if err != nil { + return nil, err + } + + _, err = db.NewInsert().Model(newRT).Exec(ctx) + if err != nil { + return nil, err + } + + at, err := GenerateAccessToken(u, &s.tokenConfig) + if err != nil { + return nil, err + } + + return &AuthenticationTokens{ + User: u, + AccessToken: at, + RefreshToken: EncodeRefreshToken(newRT.Token), + }, nil +} + +func (s *Service) generateTokens(ctx context.Context, db bun.IDB, user *user.User, grant *Grant) (*AuthenticationTokens, error) { + at, err := GenerateAccessToken(user, &s.tokenConfig) + if err != nil { + return nil, err + } + + rt, err := GenerateRefreshToken(user, &s.tokenConfig) + if err != nil { + return nil, err + } + + rt.GrantID = grant.ID + + _, err = db.NewInsert().Model(rt).Exec(ctx) + if err != nil { + return nil, err + } + + return &AuthenticationTokens{ + User: user, + AccessToken: at, + RefreshToken: EncodeRefreshToken(rt.Token), + }, nil +} + +func (s *Service) authenticateWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*user.User, error) { + u, err := s.userService.UserByEmail(ctx, db, email) + if err != nil { + var nf *user.NotFoundError + if errors.As(err, &nf) { + return nil, ErrInvalidCredentials + } + return nil, err + } + + ok, err := password.Verify(plain, u.Password) + if err != nil || !ok { + return nil, ErrInvalidCredentials + } + + return u, nil +} diff --git a/apps/backend/internal/auth/tokens.go b/apps/backend/internal/auth/tokens.go index 42340e9..71e5aa5 100644 --- a/apps/backend/internal/auth/tokens.go +++ b/apps/backend/internal/auth/tokens.go @@ -28,12 +28,13 @@ type TokenConfig struct { type RefreshToken struct { bun.BaseModel `bun:"refresh_tokens"` - ID uuid.UUID `bun:",pk,type:uuid"` - UserID uuid.UUID `bun:"user_id,notnull"` - Token []byte `bun:"-"` - TokenHash string `bun:"token_hash,notnull"` - ExpiresAt time.Time `bun:"expires_at,notnull"` - CreatedAt time.Time `bun:"created_at,notnull,nullzero"` + ID uuid.UUID `bun:",pk,type:uuid"` + GrantID uuid.UUID `bun:"grant_id,notnull"` + Token []byte `bun:"-"` + TokenHash string `bun:"token_hash,notnull"` + ExpiresAt time.Time `bun:"expires_at,notnull"` + CreatedAt time.Time `bun:"created_at,notnull,nullzero"` + ConsumedAt *time.Time `bun:"consumed_at,nullzero"` } func newTokenID() (uuid.UUID, error) { @@ -77,7 +78,6 @@ func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error return &RefreshToken{ ID: id, - UserID: user.ID, Token: buf, TokenHash: hex, ExpiresAt: now.Add(refreshTokenValidFor), @@ -96,3 +96,16 @@ func ParseAccessToken(token string, c *TokenConfig) (*jwt.RegisteredClaims, erro } return parsed.Claims.(*jwt.RegisteredClaims), nil } + +func EncodeRefreshToken(token []byte) string { + return hex.EncodeToString(token) +} + +func DecodeRefreshToken(token string) ([]byte, error) { + return hex.DecodeString(token) +} + +func HashRefreshToken(token []byte) string { + h := sha256.Sum256(token) + return hex.EncodeToString(h[:]) +} diff --git a/apps/backend/internal/database/migrations/001_initial.up.sql b/apps/backend/internal/database/migrations/001_initial.up.sql index 096bdc5..6683bab 100644 --- a/apps/backend/internal/database/migrations/001_initial.up.sql +++ b/apps/backend/internal/database/migrations/001_initial.up.sql @@ -24,15 +24,24 @@ CREATE TABLE IF NOT EXISTS accounts ( CREATE INDEX idx_accounts_user_id ON accounts(user_id); -CREATE TABLE IF NOT EXISTS refresh_tokens ( +CREATE TABLE IF NOT EXISTS grants ( id UUID PRIMARY KEY, user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + revoked_at TIMESTAMPTZ +); + +CREATE TABLE IF NOT EXISTS refresh_tokens ( + id UUID PRIMARY KEY, + grant_id UUID NOT NULL REFERENCES grants(id) ON DELETE CASCADE, token_hash TEXT NOT NULL UNIQUE, expires_at TIMESTAMPTZ NOT NULL, + consumed_at TIMESTAMPTZ, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -CREATE INDEX idx_refresh_tokens_user_id ON refresh_tokens(user_id); +CREATE INDEX idx_refresh_tokens_grant_id ON refresh_tokens(grant_id); CREATE INDEX idx_refresh_tokens_token_hash ON refresh_tokens(token_hash); CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at); @@ -103,4 +112,7 @@ CREATE TRIGGER update_node_shares_updated_at BEFORE UPDATE ON node_shares FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); CREATE TRIGGER update_accounts_updated_at BEFORE UPDATE ON accounts + FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); + +CREATE TRIGGER update_grants_updated_at BEFORE UPDATE ON grants FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); \ No newline at end of file diff --git a/apps/backend/internal/virtualfs/vfs.go b/apps/backend/internal/virtualfs/vfs.go index a651173..6041129 100644 --- a/apps/backend/internal/virtualfs/vfs.go +++ b/apps/backend/internal/virtualfs/vfs.go @@ -265,7 +265,7 @@ func (vfs *VirtualFS) CreateDirectory(ctx context.Context, db bun.IDB, accountID return nil, err } - node := Node{ + node := &Node{ ID: id, PublicID: pid, AccountID: accountID, @@ -283,7 +283,7 @@ func (vfs *VirtualFS) CreateDirectory(ctx context.Context, db bun.IDB, accountID return nil, err } - return &node, nil + return node, nil } func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, db bun.IDB, node *Node) error {