mirror of
https://github.com/get-drexa/drive.git
synced 2025-11-30 21:41:39 +00:00
Compare commits
11 Commits
6984bb209e
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 3e96c42c4a | |||
| 6c61cbe1fd | |||
| ccdeaf0364 | |||
| a3110b67c3 | |||
| 1907cd83c8 | |||
| 033ad65d5f | |||
| 987edc0d4a | |||
| fb8e91dd47 | |||
| 0b8aee5d60 | |||
| 89b62f6d8a | |||
| 1c1392a0a1 |
@@ -7,14 +7,16 @@ require (
|
|||||||
github.com/gofiber/fiber/v2 v2.52.9
|
github.com/gofiber/fiber/v2 v2.52.9
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/sqids/sqids-go v0.4.1
|
github.com/sqids/sqids-go v0.4.1
|
||||||
github.com/uptrace/bun v1.2.15
|
github.com/uptrace/bun v1.2.16
|
||||||
golang.org/x/crypto v0.40.0
|
github.com/uptrace/bun/extra/bundebug v1.2.16
|
||||||
|
golang.org/x/crypto v0.45.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
go.opentelemetry.io/otel v1.37.0 // indirect
|
github.com/fatih/color v1.18.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
go.opentelemetry.io/otel v1.38.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/trace v1.38.0 // indirect
|
||||||
mellium.im/sasl v0.3.2 // indirect
|
mellium.im/sasl v0.3.2 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,12 +31,12 @@ require (
|
|||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||||
github.com/rivo/uniseg v0.2.0 // indirect
|
github.com/rivo/uniseg v0.2.0 // indirect
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||||
github.com/uptrace/bun/dialect/pgdialect v1.2.15
|
github.com/uptrace/bun/dialect/pgdialect v1.2.16
|
||||||
github.com/uptrace/bun/driver/pgdriver v1.2.15
|
github.com/uptrace/bun/driver/pgdriver v1.2.16
|
||||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
github.com/valyala/fasthttp v1.51.0 // indirect
|
github.com/valyala/fasthttp v1.51.0 // indirect
|
||||||
github.com/valyala/tcplisten v1.0.0 // indirect
|
github.com/valyala/tcplisten v1.0.0 // indirect
|
||||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||||
golang.org/x/sys v0.34.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1
|
|||||||
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
|
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||||
|
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik=
|
github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||||
github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw=
|
github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw=
|
||||||
@@ -34,16 +36,18 @@ github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
|||||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
|
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
|
||||||
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
|
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||||
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
|
github.com/uptrace/bun v1.2.16 h1:QlObi6ZIK5Ao7kAALnh91HWYNZUBbVwye52fmlQM9kc=
|
||||||
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038=
|
github.com/uptrace/bun v1.2.16/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM=
|
||||||
github.com/uptrace/bun/dialect/pgdialect v1.2.15 h1:er+/3giAIqpfrXJw+KP9B7ujyQIi5XkPnFmgjAVL6bA=
|
github.com/uptrace/bun/dialect/pgdialect v1.2.16 h1:KFNZ0LxAyczKNfK/IJWMyaleO6eI9/Z5tUv3DE1NVL4=
|
||||||
github.com/uptrace/bun/dialect/pgdialect v1.2.15/go.mod h1:QSiz6Qpy9wlGFsfpf7UMSL6mXAL1jDJhFwuOVacCnOQ=
|
github.com/uptrace/bun/dialect/pgdialect v1.2.16/go.mod h1:IJdMeV4sLfh0LDUZl7TIxLI0LipF1vwTK3hBC7p5qLo=
|
||||||
github.com/uptrace/bun/driver/pgdriver v1.2.15 h1:eZZ60ZtUUE6jjv6VAI1pCMaTgtx3sxmChQzwbvchOOo=
|
github.com/uptrace/bun/driver/pgdriver v1.2.16 h1:b1kpXKUxtTSGYow5Vlsb+dKV3z0R7aSAJNfMfKp61ZU=
|
||||||
github.com/uptrace/bun/driver/pgdriver v1.2.15/go.mod h1:s2zz/BAeScal4KLFDI8PURwATN8s9RDBsElEbnPAjv4=
|
github.com/uptrace/bun/driver/pgdriver v1.2.16/go.mod h1:H6lUZ9CBfp1X5Vq62YGSV7q96/v94ja9AYFjKvdoTk0=
|
||||||
|
github.com/uptrace/bun/extra/bundebug v1.2.16 h1:3OXAfHTU4ydu2+4j05oB1BxPx6+ypdWIVzTugl/7zl0=
|
||||||
|
github.com/uptrace/bun/extra/bundebug v1.2.16/go.mod h1:vk6R/1i67/S2RvUI5AH/m3P5e67mOkfDCmmCsAPUumo=
|
||||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA=
|
github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA=
|
||||||
@@ -54,15 +58,15 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
|
|||||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||||
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
|
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||||
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
|
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||||
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||||
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
|
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
|
||||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
|||||||
23
apps/backend/internal/account/account.go
Normal file
23
apps/backend/internal/account/account.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Account struct {
|
||||||
|
bun.BaseModel `bun:"accounts"`
|
||||||
|
|
||||||
|
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
|
||||||
|
UserID uuid.UUID `bun:"user_id,notnull,type:uuid" json:"userId"`
|
||||||
|
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
|
||||||
|
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
|
||||||
|
CreatedAt time.Time `bun:"created_at,notnull,nullzero" json:"createdAt"`
|
||||||
|
UpdatedAt time.Time `bun:"updated_at,notnull,nullzero" json:"updatedAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAccountID() (uuid.UUID, error) {
|
||||||
|
return uuid.NewV7()
|
||||||
|
}
|
||||||
8
apps/backend/internal/account/err.go
Normal file
8
apps/backend/internal/account/err.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package account
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrAccountNotFound = errors.New("account not found")
|
||||||
|
ErrAccountAlreadyExists = errors.New("account already exists")
|
||||||
|
)
|
||||||
132
apps/backend/internal/account/http.go
Normal file
132
apps/backend/internal/account/http.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/get-drexa/drexa/internal/auth"
|
||||||
|
"github.com/get-drexa/drexa/internal/httperr"
|
||||||
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HTTPHandler struct {
|
||||||
|
accountService *Service
|
||||||
|
authService *auth.Service
|
||||||
|
db *bun.DB
|
||||||
|
authMiddleware fiber.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
type registerAccountRequest struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
DisplayName string `json:"displayName"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type registerAccountResponse struct {
|
||||||
|
Account *Account `json:"account"`
|
||||||
|
User *user.User `json:"user"`
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const currentAccountKey = "currentAccount"
|
||||||
|
|
||||||
|
func CurrentAccount(c *fiber.Ctx) *Account {
|
||||||
|
return c.Locals(currentAccountKey).(*Account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHTTPHandler(accountService *Service, authService *auth.Service, db *bun.DB, authMiddleware fiber.Handler) *HTTPHandler {
|
||||||
|
return &HTTPHandler{accountService: accountService, authService: authService, db: db, authMiddleware: authMiddleware}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) fiber.Router {
|
||||||
|
api.Post("/accounts", h.registerAccount)
|
||||||
|
|
||||||
|
account := api.Group("/accounts/:accountID")
|
||||||
|
account.Use(h.authMiddleware)
|
||||||
|
account.Use(h.accountMiddleware)
|
||||||
|
|
||||||
|
account.Get("/", h.getAccount)
|
||||||
|
|
||||||
|
return account
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HTTPHandler) accountMiddleware(c *fiber.Ctx) error {
|
||||||
|
user, err := auth.AuthenticatedUser(c)
|
||||||
|
if err != nil {
|
||||||
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID, err := uuid.Parse(c.Params("accountID"))
|
||||||
|
if err != nil {
|
||||||
|
return c.SendStatus(fiber.StatusNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := h.accountService.AccountByID(c.Context(), h.db, user.ID, accountID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrAccountNotFound) {
|
||||||
|
return c.SendStatus(fiber.StatusNotFound)
|
||||||
|
}
|
||||||
|
return httperr.Internal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Locals(currentAccountKey, account)
|
||||||
|
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HTTPHandler) getAccount(c *fiber.Ctx) error {
|
||||||
|
account := CurrentAccount(c)
|
||||||
|
if account == nil {
|
||||||
|
return c.SendStatus(fiber.StatusNotFound)
|
||||||
|
}
|
||||||
|
return c.JSON(account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HTTPHandler) registerAccount(c *fiber.Ctx) error {
|
||||||
|
req := new(registerAccountRequest)
|
||||||
|
if err := c.BodyParser(req); err != nil {
|
||||||
|
return c.SendStatus(fiber.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := h.db.BeginTx(c.Context(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return httperr.Internal(err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
acc, u, err := h.accountService.Register(c.Context(), tx, RegisterOptions{
|
||||||
|
Email: req.Email,
|
||||||
|
Password: req.Password,
|
||||||
|
DisplayName: req.DisplayName,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
var ae *user.AlreadyExistsError
|
||||||
|
if errors.As(err, &ae) {
|
||||||
|
return c.SendStatus(fiber.StatusConflict)
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrAccountAlreadyExists) {
|
||||||
|
return c.SendStatus(fiber.StatusConflict)
|
||||||
|
}
|
||||||
|
return httperr.Internal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.authService.GenerateTokenForUser(c.Context(), tx, u)
|
||||||
|
if err != nil {
|
||||||
|
return httperr.Internal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return httperr.Internal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(registerAccountResponse{
|
||||||
|
Account: acc,
|
||||||
|
User: u,
|
||||||
|
AccessToken: result.AccessToken,
|
||||||
|
RefreshToken: result.RefreshToken,
|
||||||
|
})
|
||||||
|
}
|
||||||
115
apps/backend/internal/account/service.go
Normal file
115
apps/backend/internal/account/service.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/get-drexa/drexa/internal/database"
|
||||||
|
"github.com/get-drexa/drexa/internal/password"
|
||||||
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
|
"github.com/get-drexa/drexa/internal/virtualfs"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
userService user.Service
|
||||||
|
vfs *virtualfs.VirtualFS
|
||||||
|
}
|
||||||
|
|
||||||
|
type RegisterOptions struct {
|
||||||
|
Email string
|
||||||
|
Password string
|
||||||
|
DisplayName string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateAccountOptions struct {
|
||||||
|
OrganizationID uuid.UUID
|
||||||
|
QuotaBytes int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(userService *user.Service, vfs *virtualfs.VirtualFS) *Service {
|
||||||
|
return &Service{
|
||||||
|
userService: *userService,
|
||||||
|
vfs: vfs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Register(ctx context.Context, db bun.IDB, opts RegisterOptions) (*Account, *user.User, error) {
|
||||||
|
hashed, err := password.Hash(opts.Password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
|
||||||
|
Email: opts.Email,
|
||||||
|
Password: hashed,
|
||||||
|
DisplayName: opts.DisplayName,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := s.CreateAccount(ctx, db, u.ID, CreateAccountOptions{
|
||||||
|
// TODO: make quota configurable
|
||||||
|
QuotaBytes: 1024 * 1024 * 1024, // 1GB
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.vfs.CreateDirectory(ctx, db, acc.ID, uuid.Nil, virtualfs.RootDirectoryName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return acc, u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) CreateAccount(ctx context.Context, db bun.IDB, userID uuid.UUID, opts CreateAccountOptions) (*Account, error) {
|
||||||
|
id, err := newAccountID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: id,
|
||||||
|
UserID: userID,
|
||||||
|
StorageQuotaBytes: opts.QuotaBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.NewInsert().Model(account).Returning("*").Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if database.IsUniqueViolation(err) {
|
||||||
|
return nil, ErrAccountAlreadyExists
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) AccountByUserID(ctx context.Context, db bun.IDB, userID uuid.UUID) (*Account, error) {
|
||||||
|
var account Account
|
||||||
|
err := db.NewSelect().Model(&account).Where("user_id = ?", userID).Scan(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrAccountNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) AccountByID(ctx context.Context, db bun.IDB, userID uuid.UUID, id uuid.UUID) (*Account, error) {
|
||||||
|
var account Account
|
||||||
|
err := db.NewSelect().Model(&account).Where("user_id = ?", userID).Where("id = ?", id).Scan(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrAccountNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &account, nil
|
||||||
|
}
|
||||||
@@ -2,8 +2,8 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
|
||||||
|
|
||||||
|
"github.com/get-drexa/drexa/internal/httperr"
|
||||||
"github.com/get-drexa/drexa/internal/user"
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
@@ -14,12 +14,6 @@ type loginRequest struct {
|
|||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type registerRequest struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
DisplayName string `json:"displayName"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type loginResponse struct {
|
type loginResponse struct {
|
||||||
User user.User `json:"user"`
|
User user.User `json:"user"`
|
||||||
AccessToken string `json:"accessToken"`
|
AccessToken string `json:"accessToken"`
|
||||||
@@ -37,9 +31,7 @@ func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
|
|||||||
|
|
||||||
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
|
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
|
||||||
auth := api.Group("/auth")
|
auth := api.Group("/auth")
|
||||||
|
|
||||||
auth.Post("/login", h.Login)
|
auth.Post("/login", h.Login)
|
||||||
auth.Post("/register", h.Register)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
||||||
@@ -50,59 +42,20 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
|
|||||||
|
|
||||||
tx, err := h.db.BeginTx(c.Context(), nil)
|
tx, err := h.db.BeginTx(c.Context(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
return httperr.Internal(err)
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
result, err := h.service.LoginWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
|
result, err := h.service.AuthenticateWithEmailAndPassword(c.Context(), tx, req.Email, req.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrInvalidCredentials) {
|
if errors.Is(err, ErrInvalidCredentials) {
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
|
||||||
}
|
}
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
return httperr.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
return httperr.Internal(err)
|
||||||
}
|
|
||||||
|
|
||||||
return c.JSON(loginResponse{
|
|
||||||
User: *result.User,
|
|
||||||
AccessToken: result.AccessToken,
|
|
||||||
RefreshToken: result.RefreshToken,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HTTPHandler) Register(c *fiber.Ctx) error {
|
|
||||||
req := new(registerRequest)
|
|
||||||
if err := c.BodyParser(req); err != nil {
|
|
||||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
|
||||||
}
|
|
||||||
|
|
||||||
tx, err := h.db.BeginTx(c.Context(), nil)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to begin transaction", "error", err)
|
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
result, err := h.service.Register(c.Context(), tx, registerOptions{
|
|
||||||
email: req.Email,
|
|
||||||
password: req.Password,
|
|
||||||
displayName: req.DisplayName,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
var ae *user.AlreadyExistsError
|
|
||||||
if errors.As(err, &ae) {
|
|
||||||
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"})
|
|
||||||
}
|
|
||||||
slog.Error("failed to register user", "error", err)
|
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
slog.Error("failed to commit transaction", "error", err)
|
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(loginResponse{
|
return c.JSON(loginResponse{
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/get-drexa/drexa/internal/httperr"
|
||||||
"github.com/get-drexa/drexa/internal/user"
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
@@ -17,11 +19,13 @@ func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
|
|||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
authHeader := c.Get("Authorization")
|
authHeader := c.Get("Authorization")
|
||||||
if authHeader == "" {
|
if authHeader == "" {
|
||||||
|
slog.Info("no auth header")
|
||||||
return c.SendStatus(fiber.StatusUnauthorized)
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Split(authHeader, " ")
|
parts := strings.Split(authHeader, " ")
|
||||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||||
|
slog.Info("invalid auth header")
|
||||||
return c.SendStatus(fiber.StatusUnauthorized)
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -30,15 +34,17 @@ func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var e *InvalidAccessTokenError
|
var e *InvalidAccessTokenError
|
||||||
if errors.As(err, &e) {
|
if errors.As(err, &e) {
|
||||||
|
slog.Info("invalid access token")
|
||||||
return c.SendStatus(fiber.StatusUnauthorized)
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
var nf *user.NotFoundError
|
var nf *user.NotFoundError
|
||||||
if errors.As(err, &nf) {
|
if errors.As(err, &nf) {
|
||||||
|
slog.Info("user not found")
|
||||||
return c.SendStatus(fiber.StatusUnauthorized)
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.SendStatus(fiber.StatusInternalServerError)
|
return httperr.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Locals(authenticatedUserKey, u)
|
c.Locals(authenticatedUserKey, u)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
"github.com/get-drexa/drexa/internal/password"
|
"github.com/get-drexa/drexa/internal/password"
|
||||||
"github.com/get-drexa/drexa/internal/user"
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
@@ -11,7 +12,7 @@ import (
|
|||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LoginResult struct {
|
type AuthenticationResult struct {
|
||||||
User *user.User
|
User *user.User
|
||||||
AccessToken string
|
AccessToken string
|
||||||
RefreshToken string
|
RefreshToken string
|
||||||
@@ -24,12 +25,6 @@ type Service struct {
|
|||||||
tokenConfig TokenConfig
|
tokenConfig TokenConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type registerOptions struct {
|
|
||||||
displayName string
|
|
||||||
email string
|
|
||||||
password string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
|
func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
userService: userService,
|
userService: userService,
|
||||||
@@ -37,7 +32,30 @@ func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*LoginResult, error) {
|
func (s *Service) GenerateTokenForUser(ctx context.Context, db bun.IDB, user *user.User) (*AuthenticationResult, error) {
|
||||||
|
at, err := GenerateAccessToken(user, &s.tokenConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rt, err := GenerateRefreshToken(user, &s.tokenConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.NewInsert().Model(rt).Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AuthenticationResult{
|
||||||
|
User: user,
|
||||||
|
AccessToken: at,
|
||||||
|
RefreshToken: hex.EncodeToString(rt.Token),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) AuthenticateWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*AuthenticationResult, error) {
|
||||||
u, err := s.userService.UserByEmail(ctx, db, email)
|
u, err := s.userService.UserByEmail(ctx, db, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var nf *user.NotFoundError
|
var nf *user.NotFoundError
|
||||||
@@ -67,44 +85,7 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, ema
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &LoginResult{
|
return &AuthenticationResult{
|
||||||
User: u,
|
|
||||||
AccessToken: at,
|
|
||||||
RefreshToken: hex.EncodeToString(rt.Token),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions) (*LoginResult, error) {
|
|
||||||
hashed, err := password.Hash(opts.password)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
|
|
||||||
Email: opts.email,
|
|
||||||
DisplayName: opts.displayName,
|
|
||||||
Password: hashed,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
at, err := GenerateAccessToken(u, &s.tokenConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rt, err := GenerateRefreshToken(u, &s.tokenConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = db.NewInsert().Model(rt).Exec(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &LoginResult{
|
|
||||||
User: u,
|
User: u,
|
||||||
AccessToken: at,
|
AccessToken: at,
|
||||||
RefreshToken: hex.EncodeToString(rt.Token),
|
RefreshToken: hex.EncodeToString(rt.Token),
|
||||||
@@ -114,11 +95,13 @@ func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions
|
|||||||
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
|
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
|
||||||
claims, err := ParseAccessToken(token, &s.tokenConfig)
|
claims, err := ParseAccessToken(token, &s.tokenConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
slog.Info("failed to parse access token", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
id, err := uuid.Parse(claims.Subject)
|
id, err := uuid.Parse(claims.Subject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
slog.Info("failed to parse access token subject", "error", err)
|
||||||
return nil, newInvalidAccessTokenError(err)
|
return nil, newInvalidAccessTokenError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ type RefreshToken struct {
|
|||||||
Token []byte `bun:"-"`
|
Token []byte `bun:"-"`
|
||||||
TokenHash string `bun:"token_hash,notnull"`
|
TokenHash string `bun:"token_hash,notnull"`
|
||||||
ExpiresAt time.Time `bun:"expires_at,notnull"`
|
ExpiresAt time.Time `bun:"expires_at,notnull"`
|
||||||
CreatedAt time.Time `bun:"created_at,notnull"`
|
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTokenID() (uuid.UUID, error) {
|
func newTokenID() (uuid.UUID, error) {
|
||||||
|
|||||||
@@ -16,8 +16,7 @@ type FSStore struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type FSStoreConfig struct {
|
type FSStoreConfig struct {
|
||||||
Root string
|
Root string
|
||||||
UploadURL string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFSStore(config FSStoreConfig) *FSStore {
|
func NewFSStore(config FSStoreConfig) *FSStore {
|
||||||
@@ -28,10 +27,6 @@ func (s *FSStore) Initialize(ctx context.Context) error {
|
|||||||
return os.MkdirAll(s.config.Root, 0755)
|
return os.MkdirAll(s.config.Root, 0755)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FSStore) GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error) {
|
|
||||||
return s.config.UploadURL, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FSStore) Put(ctx context.Context, key Key, reader io.Reader) error {
|
func (s *FSStore) Put(ctx context.Context, key Key, reader io.Reader) error {
|
||||||
path := filepath.Join(s.config.Root, string(key))
|
path := filepath.Join(s.config.Root, string(key))
|
||||||
|
|
||||||
@@ -149,3 +144,11 @@ func (s *FSStore) Move(ctx context.Context, srcKey, dstKey Key) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FSStore) SupportsDirectUpload() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FSStore) GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,13 +2,6 @@ package blob
|
|||||||
|
|
||||||
type Key string
|
type Key string
|
||||||
|
|
||||||
type KeyMode int
|
|
||||||
|
|
||||||
const (
|
|
||||||
KeyModeStable KeyMode = iota
|
|
||||||
KeyModeDerived
|
|
||||||
)
|
|
||||||
|
|
||||||
func (k Key) IsNil() bool {
|
func (k Key) IsNil() bool {
|
||||||
return k == ""
|
return k == ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ type UpdateOptions struct {
|
|||||||
|
|
||||||
type Store interface {
|
type Store interface {
|
||||||
Initialize(ctx context.Context) error
|
Initialize(ctx context.Context) error
|
||||||
GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error)
|
|
||||||
Put(ctx context.Context, key Key, reader io.Reader) error
|
Put(ctx context.Context, key Key, reader io.Reader) error
|
||||||
Update(ctx context.Context, key Key, opts UpdateOptions) error
|
Update(ctx context.Context, key Key, opts UpdateOptions) error
|
||||||
Delete(ctx context.Context, key Key) error
|
Delete(ctx context.Context, key Key) error
|
||||||
@@ -25,4 +24,10 @@ type Store interface {
|
|||||||
Read(ctx context.Context, key Key) (io.ReadCloser, error)
|
Read(ctx context.Context, key Key) (io.ReadCloser, error)
|
||||||
ReadRange(ctx context.Context, key Key, offset, length int64) (io.ReadCloser, error)
|
ReadRange(ctx context.Context, key Key, offset, length int64) (io.ReadCloser, error)
|
||||||
ReadSize(ctx context.Context, key Key) (int64, error)
|
ReadSize(ctx context.Context, key Key) (int64, error)
|
||||||
|
|
||||||
|
// SupportsDirectUpload returns true if the store allows files to be uploaded directly to the blob store.
|
||||||
|
SupportsDirectUpload() bool
|
||||||
|
|
||||||
|
// GenerateUploadURL generates a URL that can be used to upload a file directly to the blob store. If unsupported, returns an empty string with no error.
|
||||||
|
GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,14 +7,23 @@ CREATE TABLE IF NOT EXISTS users (
|
|||||||
display_name TEXT,
|
display_name TEXT,
|
||||||
email TEXT NOT NULL UNIQUE,
|
email TEXT NOT NULL UNIQUE,
|
||||||
password TEXT NOT NULL,
|
password TEXT NOT NULL,
|
||||||
storage_usage_bytes BIGINT NOT NULL,
|
|
||||||
storage_quota_bytes BIGINT NOT NULL,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX idx_users_email ON users(email);
|
CREATE INDEX idx_users_email ON users(email);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS accounts (
|
||||||
|
id UUID PRIMARY KEY,
|
||||||
|
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
storage_usage_bytes BIGINT NOT NULL DEFAULT 0,
|
||||||
|
storage_quota_bytes BIGINT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_accounts_user_id ON accounts(user_id);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||||
id UUID PRIMARY KEY,
|
id UUID PRIMARY KEY,
|
||||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
@@ -31,7 +40,7 @@ CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
|
|||||||
CREATE TABLE IF NOT EXISTS vfs_nodes (
|
CREATE TABLE IF NOT EXISTS vfs_nodes (
|
||||||
id UUID PRIMARY KEY,
|
id UUID PRIMARY KEY,
|
||||||
public_id TEXT NOT NULL UNIQUE, -- opaque ID for external API (no timestamp leak)
|
public_id TEXT NOT NULL UNIQUE, -- opaque ID for external API (no timestamp leak)
|
||||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
account_id UUID NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
||||||
parent_id UUID REFERENCES vfs_nodes(id) ON DELETE CASCADE, -- NULL = root directory
|
parent_id UUID REFERENCES vfs_nodes(id) ON DELETE CASCADE, -- NULL = root directory
|
||||||
kind TEXT NOT NULL CHECK (kind IN ('file', 'directory')),
|
kind TEXT NOT NULL CHECK (kind IN ('file', 'directory')),
|
||||||
status TEXT NOT NULL DEFAULT 'ready' CHECK (status IN ('pending', 'ready')),
|
status TEXT NOT NULL DEFAULT 'ready' CHECK (status IN ('pending', 'ready')),
|
||||||
@@ -46,17 +55,17 @@ CREATE TABLE IF NOT EXISTS vfs_nodes (
|
|||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
deleted_at TIMESTAMPTZ, -- soft delete for trash
|
deleted_at TIMESTAMPTZ, -- soft delete for trash
|
||||||
|
|
||||||
-- No duplicate names in same parent (per user, excluding deleted)
|
-- No duplicate names in same parent (per account, excluding deleted)
|
||||||
CONSTRAINT unique_node_name UNIQUE NULLS NOT DISTINCT (user_id, parent_id, name, deleted_at)
|
CONSTRAINT unique_node_name UNIQUE NULLS NOT DISTINCT (account_id, parent_id, name, deleted_at)
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX idx_vfs_nodes_user_id ON vfs_nodes(user_id) WHERE deleted_at IS NULL;
|
CREATE INDEX idx_vfs_nodes_account_id ON vfs_nodes(account_id) WHERE deleted_at IS NULL;
|
||||||
CREATE INDEX idx_vfs_nodes_parent_id ON vfs_nodes(parent_id) WHERE deleted_at IS NULL;
|
CREATE INDEX idx_vfs_nodes_parent_id ON vfs_nodes(parent_id) WHERE deleted_at IS NULL;
|
||||||
CREATE INDEX idx_vfs_nodes_user_parent ON vfs_nodes(user_id, parent_id) WHERE deleted_at IS NULL;
|
CREATE INDEX idx_vfs_nodes_account_parent ON vfs_nodes(account_id, parent_id) WHERE deleted_at IS NULL;
|
||||||
CREATE INDEX idx_vfs_nodes_kind ON vfs_nodes(user_id, kind) WHERE deleted_at IS NULL;
|
CREATE INDEX idx_vfs_nodes_kind ON vfs_nodes(account_id, kind) WHERE deleted_at IS NULL;
|
||||||
CREATE INDEX idx_vfs_nodes_deleted ON vfs_nodes(user_id, deleted_at) WHERE deleted_at IS NOT NULL;
|
CREATE INDEX idx_vfs_nodes_deleted ON vfs_nodes(account_id, deleted_at) WHERE deleted_at IS NOT NULL;
|
||||||
CREATE INDEX idx_vfs_nodes_public_id ON vfs_nodes(public_id);
|
CREATE INDEX idx_vfs_nodes_public_id ON vfs_nodes(public_id);
|
||||||
CREATE UNIQUE INDEX idx_vfs_nodes_user_root ON vfs_nodes(user_id) WHERE parent_id IS NULL; -- one root per user
|
CREATE UNIQUE INDEX idx_vfs_nodes_account_root ON vfs_nodes(account_id) WHERE parent_id IS NULL; -- one root per account
|
||||||
CREATE INDEX idx_vfs_nodes_pending ON vfs_nodes(created_at) WHERE status = 'pending'; -- for cleanup job
|
CREATE INDEX idx_vfs_nodes_pending ON vfs_nodes(created_at) WHERE status = 'pending'; -- for cleanup job
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS node_shares (
|
CREATE TABLE IF NOT EXISTS node_shares (
|
||||||
@@ -92,3 +101,6 @@ CREATE TRIGGER update_vfs_nodes_updated_at BEFORE UPDATE ON vfs_nodes
|
|||||||
|
|
||||||
CREATE TRIGGER update_node_shares_updated_at BEFORE UPDATE ON node_shares
|
CREATE TRIGGER update_node_shares_updated_at BEFORE UPDATE ON node_shares
|
||||||
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
||||||
|
|
||||||
|
CREATE TRIGGER update_accounts_updated_at BEFORE UPDATE ON accounts
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
||||||
@@ -4,22 +4,29 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/get-drexa/drexa/internal/account"
|
||||||
"github.com/get-drexa/drexa/internal/auth"
|
"github.com/get-drexa/drexa/internal/auth"
|
||||||
"github.com/get-drexa/drexa/internal/blob"
|
"github.com/get-drexa/drexa/internal/blob"
|
||||||
"github.com/get-drexa/drexa/internal/database"
|
"github.com/get-drexa/drexa/internal/database"
|
||||||
|
"github.com/get-drexa/drexa/internal/httperr"
|
||||||
"github.com/get-drexa/drexa/internal/upload"
|
"github.com/get-drexa/drexa/internal/upload"
|
||||||
"github.com/get-drexa/drexa/internal/user"
|
"github.com/get-drexa/drexa/internal/user"
|
||||||
"github.com/get-drexa/drexa/internal/virtualfs"
|
"github.com/get-drexa/drexa/internal/virtualfs"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||||
|
"github.com/uptrace/bun/extra/bundebug"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewServer(c Config) (*fiber.App, error) {
|
func NewServer(c Config) (*fiber.App, error) {
|
||||||
app := fiber.New()
|
app := fiber.New(fiber.Config{
|
||||||
db := database.NewFromPostgres(c.Database.PostgresURL)
|
ErrorHandler: httperr.ErrorHandler,
|
||||||
|
StreamRequestBody: true,
|
||||||
|
})
|
||||||
app.Use(logger.New())
|
app.Use(logger.New())
|
||||||
|
|
||||||
|
db := database.NewFromPostgres(c.Database.PostgresURL)
|
||||||
|
db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
|
||||||
|
|
||||||
// Initialize blob store based on config
|
// Initialize blob store based on config
|
||||||
var blobStore blob.Store
|
var blobStore blob.Store
|
||||||
switch c.Storage.Backend {
|
switch c.Storage.Backend {
|
||||||
@@ -49,7 +56,7 @@ func NewServer(c Config) (*fiber.App, error) {
|
|||||||
return nil, fmt.Errorf("unknown storage mode: %s", c.Storage.Mode)
|
return nil, fmt.Errorf("unknown storage mode: %s", c.Storage.Mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
vfs, err := virtualfs.NewVirtualFS(db, blobStore, keyResolver)
|
vfs, err := virtualfs.NewVirtualFS(blobStore, keyResolver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create virtual file system: %w", err)
|
return nil, fmt.Errorf("failed to create virtual file system: %w", err)
|
||||||
}
|
}
|
||||||
@@ -61,12 +68,16 @@ func NewServer(c Config) (*fiber.App, error) {
|
|||||||
SecretKey: c.JWT.SecretKey,
|
SecretKey: c.JWT.SecretKey,
|
||||||
})
|
})
|
||||||
uploadService := upload.NewService(vfs, blobStore)
|
uploadService := upload.NewService(vfs, blobStore)
|
||||||
|
accountService := account.NewService(userService, vfs)
|
||||||
|
|
||||||
authMiddleware := auth.NewBearerAuthMiddleware(authService, db)
|
authMiddleware := auth.NewBearerAuthMiddleware(authService, db)
|
||||||
|
|
||||||
api := app.Group("/api")
|
api := app.Group("/api")
|
||||||
|
|
||||||
|
accRouter := account.NewHTTPHandler(accountService, authService, db, authMiddleware).RegisterRoutes(api)
|
||||||
|
|
||||||
auth.NewHTTPHandler(authService, db).RegisterRoutes(api)
|
auth.NewHTTPHandler(authService, db).RegisterRoutes(api)
|
||||||
upload.NewHTTPHandler(uploadService, authMiddleware).RegisterRoutes(api, authMiddleware)
|
upload.NewHTTPHandler(uploadService, db).RegisterRoutes(accRouter)
|
||||||
|
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|||||||
42
apps/backend/internal/httperr/error.go
Normal file
42
apps/backend/internal/httperr/error.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package httperr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPError represents an HTTP error with a status code and underlying error.
|
||||||
|
type HTTPError struct {
|
||||||
|
Code int
|
||||||
|
Message string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements the error interface.
|
||||||
|
func (e *HTTPError) Error() string {
|
||||||
|
if e.Err != nil {
|
||||||
|
return fmt.Sprintf("HTTP %d: %s: %v", e.Code, e.Message, e.Err)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("HTTP %d: %s", e.Code, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap returns the underlying error.
|
||||||
|
func (e *HTTPError) Unwrap() error {
|
||||||
|
return e.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHTTPError creates a new HTTPError with the given status code, message, and underlying error.
|
||||||
|
func NewHTTPError(code int, message string, err error) *HTTPError {
|
||||||
|
return &HTTPError{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Internal creates a new HTTPError with status 500.
|
||||||
|
func Internal(err error) *HTTPError {
|
||||||
|
return NewHTTPError(fiber.StatusInternalServerError, "Internal", err)
|
||||||
|
}
|
||||||
|
|
||||||
64
apps/backend/internal/httperr/handler.go
Normal file
64
apps/backend/internal/httperr/handler.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package httperr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrorHandler is a global error handler for Fiber that logs errors and returns appropriate responses.
|
||||||
|
func ErrorHandler(c *fiber.Ctx, err error) error {
|
||||||
|
// Default status code
|
||||||
|
code := fiber.StatusInternalServerError
|
||||||
|
message := "Internal"
|
||||||
|
|
||||||
|
// Check if it's our custom HTTPError
|
||||||
|
var httpErr *HTTPError
|
||||||
|
if errors.As(err, &httpErr) {
|
||||||
|
code = httpErr.Code
|
||||||
|
message = httpErr.Message
|
||||||
|
|
||||||
|
// Log the error with underlying error details
|
||||||
|
if httpErr.Err != nil {
|
||||||
|
slog.Error("HTTP error",
|
||||||
|
"status", code,
|
||||||
|
"message", message,
|
||||||
|
"error", httpErr.Err.Error(),
|
||||||
|
"path", c.Path(),
|
||||||
|
"method", c.Method(),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
slog.Warn("HTTP error",
|
||||||
|
"status", code,
|
||||||
|
"message", message,
|
||||||
|
"path", c.Path(),
|
||||||
|
"method", c.Method(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Check if it's a Fiber error
|
||||||
|
var fiberErr *fiber.Error
|
||||||
|
if errors.As(err, &fiberErr) {
|
||||||
|
code = fiberErr.Code
|
||||||
|
message = fiberErr.Message
|
||||||
|
} else {
|
||||||
|
// Generic error - log it
|
||||||
|
slog.Error("Unhandled error",
|
||||||
|
"status", code,
|
||||||
|
"error", err.Error(),
|
||||||
|
"path", c.Path(),
|
||||||
|
"method", c.Method(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set Content-Type header
|
||||||
|
c.Set(fiber.HeaderContentType, fiber.MIMEApplicationJSONCharsetUTF8)
|
||||||
|
|
||||||
|
// Return JSON response
|
||||||
|
return c.Status(code).JSON(fiber.Map{
|
||||||
|
"error": message,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
@@ -6,4 +6,5 @@ var (
|
|||||||
ErrNotFound = errors.New("not found")
|
ErrNotFound = errors.New("not found")
|
||||||
ErrParentNotDirectory = errors.New("parent is not a directory")
|
ErrParentNotDirectory = errors.New("parent is not a directory")
|
||||||
ErrConflict = errors.New("node conflict")
|
ErrConflict = errors.New("node conflict")
|
||||||
|
ErrContentNotUploaded = errors.New("content has not been uploaded")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,9 +2,12 @@ package upload
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/get-drexa/drexa/internal/auth"
|
"github.com/get-drexa/drexa/internal/account"
|
||||||
|
"github.com/get-drexa/drexa/internal/httperr"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
type createUploadRequest struct {
|
type createUploadRequest struct {
|
||||||
@@ -17,17 +20,16 @@ type updateUploadRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type HTTPHandler struct {
|
type HTTPHandler struct {
|
||||||
service *Service
|
service *Service
|
||||||
authMiddleware fiber.Handler
|
db *bun.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTPHandler(s *Service, authMiddleware fiber.Handler) *HTTPHandler {
|
func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
|
||||||
return &HTTPHandler{service: s, authMiddleware: authMiddleware}
|
return &HTTPHandler{service: s, db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HTTPHandler) RegisterRoutes(api fiber.Router, authMiddleware fiber.Handler) {
|
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
|
||||||
upload := api.Group("/uploads")
|
upload := api.Group("/uploads")
|
||||||
upload.Use(authMiddleware)
|
|
||||||
|
|
||||||
upload.Post("/", h.Create)
|
upload.Post("/", h.Create)
|
||||||
upload.Put("/:uploadID/content", h.ReceiveContent)
|
upload.Put("/:uploadID/content", h.ReceiveContent)
|
||||||
@@ -35,9 +37,9 @@ func (h *HTTPHandler) RegisterRoutes(api fiber.Router, authMiddleware fiber.Hand
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *HTTPHandler) Create(c *fiber.Ctx) error {
|
func (h *HTTPHandler) Create(c *fiber.Ctx) error {
|
||||||
u, err := auth.AuthenticatedUser(c)
|
account := account.CurrentAccount(c)
|
||||||
if err != nil {
|
if account == nil {
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
req := new(createUploadRequest)
|
req := new(createUploadRequest)
|
||||||
@@ -45,38 +47,54 @@ func (h *HTTPHandler) Create(c *fiber.Ctx) error {
|
|||||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
|
||||||
}
|
}
|
||||||
|
|
||||||
upload, err := h.service.CreateUpload(c.Context(), u.ID, CreateUploadOptions{
|
upload, err := h.service.CreateUpload(c.Context(), h.db, account.ID, CreateUploadOptions{
|
||||||
ParentID: req.ParentID,
|
ParentID: req.ParentID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
if errors.Is(err, ErrNotFound) {
|
||||||
|
return c.SendStatus(fiber.StatusNotFound)
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrParentNotDirectory) {
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Parent is not a directory"})
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrConflict) {
|
||||||
|
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "A file with this name already exists"})
|
||||||
|
}
|
||||||
|
return httperr.Internal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if upload.UploadURL == "" {
|
||||||
|
upload.UploadURL = fmt.Sprintf("%s%s/%s/content", c.BaseURL(), c.OriginalURL(), upload.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(upload)
|
return c.JSON(upload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
|
func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
|
||||||
u, err := auth.AuthenticatedUser(c)
|
account := account.CurrentAccount(c)
|
||||||
if err != nil {
|
if account == nil {
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
uploadID := c.Params("uploadID")
|
uploadID := c.Params("uploadID")
|
||||||
|
|
||||||
err = h.service.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream())
|
err := h.service.ReceiveUpload(c.Context(), h.db, account.ID, uploadID, c.Context().RequestBodyStream())
|
||||||
defer c.Request().CloseBodyStream()
|
defer c.Context().Request.CloseBodyStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
if errors.Is(err, ErrNotFound) {
|
||||||
|
return c.SendStatus(fiber.StatusNotFound)
|
||||||
|
}
|
||||||
|
return httperr.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.SendStatus(fiber.StatusNoContent)
|
return c.SendStatus(fiber.StatusNoContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HTTPHandler) Update(c *fiber.Ctx) error {
|
func (h *HTTPHandler) Update(c *fiber.Ctx) error {
|
||||||
u, err := auth.AuthenticatedUser(c)
|
account := account.CurrentAccount(c)
|
||||||
if err != nil {
|
if account == nil {
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
req := new(updateUploadRequest)
|
req := new(updateUploadRequest)
|
||||||
@@ -85,12 +103,15 @@ func (h *HTTPHandler) Update(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if req.Status == StatusCompleted {
|
if req.Status == StatusCompleted {
|
||||||
upload, err := h.service.CompleteUpload(c.Context(), u.ID, c.Params("uploadID"))
|
upload, err := h.service.CompleteUpload(c.Context(), h.db, account.ID, c.Params("uploadID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrNotFound) {
|
if errors.Is(err, ErrNotFound) {
|
||||||
return c.SendStatus(fiber.StatusNotFound)
|
return c.SendStatus(fiber.StatusNotFound)
|
||||||
}
|
}
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
|
if errors.Is(err, ErrContentNotUploaded) {
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Content has not been uploaded"})
|
||||||
|
}
|
||||||
|
return httperr.Internal(err)
|
||||||
}
|
}
|
||||||
return c.JSON(upload)
|
return c.JSON(upload)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package upload
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -10,6 +11,7 @@ import (
|
|||||||
"github.com/get-drexa/drexa/internal/blob"
|
"github.com/get-drexa/drexa/internal/blob"
|
||||||
"github.com/get-drexa/drexa/internal/virtualfs"
|
"github.com/get-drexa/drexa/internal/virtualfs"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
@@ -33,8 +35,8 @@ type CreateUploadOptions struct {
|
|||||||
Name string
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts CreateUploadOptions) (*Upload, error) {
|
func (s *Service) CreateUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, opts CreateUploadOptions) (*Upload, error) {
|
||||||
parentNode, err := s.vfs.FindNodeByPublicID(ctx, userID, opts.ParentID)
|
parentNode, err := s.vfs.FindNodeByPublicID(ctx, db, accountID, opts.ParentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, virtualfs.ErrNodeNotFound) {
|
if errors.Is(err, virtualfs.ErrNodeNotFound) {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
@@ -46,7 +48,7 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
|
|||||||
return nil, ErrParentNotDirectory
|
return nil, ErrParentNotDirectory
|
||||||
}
|
}
|
||||||
|
|
||||||
node, err := s.vfs.CreateFile(ctx, userID, virtualfs.CreateFileOptions{
|
node, err := s.vfs.CreateFile(ctx, db, accountID, virtualfs.CreateFileOptions{
|
||||||
ParentID: parentNode.ID,
|
ParentID: parentNode.ID,
|
||||||
Name: opts.Name,
|
Name: opts.Name,
|
||||||
})
|
})
|
||||||
@@ -57,12 +59,17 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
uploadURL, err := s.blobStore.GenerateUploadURL(ctx, node.BlobKey, blob.UploadURLOptions{
|
var uploadURL string
|
||||||
Duration: 1 * time.Hour,
|
if s.blobStore.SupportsDirectUpload() {
|
||||||
})
|
uploadURL, err = s.blobStore.GenerateUploadURL(ctx, node.BlobKey, blob.UploadURLOptions{
|
||||||
if err != nil {
|
Duration: 1 * time.Hour,
|
||||||
_ = s.vfs.PermanentlyDeleteNode(ctx, node)
|
})
|
||||||
return nil, err
|
if err != nil {
|
||||||
|
_ = s.vfs.PermanentlyDeleteNode(ctx, db, node)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
uploadURL = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
upload := &Upload{
|
upload := &Upload{
|
||||||
@@ -77,7 +84,8 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
|
|||||||
return upload, nil
|
return upload, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID string, reader io.Reader) error {
|
func (s *Service) ReceiveUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, uploadID string, reader io.Reader) error {
|
||||||
|
fmt.Printf("reader: %v\n", reader)
|
||||||
n, ok := s.pendingUploads.Load(uploadID)
|
n, ok := s.pendingUploads.Load(uploadID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return ErrNotFound
|
return ErrNotFound
|
||||||
@@ -88,11 +96,11 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID
|
|||||||
return ErrNotFound
|
return ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
if upload.TargetNode.UserID != userID {
|
if upload.TargetNode.AccountID != accountID {
|
||||||
return ErrNotFound
|
return ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
err := s.vfs.WriteFile(ctx, upload.TargetNode, virtualfs.FileContentFromReader(reader))
|
err := s.vfs.WriteFile(ctx, db, upload.TargetNode, virtualfs.FileContentFromReader(reader))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -102,7 +110,7 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID string) (*Upload, error) {
|
func (s *Service) CompleteUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, uploadID string) (*Upload, error) {
|
||||||
n, ok := s.pendingUploads.Load(uploadID)
|
n, ok := s.pendingUploads.Load(uploadID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
@@ -113,7 +121,7 @@ func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID
|
|||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
if upload.TargetNode.UserID != userID {
|
if upload.TargetNode.AccountID != accountID {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,8 +129,11 @@ func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID
|
|||||||
return upload, nil
|
return upload, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := s.vfs.WriteFile(ctx, upload.TargetNode, virtualfs.FileContentFromBlobKey(upload.TargetNode.BlobKey))
|
err := s.vfs.WriteFile(ctx, db, upload.TargetNode, virtualfs.FileContentFromBlobKey(upload.TargetNode.BlobKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, blob.ErrNotFound) {
|
||||||
|
return nil, ErrContentNotUploaded
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,14 +11,12 @@ import (
|
|||||||
type User struct {
|
type User struct {
|
||||||
bun.BaseModel `bun:"users"`
|
bun.BaseModel `bun:"users"`
|
||||||
|
|
||||||
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
|
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
|
||||||
DisplayName string `bun:"display_name" json:"displayName"`
|
DisplayName string `bun:"display_name" json:"displayName"`
|
||||||
Email string `bun:"email,unique,notnull" json:"email"`
|
Email string `bun:"email,unique,notnull" json:"email"`
|
||||||
Password password.Hashed `bun:"password,notnull" json:"-"`
|
Password password.Hashed `bun:"password,notnull" json:"-"`
|
||||||
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
|
CreatedAt time.Time `bun:"created_at,notnull,nullzero" json:"createdAt"`
|
||||||
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
|
UpdatedAt time.Time `bun:"updated_at,notnull,nullzero" json:"updatedAt"`
|
||||||
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
|
|
||||||
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUserID() (uuid.UUID, error) {
|
func newUserID() (uuid.UUID, error) {
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ func NewFlatKeyResolver() *FlatKeyResolver {
|
|||||||
return &FlatKeyResolver{}
|
return &FlatKeyResolver{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *FlatKeyResolver) KeyMode() blob.KeyMode {
|
func (r *FlatKeyResolver) ShouldPersistKey() bool {
|
||||||
return blob.KeyModeStable
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *FlatKeyResolver) Resolve(ctx context.Context, node *Node) (blob.Key, error) {
|
func (r *FlatKeyResolver) Resolve(ctx context.Context, node *Node) (blob.Key, error) {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package virtualfs
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/get-drexa/drexa/internal/blob"
|
"github.com/get-drexa/drexa/internal/blob"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
@@ -17,8 +18,8 @@ func NewHierarchicalKeyResolver(db *bun.DB) *HierarchicalKeyResolver {
|
|||||||
return &HierarchicalKeyResolver{db: db}
|
return &HierarchicalKeyResolver{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *HierarchicalKeyResolver) KeyMode() blob.KeyMode {
|
func (r *HierarchicalKeyResolver) ShouldPersistKey() bool {
|
||||||
return blob.KeyModeDerived
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *HierarchicalKeyResolver) Resolve(ctx context.Context, node *Node) (blob.Key, error) {
|
func (r *HierarchicalKeyResolver) Resolve(ctx context.Context, node *Node) (blob.Key, error) {
|
||||||
@@ -27,7 +28,7 @@ func (r *HierarchicalKeyResolver) Resolve(ctx context.Context, node *Node) (blob
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return blob.Key(path), nil
|
return blob.Key(fmt.Sprintf("%s/%s", node.AccountID, path)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *HierarchicalKeyResolver) ResolveDeletionKeys(ctx context.Context, node *Node, allKeys []blob.Key) (*DeletionPlan, error) {
|
func (r *HierarchicalKeyResolver) ResolveDeletionKeys(ctx context.Context, node *Node, allKeys []blob.Key) (*DeletionPlan, error) {
|
||||||
|
|||||||
@@ -7,7 +7,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type BlobKeyResolver interface {
|
type BlobKeyResolver interface {
|
||||||
KeyMode() blob.KeyMode
|
// ShouldPersistKey returns true if the resolved key should be stored in node.BlobKey.
|
||||||
|
// Flat keys (e.g. UUIDs) return true - key is generated once and stored.
|
||||||
|
// Hierarchical keys return false - key is derived from path each time.
|
||||||
|
ShouldPersistKey() bool
|
||||||
Resolve(ctx context.Context, node *Node) (blob.Key, error)
|
Resolve(ctx context.Context, node *Node) (blob.Key, error)
|
||||||
ResolveDeletionKeys(ctx context.Context, node *Node, allKeys []blob.Key) (*DeletionPlan, error)
|
ResolveDeletionKeys(ctx context.Context, node *Node, allKeys []blob.Key) (*DeletionPlan, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,21 +25,25 @@ const (
|
|||||||
type Node struct {
|
type Node struct {
|
||||||
bun.BaseModel `bun:"vfs_nodes"`
|
bun.BaseModel `bun:"vfs_nodes"`
|
||||||
|
|
||||||
ID uuid.UUID `bun:",pk,type:uuid"`
|
ID uuid.UUID `bun:",pk,type:uuid"`
|
||||||
PublicID string `bun:"public_id,notnull"`
|
PublicID string `bun:"public_id,notnull"`
|
||||||
UserID uuid.UUID `bun:"user_id,notnull"`
|
AccountID uuid.UUID `bun:"account_id,notnull,type:uuid"`
|
||||||
ParentID uuid.UUID `bun:"parent_id,notnull"`
|
ParentID uuid.UUID `bun:"parent_id,nullzero"`
|
||||||
Kind NodeKind `bun:"kind,notnull"`
|
Kind NodeKind `bun:"kind,notnull"`
|
||||||
Status NodeStatus `bun:"status,notnull"`
|
Status NodeStatus `bun:"status,notnull"`
|
||||||
Name string `bun:"name,notnull"`
|
Name string `bun:"name,notnull"`
|
||||||
|
|
||||||
BlobKey blob.Key `bun:"blob_key"`
|
BlobKey blob.Key `bun:"blob_key,nullzero"`
|
||||||
Size int64 `bun:"size"`
|
Size int64 `bun:"size"`
|
||||||
MimeType string `bun:"mime_type"`
|
MimeType string `bun:"mime_type,nullzero"`
|
||||||
|
|
||||||
CreatedAt time.Time `bun:"created_at,notnull"`
|
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
|
||||||
UpdatedAt time.Time `bun:"updated_at,notnull"`
|
UpdatedAt time.Time `bun:"updated_at,notnull,nullzero"`
|
||||||
DeletedAt time.Time `bun:"deleted_at"`
|
DeletedAt time.Time `bun:"deleted_at,nullzero"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNodeID() (uuid.UUID, error) {
|
||||||
|
return uuid.NewV7()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAccessible returns true if the node can be accessed.
|
// IsAccessible returns true if the node can be accessed.
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
const absolutePathQuery = `WITH RECURSIVE path AS (
|
const absolutePathQuery = `WITH RECURSIVE path AS (
|
||||||
SELECT id, parent_id, name, 1 as depth
|
SELECT id, parent_id, name, 1 as depth
|
||||||
FROM vfs_nodes WHERE id = $1 AND deleted_at IS NULL
|
FROM vfs_nodes WHERE id = ? AND deleted_at IS NULL
|
||||||
|
|
||||||
UNION ALL
|
UNION ALL
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ func JoinPath(parts ...string) string {
|
|||||||
return strings.Join(parts, "/")
|
return strings.Join(parts, "/")
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildNodeAbsolutePath(ctx context.Context, db *bun.DB, nodeID uuid.UUID) (string, error) {
|
func buildNodeAbsolutePath(ctx context.Context, db bun.IDB, nodeID uuid.UUID) (string, error) {
|
||||||
var path []string
|
var path []string
|
||||||
err := db.NewRaw(absolutePathQuery, nodeID).Scan(ctx, &path)
|
err := db.NewRaw(absolutePathQuery, nodeID).Scan(ctx, &path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type VirtualFS struct {
|
type VirtualFS struct {
|
||||||
db *bun.DB
|
|
||||||
blobStore blob.Store
|
blobStore blob.Store
|
||||||
keyResolver BlobKeyResolver
|
keyResolver BlobKeyResolver
|
||||||
|
|
||||||
@@ -42,6 +41,8 @@ type FileContent struct {
|
|||||||
blobKey blob.Key
|
blobKey blob.Key
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const RootDirectoryName = "root"
|
||||||
|
|
||||||
func FileContentFromReader(reader io.Reader) FileContent {
|
func FileContentFromReader(reader io.Reader) FileContent {
|
||||||
return FileContent{reader: reader}
|
return FileContent{reader: reader}
|
||||||
}
|
}
|
||||||
@@ -50,23 +51,22 @@ func FileContentFromBlobKey(blobKey blob.Key) FileContent {
|
|||||||
return FileContent{blobKey: blobKey}
|
return FileContent{blobKey: blobKey}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewVirtualFS(db *bun.DB, blobStore blob.Store, keyResolver BlobKeyResolver) (*VirtualFS, error) {
|
func NewVirtualFS(blobStore blob.Store, keyResolver BlobKeyResolver) (*VirtualFS, error) {
|
||||||
sqid, err := sqids.New()
|
sqid, err := sqids.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &VirtualFS{
|
return &VirtualFS{
|
||||||
db: db,
|
|
||||||
blobStore: blobStore,
|
blobStore: blobStore,
|
||||||
keyResolver: keyResolver,
|
keyResolver: keyResolver,
|
||||||
sqid: sqid,
|
sqid: sqid,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) FindNode(ctx context.Context, userID, fileID string) (*Node, error) {
|
func (vfs *VirtualFS) FindNode(ctx context.Context, db bun.IDB, accountID, fileID string) (*Node, error) {
|
||||||
var node Node
|
var node Node
|
||||||
err := vfs.db.NewSelect().Model(&node).
|
err := db.NewSelect().Model(&node).
|
||||||
Where("user_id = ?", userID).
|
Where("account_id = ?", accountID).
|
||||||
Where("id = ?", fileID).
|
Where("id = ?", fileID).
|
||||||
Where("status = ?", NodeStatusReady).
|
Where("status = ?", NodeStatusReady).
|
||||||
Where("deleted_at IS NULL").
|
Where("deleted_at IS NULL").
|
||||||
@@ -80,10 +80,10 @@ func (vfs *VirtualFS) FindNode(ctx context.Context, userID, fileID string) (*Nod
|
|||||||
return &node, nil
|
return &node, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, userID uuid.UUID, publicID string) (*Node, error) {
|
func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, db bun.IDB, accountID uuid.UUID, publicID string) (*Node, error) {
|
||||||
var node Node
|
var node Node
|
||||||
err := vfs.db.NewSelect().Model(&node).
|
err := db.NewSelect().Model(&node).
|
||||||
Where("user_id = ?", userID).
|
Where("account_id = ?", accountID).
|
||||||
Where("public_id = ?", publicID).
|
Where("public_id = ?", publicID).
|
||||||
Where("status = ?", NodeStatusReady).
|
Where("status = ?", NodeStatusReady).
|
||||||
Where("deleted_at IS NULL").
|
Where("deleted_at IS NULL").
|
||||||
@@ -97,14 +97,14 @@ func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, userID uuid.UUID,
|
|||||||
return &node, nil
|
return &node, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, error) {
|
func (vfs *VirtualFS) ListChildren(ctx context.Context, db bun.IDB, node *Node) ([]*Node, error) {
|
||||||
if !node.IsAccessible() {
|
if !node.IsAccessible() {
|
||||||
return nil, ErrNodeNotFound
|
return nil, ErrNodeNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
var nodes []*Node
|
var nodes []*Node
|
||||||
err := vfs.db.NewSelect().Model(&nodes).
|
err := db.NewSelect().Model(&nodes).
|
||||||
Where("user_id = ?", node.UserID).
|
Where("account_id = ?", node.AccountID).
|
||||||
Where("parent_id = ?", node.ID).
|
Where("parent_id = ?", node.ID).
|
||||||
Where("status = ?", NodeStatusReady).
|
Where("status = ?", NodeStatusReady).
|
||||||
Where("deleted_at IS NULL").
|
Where("deleted_at IS NULL").
|
||||||
@@ -119,29 +119,35 @@ func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, er
|
|||||||
return nodes, nil
|
return nodes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) CreateFile(ctx context.Context, userID uuid.UUID, opts CreateFileOptions) (*Node, error) {
|
func (vfs *VirtualFS) CreateFile(ctx context.Context, db bun.IDB, accountID uuid.UUID, opts CreateFileOptions) (*Node, error) {
|
||||||
pid, err := vfs.generatePublicID()
|
pid, err := vfs.generatePublicID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
node := Node{
|
id, err := newNodeID()
|
||||||
PublicID: pid,
|
if err != nil {
|
||||||
UserID: userID,
|
return nil, err
|
||||||
ParentID: opts.ParentID,
|
|
||||||
Kind: NodeKindFile,
|
|
||||||
Status: NodeStatusPending,
|
|
||||||
Name: opts.Name,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if vfs.keyResolver.KeyMode() == blob.KeyModeStable {
|
node := Node{
|
||||||
|
ID: id,
|
||||||
|
PublicID: pid,
|
||||||
|
AccountID: accountID,
|
||||||
|
ParentID: opts.ParentID,
|
||||||
|
Kind: NodeKindFile,
|
||||||
|
Status: NodeStatusPending,
|
||||||
|
Name: opts.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
if vfs.keyResolver.ShouldPersistKey() {
|
||||||
node.BlobKey, err = vfs.keyResolver.Resolve(ctx, &node)
|
node.BlobKey, err = vfs.keyResolver.Resolve(ctx, &node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = vfs.db.NewInsert().Model(&node).Returning("*").Exec(ctx)
|
_, err = db.NewInsert().Model(&node).Returning("*").Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if database.IsUniqueViolation(err) {
|
if database.IsUniqueViolation(err) {
|
||||||
return nil, ErrNodeConflict
|
return nil, ErrNodeConflict
|
||||||
@@ -152,7 +158,7 @@ func (vfs *VirtualFS) CreateFile(ctx context.Context, userID uuid.UUID, opts Cre
|
|||||||
return &node, nil
|
return &node, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileContent) error {
|
func (vfs *VirtualFS) WriteFile(ctx context.Context, db bun.IDB, node *Node, content FileContent) error {
|
||||||
if content.reader == nil && content.blobKey.IsNil() {
|
if content.reader == nil && content.blobKey.IsNil() {
|
||||||
return blob.ErrInvalidFileContent
|
return blob.ErrInvalidFileContent
|
||||||
}
|
}
|
||||||
@@ -184,7 +190,7 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if vfs.keyResolver.KeyMode() == blob.KeyModeStable {
|
if vfs.keyResolver.ShouldPersistKey() {
|
||||||
node.BlobKey = key
|
node.BlobKey = key
|
||||||
setCols = append(setCols, "blob_key")
|
setCols = append(setCols, "blob_key")
|
||||||
}
|
}
|
||||||
@@ -222,7 +228,7 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
|
|||||||
setCols = append(setCols, "mime_type", "blob_key", "size", "status")
|
setCols = append(setCols, "mime_type", "blob_key", "size", "status")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := vfs.db.NewUpdate().Model(&node).
|
_, err := db.NewUpdate().Model(node).
|
||||||
Column(setCols...).
|
Column(setCols...).
|
||||||
WherePK().
|
WherePK().
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
@@ -233,22 +239,28 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) CreateDirectory(ctx context.Context, userID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) {
|
func (vfs *VirtualFS) CreateDirectory(ctx context.Context, db bun.IDB, accountID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) {
|
||||||
pid, err := vfs.generatePublicID()
|
pid, err := vfs.generatePublicID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
node := Node{
|
id, err := newNodeID()
|
||||||
PublicID: pid,
|
if err != nil {
|
||||||
UserID: userID,
|
return nil, err
|
||||||
ParentID: parentID,
|
|
||||||
Kind: NodeKindDirectory,
|
|
||||||
Status: NodeStatusReady,
|
|
||||||
Name: name,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = vfs.db.NewInsert().Model(&node).Exec(ctx)
|
node := Node{
|
||||||
|
ID: id,
|
||||||
|
PublicID: pid,
|
||||||
|
AccountID: accountID,
|
||||||
|
ParentID: parentID,
|
||||||
|
Kind: NodeKindDirectory,
|
||||||
|
Status: NodeStatusReady,
|
||||||
|
Name: name,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.NewInsert().Model(node).Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if database.IsUniqueViolation(err) {
|
if database.IsUniqueViolation(err) {
|
||||||
return nil, ErrNodeConflict
|
return nil, ErrNodeConflict
|
||||||
@@ -259,12 +271,12 @@ func (vfs *VirtualFS) CreateDirectory(ctx context.Context, userID uuid.UUID, par
|
|||||||
return &node, nil
|
return &node, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, node *Node) error {
|
func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, db bun.IDB, node *Node) error {
|
||||||
if !node.IsAccessible() {
|
if !node.IsAccessible() {
|
||||||
return ErrNodeNotFound
|
return ErrNodeNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := vfs.db.NewUpdate().Model(node).
|
_, err := db.NewUpdate().Model(node).
|
||||||
WherePK().
|
WherePK().
|
||||||
Where("deleted_at IS NULL").
|
Where("deleted_at IS NULL").
|
||||||
Where("status = ?", NodeStatusReady).
|
Where("status = ?", NodeStatusReady).
|
||||||
@@ -281,12 +293,12 @@ func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, node *Node) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) RestoreNode(ctx context.Context, node *Node) error {
|
func (vfs *VirtualFS) RestoreNode(ctx context.Context, db bun.IDB, node *Node) error {
|
||||||
if node.Status != NodeStatusReady {
|
if node.Status != NodeStatusReady {
|
||||||
return ErrNodeNotFound
|
return ErrNodeNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := vfs.db.NewUpdate().Model(node).
|
_, err := db.NewUpdate().Model(node).
|
||||||
WherePK().
|
WherePK().
|
||||||
Where("deleted_at IS NOT NULL").
|
Where("deleted_at IS NOT NULL").
|
||||||
Set("deleted_at = NULL").
|
Set("deleted_at = NULL").
|
||||||
@@ -302,12 +314,12 @@ func (vfs *VirtualFS) RestoreNode(ctx context.Context, node *Node) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) RenameNode(ctx context.Context, node *Node, name string) error {
|
func (vfs *VirtualFS) RenameNode(ctx context.Context, db bun.IDB, node *Node, name string) error {
|
||||||
if !node.IsAccessible() {
|
if !node.IsAccessible() {
|
||||||
return ErrNodeNotFound
|
return ErrNodeNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := vfs.db.NewUpdate().Model(node).
|
_, err := db.NewUpdate().Model(node).
|
||||||
WherePK().
|
WherePK().
|
||||||
Where("status = ?", NodeStatusReady).
|
Where("status = ?", NodeStatusReady).
|
||||||
Where("deleted_at IS NULL").
|
Where("deleted_at IS NULL").
|
||||||
@@ -323,7 +335,7 @@ func (vfs *VirtualFS) RenameNode(ctx context.Context, node *Node, name string) e
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UUID) error {
|
func (vfs *VirtualFS) MoveNode(ctx context.Context, db bun.IDB, node *Node, parentID uuid.UUID) error {
|
||||||
if !node.IsAccessible() {
|
if !node.IsAccessible() {
|
||||||
return ErrNodeNotFound
|
return ErrNodeNotFound
|
||||||
}
|
}
|
||||||
@@ -333,7 +345,7 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = vfs.db.NewUpdate().Model(node).
|
_, err = db.NewUpdate().Model(node).
|
||||||
WherePK().
|
WherePK().
|
||||||
Where("status = ?", NodeStatusReady).
|
Where("status = ?", NodeStatusReady).
|
||||||
Where("deleted_at IS NULL").
|
Where("deleted_at IS NULL").
|
||||||
@@ -360,9 +372,9 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if vfs.keyResolver.KeyMode() == blob.KeyModeStable {
|
if vfs.keyResolver.ShouldPersistKey() {
|
||||||
node.BlobKey = newKey
|
node.BlobKey = newKey
|
||||||
_, err = vfs.db.NewUpdate().Model(node).
|
_, err = db.NewUpdate().Model(node).
|
||||||
WherePK().
|
WherePK().
|
||||||
Set("blob_key = ?", newKey).
|
Set("blob_key = ?", newKey).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
@@ -374,34 +386,34 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) AbsolutePath(ctx context.Context, node *Node) (string, error) {
|
func (vfs *VirtualFS) AbsolutePath(ctx context.Context, db bun.IDB, node *Node) (string, error) {
|
||||||
if !node.IsAccessible() {
|
if !node.IsAccessible() {
|
||||||
return "", ErrNodeNotFound
|
return "", ErrNodeNotFound
|
||||||
}
|
}
|
||||||
return buildNodeAbsolutePath(ctx, vfs.db, node.ID)
|
return buildNodeAbsolutePath(ctx, db, node.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) PermanentlyDeleteNode(ctx context.Context, node *Node) error {
|
func (vfs *VirtualFS) PermanentlyDeleteNode(ctx context.Context, db bun.IDB, node *Node) error {
|
||||||
if !node.IsAccessible() {
|
if !node.IsAccessible() {
|
||||||
return ErrNodeNotFound
|
return ErrNodeNotFound
|
||||||
}
|
}
|
||||||
switch node.Kind {
|
switch node.Kind {
|
||||||
case NodeKindFile:
|
case NodeKindFile:
|
||||||
return vfs.permanentlyDeleteFileNode(ctx, node)
|
return vfs.permanentlyDeleteFileNode(ctx, db, node)
|
||||||
case NodeKindDirectory:
|
case NodeKindDirectory:
|
||||||
return vfs.permanentlyDeleteDirectoryNode(ctx, node)
|
return vfs.permanentlyDeleteDirectoryNode(ctx, db, node)
|
||||||
default:
|
default:
|
||||||
return ErrUnsupportedOperation
|
return ErrUnsupportedOperation
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) permanentlyDeleteFileNode(ctx context.Context, node *Node) error {
|
func (vfs *VirtualFS) permanentlyDeleteFileNode(ctx context.Context, db bun.IDB, node *Node) error {
|
||||||
err := vfs.blobStore.Delete(ctx, node.BlobKey)
|
err := vfs.blobStore.Delete(ctx, node.BlobKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = vfs.db.NewDelete().Model(node).WherePK().Exec(ctx)
|
_, err = db.NewDelete().Model(node).WherePK().Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -409,7 +421,7 @@ func (vfs *VirtualFS) permanentlyDeleteFileNode(ctx context.Context, node *Node)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, node *Node) error {
|
func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, db bun.IDB, node *Node) error {
|
||||||
const descendantsQuery = `WITH RECURSIVE descendants AS (
|
const descendantsQuery = `WITH RECURSIVE descendants AS (
|
||||||
SELECT id, blob_key FROM vfs_nodes WHERE id = ?
|
SELECT id, blob_key FROM vfs_nodes WHERE id = ?
|
||||||
UNION ALL
|
UNION ALL
|
||||||
@@ -423,14 +435,29 @@ func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, node *
|
|||||||
BlobKey blob.Key `bun:"blob_key"`
|
BlobKey blob.Key `bun:"blob_key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := vfs.db.BeginTx(ctx, nil)
|
// If db is already a transaction, use it directly; otherwise start a new transaction
|
||||||
if err != nil {
|
var tx bun.IDB
|
||||||
return err
|
var startedTx *bun.Tx
|
||||||
|
switch v := db.(type) {
|
||||||
|
case *bun.DB:
|
||||||
|
newTx, err := v.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
startedTx = &newTx
|
||||||
|
tx = newTx
|
||||||
|
defer func() {
|
||||||
|
if startedTx != nil {
|
||||||
|
(*startedTx).Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
default:
|
||||||
|
// Assume it's already a transaction
|
||||||
|
tx = db
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
var records []nodeRecord
|
var records []nodeRecord
|
||||||
err = tx.NewRaw(descendantsQuery, node.ID).Scan(ctx, &records)
|
err := tx.NewRaw(descendantsQuery, node.ID).Scan(ctx, &records)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return ErrNodeNotFound
|
return ErrNodeNotFound
|
||||||
@@ -472,7 +499,14 @@ func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, node *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Commit()
|
// Only commit if we started the transaction
|
||||||
|
if startedTx != nil {
|
||||||
|
err := (*startedTx).Commit()
|
||||||
|
startedTx = nil // Prevent defer from rolling back
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vfs *VirtualFS) generatePublicID() (string, error) {
|
func (vfs *VirtualFS) generatePublicID() (string, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user