feat: impl refresh token rotation

This commit is contained in:
2025-12-03 00:07:39 +00:00
parent a7bdc831ea
commit 3a6fafacca
8 changed files with 259 additions and 56 deletions

View File

@@ -113,7 +113,7 @@ func (h *HTTPHandler) registerAccount(c *fiber.Ctx) error {
return httperr.Internal(err)
}
result, err := h.authService.GenerateTokenForUser(c.Context(), tx, u)
result, err := h.authService.GrantForUser(c.Context(), tx, u)
if err != nil {
return httperr.Internal(err)
}

View File

@@ -6,6 +6,9 @@ import (
)
var ErrUnauthenticatedRequest = errors.New("unauthenticated request")
var ErrInvalidRefreshToken = errors.New("invalid refresh token")
var ErrRefreshTokenExpired = errors.New("refresh token expired")
var ErrRefreshTokenReused = errors.New("refresh token reused")
type InvalidAccessTokenError struct {
err error

View File

@@ -0,0 +1,22 @@
package auth
import (
"time"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type Grant struct {
bun.BaseModel `bun:"grants"`
ID uuid.UUID `bun:",pk,type:uuid"`
UserID uuid.UUID `bun:"user_id,notnull"`
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
UpdatedAt time.Time `bun:"updated_at,notnull,nullzero"`
RevokedAt *time.Time `bun:"revoked_at,nullzero"`
}
func newGrantID() (uuid.UUID, error) {
return uuid.NewV7()
}

View File

@@ -2,6 +2,7 @@ package auth
import (
"errors"
"log/slog"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/user"
@@ -20,6 +21,15 @@ type loginResponse struct {
RefreshToken string `json:"refreshToken"`
}
type refreshAccessTokenRequest struct {
RefreshToken string `json:"refreshToken"`
}
type tokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
}
type HTTPHandler struct {
service *Service
db *bun.DB
@@ -32,6 +42,7 @@ func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
auth := api.Group("/auth")
auth.Post("/login", h.Login)
auth.Post("/tokens", h.refreshAccessToken)
}
func (h *HTTPHandler) Login(c *fiber.Ctx) error {
@@ -46,7 +57,7 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
}
defer tx.Rollback()
result, err := h.service.AuthenticateWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
if err != nil {
if errors.Is(err, ErrInvalidCredentials) {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
@@ -64,3 +75,37 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
RefreshToken: result.RefreshToken,
})
}
func (h *HTTPHandler) refreshAccessToken(c *fiber.Ctx) error {
req := new(refreshAccessTokenRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return httperr.Internal(err)
}
defer tx.Rollback()
result, err := h.service.RefreshAccessToken(c.Context(), tx, req.RefreshToken)
if err != nil {
if errors.Is(err, ErrInvalidRefreshToken) ||
errors.Is(err, ErrRefreshTokenExpired) ||
errors.Is(err, ErrRefreshTokenReused) {
_ = tx.Commit()
slog.Info("invalid refresh token", "error", err)
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "invalid refresh token"})
}
return httperr.Internal(err)
}
if err := tx.Commit(); err != nil {
return httperr.Internal(err)
}
return c.JSON(tokenResponse{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
})
}

View File

