feat: impl bearer auth middleware

This commit is contained in:
2025-11-26 01:09:42 +00:00
parent 81e3f7af75
commit 389fe35a0a
9 changed files with 161 additions and 25 deletions

View File

@@ -1,6 +1,11 @@
package auth
import "fmt"
import (
"errors"
"fmt"
)
var ErrUnauthenticatedRequest = errors.New("unauthenticated request")
type InvalidAccessTokenError struct {
err error

View File

@@ -3,6 +3,7 @@ package auth
import (
"errors"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
)
@@ -20,9 +21,9 @@ type registerRequest struct {
}
type loginResponse struct {
User User `json:"user"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
User user.User `json:"user"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
}
func RegisterAPIRoutes(api fiber.Router, s *Service) {
@@ -56,7 +57,7 @@ func login(c *fiber.Ctx) error {
}
return c.JSON(loginResponse{
User: result.User,
User: *result.User,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
})
@@ -83,7 +84,7 @@ func register(c *fiber.Ctx) error {
}
return c.JSON(loginResponse{
User: result.User,
User: *result.User,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
})

View File

@@ -0,0 +1,56 @@
package auth
import (
"errors"
"strings"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
)
const authenticatedUserKey = "authenticatedUser"
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token.
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
func NewBearerAuthMiddleware(s *Service) fiber.Handler {
return func(c *fiber.Ctx) error {
authHeader := c.Get("Authorization")
if authHeader == "" {
return c.SendStatus(fiber.StatusUnauthorized)
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
return c.SendStatus(fiber.StatusUnauthorized)
}
token := parts[1]
u, err := s.AuthenticateWithAccessToken(c.Context(), token)
if err != nil {
var e *InvalidAccessTokenError
if errors.As(err, &e) {
return c.SendStatus(fiber.StatusUnauthorized)
}
var nf *user.NotFoundError
if errors.As(err, &nf) {
return c.SendStatus(fiber.StatusUnauthorized)
}
return c.SendStatus(fiber.StatusInternalServerError)
}
c.Locals(authenticatedUserKey, u)
return c.Next()
}
}
// AuthenticatedUser returns the authenticated user from the given fiber context.
// Returns ErrUnauthenticatedRequest if not authenticated.
func AuthenticatedUser(c *fiber.Ctx) (*user.User, error) {
if u, ok := c.Locals(authenticatedUserKey).(*user.User); ok {
return u, nil
}
return nil, ErrUnauthenticatedRequest
}

View File

@@ -5,11 +5,13 @@ import (
"encoding/hex"
"errors"
"github.com/get-drexa/drexa/internal/user"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type LoginResult struct {
User User
User *user.User
AccessToken string
RefreshToken string
}
@@ -19,6 +21,7 @@ var ErrUserExists = errors.New("user already exists")
type Service struct {
db *bun.DB
userService *user.Service
tokenConfig TokenConfig
}
@@ -28,32 +31,33 @@ type registerOptions struct {
password string
}
func NewService(db *bun.DB, tokenConfig TokenConfig) *Service {
func NewService(db *bun.DB, userService *user.Service, tokenConfig TokenConfig) *Service {
return &Service{
db: db,
userService: userService,
tokenConfig: tokenConfig,
}
}
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password string) (*LoginResult, error) {
var user User
var u user.User
err := s.db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx)
err := s.db.NewSelect().Model(&u).Where("email = ?", email).Scan(ctx)
if err != nil {
return nil, err
}
ok, err := VerifyPassword(password, user.Password)
ok, err := VerifyPassword(password, u.Password)
if err != nil || !ok {
return nil, ErrInvalidCredentials
}
at, err := GenerateAccessToken(&user, &s.tokenConfig)
at, err := GenerateAccessToken(&u, &s.tokenConfig)
if err != nil {
return nil, err
}
rt, err := GenerateRefreshToken(&user, &s.tokenConfig)
rt, err := GenerateRefreshToken(&u, &s.tokenConfig)
if err != nil {
return nil, err
}
@@ -64,19 +68,19 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password
}
return &LoginResult{
User: user,
User: &u,
AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token),
}, nil
}
func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) {
user := User{
u := user.User{
Email: opts.email,
DisplayName: opts.displayName,
}
exists, err := s.db.NewSelect().Model(&user).Where("email = ?", opts.email).Exists(ctx)
exists, err := s.db.NewSelect().Model(&u).Where("email = ?", opts.email).Exists(ctx)
if err != nil {
return nil, err
}
@@ -84,29 +88,43 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
return nil, ErrUserExists
}
user.Password, err = HashPassword(opts.password)
u.Password, err = HashPassword(opts.password)
if err != nil {
return nil, err
}
_, err = s.db.NewInsert().Model(&user).Exec(ctx)
_, err = s.db.NewInsert().Model(&u).Exec(ctx)
if err != nil {
return nil, err
}
at, err := GenerateAccessToken(&user, &s.tokenConfig)
at, err := GenerateAccessToken(&u, &s.tokenConfig)
if err != nil {
return nil, err
}
rt, err := GenerateRefreshToken(&user, &s.tokenConfig)
rt, err := GenerateRefreshToken(&u, &s.tokenConfig)
if err != nil {
return nil, err
}
return &LoginResult{
User: user,
User: &u,
AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token),
}, nil
}
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string) (*user.User, error) {
claims, err := ParseAccessToken(token, &s.tokenConfig)
if err != nil {
return nil, err
}
id, err := uuid.Parse(claims.Subject)
if err != nil {
return nil, newInvalidAccessTokenError(err)
}
return s.userService.UserByID(ctx, id)
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"time"
"github.com/get-drexa/drexa/internal/user"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/uptrace/bun"
@@ -35,7 +36,7 @@ type RefreshToken struct {
CreatedAt time.Time `bun:"created_at,notnull"`
}
func GenerateAccessToken(user *User, c *TokenConfig) (string, error) {
func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) {
now := time.Now()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
@@ -54,7 +55,7 @@ func GenerateAccessToken(user *User, c *TokenConfig) (string, error) {
return signed, nil
}
func GenerateRefreshToken(user *User, c *TokenConfig) (*RefreshToken, error) {
func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error) {
now := time.Now()
buf := make([]byte, refreshTokenByteLength)
@@ -80,6 +81,8 @@ func GenerateRefreshToken(user *User, c *TokenConfig) (*RefreshToken, error) {
}, nil
}
// ParseAccessToken parses a JWT access token and returns the claims.
// Returns an InvalidAccessTokenError if the token is invalid.
func ParseAccessToken(token string, c *TokenConfig) (*jwt.RegisteredClaims, error) {
parsed, err := jwt.ParseWithClaims(token, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
return c.SecretKey, nil

View File

@@ -1,17 +0,0 @@
package auth
import (
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type User struct {
bun.BaseModel `bun:"users"`
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
DisplayName string `bun:"display_name,notnull" json:"displayName"`
Email string `bun:"email,unique,notnull" json:"email"`
Password string `bun:"password,notnull" json:"-"`
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
}