24 Commits

Author SHA1 Message Date
085bbd4ffe fix: use db tx when creating directory 2025-12-05 00:55:41 +00:00
1c9e049448 fix: handle create dir conflict in api 2025-12-05 00:38:31 +00:00
3ea12cf51a feat: add query param to include dir path in query 2025-12-05 00:38:05 +00:00
1bedafd3de feat: add /users/me route 2025-12-04 00:49:59 +00:00
539b15dcb7 refactor: introduce reqctx pkg for request context 2025-12-04 00:48:43 +00:00
57167d5715 feat: impl cookie-based auth tokens exchange
implement access/refresh token exchange via cookies as well as automatic
access token refresh
2025-12-04 00:26:20 +00:00
d4c4e84fbf feat: use argon2id to hash refresh tokens in db 2025-12-03 23:05:00 +00:00
589158a8ed feat: impl directory endpoints 2025-12-03 00:56:44 +00:00
3a6fafacca feat: impl refresh token rotation 2025-12-03 00:07:39 +00:00
a7bdc831ea chore: add data dir to gitignore 2025-12-02 22:10:57 +00:00
ca9dfd90b2 feat: impl files endpoints 2025-12-02 22:08:50 +00:00
e8a558d652 refactor: wrap fiber.App in custom Server struct 2025-11-30 23:14:09 +00:00
3e96c42c4a fix: handle missing expected err cases 2025-11-30 20:08:31 +00:00
6c61cbe1fd fix: hierarchical keys should have acc id prefix 2025-11-30 19:49:13 +00:00
ccdeaf0364 fix: some weird upload bugs 2025-11-30 19:39:47 +00:00
a3110b67c3 fix: vfs node primary id not generated properly 2025-11-30 19:20:08 +00:00
1907cd83c8 fix: upload url generation 2025-11-30 19:19:54 +00:00
033ad65d5f feat: improve err logging 2025-11-30 19:19:33 +00:00
987edc0d4a fix: timestamp cols default not working 2025-11-30 17:42:35 +00:00
fb8e91dd47 feat: create root dir on acc registration 2025-11-30 17:37:59 +00:00
0b8aee5d60 refactor: make vfs methods accept bun.IDB 2025-11-30 17:31:24 +00:00
89b62f6d8a feat: introduce account 2025-11-30 17:12:50 +00:00
1c1392a0a1 refactor: replace KeyMode with ShouldPersistKey 2025-11-30 15:02:37 +00:00
6984bb209e refactor: node deletion 2025-11-30 01:16:44 +00:00
43 changed files with 2052 additions and 439 deletions

1
apps/backend/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
data/

View File

@@ -29,6 +29,5 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
log.Printf("starting server on :%d", config.Server.Port) log.Fatal(server.Start())
log.Fatal(server.Listen(fmt.Sprintf(":%d", config.Server.Port)))
} }

View File

@@ -28,3 +28,8 @@ storage:
# Required when backend is "s3" # Required when backend is "s3"
# bucket: my-drexa-bucket # bucket: my-drexa-bucket
cookie:
# Domain for cross-subdomain auth cookies.
# Set this when frontend and API are on different subdomains (e.g., "app.com" for web.app.com + api.app.com).
# Leave empty for single-domain or localhost setups.
# domain: app.com

View File

@@ -7,14 +7,16 @@ require (
github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/fiber/v2 v2.52.9
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/sqids/sqids-go v0.4.1 github.com/sqids/sqids-go v0.4.1
github.com/uptrace/bun v1.2.15 github.com/uptrace/bun v1.2.16
golang.org/x/crypto v0.40.0 github.com/uptrace/bun/extra/bundebug v1.2.16
golang.org/x/crypto v0.45.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
require ( require (
go.opentelemetry.io/otel v1.37.0 // indirect github.com/fatih/color v1.18.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect go.opentelemetry.io/otel v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
mellium.im/sasl v0.3.2 // indirect mellium.im/sasl v0.3.2 // indirect
) )
@@ -29,12 +31,12 @@ require (
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
github.com/rivo/uniseg v0.2.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/uptrace/bun/dialect/pgdialect v1.2.15 github.com/uptrace/bun/dialect/pgdialect v1.2.16
github.com/uptrace/bun/driver/pgdriver v1.2.15 github.com/uptrace/bun/driver/pgdriver v1.2.16
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect
github.com/valyala/tcplisten v1.0.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
golang.org/x/sys v0.34.0 // indirect golang.org/x/sys v0.38.0 // indirect
) )

View File

