mirror of
https://github.com/get-drexa/drive.git
synced 2025-12-04 23:31:40 +00:00
146 lines
3.7 KiB
Go
146 lines
3.7 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/get-drexa/drexa/internal/password"
|
|
"github.com/get-drexa/drexa/internal/user"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/uptrace/bun"
|
|
)
|
|
|
|
const (
|
|
accessTokenValidFor = time.Minute * 15
|
|
refreshTokenKeyLength = 16
|
|
refreshTokenDataLength = 32
|
|
refreshTokenLength = refreshTokenKeyLength + refreshTokenDataLength
|
|
refreshTokenValidFor = time.Hour * 24 * 30
|
|
)
|
|
|
|
type TokenConfig struct {
|
|
Issuer string
|
|
Audience string
|
|
SecretKey []byte
|
|
}
|
|
|
|
type RefreshToken struct {
|
|
bun.BaseModel `bun:"refresh_tokens"`
|
|
|
|
ID uuid.UUID `bun:",pk,type:uuid"`
|
|
GrantID uuid.UUID `bun:"grant_id,notnull"`
|
|
Token []byte `bun:"-"`
|
|
Key uuid.UUID `bun:"key,notnull"`
|
|
TokenHash password.Hashed `bun:"token_hash,notnull"`
|
|
ExpiresAt time.Time `bun:"expires_at,notnull"`
|
|
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
|
|
ConsumedAt *time.Time `bun:"consumed_at,nullzero"`
|
|
}
|
|
|
|
type RefreshTokenParts struct {
|
|
Key uuid.UUID
|
|
Token []byte
|
|
}
|
|
|
|
func newTokenID() (uuid.UUID, error) {
|
|
return uuid.NewV7()
|
|
}
|
|
|
|
func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) {
|
|
now := time.Now()
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
|
|
Issuer: c.Issuer,
|
|
Audience: jwt.ClaimStrings{c.Audience},
|
|
Subject: user.ID.String(),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(accessTokenValidFor)),
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
})
|
|
|
|
signed, err := token.SignedString(c.SecretKey)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sign token: %w", err)
|
|
}
|
|
|
|
return signed, nil
|
|
}
|
|
|
|
func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error) {
|
|
now := time.Now()
|
|
|
|
buf := make([]byte, refreshTokenDataLength)
|
|
if _, err := rand.Read(buf); err != nil {
|
|
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
|
}
|
|
|
|
key, err := uuid.NewRandom()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate refresh token key: %w", err)
|
|
}
|
|
|
|
id, err := newTokenID()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate token ID: %w", err)
|
|
}
|
|
|
|
hashed, err := password.Hash(buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to hash refresh token: %w", err)
|
|
}
|
|
|
|
return &RefreshToken{
|
|
ID: id,
|
|
Token: buf,
|
|
Key: key,
|
|
TokenHash: hashed,
|
|
ExpiresAt: now.Add(refreshTokenValidFor),
|
|
CreatedAt: now,
|
|
}, nil
|
|
}
|
|
|
|
// ParseAccessToken parses a JWT access token and returns the claims.
|
|
// Returns an InvalidAccessTokenError if the token is invalid.
|
|
func ParseAccessToken(token string, c *TokenConfig) (*jwt.RegisteredClaims, error) {
|
|
parsed, err := jwt.ParseWithClaims(token, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
|
|
return c.SecretKey, nil
|
|
}, jwt.WithIssuer(c.Issuer), jwt.WithExpirationRequired(), jwt.WithAudience(c.Audience))
|
|
if err != nil {
|
|
return nil, newInvalidAccessTokenError(err)
|
|
}
|
|
return parsed.Claims.(*jwt.RegisteredClaims), nil
|
|
}
|
|
|
|
func SerializeRefreshToken(token *RefreshToken) ([]byte, error) {
|
|
keyBytes, err := token.Key.MarshalBinary()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal refresh token key: %w", err)
|
|
}
|
|
|
|
buf := make([]byte, 0, len(keyBytes)+len(token.Token))
|
|
buf = append(buf, keyBytes...)
|
|
buf = append(buf, token.Token...)
|
|
|
|
return buf, nil
|
|
}
|
|
|
|
func DeserializeRefreshToken(buf []byte) (*RefreshTokenParts, error) {
|
|
if len(buf) != refreshTokenLength {
|
|
return nil, fmt.Errorf("refresh token is too short")
|
|
}
|
|
|
|
keyBytes := buf[:refreshTokenKeyLength]
|
|
token := buf[refreshTokenKeyLength:]
|
|
|
|
key, err := uuid.FromBytes(keyBytes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal refresh token key: %w", err)
|
|
}
|
|
|
|
return &RefreshTokenParts{
|
|
Key: key,
|
|
Token: token,
|
|
}, nil
|
|
}
|