diff --git a/apps/backend/internal/auth/err.go b/apps/backend/internal/auth/err.go index ab887e6..0def7f4 100644 --- a/apps/backend/internal/auth/err.go +++ b/apps/backend/internal/auth/err.go @@ -1,6 +1,11 @@ package auth -import "fmt" +import ( + "errors" + "fmt" +) + +var ErrUnauthenticatedRequest = errors.New("unauthenticated request") type InvalidAccessTokenError struct { err error diff --git a/apps/backend/internal/auth/http.go b/apps/backend/internal/auth/http.go index c899f2f..b28df9e 100644 --- a/apps/backend/internal/auth/http.go +++ b/apps/backend/internal/auth/http.go @@ -3,6 +3,7 @@ package auth import ( "errors" + "github.com/get-drexa/drexa/internal/user" "github.com/gofiber/fiber/v2" ) @@ -20,9 +21,9 @@ type registerRequest struct { } type loginResponse struct { - User User `json:"user"` - AccessToken string `json:"accessToken"` - RefreshToken string `json:"refreshToken"` + User user.User `json:"user"` + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` } func RegisterAPIRoutes(api fiber.Router, s *Service) { @@ -56,7 +57,7 @@ func login(c *fiber.Ctx) error { } return c.JSON(loginResponse{ - User: result.User, + User: *result.User, AccessToken: result.AccessToken, RefreshToken: result.RefreshToken, }) @@ -83,7 +84,7 @@ func register(c *fiber.Ctx) error { } return c.JSON(loginResponse{ - User: result.User, + User: *result.User, AccessToken: result.AccessToken, RefreshToken: result.RefreshToken, }) diff --git a/apps/backend/internal/auth/middleware.go b/apps/backend/internal/auth/middleware.go new file mode 100644 index 0000000..f3d136a --- /dev/null +++ b/apps/backend/internal/auth/middleware.go @@ -0,0 +1,56 @@ +package auth + +import ( + "errors" + "strings" + + "github.com/get-drexa/drexa/internal/user" + "github.com/gofiber/fiber/v2" +) + +const authenticatedUserKey = "authenticatedUser" + +// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token. +// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser. +func NewBearerAuthMiddleware(s *Service) fiber.Handler { + return func(c *fiber.Ctx) error { + authHeader := c.Get("Authorization") + if authHeader == "" { + return c.SendStatus(fiber.StatusUnauthorized) + } + + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return c.SendStatus(fiber.StatusUnauthorized) + } + + token := parts[1] + u, err := s.AuthenticateWithAccessToken(c.Context(), token) + if err != nil { + var e *InvalidAccessTokenError + if errors.As(err, &e) { + return c.SendStatus(fiber.StatusUnauthorized) + } + + var nf *user.NotFoundError + if errors.As(err, &nf) { + return c.SendStatus(fiber.StatusUnauthorized) + } + + return c.SendStatus(fiber.StatusInternalServerError) + } + + c.Locals(authenticatedUserKey, u) + + return c.Next() + } +} + +// AuthenticatedUser returns the authenticated user from the given fiber context. +// Returns ErrUnauthenticatedRequest if not authenticated. +func AuthenticatedUser(c *fiber.Ctx) (*user.User, error) { + if u, ok := c.Locals(authenticatedUserKey).(*user.User); ok { + return u, nil + } + return nil, ErrUnauthenticatedRequest +} diff --git a/apps/backend/internal/auth/service.go b/apps/backend/internal/auth/service.go index 228ea84..76f5418 100644 --- a/apps/backend/internal/auth/service.go +++ b/apps/backend/internal/auth/service.go @@ -5,11 +5,13 @@ import ( "encoding/hex" "errors" + "github.com/get-drexa/drexa/internal/user" + "github.com/google/uuid" "github.com/uptrace/bun" ) type LoginResult struct { - User User + User *user.User AccessToken string RefreshToken string } @@ -19,6 +21,7 @@ var ErrUserExists = errors.New("user already exists") type Service struct { db *bun.DB + userService *user.Service tokenConfig TokenConfig } @@ -28,32 +31,33 @@ type registerOptions struct { password string } -func NewService(db *bun.DB, tokenConfig TokenConfig) *Service { +func NewService(db *bun.DB, userService *user.Service, tokenConfig TokenConfig) *Service { return &Service{ db: db, + userService: userService, tokenConfig: tokenConfig, } } func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password string) (*LoginResult, error) { - var user User + var u user.User - err := s.db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx) + err := s.db.NewSelect().Model(&u).Where("email = ?", email).Scan(ctx) if err != nil { return nil, err } - ok, err := VerifyPassword(password, user.Password) + ok, err := VerifyPassword(password, u.Password) if err != nil || !ok { return nil, ErrInvalidCredentials } - at, err := GenerateAccessToken(&user, &s.tokenConfig) + at, err := GenerateAccessToken(&u, &s.tokenConfig) if err != nil { return nil, err } - rt, err := GenerateRefreshToken(&user, &s.tokenConfig) + rt, err := GenerateRefreshToken(&u, &s.tokenConfig) if err != nil { return nil, err } @@ -64,19 +68,19 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password } return &LoginResult{ - User: user, + User: &u, AccessToken: at, RefreshToken: hex.EncodeToString(rt.Token), }, nil } func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) { - user := User{ + u := user.User{ Email: opts.email, DisplayName: opts.displayName, } - exists, err := s.db.NewSelect().Model(&user).Where("email = ?", opts.email).Exists(ctx) + exists, err := s.db.NewSelect().Model(&u).Where("email = ?", opts.email).Exists(ctx) if err != nil { return nil, err } @@ -84,29 +88,43 @@ func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginRes return nil, ErrUserExists } - user.Password, err = HashPassword(opts.password) + u.Password, err = HashPassword(opts.password) if err != nil { return nil, err } - _, err = s.db.NewInsert().Model(&user).Exec(ctx) + _, err = s.db.NewInsert().Model(&u).Exec(ctx) if err != nil { return nil, err } - at, err := GenerateAccessToken(&user, &s.tokenConfig) + at, err := GenerateAccessToken(&u, &s.tokenConfig) if err != nil { return nil, err } - rt, err := GenerateRefreshToken(&user, &s.tokenConfig) + rt, err := GenerateRefreshToken(&u, &s.tokenConfig) if err != nil { return nil, err } return &LoginResult{ - User: user, + User: &u, AccessToken: at, RefreshToken: hex.EncodeToString(rt.Token), }, nil } + +func (s *Service) AuthenticateWithAccessToken(ctx context.Context, token string) (*user.User, error) { + claims, err := ParseAccessToken(token, &s.tokenConfig) + if err != nil { + return nil, err + } + + id, err := uuid.Parse(claims.Subject) + if err != nil { + return nil, newInvalidAccessTokenError(err) + } + + return s.userService.UserByID(ctx, id) +} diff --git a/apps/backend/internal/auth/tokens.go b/apps/backend/internal/auth/tokens.go index 74c9898..ebd9e41 100644 --- a/apps/backend/internal/auth/tokens.go +++ b/apps/backend/internal/auth/tokens.go @@ -7,6 +7,7 @@ import ( "fmt" "time" + "github.com/get-drexa/drexa/internal/user" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/uptrace/bun" @@ -35,7 +36,7 @@ type RefreshToken struct { CreatedAt time.Time `bun:"created_at,notnull"` } -func GenerateAccessToken(user *User, c *TokenConfig) (string, error) { +func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) { now := time.Now() token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ @@ -54,7 +55,7 @@ func GenerateAccessToken(user *User, c *TokenConfig) (string, error) { return signed, nil } -func GenerateRefreshToken(user *User, c *TokenConfig) (*RefreshToken, error) { +func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error) { now := time.Now() buf := make([]byte, refreshTokenByteLength) @@ -80,6 +81,8 @@ func GenerateRefreshToken(user *User, c *TokenConfig) (*RefreshToken, error) { }, nil } +// ParseAccessToken parses a JWT access token and returns the claims. +// Returns an InvalidAccessTokenError if the token is invalid. func ParseAccessToken(token string, c *TokenConfig) (*jwt.RegisteredClaims, error) { parsed, err := jwt.ParseWithClaims(token, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) { return c.SecretKey, nil diff --git a/apps/backend/internal/drexa/server.go b/apps/backend/internal/drexa/server.go index 5b2ce3e..99c41a1 100644 --- a/apps/backend/internal/drexa/server.go +++ b/apps/backend/internal/drexa/server.go @@ -9,6 +9,7 @@ import ( "github.com/get-drexa/drexa/internal/auth" "github.com/get-drexa/drexa/internal/database" + "github.com/get-drexa/drexa/internal/user" "github.com/gofiber/fiber/v2" ) @@ -24,7 +25,8 @@ func NewServer(c ServerConfig) *fiber.App { app := fiber.New() db := database.NewFromPostgres(c.PostgresURL) - authService := auth.NewService(db, auth.TokenConfig{ + userService := user.NewService(db) + authService := auth.NewService(db, userService, auth.TokenConfig{ Issuer: c.JWTIssuer, Audience: c.JWTAudience, SecretKey: c.JWTSecretKey, diff --git a/apps/backend/internal/user/err.go b/apps/backend/internal/user/err.go new file mode 100644 index 0000000..1c74ca1 --- /dev/null +++ b/apps/backend/internal/user/err.go @@ -0,0 +1,19 @@ +package user + +import ( + "fmt" + + "github.com/google/uuid" +) + +type NotFoundError struct { + id uuid.UUID +} + +func newNotFoundError(id uuid.UUID) *NotFoundError { + return &NotFoundError{id} +} + +func (e *NotFoundError) Error() string { + return fmt.Sprintf("user not found: %v", e.id) +} diff --git a/apps/backend/internal/user/service.go b/apps/backend/internal/user/service.go new file mode 100644 index 0000000..696a1d9 --- /dev/null +++ b/apps/backend/internal/user/service.go @@ -0,0 +1,32 @@ +package user + +import ( + "context" + "database/sql" + "errors" + + "github.com/google/uuid" + "github.com/uptrace/bun" +) + +type Service struct { + db *bun.DB +} + +func NewService(db *bun.DB) *Service { + return &Service{ + db: db, + } +} + +func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) { + var user User + err := s.db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, newNotFoundError(id) + } + return nil, err + } + return &user, nil +} diff --git a/apps/backend/internal/auth/user.go b/apps/backend/internal/user/user.go similarity index 97% rename from apps/backend/internal/auth/user.go rename to apps/backend/internal/user/user.go index c95ef8b..24d6662 100644 --- a/apps/backend/internal/auth/user.go +++ b/apps/backend/internal/user/user.go @@ -1,4 +1,4 @@ -package auth +package user import ( "github.com/google/uuid"