mirror of
https://github.com/get-drexa/drive.git
synced 2025-11-30 21:41:39 +00:00
feat: impl bearer auth middleware
This commit is contained in:
@@ -1,6 +1,11 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrUnauthenticatedRequest = errors.New("unauthenticated request")
|
||||||
|
|
||||||
type InvalidAccessTokenError struct {
|
type InvalidAccessTokenError struct {
|
||||||
err error
|
err error
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,7 +21,7 @@ type registerRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type loginResponse struct {
|
type loginResponse struct {
|
||||||
User User `json:"user"`
|
User user.User `json:"user"`
|
||||||
AccessToken string `json:"accessToken"`
|
AccessToken string `json:"accessToken"`
|
||||||
RefreshToken string `json:"refreshToken"`
|
RefreshToken string `json:"refreshToken"`
|
||||||
}
|
}
|
||||||
@@ -56,7 +57,7 @@ func login(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(loginResponse{
|
return c.JSON(loginResponse{
|
||||||
User: result.User,
|
User: *result.User,
|
||||||
AccessToken: result.AccessToken,
|
AccessToken: result.AccessToken,
|
||||||
RefreshToken: result.RefreshToken,
|
RefreshToken: result.RefreshToken,
|
||||||
})
|
})
|
||||||
@@ -83,7 +84,7 @@ func register(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(loginResponse{
|
return c.JSON(loginResponse{
|
||||||
User: result.User,
|
User: *result.User,
|
||||||
AccessToken: result.AccessToken,
|
AccessToken: result.AccessToken,
|
||||||
RefreshToken: result.RefreshToken,
|
RefreshToken: result.RefreshToken,
|
||||||
})
|
})
|
||||||
|
|||||||
56
apps/backend/internal/auth/middleware.go
Normal file
56
apps/backend/internal/auth/middleware.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const authenticatedUserKey = "authenticatedUser"
|
||||||
|
|
||||||
|
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token.
|
||||||
|
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
|
||||||
|
func NewBearerAuthMiddleware(s *Service) fiber.Handler {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
authHeader := c.Get("Authorization")
|
||||||
|
if authHeader == "" {
|
||||||
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(authHeader, " ")
|
||||||
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||||
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
token := parts[1]
|
||||||
|
u, err := s.AuthenticateWithAccessToken(c.Context(), token)
|
||||||
|
if err != nil {
|
||||||
|
var e *InvalidAccessTokenError
|
||||||
|
if errors.As(err, &e) {
|
||||||
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
var nf *user.NotFoundError
|
||||||
|
if errors.As(err, &nf) {
|
||||||
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.SendStatus(fiber.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Locals(authenticatedUserKey, u)
|
||||||
|
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthenticatedUser returns the authenticated user from the given fiber context.
|
||||||
|
// Returns ErrUnauthenticatedRequest if not authenticated.
|
||||||
|
func AuthenticatedUser(c *fiber.Ctx) (*user.User, error) {
|
||||||
|
if u, ok := c.Locals(authenticatedUserKey).(*user.User); ok {
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
return nil, ErrUnauthenticatedRequest
|
||||||
|
}
|
||||||
@@ -5,11 +5,13 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LoginResult struct {
|
type LoginResult struct {
|
||||||
User User
|
User *user.User
|
||||||
AccessToken string
|
AccessToken string
|
||||||
RefreshToken string
|
RefreshToken string
|
||||||
}
|
}
|
||||||
@@ -19,6 +21,7 @@ var ErrUserExists = errors.New("user already exists")
|
|||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
db *bun.DB
|
db *bun.DB
|
||||||
|
userService *user.Service
|
||||||
tokenConfig TokenConfig
|
tokenConfig TokenConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -28,32 +31,33 @@ type registerOptions struct {
|
|||||||
password string
|
password string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(db *bun.DB, tokenConfig TokenConfig) *Service {
|
func NewService(db *bun.DB, userService *user.Service, tokenConfig TokenConfig) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
db: db,
|
db: db,
|
||||||
|
userService: userService,
|
||||||
tokenConfig: tokenConfig,
|
tokenConfig: tokenConfig,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password string) (*LoginResult, error) {
|
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password string) (*LoginResult, error) {
|
||||||
var user User
|
var u user.User
|
||||||
|
|
||||||
err := s.db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx)
|
err := s.db.NewSelect().Model(&u).Where("email = ?", email).Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := VerifyPassword(password, user.Password)
|
ok, err := VerifyPassword(password, u.Password)
|
||||||
if err != nil || !ok {
|
if err != nil || !ok {
|
||||||
return nil, ErrInvalidCredentials
|
return nil, ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
at, err := GenerateAccessToken(&user, &s.tokenConfig)
|
at, err := GenerateAccessToken(&u, &s.tokenConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rt, err := GenerateRefreshToken(&user, &s.tokenConfig)
|
rt, err := GenerateRefreshToken(&u, &s.tokenConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -64,19 +68,19 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &LoginResult{
|
return &LoginResult{
|
||||||
User: user,
|
User: &u,
|
||||||
AccessToken: at,
|
AccessToken: at,
|
||||||
RefreshToken: hex.EncodeToString(rt.Token),
|
RefreshToken: hex.EncodeToString(rt.Token),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) {
|
func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) {
|
||||||
user := User{
|
u := user.User{
|
||||||
Email: opts.email,
|
Email: opts.email,
|
||||||
DisplayName: opts.displayName,
|
DisplayName: opts.displayName,
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err := s.db.NewSelect().Model(&user).Where("email = ?", opts.email).Exists(ctx)
|
exists, err := s.db.NewSelect().Model(&u).Where("email = ?", opts.email).Exists(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -84,29 +88,43 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes
|
|||||||
return nil, ErrUserExists
|
return nil, ErrUserExists
|
||||||
}
|
}
|
||||||
|
|
||||||
user.Password, err = HashPassword(opts.password)
|
u.Password, err = HashPassword(opts.password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.db.NewInsert().Model(&user).Exec(ctx)
|
_, err = s.db.NewInsert().Model(&u).Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
at, err := GenerateAccessToken(&user, &s.tokenConfig)
|
at, err := GenerateAccessToken(&u, &s.tokenConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rt, err := GenerateRefreshToken(&user, &s.tokenConfig)
|
rt, err := GenerateRefreshToken(&u, &s.tokenConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &LoginResult{
|
return &LoginResult{
|
||||||
User: user,
|
User: &u,
|
||||||
AccessToken: at,
|
AccessToken: at,
|
||||||
RefreshToken: hex.EncodeToString(rt.Token),
|
RefreshToken: hex.EncodeToString(rt.Token),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string) (*user.User, error) {
|
||||||
|
claims, err := ParseAccessToken(token, &s.tokenConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := uuid.Parse(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
return nil, newInvalidAccessTokenError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.userService.UserByID(ctx, id)
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
@@ -35,7 +36,7 @@ type RefreshToken struct {
|
|||||||
CreatedAt time.Time `bun:"created_at,notnull"`
|
CreatedAt time.Time `bun:"created_at,notnull"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateAccessToken(user *User, c *TokenConfig) (string, error) {
|
func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
|
||||||
@@ -54,7 +55,7 @@ func GenerateAccessToken(user *User, c *TokenConfig) (string, error) {
|
|||||||
return signed, nil
|
return signed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateRefreshToken(user *User, c *TokenConfig) (*RefreshToken, error) {
|
func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
buf := make([]byte, refreshTokenByteLength)
|
buf := make([]byte, refreshTokenByteLength)
|
||||||
@@ -80,6 +81,8 @@ func GenerateRefreshToken(user *User, c *TokenConfig) (*RefreshToken, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseAccessToken parses a JWT access token and returns the claims.
|
||||||
|
// Returns an InvalidAccessTokenError if the token is invalid.
|
||||||
func ParseAccessToken(token string, c *TokenConfig) (*jwt.RegisteredClaims, error) {
|
func ParseAccessToken(token string, c *TokenConfig) (*jwt.RegisteredClaims, error) {
|
||||||
parsed, err := jwt.ParseWithClaims(token, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
|
parsed, err := jwt.ParseWithClaims(token, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
|
||||||
return c.SecretKey, nil
|
return c.SecretKey, nil
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/get-drexa/drexa/internal/auth"
|
"github.com/get-drexa/drexa/internal/auth"
|
||||||
"github.com/get-drexa/drexa/internal/database"
|
"github.com/get-drexa/drexa/internal/database"
|
||||||
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -24,7 +25,8 @@ func NewServer(c ServerConfig) *fiber.App {
|
|||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
db := database.NewFromPostgres(c.PostgresURL)
|
db := database.NewFromPostgres(c.PostgresURL)
|
||||||
|
|
||||||
authService := auth.NewService(db, auth.TokenConfig{
|
userService := user.NewService(db)
|
||||||
|
authService := auth.NewService(db, userService, auth.TokenConfig{
|
||||||
Issuer: c.JWTIssuer,
|
Issuer: c.JWTIssuer,
|
||||||
Audience: c.JWTAudience,
|
Audience: c.JWTAudience,
|
||||||
SecretKey: c.JWTSecretKey,
|
SecretKey: c.JWTSecretKey,
|
||||||
|
|||||||
19
apps/backend/internal/user/err.go
Normal file
19
apps/backend/internal/user/err.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NotFoundError struct {
|
||||||
|
id uuid.UUID
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNotFoundError(id uuid.UUID) *NotFoundError {
|
||||||
|
return &NotFoundError{id}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NotFoundError) Error() string {
|
||||||
|
return fmt.Sprintf("user not found: %v", e.id)
|
||||||
|
}
|
||||||
32
apps/backend/internal/user/service.go
Normal file
32
apps/backend/internal/user/service.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
db *bun.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(db *bun.DB) *Service {
|
||||||
|
return &Service{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) {
|
||||||
|
var user User
|
||||||
|
err := s.db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, newNotFoundError(id)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package auth
|
package user
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
Reference in New Issue
Block a user