fix: ret invalid cred if usr not found in login

This commit is contained in:
2025-11-26 01:42:52 +00:00
parent 389fe35a0a
commit 06c3951293
6 changed files with 168 additions and 42 deletions

View File

@@ -5,6 +5,7 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"github.com/get-drexa/drexa/internal/password"
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -39,25 +40,27 @@ func NewService(db *bun.DB, userService *user.Service, tokenConfig TokenConfig)
} }
} }
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password string) (*LoginResult, error) { func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain string) (*LoginResult, error) {
var u user.User u, err := s.userService.UserByEmail(ctx, email)
err := s.db.NewSelect().Model(&u).Where("email = ?", email).Scan(ctx)
if err != nil { if err != nil {
var nf *user.NotFoundError
if errors.As(err, &nf) {
return nil, ErrInvalidCredentials
}
return nil, err return nil, err
} }
ok, err := VerifyPassword(password, u.Password) ok, err := password.Verify(plain, u.Password)
if err != nil || !ok { if err != nil || !ok {
return nil, ErrInvalidCredentials return nil, ErrInvalidCredentials
} }
at, err := GenerateAccessToken(&u, &s.tokenConfig) at, err := GenerateAccessToken(u, &s.tokenConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rt, err := GenerateRefreshToken(&u, &s.tokenConfig) rt, err := GenerateRefreshToken(u, &s.tokenConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -68,48 +71,44 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password
} }
return &LoginResult{ return &LoginResult{
User: &u, User: u,
AccessToken: at, AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token), RefreshToken: hex.EncodeToString(rt.Token),
}, nil }, nil
} }
func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) { func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) {
u := user.User{ hashed, err := password.Hash(opts.password)
if err != nil {
return nil, err
}
u, err := s.userService.RegisterUser(ctx, user.UserRegistrationOptions{
Email: opts.email, Email: opts.email,
DisplayName: opts.displayName, DisplayName: opts.displayName,
} Password: hashed,
})
exists, err := s.db.NewSelect().Model(&u).Where("email = ?", opts.email).Exists(ctx)
if err != nil {
return nil, err
}
if exists {
return nil, ErrUserExists
}
u.Password, err = HashPassword(opts.password)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = s.db.NewInsert().Model(&u).Exec(ctx) at, err := GenerateAccessToken(u, &s.tokenConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
at, err := GenerateAccessToken(&u, &s.tokenConfig) rt, err := GenerateRefreshToken(u, &s.tokenConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rt, err := GenerateRefreshToken(&u, &s.tokenConfig) _, err = s.db.NewInsert().Model(rt).Exec(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &LoginResult{ return &LoginResult{
User: &u, User: u,
AccessToken: at, AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token), RefreshToken: hex.EncodeToString(rt.Token),
}, nil }, nil

View File

@@ -0,0 +1,61 @@
package database
import (
"errors"
"github.com/uptrace/bun/driver/pgdriver"
)
// PostgreSQL SQLSTATE error codes.
// See: https://www.postgresql.org/docs/current/errcodes-appendix.html
const (
PgUniqueViolation = "23505"
PgForeignKeyViolation = "23503"
PgNotNullViolation = "23502"
)
// PostgreSQL protocol error field identifiers used with pgdriver.Error.Field().
// See: https://www.postgresql.org/docs/current/protocol-error-fields.html
//
// Common fields:
// - 'C' - SQLSTATE code (e.g., "23505")
// - 'M' - Primary error message
// - 'D' - Detail message
// - 'H' - Hint
// - 's' - Schema name
// - 't' - Table name
// - 'c' - Column name
// - 'n' - Constraint name
const (
pgFieldCode = 'C'
pgFieldConstraint = 'n'
)
// IsUniqueViolation checks if the error is a PostgreSQL unique constraint violation.
func IsUniqueViolation(err error) bool {
return hasPgCode(err, PgUniqueViolation)
}
// IsForeignKeyViolation checks if the error is a PostgreSQL foreign key violation.
func IsForeignKeyViolation(err error) bool {
return hasPgCode(err, PgForeignKeyViolation)
}
// IsNotNullViolation checks if the error is a PostgreSQL not-null constraint violation.
func IsNotNullViolation(err error) bool {
return hasPgCode(err, PgNotNullViolation)
}
// ConstraintName returns the constraint name from a PostgreSQL error, or empty string if not applicable.
func ConstraintName(err error) string {
var pgErr pgdriver.Error
if errors.As(err, &pgErr) {
return pgErr.Field(pgFieldConstraint)
}
return ""
}
func hasPgCode(err error, code string) bool {
var pgErr pgdriver.Error
return errors.As(err, &pgErr) && pgErr.Field(pgFieldCode) == code
}

View File

@@ -1,4 +1,4 @@
package auth package password
import ( import (
"crypto/rand" "crypto/rand"
@@ -11,6 +11,10 @@ import (
"golang.org/x/crypto/argon2" "golang.org/x/crypto/argon2"
) )
// Hashed represents a securely hashed password.
// This type ensures plaintext passwords cannot be accidentally stored.
type Hashed string
// argon2id parameters // argon2id parameters
const ( const (
memory = 64 * 1024 memory = 64 * 1024
@@ -34,14 +38,15 @@ type argon2Hash struct {
hash []byte hash []byte
} }
func HashPassword(password string) (string, error) { // Hash securely hashes a plaintext password using argon2id.
func Hash(plain string) (Hashed, error) {
salt := make([]byte, saltLength) salt := make([]byte, saltLength)
if _, err := rand.Read(salt); err != nil { if _, err := rand.Read(salt); err != nil {
return "", fmt.Errorf("failed to generate salt: %w", err) return "", fmt.Errorf("failed to generate salt: %w", err)
} }
hash := argon2.IDKey( hash := argon2.IDKey(
[]byte(password), []byte(plain),
salt, salt,
iterations, iterations,
memory, memory,
@@ -52,7 +57,7 @@ func HashPassword(password string) (string, error) {
b64Salt := base64.RawStdEncoding.EncodeToString(salt) b64Salt := base64.RawStdEncoding.EncodeToString(salt)
b64Hash := base64.RawStdEncoding.EncodeToString(hash) b64Hash := base64.RawStdEncoding.EncodeToString(hash)
encodedHash := fmt.Sprintf( encoded := fmt.Sprintf(
"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", "$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
argon2.Version, argon2.Version,
memory, memory,
@@ -62,17 +67,18 @@ func HashPassword(password string) (string, error) {
b64Hash, b64Hash,
) )
return encodedHash, nil return Hashed(encoded), nil
} }
func VerifyPassword(password, encodedHash string) (bool, error) { // Verify checks if a plaintext password matches a hashed password.
h, err := decodeHash(encodedHash) func Verify(plain string, hashed Hashed) (bool, error) {
h, err := decodeHash(string(hashed))
if err != nil { if err != nil {
return false, err return false, err
} }
otherHash := argon2.IDKey( otherHash := argon2.IDKey(
[]byte(password), []byte(plain),
h.salt, h.salt,
h.iterations, h.iterations,
h.memory, h.memory,

View File

@@ -7,13 +7,30 @@ import (
) )
type NotFoundError struct { type NotFoundError struct {
// ID is the ID that was used to try to find the user.
// Not set if not tried.
id uuid.UUID id uuid.UUID
// Email is the email that was used to try to find the user.
// Not set if not tried.
email string
} }
func newNotFoundError(id uuid.UUID) *NotFoundError { func newNotFoundError(id uuid.UUID, email string) *NotFoundError {
return &NotFoundError{id} return &NotFoundError{id, email}
} }
func (e *NotFoundError) Error() string { func (e *NotFoundError) Error() string {
return fmt.Sprintf("user not found: %v", e.id) return fmt.Sprintf("user not found: %v", e.id)
} }
type AlreadyExistsError struct {
// Email is the email that was used to try to create the user.
Email string
}
func newAlreadyExistsError(email string) *AlreadyExistsError {
return &AlreadyExistsError{email}
}
func (e *AlreadyExistsError) Error() string {
return fmt.Sprintf("user with email %s already exists", e.Email)
}

View File

@@ -5,6 +5,8 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"github.com/get-drexa/drexa/internal/database"
"github.com/get-drexa/drexa/internal/password"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@@ -13,20 +15,60 @@ type Service struct {
db *bun.DB db *bun.DB
} }
type UserRegistrationOptions struct {
Email string
DisplayName string
Password password.Hashed
}
func NewService(db *bun.DB) *Service { func NewService(db *bun.DB) *Service {
return &Service{ return &Service{
db: db, db: db,
} }
} }
func (s *Service) RegisterUser(ctx context.Context, opts UserRegistrationOptions) (*User, error) {
u := User{
Email: opts.Email,
DisplayName: opts.DisplayName,
Password: opts.Password,
}
_, err := s.db.NewInsert().Model(&u).Returning("*").Exec(ctx)
if err != nil {
if database.IsUniqueViolation(err) {
return nil, newAlreadyExistsError(u.Email)
}
return nil, err
}
return &u, nil
}
func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) { func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) {
var user User var user User
err := s.db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx) err := s.db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, newNotFoundError(id) return nil, newNotFoundError(id, "")
} }
return nil, err return nil, err
} }
return &user, nil return &user, nil
} }
func (s *Service) UserByEmail(ctx context.Context, email string) (*User, error) {
var user User
err := s.db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, newNotFoundError(uuid.Nil, email)
}
return nil, err
}
return &user, nil
}
func (s *Service) UserExistsByEmail(ctx context.Context, email string) (bool, error) {
return s.db.NewSelect().Model(&User{}).Where("email = ?", email).Exists(ctx)
}

View File

@@ -1,6 +1,7 @@
package user package user
import ( import (
"github.com/get-drexa/drexa/internal/password"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@@ -11,7 +12,7 @@ type User struct {
ID uuid.UUID `bun:",pk,type:uuid" json:"id"` ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
DisplayName string `bun:"display_name,notnull" json:"displayName"` DisplayName string `bun:"display_name,notnull" json:"displayName"`
Email string `bun:"email,unique,notnull" json:"email"` Email string `bun:"email,unique,notnull" json:"email"`
Password string `bun:"password,notnull" json:"-"` Password password.Hashed `bun:"password,notnull" json:"-"`
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"` StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"` StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
} }