mirror of
https://github.com/get-drexa/drive.git
synced 2025-12-01 05:51:39 +00:00
fix: ret invalid cred if usr not found in login
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
|
"github.com/get-drexa/drexa/internal/password"
|
||||||
"github.com/get-drexa/drexa/internal/user"
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
@@ -39,25 +40,27 @@ func NewService(db *bun.DB, userService *user.Service, tokenConfig TokenConfig)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password string) (*LoginResult, error) {
|
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, plain string) (*LoginResult, error) {
|
||||||
var u user.User
|
u, err := s.userService.UserByEmail(ctx, email)
|
||||||
|
|
||||||
err := s.db.NewSelect().Model(&u).Where("email = ?", email).Scan(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
var nf *user.NotFoundError
|
||||||
|
if errors.As(err, &nf) {
|
||||||
|
return nil, ErrInvalidCredentials
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := VerifyPassword(password, u.Password)
|
ok, err := password.Verify(plain, u.Password)
|
||||||
if err != nil || !ok {
|
if err != nil || !ok {
|
||||||
return nil, ErrInvalidCredentials
|
return nil, ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
at, err := GenerateAccessToken(&u, &s.tokenConfig)
|
at, err := GenerateAccessToken(u, &s.tokenConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rt, err := GenerateRefreshToken(&u, &s.tokenConfig)
|
rt, err := GenerateRefreshToken(u, &s.tokenConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -68,48 +71,44 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, email, password
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &LoginResult{
|
return &LoginResult{
|
||||||
User: &u,
|
User: u,
|
||||||
AccessToken: at,
|
AccessToken: at,
|
||||||
RefreshToken: hex.EncodeToString(rt.Token),
|
RefreshToken: hex.EncodeToString(rt.Token),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) {
|
func (s *Service) Register(ctx context.Context, opts registerOptions) (*LoginResult, error) {
|
||||||
u := user.User{
|
hashed, err := password.Hash(opts.password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := s.userService.RegisterUser(ctx, user.UserRegistrationOptions{
|
||||||
Email: opts.email,
|
Email: opts.email,
|
||||||
DisplayName: opts.displayName,
|
DisplayName: opts.displayName,
|
||||||
}
|
Password: hashed,
|
||||||
|
})
|
||||||
exists, err := s.db.NewSelect().Model(&u).Where("email = ?", opts.email).Exists(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if exists {
|
|
||||||
return nil, ErrUserExists
|
|
||||||
}
|
|
||||||
|
|
||||||
u.Password, err = HashPassword(opts.password)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.db.NewInsert().Model(&u).Exec(ctx)
|
at, err := GenerateAccessToken(u, &s.tokenConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
at, err := GenerateAccessToken(&u, &s.tokenConfig)
|
rt, err := GenerateRefreshToken(u, &s.tokenConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rt, err := GenerateRefreshToken(&u, &s.tokenConfig)
|
_, err = s.db.NewInsert().Model(rt).Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &LoginResult{
|
return &LoginResult{
|
||||||
User: &u,
|
User: u,
|
||||||
AccessToken: at,
|
AccessToken: at,
|
||||||
RefreshToken: hex.EncodeToString(rt.Token),
|
RefreshToken: hex.EncodeToString(rt.Token),
|
||||||
}, nil
|
}, nil
|
||||||
|
|||||||
61
apps/backend/internal/database/errs.go
Normal file
61
apps/backend/internal/database/errs.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/uptrace/bun/driver/pgdriver"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostgreSQL SQLSTATE error codes.
|
||||||
|
// See: https://www.postgresql.org/docs/current/errcodes-appendix.html
|
||||||
|
const (
|
||||||
|
PgUniqueViolation = "23505"
|
||||||
|
PgForeignKeyViolation = "23503"
|
||||||
|
PgNotNullViolation = "23502"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostgreSQL protocol error field identifiers used with pgdriver.Error.Field().
|
||||||
|
// See: https://www.postgresql.org/docs/current/protocol-error-fields.html
|
||||||
|
//
|
||||||
|
// Common fields:
|
||||||
|
// - 'C' - SQLSTATE code (e.g., "23505")
|
||||||
|
// - 'M' - Primary error message
|
||||||
|
// - 'D' - Detail message
|
||||||
|
// - 'H' - Hint
|
||||||
|
// - 's' - Schema name
|
||||||
|
// - 't' - Table name
|
||||||
|
// - 'c' - Column name
|
||||||
|
// - 'n' - Constraint name
|
||||||
|
const (
|
||||||
|
pgFieldCode = 'C'
|
||||||
|
pgFieldConstraint = 'n'
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsUniqueViolation checks if the error is a PostgreSQL unique constraint violation.
|
||||||
|
func IsUniqueViolation(err error) bool {
|
||||||
|
return hasPgCode(err, PgUniqueViolation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsForeignKeyViolation checks if the error is a PostgreSQL foreign key violation.
|
||||||
|
func IsForeignKeyViolation(err error) bool {
|
||||||
|
return hasPgCode(err, PgForeignKeyViolation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNotNullViolation checks if the error is a PostgreSQL not-null constraint violation.
|
||||||
|
func IsNotNullViolation(err error) bool {
|
||||||
|
return hasPgCode(err, PgNotNullViolation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConstraintName returns the constraint name from a PostgreSQL error, or empty string if not applicable.
|
||||||
|
func ConstraintName(err error) string {
|
||||||
|
var pgErr pgdriver.Error
|
||||||
|
if errors.As(err, &pgErr) {
|
||||||
|
return pgErr.Field(pgFieldConstraint)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasPgCode(err error, code string) bool {
|
||||||
|
var pgErr pgdriver.Error
|
||||||
|
return errors.As(err, &pgErr) && pgErr.Field(pgFieldCode) == code
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package auth
|
package password
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
@@ -11,6 +11,10 @@ import (
|
|||||||
"golang.org/x/crypto/argon2"
|
"golang.org/x/crypto/argon2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Hashed represents a securely hashed password.
|
||||||
|
// This type ensures plaintext passwords cannot be accidentally stored.
|
||||||
|
type Hashed string
|
||||||
|
|
||||||
// argon2id parameters
|
// argon2id parameters
|
||||||
const (
|
const (
|
||||||
memory = 64 * 1024
|
memory = 64 * 1024
|
||||||
@@ -34,14 +38,15 @@ type argon2Hash struct {
|
|||||||
hash []byte
|
hash []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func HashPassword(password string) (string, error) {
|
// Hash securely hashes a plaintext password using argon2id.
|
||||||
|
func Hash(plain string) (Hashed, error) {
|
||||||
salt := make([]byte, saltLength)
|
salt := make([]byte, saltLength)
|
||||||
if _, err := rand.Read(salt); err != nil {
|
if _, err := rand.Read(salt); err != nil {
|
||||||
return "", fmt.Errorf("failed to generate salt: %w", err)
|
return "", fmt.Errorf("failed to generate salt: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
hash := argon2.IDKey(
|
hash := argon2.IDKey(
|
||||||
[]byte(password),
|
[]byte(plain),
|
||||||
salt,
|
salt,
|
||||||
iterations,
|
iterations,
|
||||||
memory,
|
memory,
|
||||||
@@ -52,7 +57,7 @@ func HashPassword(password string) (string, error) {
|
|||||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||||
|
|
||||||
encodedHash := fmt.Sprintf(
|
encoded := fmt.Sprintf(
|
||||||
"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||||
argon2.Version,
|
argon2.Version,
|
||||||
memory,
|
memory,
|
||||||
@@ -62,17 +67,18 @@ func HashPassword(password string) (string, error) {
|
|||||||
b64Hash,
|
b64Hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
return encodedHash, nil
|
return Hashed(encoded), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func VerifyPassword(password, encodedHash string) (bool, error) {
|
// Verify checks if a plaintext password matches a hashed password.
|
||||||
h, err := decodeHash(encodedHash)
|
func Verify(plain string, hashed Hashed) (bool, error) {
|
||||||
|
h, err := decodeHash(string(hashed))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
otherHash := argon2.IDKey(
|
otherHash := argon2.IDKey(
|
||||||
[]byte(password),
|
[]byte(plain),
|
||||||
h.salt,
|
h.salt,
|
||||||
h.iterations,
|
h.iterations,
|
||||||
h.memory,
|
h.memory,
|
||||||
@@ -7,13 +7,30 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type NotFoundError struct {
|
type NotFoundError struct {
|
||||||
|
// ID is the ID that was used to try to find the user.
|
||||||
|
// Not set if not tried.
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
|
|
||||||
|
// Email is the email that was used to try to find the user.
|
||||||
|
// Not set if not tried.
|
||||||
|
email string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNotFoundError(id uuid.UUID) *NotFoundError {
|
func newNotFoundError(id uuid.UUID, email string) *NotFoundError {
|
||||||
return &NotFoundError{id}
|
return &NotFoundError{id, email}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *NotFoundError) Error() string {
|
func (e *NotFoundError) Error() string {
|
||||||
return fmt.Sprintf("user not found: %v", e.id)
|
return fmt.Sprintf("user not found: %v", e.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AlreadyExistsError struct {
|
||||||
|
// Email is the email that was used to try to create the user.
|
||||||
|
Email string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAlreadyExistsError(email string) *AlreadyExistsError {
|
||||||
|
return &AlreadyExistsError{email}
|
||||||
|
}
|
||||||
|
func (e *AlreadyExistsError) Error() string {
|
||||||
|
return fmt.Sprintf("user with email %s already exists", e.Email)
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
|
"github.com/get-drexa/drexa/internal/database"
|
||||||
|
"github.com/get-drexa/drexa/internal/password"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
@@ -13,20 +15,60 @@ type Service struct {
|
|||||||
db *bun.DB
|
db *bun.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UserRegistrationOptions struct {
|
||||||
|
Email string
|
||||||
|
DisplayName string
|
||||||
|
Password password.Hashed
|
||||||
|
}
|
||||||
|
|
||||||
func NewService(db *bun.DB) *Service {
|
func NewService(db *bun.DB) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) RegisterUser(ctx context.Context, opts UserRegistrationOptions) (*User, error) {
|
||||||
|
u := User{
|
||||||
|
Email: opts.Email,
|
||||||
|
DisplayName: opts.DisplayName,
|
||||||
|
Password: opts.Password,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := s.db.NewInsert().Model(&u).Returning("*").Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if database.IsUniqueViolation(err) {
|
||||||
|
return nil, newAlreadyExistsError(u.Email)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &u, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) {
|
func (s *Service) UserByID(ctx context.Context, id uuid.UUID) (*User, error) {
|
||||||
var user User
|
var user User
|
||||||
err := s.db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx)
|
err := s.db.NewSelect().Model(&user).Where("id = ?", id).Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, newNotFoundError(id)
|
return nil, newNotFoundError(id, "")
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) UserByEmail(ctx context.Context, email string) (*User, error) {
|
||||||
|
var user User
|
||||||
|
err := s.db.NewSelect().Model(&user).Where("email = ?", email).Scan(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, newNotFoundError(uuid.Nil, email)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) UserExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||||
|
return s.db.NewSelect().Model(&User{}).Where("email = ?", email).Exists(ctx)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package user
|
package user
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/get-drexa/drexa/internal/password"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
@@ -8,10 +9,10 @@ import (
|
|||||||
type User struct {
|
type User struct {
|
||||||
bun.BaseModel `bun:"users"`
|
bun.BaseModel `bun:"users"`
|
||||||
|
|
||||||
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
|
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
|
||||||
DisplayName string `bun:"display_name,notnull" json:"displayName"`
|
DisplayName string `bun:"display_name,notnull" json:"displayName"`
|
||||||
Email string `bun:"email,unique,notnull" json:"email"`
|
Email string `bun:"email,unique,notnull" json:"email"`
|
||||||
Password string `bun:"password,notnull" json:"-"`
|
Password password.Hashed `bun:"password,notnull" json:"-"`
|
||||||
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
|
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
|
||||||
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
|
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user