diff --git a/apps/backend/internal/auth/service.go b/apps/backend/internal/auth/service.go index 76f5418..734db82 100644 --- a/apps/backend/internal/auth/service.go +++ b/apps/backend/internal/auth/service.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "errors" + "github.com/get-drexa/drexa/internal/password" "github.com/get-drexa/drexa/internal/user" "github.com/google/uuid" "github.com/uptrace/bun" @@ -39,25 +40,27 @@ func NewService(db *bun.DB, userService *user.Service, tokenConfig TokenConfig) } } -func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password string) (*LoginResult, error) { - var u user.User - - err := s.db.NewSelect().Model(&u).Where("email = ?", email).Scan(ctx) +func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain string) (*LoginResult, error) { + u, err := s.userService.UserByEmail(ctx, email) if err != nil { + var nf *user.NotFoundError + if errors.As(err, &nf) { + return nil, ErrInvalidCredentials + } return nil, err } - ok, err := VerifyPassword(password, u.Password) + ok, err := password.Verify(plain, u.Password) if err != nil || !ok { return nil, ErrInvalidCredentials } - at, err := GenerateAccessToken(&u, &s.tokenConfig) + at, err := GenerateAccessToken(u, &s.tokenConfig) if err != nil { return nil, err } - rt, err := GenerateRefreshToken(&u, &s.tokenConfig) + rt, err := GenerateRefreshToken(u, &s.tokenConfig) if err != nil { return nil, err } @@ -68,48 +71,44 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password } return &LoginResult{ - User: &u, + User: u, AccessToken: at, RefreshToken: hex.EncodeToString(rt.Token), }, nil } func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) { - u := user.User{ + hashed, err := password.Hash(opts.password) + if err != nil { + return nil, err + } + + u, err := s.userService.RegisterUser(ctx, user.UserRegistrationOptions{ Email: opts.email, DisplayName: opts.displayName, - } - - exists, err := s.db.NewSelect().Model(&u).Where("email = ?", opts.email).Exists(ctx) - if err != nil { - return nil, err - } - if exists { - return nil, ErrUserExists - } - - u.Password, err = HashPassword(opts.password) + Password: hashed, + }) if err != nil { return nil, err } - _, err = s.db.NewInsert().Model(&u).Exec(ctx) + at, err := GenerateAccessToken(u, &s.tokenConfig) if err != nil { return nil, err } - at, err := GenerateAccessToken(&u, &s.tokenConfig) + rt, err := GenerateRefreshToken(u, &s.tokenConfig) if err != nil { return nil, err } - rt, err := GenerateRefreshToken(&u, &s.tokenConfig) + _, err = s.db.NewInsert().Model(rt).Exec(ctx) if err != nil { return nil, err } return &LoginResult{ - User: &u, + User: u, AccessToken: at, RefreshToken: hex.EncodeToString(rt.Token), }, nil diff --git a/apps/backend/internal/database/errs.go b/apps/backend/internal/database/errs.go new file mode 100644 index 0000000..9be9d6f --- /dev/null +++ b/apps/backend/internal/database/errs.go @@ -0,0 +1,61 @@ +package database + +import ( + "errors" + + "github.com/uptrace/bun/driver/pgdriver" +) + +// PostgreSQL SQLSTATE error codes. +// See: https://www.postgresql.org/docs/current/errcodes-appendix.html +const ( + PgUniqueViolation = "23505" + PgForeignKeyViolation = "23503" + PgNotNullViolation = "23502" +) + +// PostgreSQL protocol error field identifiers used with pgdriver.Error.Field(). +// See: https://www.postgresql.org/docs/current/protocol-error-fields.html +// +// Common fields: +// - 'C' - SQLSTATE code (e.g., "23505") +// - 'M' - Primary error message +// - 'D' - Detail message +// - 'H' - Hint +// - 's' - Schema name +// - 't' - Table name +// - 'c' - Column name +// - 'n' - Constraint name +const ( + pgFieldCode = 'C' + pgFieldConstraint = 'n' +) + +// IsUniqueViolation checks if the error is a PostgreSQL unique constraint violation. +func IsUniqueViolation(err error) bool { + return hasPgCode(err, PgUniqueViolation) +} + +// IsForeignKeyViolation checks if the error is a PostgreSQL foreign key violation. +func IsForeignKeyViolation(err error) bool { + return hasPgCode(err, PgForeignKeyViolation) +} + +// IsNotNullViolation checks if the error is a PostgreSQL not-null constraint violation. +func IsNotNullViolation(err error) bool { + return hasPgCode(err, PgNotNullViolation) +} + +// ConstraintName returns the constraint name from a PostgreSQL error, or empty string if not applicable. +func ConstraintName(err error) string { + var pgErr pgdriver.Error + if errors.As(err, &pgErr) { + return pgErr.Field(pgFieldConstraint) + } + return "" +} + +func hasPgCode(err error, code string) bool { + var pgErr pgdriver.Error + return errors.As(err, &pgErr) && pgErr.Field(pgFieldCode) == code +} diff --git a/apps/backend/internal/auth/password.go b/apps/backend/internal/password/password.go similarity index 81% rename from apps/backend/internal/auth/password.go rename to apps/backend/internal/password/password.go index 58316d1..29e2189 100644 --- a/apps/backend/internal/auth/password.go +++ b/apps/backend/internal/password/password.go @@ -1,4 +1,4 @@ -package auth +package password import ( "crypto/rand" @@ -11,6 +11,10 @@ import ( "golang.org/x/crypto/argon2" ) +// Hashed represents a securely hashed password. +// This type ensures plaintext passwords cannot be accidentally stored. +type Hashed string + // argon2id parameters const ( memory = 64 * 1024 @@ -34,14 +38,15 @@ type argon2Hash struct { hash []byte } -func HashPassword(password string) (string, error) { +// Hash securely hashes a plaintext password using argon2id. +func Hash(plain string) (Hashed, error) { salt := make([]byte, saltLength) if _, err := rand.Read(salt); err != nil { return "", fmt.Errorf("failed to generate salt: %w", err) } hash := argon2.IDKey( - []byte(password), + []byte(plain), salt, iterations, memory, @@ -52,7 +57,7 @@ func HashPassword(password string) (string, error) { b64Salt := base64.RawStdEncoding.EncodeToString(salt) b64Hash := base64.RawStdEncoding.EncodeToString(hash) - encodedHash := fmt.Sprintf( + encoded := fmt.Sprintf( "$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, memory, @@ -62,17 +67,18 @@ func HashPassword(password string) (string, error) { b64Hash, ) - return encodedHash, nil + return Hashed(encoded), nil } -func VerifyPassword(password, encodedHash string) (bool, error) { - h, err := decodeHash(encodedHash) +// Verify checks if a plaintext password matches a hashed password. +func Verify(plain string, hashed Hashed) (bool, error) { + h, err := decodeHash(string(hashed)) if err != nil { return false, err } otherHash := argon2.IDKey( - []byte(password), + []byte(plain), h.salt, h.iterations, h.memory, diff --git a/apps/backend/internal/user/err.go b/apps/backend/internal/user/err.go index 1c74ca1..17473aa 100644 --- a/apps/backend/internal/user/err.go +++ b/apps/backend/internal/user/err.go @@ -7,13 +7,30 @@ import ( ) type NotFoundError struct { + // ID is the ID that was used to try to find the user. + // Not set if not tried. id uuid.UUID + + // Email is the email that was used to try to find the user. + // Not set if not tried. + email string } -func newNotFoundError(id uuid.UUID) *NotFoundError { - return &NotFoundError{id} +func newNotFoundError(id uuid.UUID, email string) *NotFoundError { + return &NotFoundError{id, email} } - func (e *NotFoundError) Error() string { return fmt.Sprintf("user not found: %v", e.id) } + +type AlreadyExistsError struct { + // Email is the email that was used to try to create the user. + Email string +} + +func newAlreadyExistsError(email string) *AlreadyExistsError { + return &AlreadyExistsError{email} +} +func (e *AlreadyExistsError) Error() string { + return fmt.Sprintf("user with email %s already exists", e.Email) +} diff --git a/apps/backend/internal/user/service.go b/apps/backend/internal/user/service.go index 696a1d9..276322c 100644 --- a/apps/backend/internal/user/service.go +++ b/apps/backend/internal/user/service.go @@ -5,6 +5,8 @@ import ( "database/sql" "errors" + "github.com/get-drexa/drexa/internal/database" + "github.com/get-drexa/drexa/internal/password" "github.com/google/uuid" "github.com/uptrace/bun" ) @@ -13,20 +15,60 @@ type Service struct { db *bun.DB } +type UserRegistrationOptions struct { + Email string + DisplayName string + Password password.Hashed +} + func NewService(db *bun.DB) *Service { return &Service{ db: db, } } +func (s *Service) RegisterUser(ctx context.Context, opts UserRegistrationOptions) (*User, error) { + u := User{ + Email: opts.Email, + DisplayName: opts.DisplayName, + Password: opts.Password, + } + + _, err := s.db.NewInsert().Model(&u).Returning("*").Exec(ctx) + if err != nil { + if database.IsUniqueViolation(err) { + return nil, newAlreadyExistsError(u.Email) + } + return nil, err + } + + return &u, nil +} + func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) { var user User err := s.db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return nil, newNotFoundError(id) + return nil, newNotFoundError(id, "") } return nil, err } return &user, nil } + +func (s *Service) UserByEmail(ctx context.Context, email string) (*User, error) { + var user User + err := s.db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, newNotFoundError(uuid.Nil, email) + } + return nil, err + } + return &user, nil +} + +func (s *Service) UserExistsByEmail(ctx context.Context, email string) (bool, error) { + return s.db.NewSelect().Model(&User{}).Where("email = ?", email).Exists(ctx) +} diff --git a/apps/backend/internal/user/user.go b/apps/backend/internal/user/user.go index 24d6662..3a1a828 100644 --- a/apps/backend/internal/user/user.go +++ b/apps/backend/internal/user/user.go @@ -1,6 +1,7 @@ package user import ( + "github.com/get-drexa/drexa/internal/password" "github.com/google/uuid" "github.com/uptrace/bun" ) @@ -8,10 +9,10 @@ import ( type User struct { bun.BaseModel `bun:"users"` - ID uuid.UUID `bun:",pk,type:uuid" json:"id"` - DisplayName string `bun:"display_name,notnull" json:"displayName"` - Email string `bun:"email,unique,notnull" json:"email"` - Password string `bun:"password,notnull" json:"-"` - StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"` - StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"` + ID uuid.UUID `bun:",pk,type:uuid" json:"id"` + DisplayName string `bun:"display_name,notnull" json:"displayName"` + Email string `bun:"email,unique,notnull" json:"email"` + Password password.Hashed `bun:"password,notnull" json:"-"` + StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"` + StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"` }