From d4c4e84fbfa5c5bbc2d0e7a8ff5727fbab5b9c57 Mon Sep 17 00:00:00 2001 From: Kenneth Date: Wed, 3 Dec 2025 23:05:00 +0000 Subject: [PATCH] feat: use argon2id to hash refresh tokens in db --- apps/backend/internal/account/service.go | 2 +- apps/backend/internal/auth/service.go | 37 +++++++-- apps/backend/internal/auth/tokens.go | 78 +++++++++++++------ .../database/migrations/001_initial.up.sql | 2 +- apps/backend/internal/password/password.go | 20 +++-- 5 files changed, 102 insertions(+), 37 deletions(-) diff --git a/apps/backend/internal/account/service.go b/apps/backend/internal/account/service.go index fd664d5..ea9ddf2 100644 --- a/apps/backend/internal/account/service.go +++ b/apps/backend/internal/account/service.go @@ -37,7 +37,7 @@ func NewService(userService *user.Service, vfs *virtualfs.VirtualFS) *Service { } func (s *Service) Register(ctx context.Context, db bun.IDB, opts RegisterOptions) (*Account, *user.User, error) { - hashed, err := password.Hash(opts.Password) + hashed, err := password.HashString(opts.Password) if err != nil { return nil, nil, err } diff --git a/apps/backend/internal/auth/service.go b/apps/backend/internal/auth/service.go index 68423d7..3c5e0b5 100644 --- a/apps/backend/internal/auth/service.go +++ b/apps/backend/internal/auth/service.go @@ -3,6 +3,7 @@ package auth import ( "context" "database/sql" + "encoding/base64" "errors" "log/slog" "time" @@ -81,15 +82,18 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, t } func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshToken string) (*AuthenticationTokens, error) { - rtBytes, err := DecodeRefreshToken(refreshToken) + t, err := base64.URLEncoding.DecodeString(refreshToken) if err != nil { - return nil, err + return nil, ErrInvalidRefreshToken } - rtHash := HashRefreshToken(rtBytes) + rtParts, err := DeserializeRefreshToken(t) + if err != nil { + return nil, ErrInvalidRefreshToken + } rt := &RefreshToken{} - err = db.NewSelect().Model(rt).Where("token_hash = ?", rtHash).Scan(ctx) + err = db.NewSelect().Model(rt).Where("key = ?", rtParts.Key).Scan(ctx) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrInvalidRefreshToken @@ -133,6 +137,14 @@ func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshTok return nil, ErrInvalidRefreshToken } + ok, err := password.Verify(rtParts.Token, rt.TokenHash) + if err != nil { + return nil, err + } + if !ok { + return nil, ErrInvalidRefreshToken + } + u, err := s.userService.UserByID(ctx, db, grant.UserID) if err != nil { var nf *user.NotFoundError @@ -166,10 +178,15 @@ func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshTok return nil, err } + srt, err := SerializeRefreshToken(newRT) + if err != nil { + return nil, err + } + return &AuthenticationTokens{ User: u, AccessToken: at, - RefreshToken: EncodeRefreshToken(newRT.Token), + RefreshToken: base64.URLEncoding.EncodeToString(srt), }, nil } @@ -183,7 +200,6 @@ func (s *Service) generateTokens(ctx context.Context, db bun.IDB, user *user.Use if err != nil { return nil, err } - rt.GrantID = grant.ID _, err = db.NewInsert().Model(rt).Exec(ctx) @@ -191,10 +207,15 @@ func (s *Service) generateTokens(ctx context.Context, db bun.IDB, user *user.Use return nil, err } + srt, err := SerializeRefreshToken(rt) + if err != nil { + return nil, err + } + return &AuthenticationTokens{ User: user, AccessToken: at, - RefreshToken: EncodeRefreshToken(rt.Token), + RefreshToken: base64.URLEncoding.EncodeToString(srt), }, nil } @@ -208,7 +229,7 @@ func (s *Service) authenticateWithEmailAndPassword(ctx context.Context, db bun.I return nil, err } - ok, err := password.Verify(plain, u.Password) + ok, err := password.VerifyString(plain, u.Password) if err != nil || !ok { return nil, ErrInvalidCredentials } diff --git a/apps/backend/internal/auth/tokens.go b/apps/backend/internal/auth/tokens.go index 71e5aa5..73fcacf 100644 --- a/apps/backend/internal/auth/tokens.go +++ b/apps/backend/internal/auth/tokens.go @@ -2,11 +2,10 @@ package auth import ( "crypto/rand" - "crypto/sha256" - "encoding/hex" "fmt" "time" + "github.com/get-drexa/drexa/internal/password" "github.com/get-drexa/drexa/internal/user" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" @@ -15,7 +14,9 @@ import ( const ( accessTokenValidFor = time.Minute * 15 - refreshTokenByteLength = 32 + refreshTokenKeyLength = 16 + refreshTokenDataLength = 32 + refreshTokenLength = refreshTokenKeyLength + refreshTokenDataLength refreshTokenValidFor = time.Hour * 24 * 30 ) @@ -28,13 +29,19 @@ type TokenConfig struct { type RefreshToken struct { bun.BaseModel `bun:"refresh_tokens"` - ID uuid.UUID `bun:",pk,type:uuid"` - GrantID uuid.UUID `bun:"grant_id,notnull"` - Token []byte `bun:"-"` - TokenHash string `bun:"token_hash,notnull"` - ExpiresAt time.Time `bun:"expires_at,notnull"` - CreatedAt time.Time `bun:"created_at,notnull,nullzero"` - ConsumedAt *time.Time `bun:"consumed_at,nullzero"` + ID uuid.UUID `bun:",pk,type:uuid"` + GrantID uuid.UUID `bun:"grant_id,notnull"` + Token []byte `bun:"-"` + Key uuid.UUID `bun:"key,notnull"` + TokenHash password.Hashed `bun:"token_hash,notnull"` + ExpiresAt time.Time `bun:"expires_at,notnull"` + CreatedAt time.Time `bun:"created_at,notnull,nullzero"` + ConsumedAt *time.Time `bun:"consumed_at,nullzero"` +} + +type RefreshTokenParts struct { + Key uuid.UUID + Token []byte } func newTokenID() (uuid.UUID, error) { @@ -63,23 +70,31 @@ func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) { func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error) { now := time.Now() - buf := make([]byte, refreshTokenByteLength) + buf := make([]byte, refreshTokenDataLength) if _, err := rand.Read(buf); err != nil { return nil, fmt.Errorf("failed to generate refresh token: %w", err) } + key, err := uuid.NewRandom() + if err != nil { + return nil, fmt.Errorf("failed to generate refresh token key: %w", err) + } + id, err := newTokenID() if err != nil { return nil, fmt.Errorf("failed to generate token ID: %w", err) } - h := sha256.Sum256(buf) - hex := hex.EncodeToString(h[:]) + hashed, err := password.Hash(buf) + if err != nil { + return nil, fmt.Errorf("failed to hash refresh token: %w", err) + } return &RefreshToken{ ID: id, Token: buf, - TokenHash: hex, + Key: key, + TokenHash: hashed, ExpiresAt: now.Add(refreshTokenValidFor), CreatedAt: now, }, nil @@ -97,15 +112,34 @@ func ParseAccessToken(token string, c *TokenConfig) (*jwt.RegisteredClaims, erro return parsed.Claims.(*jwt.RegisteredClaims), nil } -func EncodeRefreshToken(token []byte) string { - return hex.EncodeToString(token) +func SerializeRefreshToken(token *RefreshToken) ([]byte, error) { + keyBytes, err := token.Key.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("failed to marshal refresh token key: %w", err) + } + + buf := make([]byte, 0, len(keyBytes)+len(token.Token)) + buf = append(buf, keyBytes...) + buf = append(buf, token.Token...) + + return buf, nil } -func DecodeRefreshToken(token string) ([]byte, error) { - return hex.DecodeString(token) -} +func DeserializeRefreshToken(buf []byte) (*RefreshTokenParts, error) { + if len(buf) != refreshTokenLength { + return nil, fmt.Errorf("refresh token is too short") + } -func HashRefreshToken(token []byte) string { - h := sha256.Sum256(token) - return hex.EncodeToString(h[:]) + keyBytes := buf[:refreshTokenKeyLength] + token := buf[refreshTokenKeyLength:] + + key, err := uuid.FromBytes(keyBytes) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal refresh token key: %w", err) + } + + return &RefreshTokenParts{ + Key: key, + Token: token, + }, nil } diff --git a/apps/backend/internal/database/migrations/001_initial.up.sql b/apps/backend/internal/database/migrations/001_initial.up.sql index 6683bab..c4d9b0e 100644 --- a/apps/backend/internal/database/migrations/001_initial.up.sql +++ b/apps/backend/internal/database/migrations/001_initial.up.sql @@ -35,6 +35,7 @@ CREATE TABLE IF NOT EXISTS grants ( CREATE TABLE IF NOT EXISTS refresh_tokens ( id UUID PRIMARY KEY, grant_id UUID NOT NULL REFERENCES grants(id) ON DELETE CASCADE, + key UUID NOT NULL UNIQUE, token_hash TEXT NOT NULL UNIQUE, expires_at TIMESTAMPTZ NOT NULL, consumed_at TIMESTAMPTZ, @@ -42,7 +43,6 @@ CREATE TABLE IF NOT EXISTS refresh_tokens ( ); CREATE INDEX idx_refresh_tokens_grant_id ON refresh_tokens(grant_id); -CREATE INDEX idx_refresh_tokens_token_hash ON refresh_tokens(token_hash); CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at); -- Virtual filesystem nodes (unified files + directories) diff --git a/apps/backend/internal/password/password.go b/apps/backend/internal/password/password.go index 29e2189..3e70534 100644 --- a/apps/backend/internal/password/password.go +++ b/apps/backend/internal/password/password.go @@ -38,15 +38,21 @@ type argon2Hash struct { hash []byte } -// Hash securely hashes a plaintext password using argon2id. -func Hash(plain string) (Hashed, error) { +// HashString hashes the given password string. +// This is a convenience function that converts the password string to a byte slice and hashes it using Hash. +func HashString(pw string) (Hashed, error) { + return Hash([]byte(pw)) +} + +// Hash hashes the provided bytes. +func Hash(pw []byte) (Hashed, error) { salt := make([]byte, saltLength) if _, err := rand.Read(salt); err != nil { return "", fmt.Errorf("failed to generate salt: %w", err) } hash := argon2.IDKey( - []byte(plain), + pw, salt, iterations, memory, @@ -70,15 +76,19 @@ func Hash(plain string) (Hashed, error) { return Hashed(encoded), nil } +func VerifyString(plain string, hashed Hashed) (bool, error) { + return Verify([]byte(plain), hashed) +} + // Verify checks if a plaintext password matches a hashed password. -func Verify(plain string, hashed Hashed) (bool, error) { +func Verify(plain []byte, hashed Hashed) (bool, error) { h, err := decodeHash(string(hashed)) if err != nil { return false, err } otherHash := argon2.IDKey( - []byte(plain), + plain, h.salt, h.iterations, h.memory,