@@ -2,9 +2,10 @@ package auth
import (
"context"
"encoding/hex"
"database/sql"
"errors"
"log/slog"
"time"
"github.com/get-drexa/drexa/internal/password"
"github.com/get-drexa/drexa/internal/user"
@@ -12,7 +13,7 @@ import (
"github.com/uptrace/bun"
)
type AuthenticationResult struct {
type AuthenticationTokens struct {
User *user.User
AccessToken string
RefreshToken string
@@ -32,64 +33,35 @@ func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
}
}
func (s *Service) GenerateTokenForUser(ctx context.Context, db bun.IDB, user *user.User) (*AuthenticationResult, error) {
at, err := GenerateAccessToken(user, &s.tokenConfig)
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*AuthenticationTokens, error) {
u, err := s.authenticateWithEmailAndPassword(ctx, db, email, plain)
if err != nil {
return nil, err
}
rt, err := GenerateRefreshToken(user, &s.tokenConfig)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
return &AuthenticationResult{
User: user,
AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token),
}, nil
return s.GrantForUser(ctx, db, u)
}
func (s *Service) AuthenticateWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*AuthenticationResult, error) {
u, err := s.userService.UserByEmail(ctx, db, email)
if err != nil {
var nf *user.NotFoundError
if errors.As(err, &nf) {
return nil, ErrInvalidCredentials
}
return nil, err
}
ok, err := password.Verify(plain, u.Password)
if err != nil || !ok {
return nil, ErrInvalidCredentials
}
at, err := GenerateAccessToken(u, &s.tokenConfig)
func (s *Service) GrantForUser(ctx context.Context, db bun.IDB, user *user.User) (*AuthenticationTokens, error) {
id, err := newGrantID()
if err != nil {
return nil, err
}
rt, err := GenerateRefreshToken(u, &s.tokenConfig)
grant := &Grant{
ID: id,
UserID: user.ID,
}
_, err = db.NewInsert().Model(grant).Exec(ctx)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(rt).Exec(ctx)
result, err := s.generateTokens(ctx, db, user, grant)
if err != nil {
return nil, err
}
return &AuthenticationResult{
User: u,
AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token),
}, nil
return result, nil
}
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
@@ -107,3 +79,139 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, t
return s.userService.UserByID(ctx, db, id)
}
func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshToken string) (*AuthenticationTokens, error) {
rtBytes, err := DecodeRefreshToken(refreshToken)
if err != nil {
return nil, err
}
rtHash := HashRefreshToken(rtBytes)
rt := &RefreshToken{}
err = db.NewSelect().Model(rt).Where("token_hash = ?", rtHash).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrInvalidRefreshToken
}
return nil, err
}
if rt.ExpiresAt.Before(time.Now()) {
return nil, ErrRefreshTokenExpired
}
if rt.ConsumedAt != nil {
// token reuse detected, invalidate all refresh tokens under the same grant
_, err = db.NewDelete().Model((*RefreshToken)(nil)).Where("grant_id = ?", rt.GrantID).Exec(ctx)
if err != nil {
return nil, err
}
_, err = db.NewUpdate().Model((*Grant)(nil)).Set("revoked_at = ?", time.Now()).Where("id = ?", rt.GrantID).Exec(ctx)
if err != nil {
return nil, err
}
return nil, ErrRefreshTokenReused
}
grant := &Grant{
ID: rt.GrantID,
}
err = db.NewSelect().Model(grant).WherePK().Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
// the grant for this refresh token was deleted, delete the refresh token
_, _ = db.NewDelete().Model(rt).WherePK().Exec(ctx)
return nil, ErrInvalidRefreshToken
}
return nil, err
}
if grant.RevokedAt != nil {
return nil, ErrInvalidRefreshToken
}
u, err := s.userService.UserByID(ctx, db, grant.UserID)
if err != nil {
var nf *user.NotFoundError
if errors.As(err, &nf) {
// the user for this grant was deleted, delete the grant and refresh token
_, _ = db.NewDelete().Model(grant).WherePK().Exec(ctx)
_, _ = db.NewDelete().Model(rt).WherePK().Exec(ctx)
return nil, ErrInvalidRefreshToken
}
return nil, err
}
newRT, err := GenerateRefreshToken(u, &s.tokenConfig)
if err != nil {
return nil, err
}
newRT.GrantID = grant.ID
_, err = db.NewUpdate().Model(rt).Set("consumed_at = ?", time.Now()).WherePK().Exec(ctx)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(newRT).Exec(ctx)
if err != nil {
return nil, err
}
at, err := GenerateAccessToken(u, &s.tokenConfig)
if err != nil {
return nil, err
}
return &AuthenticationTokens{
User: u,
AccessToken: at,
RefreshToken: EncodeRefreshToken(newRT.Token),
}, nil
}
func (s *Service) generateTokens(ctx context.Context, db bun.IDB, user *user.User, grant *Grant) (*AuthenticationTokens, error) {
at, err := GenerateAccessToken(user, &s.tokenConfig)
if err != nil {
return nil, err
}
rt, err := GenerateRefreshToken(user, &s.tokenConfig)
if err != nil {
return nil, err
}
rt.GrantID = grant.ID
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
return &AuthenticationTokens{
User: user,
AccessToken: at,
RefreshToken: EncodeRefreshToken(rt.Token),
}, nil
}
func (s *Service) authenticateWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*user.User, error) {
u, err := s.userService.UserByEmail(ctx, db, email)
if err != nil {
var nf *user.NotFoundError
if errors.As(err, &nf) {
return nil, ErrInvalidCredentials
}
return nil, err
}
ok, err := password.Verify(plain, u.Password)
if err != nil || !ok {
return nil, ErrInvalidCredentials
}
return u, nil
}