@@ -2,6 +2,8 @@ github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik= github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik=
github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw= github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw=
@@ -34,16 +36,18 @@ github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw= github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8= github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE= github.com/uptrace/bun v1.2.16 h1:QlObi6ZIK5Ao7kAALnh91HWYNZUBbVwye52fmlQM9kc=
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038= github.com/uptrace/bun v1.2.16/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM=
github.com/uptrace/bun/dialect/pgdialect v1.2.15 h1:er+/3giAIqpfrXJw+KP9B7ujyQIi5XkPnFmgjAVL6bA= github.com/uptrace/bun/dialect/pgdialect v1.2.16 h1:KFNZ0LxAyczKNfK/IJWMyaleO6eI9/Z5tUv3DE1NVL4=
github.com/uptrace/bun/dialect/pgdialect v1.2.15/go.mod h1:QSiz6Qpy9wlGFsfpf7UMSL6mXAL1jDJhFwuOVacCnOQ= github.com/uptrace/bun/dialect/pgdialect v1.2.16/go.mod h1:IJdMeV4sLfh0LDUZl7TIxLI0LipF1vwTK3hBC7p5qLo=
github.com/uptrace/bun/driver/pgdriver v1.2.15 h1:eZZ60ZtUUE6jjv6VAI1pCMaTgtx3sxmChQzwbvchOOo= github.com/uptrace/bun/driver/pgdriver v1.2.16 h1:b1kpXKUxtTSGYow5Vlsb+dKV3z0R7aSAJNfMfKp61ZU=
github.com/uptrace/bun/driver/pgdriver v1.2.15/go.mod h1:s2zz/BAeScal4KLFDI8PURwATN8s9RDBsElEbnPAjv4= github.com/uptrace/bun/driver/pgdriver v1.2.16/go.mod h1:H6lUZ9CBfp1X5Vq62YGSV7q96/v94ja9AYFjKvdoTk0=
github.com/uptrace/bun/extra/bundebug v1.2.16 h1:3OXAfHTU4ydu2+4j05oB1BxPx6+ypdWIVzTugl/7zl0=
github.com/uptrace/bun/extra/bundebug v1.2.16/go.mod h1:vk6R/1i67/S2RvUI5AH/m3P5e67mOkfDCmmCsAPUumo=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA=
@@ -54,15 +58,15 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -0,0 +1,23 @@
package account
import (
"time"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type Account struct {
bun.BaseModel `bun:"accounts"`
ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
UserID uuid.UUID `bun:"user_id,notnull,type:uuid" json:"userId"`
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"`
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"`
CreatedAt time.Time `bun:"created_at,notnull,nullzero" json:"createdAt"`
UpdatedAt time.Time `bun:"updated_at,notnull,nullzero" json:"updatedAt"`
}
func newAccountID() (uuid.UUID, error) {
return uuid.NewV7()
}

View File

@@ -0,0 +1,8 @@
package account
import "errors"
var (
ErrAccountNotFound = errors.New("account not found")
ErrAccountAlreadyExists = errors.New("account already exists")
)

View File

@@ -0,0 +1,130 @@
package account
import (
"errors"
"github.com/get-drexa/drexa/internal/auth"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/reqctx"
"github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type HTTPHandler struct {
accountService *Service
authService *auth.Service
db *bun.DB
authMiddleware fiber.Handler
}
type registerAccountRequest struct {
Email string `json:"email"`
Password string `json:"password"`
DisplayName string `json:"displayName"`
}
type registerAccountResponse struct {
Account *Account `json:"account"`
User *user.User `json:"user"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
}
const currentAccountKey = "currentAccount"
func CurrentAccount(c *fiber.Ctx) *Account {
return c.Locals(currentAccountKey).(*Account)
}
func NewHTTPHandler(accountService *Service, authService *auth.Service, db *bun.DB, authMiddleware fiber.Handler) *HTTPHandler {
return &HTTPHandler{accountService: accountService, authService: authService, db: db, authMiddleware: authMiddleware}
}
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) fiber.Router {
api.Post("/accounts", h.registerAccount)
account := api.Group("/accounts/:accountID")
account.Use(h.authMiddleware)
account.Use(h.accountMiddleware)
account.Get("/", h.getAccount)
return account
}
func (h *HTTPHandler) accountMiddleware(c *fiber.Ctx) error {
u := reqctx.AuthenticatedUser(c).(*user.User)
accountID, err := uuid.Parse(c.Params("accountID"))
if err != nil {
return c.SendStatus(fiber.StatusNotFound)
}
account, err := h.accountService.AccountByID(c.Context(), h.db, u.ID, accountID)
if err != nil {
if errors.Is(err, ErrAccountNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
c.Locals(currentAccountKey, account)
return c.Next()
}
func (h *HTTPHandler) getAccount(c *fiber.Ctx) error {
account := CurrentAccount(c)
if account == nil {
return c.SendStatus(fiber.StatusNotFound)
}
return c.JSON(account)
}
func (h *HTTPHandler) registerAccount(c *fiber.Ctx) error {
req := new(registerAccountRequest)
if err := c.BodyParser(req); err != nil {
return c.SendStatus(fiber.StatusBadRequest)
}
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return httperr.Internal(err)
}
defer tx.Rollback()
acc, u, err := h.accountService.Register(c.Context(), tx, RegisterOptions{
Email: req.Email,
Password: req.Password,
DisplayName: req.DisplayName,
})
if err != nil {
var ae *user.AlreadyExistsError
if errors.As(err, &ae) {
return c.SendStatus(fiber.StatusConflict)
}
if errors.Is(err, ErrAccountAlreadyExists) {
return c.SendStatus(fiber.StatusConflict)
}
return httperr.Internal(err)
}
result, err := h.authService.GrantForUser(c.Context(), tx, u)
if err != nil {
return httperr.Internal(err)
}
err = tx.Commit()
if err != nil {
return httperr.Internal(err)
}
return c.JSON(registerAccountResponse{
Account: acc,
User: u,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
})
}

View File

@@ -0,0 +1,115 @@
package account
import (
"context"
"database/sql"
"errors"
"github.com/get-drexa/drexa/internal/database"
"github.com/get-drexa/drexa/internal/password"
"github.com/get-drexa/drexa/internal/user"
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type Service struct {
userService user.Service
vfs *virtualfs.VirtualFS
}
type RegisterOptions struct {
Email string
Password string
DisplayName string
}
type CreateAccountOptions struct {
OrganizationID uuid.UUID
QuotaBytes int64
}
func NewService(userService *user.Service, vfs *virtualfs.VirtualFS) *Service {
return &Service{
userService: *userService,
vfs: vfs,
}
}
func (s *Service) Register(ctx context.Context, db bun.IDB, opts RegisterOptions) (*Account, *user.User, error) {
hashed, err := password.HashString(opts.Password)
if err != nil {
return nil, nil, err
}
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
Email: opts.Email,
Password: hashed,
DisplayName: opts.DisplayName,
})
if err != nil {
return nil, nil, err
}
acc, err := s.CreateAccount(ctx, db, u.ID, CreateAccountOptions{
// TODO: make quota configurable
QuotaBytes: 1024 * 1024 * 1024, // 1GB
})
if err != nil {
return nil, nil, err
}
_, err = s.vfs.CreateDirectory(ctx, db, acc.ID, uuid.Nil, virtualfs.RootDirectoryName)
if err != nil {
return nil, nil, err
}
return acc, u, nil
}
func (s *Service) CreateAccount(ctx context.Context, db bun.IDB, userID uuid.UUID, opts CreateAccountOptions) (*Account, error) {
id, err := newAccountID()
if err != nil {
return nil, err
}
account := &Account{
ID: id,
UserID: userID,
StorageQuotaBytes: opts.QuotaBytes,
}
_, err = db.NewInsert().Model(account).Returning("*").Exec(ctx)
if err != nil {
if database.IsUniqueViolation(err) {
return nil, ErrAccountAlreadyExists
}
return nil, err
}
return account, nil
}
func (s *Service) AccountByUserID(ctx context.Context, db bun.IDB, userID uuid.UUID) (*Account, error) {
var account Account
err := db.NewSelect().Model(&account).Where("user_id = ?", userID).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrAccountNotFound
}
return nil, err
}
return &account, nil
}
func (s *Service) AccountByID(ctx context.Context, db bun.IDB, userID uuid.UUID, id uuid.UUID) (*Account, error) {
var account Account
err := db.NewSelect().Model(&account).Where("user_id = ?", userID).Where("id = ?", id).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrAccountNotFound
}
return nil, err
}
return &account, nil
}

View File

@@ -0,0 +1,56 @@
package auth
import (
"time"
"github.com/gofiber/fiber/v2"
)
// CookieConfig controls auth cookie behavior.
type CookieConfig struct {
// Domain for cross-subdomain cookies (e.g., "app.com" for web.app.com + api.app.com).
// Leave empty for same-host cookies (localhost, single domain).
Domain string
}
// authCookies returns auth cookies from the given fiber context.
// Returns a map with the cookie names as keys and the cookie values as values.
func authCookies(c *fiber.Ctx) map[string]string {
m := make(map[string]string)
at := c.Cookies(cookieKeyAccessToken)
if at != "" {
m[cookieKeyAccessToken] = at
}
rt := c.Cookies(cookieKeyRefreshToken)
if rt != "" {
m[cookieKeyRefreshToken] = rt
}
return m
}
// setAuthCookies sets HTTP-only auth cookies with security settings derived from the request.
// Secure flag is based on actual protocol (works automatically with proxies/tunnels).
func setAuthCookies(c *fiber.Ctx, accessToken, refreshToken string, cfg CookieConfig) {
secure := c.Protocol() == "https"
c.Cookie(&fiber.Cookie{
Name: cookieKeyAccessToken,
Value: accessToken,
Path: "/",
Domain: cfg.Domain,
Expires: time.Now().Add(accessTokenValidFor),
SameSite: fiber.CookieSameSiteLaxMode,
HTTPOnly: true,
Secure: secure,
})
c.Cookie(&fiber.Cookie{
Name: cookieKeyRefreshToken,
Value: refreshToken,
Path: "/",
Domain: cfg.Domain,
Expires: time.Now().Add(refreshTokenValidFor),
SameSite: fiber.CookieSameSiteLaxMode,
HTTPOnly: true,
Secure: secure,
})
}

View File

@@ -5,7 +5,9 @@ import (
"fmt" "fmt"
) )
var ErrUnauthenticatedRequest = errors.New("unauthenticated request") var ErrInvalidRefreshToken = errors.New("invalid refresh token")
var ErrRefreshTokenExpired = errors.New("refresh token expired")
var ErrRefreshTokenReused = errors.New("refresh token reused")
type InvalidAccessTokenError struct { type InvalidAccessTokenError struct {
err error err error

View File

@@ -0,0 +1,22 @@
package auth
import (
"time"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type Grant struct {
bun.BaseModel `bun:"grants"`
ID uuid.UUID `bun:",pk,type:uuid"`
UserID uuid.UUID `bun:"user_id,notnull"`
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
UpdatedAt time.Time `bun:"updated_at,notnull,nullzero"`
RevokedAt *time.Time `bun:"revoked_at,nullzero"`
}
func newGrantID() (uuid.UUID, error) {
return uuid.NewV7()
}

View File

@@ -4,42 +4,57 @@ import (
"errors" "errors"
"log/slog" "log/slog"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type loginRequest struct { const (
Email string `json:"email"` tokenDeliveryCookie = "cookie"
Password string `json:"password"` tokenDeliveryBody = "body"
} )
type registerRequest struct { const (
Email string `json:"email"` cookieKeyAccessToken = "access_token"
Password string `json:"password"` cookieKeyRefreshToken = "refresh_token"
DisplayName string `json:"displayName"` )
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
TokenDelivery string `json:"tokenDelivery"`
} }
type loginResponse struct { type loginResponse struct {
User user.User `json:"user"` User user.User `json:"user"`
AccessToken string `json:"accessToken"` AccessToken string `json:"accessToken,omitempty"`
RefreshToken string `json:"refreshToken"` RefreshToken string `json:"refreshToken,omitempty"`
}
type refreshAccessTokenRequest struct {
RefreshToken string `json:"refreshToken"`
}
type tokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
} }
type HTTPHandler struct { type HTTPHandler struct {
service *Service service *Service
db *bun.DB db *bun.DB
cookieConfig CookieConfig
} }
func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler { func NewHTTPHandler(s *Service, db *bun.DB, cookieConfig CookieConfig) *HTTPHandler {
return &HTTPHandler{service: s, db: db} return &HTTPHandler{service: s, db: db, cookieConfig: cookieConfig}
} }
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) { func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
auth := api.Group("/auth") auth := api.Group("/auth")
auth.Post("/login", h.Login) auth.Post("/login", h.Login)
auth.Post("/register", h.Register) auth.Post("/tokens", h.refreshAccessToken)
} }
func (h *HTTPHandler) Login(c *fiber.Ctx) error { func (h *HTTPHandler) Login(c *fiber.Ctx) error {
@@ -50,7 +65,7 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
tx, err := h.db.BeginTx(c.Context(), nil) tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) return httperr.Internal(err)
} }
defer tx.Rollback() defer tx.Rollback()
@@ -59,54 +74,61 @@ func (h *HTTPHandler) Login(c *fiber.Ctx) error {
if errors.Is(err, ErrInvalidCredentials) { if errors.Is(err, ErrInvalidCredentials) {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"}) return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid credentials"})
} }
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) return httperr.Internal(err)
} }
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) return httperr.Internal(err)
} }
return c.JSON(loginResponse{ switch req.TokenDelivery {
User: *result.User, default:
AccessToken: result.AccessToken, return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid token delivery method"})
RefreshToken: result.RefreshToken,
}) case tokenDeliveryCookie:
setAuthCookies(c, result.AccessToken, result.RefreshToken, h.cookieConfig)
return c.JSON(loginResponse{
User: *result.User,
})
case tokenDeliveryBody:
return c.JSON(loginResponse{
User: *result.User,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
})
}
} }
func (h *HTTPHandler) Register(c *fiber.Ctx) error { func (h *HTTPHandler) refreshAccessToken(c *fiber.Ctx) error {
req := new(registerRequest) req := new(refreshAccessTokenRequest)
if err := c.BodyParser(req); err != nil { if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
} }
tx, err := h.db.BeginTx(c.Context(), nil) tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil { if err != nil {
slog.Error("failed to begin transaction", "error", err) return httperr.Internal(err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
} }
defer tx.Rollback() defer tx.Rollback()
result, err := h.service.Register(c.Context(), tx, registerOptions{ result, err := h.service.RefreshAccessToken(c.Context(), tx, req.RefreshToken)
email: req.Email,
password: req.Password,
displayName: req.DisplayName,
})
if err != nil { if err != nil {
var ae *user.AlreadyExistsError if errors.Is(err, ErrInvalidRefreshToken) ||
if errors.As(err, &ae) { errors.Is(err, ErrRefreshTokenExpired) ||
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "User already exists"}) errors.Is(err, ErrRefreshTokenReused) {
_ = tx.Commit()
slog.Info("invalid refresh token", "error", err)
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "invalid refresh token"})
} }
slog.Error("failed to register user", "error", err) return httperr.Internal(err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
} }
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
slog.Error("failed to commit transaction", "error", err) return httperr.Internal(err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
} }
return c.JSON(loginResponse{ return c.JSON(tokenResponse{
User: *result.User,
AccessToken: result.AccessToken, AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken, RefreshToken: result.RefreshToken,
}) })

View File

@@ -2,56 +2,81 @@ package auth
import ( import (
"errors" "errors"
"log/slog"
"strings" "strings"
"time"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/reqctx"
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
const authenticatedUserKey = "authenticatedUser" // NewAuthMiddleware creates a middleware that authenticates requests via Bearer token or cookies.
// To obtain the authenticated user in subsequent handlers, see reqctx.AuthenticatedUser.
// NewBearerAuthMiddleware is a middleware that authenticates a request using a bearer token. func NewAuthMiddleware(s *Service, db *bun.DB, cookieConfig CookieConfig) fiber.Handler {
// To obtain the authenticated user in subsequent handlers, see AuthenticatedUser.
func NewBearerAuthMiddleware(s *Service, db *bun.DB) fiber.Handler {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
var at string
var rt string
var setCookies bool
authHeader := c.Get("Authorization") authHeader := c.Get("Authorization")
if authHeader == "" { if authHeader != "" {
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
slog.Info("invalid auth header")
return c.SendStatus(fiber.StatusUnauthorized)
}
at = parts[1]
setCookies = false
} else {
cookies := authCookies(c)
at = cookies[cookieKeyAccessToken]
rt = cookies[cookieKeyRefreshToken]
setCookies = true
}
if at == "" {
slog.Info("no access token")
return c.SendStatus(fiber.StatusUnauthorized) return c.SendStatus(fiber.StatusUnauthorized)
} }
parts := strings.Split(authHeader, " ") authResult, err := s.AuthenticateWithAccessToken(c.Context(), db, at)
if len(parts) != 2 || parts[0] != "Bearer" {
return c.SendStatus(fiber.StatusUnauthorized)
}
token := parts[1]
u, err := s.AuthenticateWithAccessToken(c.Context(), db, token)
if err != nil { if err != nil {
var e *InvalidAccessTokenError var e *InvalidAccessTokenError
if errors.As(err, &e) { if errors.As(err, &e) {
slog.Info("invalid access token")
return c.SendStatus(fiber.StatusUnauthorized) return c.SendStatus(fiber.StatusUnauthorized)
} }
var nf *user.NotFoundError var nf *user.NotFoundError
if errors.As(err, &nf) { if errors.As(err, &nf) {
slog.Info("user not found")
return c.SendStatus(fiber.StatusUnauthorized) return c.SendStatus(fiber.StatusUnauthorized)
} }
return c.SendStatus(fiber.StatusInternalServerError) return httperr.Internal(err)
} }
c.Locals(authenticatedUserKey, u) reqctx.SetAuthenticatedUser(c, authResult.User)
// if cookie based auth and access token is about to expire (within 5 minutes),
// attempt to refresh the access token. if there is any error, ignore it and let the request continue.
if setCookies && time.Until(authResult.Claims.ExpiresAt.Time) < 5*time.Minute && rt != "" {
tx, txErr := db.BeginTx(c.Context(), nil)
if txErr == nil {
newTokens, err := s.RefreshAccessToken(c.Context(), tx, rt)
if err == nil {
if commitErr := tx.Commit(); commitErr == nil {
setAuthCookies(c, newTokens.AccessToken, newTokens.RefreshToken, cookieConfig)
}
} else {
_ = tx.Rollback()
slog.Debug("auto-refresh failed", "error", err)
}
}
}
return c.Next() return c.Next()
} }
} }
// AuthenticatedUser returns the authenticated user from the given fiber context.
// Returns ErrUnauthenticatedRequest if not authenticated.
func AuthenticatedUser(c *fiber.Ctx) (*user.User, error) {
if u, ok := c.Locals(authenticatedUserKey).(*user.User); ok {
return u, nil
}
return nil, ErrUnauthenticatedRequest
}

View File

@@ -2,21 +2,30 @@ package auth
import ( import (
"context" "context"
"encoding/hex" "database/sql"
"encoding/base64"
"errors" "errors"
"log/slog"
"time"
"github.com/get-drexa/drexa/internal/password" "github.com/get-drexa/drexa/internal/password"
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type LoginResult struct { type AuthenticationTokens struct {
User *user.User User *user.User
AccessToken string AccessToken string
RefreshToken string RefreshToken string
} }
type AccessTokenAuthentication struct {
Claims *jwt.RegisteredClaims
User *user.User
}
var ErrInvalidCredentials = errors.New("invalid credentials") var ErrInvalidCredentials = errors.New("invalid credentials")
type Service struct { type Service struct {
@@ -24,12 +33,6 @@ type Service struct {
tokenConfig TokenConfig tokenConfig TokenConfig
} }
type registerOptions struct {
displayName string
email string
password string
}
func NewService(userService *user.Service, tokenConfig TokenConfig) *Service { func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
return &Service{ return &Service{
userService: userService, userService: userService,
@@ -37,7 +40,204 @@ func NewService(userService *user.Service, tokenConfig TokenConfig) *Service {
} }
} }
func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*LoginResult, error) { func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*AuthenticationTokens, error) {
u, err := s.authenticateWithEmailAndPassword(ctx, db, email, plain)
if err != nil {
return nil, err
}
return s.GrantForUser(ctx, db, u)
}
func (s *Service) GrantForUser(ctx context.Context, db bun.IDB, user *user.User) (*AuthenticationTokens, error) {
id, err := newGrantID()
if err != nil {
return nil, err
}
grant := &Grant{
ID: id,
UserID: user.ID,
}
_, err = db.NewInsert().Model(grant).Exec(ctx)
if err != nil {
return nil, err
}
result, err := s.generateTokens(ctx, db, user, grant)
if err != nil {
return nil, err
}
return result, nil
}
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*AccessTokenAuthentication, error) {
claims, err := ParseAccessToken(token, &s.tokenConfig)
if err != nil {
slog.Info("failed to parse access token", "error", err)
return nil, err
}
id, err := uuid.Parse(claims.Subject)
if err != nil {
slog.Info("failed to parse access token subject", "error", err)
return nil, newInvalidAccessTokenError(err)
}
u, err := s.userService.UserByID(ctx, db, id)
if err != nil {
var nf *user.NotFoundError
if errors.As(err, &nf) {
return nil, newInvalidAccessTokenError(err)
}
return nil, err
}
return &AccessTokenAuthentication{
Claims: claims,
User: u,
}, nil
}
func (s *Service) RefreshAccessToken(ctx context.Context, db bun.IDB, refreshToken string) (*AuthenticationTokens, error) {
t, err := base64.URLEncoding.DecodeString(refreshToken)
if err != nil {
return nil, ErrInvalidRefreshToken
}
rtParts, err := DeserializeRefreshToken(t)
if err != nil {
return nil, ErrInvalidRefreshToken
}
rt := &RefreshToken{}
err = db.NewSelect().Model(rt).Where("key = ?", rtParts.Key).Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrInvalidRefreshToken
}
return nil, err
}
if rt.ExpiresAt.Before(time.Now()) {
return nil, ErrRefreshTokenExpired
}
if rt.ConsumedAt != nil {
// token reuse detected, invalidate all refresh tokens under the same grant
_, err = db.NewDelete().Model((*RefreshToken)(nil)).Where("grant_id = ?", rt.GrantID).Exec(ctx)
if err != nil {
return nil, err
}
_, err = db.NewUpdate().Model((*Grant)(nil)).Set("revoked_at = ?", time.Now()).Where("id = ?", rt.GrantID).Exec(ctx)
if err != nil {
return nil, err
}
return nil, ErrRefreshTokenReused
}
grant := &Grant{
ID: rt.GrantID,
}
err = db.NewSelect().Model(grant).WherePK().Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
// the grant for this refresh token was deleted, delete the refresh token
_, _ = db.NewDelete().Model(rt).WherePK().Exec(ctx)
return nil, ErrInvalidRefreshToken
}
return nil, err
}
if grant.RevokedAt != nil {
return nil, ErrInvalidRefreshToken
}
ok, err := password.Verify(rtParts.Token, rt.TokenHash)
if err != nil {
return nil, err
}
if !ok {
return nil, ErrInvalidRefreshToken
}
u, err := s.userService.UserByID(ctx, db, grant.UserID)
if err != nil {
var nf *user.NotFoundError
if errors.As(err, &nf) {
// the user for this grant was deleted, delete the grant and refresh token
_, _ = db.NewDelete().Model(grant).WherePK().Exec(ctx)
_, _ = db.NewDelete().Model(rt).WherePK().Exec(ctx)
return nil, ErrInvalidRefreshToken
}
return nil, err
}
newRT, err := GenerateRefreshToken(u, &s.tokenConfig)
if err != nil {
return nil, err
}
newRT.GrantID = grant.ID
_, err = db.NewUpdate().Model(rt).Set("consumed_at = ?", time.Now()).WherePK().Exec(ctx)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(newRT).Exec(ctx)
if err != nil {
return nil, err
}
at, err := GenerateAccessToken(u, &s.tokenConfig)
if err != nil {
return nil, err
}
srt, err := SerializeRefreshToken(newRT)
if err != nil {
return nil, err
}
return &AuthenticationTokens{
User: u,
AccessToken: at,
RefreshToken: base64.URLEncoding.EncodeToString(srt),
}, nil
}
func (s *Service) generateTokens(ctx context.Context, db bun.IDB, user *user.User, grant *Grant) (*AuthenticationTokens, error) {
at, err := GenerateAccessToken(user, &s.tokenConfig)
if err != nil {
return nil, err
}
rt, err := GenerateRefreshToken(user, &s.tokenConfig)
if err != nil {
return nil, err
}
rt.GrantID = grant.ID
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
srt, err := SerializeRefreshToken(rt)
if err != nil {
return nil, err
}
return &AuthenticationTokens{
User: user,
AccessToken: at,
RefreshToken: base64.URLEncoding.EncodeToString(srt),
}, nil
}
func (s *Service) authenticateWithEmailAndPassword(ctx context.Context, db bun.IDB, email, plain string) (*user.User, error) {
u, err := s.userService.UserByEmail(ctx, db, email) u, err := s.userService.UserByEmail(ctx, db, email)
if err != nil { if err != nil {
var nf *user.NotFoundError var nf *user.NotFoundError
@@ -47,80 +247,10 @@ func (s *Service) LoginWithEmailAndPassword(ctx context.Context, db bun.IDB, ema
return nil, err return nil, err
} }
ok, err := password.Verify(plain, u.Password) ok, err := password.VerifyString(plain, u.Password)
if err != nil || !ok { if err != nil || !ok {
return nil, ErrInvalidCredentials return nil, ErrInvalidCredentials
} }
at, err := GenerateAccessToken(u, &s.tokenConfig) return u, nil
if err != nil {
return nil, err
}
rt, err := GenerateRefreshToken(u, &s.tokenConfig)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
return &LoginResult{
User: u,
AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token),
}, nil
}
func (s *Service) Register(ctx context.Context, db bun.IDB, opts registerOptions) (*LoginResult, error) {
hashed, err := password.Hash(opts.password)
if err != nil {
return nil, err
}
u, err := s.userService.RegisterUser(ctx, db, user.UserRegistrationOptions{
Email: opts.email,
DisplayName: opts.displayName,
Password: hashed,
})
if err != nil {
return nil, err
}
at, err := GenerateAccessToken(u, &s.tokenConfig)
if err != nil {
return nil, err
}
rt, err := GenerateRefreshToken(u, &s.tokenConfig)
if err != nil {
return nil, err
}
_, err = db.NewInsert().Model(rt).Exec(ctx)
if err != nil {
return nil, err
}
return &LoginResult{
User: u,
AccessToken: at,
RefreshToken: hex.EncodeToString(rt.Token),
}, nil
}
func (s *Service) AuthenticateWithAccessToken(ctx context.Context, db bun.IDB, token string) (*user.User, error) {
claims, err := ParseAccessToken(token, &s.tokenConfig)
if err != nil {
return nil, err
}
id, err := uuid.Parse(claims.Subject)
if err != nil {
return nil, newInvalidAccessTokenError(err)
}
return s.userService.UserByID(ctx, db, id)
} }

View File

@@ -2,11 +2,10 @@ package auth
import ( import (
"crypto/rand" "crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"time" "time"
"github.com/get-drexa/drexa/internal/password"
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid" "github.com/google/uuid"
@@ -15,7 +14,9 @@ import (
const ( const (
accessTokenValidFor = time.Minute * 15 accessTokenValidFor = time.Minute * 15
refreshTokenByteLength = 32 refreshTokenKeyLength = 16
refreshTokenDataLength = 32
refreshTokenLength = refreshTokenKeyLength + refreshTokenDataLength
refreshTokenValidFor = time.Hour * 24 * 30 refreshTokenValidFor = time.Hour * 24 * 30
) )
@@ -28,12 +29,19 @@ type TokenConfig struct {
type RefreshToken struct { type RefreshToken struct {
bun.BaseModel `bun:"refresh_tokens"` bun.BaseModel `bun:"refresh_tokens"`
ID uuid.UUID `bun:",pk,type:uuid"` ID uuid.UUID `bun:",pk,type:uuid"`
UserID uuid.UUID `bun:"user_id,notnull"` GrantID uuid.UUID `bun:"grant_id,notnull"`
Token []byte `bun:"-"` Token []byte `bun:"-"`
TokenHash string `bun:"token_hash,notnull"` Key uuid.UUID `bun:"key,notnull"`
ExpiresAt time.Time `bun:"expires_at,notnull"` TokenHash password.Hashed `bun:"token_hash,notnull"`
CreatedAt time.Time `bun:"created_at,notnull"` ExpiresAt time.Time `bun:"expires_at,notnull"`
CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
ConsumedAt *time.Time `bun:"consumed_at,nullzero"`
}
type RefreshTokenParts struct {
Key uuid.UUID
Token []byte
} }
func newTokenID() (uuid.UUID, error) { func newTokenID() (uuid.UUID, error) {
@@ -62,24 +70,31 @@ func GenerateAccessToken(user *user.User, c *TokenConfig) (string, error) {
func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error) { func GenerateRefreshToken(user *user.User, c *TokenConfig) (*RefreshToken, error) {
now := time.Now() now := time.Now()
buf := make([]byte, refreshTokenByteLength) buf := make([]byte, refreshTokenDataLength)
if _, err := rand.Read(buf); err != nil { if _, err := rand.Read(buf); err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err) return nil, fmt.Errorf("failed to generate refresh token: %w", err)
} }
key, err := uuid.NewRandom()
if err != nil {
return nil, fmt.Errorf("failed to generate refresh token key: %w", err)
}
id, err := newTokenID() id, err := newTokenID()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate token ID: %w", err) return nil, fmt.Errorf("failed to generate token ID: %w", err)
} }
h := sha256.Sum256(buf) hashed, err := password.Hash(buf)
hex := hex.EncodeToString(h[:]) if err != nil {
return nil, fmt.Errorf("failed to hash refresh token: %w", err)
}
return &RefreshToken{ return &RefreshToken{
ID: id, ID: id,
UserID: user.ID,
Token: buf, Token: buf,
TokenHash: hex, Key: key,
TokenHash: hashed,
ExpiresAt: now.Add(refreshTokenValidFor), ExpiresAt: now.Add(refreshTokenValidFor),
CreatedAt: now, CreatedAt: now,
}, nil }, nil
@@ -96,3 +111,35 @@ func ParseAccessToken(token string, c *TokenConfig) (*jwt.RegisteredClaims, erro
} }
return parsed.Claims.(*jwt.RegisteredClaims), nil return parsed.Claims.(*jwt.RegisteredClaims), nil
} }
func SerializeRefreshToken(token *RefreshToken) ([]byte, error) {
keyBytes, err := token.Key.MarshalBinary()
if err != nil {
return nil, fmt.Errorf("failed to marshal refresh token key: %w", err)
}
buf := make([]byte, 0, len(keyBytes)+len(token.Token))
buf = append(buf, keyBytes...)
buf = append(buf, token.Token...)
return buf, nil
}
func DeserializeRefreshToken(buf []byte) (*RefreshTokenParts, error) {
if len(buf) != refreshTokenLength {
return nil, fmt.Errorf("refresh token is too short")
}
keyBytes := buf[:refreshTokenKeyLength]
token := buf[refreshTokenKeyLength:]
key, err := uuid.FromBytes(keyBytes)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal refresh token key: %w", err)
}
return &RefreshTokenParts{
Key: key,
Token: token,
}, nil
}

View File

@@ -3,6 +3,7 @@ package blob
import ( import (
"context" "context"
"io" "io"
"log/slog"
"os" "os"
"path/filepath" "path/filepath"
@@ -16,25 +17,30 @@ type FSStore struct {
} }
type FSStoreConfig struct { type FSStoreConfig struct {
Root string Root string
UploadURL string
} }
func NewFSStore(config FSStoreConfig) *FSStore { func NewFSStore(config FSStoreConfig) *FSStore {
return &FSStore{config: config} return &FSStore{config: config}
} }
func (s *FSStore) SupportsDirectUpload() bool {
return false
}
func (s *FSStore) SupportsDirectDownload() bool {
return false
}
func (s *FSStore) Initialize(ctx context.Context) error { func (s *FSStore) Initialize(ctx context.Context) error {
return os.MkdirAll(s.config.Root, 0755) return os.MkdirAll(s.config.Root, 0755)
} }
func (s *FSStore) GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error) {
return s.config.UploadURL, nil
}
func (s *FSStore) Put(ctx context.Context, key Key, reader io.Reader) error { func (s *FSStore) Put(ctx context.Context, key Key, reader io.Reader) error {
path := filepath.Join(s.config.Root, string(key)) path := filepath.Join(s.config.Root, string(key))
slog.Info("fs store: putting file", "path", path)
err := os.MkdirAll(filepath.Dir(path), 0755) err := os.MkdirAll(filepath.Dir(path), 0755)
if err != nil { if err != nil {
return err return err
@@ -60,6 +66,9 @@ func (s *FSStore) Put(ctx context.Context, key Key, reader io.Reader) error {
func (s *FSStore) Read(ctx context.Context, key Key) (io.ReadCloser, error) { func (s *FSStore) Read(ctx context.Context, key Key) (io.ReadCloser, error) {
path := filepath.Join(s.config.Root, string(key)) path := filepath.Join(s.config.Root, string(key))
slog.Info("fs store: reading file", "path", path)
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@@ -73,6 +82,8 @@ func (s *FSStore) Read(ctx context.Context, key Key) (io.ReadCloser, error) {
func (s *FSStore) ReadRange(ctx context.Context, key Key, offset, length int64) (io.ReadCloser, error) { func (s *FSStore) ReadRange(ctx context.Context, key Key, offset, length int64) (io.ReadCloser, error) {
path := filepath.Join(s.config.Root, string(key)) path := filepath.Join(s.config.Root, string(key))
slog.Info("fs store: reading range", "path", path, "offset", offset, "length", length)
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@@ -91,6 +102,9 @@ func (s *FSStore) ReadRange(ctx context.Context, key Key, offset, length int64)
func (s *FSStore) ReadSize(ctx context.Context, key Key) (int64, error) { func (s *FSStore) ReadSize(ctx context.Context, key Key) (int64, error) {
path := filepath.Join(s.config.Root, string(key)) path := filepath.Join(s.config.Root, string(key))
slog.Info("fs store: reading size", "path", path)
fi, err := os.Stat(path) fi, err := os.Stat(path)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@@ -102,7 +116,9 @@ func (s *FSStore) ReadSize(ctx context.Context, key Key) (int64, error) {
} }
func (s *FSStore) Delete(ctx context.Context, key Key) error { func (s *FSStore) Delete(ctx context.Context, key Key) error {
err := os.Remove(filepath.Join(s.config.Root, string(key))) path := filepath.Join(s.config.Root, string(key))
slog.Info("fs store: deleting file", "path", path)
err := os.Remove(path)
// no op if file does not exist // no op if file does not exist
// swallow error if file does not exist // swallow error if file does not exist
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {
@@ -111,6 +127,15 @@ func (s *FSStore) Delete(ctx context.Context, key Key) error {
return nil return nil
} }
func (s *FSStore) DeletePrefix(ctx context.Context, prefix Key) error {
prefixPath := filepath.Join(s.config.Root, string(prefix))
err := os.RemoveAll(prefixPath)
if err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
func (s *FSStore) Update(ctx context.Context, key Key, opts UpdateOptions) error { func (s *FSStore) Update(ctx context.Context, key Key, opts UpdateOptions) error {
// Update is a no-op for FSStore // Update is a no-op for FSStore
return nil return nil
@@ -140,3 +165,11 @@ func (s *FSStore) Move(ctx context.Context, srcKey, dstKey Key) error {
return nil return nil
} }
func (s *FSStore) GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error) {
return "", nil
}
func (s *FSStore) GenerateDownloadURL(ctx context.Context, key Key, opts DownloadURLOptions) (string, error) {
return "", nil
}

View File

@@ -2,13 +2,6 @@ package blob
type Key string type Key string
type KeyMode int
const (
KeyModeStable KeyMode = iota
KeyModeDerived
)
func (k Key) IsNil() bool { func (k Key) IsNil() bool {
return k == "" return k == ""
} }

View File

@@ -10,18 +10,34 @@ type UploadURLOptions struct {
Duration time.Duration Duration time.Duration
} }
type DownloadURLOptions struct {
Duration time.Duration
}
type UpdateOptions struct { type UpdateOptions struct {
ContentType string ContentType string
} }
type Store interface { type Store interface {
// SupportsDirectUpload returns true if the store allows files to be uploaded directly to the blob store.
SupportsDirectUpload() bool
// SupportsDirectDownload returns true if the store allows files to be downloaded directly from the blob store.
SupportsDirectDownload() bool
Initialize(ctx context.Context) error Initialize(ctx context.Context) error
GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error)
Put(ctx context.Context, key Key, reader io.Reader) error Put(ctx context.Context, key Key, reader io.Reader) error
Update(ctx context.Context, key Key, opts UpdateOptions) error Update(ctx context.Context, key Key, opts UpdateOptions) error
Delete(ctx context.Context, key Key) error Delete(ctx context.Context, key Key) error
DeletePrefix(ctx context.Context, prefix Key) error
Move(ctx context.Context, srcKey, dstKey Key) error Move(ctx context.Context, srcKey, dstKey Key) error
Read(ctx context.Context, key Key) (io.ReadCloser, error) Read(ctx context.Context, key Key) (io.ReadCloser, error)
ReadRange(ctx context.Context, key Key, offset, length int64) (io.ReadCloser, error) ReadRange(ctx context.Context, key Key, offset, length int64) (io.ReadCloser, error)
ReadSize(ctx context.Context, key Key) (int64, error) ReadSize(ctx context.Context, key Key) (int64, error)
// GenerateUploadURL generates a URL that can be used to upload a file directly to the blob store. If unsupported, returns an empty string with no error.
GenerateUploadURL(ctx context.Context, key Key, opts UploadURLOptions) (string, error)
// GenerateDownloadURL generates a URL that can be used to download a file directly from the blob store. If unsupported, returns an empty string with no error.
GenerateDownloadURL(ctx context.Context, key Key, opts DownloadURLOptions) (string, error)
} }

View File

@@ -0,0 +1,239 @@
package catalog
import (
"errors"
"time"
"github.com/get-drexa/drexa/internal/account"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/gofiber/fiber/v2"
)
const (
DirItemKindDirectory = "directory"
DirItemKindFile = "file"
)
type DirectoryInfo struct {
Kind string `json:"kind"`
ID string `json:"id"`
Path virtualfs.Path `json:"path,omitempty"`
Name string `json:"name"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt *time.Time `json:"deletedAt,omitempty"`
}
type createDirectoryRequest struct {
ParentID string `json:"parentID"`
Name string `json:"name"`
}
func (h *HTTPHandler) currentDirectoryMiddleware(c *fiber.Ctx) error {
account := account.CurrentAccount(c)
if account == nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
directoryID := c.Params("directoryID")
node, err := h.vfs.FindNodeByPublicID(c.Context(), h.db, account.ID, directoryID)
if err != nil {
if errors.Is(err, virtualfs.ErrNodeNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
c.Locals("directory", node)
return c.Next()
}
func mustCurrentDirectoryNode(c *fiber.Ctx) *virtualfs.Node {
return c.Locals("directory").(*virtualfs.Node)
}
func (h *HTTPHandler) createDirectory(c *fiber.Ctx) error {
account := account.CurrentAccount(c)
if account == nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
req := new(createDirectoryRequest)
if err := c.BodyParser(req); err != nil {
return c.SendStatus(fiber.StatusBadRequest)
}
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return httperr.Internal(err)
}
defer tx.Rollback()
parent, err := h.vfs.FindNodeByPublicID(c.Context(), tx, account.ID, req.ParentID)
if err != nil {
if errors.Is(err, virtualfs.ErrNodeNotFound) {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Parent not found"})
}
return httperr.Internal(err)
}
if parent.Kind != virtualfs.NodeKindDirectory {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Parent is not a directory"})
}
node, err := h.vfs.CreateDirectory(c.Context(), tx, account.ID, parent.ID, req.Name)
if err != nil {
if errors.Is(err, virtualfs.ErrNodeConflict) {
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "Directory already exists"})
}
return httperr.Internal(err)
}
err = tx.Commit()
if err != nil {
return httperr.Internal(err)
}
return c.JSON(DirectoryInfo{
Kind: DirItemKindDirectory,
ID: node.PublicID,
Name: node.Name,
CreatedAt: node.CreatedAt,
UpdatedAt: node.UpdatedAt,
DeletedAt: node.DeletedAt,
})
}
func (h *HTTPHandler) fetchDirectory(c *fiber.Ctx) error {
node := mustCurrentDirectoryNode(c)
i := DirectoryInfo{
Kind: DirItemKindDirectory,
ID: node.PublicID,
Name: node.Name,
CreatedAt: node.CreatedAt,
UpdatedAt: node.UpdatedAt,
DeletedAt: node.DeletedAt,
}
include := c.Query("include")
if include == "path" {
p, err := h.vfs.RealPath(c.Context(), h.db, node)
if err != nil {
return httperr.Internal(err)
}
i.Path = p
}
return c.JSON(i)
}
func (h *HTTPHandler) listDirectory(c *fiber.Ctx) error {
node := mustCurrentDirectoryNode(c)
children, err := h.vfs.ListChildren(c.Context(), h.db, node)
if err != nil {
if errors.Is(err, virtualfs.ErrNodeNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
items := make([]any, len(children))
for i, child := range children {
switch child.Kind {
case virtualfs.NodeKindDirectory:
items[i] = DirectoryInfo{
Kind: DirItemKindDirectory,
ID: child.PublicID,
Name: child.Name,
CreatedAt: child.CreatedAt,
UpdatedAt: child.UpdatedAt,
DeletedAt: child.DeletedAt,
}
case virtualfs.NodeKindFile:
items[i] = FileInfo{
Kind: DirItemKindFile,
ID: child.PublicID,
Name: child.Name,
Size: child.Size,
MimeType: child.MimeType,
CreatedAt: child.CreatedAt,
UpdatedAt: child.UpdatedAt,
DeletedAt: child.DeletedAt,
}
}
}
return c.JSON(items)
}
func (h *HTTPHandler) patchDirectory(c *fiber.Ctx) error {
node := mustCurrentDirectoryNode(c)
patch := new(patchDirectoryRequest)
if err := c.BodyParser(patch); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return httperr.Internal(err)
}
defer tx.Rollback()
if patch.Name != "" {
err := h.vfs.RenameNode(c.Context(), tx, node, patch.Name)
if err != nil {
if errors.Is(err, virtualfs.ErrNodeNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
}
err = tx.Commit()
if err != nil {
return httperr.Internal(err)
}
return c.JSON(DirectoryInfo{
Kind: DirItemKindDirectory,
ID: node.PublicID,
Name: node.Name,
CreatedAt: node.CreatedAt,
UpdatedAt: node.UpdatedAt,
DeletedAt: node.DeletedAt,
})
}
func (h *HTTPHandler) deleteDirectory(c *fiber.Ctx) error {
node := mustCurrentDirectoryNode(c)
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return httperr.Internal(err)
}
defer tx.Rollback()
shouldTrash := c.Query("trash") == "true"
if shouldTrash {
err = h.vfs.SoftDeleteNode(c.Context(), h.db, node)
if err != nil {
return httperr.Internal(err)
}
} else {
err = h.vfs.PermanentlyDeleteNode(c.Context(), h.db, node)
if err != nil {
return httperr.Internal(err)
}
}
err = tx.Commit()
if err != nil {
return httperr.Internal(err)
}
return c.SendStatus(fiber.StatusNoContent)
}

View File

@@ -0,0 +1,2 @@
// Package catalog handles file and directory browsing.
package catalog

View File

@@ -0,0 +1,171 @@
package catalog
import (
"errors"
"fmt"
"time"
"github.com/get-drexa/drexa/internal/account"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/gofiber/fiber/v2"
)
type FileInfo struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Size int64 `json:"size"`
MimeType string `json:"mimeType"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt *time.Time `json:"deletedAt,omitempty"`
}
func mustCurrentFileNode(c *fiber.Ctx) *virtualfs.Node {
return c.Locals("file").(*virtualfs.Node)
}
func (h *HTTPHandler) currentFileMiddleware(c *fiber.Ctx) error {
account := account.CurrentAccount(c)
if account == nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
fileID := c.Params("fileID")
node, err := h.vfs.FindNodeByPublicID(c.Context(), h.db, account.ID, fileID)
if err != nil {
if errors.Is(err, virtualfs.ErrNodeNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
c.Locals("file", node)
return c.Next()
}
func (h *HTTPHandler) fetchFile(c *fiber.Ctx) error {
node := mustCurrentFileNode(c)
i := FileInfo{
Kind: DirItemKindFile,
ID: node.PublicID,
Name: node.Name,
Size: node.Size,
MimeType: node.MimeType,
CreatedAt: node.CreatedAt,
UpdatedAt: node.UpdatedAt,
DeletedAt: node.DeletedAt,
}
return c.JSON(i)
}
func (h *HTTPHandler) downloadFile(c *fiber.Ctx) error {
node := mustCurrentFileNode(c)
content, err := h.vfs.ReadFile(c.Context(), h.db, node)
if err != nil {
if errors.Is(err, virtualfs.ErrUnsupportedOperation) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
if content.URL != "" {
return c.Redirect(content.URL, fiber.StatusTemporaryRedirect)
}
if content.Reader != nil {
if node.MimeType != "" {
c.Set("Content-Type", node.MimeType)
}
if content.Size > 0 {
return c.SendStream(content.Reader, int(content.Size))
}
return c.SendStream(content.Reader)
}
return httperr.Internal(errors.New("vfs returned neither a reader nor a URL"))
}
func (h *HTTPHandler) patchFile(c *fiber.Ctx) error {
node := mustCurrentFileNode(c)
patch := new(patchFileRequest)
if err := c.BodyParser(patch); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return httperr.Internal(err)
}
defer tx.Rollback()
if patch.Name != "" {
err := h.vfs.RenameNode(c.Context(), tx, node, patch.Name)
if err != nil {
if errors.Is(err, virtualfs.ErrNodeNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
}
}
err = tx.Commit()
if err != nil {
return httperr.Internal(err)
}
fmt.Printf("node deleted at: %v\n", node.DeletedAt)
return c.JSON(FileInfo{
ID: node.PublicID,
Name: node.Name,
Size: node.Size,
MimeType: node.MimeType,
CreatedAt: node.CreatedAt,
UpdatedAt: node.UpdatedAt,
DeletedAt: node.DeletedAt,
})
}
func (h *HTTPHandler) deleteFile(c *fiber.Ctx) error {
node := mustCurrentFileNode(c)
tx, err := h.db.BeginTx(c.Context(), nil)
if err != nil {
return httperr.Internal(err)
}
defer tx.Rollback()
shouldTrash := c.Query("trash") == "true"
if shouldTrash {
err = h.vfs.SoftDeleteNode(c.Context(), tx, node)
if err != nil {
return httperr.Internal(err)
}
return c.JSON(FileInfo{
ID: node.PublicID,
Name: node.Name,
Size: node.Size,
MimeType: node.MimeType,
CreatedAt: node.CreatedAt,
UpdatedAt: node.UpdatedAt,
DeletedAt: node.DeletedAt,
})
} else {
err = h.vfs.PermanentlyDeleteNode(c.Context(), tx, node)
if err != nil {
return httperr.Internal(err)
}
err = tx.Commit()
if err != nil {
return httperr.Internal(err)
}
return c.SendStatus(fiber.StatusNoContent)
}
}

View File

@@ -0,0 +1,42 @@
package catalog
import (
"github.com/get-drexa/drexa/internal/virtualfs"
"github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
)
type HTTPHandler struct {
vfs *virtualfs.VirtualFS
db *bun.DB
}
type patchFileRequest struct {
Name string `json:"name"`
}
type patchDirectoryRequest struct {
Name string `json:"name"`
}
func NewHTTPHandler(vfs *virtualfs.VirtualFS, db *bun.DB) *HTTPHandler {
return &HTTPHandler{vfs: vfs, db: db}
}
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
fg := api.Group("/files/:fileID")
fg.Use(h.currentFileMiddleware)
fg.Get("/", h.fetchFile)
fg.Get("/content", h.downloadFile)
fg.Patch("/", h.patchFile)
fg.Delete("/", h.deleteFile)
api.Post("/directories", h.createDirectory)
dg := api.Group("/directories/:directoryID")
dg.Use(h.currentDirectoryMiddleware)
dg.Get("/", h.fetchDirectory)
dg.Get("/content", h.listDirectory)
dg.Patch("/", h.patchDirectory)
dg.Delete("/", h.deleteDirectory)
}

View File

@@ -7,31 +7,49 @@ CREATE TABLE IF NOT EXISTS users (
display_name TEXT, display_name TEXT,
email TEXT NOT NULL UNIQUE, email TEXT NOT NULL UNIQUE,
password TEXT NOT NULL, password TEXT NOT NULL,
storage_usage_bytes BIGINT NOT NULL,
storage_quota_bytes BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE INDEX idx_users_email ON users(email); CREATE INDEX idx_users_email ON users(email);
CREATE TABLE IF NOT EXISTS refresh_tokens ( CREATE TABLE IF NOT EXISTS accounts (
id UUID PRIMARY KEY, id UUID PRIMARY KEY,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
storage_usage_bytes BIGINT NOT NULL DEFAULT 0,
storage_quota_bytes BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_accounts_user_id ON accounts(user_id);
CREATE TABLE IF NOT EXISTS grants (
id UUID PRIMARY KEY,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_at TIMESTAMPTZ
);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id UUID PRIMARY KEY,
grant_id UUID NOT NULL REFERENCES grants(id) ON DELETE CASCADE,
key UUID NOT NULL UNIQUE,
token_hash TEXT NOT NULL UNIQUE, token_hash TEXT NOT NULL UNIQUE,
expires_at TIMESTAMPTZ NOT NULL, expires_at TIMESTAMPTZ NOT NULL,
consumed_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE INDEX idx_refresh_tokens_user_id ON refresh_tokens(user_id); CREATE INDEX idx_refresh_tokens_grant_id ON refresh_tokens(grant_id);
CREATE INDEX idx_refresh_tokens_token_hash ON refresh_tokens(token_hash);
CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at); CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
-- Virtual filesystem nodes (unified files + directories) -- Virtual filesystem nodes (unified files + directories)
CREATE TABLE IF NOT EXISTS vfs_nodes ( CREATE TABLE IF NOT EXISTS vfs_nodes (
id UUID PRIMARY KEY, id UUID PRIMARY KEY,
public_id TEXT NOT NULL UNIQUE, -- opaque ID for external API (no timestamp leak) public_id TEXT NOT NULL UNIQUE, -- opaque ID for external API (no timestamp leak)
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, account_id UUID NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
parent_id UUID REFERENCES vfs_nodes(id) ON DELETE CASCADE, -- NULL = root directory parent_id UUID REFERENCES vfs_nodes(id) ON DELETE CASCADE, -- NULL = root directory
kind TEXT NOT NULL CHECK (kind IN ('file', 'directory')), kind TEXT NOT NULL CHECK (kind IN ('file', 'directory')),
status TEXT NOT NULL DEFAULT 'ready' CHECK (status IN ('pending', 'ready')), status TEXT NOT NULL DEFAULT 'ready' CHECK (status IN ('pending', 'ready')),
@@ -46,17 +64,17 @@ CREATE TABLE IF NOT EXISTS vfs_nodes (
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ, -- soft delete for trash deleted_at TIMESTAMPTZ, -- soft delete for trash
-- No duplicate names in same parent (per user, excluding deleted) -- No duplicate names in same parent (per account, excluding deleted)
CONSTRAINT unique_node_name UNIQUE NULLS NOT DISTINCT (user_id, parent_id, name, deleted_at) CONSTRAINT unique_node_name UNIQUE NULLS NOT DISTINCT (account_id, parent_id, name, deleted_at)
); );
CREATE INDEX idx_vfs_nodes_user_id ON vfs_nodes(user_id) WHERE deleted_at IS NULL; CREATE INDEX idx_vfs_nodes_account_id ON vfs_nodes(account_id) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_parent_id ON vfs_nodes(parent_id) WHERE deleted_at IS NULL; CREATE INDEX idx_vfs_nodes_parent_id ON vfs_nodes(parent_id) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_user_parent ON vfs_nodes(user_id, parent_id) WHERE deleted_at IS NULL; CREATE INDEX idx_vfs_nodes_account_parent ON vfs_nodes(account_id, parent_id) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_kind ON vfs_nodes(user_id, kind) WHERE deleted_at IS NULL; CREATE INDEX idx_vfs_nodes_kind ON vfs_nodes(account_id, kind) WHERE deleted_at IS NULL;
CREATE INDEX idx_vfs_nodes_deleted ON vfs_nodes(user_id, deleted_at) WHERE deleted_at IS NOT NULL; CREATE INDEX idx_vfs_nodes_deleted ON vfs_nodes(account_id, deleted_at) WHERE deleted_at IS NOT NULL;
CREATE INDEX idx_vfs_nodes_public_id ON vfs_nodes(public_id); CREATE INDEX idx_vfs_nodes_public_id ON vfs_nodes(public_id);
CREATE UNIQUE INDEX idx_vfs_nodes_user_root ON vfs_nodes(user_id) WHERE parent_id IS NULL; -- one root per user CREATE UNIQUE INDEX idx_vfs_nodes_account_root ON vfs_nodes(account_id) WHERE parent_id IS NULL; -- one root per account
CREATE INDEX idx_vfs_nodes_pending ON vfs_nodes(created_at) WHERE status = 'pending'; -- for cleanup job CREATE INDEX idx_vfs_nodes_pending ON vfs_nodes(created_at) WHERE status = 'pending'; -- for cleanup job
CREATE TABLE IF NOT EXISTS node_shares ( CREATE TABLE IF NOT EXISTS node_shares (
@@ -92,3 +110,9 @@ CREATE TRIGGER update_vfs_nodes_updated_at BEFORE UPDATE ON vfs_nodes
CREATE TRIGGER update_node_shares_updated_at BEFORE UPDATE ON node_shares CREATE TRIGGER update_node_shares_updated_at BEFORE UPDATE ON node_shares
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
CREATE TRIGGER update_accounts_updated_at BEFORE UPDATE ON accounts
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
CREATE TRIGGER update_grants_updated_at BEFORE UPDATE ON grants
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();

View File

@@ -27,6 +27,7 @@ type Config struct {
Database DatabaseConfig `yaml:"database"` Database DatabaseConfig `yaml:"database"`
JWT JWTConfig `yaml:"jwt"` JWT JWTConfig `yaml:"jwt"`
Storage StorageConfig `yaml:"storage"` Storage StorageConfig `yaml:"storage"`
Cookie CookieConfig `yaml:"cookie"`
} }
type ServerConfig struct { type ServerConfig struct {
@@ -52,6 +53,13 @@ type StorageConfig struct {
Bucket string `yaml:"bucket"` Bucket string `yaml:"bucket"`
} }
// CookieConfig controls auth cookie behavior.
// Domain is optional - only needed for cross-subdomain setups (e.g., "app.com" for web.app.com + api.app.com).
// Secure flag is derived from the request protocol automatically.
type CookieConfig struct {
Domain string `yaml:"domain"`
}
// ConfigFromFile loads configuration from a YAML file. // ConfigFromFile loads configuration from a YAML file.
// JWT secret key is loaded from JWT_SECRET_KEY env var (base64 encoded), // JWT secret key is loaded from JWT_SECRET_KEY env var (base64 encoded),
// falling back to the file path specified in jwt.secret_key_path. // falling back to the file path specified in jwt.secret_key_path.

View File

@@ -4,22 +4,49 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/get-drexa/drexa/internal/account"
"github.com/get-drexa/drexa/internal/auth" "github.com/get-drexa/drexa/internal/auth"
"github.com/get-drexa/drexa/internal/blob" "github.com/get-drexa/drexa/internal/blob"
"github.com/get-drexa/drexa/internal/catalog"
"github.com/get-drexa/drexa/internal/database" "github.com/get-drexa/drexa/internal/database"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/get-drexa/drexa/internal/upload" "github.com/get-drexa/drexa/internal/upload"
"github.com/get-drexa/drexa/internal/user" "github.com/get-drexa/drexa/internal/user"
"github.com/get-drexa/drexa/internal/virtualfs" "github.com/get-drexa/drexa/internal/virtualfs"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/fiber/v2/middleware/logger"
"github.com/uptrace/bun"
"github.com/uptrace/bun/extra/bundebug"
) )
func NewServer(c Config) (*fiber.App, error) { type Server struct {
app := fiber.New() config Config
db := database.NewFromPostgres(c.Database.PostgresURL) app *fiber.App
db *bun.DB
blobStore blob.Store
vfs *virtualfs.VirtualFS
userService *user.Service
authService *auth.Service
accountService *account.Service
uploadService *upload.Service
authMiddleware fiber.Handler
}
func NewServer(c Config) (*Server, error) {
app := fiber.New(fiber.Config{
ErrorHandler: httperr.ErrorHandler,
StreamRequestBody: true,
// Trust proxy headers (X-Forwarded-Proto, X-Forwarded-For) for proper
// protocol detection behind reverse proxies, tunnels (ngrok, cloudflare), etc.
EnableTrustedProxyCheck: true,
TrustedProxies: []string{"127.0.0.1", "::1"},
ProxyHeader: fiber.HeaderXForwardedFor,
})
app.Use(logger.New()) app.Use(logger.New())
db := database.NewFromPostgres(c.Database.PostgresURL)
db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
// Initialize blob store based on config // Initialize blob store based on config
var blobStore blob.Store var blobStore blob.Store
switch c.Storage.Backend { switch c.Storage.Backend {
@@ -33,7 +60,7 @@ func NewServer(c Config) (*fiber.App, error) {
return nil, fmt.Errorf("unknown storage backend: %s", c.Storage.Backend) return nil, fmt.Errorf("unknown storage backend: %s", c.Storage.Backend)
} }
err := blobStore.Initialize(context.Background()) err := blobStore.Initialize(context.TODO())
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize blob store: %w", err) return nil, fmt.Errorf("failed to initialize blob store: %w", err)
} }
@@ -49,7 +76,7 @@ func NewServer(c Config) (*fiber.App, error) {
return nil, fmt.Errorf("unknown storage mode: %s", c.Storage.Mode) return nil, fmt.Errorf("unknown storage mode: %s", c.Storage.Mode)
} }
vfs, err := virtualfs.NewVirtualFS(db, blobStore, keyResolver) vfs, err := virtualfs.New(blobStore, keyResolver)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create virtual file system: %w", err) return nil, fmt.Errorf("failed to create virtual file system: %w", err)
} }
@@ -61,10 +88,38 @@ func NewServer(c Config) (*fiber.App, error) {
SecretKey: c.JWT.SecretKey, SecretKey: c.JWT.SecretKey,
}) })
uploadService := upload.NewService(vfs, blobStore) uploadService := upload.NewService(vfs, blobStore)
accountService := account.NewService(userService, vfs)
cookieConfig := auth.CookieConfig{
Domain: c.Cookie.Domain,
}
authMiddleware := auth.NewAuthMiddleware(authService, db, cookieConfig)
api := app.Group("/api") api := app.Group("/api")
auth.NewHTTPHandler(authService, db).RegisterRoutes(api) auth.NewHTTPHandler(authService, db, cookieConfig).RegisterRoutes(api)
upload.NewHTTPHandler(uploadService).RegisterRoutes(api) user.NewHTTPHandler(userService, db, authMiddleware).RegisterRoutes(api)
return app, nil accountRouter := account.NewHTTPHandler(accountService, authService, db, authMiddleware).RegisterRoutes(api)
upload.NewHTTPHandler(uploadService, db).RegisterRoutes(accountRouter)
catalog.NewHTTPHandler(vfs, db).RegisterRoutes(accountRouter)
s := &Server{
config: c,
app: app,
db: db,
blobStore: blobStore,
vfs: vfs,
userService: userService,
authService: authService,
accountService: accountService,
uploadService: uploadService,
authMiddleware: authMiddleware,
}
return s, nil
}
func (s *Server) Start() error {
return s.app.Listen(fmt.Sprintf(":%d", s.config.Server.Port))
} }

View File

@@ -0,0 +1,42 @@
package httperr
import (
"fmt"
"github.com/gofiber/fiber/v2"
)
// HTTPError represents an HTTP error with a status code and underlying error.
type HTTPError struct {
Code int
Message string
Err error
}
// Error implements the error interface.
func (e *HTTPError) Error() string {
if e.Err != nil {
return fmt.Sprintf("HTTP %d: %s: %v", e.Code, e.Message, e.Err)
}
return fmt.Sprintf("HTTP %d: %s", e.Code, e.Message)
}
// Unwrap returns the underlying error.
func (e *HTTPError) Unwrap() error {
return e.Err
}
// NewHTTPError creates a new HTTPError with the given status code, message, and underlying error.
func NewHTTPError(code int, message string, err error) *HTTPError {
return &HTTPError{
Code: code,
Message: message,
Err: err,
}
}
// Internal creates a new HTTPError with status 500.
func Internal(err error) *HTTPError {
return NewHTTPError(fiber.StatusInternalServerError, "Internal", err)
}

View File

@@ -0,0 +1,64 @@
package httperr
import (
"errors"
"log/slog"
"github.com/gofiber/fiber/v2"
)
// ErrorHandler is a global error handler for Fiber that logs errors and returns appropriate responses.
func ErrorHandler(c *fiber.Ctx, err error) error {
// Default status code
code := fiber.StatusInternalServerError
message := "Internal"
// Check if it's our custom HTTPError
var httpErr *HTTPError
if errors.As(err, &httpErr) {
code = httpErr.Code
message = httpErr.Message
// Log the error with underlying error details
if httpErr.Err != nil {
slog.Error("HTTP error",
"status", code,
"message", message,
"error", httpErr.Err.Error(),
"path", c.Path(),
"method", c.Method(),
)
} else {
slog.Warn("HTTP error",
"status", code,
"message", message,
"path", c.Path(),
"method", c.Method(),
)
}
} else {
// Check if it's a Fiber error
var fiberErr *fiber.Error
if errors.As(err, &fiberErr) {
code = fiberErr.Code
message = fiberErr.Message
} else {
// Generic error - log it
slog.Error("Unhandled error",
"status", code,
"error", err.Error(),
"path", c.Path(),
"method", c.Method(),
)
}
}
// Set Content-Type header
c.Set(fiber.HeaderContentType, fiber.MIMEApplicationJSONCharsetUTF8)
// Return JSON response
return c.Status(code).JSON(fiber.Map{
"error": message,
})
}

View File

@@ -38,15 +38,21 @@ type argon2Hash struct {
hash []byte hash []byte
} }
// Hash securely hashes a plaintext password using argon2id. // HashString hashes the given password string.
func Hash(plain string) (Hashed, error) { // This is a convenience function that converts the password string to a byte slice and hashes it using Hash.
func HashString(pw string) (Hashed, error) {
return Hash([]byte(pw))
}
// Hash hashes the provided bytes.
func Hash(pw []byte) (Hashed, error) {
salt := make([]byte, saltLength) salt := make([]byte, saltLength)
if _, err := rand.Read(salt); err != nil { if _, err := rand.Read(salt); err != nil {
return "", fmt.Errorf("failed to generate salt: %w", err) return "", fmt.Errorf("failed to generate salt: %w", err)
} }
hash := argon2.IDKey( hash := argon2.IDKey(
[]byte(plain), pw,
salt, salt,
iterations, iterations,
memory, memory,
@@ -70,15 +76,19 @@ func Hash(plain string) (Hashed, error) {
return Hashed(encoded), nil return Hashed(encoded), nil
} }
func VerifyString(plain string, hashed Hashed) (bool, error) {
return Verify([]byte(plain), hashed)
}
// Verify checks if a plaintext password matches a hashed password. // Verify checks if a plaintext password matches a hashed password.
func Verify(plain string, hashed Hashed) (bool, error) { func Verify(plain []byte, hashed Hashed) (bool, error) {
h, err := decodeHash(string(hashed)) h, err := decodeHash(string(hashed))
if err != nil { if err != nil {
return false, err return false, err
} }
otherHash := argon2.IDKey( otherHash := argon2.IDKey(
[]byte(plain), plain,
h.salt, h.salt,
h.iterations, h.iterations,
h.memory, h.memory,

View File

@@ -0,0 +1,23 @@
package reqctx
import (
"errors"
"github.com/gofiber/fiber/v2"
)
const authenticatedUserKey = "authenticatedUser"
var ErrUnauthenticatedRequest = errors.New("unauthenticated request")
// AuthenticatedUser returns the authenticated user from the given fiber context.
// Returns ErrUnauthenticatedRequest if not authenticated.
// The caller must type assert the returned value to the appropriate user type.
func AuthenticatedUser(c *fiber.Ctx) any {
return c.Locals(authenticatedUserKey)
}
// SetAuthenticatedUser sets the authenticated user in the fiber context.
func SetAuthenticatedUser(c *fiber.Ctx, user interface{}) {
c.Locals(authenticatedUserKey, user)
}

View File

@@ -6,4 +6,5 @@ var (
ErrNotFound = errors.New("not found") ErrNotFound = errors.New("not found")
ErrParentNotDirectory = errors.New("parent is not a directory") ErrParentNotDirectory = errors.New("parent is not a directory")
ErrConflict = errors.New("node conflict") ErrConflict = errors.New("node conflict")
ErrContentNotUploaded = errors.New("content has not been uploaded")
) )

View File

@@ -2,9 +2,12 @@ package upload
import ( import (
"errors" "errors"
"fmt"
"github.com/get-drexa/drexa/internal/auth" "github.com/get-drexa/drexa/internal/account"
"github.com/get-drexa/drexa/internal/httperr"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
) )
type createUploadRequest struct { type createUploadRequest struct {
@@ -18,10 +21,11 @@ type updateUploadRequest struct {
type HTTPHandler struct { type HTTPHandler struct {
service *Service service *Service
db *bun.DB
} }
func NewHTTPHandler(s *Service) *HTTPHandler { func NewHTTPHandler(s *Service, db *bun.DB) *HTTPHandler {
return &HTTPHandler{service: s} return &HTTPHandler{service: s, db: db}
} }
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) { func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
@@ -33,9 +37,9 @@ func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
} }
func (h *HTTPHandler) Create(c *fiber.Ctx) error { func (h *HTTPHandler) Create(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c) account := account.CurrentAccount(c)
if err != nil { if account == nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"}) return c.SendStatus(fiber.StatusUnauthorized)
} }
req := new(createUploadRequest) req := new(createUploadRequest)
@@ -43,38 +47,54 @@ func (h *HTTPHandler) Create(c *fiber.Ctx) error {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
} }
upload, err := h.service.CreateUpload(c.Context(), u.ID, CreateUploadOptions{ upload, err := h.service.CreateUpload(c.Context(), h.db, account.ID, CreateUploadOptions{
ParentID: req.ParentID, ParentID: req.ParentID,
Name: req.Name, Name: req.Name,
}) })
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) if errors.Is(err, ErrNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
if errors.Is(err, ErrParentNotDirectory) {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Parent is not a directory"})
}
if errors.Is(err, ErrConflict) {
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "A file with this name already exists"})
}
return httperr.Internal(err)
}
if upload.UploadURL == "" {
upload.UploadURL = fmt.Sprintf("%s%s/%s/content", c.BaseURL(), c.OriginalURL(), upload.ID)
} }
return c.JSON(upload) return c.JSON(upload)
} }
func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error { func (h *HTTPHandler) ReceiveContent(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c) account := account.CurrentAccount(c)
if err != nil { if account == nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"}) return c.SendStatus(fiber.StatusUnauthorized)
} }
uploadID := c.Params("uploadID") uploadID := c.Params("uploadID")
err = h.service.ReceiveUpload(c.Context(), u.ID, uploadID, c.Request().BodyStream()) err := h.service.ReceiveUpload(c.Context(), h.db, account.ID, uploadID, c.Context().RequestBodyStream())
defer c.Request().CloseBodyStream() defer c.Context().Request.CloseBodyStream()
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) if errors.Is(err, ErrNotFound) {
return c.SendStatus(fiber.StatusNotFound)
}
return httperr.Internal(err)
} }
return c.SendStatus(fiber.StatusNoContent) return c.SendStatus(fiber.StatusNoContent)
} }
func (h *HTTPHandler) Update(c *fiber.Ctx) error { func (h *HTTPHandler) Update(c *fiber.Ctx) error {
u, err := auth.AuthenticatedUser(c) account := account.CurrentAccount(c)
if err != nil { if account == nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"}) return c.SendStatus(fiber.StatusUnauthorized)
} }
req := new(updateUploadRequest) req := new(updateUploadRequest)
@@ -83,12 +103,15 @@ func (h *HTTPHandler) Update(c *fiber.Ctx) error {
} }
if req.Status == StatusCompleted { if req.Status == StatusCompleted {
upload, err := h.service.CompleteUpload(c.Context(), u.ID, c.Params("uploadID")) upload, err := h.service.CompleteUpload(c.Context(), h.db, account.ID, c.Params("uploadID"))
if err != nil { if err != nil {
if errors.Is(err, ErrNotFound) { if errors.Is(err, ErrNotFound) {
return c.SendStatus(fiber.StatusNotFound) return c.SendStatus(fiber.StatusNotFound)
} }
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) if errors.Is(err, ErrContentNotUploaded) {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Content has not been uploaded"})
}
return httperr.Internal(err)
} }
return c.JSON(upload) return c.JSON(upload)
} }

View File

@@ -3,6 +3,7 @@ package upload
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"io" "io"
"sync" "sync"
"time" "time"
@@ -10,6 +11,7 @@ import (
"github.com/get-drexa/drexa/internal/blob" "github.com/get-drexa/drexa/internal/blob"
"github.com/get-drexa/drexa/internal/virtualfs" "github.com/get-drexa/drexa/internal/virtualfs"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/uptrace/bun"
) )
type Service struct { type Service struct {
@@ -33,8 +35,8 @@ type CreateUploadOptions struct {
Name string Name string
} }
func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts CreateUploadOptions) (*Upload, error) { func (s *Service) CreateUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, opts CreateUploadOptions) (*Upload, error) {
parentNode, err := s.vfs.FindNodeByPublicID(ctx, userID, opts.ParentID) parentNode, err := s.vfs.FindNodeByPublicID(ctx, db, accountID, opts.ParentID)
if err != nil { if err != nil {
if errors.Is(err, virtualfs.ErrNodeNotFound) { if errors.Is(err, virtualfs.ErrNodeNotFound) {
return nil, ErrNotFound return nil, ErrNotFound
@@ -46,7 +48,7 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
return nil, ErrParentNotDirectory return nil, ErrParentNotDirectory
} }
node, err := s.vfs.CreateFile(ctx, userID, virtualfs.CreateFileOptions{ node, err := s.vfs.CreateFile(ctx, db, accountID, virtualfs.CreateFileOptions{
ParentID: parentNode.ID, ParentID: parentNode.ID,
Name: opts.Name, Name: opts.Name,
}) })
@@ -57,11 +59,17 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
return nil, err return nil, err
} }
uploadURL, err := s.blobStore.GenerateUploadURL(ctx, node.BlobKey, blob.UploadURLOptions{ var uploadURL string
Duration: 1 * time.Hour, if s.blobStore.SupportsDirectUpload() {
}) uploadURL, err = s.blobStore.GenerateUploadURL(ctx, node.BlobKey, blob.UploadURLOptions{
if err != nil { Duration: 1 * time.Hour,
return nil, err })
if err != nil {
_ = s.vfs.PermanentlyDeleteNode(ctx, db, node)
return nil, err
}
} else {
uploadURL = ""
} }
upload := &Upload{ upload := &Upload{
@@ -76,7 +84,8 @@ func (s *Service) CreateUpload(ctx context.Context, userID uuid.UUID, opts Creat
return upload, nil return upload, nil
} }
func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID string, reader io.Reader) error { func (s *Service) ReceiveUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, uploadID string, reader io.Reader) error {
fmt.Printf("reader: %v\n", reader)
n, ok := s.pendingUploads.Load(uploadID) n, ok := s.pendingUploads.Load(uploadID)
if !ok { if !ok {
return ErrNotFound return ErrNotFound
@@ -87,11 +96,11 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID
return ErrNotFound return ErrNotFound
} }
if upload.TargetNode.UserID != userID { if upload.TargetNode.AccountID != accountID {
return ErrNotFound return ErrNotFound
} }
err := s.vfs.WriteFile(ctx, upload.TargetNode, virtualfs.FileContentFromReader(reader)) err := s.vfs.WriteFile(ctx, db, upload.TargetNode, virtualfs.FileContentFromReader(reader))
if err != nil { if err != nil {
return err return err
} }
@@ -101,7 +110,7 @@ func (s *Service) ReceiveUpload(ctx context.Context, userID uuid.UUID, uploadID
return nil return nil
} }
func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID string) (*Upload, error) { func (s *Service) CompleteUpload(ctx context.Context, db bun.IDB, accountID uuid.UUID, uploadID string) (*Upload, error) {
n, ok := s.pendingUploads.Load(uploadID) n, ok := s.pendingUploads.Load(uploadID)
if !ok { if !ok {
return nil, ErrNotFound return nil, ErrNotFound
@@ -112,7 +121,7 @@ func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID
return nil, ErrNotFound return nil, ErrNotFound
} }
if upload.TargetNode.UserID != userID { if upload.TargetNode.AccountID != accountID {
return nil, ErrNotFound return nil, ErrNotFound
} }
@@ -120,8 +129,11 @@ func (s *Service) CompleteUpload(ctx context.Context, userID uuid.UUID, uploadID
return upload, nil return upload, nil
} }
err := s.vfs.WriteFile(ctx, upload.TargetNode, virtualfs.FileContentFromBlobKey(upload.TargetNode.BlobKey)) err := s.vfs.WriteFile(ctx, db, upload.TargetNode, virtualfs.FileContentFromBlobKey(upload.TargetNode.BlobKey))
if err != nil { if err != nil {
if errors.Is(err, blob.ErrNotFound) {
return nil, ErrContentNotUploaded
}
return nil, err return nil, err
} }

View File

@@ -0,0 +1,31 @@
package user
import (
"github.com/get-drexa/drexa/internal/reqctx"
"github.com/gofiber/fiber/v2"
"github.com/uptrace/bun"
)
type HTTPHandler struct {
service *Service
db *bun.DB
authMiddleware fiber.Handler
}
func NewHTTPHandler(service *Service, db *bun.DB, authMiddleware fiber.Handler) *HTTPHandler {
return &HTTPHandler{service: service, db: db, authMiddleware: authMiddleware}
}
func (h *HTTPHandler) RegisterRoutes(api fiber.Router) {
user := api.Group("/users")
user.Use(h.authMiddleware)
user.Get("/me", h.getAuthenticatedUser)
}
func (h *HTTPHandler) getAuthenticatedUser(c *fiber.Ctx) error {
u := reqctx.AuthenticatedUser(c).(*User)
if u == nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
return c.JSON(u)
}

View File

@@ -11,14 +11,12 @@ import (
type User struct { type User struct {
bun.BaseModel `bun:"users"` bun.BaseModel `bun:"users"`
ID uuid.UUID `bun:",pk,type:uuid" json:"id"` ID uuid.UUID `bun:",pk,type:uuid" json:"id"`
DisplayName string `bun:"display_name" json:"displayName"` DisplayName string `bun:"display_name" json:"displayName"`
Email string `bun:"email,unique,notnull" json:"email"` Email string `bun:"email,unique,notnull" json:"email"`
Password password.Hashed `bun:"password,notnull" json:"-"` Password password.Hashed `bun:"password,notnull" json:"-"`
StorageUsageBytes int64 `bun:"storage_usage_bytes,notnull" json:"storageUsageBytes"` CreatedAt time.Time `bun:"created_at,notnull,nullzero" json:"-"`
StorageQuotaBytes int64 `bun:"storage_quota_bytes,notnull" json:"storageQuotaBytes"` UpdatedAt time.Time `bun:"updated_at,notnull,nullzero" json:"-"`
CreatedAt time.Time `bun:"created_at,notnull" json:"createdAt"`
UpdatedAt time.Time `bun:"updated_at,notnull" json:"updatedAt"`
} }
func newUserID() (uuid.UUID, error) { func newUserID() (uuid.UUID, error) {

View File

@@ -3,6 +3,7 @@ package virtualfs
import "errors" import "errors"
var ( var (
ErrNodeNotFound = errors.New("node not found") ErrNodeNotFound = errors.New("node not found")
ErrNodeConflict = errors.New("node conflict") ErrNodeConflict = errors.New("node conflict")
ErrUnsupportedOperation = errors.New("unsupported operation")
) )

View File

@@ -0,0 +1,34 @@
package virtualfs
import (
"io"
"github.com/get-drexa/drexa/internal/blob"
)
type FileContent struct {
Size int64
Reader io.ReadCloser
BlobKey blob.Key
URL string
}
func EmptyFileContent() FileContent {
return FileContent{}
}
func FileContentFromReader(reader io.Reader) FileContent {
return FileContent{Reader: io.NopCloser(reader)}
}
func FileContentFromReaderWithSize(reader io.Reader, size int64) FileContent {
return FileContent{Reader: io.NopCloser(reader), Size: size}
}
func FileContentFromBlobKey(blobKey blob.Key) FileContent {
return FileContent{BlobKey: blobKey}
}
func FileContentFromURL(url string) FileContent {
return FileContent{URL: url}
}

View File

@@ -0,0 +1,36 @@
package virtualfs
import (
"context"
"github.com/get-drexa/drexa/internal/blob"
"github.com/google/uuid"
"github.com/uptrace/bun"
)
type FlatKeyResolver struct{}
var _ BlobKeyResolver = &FlatKeyResolver{}
func NewFlatKeyResolver() *FlatKeyResolver {
return &FlatKeyResolver{}
}
func (r *FlatKeyResolver) ShouldPersistKey() bool {
return true
}
func (r *FlatKeyResolver) Resolve(ctx context.Context, db bun.IDB, node *Node) (blob.Key, error) {
if node.BlobKey == "" {
id, err := uuid.NewV7()
if err != nil {
return "", err
}
return blob.Key(id.String()), nil
}
return node.BlobKey, nil
}
func (r *FlatKeyResolver) ResolveDeletionKeys(ctx context.Context, node *Node, allKeys []blob.Key) (*DeletionPlan, error) {
return &DeletionPlan{Keys: allKeys}, nil
}

View File

@@ -0,0 +1,40 @@
package virtualfs
import (
"context"
"fmt"
"github.com/get-drexa/drexa/internal/blob"
"github.com/uptrace/bun"
)
type HierarchicalKeyResolver struct {
db *bun.DB
}
var _ BlobKeyResolver = &HierarchicalKeyResolver{}
func NewHierarchicalKeyResolver(db *bun.DB) *HierarchicalKeyResolver {
return &HierarchicalKeyResolver{db: db}
}
func (r *HierarchicalKeyResolver) ShouldPersistKey() bool {
return false
}
func (r *HierarchicalKeyResolver) Resolve(ctx context.Context, db bun.IDB, node *Node) (blob.Key, error) {
path, err := buildNodeAbsolutePathString(ctx, db, node)
if err != nil {
return "", err
}
return blob.Key(fmt.Sprintf("%s/%s", node.AccountID, path)), nil
}
func (r *HierarchicalKeyResolver) ResolveDeletionKeys(ctx context.Context, node *Node, allKeys []blob.Key) (*DeletionPlan, error) {
path, err := buildNodeAbsolutePathString(ctx, r.db, node)
if err != nil {
return nil, err
}
return &DeletionPlan{Prefix: blob.Key(path)}, nil
}

View File

@@ -4,53 +4,19 @@ import (
"context" "context"
"github.com/get-drexa/drexa/internal/blob" "github.com/get-drexa/drexa/internal/blob"
"github.com/google/uuid"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type BlobKeyResolver interface { type BlobKeyResolver interface {
KeyMode() blob.KeyMode // ShouldPersistKey returns true if the resolved key should be stored in node.BlobKey.
Resolve(ctx context.Context, node *Node) (blob.Key, error) // Flat keys (e.g. UUIDs) return true - key is generated once and stored.
// Hierarchical keys return false - key is derived from path each time.
ShouldPersistKey() bool
Resolve(ctx context.Context, db bun.IDB, node *Node) (blob.Key, error)
ResolveDeletionKeys(ctx context.Context, node *Node, allKeys []blob.Key) (*DeletionPlan, error)
} }
type FlatKeyResolver struct{} type DeletionPlan struct {
Prefix blob.Key
func NewFlatKeyResolver() *FlatKeyResolver { Keys []blob.Key
return &FlatKeyResolver{}
}
func (r *FlatKeyResolver) KeyMode() blob.KeyMode {
return blob.KeyModeStable
}
func (r *FlatKeyResolver) Resolve(ctx context.Context, node *Node) (blob.Key, error) {
if node.BlobKey == "" {
id, err := uuid.NewV7()
if err != nil {
return "", err
}
return blob.Key(id.String()), nil
}
return node.BlobKey, nil
}
type HierarchicalKeyResolver struct {
db *bun.DB
}
func NewHierarchicalKeyResolver(db *bun.DB) *HierarchicalKeyResolver {
return &HierarchicalKeyResolver{db: db}
}
func (r *HierarchicalKeyResolver) KeyMode() blob.KeyMode {
return blob.KeyModeDerived
}
func (r *HierarchicalKeyResolver) Resolve(ctx context.Context, node *Node) (blob.Key, error) {
path, err := buildNodeAbsolutePath(ctx, r.db, node.ID)
if err != nil {
return "", err
}
return blob.Key(path), nil
} }

View File

@@ -25,25 +25,29 @@ const (
type Node struct { type Node struct {
bun.BaseModel `bun:"vfs_nodes"` bun.BaseModel `bun:"vfs_nodes"`
ID uuid.UUID `bun:",pk,type:uuid"` ID uuid.UUID `bun:",pk,type:uuid"`
PublicID string `bun:"public_id,notnull"` PublicID string `bun:"public_id,notnull"`
UserID uuid.UUID `bun:"user_id,notnull"` AccountID uuid.UUID `bun:"account_id,notnull,type:uuid"`
ParentID uuid.UUID `bun:"parent_id,notnull"` ParentID uuid.UUID `bun:"parent_id,nullzero"`
Kind NodeKind `bun:"kind,notnull"` Kind NodeKind `bun:"kind,notnull"`
Status NodeStatus `bun:"status,notnull"` Status NodeStatus `bun:"status,notnull"`
Name string `bun:"name,notnull"` Name string `bun:"name,notnull"`
BlobKey blob.Key `bun:"blob_key"` BlobKey blob.Key `bun:"blob_key,nullzero"`
Size int64 `bun:"size"` Size int64 `bun:"size"`
MimeType string `bun:"mime_type"` MimeType string `bun:"mime_type,nullzero"`
CreatedAt time.Time `bun:"created_at,notnull"` CreatedAt time.Time `bun:"created_at,notnull,nullzero"`
UpdatedAt time.Time `bun:"updated_at,notnull"` UpdatedAt time.Time `bun:"updated_at,notnull,nullzero"`
DeletedAt time.Time `bun:"deleted_at"` DeletedAt *time.Time `bun:"deleted_at,nullzero"`
}
func newNodeID() (uuid.UUID, error) {
return uuid.NewV7()
} }
// IsAccessible returns true if the node can be accessed. // IsAccessible returns true if the node can be accessed.
// If the node is not ready or if it is soft deleted, it cannot be accessed. // If the node is not ready or if it is soft deleted, it cannot be accessed.
func (n *Node) IsAccessible() bool { func (n *Node) IsAccessible() bool {
return n.DeletedAt.IsZero() && n.Status == NodeStatusReady return n.DeletedAt == nil && n.Status == NodeStatusReady
} }

View File

@@ -6,32 +6,53 @@ import (
"errors" "errors"
"strings" "strings"
"github.com/google/uuid"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
const absolutePathQuery = `WITH RECURSIVE path AS ( const absolutePathQuery = `WITH RECURSIVE path AS (
SELECT id, parent_id, name, 1 as depth SELECT id, parent_id, name, 1 as depth
FROM vfs_nodes WHERE id = $1 AND deleted_at IS NULL FROM vfs_nodes WHERE id = ? AND deleted_at IS NULL
UNION ALL UNION ALL
SELECT n.id, n.parent_id, n.name, p.depth + 1 SELECT n.id, n.parent_id, n.name, p.depth + 1
FROM vfs_nodes n FROM vfs_nodes n
JOIN path p ON n.id = p.parent_id JOIN path p ON n.id = p.parent_id
WHERE n.deleted_at IS NULL
) )
SELECT name FROM path SELECT name FROM path
WHERE EXISTS (SELECT 1 FROM path WHERE parent_id IS NULL) WHERE EXISTS (SELECT 1 FROM path WHERE parent_id IS NULL)
ORDER BY depth DESC;` ORDER BY depth DESC;`
const absolutePathWithPublicIDQuery = `WITH RECURSIVE path AS (
SELECT id, parent_id, name, public_id, 1 as depth
FROM vfs_nodes WHERE id = ? AND deleted_at IS NULL
UNION ALL
SELECT n.id, n.parent_id, n.name, n.public_id, p.depth + 1
FROM vfs_nodes n
JOIN path p ON n.id = p.parent_id
)
SELECT name, public_id FROM path
WHERE EXISTS (SELECT 1 FROM path WHERE parent_id IS NULL)
ORDER BY depth DESC;`
// Path represents a path to a node.
type Path []PathSegment
// PathSegment represents a segment of a path, containing the name and public ID of the node.
type PathSegment struct {
Name string `bun:"name" json:"name"`
PublicID string `bun:"public_id" json:"id"`
}
func JoinPath(parts ...string) string { func JoinPath(parts ...string) string {
return strings.Join(parts, "/") return strings.Join(parts, "/")
} }
func buildNodeAbsolutePath(ctx context.Context, db *bun.DB, nodeID uuid.UUID) (string, error) { func buildNodeAbsolutePathString(ctx context.Context, db bun.IDB, node *Node) (string, error) {
var path []string var path []string
err := db.NewRaw(absolutePathQuery, nodeID).Scan(ctx, &path) err := db.NewRaw(absolutePathQuery, node.ID).Scan(ctx, &path)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return "", ErrNodeNotFound return "", ErrNodeNotFound
@@ -40,3 +61,15 @@ func buildNodeAbsolutePath(ctx context.Context, db *bun.DB, nodeID uuid.UUID) (s
} }
return JoinPath(path...), nil return JoinPath(path...), nil
} }
func buildNoteAbsolutePath(ctx context.Context, db bun.IDB, node *Node) (Path, error) {
var segments []PathSegment
err := db.NewRaw(absolutePathWithPublicIDQuery, node.ID).Scan(ctx, &segments)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNodeNotFound
}
return nil, err
}
return segments, nil
}

View File

@@ -8,6 +8,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"time"
"github.com/gabriel-vasile/mimetype" "github.com/gabriel-vasile/mimetype"
"github.com/get-drexa/drexa/internal/blob" "github.com/get-drexa/drexa/internal/blob"
@@ -19,7 +20,6 @@ import (
) )
type VirtualFS struct { type VirtualFS struct {
db *bun.DB
blobStore blob.Store blobStore blob.Store
keyResolver BlobKeyResolver keyResolver BlobKeyResolver
@@ -37,36 +37,24 @@ type CreateFileOptions struct {
Name string Name string
} }
type FileContent struct { const RootDirectoryName = "root"
reader io.Reader
blobKey blob.Key
}
func FileContentFromReader(reader io.Reader) FileContent { func New(blobStore blob.Store, keyResolver BlobKeyResolver) (*VirtualFS, error) {
return FileContent{reader: reader}
}
func FileContentFromBlobKey(blobKey blob.Key) FileContent {
return FileContent{blobKey: blobKey}
}
func NewVirtualFS(db *bun.DB, blobStore blob.Store, keyResolver BlobKeyResolver) (*VirtualFS, error) {
sqid, err := sqids.New() sqid, err := sqids.New()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &VirtualFS{ return &VirtualFS{
db: db,
blobStore: blobStore, blobStore: blobStore,
keyResolver: keyResolver, keyResolver: keyResolver,
sqid: sqid, sqid: sqid,
}, nil }, nil
} }
func (vfs *VirtualFS) FindNode(ctx context.Context, userID, fileID string) (*Node, error) { func (vfs *VirtualFS) FindNode(ctx context.Context, db bun.IDB, accountID, fileID string) (*Node, error) {
var node Node var node Node
err := vfs.db.NewSelect().Model(&node). err := db.NewSelect().Model(&node).
Where("user_id = ?", userID). Where("account_id = ?", accountID).
Where("id = ?", fileID). Where("id = ?", fileID).
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
@@ -80,13 +68,12 @@ func (vfs *VirtualFS) FindNode(ctx context.Context, userID, fileID string) (*Nod
return &node, nil return &node, nil
} }
func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, userID uuid.UUID, publicID string) (*Node, error) { func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, db bun.IDB, accountID uuid.UUID, publicID string) (*Node, error) {
var node Node var node Node
err := vfs.db.NewSelect().Model(&node). err := db.NewSelect().Model(&node).
Where("user_id = ?", userID). Where("account_id = ?", accountID).
Where("public_id = ?", publicID). Where("public_id = ?", publicID).
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL").
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
@@ -97,14 +84,14 @@ func (vfs *VirtualFS) FindNodeByPublicID(ctx context.Context, userID uuid.UUID,
return &node, nil return &node, nil
} }
func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, error) { func (vfs *VirtualFS) ListChildren(ctx context.Context, db bun.IDB, node *Node) ([]*Node, error) {
if !node.IsAccessible() { if !node.IsAccessible() {
return nil, ErrNodeNotFound return nil, ErrNodeNotFound
} }
var nodes []*Node var nodes []*Node
err := vfs.db.NewSelect().Model(&nodes). err := db.NewSelect().Model(&nodes).
Where("user_id = ?", node.UserID). Where("account_id = ?", node.AccountID).
Where("parent_id = ?", node.ID). Where("parent_id = ?", node.ID).
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
@@ -119,29 +106,35 @@ func (vfs *VirtualFS) ListChildren(ctx context.Context, node *Node) ([]*Node, er
return nodes, nil return nodes, nil
} }
func (vfs *VirtualFS) CreateFile(ctx context.Context, userID uuid.UUID, opts CreateFileOptions) (*Node, error) { func (vfs *VirtualFS) CreateFile(ctx context.Context, db bun.IDB, accountID uuid.UUID, opts CreateFileOptions) (*Node, error) {
pid, err := vfs.generatePublicID() pid, err := vfs.generatePublicID()
if err != nil { if err != nil {
return nil, err return nil, err
} }
node := Node{ id, err := newNodeID()
PublicID: pid, if err != nil {
UserID: userID, return nil, err
ParentID: opts.ParentID,
Kind: NodeKindFile,
Status: NodeStatusPending,
Name: opts.Name,
} }
if vfs.keyResolver.KeyMode() == blob.KeyModeStable { node := Node{
node.BlobKey, err = vfs.keyResolver.Resolve(ctx, &node) ID: id,
PublicID: pid,
AccountID: accountID,
ParentID: opts.ParentID,
Kind: NodeKindFile,
Status: NodeStatusPending,
Name: opts.Name,
}
if vfs.keyResolver.ShouldPersistKey() {
node.BlobKey, err = vfs.keyResolver.Resolve(ctx, db, &node)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
_, err = vfs.db.NewInsert().Model(&node).Returning("*").Exec(ctx) _, err = db.NewInsert().Model(&node).Returning("*").Exec(ctx)
if err != nil { if err != nil {
if database.IsUniqueViolation(err) { if database.IsUniqueViolation(err) {
return nil, ErrNodeConflict return nil, ErrNodeConflict
@@ -152,39 +145,39 @@ func (vfs *VirtualFS) CreateFile(ctx context.Context, userID uuid.UUID, opts Cre
return &node, nil return &node, nil
} }
func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileContent) error { func (vfs *VirtualFS) WriteFile(ctx context.Context, db bun.IDB, node *Node, content FileContent) error {
if content.reader == nil && content.blobKey.IsNil() { if content.Reader == nil && content.BlobKey.IsNil() {
return blob.ErrInvalidFileContent return blob.ErrInvalidFileContent
} }
if !node.DeletedAt.IsZero() { if node.DeletedAt != nil {
return ErrNodeNotFound return ErrNodeNotFound
} }
setCols := make([]string, 0, 4) setCols := make([]string, 0, 4)
if content.reader != nil { if content.Reader != nil {
key, err := vfs.keyResolver.Resolve(ctx, node) key, err := vfs.keyResolver.Resolve(ctx, db, node)
if err != nil { if err != nil {
return err return err
} }
buf := make([]byte, 3072) buf := make([]byte, 3072)
n, err := io.ReadFull(content.reader, buf) n, err := io.ReadFull(content.Reader, buf)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return err return err
} }
buf = buf[:n] buf = buf[:n]
mt := mimetype.Detect(buf) mt := mimetype.Detect(buf)
cr := ioext.NewCountingReader(io.MultiReader(bytes.NewReader(buf), content.reader)) cr := ioext.NewCountingReader(io.MultiReader(bytes.NewReader(buf), content.Reader))
err = vfs.blobStore.Put(ctx, key, cr) err = vfs.blobStore.Put(ctx, key, cr)
if err != nil { if err != nil {
return err return err
} }
if vfs.keyResolver.KeyMode() == blob.KeyModeStable { if vfs.keyResolver.ShouldPersistKey() {
node.BlobKey = key node.BlobKey = key
setCols = append(setCols, "blob_key") setCols = append(setCols, "blob_key")
} }
@@ -195,9 +188,9 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
setCols = append(setCols, "mime_type", "size", "status") setCols = append(setCols, "mime_type", "size", "status")
} else { } else {
node.BlobKey = content.blobKey node.BlobKey = content.BlobKey
b, err := vfs.blobStore.ReadRange(ctx, content.blobKey, 0, 3072) b, err := vfs.blobStore.ReadRange(ctx, content.BlobKey, 0, 3072)
if err != nil { if err != nil {
return err return err
} }
@@ -213,7 +206,7 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
node.MimeType = mt.String() node.MimeType = mt.String()
node.Status = NodeStatusReady node.Status = NodeStatusReady
s, err := vfs.blobStore.ReadSize(ctx, content.blobKey) s, err := vfs.blobStore.ReadSize(ctx, content.BlobKey)
if err != nil { if err != nil {
return err return err
} }
@@ -222,7 +215,7 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
setCols = append(setCols, "mime_type", "blob_key", "size", "status") setCols = append(setCols, "mime_type", "blob_key", "size", "status")
} }
_, err := vfs.db.NewUpdate().Model(&node). _, err := db.NewUpdate().Model(node).
Column(setCols...). Column(setCols...).
WherePK(). WherePK().
Exec(ctx) Exec(ctx)
@@ -233,22 +226,56 @@ func (vfs *VirtualFS) WriteFile(ctx context.Context, node *Node, content FileCon
return nil return nil
} }
func (vfs *VirtualFS) CreateDirectory(ctx context.Context, userID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) { func (vfs *VirtualFS) ReadFile(ctx context.Context, db bun.IDB, node *Node) (FileContent, error) {
if node.Kind != NodeKindFile {
return EmptyFileContent(), ErrUnsupportedOperation
}
key, err := vfs.keyResolver.Resolve(ctx, db, node)
if err != nil {
return EmptyFileContent(), err
}
if vfs.blobStore.SupportsDirectDownload() {
url, err := vfs.blobStore.GenerateDownloadURL(ctx, key, blob.DownloadURLOptions{
Duration: 1 * time.Hour,
})
if err != nil {
return EmptyFileContent(), err
}
return FileContentFromURL(url), nil
}
reader, err := vfs.blobStore.Read(ctx, key)
if err != nil {
return EmptyFileContent(), err
}
return FileContentFromReaderWithSize(reader, node.Size), nil
}
func (vfs *VirtualFS) CreateDirectory(ctx context.Context, db bun.IDB, accountID uuid.UUID, parentID uuid.UUID, name string) (*Node, error) {
pid, err := vfs.generatePublicID() pid, err := vfs.generatePublicID()
if err != nil { if err != nil {
return nil, err return nil, err
} }
node := Node{ id, err := newNodeID()
PublicID: pid, if err != nil {
UserID: userID, return nil, err
ParentID: parentID,
Kind: NodeKindDirectory,
Status: NodeStatusReady,
Name: name,
} }
_, err = vfs.db.NewInsert().Model(&node).Exec(ctx) node := &Node{
ID: id,
PublicID: pid,
AccountID: accountID,
ParentID: parentID,
Kind: NodeKindDirectory,
Status: NodeStatusReady,
Name: name,
}
_, err = db.NewInsert().Model(node).Exec(ctx)
if err != nil { if err != nil {
if database.IsUniqueViolation(err) { if database.IsUniqueViolation(err) {
return nil, ErrNodeConflict return nil, ErrNodeConflict
@@ -256,15 +283,15 @@ func (vfs *VirtualFS) CreateDirectory(ctx context.Context, userID uuid.UUID, par
return nil, err return nil, err
} }
return &node, nil return node, nil
} }
func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, node *Node) error { func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, db bun.IDB, node *Node) error {
if !node.IsAccessible() { if !node.IsAccessible() {
return ErrNodeNotFound return ErrNodeNotFound
} }
_, err := vfs.db.NewUpdate().Model(node). _, err := db.NewUpdate().Model(node).
WherePK(). WherePK().
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
@@ -281,12 +308,12 @@ func (vfs *VirtualFS) SoftDeleteNode(ctx context.Context, node *Node) error {
return nil return nil
} }
func (vfs *VirtualFS) RestoreNode(ctx context.Context, node *Node) error { func (vfs *VirtualFS) RestoreNode(ctx context.Context, db bun.IDB, node *Node) error {
if node.Status != NodeStatusReady { if node.Status != NodeStatusReady {
return ErrNodeNotFound return ErrNodeNotFound
} }
_, err := vfs.db.NewUpdate().Model(node). _, err := db.NewUpdate().Model(node).
WherePK(). WherePK().
Where("deleted_at IS NOT NULL"). Where("deleted_at IS NOT NULL").
Set("deleted_at = NULL"). Set("deleted_at = NULL").
@@ -302,12 +329,17 @@ func (vfs *VirtualFS) RestoreNode(ctx context.Context, node *Node) error {
return nil return nil
} }
func (vfs *VirtualFS) RenameNode(ctx context.Context, node *Node, name string) error { func (vfs *VirtualFS) RenameNode(ctx context.Context, db bun.IDB, node *Node, name string) error {
if !node.IsAccessible() { if !node.IsAccessible() {
return ErrNodeNotFound return ErrNodeNotFound
} }
_, err := vfs.db.NewUpdate().Model(node). oldKey, err := vfs.keyResolver.Resolve(ctx, db, node)
if err != nil {
return err
}
_, err = db.NewUpdate().Model(node).
WherePK(). WherePK().
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
@@ -320,20 +352,44 @@ func (vfs *VirtualFS) RenameNode(ctx context.Context, node *Node, name string) e
} }
return err return err
} }
return nil
}
func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UUID) error { newKey, err := vfs.keyResolver.Resolve(ctx, db, node)
if !node.IsAccessible() {
return ErrNodeNotFound
}
oldKey, err := vfs.keyResolver.Resolve(ctx, node)
if err != nil { if err != nil {
return err return err
} }
_, err = vfs.db.NewUpdate().Model(node). if oldKey != newKey {
err = vfs.blobStore.Move(ctx, oldKey, newKey)
if err != nil {
return err
}
if vfs.keyResolver.ShouldPersistKey() {
node.BlobKey = newKey
_, err = db.NewUpdate().Model(node).
WherePK().
Set("blob_key = ?", newKey).
Exec(ctx)
if err != nil {
return err
}
}
}
return nil
}
func (vfs *VirtualFS) MoveNode(ctx context.Context, db bun.IDB, node *Node, parentID uuid.UUID) error {
if !node.IsAccessible() {
return ErrNodeNotFound
}
oldKey, err := vfs.keyResolver.Resolve(ctx, db, node)
if err != nil {
return err
}
_, err = db.NewUpdate().Model(node).
WherePK(). WherePK().
Where("status = ?", NodeStatusReady). Where("status = ?", NodeStatusReady).
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
@@ -350,7 +406,7 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU
return err return err
} }
newKey, err := vfs.keyResolver.Resolve(ctx, node) newKey, err := vfs.keyResolver.Resolve(ctx, db, node)
if err != nil { if err != nil {
return err return err
} }
@@ -360,9 +416,9 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU
return err return err
} }
if vfs.keyResolver.KeyMode() == blob.KeyModeStable { if vfs.keyResolver.ShouldPersistKey() {
node.BlobKey = newKey node.BlobKey = newKey
_, err = vfs.db.NewUpdate().Model(node). _, err = db.NewUpdate().Model(node).
WherePK(). WherePK().
Set("blob_key = ?", newKey). Set("blob_key = ?", newKey).
Exec(ctx) Exec(ctx)
@@ -374,18 +430,44 @@ func (vfs *VirtualFS) MoveNode(ctx context.Context, node *Node, parentID uuid.UU
return nil return nil
} }
func (vfs *VirtualFS) AbsolutePath(ctx context.Context, node *Node) (string, error) { func (vfs *VirtualFS) AbsolutePath(ctx context.Context, db bun.IDB, node *Node) (string, error) {
if !node.IsAccessible() { if !node.IsAccessible() {
return "", ErrNodeNotFound return "", ErrNodeNotFound
} }
return buildNodeAbsolutePath(ctx, vfs.db, node.ID) return buildNodeAbsolutePathString(ctx, db, node)
} }
func (vfs *VirtualFS) PermanentlyDeleteNode(ctx context.Context, node *Node) error { func (vfs *VirtualFS) PermanentlyDeleteNode(ctx context.Context, db bun.IDB, node *Node) error {
if !node.IsAccessible() { switch node.Kind {
return ErrNodeNotFound case NodeKindFile:
return vfs.permanentlyDeleteFileNode(ctx, db, node)
case NodeKindDirectory:
return vfs.permanentlyDeleteDirectoryNode(ctx, db, node)
default:
return ErrUnsupportedOperation
}
}
func (vfs *VirtualFS) permanentlyDeleteFileNode(ctx context.Context, db bun.IDB, node *Node) error {
key, err := vfs.keyResolver.Resolve(ctx, db, node)
if err != nil {
return err
} }
err = vfs.blobStore.Delete(ctx, key)
if err != nil {
return err
}
_, err = db.NewDelete().Model(node).WherePK().Exec(ctx)
if err != nil {
return err
}
return nil
}
func (vfs *VirtualFS) permanentlyDeleteDirectoryNode(ctx context.Context, db bun.IDB, node *Node) error {
const descendantsQuery = `WITH RECURSIVE descendants AS ( const descendantsQuery = `WITH RECURSIVE descendants AS (
SELECT id, blob_key FROM vfs_nodes WHERE id = ? SELECT id, blob_key FROM vfs_nodes WHERE id = ?
UNION ALL UNION ALL
@@ -399,43 +481,75 @@ func (vfs *VirtualFS) PermanentlyDeleteNode(ctx context.Context, node *Node) err
BlobKey blob.Key `bun:"blob_key"` BlobKey blob.Key `bun:"blob_key"`
} }
var blobKeys []blob.Key // If db is already a transaction, use it directly; otherwise start a new transaction
var tx bun.IDB
err := vfs.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { var startedTx *bun.Tx
var records []nodeRecord switch v := db.(type) {
err := tx.NewRaw(descendantsQuery, node.ID).Scan(ctx, &records) case *bun.DB:
newTx, err := v.BeginTx(ctx, nil)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrNodeNotFound
}
return err return err
} }
startedTx = &newTx
tx = newTx
defer func() {
if startedTx != nil {
(*startedTx).Rollback()
}
}()
default:
// Assume it's already a transaction
tx = db
}
if len(records) == 0 { var records []nodeRecord
err := tx.NewRaw(descendantsQuery, node.ID).Scan(ctx, &records)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrNodeNotFound return ErrNodeNotFound
} }
nodeIDs := make([]uuid.UUID, 0, len(records))
for _, r := range records {
nodeIDs = append(nodeIDs, r.ID)
if !r.BlobKey.IsNil() {
blobKeys = append(blobKeys, r.BlobKey)
}
}
_, err = tx.NewDelete().
Model((*Node)(nil)).
Where("id IN (?)", bun.In(nodeIDs)).
Exec(ctx)
return err return err
}) }
if len(records) == 0 {
return ErrNodeNotFound
}
nodeIDs := make([]uuid.UUID, 0, len(records))
blobKeys := make([]blob.Key, 0, len(records))
for _, r := range records {
nodeIDs = append(nodeIDs, r.ID)
if !r.BlobKey.IsNil() {
blobKeys = append(blobKeys, r.BlobKey)
}
}
plan, err := vfs.keyResolver.ResolveDeletionKeys(ctx, node, blobKeys)
if err != nil { if err != nil {
return err return err
} }
// Delete blobs outside transaction (best effort) _, err = tx.NewDelete().
for _, key := range blobKeys { Model((*Node)(nil)).
_ = vfs.blobStore.Delete(ctx, key) Where("id IN (?)", bun.In(nodeIDs)).
Exec(ctx)
if err != nil {
return err
}
if !plan.Prefix.IsNil() {
_ = vfs.blobStore.DeletePrefix(ctx, plan.Prefix)
} else {
for _, key := range plan.Keys {
_ = vfs.blobStore.Delete(ctx, key)
}
}
// Only commit if we started the transaction
if startedTx != nil {
err := (*startedTx).Commit()
startedTx = nil // Prevent defer from rolling back
return err
} }
return nil return nil
@@ -450,3 +564,10 @@ func (vfs *VirtualFS) generatePublicID() (string, error) {
n := binary.BigEndian.Uint64(b[:]) n := binary.BigEndian.Uint64(b[:])
return vfs.sqid.Encode([]uint64{n}) return vfs.sqid.Encode([]uint64{n})
} }
func (vfs *VirtualFS) RealPath(ctx context.Context, db bun.IDB, node *Node) (Path, error) {
if !node.IsAccessible() {
return nil, ErrNodeNotFound
}
return buildNoteAbsolutePath(ctx, db, node)
}