mirror of
https://github.com/get-drexa/drive.git
synced 2025-12-04 07:21:38 +00:00
feat: impl refresh token rotation
This commit is contained in:
@@ -113,7 +113,7 @@ func (h *HTTPHandler) registerAccount(c *fiber.Ctx) error {
|
||||
return httperr.Internal(err)
|
||||
}
|
||||
|
||||
result, err := h.authService.GenerateTokenForUser(c.Context(), tx, u)
|
||||
result, err := h.authService.GrantForUser(c.Context(), tx, u)
|
||||
if err != nil {
|
||||
return httperr.Internal(err)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,9 @@ import (
|
||||
)
|
||||
|
||||
var ErrUnauthenticatedRequest = errors.New("unauthenticated request")
|
||||
var ErrInvalidRefreshToken = errors.New("invalid refresh token")
|
||||
var ErrRefreshTokenExpired = errors.New("refresh token expired")
|
||||
var ErrRefreshTokenReused = errors.New("refresh token reused")
|
||||
|
||||
type InvalidAccessTokenError struct {
|
||||
err error
|
||||
|
||||
22
apps/backend/internal/auth/grant.go
Normal file
22
apps/backend/internal/auth/grant.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type Grant struct {
|
||||
bun.BaseModel `bun:"grants"`
|
||||
|
||||
ID uuid.UUID `bun:",pk,type:uuid"`
|
||||
UserID uuid.UUID `bun:"user_id,notnull"`
|
||||
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
|
||||
UpdatedAt time.Time `bun:"updated_at,notnull,nullzero"`
|
||||
RevokedAt *time.Time `bun:"revoked_at,nullzero"`
|
||||
}
|
||||
|
||||
func newGrantID() (uuid.UUID, error) {
|
||||
return uuid.NewV7()
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
|
||||
"github.com/get-drexa/drexa/internal/httperr"
|
||||
"github.com/get-drexa/drexa/internal/user"
|
||||
@@ -20,6 +21,15 @@ type loginResponse struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
type refreshAccessTokenRequest struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
type tokenResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
type HTTPHandler struct {
|
||||
service *Service
|
||||
db *bun.DB
|
||||
@@ -32,6 +42,7 @@ func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
|
||||
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
|
||||
auth := api.Group("/auth")
|
||||
auth.Post("/login", h.Login)
|
||||
auth.Post("/tokens", h.refreshAccessToken)
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
||||
@@ -46,7 +57,7 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
result, err := h.service.AuthenticateWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
|
||||
result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrInvalidCredentials) {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
|
||||
@@ -64,3 +75,37 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
||||
RefreshToken: result.RefreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) refreshAccessToken(c *fiber.Ctx) error {
|
||||
req := new(refreshAccessTokenRequest)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
||||
}
|
||||
|
||||
tx, err := h.db.BeginTx(c.Context(), nil)
|
||||
if err != nil {
|
||||
return httperr.Internal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
result, err := h.service.RefreshAccessToken(c.Context(), tx, req.RefreshToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrInvalidRefreshToken) ||
|
||||
errors.Is(err, ErrRefreshTokenExpired) ||
|
||||
errors.Is(err, ErrRefreshTokenReused) {
|
||||
_ = tx.Commit()
|
||||
slog.Info("invalid refresh token", "error", err)
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "invalid refresh token"})
|
||||
}
|
||||
return httperr.Internal(err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return httperr.Internal(err)
|
||||
}
|
||||
|
||||
return c.JSON(tokenResponse{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,9 +2,10 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/get-drexa/drexa/internal/password"
|
||||
"github.com/get-drexa/drexa/internal/user"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type AuthenticationResult struct {
|
||||
type AuthenticationTokens struct {
|
||||
User *user.User
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
@@ -32,64 +33,35 @@ func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) GenerateTokenForUser(ctx context.Context, db bun.IDB, user *user.User) (*AuthenticationResult, error) {
|
||||
at, err := GenerateAccessToken(user, &s.tokenConfig)
|
||||
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*AuthenticationTokens, error) {
|
||||
u, err := s.authenticateWithEmailAndPassword(ctx, db, email, plain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rt, err := GenerateRefreshToken(user, &s.tokenConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = db.NewInsert().Model(rt).Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AuthenticationResult{
|
||||
User: user,
|
||||
AccessToken: at,
|
||||
RefreshToken: hex.EncodeToString(rt.Token),
|
||||
}, nil
|
||||
return s.GrantForUser(ctx, db, u)
|
||||
}
|
||||
|
||||
func (s *Service) AuthenticateWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*AuthenticationResult, error) {
|
||||
u, err := s.userService.UserByEmail(ctx, db, email)
|
||||
if err != nil {
|
||||
var nf *user.NotFoundError
|
||||
if errors.As(err, &nf) {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ok, err := password.Verify(plain, u.Password)
|
||||
if err != nil || !ok {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
at, err := GenerateAccessToken(u, &s.tokenConfig)
|
||||
func (s *Service) GrantForUser(ctx context.Context, db bun.IDB, user *user.User) (*AuthenticationTokens, error) {
|
||||
id, err := newGrantID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rt, err := GenerateRefreshToken(u, &s.tokenConfig)
|
||||
grant := &Grant{
|
||||
ID: id,
|
||||
UserID: user.ID,
|
||||
}
|
||||
_, err = db.NewInsert().Model(grant).Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = db.NewInsert().Model(rt).Exec(ctx)
|
||||
result, err := s.generateTokens(ctx, db, user, grant)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AuthenticationResult{
|
||||
User: u,
|
||||
AccessToken: at,
|
||||
RefreshToken: hex.EncodeToString(rt.Token),
|
||||
}, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
|
||||
@@ -107,3 +79,139 @@ func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, t
|
||||
|
||||
return s.userService.UserByID(ctx, db, id)
|
||||
}
|
||||
|
||||
func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshToken string) (*AuthenticationTokens, error) {
|
||||
rtBytes, err := DecodeRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rtHash := HashRefreshToken(rtBytes)
|
||||
|
||||
rt := &RefreshToken{}
|
||||
err = db.NewSelect().Model(rt).Where("token_hash = ?", rtHash).Scan(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rt.ExpiresAt.Before(time.Now()) {
|
||||
return nil, ErrRefreshTokenExpired
|
||||
}
|
||||
|
||||
if rt.ConsumedAt != nil {
|
||||
// token reuse detected, invalidate all refresh tokens under the same grant
|
||||
_, err = db.NewDelete().Model((*RefreshToken)(nil)).Where("grant_id = ?", rt.GrantID).Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = db.NewUpdate().Model((*Grant)(nil)).Set("revoked_at = ?", time.Now()).Where("id = ?", rt.GrantID).Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, ErrRefreshTokenReused
|
||||
}
|
||||
|
||||
grant := &Grant{
|
||||
ID: rt.GrantID,
|
||||
}
|
||||
err = db.NewSelect().Model(grant).WherePK().Scan(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// the grant for this refresh token was deleted, delete the refresh token
|
||||
_, _ = db.NewDelete().Model(rt).WherePK().Exec(ctx)
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if grant.RevokedAt != nil {
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
u, err := s.userService.UserByID(ctx, db, grant.UserID)
|
||||
if err != nil {
|
||||
var nf *user.NotFoundError
|
||||
if errors.As(err, &nf) {
|
||||
// the user for this grant was deleted, delete the grant and refresh token
|
||||
_, _ = db.NewDelete().Model(grant).WherePK().Exec(ctx)
|
||||
_, _ = db.NewDelete().Model(rt).WherePK().Exec(ctx)
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRT, err := GenerateRefreshToken(u, &s.tokenConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newRT.GrantID = grant.ID
|
||||
|
||||
_, err = db.NewUpdate().Model(rt).Set("consumed_at = ?", time.Now()).WherePK().Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = db.NewInsert().Model(newRT).Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
at, err := GenerateAccessToken(u, &s.tokenConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AuthenticationTokens{
|
||||
User: u,
|
||||
AccessToken: at,
|
||||
RefreshToken: EncodeRefreshToken(newRT.Token),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) generateTokens(ctx context.Context, db bun.IDB, user *user.User, grant *Grant) (*AuthenticationTokens, error) {
|
||||
at, err := GenerateAccessToken(user, &s.tokenConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rt, err := GenerateRefreshToken(user, &s.tokenConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rt.GrantID = grant.ID
|
||||
|
||||
_, err = db.NewInsert().Model(rt).Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AuthenticationTokens{
|
||||
User: user,
|
||||
AccessToken: at,
|
||||
RefreshToken: EncodeRefreshToken(rt.Token),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) authenticateWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*user.User, error) {
|
||||
u, err := s.userService.UserByEmail(ctx, db, email)
|
||||
if err != nil {
|
||||
var nf *user.NotFoundError
|
||||
if errors.As(err, &nf) {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ok, err := password.Verify(plain, u.Password)
|
||||
if err != nil || !ok {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
@@ -28,12 +28,13 @@ type TokenConfig struct {
|
||||
type RefreshToken struct {
|
||||
bun.BaseModel `bun:"refresh_tokens"`
|
||||
|
||||
ID uuid.UUID `bun:",pk,type:uuid"`
|
||||
UserID uuid.UUID `bun:"user_id,notnull"`
|
||||
Token []byte `bun:"-"`
|
||||
TokenHash string `bun:"token_hash,notnull"`
|
||||
ExpiresAt time.Time `bun:"expires_at,notnull"`
|
||||
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
|
||||
ID uuid.UUID `bun:",pk,type:uuid"`
|
||||
GrantID uuid.UUID `bun:"grant_id,notnull"`
|
||||
Token []byte `bun:"-"`
|
||||
TokenHash string `bun:"token_hash,notnull"`
|
||||
ExpiresAt time.Time `bun:"expires_at,notnull"`
|
||||
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
|
||||
ConsumedAt *time.Time `bun:"consumed_at,nullzero"`
|
||||
}
|
||||
|
||||
func newTokenID() (uuid.UUID, error) {
|
||||
@@ -77,7 +78,6 @@ func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error
|
||||
|
||||
return &RefreshToken{
|
||||
ID: id,
|
||||
UserID: user.ID,
|
||||
Token: buf,
|
||||
TokenHash: hex,
|
||||
ExpiresAt: now.Add(refreshTokenValidFor),
|
||||
@@ -96,3 +96,16 @@ func ParseAccessToken(token string, c *TokenConfig) (*jwt.RegisteredClaims, erro
|
||||
}
|
||||
return parsed.Claims.(*jwt.RegisteredClaims), nil
|
||||
}
|
||||
|
||||
func EncodeRefreshToken(token []byte) string {
|
||||
return hex.EncodeToString(token)
|
||||
}
|
||||
|
||||
func DecodeRefreshToken(token string) ([]byte, error) {
|
||||
return hex.DecodeString(token)
|
||||
}
|
||||
|
||||
func HashRefreshToken(token []byte) string {
|
||||
h := sha256.Sum256(token)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
@@ -24,15 +24,24 @@ CREATE TABLE IF NOT EXISTS accounts (
|
||||
|
||||
CREATE INDEX idx_accounts_user_id ON accounts(user_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
CREATE TABLE IF NOT EXISTS grants (
|
||||
id UUID PRIMARY KEY,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
revoked_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
id UUID PRIMARY KEY,
|
||||
grant_id UUID NOT NULL REFERENCES grants(id) ON DELETE CASCADE,
|
||||
token_hash TEXT NOT NULL UNIQUE,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
consumed_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_refresh_tokens_user_id ON refresh_tokens(user_id);
|
||||
CREATE INDEX idx_refresh_tokens_grant_id ON refresh_tokens(grant_id);
|
||||
CREATE INDEX idx_refresh_tokens_token_hash ON refresh_tokens(token_hash);
|
||||
CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
|
||||
|
||||
@@ -103,4 +112,7 @@ CREATE TRIGGER update_node_shares_updated_at BEFORE UPDATE ON node_shares
|
||||
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE TRIGGER update_accounts_updated_at BEFORE UPDATE ON accounts
|
||||
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE TRIGGER update_grants_updated_at BEFORE UPDATE ON grants
|
||||
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
||||
@@ -265,7 +265,7 @@ func (vfs *VirtualFS) CreateDirectory(ctx context.Context, db bun.IDB, accountID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node := Node{
|
||||
node := &Node{
|
||||
ID: id,
|
||||
PublicID: pid,
|
||||
AccountID: accountID,
|
||||
@@ -283,7 +283,7 @@ func (vfs *VirtualFS) CreateDirectory(ctx context.Context, db bun.IDB, accountID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &node, nil
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, db bun.IDB, node *Node) error {
|
||||
|
||||
Reference in New Issue
Block a user