mirror of
https://github.com/get-drexa/drive.git
synced 2025-12-04 15:21:39 +00:00
feat: use argon2id to hash refresh tokens in db
This commit is contained in:
@@ -37,7 +37,7 @@ func NewService(userService *user.Service, vfs *virtualfs.VirtualFS) *Service {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Register(ctx context.Context, db bun.IDB, opts RegisterOptions) (*Account, *user.User, error) {
|
func (s *Service) Register(ctx context.Context, db bun.IDB, opts RegisterOptions) (*Account, *user.User, error) {
|
||||||
hashed, err := password.Hash(opts.Password)
|
hashed, err := password.HashString(opts.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"time"
|
"time"
|
||||||
@@ -81,15 +82,18 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, t
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshToken string) (*AuthenticationTokens, error) {
|
func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshToken string) (*AuthenticationTokens, error) {
|
||||||
rtBytes, err := DecodeRefreshToken(refreshToken)
|
t, err := base64.URLEncoding.DecodeString(refreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, ErrInvalidRefreshToken
|
||||||
}
|
}
|
||||||
|
|
||||||
rtHash := HashRefreshToken(rtBytes)
|
rtParts, err := DeserializeRefreshToken(t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrInvalidRefreshToken
|
||||||
|
}
|
||||||
|
|
||||||
rt := &RefreshToken{}
|
rt := &RefreshToken{}
|
||||||
err = db.NewSelect().Model(rt).Where("token_hash = ?", rtHash).Scan(ctx)
|
err = db.NewSelect().Model(rt).Where("key = ?", rtParts.Key).Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, ErrInvalidRefreshToken
|
return nil, ErrInvalidRefreshToken
|
||||||
@@ -133,6 +137,14 @@ func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshTok
|
|||||||
return nil, ErrInvalidRefreshToken
|
return nil, ErrInvalidRefreshToken
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ok, err := password.Verify(rtParts.Token, rt.TokenHash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrInvalidRefreshToken
|
||||||
|
}
|
||||||
|
|
||||||
u, err := s.userService.UserByID(ctx, db, grant.UserID)
|
u, err := s.userService.UserByID(ctx, db, grant.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var nf *user.NotFoundError
|
var nf *user.NotFoundError
|
||||||
@@ -166,10 +178,15 @@ func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshTok
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
srt, err := SerializeRefreshToken(newRT)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &AuthenticationTokens{
|
return &AuthenticationTokens{
|
||||||
User: u,
|
User: u,
|
||||||
AccessToken: at,
|
AccessToken: at,
|
||||||
RefreshToken: EncodeRefreshToken(newRT.Token),
|
RefreshToken: base64.URLEncoding.EncodeToString(srt),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,7 +200,6 @@ func (s *Service) generateTokens(ctx context.Context, db bun.IDB, user *user.Use
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rt.GrantID = grant.ID
|
rt.GrantID = grant.ID
|
||||||
|
|
||||||
_, err = db.NewInsert().Model(rt).Exec(ctx)
|
_, err = db.NewInsert().Model(rt).Exec(ctx)
|
||||||
@@ -191,10 +207,15 @@ func (s *Service) generateTokens(ctx context.Context, db bun.IDB, user *user.Use
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
srt, err := SerializeRefreshToken(rt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &AuthenticationTokens{
|
return &AuthenticationTokens{
|
||||||
User: user,
|
User: user,
|
||||||
AccessToken: at,
|
AccessToken: at,
|
||||||
RefreshToken: EncodeRefreshToken(rt.Token),
|
RefreshToken: base64.URLEncoding.EncodeToString(srt),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,7 +229,7 @@ func (s *Service) authenticateWithEmailAndPassword(ctx context.Context, db bun.I
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := password.Verify(plain, u.Password)
|
ok, err := password.VerifyString(plain, u.Password)
|
||||||
if err != nil || !ok {
|
if err != nil || !ok {
|
||||||
return nil, ErrInvalidCredentials
|
return nil, ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,10 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/get-drexa/drexa/internal/password"
|
||||||
"github.com/get-drexa/drexa/internal/user"
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -15,7 +14,9 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
accessTokenValidFor = time.Minute * 15
|
accessTokenValidFor = time.Minute * 15
|
||||||
refreshTokenByteLength = 32
|
refreshTokenKeyLength = 16
|
||||||
|
refreshTokenDataLength = 32
|
||||||
|
refreshTokenLength = refreshTokenKeyLength + refreshTokenDataLength
|
||||||
refreshTokenValidFor = time.Hour * 24 * 30
|
refreshTokenValidFor = time.Hour * 24 * 30
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,13 +29,19 @@ type TokenConfig struct {
|
|||||||
type RefreshToken struct {
|
type RefreshToken struct {
|
||||||
bun.BaseModel `bun:"refresh_tokens"`
|
bun.BaseModel `bun:"refresh_tokens"`
|
||||||
|
|
||||||
ID uuid.UUID `bun:",pk,type:uuid"`
|
ID uuid.UUID `bun:",pk,type:uuid"`
|
||||||
GrantID uuid.UUID `bun:"grant_id,notnull"`
|
GrantID uuid.UUID `bun:"grant_id,notnull"`
|
||||||
Token []byte `bun:"-"`
|
Token []byte `bun:"-"`
|
||||||
TokenHash string `bun:"token_hash,notnull"`
|
Key uuid.UUID `bun:"key,notnull"`
|
||||||
ExpiresAt time.Time `bun:"expires_at,notnull"`
|
TokenHash password.Hashed `bun:"token_hash,notnull"`
|
||||||
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
|
ExpiresAt time.Time `bun:"expires_at,notnull"`
|
||||||
ConsumedAt *time.Time `bun:"consumed_at,nullzero"`
|
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
|
||||||
|
ConsumedAt *time.Time `bun:"consumed_at,nullzero"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RefreshTokenParts struct {
|
||||||
|
Key uuid.UUID
|
||||||
|
Token []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTokenID() (uuid.UUID, error) {
|
func newTokenID() (uuid.UUID, error) {
|
||||||
@@ -63,23 +70,31 @@ func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) {
|
|||||||
func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error) {
|
func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
buf := make([]byte, refreshTokenByteLength)
|
buf := make([]byte, refreshTokenDataLength)
|
||||||
if _, err := rand.Read(buf); err != nil {
|
if _, err := rand.Read(buf); err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
key, err := uuid.NewRandom()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate refresh token key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
id, err := newTokenID()
|
id, err := newTokenID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate token ID: %w", err)
|
return nil, fmt.Errorf("failed to generate token ID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h := sha256.Sum256(buf)
|
hashed, err := password.Hash(buf)
|
||||||
hex := hex.EncodeToString(h[:])
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to hash refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &RefreshToken{
|
return &RefreshToken{
|
||||||
ID: id,
|
ID: id,
|
||||||
Token: buf,
|
Token: buf,
|
||||||
TokenHash: hex,
|
Key: key,
|
||||||
|
TokenHash: hashed,
|
||||||
ExpiresAt: now.Add(refreshTokenValidFor),
|
ExpiresAt: now.Add(refreshTokenValidFor),
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -97,15 +112,34 @@ func ParseAccessToken(token string, c *TokenConfig) (*jwt.RegisteredClaims, erro
|
|||||||
return parsed.Claims.(*jwt.RegisteredClaims), nil
|
return parsed.Claims.(*jwt.RegisteredClaims), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func EncodeRefreshToken(token []byte) string {
|
func SerializeRefreshToken(token *RefreshToken) ([]byte, error) {
|
||||||
return hex.EncodeToString(token)
|
keyBytes, err := token.Key.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal refresh token key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 0, len(keyBytes)+len(token.Token))
|
||||||
|
buf = append(buf, keyBytes...)
|
||||||
|
buf = append(buf, token.Token...)
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecodeRefreshToken(token string) ([]byte, error) {
|
func DeserializeRefreshToken(buf []byte) (*RefreshTokenParts, error) {
|
||||||
return hex.DecodeString(token)
|
if len(buf) != refreshTokenLength {
|
||||||
}
|
return nil, fmt.Errorf("refresh token is too short")
|
||||||
|
}
|
||||||
|
|
||||||
func HashRefreshToken(token []byte) string {
|
keyBytes := buf[:refreshTokenKeyLength]
|
||||||
h := sha256.Sum256(token)
|
token := buf[refreshTokenKeyLength:]
|
||||||
return hex.EncodeToString(h[:])
|
|
||||||
|
key, err := uuid.FromBytes(keyBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal refresh token key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RefreshTokenParts{
|
||||||
|
Key: key,
|
||||||
|
Token: token,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ CREATE TABLE IF NOT EXISTS grants (
|
|||||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||||
id UUID PRIMARY KEY,
|
id UUID PRIMARY KEY,
|
||||||
grant_id UUID NOT NULL REFERENCES grants(id) ON DELETE CASCADE,
|
grant_id UUID NOT NULL REFERENCES grants(id) ON DELETE CASCADE,
|
||||||
|
key UUID NOT NULL UNIQUE,
|
||||||
token_hash TEXT NOT NULL UNIQUE,
|
token_hash TEXT NOT NULL UNIQUE,
|
||||||
expires_at TIMESTAMPTZ NOT NULL,
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
consumed_at TIMESTAMPTZ,
|
consumed_at TIMESTAMPTZ,
|
||||||
@@ -42,7 +43,6 @@ CREATE TABLE IF NOT EXISTS refresh_tokens (
|
|||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX idx_refresh_tokens_grant_id ON refresh_tokens(grant_id);
|
CREATE INDEX idx_refresh_tokens_grant_id ON refresh_tokens(grant_id);
|
||||||
CREATE INDEX idx_refresh_tokens_token_hash ON refresh_tokens(token_hash);
|
|
||||||
CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
|
CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
|
||||||
|
|
||||||
-- Virtual filesystem nodes (unified files + directories)
|
-- Virtual filesystem nodes (unified files + directories)
|
||||||
|
|||||||
@@ -38,15 +38,21 @@ type argon2Hash struct {
|
|||||||
hash []byte
|
hash []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hash securely hashes a plaintext password using argon2id.
|
// HashString hashes the given password string.
|
||||||
func Hash(plain string) (Hashed, error) {
|
// This is a convenience function that converts the password string to a byte slice and hashes it using Hash.
|
||||||
|
func HashString(pw string) (Hashed, error) {
|
||||||
|
return Hash([]byte(pw))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash hashes the provided bytes.
|
||||||
|
func Hash(pw []byte) (Hashed, error) {
|
||||||
salt := make([]byte, saltLength)
|
salt := make([]byte, saltLength)
|
||||||
if _, err := rand.Read(salt); err != nil {
|
if _, err := rand.Read(salt); err != nil {
|
||||||
return "", fmt.Errorf("failed to generate salt: %w", err)
|
return "", fmt.Errorf("failed to generate salt: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
hash := argon2.IDKey(
|
hash := argon2.IDKey(
|
||||||
[]byte(plain),
|
pw,
|
||||||
salt,
|
salt,
|
||||||
iterations,
|
iterations,
|
||||||
memory,
|
memory,
|
||||||
@@ -70,15 +76,19 @@ func Hash(plain string) (Hashed, error) {
|
|||||||
return Hashed(encoded), nil
|
return Hashed(encoded), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func VerifyString(plain string, hashed Hashed) (bool, error) {
|
||||||
|
return Verify([]byte(plain), hashed)
|
||||||
|
}
|
||||||
|
|
||||||
// Verify checks if a plaintext password matches a hashed password.
|
// Verify checks if a plaintext password matches a hashed password.
|
||||||
func Verify(plain string, hashed Hashed) (bool, error) {
|
func Verify(plain []byte, hashed Hashed) (bool, error) {
|
||||||
h, err := decodeHash(string(hashed))
|
h, err := decodeHash(string(hashed))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
otherHash := argon2.IDKey(
|
otherHash := argon2.IDKey(
|
||||||
[]byte(plain),
|
plain,
|
||||||
h.salt,
|
h.salt,
|
||||||
h.iterations,
|
h.iterations,
|
||||||
h.memory,
|
h.memory,
|
||||||
|
|||||||
Reference in New Issue
Block a user