View File

@@ -28,12 +28,13 @@ type TokenConfig struct {
type RefreshToken struct {
bun.BaseModel `bun:"refresh_tokens"`
ID uuid.UUID `bun:",pk,type:uuid"`
UserID uuid.UUID `bun:"user_id,notnull"`
Token []byte `bun:"-"`
TokenHash string `bun:"token_hash,notnull"`
ExpiresAt time.Time `bun:"expires_at,notnull"`
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
ID uuid.UUID `bun:",pk,type:uuid"`
GrantID uuid.UUID `bun:"grant_id,notnull"`
Token []byte `bun:"-"`
TokenHash string `bun:"token_hash,notnull"`
ExpiresAt time.Time `bun:"expires_at,notnull"`
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
ConsumedAt *time.Time `bun:"consumed_at,nullzero"`
}
func newTokenID() (uuid.UUID, error) {
@@ -77,7 +78,6 @@ func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error
return &RefreshToken{
ID: id,
UserID: user.ID,
Token: buf,
TokenHash: hex,
ExpiresAt: now.Add(refreshTokenValidFor),
@@ -96,3 +96,16 @@ func ParseAccessToken(token string, c *TokenConfig) (*jwt.RegisteredClaims, erro
}
return parsed.Claims.(*jwt.RegisteredClaims), nil
}
func EncodeRefreshToken(token []byte) string {
return hex.EncodeToString(token)
}
func DecodeRefreshToken(token string) ([]byte, error) {
return hex.DecodeString(token)
}
func HashRefreshToken(token []byte) string {
h := sha256.Sum256(token)
return hex.EncodeToString(h[:])
}

View File

@@ -24,15 +24,24 @@ CREATE TABLE IF NOT EXISTS accounts (
CREATE INDEX idx_accounts_user_id ON accounts(user_id);
CREATE TABLE IF NOT EXISTS refresh_tokens (
CREATE TABLE IF NOT EXISTS grants (
id UUID PRIMARY KEY,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_at TIMESTAMPTZ
);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id UUID PRIMARY KEY,
grant_id UUID NOT NULL REFERENCES grants(id) ON DELETE CASCADE,
token_hash TEXT NOT NULL UNIQUE,
expires_at TIMESTAMPTZ NOT NULL,
consumed_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_refresh_tokens_user_id ON refresh_tokens(user_id);
CREATE INDEX idx_refresh_tokens_grant_id ON refresh_tokens(grant_id);
CREATE INDEX idx_refresh_tokens_token_hash ON refresh_tokens(token_hash);
CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
@@ -103,4 +112,7 @@ CREATE TRIGGER update_node_shares_updated_at BEFORE UPDATE ON node_shares
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
CREATE TRIGGER update_accounts_updated_at BEFORE UPDATE ON accounts
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
CREATE TRIGGER update_grants_updated_at BEFORE UPDATE ON grants
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();

View File

@@ -265,7 +265,7 @@ func (vfs *VirtualFS) CreateDirectory(ctx context.Context, db bun.IDB, accountID
return nil, err
}
node := Node{
node := &Node{
ID: id,
PublicID: pid,
AccountID: accountID,
@@ -283,7 +283,7 @@ func (vfs *VirtualFS) CreateDirectory(ctx context.Context, db bun.IDB, accountID
return nil, err
}
return &node, nil
return node, nil
}
func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, db bun.IDB, node *Node) error {