feat: impl bearer auth middleware

This commit is contained in:
2025-11-26 01:09:42 +00:00
parent 81e3f7af75
commit 389fe35a0a
9 changed files with 161 additions and 25 deletions

View File

@@ -1,6 +1,11 @@
package auth package auth
import "fmt" import (
"errors"
"fmt"
)
var ErrUnauthenticatedRequest = errors.New("unauthenticated request")
type InvalidAccessTokenError struct { type InvalidAccessTokenError struct {
err error err error

View File

@@ -3,6 +3,7 @@ package auth
import ( import (
"errors" "errors"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
@@ -20,7 +21,7 @@ type registerRequest struct {
} }
type loginResponse struct { type loginResponse struct {
User User `json:"user"` User user.User `json:"user"`
AccessToken string `json:"accessToken"` AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"` RefreshToken string `json:"refreshToken"`
} }
@@ -56,7 +57,7 @@ func login(c *fiber.Ctx) error {
} }
return c.JSON(loginResponse{ return c.JSON(loginResponse{
User: result.User, User: *result.User,
AccessToken: result.AccessToken, AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken, RefreshToken: result.RefreshToken,
}) })
@@ -83,7 +84,7 @@ func register(c *fiber.Ctx) error {
} }
return c.JSON(loginResponse{ return c.JSON(loginResponse{
User: result.User, User: *result.User,
AccessToken: result.AccessToken, AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken, RefreshToken: result.RefreshToken,
}) })

View File

@@ -0,0 +1,56 @@
package auth
import (
"errors"
"strings"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
)
const authenticatedUserKey = "authenticatedUser"
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token.
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
func NewBearerAuthMiddleware(s *Service) fiber.Handler {
return func(c *fiber.Ctx) error {
authHeader := c.Get("Authorization")
if authHeader == "" {
return c.SendStatus(fiber.StatusUnauthorized)
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
return c.SendStatus(fiber.StatusUnauthorized)
}
token := parts[1]
u, err := s.AuthenticateWithAccessToken(c.Context(), token)
if err != nil {
var e *InvalidAccessTokenError
if errors.As(err, &e) {
return c.SendStatus(fiber.StatusUnauthorized)
}
var nf *user.NotFoundError
if errors.As(err, &nf) {
return c.SendStatus(fiber.StatusUnauthorized)
}
return c.SendStatus(fiber.StatusInternalServerError)
}
c.Locals(authenticatedUserKey, u)
return c.Next()
}
}
// AuthenticatedUser returns the authenticated user from the given fiber context.
// Returns ErrUnauthenticatedRequest if not authenticated.
func AuthenticatedUser(c *fiber.Ctx) (*user.User, error) {
if u, ok := c.Locals(authenticatedUserKey).(*user.User); ok {
return u, nil
}
return nil, ErrUnauthenticatedRequest
}

View File

@@ -5,11 +5,13 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"github.com/get-drexa/drexa/internal/user"
"github.com/google/uuid"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type LoginResult struct { type LoginResult struct {
User User User *user.User
AccessToken string AccessToken string
RefreshToken string RefreshToken string
} }
@@ -19,6 +21,7 @@ var ErrUserExists = errors.New("user already exists")
type Service struct { type Service struct {
db *bun.DB db *bun.DB
userService *user.Service
tokenConfig TokenConfig tokenConfig TokenConfig
} }
@@ -28,32 +31,33 @@ type registerOptions struct {
password string password string
} }
func NewService(db *bun.DB, tokenConfig TokenConfig) *Service { func NewService(db *bun.DB, userService *user.Service, tokenConfig TokenConfig) *Service {
return &Service{ return &Service{
db: db, db: db,
userService: userService,
tokenConfig: tokenConfig, tokenConfig: tokenConfig,
} }
} }
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password string) (*LoginResult, error) { func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password string) (*LoginResult, error) {
var user User var u user.User
err := s.db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx) err := s.db.NewSelect().Model(&u).Where("email = ?", email).Scan(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ok, err := VerifyPassword(password, user.Password) ok, err := VerifyPassword(password, u.Password)
if err != nil || !ok { if err != nil || !ok {
return nil, ErrInvalidCredentials return nil, ErrInvalidCredentials
} }
at, err := GenerateAccessToken(&user, &s.tokenConfig) at, err := GenerateAccessToken(&u, &s.tokenConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rt, err := GenerateRefreshToken(&user, &s.tokenConfig) rt, err := GenerateRefreshToken(&u, &s.tokenConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -64,19 +68,19 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password
} }
return &LoginResult{ return &LoginResult{
User: user, User: &u,
AccessToken: at, AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token), RefreshToken: hex.EncodeToString(rt.Token),
}, nil }, nil
} }
func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) { func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) {
user := User{ u := user.User{
Email: opts.email, Email: opts.email,
DisplayName: opts.displayName, DisplayName: opts.displayName,
} }
exists, err := s.db.NewSelect().Model(&user).Where("email = ?", opts.email).Exists(ctx) exists, err := s.db.NewSelect().Model(&u).Where("email = ?", opts.email).Exists(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -84,29 +88,43 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
return nil, ErrUserExists return nil, ErrUserExists
} }
user.Password, err = HashPassword(opts.password) u.Password, err = HashPassword(opts.password)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = s.db.NewInsert().Model(&user).Exec(ctx) _, err = s.db.NewInsert().Model(&u).Exec(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
at, err := GenerateAccessToken(&user, &s.tokenConfig) at, err := GenerateAccessToken(&u, &s.tokenConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rt, err := GenerateRefreshToken(&user, &s.tokenConfig) rt, err := GenerateRefreshToken(&u, &s.tokenConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &LoginResult{ return &LoginResult{
User: user, User: &u,
AccessToken: at, AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token), RefreshToken: hex.EncodeToString(rt.Token),
}, nil }, nil
} }
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string) (*user.User, error) {
claims, err := ParseAccessToken(token, &s.tokenConfig)
if err != nil {
return nil, err
}
id, err := uuid.Parse(claims.Subject)
if err != nil {
return nil, newInvalidAccessTokenError(err)
}
return s.userService.UserByID(ctx, id)
}

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/get-drexa/drexa/internal/user"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -35,7 +36,7 @@ type RefreshToken struct {
CreatedAt time.Time `bun:"created_at,notnull"` CreatedAt time.Time `bun:"created_at,notnull"`
} }
func GenerateAccessToken(user *User, c *TokenConfig) (string, error) { func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) {
now := time.Now() now := time.Now()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
@@ -54,7 +55,7 @@ func GenerateAccessToken(user *User, c *TokenConfig) (string, error) {
return signed, nil return signed, nil
} }
func GenerateRefreshToken(user *User, c *TokenConfig) (*RefreshToken, error) { func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error) {
now := time.Now() now := time.Now()
buf := make([]byte, refreshTokenByteLength) buf := make([]byte, refreshTokenByteLength)
@@ -80,6 +81,8 @@ func GenerateRefreshToken(user *User, c *TokenConfig) (*RefreshToken, error) {
}, nil }, nil
} }
// ParseAccessToken parses a JWT access token and returns the claims.
// Returns an InvalidAccessTokenError if the token is invalid.
func ParseAccessToken(token string, c *TokenConfig) (*jwt.RegisteredClaims, error) { func ParseAccessToken(token string, c *TokenConfig) (*jwt.RegisteredClaims, error) {
parsed, err := jwt.ParseWithClaims(token, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) { parsed, err := jwt.ParseWithClaims(token, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
return c.SecretKey, nil return c.SecretKey, nil

View File

@@ -9,6 +9,7 @@ import (
"github.com/get-drexa/drexa/internal/auth" "github.com/get-drexa/drexa/internal/auth"
"github.com/get-drexa/drexa/internal/database" "github.com/get-drexa/drexa/internal/database"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
@@ -24,7 +25,8 @@ func NewServer(c ServerConfig) *fiber.App {
app := fiber.New() app := fiber.New()
db := database.NewFromPostgres(c.PostgresURL) db := database.NewFromPostgres(c.PostgresURL)
authService := auth.NewService(db, auth.TokenConfig{ userService := user.NewService(db)
authService := auth.NewService(db, userService, auth.TokenConfig{
Issuer: c.JWTIssuer, Issuer: c.JWTIssuer,
Audience: c.JWTAudience, Audience: c.JWTAudience,
SecretKey: c.JWTSecretKey, SecretKey: c.JWTSecretKey,

View File

@@ -0,0 +1,19 @@
package user
import (
"fmt"
"github.com/google/uuid"
)
type NotFoundError struct {
id uuid.UUID
}
func newNotFoundError(id uuid.UUID) *NotFoundError {
return &NotFoundError{id}
}
func (e *NotFoundError) Error() string {
return fmt.Sprintf("user not found: %v", e.id)
}

View File

@@ -0,0 +1,32 @@
package user
import (
"context"
"database/sql"
"errors"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type Service struct {
db *bun.DB
}
func NewService(db *bun.DB) *Service {
return &Service{
db: db,
}
}
func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) {
var user User
err := s.db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, newNotFoundError(id)
}
return nil, err
}
return &user, nil
}

View File

@@ -1,4 +1,4 @@
package auth package user
import ( import (
"github.com/google/uuid" "github.com/google/uuid"