mirror of
https://github.com/kennethnym/freya
synced 2026-07-02 22:31:14 +01:00
Compare commits
1 Commits
master
...
feat/agent
| Author | SHA1 | Date | |
|---|---|---|---|
|
af4df2cd2c
|
@@ -10,6 +10,6 @@
|
|||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@freya/agent-protocol": "workspace:*",
|
"@freya/agent-protocol": "workspace:*",
|
||||||
"@nym.sh/jrpc": "^0.1.0"
|
"@nym.sh/jrpc": "1.1.0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,7 +26,7 @@
|
|||||||
"@freya/source-tfl": "workspace:*",
|
"@freya/source-tfl": "workspace:*",
|
||||||
"@freya/source-weatherkit": "workspace:*",
|
"@freya/source-weatherkit": "workspace:*",
|
||||||
"@freya/source-web-search": "workspace:*",
|
"@freya/source-web-search": "workspace:*",
|
||||||
"@nym.sh/jrpc": "^0.1.0",
|
"@nym.sh/jrpc": "1.1.0",
|
||||||
"@openrouter/sdk": "^0.9.11",
|
"@openrouter/sdk": "^0.9.11",
|
||||||
"arktype": "^2.1.29",
|
"arktype": "^2.1.29",
|
||||||
"better-auth": "^1",
|
"better-auth": "^1",
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ mock.module("../sources/user-sources.ts", () => ({
|
|||||||
}),
|
}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
mock.module("../conversations/storage.ts", () => ({
|
mock.module("../conversations/db-storage.ts", () => ({
|
||||||
conversations: (_db: Database, userId: string) => ({
|
conversations: (_db: Database, userId: string) => ({
|
||||||
async getOrCreateConversation() {
|
async getOrCreateConversation() {
|
||||||
return { id: `conversation-${userId}` }
|
return { id: `conversation-${userId}` }
|
||||||
|
|||||||
145
apps/freya-backend/src/agent/job.ts
Normal file
145
apps/freya-backend/src/agent/job.ts
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import type { AgentEvent } from "@freya/agent-protocol"
|
||||||
|
|
||||||
|
import {
|
||||||
|
AssistantMessagePayload,
|
||||||
|
ConversationEntryKind,
|
||||||
|
UserMessagePayload,
|
||||||
|
ToolCallPayload,
|
||||||
|
ToolResultPayload,
|
||||||
|
} from "@freya/core"
|
||||||
|
import { type } from "arktype"
|
||||||
|
|
||||||
|
import type { ConversationStorage } from "../conversations/storage"
|
||||||
|
import type { Job } from "../lib/job"
|
||||||
|
import type { JobExecutor } from "../lib/worker"
|
||||||
|
import type { NotificationCentral } from "../notification/notification-central"
|
||||||
|
import type { UserSessionManager } from "../session"
|
||||||
|
|
||||||
|
import { ConversationResponseStateStatus } from "../db/schema"
|
||||||
|
import { streamAgentResponse } from "./streaming"
|
||||||
|
|
||||||
|
export interface AgentResponseJobPayload {
|
||||||
|
conversationId: string
|
||||||
|
}
|
||||||
|
|
||||||
|
interface AgentResponseWorkerConfig {
|
||||||
|
conversationStorage: ConversationStorage
|
||||||
|
userSessionManager: UserSessionManager
|
||||||
|
notificationCentral: NotificationCentral
|
||||||
|
}
|
||||||
|
|
||||||
|
export class AgentResponseJobExecutor implements JobExecutor<AgentResponseJobPayload> {
|
||||||
|
private conversationStorage: ConversationStorage
|
||||||
|
private userSessionManager: UserSessionManager
|
||||||
|
private notificationCentral: NotificationCentral
|
||||||
|
|
||||||
|
constructor({
|
||||||
|
conversationStorage,
|
||||||
|
userSessionManager,
|
||||||
|
notificationCentral,
|
||||||
|
}: AgentResponseWorkerConfig) {
|
||||||
|
this.conversationStorage = conversationStorage
|
||||||
|
this.userSessionManager = userSessionManager
|
||||||
|
this.notificationCentral = notificationCentral
|
||||||
|
}
|
||||||
|
|
||||||
|
async execute(job: Job<AgentResponseJobPayload>): Promise<void> {
|
||||||
|
const conversation = await this.conversationStorage.findConversation(job.payload.conversationId)
|
||||||
|
if (!conversation) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const claimed = await this.conversationStorage.claimPendingConversationResponseState(
|
||||||
|
job.payload.conversationId,
|
||||||
|
)
|
||||||
|
if (!claimed) {
|
||||||
|
// conversation response state not found or already claimed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const pendingEntries = await this.conversationStorage.listPendingUserConversationEntries(
|
||||||
|
conversation.userId,
|
||||||
|
conversation.id,
|
||||||
|
)
|
||||||
|
if (pendingEntries.length === 0) {
|
||||||
|
await this.conversationStorage.clearConversationResponseState(job.payload.conversationId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const message = pendingEntries.reduce((acc, entry) => {
|
||||||
|
const payload = UserMessagePayload(entry.payload)
|
||||||
|
if (payload instanceof type.errors) {
|
||||||
|
return acc
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
acc + "\n" + payload.parts.reduce((msg, p) => (p.type === "text" ? msg + p.text : msg), "")
|
||||||
|
)
|
||||||
|
}, "")
|
||||||
|
|
||||||
|
const session = await this.userSessionManager.getOrCreate(conversation.userId)
|
||||||
|
|
||||||
|
try {
|
||||||
|
for await (const event of streamAgentResponse({
|
||||||
|
agent: session.agent,
|
||||||
|
input: { message, signal: job.signal },
|
||||||
|
})) {
|
||||||
|
if (job.signal.aborted) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.recordAgentEvent(event, conversation.id)
|
||||||
|
await this.notificationCentral.notifyUser(conversation.userId, {
|
||||||
|
kind: "agent",
|
||||||
|
payload: event,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// if job is aborted, stop everything immediately, including clean up.
|
||||||
|
// the aborter is assumed responsibility on how to proceed.
|
||||||
|
if (!job.signal.aborted) {
|
||||||
|
await this.conversationStorage.clearConversationResponseState(job.payload.conversationId)
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.error("[agent job executor] error streaming agent response:", err)
|
||||||
|
if (!job.signal.aborted) {
|
||||||
|
await this.conversationStorage.markResponseStateStatus(
|
||||||
|
[job.payload.conversationId],
|
||||||
|
ConversationResponseStateStatus.Failed,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async recordAgentEvent(event: AgentEvent, conversationId: string) {
|
||||||
|
switch (event.type) {
|
||||||
|
case "message_created":
|
||||||
|
await this.conversationStorage.appendEntry(conversationId, {
|
||||||
|
kind: ConversationEntryKind.AssistantMessage,
|
||||||
|
payload: {
|
||||||
|
role: "assistant",
|
||||||
|
parts: [{ type: "text", text: event.text }],
|
||||||
|
} satisfies AssistantMessagePayload,
|
||||||
|
})
|
||||||
|
break
|
||||||
|
|
||||||
|
case "tool_started":
|
||||||
|
await this.conversationStorage.appendEntry(conversationId, {
|
||||||
|
kind: ConversationEntryKind.ToolCall,
|
||||||
|
payload: {
|
||||||
|
toolName: event.toolName,
|
||||||
|
} satisfies ToolCallPayload,
|
||||||
|
})
|
||||||
|
break
|
||||||
|
|
||||||
|
case "tool_finished":
|
||||||
|
await this.conversationStorage.appendEntry(conversationId, {
|
||||||
|
kind: ConversationEntryKind.ToolResult,
|
||||||
|
payload: {
|
||||||
|
toolName: event.toolName,
|
||||||
|
ok: event.ok,
|
||||||
|
} satisfies ToolResultPayload,
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -166,6 +166,16 @@ export class PiQueryAgent implements QueryAgent {
|
|||||||
this.handlePiEvent(event, pushRunEvent)
|
this.handlePiEvent(event, pushRunEvent)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
input.signal?.addEventListener(
|
||||||
|
"abort",
|
||||||
|
async () => {
|
||||||
|
await session.abort()
|
||||||
|
close()
|
||||||
|
unsubscribe()
|
||||||
|
},
|
||||||
|
{ once: true },
|
||||||
|
)
|
||||||
|
|
||||||
session
|
session
|
||||||
.prompt(input.message)
|
.prompt(input.message)
|
||||||
.then(() => {
|
.then(() => {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ export interface QueryAgentAsk {
|
|||||||
message: string
|
message: string
|
||||||
conversationId?: string
|
conversationId?: string
|
||||||
userMessageEntry?: QueryAgentConversationEntryRef
|
userMessageEntry?: QueryAgentConversationEntryRef
|
||||||
|
signal?: AbortSignal
|
||||||
}
|
}
|
||||||
|
|
||||||
export type QueryAgentStreamEvent =
|
export type QueryAgentStreamEvent =
|
||||||
|
|||||||
70
apps/freya-backend/src/agent/reconciler.ts
Normal file
70
apps/freya-backend/src/agent/reconciler.ts
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import type { ConversationStorage } from "../conversations/storage"
|
||||||
|
import type { AgentWorkScheduler } from "./scheduler"
|
||||||
|
|
||||||
|
interface AgentResponseReconcilerConfig {
|
||||||
|
storage: ConversationStorage
|
||||||
|
interval: number
|
||||||
|
scheduler: AgentWorkScheduler
|
||||||
|
signal: AbortSignal
|
||||||
|
}
|
||||||
|
|
||||||
|
export class AgentResponseReconciler {
|
||||||
|
private storage: ConversationStorage
|
||||||
|
private interval: number
|
||||||
|
private scheduler: AgentWorkScheduler
|
||||||
|
private signal: AbortSignal
|
||||||
|
|
||||||
|
private stopLoop: ReturnType<typeof setInterval> | null = null
|
||||||
|
|
||||||
|
constructor({ storage, interval, scheduler, signal }: AgentResponseReconcilerConfig) {
|
||||||
|
this.storage = storage
|
||||||
|
this.interval = interval
|
||||||
|
this.scheduler = scheduler
|
||||||
|
this.signal = signal
|
||||||
|
}
|
||||||
|
|
||||||
|
start() {
|
||||||
|
this.signal.throwIfAborted()
|
||||||
|
|
||||||
|
this.signal.addEventListener(
|
||||||
|
"abort",
|
||||||
|
() => {
|
||||||
|
if (this.stopLoop !== null) {
|
||||||
|
clearInterval(this.stopLoop)
|
||||||
|
this.stopLoop = null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ once: true },
|
||||||
|
)
|
||||||
|
|
||||||
|
this.stopLoop = setInterval(this.reconcile.bind(this), this.interval)
|
||||||
|
}
|
||||||
|
|
||||||
|
private async reconcile() {
|
||||||
|
// enqueue pending responses
|
||||||
|
const pendingStates = await this.storage.listPendingResponseStates()
|
||||||
|
const now = new Date().getTime()
|
||||||
|
for (const state of pendingStates) {
|
||||||
|
if (state.maxWaitUntil.getTime() < now) {
|
||||||
|
this.scheduler.enqueueAgentResponse(state.conversationId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// re-enqueue stuck responses
|
||||||
|
const runningStates = await this.storage.listRunningResponseStates()
|
||||||
|
const stuckIds: string[] = []
|
||||||
|
for (const state of runningStates) {
|
||||||
|
if (state.runningSince && Math.max(now - state.runningSince.getTime(), 0) > 5 * 1000 * 60) {
|
||||||
|
// if the response is running for more than 5 minutes
|
||||||
|
// we assume that its stuck and enqueue it for retry
|
||||||
|
stuckIds.push(state.conversationId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (stuckIds.length > 0) {
|
||||||
|
await this.storage.markResponseStateStatus(stuckIds, "pending")
|
||||||
|
for (const id of stuckIds) {
|
||||||
|
this.scheduler.enqueueAgentResponse(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
160
apps/freya-backend/src/agent/scheduler.ts
Normal file
160
apps/freya-backend/src/agent/scheduler.ts
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
import type { UserEvent } from "@freya/agent-protocol"
|
||||||
|
|
||||||
|
import { ConversationEntryKind, UserMessagePayload } from "@freya/core"
|
||||||
|
|
||||||
|
import type { ConversationStorage } from "../conversations/storage"
|
||||||
|
import type { Job, JobRegistry } from "../lib/job"
|
||||||
|
import type { AgentResponseJobPayload } from "./job"
|
||||||
|
import { ConversationNotFoundError } from "../conversations/errors";
|
||||||
|
import { ConversationResponseStateStatus } from "../db/schema";
|
||||||
|
|
||||||
|
interface AgentMessageSchedulerConfig {
|
||||||
|
storage: ConversationStorage
|
||||||
|
maxWaitTime: number
|
||||||
|
|
||||||
|
/**
|
||||||
|
* How long to wait before responding to the user.
|
||||||
|
*/
|
||||||
|
waitTIme: number
|
||||||
|
|
||||||
|
jobRegistry: JobRegistry<AgentResponseJobPayload>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Schedules and manages the flow of messages between the user and the query agent for a specific conversation.
|
||||||
|
*/
|
||||||
|
export class AgentWorkScheduler {
|
||||||
|
private conversationStorage: ConversationStorage
|
||||||
|
private jobRegistry: JobRegistry<AgentResponseJobPayload>
|
||||||
|
|
||||||
|
private timing: {
|
||||||
|
maxWaitTime: number
|
||||||
|
waitTime: number
|
||||||
|
}
|
||||||
|
|
||||||
|
private timers = new Map<string, ReturnType<typeof setTimeout>>()
|
||||||
|
private runningJobs = new Map<string, Job<AgentResponseJobPayload>>()
|
||||||
|
|
||||||
|
constructor(config: AgentMessageSchedulerConfig) {
|
||||||
|
this.conversationStorage = config.storage
|
||||||
|
this.jobRegistry = config.jobRegistry
|
||||||
|
this.timing = {
|
||||||
|
maxWaitTime: config.maxWaitTime,
|
||||||
|
waitTime: config.waitTIme,
|
||||||
|
}
|
||||||
|
|
||||||
|
this.jobRegistry.addEventListener("settled", this.eraseJob.bind(this))
|
||||||
|
this.jobRegistry.addEventListener("cancelled", this.eraseJob.bind(this))
|
||||||
|
}
|
||||||
|
|
||||||
|
async receiveMessage(conversationId: string, message: string) {
|
||||||
|
await this.conversationStorage.transaction(async (storage) => {
|
||||||
|
const now = new Date()
|
||||||
|
|
||||||
|
const entry = await storage.appendEntry(conversationId, {
|
||||||
|
kind: ConversationEntryKind.UserMessage,
|
||||||
|
payload: {
|
||||||
|
role: "user",
|
||||||
|
parts: [{ type: "text", text: message }],
|
||||||
|
} satisfies UserMessagePayload,
|
||||||
|
})
|
||||||
|
|
||||||
|
await storage.upsertConversationResponseState(conversationId, {
|
||||||
|
maxWaitUntil: new Date(now.getTime() + this.timing.maxWaitTime),
|
||||||
|
pendingSinceEntryId: entry.id,
|
||||||
|
status: "pending",
|
||||||
|
})
|
||||||
|
|
||||||
|
return entry
|
||||||
|
})
|
||||||
|
this.scheduleAgentResponse(conversationId, this.timing.waitTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
async receiveUserEvent(conversationId: string, event: UserEvent) {
|
||||||
|
if (event.type === "typing") {
|
||||||
|
await this.delayAgentResponse(conversationId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enqueueAgentResponse(conversationId: string): void {
|
||||||
|
const existing = this.timers.get(conversationId)
|
||||||
|
if (existing) {
|
||||||
|
clearTimeout(existing)
|
||||||
|
this.timers.delete(conversationId)
|
||||||
|
}
|
||||||
|
|
||||||
|
this.cancelCurrentJob(conversationId)
|
||||||
|
|
||||||
|
const job = this.jobRegistry.addJob({
|
||||||
|
payload: { conversationId },
|
||||||
|
})
|
||||||
|
this.runningJobs.set(conversationId, job)
|
||||||
|
}
|
||||||
|
|
||||||
|
private async delayAgentResponse(conversationId: string) {
|
||||||
|
this.cancelCurrentJob(conversationId);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const ok = await this.conversationStorage.transaction(async (storage) => {
|
||||||
|
const state = await storage.findConversationResponseState(conversationId);
|
||||||
|
if (state && state.status !== ConversationResponseStateStatus.Failed) {
|
||||||
|
await storage.updateConversationResponseState(conversationId, {
|
||||||
|
status: ConversationResponseStateStatus.Pending,
|
||||||
|
// the agent response was cancelled, so its no longer running
|
||||||
|
// clear runningSince timestamp
|
||||||
|
runningSince: null,
|
||||||
|
})
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
if (ok) {
|
||||||
|
await this.scheduleAgentResponse(conversationId, this.timing.waitTime)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof ConversationNotFoundError) {
|
||||||
|
// the user is typing but there isn't a scheduled agent response yet
|
||||||
|
// which means the user is typing their first message after the agent has previously responded
|
||||||
|
// swallow the error
|
||||||
|
} else {
|
||||||
|
console.error("[agent response scheduler] error delaying agent response", error)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async scheduleAgentResponse(conversationId: string, delay: number) {
|
||||||
|
const existing = this.timers.get(conversationId)
|
||||||
|
if (existing) {
|
||||||
|
clearTimeout(existing)
|
||||||
|
}
|
||||||
|
|
||||||
|
this.cancelCurrentJob(conversationId)
|
||||||
|
|
||||||
|
this.timers.set(
|
||||||
|
conversationId,
|
||||||
|
setTimeout(() => {
|
||||||
|
this.enqueueAgentResponse(conversationId)
|
||||||
|
}, delay),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* cancels the current job for agent response for the given conversation id
|
||||||
|
* no-op if there is no active job for the conversation.
|
||||||
|
*/
|
||||||
|
private cancelCurrentJob(conversationId: string): void {
|
||||||
|
const job = this.runningJobs.get(conversationId)
|
||||||
|
if (!job) return
|
||||||
|
|
||||||
|
// If an active response is working on stale context, abort it so the next
|
||||||
|
// job can answer using the latest pending user messages.
|
||||||
|
this.jobRegistry.cancelJob(job)
|
||||||
|
}
|
||||||
|
|
||||||
|
private eraseJob(job: Job<AgentResponseJobPayload>) {
|
||||||
|
if (this.runningJobs.get(job.payload.conversationId) === job) {
|
||||||
|
this.runningJobs.delete(job.payload.conversationId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
66
apps/freya-backend/src/agent/service.ts
Normal file
66
apps/freya-backend/src/agent/service.ts
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import type { UserEvent } from "@freya/agent-protocol"
|
||||||
|
|
||||||
|
import type { ConversationStorage } from "../conversations/storage"
|
||||||
|
import type { NotificationCentral } from "../notification/notification-central"
|
||||||
|
import type { UserSessionManager } from "../session"
|
||||||
|
|
||||||
|
import { JobRegistry } from "../lib/job"
|
||||||
|
import { Worker } from "../lib/worker"
|
||||||
|
import { AgentResponseJobExecutor, type AgentResponseJobPayload } from "./job"
|
||||||
|
import { AgentResponseReconciler } from "./reconciler"
|
||||||
|
import { AgentWorkScheduler } from "./scheduler"
|
||||||
|
|
||||||
|
interface AgentServiceConfig {
|
||||||
|
storage: ConversationStorage
|
||||||
|
userSessionManager: UserSessionManager
|
||||||
|
notificationCentral: NotificationCentral
|
||||||
|
signal: AbortSignal
|
||||||
|
}
|
||||||
|
|
||||||
|
export class AgentService {
|
||||||
|
private readonly storage: ConversationStorage
|
||||||
|
private readonly scheduler: AgentWorkScheduler
|
||||||
|
private readonly reconciler: AgentResponseReconciler
|
||||||
|
private readonly worker: Worker<AgentResponseJobPayload>
|
||||||
|
|
||||||
|
private readonly jobRegistry = new JobRegistry<AgentResponseJobPayload>()
|
||||||
|
|
||||||
|
constructor({ storage, userSessionManager, notificationCentral, signal }: AgentServiceConfig) {
|
||||||
|
this.storage = storage
|
||||||
|
this.scheduler = new AgentWorkScheduler({
|
||||||
|
storage,
|
||||||
|
jobRegistry: this.jobRegistry,
|
||||||
|
waitTIme: 5 * 1000,
|
||||||
|
maxWaitTime: 5 * 1000 * 60,
|
||||||
|
})
|
||||||
|
this.reconciler = new AgentResponseReconciler({
|
||||||
|
signal,
|
||||||
|
storage: this.storage,
|
||||||
|
interval: 60 * 1000,
|
||||||
|
scheduler: this.scheduler,
|
||||||
|
})
|
||||||
|
this.worker = new Worker<AgentResponseJobPayload>({
|
||||||
|
signal,
|
||||||
|
concurrency: 10,
|
||||||
|
registry: this.jobRegistry,
|
||||||
|
runner: new AgentResponseJobExecutor({
|
||||||
|
conversationStorage: storage,
|
||||||
|
notificationCentral,
|
||||||
|
userSessionManager,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
start() {
|
||||||
|
this.worker.start()
|
||||||
|
this.reconciler.start()
|
||||||
|
}
|
||||||
|
|
||||||
|
async scheduleAgentResponse(conversationId: string, message: string) {
|
||||||
|
await this.scheduler.receiveMessage(conversationId, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
async handleUserEvent(conversationId: string, event: UserEvent) {
|
||||||
|
await this.scheduler.receiveUserEvent(conversationId, event)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,7 +9,6 @@ import type {
|
|||||||
QueryAgentEventListener,
|
QueryAgentEventListener,
|
||||||
QueryAgentStreamEvent,
|
QueryAgentStreamEvent,
|
||||||
} from "./query-agent.ts"
|
} from "./query-agent.ts"
|
||||||
import type { AgentResponseStreamItem } from "./streaming.ts"
|
|
||||||
|
|
||||||
import { streamAgentResponse } from "./streaming.ts"
|
import { streamAgentResponse } from "./streaming.ts"
|
||||||
|
|
||||||
@@ -47,17 +46,13 @@ describe("streamAgentResponse", () => {
|
|||||||
{ type: "done" },
|
{ type: "done" },
|
||||||
])
|
])
|
||||||
|
|
||||||
const { events, result } = await collectStreamAgentResponse(
|
const events = await collectStreamAgentResponse(
|
||||||
streamAgentResponse({
|
streamAgentResponse({
|
||||||
agent,
|
agent,
|
||||||
input: { message: "hello" },
|
input: { message: "hello" },
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
expect(result).toEqual({
|
|
||||||
conversationId: "conversation-1",
|
|
||||||
message: "First message\nSecond message\nThird message",
|
|
||||||
})
|
|
||||||
expect(events).toEqual([
|
expect(events).toEqual([
|
||||||
{ type: "conversation_started", conversationId: "conversation-1" },
|
{ type: "conversation_started", conversationId: "conversation-1" },
|
||||||
{ type: "message_created", text: "First message" },
|
{ type: "message_created", text: "First message" },
|
||||||
@@ -74,17 +69,13 @@ describe("streamAgentResponse", () => {
|
|||||||
{ type: "done" },
|
{ type: "done" },
|
||||||
])
|
])
|
||||||
|
|
||||||
const { events, result } = await collectStreamAgentResponse(
|
const events = await collectStreamAgentResponse(
|
||||||
streamAgentResponse({
|
streamAgentResponse({
|
||||||
agent,
|
agent,
|
||||||
input: { message: "hello" },
|
input: { message: "hello" },
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
expect(result).toEqual({
|
|
||||||
conversationId: "conversation-1",
|
|
||||||
message: " const value = 1 \n\n return value",
|
|
||||||
})
|
|
||||||
expect(events).toEqual([
|
expect(events).toEqual([
|
||||||
{ type: "conversation_started", conversationId: "conversation-1" },
|
{ type: "conversation_started", conversationId: "conversation-1" },
|
||||||
{ type: "message_created", text: " const value = 1 " },
|
{ type: "message_created", text: " const value = 1 " },
|
||||||
@@ -122,28 +113,12 @@ describe("streamAgentResponse", () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
async function collectStreamAgentResponse(
|
async function collectStreamAgentResponse(
|
||||||
stream: AsyncIterable<AgentResponseStreamItem>,
|
stream: AsyncIterable<AgentEvent>,
|
||||||
events: AgentEvent[] = [],
|
events: AgentEvent[] = [],
|
||||||
): Promise<{
|
): Promise<AgentEvent[]> {
|
||||||
events: AgentEvent[]
|
for await (const event of stream) {
|
||||||
result: { message: string; conversationId: string }
|
events.push(event)
|
||||||
}> {
|
|
||||||
let result: { message: string; conversationId: string } | null = null
|
|
||||||
|
|
||||||
for await (const item of stream) {
|
|
||||||
switch (item.type) {
|
|
||||||
case "event":
|
|
||||||
events.push(item.event)
|
|
||||||
break
|
|
||||||
case "result":
|
|
||||||
result = item.result
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!result) {
|
return events
|
||||||
throw new Error("Expected stream result")
|
|
||||||
}
|
|
||||||
|
|
||||||
return { events, result }
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
import type { AgentEvent, SendMessageResult } from "@freya/agent-protocol"
|
import type { AgentEvent } from "@freya/agent-protocol"
|
||||||
|
|
||||||
import type { QueryAgent, QueryAgentAsk } from "./query-agent.ts"
|
import type { QueryAgent, QueryAgentAsk } from "./query-agent.ts"
|
||||||
|
|
||||||
export type AgentResponseStreamItem =
|
export type AgentResponseStreamItem = { type: "event"; event: AgentEvent }
|
||||||
| { type: "event"; event: AgentEvent }
|
|
||||||
| { type: "result"; result: SendMessageResult }
|
|
||||||
|
|
||||||
export async function* streamAgentResponse({
|
export async function* streamAgentResponse({
|
||||||
agent,
|
agent,
|
||||||
@@ -12,18 +10,18 @@ export async function* streamAgentResponse({
|
|||||||
}: {
|
}: {
|
||||||
agent: QueryAgent
|
agent: QueryAgent
|
||||||
input: QueryAgentAsk
|
input: QueryAgentAsk
|
||||||
}): AsyncGenerator<AgentResponseStreamItem, void, void> {
|
}): AsyncGenerator<AgentEvent, void, void> {
|
||||||
let message = ""
|
let message = ""
|
||||||
let conversationId: string | null = null
|
let conversationId: string | null = null
|
||||||
const splitter = new AgentMessageSplitter()
|
const splitter = new AgentMessageSplitter()
|
||||||
|
|
||||||
function messageEvent(text: string): AgentResponseStreamItem | null {
|
function messageEvent(text: string): AgentEvent | null {
|
||||||
if (text.trim() === "") return null
|
if (text.trim() === "") return null
|
||||||
|
|
||||||
return { type: "event", event: { type: "message_created", text } }
|
return { type: "message_created", text }
|
||||||
}
|
}
|
||||||
|
|
||||||
function flushPendingMessage(): AgentResponseStreamItem | null {
|
function flushPendingMessage(): AgentEvent | null {
|
||||||
const text = splitter.flush()
|
const text = splitter.flush()
|
||||||
if (text === null) return null
|
if (text === null) return null
|
||||||
|
|
||||||
@@ -31,10 +29,14 @@ export async function* streamAgentResponse({
|
|||||||
}
|
}
|
||||||
|
|
||||||
for await (const event of agent.ask(input)) {
|
for await (const event of agent.ask(input)) {
|
||||||
|
if (input.signal?.aborted) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
switch (event.type) {
|
switch (event.type) {
|
||||||
case "conversation":
|
case "conversation":
|
||||||
conversationId = event.conversationId
|
conversationId = event.conversationId
|
||||||
yield { type: "event", event: { type: "conversation_started", conversationId } }
|
yield { type: "conversation_started", conversationId }
|
||||||
break
|
break
|
||||||
|
|
||||||
case "text_delta":
|
case "text_delta":
|
||||||
@@ -50,7 +52,7 @@ export async function* streamAgentResponse({
|
|||||||
const item = flushPendingMessage()
|
const item = flushPendingMessage()
|
||||||
if (item) yield item
|
if (item) yield item
|
||||||
}
|
}
|
||||||
yield { type: "event", event: { type: "tool_started", toolName: event.toolName } }
|
yield { type: "tool_started", toolName: event.toolName }
|
||||||
break
|
break
|
||||||
|
|
||||||
case "tool_end":
|
case "tool_end":
|
||||||
@@ -59,12 +61,9 @@ export async function* streamAgentResponse({
|
|||||||
if (item) yield item
|
if (item) yield item
|
||||||
}
|
}
|
||||||
yield {
|
yield {
|
||||||
type: "event",
|
type: "tool_finished",
|
||||||
event: {
|
toolName: event.toolName,
|
||||||
type: "tool_finished",
|
ok: event.ok,
|
||||||
toolName: event.toolName,
|
|
||||||
ok: event.ok,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -73,7 +72,7 @@ export async function* streamAgentResponse({
|
|||||||
const item = flushPendingMessage()
|
const item = flushPendingMessage()
|
||||||
if (item) yield item
|
if (item) yield item
|
||||||
}
|
}
|
||||||
yield { type: "event", event: { type: "message_failed", error: event.message } }
|
yield { type: "message_failed", error: event.message }
|
||||||
throw new Error(event.message)
|
throw new Error(event.message)
|
||||||
|
|
||||||
case "done":
|
case "done":
|
||||||
@@ -81,26 +80,15 @@ export async function* streamAgentResponse({
|
|||||||
const item = flushPendingMessage()
|
const item = flushPendingMessage()
|
||||||
if (item) yield item
|
if (item) yield item
|
||||||
}
|
}
|
||||||
const result = createResult(message, conversationId)
|
yield { type: "message_finished" }
|
||||||
yield { type: "event", event: { type: "message_finished" } }
|
|
||||||
yield { type: "result", result }
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const item = flushPendingMessage()
|
const item = flushPendingMessage()
|
||||||
if (item) yield item
|
if (item) yield item
|
||||||
const result = createResult(message, conversationId)
|
|
||||||
yield { type: "event", event: { type: "message_finished" } }
|
|
||||||
yield { type: "result", result }
|
|
||||||
}
|
|
||||||
|
|
||||||
function createResult(message: string, conversationId: string | null): SendMessageResult {
|
yield { type: "message_finished" }
|
||||||
if (!conversationId) {
|
|
||||||
throw new Error("Agent response stream ended without a conversation id")
|
|
||||||
}
|
|
||||||
|
|
||||||
return { message, conversationId }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class AgentMessageSplitter {
|
class AgentMessageSplitter {
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import { describe, expect, test } from "bun:test"
|
import { describe, expect, test } from "bun:test"
|
||||||
import { Hono } from "hono"
|
import { Hono } from "hono"
|
||||||
|
|
||||||
import type { UserSessionManager } from "../session/index.ts"
|
import type { ConversationStorage } from "../conversations/storage.ts"
|
||||||
|
import type { NotificationCentral } from "../notification/notification-central.ts"
|
||||||
|
|
||||||
|
import type { AgentService } from "./service.ts"
|
||||||
import { registerAgentWebSocketHandlers } from "./ws.ts"
|
import { registerAgentWebSocketHandlers } from "./ws.ts"
|
||||||
|
|
||||||
describe("agent websocket handler", () => {
|
describe("agent websocket handler", () => {
|
||||||
@@ -11,7 +13,9 @@ describe("agent websocket handler", () => {
|
|||||||
const app = new Hono()
|
const app = new Hono()
|
||||||
|
|
||||||
registerAgentWebSocketHandlers(app, {
|
registerAgentWebSocketHandlers(app, {
|
||||||
sessionManager: {} as UserSessionManager,
|
agentService: {} as AgentService,
|
||||||
|
storage: {} as ConversationStorage,
|
||||||
|
notificationCentral: {} as NotificationCentral,
|
||||||
corsMiddleware: async (c, next) => {
|
corsMiddleware: async (c, next) => {
|
||||||
const origin = c.req.header("origin")
|
const origin = c.req.header("origin")
|
||||||
if (origin && origin !== "https://app.freya.test") {
|
if (origin && origin !== "https://app.freya.test") {
|
||||||
@@ -44,7 +48,9 @@ describe("agent websocket handler", () => {
|
|||||||
const app = new Hono()
|
const app = new Hono()
|
||||||
|
|
||||||
registerAgentWebSocketHandlers(app, {
|
registerAgentWebSocketHandlers(app, {
|
||||||
sessionManager: {} as UserSessionManager,
|
agentService: {} as AgentService,
|
||||||
|
storage: {} as ConversationStorage,
|
||||||
|
notificationCentral: {} as NotificationCentral,
|
||||||
corsMiddleware: async (_c, next) => {
|
corsMiddleware: async (_c, next) => {
|
||||||
await next()
|
await next()
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,53 +1,58 @@
|
|||||||
import type { AgentClientApi, AgentServerApi, SendMessageResult } from "@freya/agent-protocol"
|
import type { AgentClientApi, AgentServerApi, UserEvent } from "@freya/agent-protocol"
|
||||||
import type { JrpcChannel, JrpcMessage, JsonRpcMessage } from "@nym.sh/jrpc"
|
import type { JrpcChannel, JrpcMessage, JsonRpcMessage } from "@nym.sh/jrpc"
|
||||||
import type { Hono, MiddlewareHandler } from "hono"
|
import type { Hono, MiddlewareHandler } from "hono"
|
||||||
import type { WSContext } from "hono/ws"
|
import type { WSContext } from "hono/ws"
|
||||||
|
|
||||||
import { JsonRpcClient, JsonRpcServer } from "@nym.sh/jrpc"
|
import { JsonRpcClient, JsonRpcServer, deserializeJrpcMessage } from "@nym.sh/jrpc"
|
||||||
import { type } from "arktype"
|
|
||||||
import { upgradeWebSocket, websocket } from "hono/bun"
|
import { upgradeWebSocket, websocket } from "hono/bun"
|
||||||
|
|
||||||
import type { AuthSessionMiddleware } from "../auth/session-middleware.ts"
|
import type { AuthSessionMiddleware } from "../auth/session-middleware.ts"
|
||||||
import type { UserSessionManager } from "../session/index.ts"
|
import type { ConversationStorage } from "../conversations/storage.ts"
|
||||||
|
import type {
|
||||||
import { streamAgentResponse } from "./streaming.ts"
|
NotificationCentral,
|
||||||
|
NotificationPayload,
|
||||||
|
} from "../notification/notification-central.ts"
|
||||||
|
import type { AgentService } from "./service.ts"
|
||||||
|
|
||||||
interface AgentWebSocketHandlerDeps {
|
interface AgentWebSocketHandlerDeps {
|
||||||
sessionManager: UserSessionManager
|
agentService: AgentService
|
||||||
|
storage: ConversationStorage
|
||||||
|
notificationCentral: NotificationCentral
|
||||||
authSessionMiddleware: AuthSessionMiddleware
|
authSessionMiddleware: AuthSessionMiddleware
|
||||||
corsMiddleware: MiddlewareHandler
|
corsMiddleware: MiddlewareHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ValidSendMessageInput {
|
|
||||||
message: string
|
|
||||||
}
|
|
||||||
|
|
||||||
export const agentWebSocket = websocket
|
export const agentWebSocket = websocket
|
||||||
|
|
||||||
const SendMessageInputBody = type({
|
|
||||||
"+": "reject",
|
|
||||||
message: "string",
|
|
||||||
})
|
|
||||||
|
|
||||||
export function registerAgentWebSocketHandlers(
|
export function registerAgentWebSocketHandlers(
|
||||||
app: Hono,
|
app: Hono,
|
||||||
{ sessionManager, authSessionMiddleware, corsMiddleware }: AgentWebSocketHandlerDeps,
|
{
|
||||||
|
agentService,
|
||||||
|
storage,
|
||||||
|
notificationCentral,
|
||||||
|
authSessionMiddleware,
|
||||||
|
corsMiddleware,
|
||||||
|
}: AgentWebSocketHandlerDeps,
|
||||||
): void {
|
): void {
|
||||||
app.get(
|
app.get(
|
||||||
"/api/agent/ws",
|
"/api/agent/ws",
|
||||||
corsMiddleware,
|
corsMiddleware,
|
||||||
authSessionMiddleware,
|
authSessionMiddleware,
|
||||||
upgradeWebSocket((c) => {
|
upgradeWebSocket(async (c) => {
|
||||||
const user = c.get("user")
|
const user = c.get("user")
|
||||||
if (!user) {
|
if (!user) {
|
||||||
throw new Error("Authenticated WebSocket user missing")
|
throw new Error("Authenticated WebSocket user missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const conversation = await storage.getOrCreateConversation(user.id)
|
||||||
|
|
||||||
const channel = new HonoWebSocketJrpcChannel()
|
const channel = new HonoWebSocketJrpcChannel()
|
||||||
const connection = new AgentRpcConnection({
|
const connection = new AgentRpcConnection({
|
||||||
channel,
|
channel,
|
||||||
sessionManager,
|
notificationCentral,
|
||||||
|
agentService,
|
||||||
userId: user.id,
|
userId: user.id,
|
||||||
|
conversationId: conversation.id,
|
||||||
})
|
})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -64,6 +69,7 @@ export function registerAgentWebSocketHandlers(
|
|||||||
},
|
},
|
||||||
|
|
||||||
onClose() {
|
onClose() {
|
||||||
|
connection.close()
|
||||||
channel.close()
|
channel.close()
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -74,54 +80,52 @@ export function registerAgentWebSocketHandlers(
|
|||||||
class AgentRpcConnection implements AgentServerApi {
|
class AgentRpcConnection implements AgentServerApi {
|
||||||
private readonly client: JsonRpcClient<AgentClientApi>
|
private readonly client: JsonRpcClient<AgentClientApi>
|
||||||
private readonly server: JsonRpcServer<AgentServerApi>
|
private readonly server: JsonRpcServer<AgentServerApi>
|
||||||
private activeMessage: Promise<SendMessageResult> | null = null
|
private readonly agentService: AgentService
|
||||||
private readonly sessionManager: UserSessionManager
|
private readonly notificationCentral: NotificationCentral
|
||||||
private readonly userId: string
|
private readonly userId: string
|
||||||
|
private readonly conversationId: string
|
||||||
|
|
||||||
|
private cleanup: (() => void) | null = null
|
||||||
|
|
||||||
constructor({
|
constructor({
|
||||||
|
agentService,
|
||||||
|
notificationCentral,
|
||||||
channel,
|
channel,
|
||||||
sessionManager,
|
|
||||||
userId,
|
userId,
|
||||||
|
conversationId,
|
||||||
}: {
|
}: {
|
||||||
|
agentService: AgentService
|
||||||
|
notificationCentral: NotificationCentral
|
||||||
channel: JrpcChannel
|
channel: JrpcChannel
|
||||||
sessionManager: UserSessionManager
|
|
||||||
userId: string
|
userId: string
|
||||||
|
conversationId: string
|
||||||
}) {
|
}) {
|
||||||
this.sessionManager = sessionManager
|
|
||||||
this.userId = userId
|
|
||||||
this.client = new JsonRpcClient<AgentClientApi>(channel)
|
this.client = new JsonRpcClient<AgentClientApi>(channel)
|
||||||
|
this.agentService = agentService
|
||||||
|
this.notificationCentral = notificationCentral
|
||||||
|
this.userId = userId
|
||||||
|
this.conversationId = conversationId
|
||||||
this.server = new JsonRpcServer<AgentServerApi>(
|
this.server = new JsonRpcServer<AgentServerApi>(
|
||||||
{
|
{
|
||||||
sendMessage: this.sendMessage.bind(this),
|
sendMessage: this.sendMessage.bind(this),
|
||||||
|
notify: this.notify.bind(this),
|
||||||
ping: this.ping.bind(this),
|
ping: this.ping.bind(this),
|
||||||
},
|
},
|
||||||
channel,
|
channel,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
start(): Promise<void> {
|
notify(event: UserEvent): void {
|
||||||
return this.server.start()
|
this.agentService.handleUserEvent(this.conversationId, event)
|
||||||
}
|
}
|
||||||
|
|
||||||
async sendMessage(message: string): Promise<SendMessageResult> {
|
async sendMessage(message: string): Promise<boolean> {
|
||||||
const parsed = SendMessageInputBody({ message })
|
|
||||||
if (parsed instanceof type.errors) {
|
|
||||||
throw new Error(parsed.summary)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.activeMessage) {
|
|
||||||
throw new Error("A message is already running")
|
|
||||||
}
|
|
||||||
|
|
||||||
const run = this.runMessage(parsed)
|
|
||||||
this.activeMessage = run
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
return await run
|
await this.agentService.scheduleAgentResponse(this.conversationId, message)
|
||||||
} finally {
|
return true
|
||||||
if (this.activeMessage === run) {
|
} catch (error) {
|
||||||
this.activeMessage = null
|
console.log("[agent rpc connection] error when scheduling agent response", error)
|
||||||
}
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,26 +133,22 @@ class AgentRpcConnection implements AgentServerApi {
|
|||||||
return "pong"
|
return "pong"
|
||||||
}
|
}
|
||||||
|
|
||||||
private async runMessage(input: ValidSendMessageInput): Promise<SendMessageResult> {
|
async start() {
|
||||||
const session = await this.sessionManager.getOrCreate(this.userId)
|
this.cleanup = this.notificationCentral.registerListenerForUser(
|
||||||
let result: SendMessageResult | null = null
|
this.userId,
|
||||||
|
this.onNotificationReceived.bind(this),
|
||||||
|
)
|
||||||
|
await this.server.start()
|
||||||
|
}
|
||||||
|
|
||||||
for await (const item of streamAgentResponse({ agent: session.agent, input })) {
|
close() {
|
||||||
switch (item.type) {
|
this.cleanup?.()
|
||||||
case "event":
|
}
|
||||||
await this.client.call("notify", item.event)
|
|
||||||
break
|
private async onNotificationReceived(notification: NotificationPayload) {
|
||||||
case "result":
|
if (notification.kind === "agent") {
|
||||||
result = item.result
|
await this.client.call("notify", notification.payload)
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!result) {
|
|
||||||
throw new Error("Agent response stream ended without a result")
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,7 +171,11 @@ class HonoWebSocketJrpcChannel implements JrpcChannel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
receive(message: unknown): void {
|
receive(message: unknown): void {
|
||||||
const parsed = parseJrpcMessage(message)
|
if (typeof message !== "string") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const parsed = deserializeJrpcMessage(message)
|
||||||
if (!parsed) {
|
if (!parsed) {
|
||||||
this.ws?.close(1003, "Invalid JSON-RPC message")
|
this.ws?.close(1003, "Invalid JSON-RPC message")
|
||||||
return
|
return
|
||||||
@@ -236,52 +240,6 @@ class HonoWebSocketJrpcChannel implements JrpcChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function parseJrpcMessage(message: unknown): JrpcMessage | null {
|
|
||||||
const text = webSocketMessageText(message)
|
|
||||||
if (text === null) return null
|
|
||||||
|
|
||||||
try {
|
|
||||||
const value: unknown = JSON.parse(text)
|
|
||||||
return isJrpcMessage(value) ? value : null
|
|
||||||
} catch {
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function webSocketMessageText(message: unknown): string | null {
|
|
||||||
if (typeof message === "string") return message
|
|
||||||
if (message instanceof ArrayBuffer) return Buffer.from(message).toString("utf8")
|
|
||||||
if (ArrayBuffer.isView(message)) {
|
|
||||||
return Buffer.from(message.buffer, message.byteOffset, message.byteLength).toString("utf8")
|
|
||||||
}
|
|
||||||
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
function isJrpcMessage(value: unknown): value is JrpcMessage {
|
|
||||||
if (typeof value !== "object" || value === null) return false
|
|
||||||
if (!("jsonrpc" in value) || value.jsonrpc !== "2.0") return false
|
|
||||||
|
|
||||||
if ("method" in value) {
|
|
||||||
return "id" in value && typeof value.id === "number" && typeof value.method === "string"
|
|
||||||
}
|
|
||||||
|
|
||||||
if ("result" in value) {
|
|
||||||
return "id" in value && typeof value.id === "number"
|
|
||||||
}
|
|
||||||
|
|
||||||
if ("error" in value) {
|
|
||||||
return (
|
|
||||||
"id" in value &&
|
|
||||||
typeof value.id === "number" &&
|
|
||||||
typeof value.error === "object" &&
|
|
||||||
value.error !== null
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
function errorMessage(error: unknown): string {
|
function errorMessage(error: unknown): string {
|
||||||
return error instanceof Error ? error.message : String(error)
|
return error instanceof Error ? error.message : String(error)
|
||||||
}
|
}
|
||||||
|
|||||||
686
apps/freya-backend/src/conversations/db-storage.ts
Normal file
686
apps/freya-backend/src/conversations/db-storage.ts
Normal file
@@ -0,0 +1,686 @@
|
|||||||
|
import {
|
||||||
|
AssistantMessagePayload,
|
||||||
|
AttachmentPayload,
|
||||||
|
ConversationEntryKind,
|
||||||
|
ConversationEntryMetadata,
|
||||||
|
ConversationEntryVisibility,
|
||||||
|
ContextSummaryPayload,
|
||||||
|
GenericObjectPayload,
|
||||||
|
UserMessagePayload,
|
||||||
|
type ConversationEntryPayload,
|
||||||
|
} from "@freya/core"
|
||||||
|
import { type } from "arktype"
|
||||||
|
import { and, asc, desc, eq, gte, inArray } from "drizzle-orm"
|
||||||
|
import { alias } from "drizzle-orm/pg-core"
|
||||||
|
|
||||||
|
import type { Database } from "../db/index.ts"
|
||||||
|
import type {
|
||||||
|
AppendAttachmentEntryInput,
|
||||||
|
AppendAttachmentEntryResult,
|
||||||
|
AppendConversationEntryInput,
|
||||||
|
ConversationEntryRow,
|
||||||
|
ConversationResponseStateRow,
|
||||||
|
ConversationRow,
|
||||||
|
ConversationStorage,
|
||||||
|
CreateFileInput,
|
||||||
|
FileRow,
|
||||||
|
ListConversationEntriesParams,
|
||||||
|
UpdateConversationResponseStateInput,
|
||||||
|
UpsertConversationResponseStateInput,
|
||||||
|
} from "./storage.ts"
|
||||||
|
|
||||||
|
import {
|
||||||
|
conversationEntries,
|
||||||
|
ConversationResponseStateStatus,
|
||||||
|
conversationResponseState as conversationResponseStateTable,
|
||||||
|
conversations as conversationsTable,
|
||||||
|
files,
|
||||||
|
user,
|
||||||
|
} from "../db/schema.ts"
|
||||||
|
import { ConversationNotFoundError } from "./errors.ts"
|
||||||
|
|
||||||
|
const conversationEntryKind = type.enumerated(...Object.values(ConversationEntryKind))
|
||||||
|
const conversationEntryVisibility = type.enumerated(...Object.values(ConversationEntryVisibility))
|
||||||
|
const pendingSinceEntry = alias(conversationEntries, "pending_since_entry")
|
||||||
|
|
||||||
|
export class DrizzleConversationStorage implements ConversationStorage {
|
||||||
|
private readonly db: Database
|
||||||
|
private readonly inTransaction: boolean
|
||||||
|
|
||||||
|
constructor(db: Database, inTransaction = false) {
|
||||||
|
this.db = db
|
||||||
|
this.inTransaction = inTransaction
|
||||||
|
}
|
||||||
|
|
||||||
|
async transaction<T>(tx: (storage: ConversationStorage) => T | Promise<T>): Promise<T> {
|
||||||
|
if (this.inTransaction) return tx(this)
|
||||||
|
|
||||||
|
return this.db.transaction(async (transactionDb) =>
|
||||||
|
tx(new DrizzleConversationStorage(transactionDb, true)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async createConversation(userId: string): Promise<ConversationRow> {
|
||||||
|
return insertConversation(this.db, userId)
|
||||||
|
}
|
||||||
|
|
||||||
|
async listUserConversations(userId: string): Promise<ConversationRow[]> {
|
||||||
|
return this.db
|
||||||
|
.select()
|
||||||
|
.from(conversationsTable)
|
||||||
|
.where(eq(conversationsTable.userId, userId))
|
||||||
|
.orderBy(desc(conversationsTable.updatedAt), desc(conversationsTable.createdAt))
|
||||||
|
}
|
||||||
|
|
||||||
|
async findConversation(conversationId: string): Promise<ConversationRow | null> {
|
||||||
|
return findConversation(this.db, conversationId)
|
||||||
|
}
|
||||||
|
|
||||||
|
async getOrCreateConversation(userId: string): Promise<ConversationRow> {
|
||||||
|
return this.write(async (db) => {
|
||||||
|
await requireUserForUpdate(db, userId)
|
||||||
|
const existing = await latestConversation(db, userId)
|
||||||
|
if (existing) return existing
|
||||||
|
|
||||||
|
return insertConversation(db, userId)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async createFile(userId: string, input: CreateFileInput): Promise<FileRow> {
|
||||||
|
return insertFile(this.db, userId, input)
|
||||||
|
}
|
||||||
|
|
||||||
|
async appendEntry(
|
||||||
|
conversationId: string,
|
||||||
|
input: AppendConversationEntryInput,
|
||||||
|
): Promise<ConversationEntryRow> {
|
||||||
|
return this.write((db) => appendEntryToConversation(db, null, conversationId, input))
|
||||||
|
}
|
||||||
|
|
||||||
|
async appendAttachmentEntry(
|
||||||
|
conversationId: string,
|
||||||
|
input: AppendAttachmentEntryInput,
|
||||||
|
): Promise<AppendAttachmentEntryResult> {
|
||||||
|
return this.write((db) => appendAttachmentEntryToConversation(db, null, conversationId, input))
|
||||||
|
}
|
||||||
|
|
||||||
|
async nextSequence(conversationId: string): Promise<number> {
|
||||||
|
return nextSequence(this.db, conversationId)
|
||||||
|
}
|
||||||
|
|
||||||
|
async listUserConversationEntries(
|
||||||
|
userId: string,
|
||||||
|
conversationId: string,
|
||||||
|
params: ListConversationEntriesParams = {},
|
||||||
|
): Promise<ConversationEntryRow[]> {
|
||||||
|
if (!(await findUserConversation(this.db, userId, conversationId))) {
|
||||||
|
throw new ConversationNotFoundError(conversationId, userId)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.visibility) {
|
||||||
|
return this.db
|
||||||
|
.select()
|
||||||
|
.from(conversationEntries)
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
eq(conversationEntries.conversationId, conversationId),
|
||||||
|
eq(conversationEntries.visibility, params.visibility),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.orderBy(asc(conversationEntries.sequence))
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.db
|
||||||
|
.select()
|
||||||
|
.from(conversationEntries)
|
||||||
|
.where(eq(conversationEntries.conversationId, conversationId))
|
||||||
|
.orderBy(asc(conversationEntries.sequence))
|
||||||
|
}
|
||||||
|
|
||||||
|
async listPendingUserConversationEntries(
|
||||||
|
userId: string,
|
||||||
|
conversationId: string,
|
||||||
|
): Promise<ConversationEntryRow[]> {
|
||||||
|
const entries = await this.db
|
||||||
|
.select({ entry: conversationEntries })
|
||||||
|
.from(conversationResponseStateTable)
|
||||||
|
.innerJoin(
|
||||||
|
conversationsTable,
|
||||||
|
and(
|
||||||
|
eq(conversationsTable.id, conversationResponseStateTable.conversationId),
|
||||||
|
eq(conversationsTable.userId, userId),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.innerJoin(
|
||||||
|
pendingSinceEntry,
|
||||||
|
and(
|
||||||
|
eq(pendingSinceEntry.id, conversationResponseStateTable.pendingSinceEntryId),
|
||||||
|
eq(pendingSinceEntry.conversationId, conversationResponseStateTable.conversationId),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.innerJoin(
|
||||||
|
conversationEntries,
|
||||||
|
and(
|
||||||
|
eq(conversationEntries.conversationId, conversationResponseStateTable.conversationId),
|
||||||
|
eq(conversationEntries.kind, ConversationEntryKind.UserMessage),
|
||||||
|
gte(conversationEntries.sequence, pendingSinceEntry.sequence),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
eq(conversationResponseStateTable.conversationId, conversationId),
|
||||||
|
eq(conversationEntries.conversationId, conversationId),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.orderBy(asc(conversationEntries.sequence))
|
||||||
|
|
||||||
|
if (entries.length > 0) return entries.map(({ entry }) => entry)
|
||||||
|
if (await findUserConversation(this.db, userId, conversationId)) return []
|
||||||
|
|
||||||
|
throw new ConversationNotFoundError(conversationId, userId)
|
||||||
|
}
|
||||||
|
|
||||||
|
async findConversationResponseState(
|
||||||
|
conversationId: string,
|
||||||
|
): Promise<ConversationResponseStateRow | null> {
|
||||||
|
const rows = await this.db
|
||||||
|
.select()
|
||||||
|
.from(conversationResponseStateTable)
|
||||||
|
.where(eq(conversationResponseStateTable.conversationId, conversationId))
|
||||||
|
.limit(1)
|
||||||
|
|
||||||
|
return rows[0] ?? null
|
||||||
|
}
|
||||||
|
|
||||||
|
async listPendingResponseStates(): Promise<ConversationResponseStateRow[]> {
|
||||||
|
const rows = await this.db
|
||||||
|
.select()
|
||||||
|
.from(conversationResponseStateTable)
|
||||||
|
.where(eq(conversationResponseStateTable.status, ConversationResponseStateStatus.Pending))
|
||||||
|
|
||||||
|
return rows
|
||||||
|
}
|
||||||
|
|
||||||
|
async listRunningResponseStates(): Promise<ConversationResponseStateRow[]> {
|
||||||
|
const rows = await this.db
|
||||||
|
.select()
|
||||||
|
.from(conversationResponseStateTable)
|
||||||
|
.where(eq(conversationResponseStateTable.status, ConversationResponseStateStatus.Running))
|
||||||
|
|
||||||
|
return rows
|
||||||
|
}
|
||||||
|
|
||||||
|
async upsertConversationResponseState(
|
||||||
|
conversationId: string,
|
||||||
|
input: UpsertConversationResponseStateInput,
|
||||||
|
): Promise<ConversationResponseStateRow> {
|
||||||
|
const now = new Date()
|
||||||
|
|
||||||
|
return this.write(async (db) => {
|
||||||
|
if (!(await findConversationByIdForUpdate(db, conversationId))) {
|
||||||
|
throw new ConversationNotFoundError(conversationId, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
const rows = await db
|
||||||
|
.insert(conversationResponseStateTable)
|
||||||
|
.values({
|
||||||
|
conversationId,
|
||||||
|
status: input.status ?? ConversationResponseStateStatus.Pending,
|
||||||
|
pendingSinceEntryId: input.pendingSinceEntryId,
|
||||||
|
maxWaitUntil: input.maxWaitUntil,
|
||||||
|
runningSince: input.runningSince ?? null,
|
||||||
|
updatedAt: now,
|
||||||
|
})
|
||||||
|
.onConflictDoUpdate({
|
||||||
|
target: conversationResponseStateTable.conversationId,
|
||||||
|
set: {
|
||||||
|
status: input.status ?? ConversationResponseStateStatus.Pending,
|
||||||
|
maxWaitUntil: input.maxWaitUntil,
|
||||||
|
runningSince: input.runningSince ?? null,
|
||||||
|
updatedAt: now,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.returning()
|
||||||
|
|
||||||
|
return requireRow(rows)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async updateConversationResponseState(
|
||||||
|
conversationId: string,
|
||||||
|
input: UpdateConversationResponseStateInput,
|
||||||
|
): Promise<ConversationResponseStateRow | null> {
|
||||||
|
return this.write(async (db) => {
|
||||||
|
if (!(await findConversationByIdForUpdate(db, conversationId))) {
|
||||||
|
throw new ConversationNotFoundError(conversationId, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
const rows = await db
|
||||||
|
.update(conversationResponseStateTable)
|
||||||
|
.set({
|
||||||
|
status: input.status,
|
||||||
|
pendingSinceEntryId: input.pendingSinceEntryId,
|
||||||
|
maxWaitUntil: input.maxWaitUntil,
|
||||||
|
runningSince: input.runningSince,
|
||||||
|
updatedAt: new Date(),
|
||||||
|
})
|
||||||
|
.where(eq(conversationResponseStateTable.conversationId, conversationId))
|
||||||
|
.returning()
|
||||||
|
|
||||||
|
return rows[0] ?? null
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async markResponseStateStatus(
|
||||||
|
conversationIds: string[],
|
||||||
|
status: ConversationResponseStateStatus,
|
||||||
|
): Promise<ConversationResponseStateRow[]> {
|
||||||
|
return this.write(async (db) => {
|
||||||
|
const now = new Date()
|
||||||
|
|
||||||
|
let runningSince: Date | null
|
||||||
|
switch (status) {
|
||||||
|
case "pending":
|
||||||
|
case "failed":
|
||||||
|
runningSince = null
|
||||||
|
break
|
||||||
|
case "running":
|
||||||
|
runningSince = now
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
const rows = await db
|
||||||
|
.update(conversationResponseStateTable)
|
||||||
|
.set({
|
||||||
|
status,
|
||||||
|
runningSince,
|
||||||
|
updatedAt: now,
|
||||||
|
})
|
||||||
|
.where(inArray(conversationResponseStateTable.conversationId, conversationIds))
|
||||||
|
.returning()
|
||||||
|
|
||||||
|
return rows
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async claimPendingConversationResponseState(
|
||||||
|
conversationId: string,
|
||||||
|
): Promise<ConversationResponseStateRow | null> {
|
||||||
|
return this.write(async (db) => {
|
||||||
|
const now = new Date()
|
||||||
|
const rows = await db
|
||||||
|
.update(conversationResponseStateTable)
|
||||||
|
.set({
|
||||||
|
status: "running",
|
||||||
|
runningSince: now,
|
||||||
|
updatedAt: now,
|
||||||
|
})
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
eq(conversationResponseStateTable.conversationId, conversationId),
|
||||||
|
eq(conversationResponseStateTable.status, "pending"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.returning()
|
||||||
|
|
||||||
|
return rows[0] ?? null
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async clearConversationResponseState(conversationId: string): Promise<void> {
|
||||||
|
await this.write(async (db) => {
|
||||||
|
if (!(await findConversationByIdForUpdate(db, conversationId))) {
|
||||||
|
throw new ConversationNotFoundError(conversationId, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
await db
|
||||||
|
.delete(conversationResponseStateTable)
|
||||||
|
.where(eq(conversationResponseStateTable.conversationId, conversationId))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
private async write<T>(fn: (db: Database) => Promise<T>): Promise<T> {
|
||||||
|
if (this.inTransaction) return fn(this.db)
|
||||||
|
|
||||||
|
return this.db.transaction(fn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createConversationStorage(db: Database): ConversationStorage {
|
||||||
|
return new DrizzleConversationStorage(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function conversations(db: Database, userId: string) {
|
||||||
|
const storage = createConversationStorage(db)
|
||||||
|
|
||||||
|
return {
|
||||||
|
createConversation(): Promise<ConversationRow> {
|
||||||
|
return storage.createConversation(userId)
|
||||||
|
},
|
||||||
|
|
||||||
|
listConversations(): Promise<ConversationRow[]> {
|
||||||
|
return storage.listUserConversations(userId)
|
||||||
|
},
|
||||||
|
|
||||||
|
getConversation(conversationId: string): Promise<ConversationRow | null> {
|
||||||
|
return findUserConversation(db, userId, conversationId)
|
||||||
|
},
|
||||||
|
|
||||||
|
getOrCreateConversation(): Promise<ConversationRow> {
|
||||||
|
return storage.getOrCreateConversation(userId)
|
||||||
|
},
|
||||||
|
|
||||||
|
createFile(input: CreateFileInput): Promise<FileRow> {
|
||||||
|
return storage.createFile(userId, input)
|
||||||
|
},
|
||||||
|
|
||||||
|
appendEntry(
|
||||||
|
conversationId: string,
|
||||||
|
input: AppendConversationEntryInput,
|
||||||
|
): Promise<ConversationEntryRow> {
|
||||||
|
return db.transaction((tx) => appendEntryToConversation(tx, userId, conversationId, input))
|
||||||
|
},
|
||||||
|
|
||||||
|
appendAttachmentEntry(
|
||||||
|
conversationId: string,
|
||||||
|
input: AppendAttachmentEntryInput,
|
||||||
|
): Promise<AppendAttachmentEntryResult> {
|
||||||
|
return db.transaction((tx) =>
|
||||||
|
appendAttachmentEntryToConversation(tx, userId, conversationId, input),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
|
||||||
|
listEntries(
|
||||||
|
conversationId: string,
|
||||||
|
params: ListConversationEntriesParams = {},
|
||||||
|
): Promise<ConversationEntryRow[]> {
|
||||||
|
return storage.listUserConversationEntries(userId, conversationId, params)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function conversationResponse(db: Database, _userId: string, conversationId: string) {
|
||||||
|
const storage = createConversationStorage(db)
|
||||||
|
|
||||||
|
return {
|
||||||
|
get(): Promise<ConversationResponseStateRow | null> {
|
||||||
|
return storage.findConversationResponseState(conversationId)
|
||||||
|
},
|
||||||
|
|
||||||
|
upsert(input: UpsertConversationResponseStateInput): Promise<ConversationResponseStateRow> {
|
||||||
|
return storage.upsertConversationResponseState(conversationId, input)
|
||||||
|
},
|
||||||
|
|
||||||
|
update(
|
||||||
|
input: UpdateConversationResponseStateInput,
|
||||||
|
): Promise<ConversationResponseStateRow | null> {
|
||||||
|
return storage.updateConversationResponseState(conversationId, input)
|
||||||
|
},
|
||||||
|
|
||||||
|
clear(): Promise<void> {
|
||||||
|
return storage.clearConversationResponseState(conversationId)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function payloadForKind(
|
||||||
|
kind: ConversationEntryKind,
|
||||||
|
payload: AppendConversationEntryInput["payload"],
|
||||||
|
): ConversationEntryPayload {
|
||||||
|
switch (kind) {
|
||||||
|
case ConversationEntryKind.UserMessage:
|
||||||
|
return UserMessagePayload.assert(payload)
|
||||||
|
case ConversationEntryKind.AssistantMessage:
|
||||||
|
return AssistantMessagePayload.assert(payload)
|
||||||
|
case ConversationEntryKind.Attachment:
|
||||||
|
return AttachmentPayload.assert(payload)
|
||||||
|
case ConversationEntryKind.ContextSummary:
|
||||||
|
return ContextSummaryPayload.assert(payload)
|
||||||
|
case ConversationEntryKind.ToolCall:
|
||||||
|
case ConversationEntryKind.ToolResult:
|
||||||
|
case ConversationEntryKind.SystemNote:
|
||||||
|
return GenericObjectPayload.assert(payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function appendEntryToConversation(
|
||||||
|
db: Database,
|
||||||
|
userId: string | null,
|
||||||
|
conversationId: string,
|
||||||
|
input: AppendConversationEntryInput,
|
||||||
|
): Promise<ConversationEntryRow> {
|
||||||
|
const kind = conversationEntryKind.assert(input.kind)
|
||||||
|
const visibility = conversationEntryVisibility.assert(
|
||||||
|
input.visibility ?? defaultVisibilityForKind(kind),
|
||||||
|
)
|
||||||
|
const payload = payloadForKind(kind, input.payload)
|
||||||
|
const metadata = ConversationEntryMetadata.assert(input.metadata ?? {})
|
||||||
|
let fileId: string | null = null
|
||||||
|
|
||||||
|
if (input.kind === ConversationEntryKind.Attachment) {
|
||||||
|
fileId = input.fileId
|
||||||
|
}
|
||||||
|
|
||||||
|
const conversation = userId
|
||||||
|
? await findConversationForUpdate(db, userId, conversationId)
|
||||||
|
: await findConversationByIdForUpdate(db, conversationId)
|
||||||
|
if (!conversation) {
|
||||||
|
throw new ConversationNotFoundError(conversationId, userId ?? "")
|
||||||
|
}
|
||||||
|
if (fileId) await requireFile(db, conversation.userId, fileId)
|
||||||
|
|
||||||
|
const sequence = await nextSequence(db, conversationId)
|
||||||
|
const rows = await db
|
||||||
|
.insert(conversationEntries)
|
||||||
|
.values({
|
||||||
|
conversationId,
|
||||||
|
sequence,
|
||||||
|
kind,
|
||||||
|
visibility,
|
||||||
|
fileId,
|
||||||
|
payload,
|
||||||
|
metadata,
|
||||||
|
})
|
||||||
|
.returning()
|
||||||
|
|
||||||
|
await touchConversation(db, conversation.userId, conversationId)
|
||||||
|
return requireRow(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
async function appendAttachmentEntryToConversation(
|
||||||
|
db: Database,
|
||||||
|
userId: string | null,
|
||||||
|
conversationId: string,
|
||||||
|
input: AppendAttachmentEntryInput,
|
||||||
|
): Promise<AppendAttachmentEntryResult> {
|
||||||
|
const payload = AttachmentPayload.assert(input.payload)
|
||||||
|
const visibility = conversationEntryVisibility.assert(
|
||||||
|
input.visibility ?? defaultVisibilityForKind(ConversationEntryKind.Attachment),
|
||||||
|
)
|
||||||
|
const metadata = ConversationEntryMetadata.assert(input.metadata ?? {})
|
||||||
|
const conversation = userId
|
||||||
|
? await findConversationForUpdate(db, userId, conversationId)
|
||||||
|
: await findConversationByIdForUpdate(db, conversationId)
|
||||||
|
|
||||||
|
if (!conversation) {
|
||||||
|
throw new ConversationNotFoundError(conversationId, userId ?? "")
|
||||||
|
}
|
||||||
|
|
||||||
|
const file = await insertFile(db, conversation.userId, input.file)
|
||||||
|
const sequence = await nextSequence(db, conversationId)
|
||||||
|
const rows = await db
|
||||||
|
.insert(conversationEntries)
|
||||||
|
.values({
|
||||||
|
conversationId,
|
||||||
|
sequence,
|
||||||
|
kind: ConversationEntryKind.Attachment,
|
||||||
|
visibility,
|
||||||
|
fileId: file.id,
|
||||||
|
payload,
|
||||||
|
metadata,
|
||||||
|
})
|
||||||
|
.returning()
|
||||||
|
|
||||||
|
await touchConversation(db, conversation.userId, conversationId)
|
||||||
|
return {
|
||||||
|
file,
|
||||||
|
entry: requireRow(rows),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function requireUserForUpdate(db: Database, userId: string): Promise<void> {
|
||||||
|
const rows = await db
|
||||||
|
.select({ id: user.id })
|
||||||
|
.from(user)
|
||||||
|
.where(eq(user.id, userId))
|
||||||
|
.limit(1)
|
||||||
|
.for("update")
|
||||||
|
|
||||||
|
requireRow(rows, `User not found: ${userId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function findConversation(
|
||||||
|
db: Database,
|
||||||
|
conversationId: string,
|
||||||
|
): Promise<ConversationRow | null> {
|
||||||
|
const rows = await db
|
||||||
|
.select()
|
||||||
|
.from(conversationsTable)
|
||||||
|
.where(eq(conversationsTable.id, conversationId))
|
||||||
|
.limit(1)
|
||||||
|
|
||||||
|
return rows[0] ?? null
|
||||||
|
}
|
||||||
|
|
||||||
|
async function findUserConversation(
|
||||||
|
db: Database,
|
||||||
|
userId: string,
|
||||||
|
conversationId: string,
|
||||||
|
): Promise<ConversationRow | null> {
|
||||||
|
const rows = await db
|
||||||
|
.select()
|
||||||
|
.from(conversationsTable)
|
||||||
|
.where(and(eq(conversationsTable.id, conversationId), eq(conversationsTable.userId, userId)))
|
||||||
|
.limit(1)
|
||||||
|
|
||||||
|
return rows[0] ?? null
|
||||||
|
}
|
||||||
|
|
||||||
|
async function findConversationForUpdate(
|
||||||
|
db: Database,
|
||||||
|
userId: string,
|
||||||
|
conversationId: string,
|
||||||
|
): Promise<ConversationRow | null> {
|
||||||
|
const rows = await db
|
||||||
|
.select()
|
||||||
|
.from(conversationsTable)
|
||||||
|
.where(and(eq(conversationsTable.id, conversationId), eq(conversationsTable.userId, userId)))
|
||||||
|
.limit(1)
|
||||||
|
.for("update")
|
||||||
|
|
||||||
|
return rows[0] ?? null
|
||||||
|
}
|
||||||
|
|
||||||
|
async function findConversationByIdForUpdate(
|
||||||
|
db: Database,
|
||||||
|
conversationId: string,
|
||||||
|
): Promise<ConversationRow | null> {
|
||||||
|
const rows = await db
|
||||||
|
.select()
|
||||||
|
.from(conversationsTable)
|
||||||
|
.where(eq(conversationsTable.id, conversationId))
|
||||||
|
.limit(1)
|
||||||
|
.for("update")
|
||||||
|
|
||||||
|
return rows[0] ?? null
|
||||||
|
}
|
||||||
|
|
||||||
|
async function latestConversation(db: Database, userId: string): Promise<ConversationRow | null> {
|
||||||
|
const rows = await db
|
||||||
|
.select()
|
||||||
|
.from(conversationsTable)
|
||||||
|
.where(eq(conversationsTable.userId, userId))
|
||||||
|
.orderBy(desc(conversationsTable.updatedAt), desc(conversationsTable.createdAt))
|
||||||
|
.limit(1)
|
||||||
|
|
||||||
|
return rows[0] ?? null
|
||||||
|
}
|
||||||
|
|
||||||
|
async function insertConversation(db: Database, userId: string): Promise<ConversationRow> {
|
||||||
|
const rows = await db
|
||||||
|
.insert(conversationsTable)
|
||||||
|
.values({
|
||||||
|
userId,
|
||||||
|
})
|
||||||
|
.returning()
|
||||||
|
|
||||||
|
return requireRow(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
async function requireFile(db: Database, userId: string, fileId: string): Promise<FileRow> {
|
||||||
|
const rows = await db
|
||||||
|
.select()
|
||||||
|
.from(files)
|
||||||
|
.where(and(eq(files.id, fileId), eq(files.userId, userId)))
|
||||||
|
.limit(1)
|
||||||
|
|
||||||
|
return requireRow(rows, `File not found: ${fileId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
async function insertFile(db: Database, userId: string, input: CreateFileInput): Promise<FileRow> {
|
||||||
|
const rows = await db
|
||||||
|
.insert(files)
|
||||||
|
.values({
|
||||||
|
userId,
|
||||||
|
storageKey: input.storageKey,
|
||||||
|
originalName: input.originalName ?? null,
|
||||||
|
mimeType: input.mimeType,
|
||||||
|
sizeBytes: input.sizeBytes,
|
||||||
|
metadata: input.metadata ?? {},
|
||||||
|
})
|
||||||
|
.returning()
|
||||||
|
|
||||||
|
return requireRow(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
async function touchConversation(
|
||||||
|
db: Database,
|
||||||
|
userId: string,
|
||||||
|
conversationId: string,
|
||||||
|
): Promise<void> {
|
||||||
|
await db
|
||||||
|
.update(conversationsTable)
|
||||||
|
.set({ updatedAt: new Date() })
|
||||||
|
.where(and(eq(conversationsTable.id, conversationId), eq(conversationsTable.userId, userId)))
|
||||||
|
}
|
||||||
|
|
||||||
|
async function nextSequence(db: Database, conversationId: string): Promise<number> {
|
||||||
|
const rows = await db
|
||||||
|
.select({ sequence: conversationEntries.sequence })
|
||||||
|
.from(conversationEntries)
|
||||||
|
.where(eq(conversationEntries.conversationId, conversationId))
|
||||||
|
.orderBy(desc(conversationEntries.sequence))
|
||||||
|
.limit(1)
|
||||||
|
|
||||||
|
return (rows[0]?.sequence ?? 0) + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
function requireRow<T>(rows: T[], message = "Expected database row"): T {
|
||||||
|
const row = rows[0]
|
||||||
|
if (!row) throw new Error(message)
|
||||||
|
return row
|
||||||
|
}
|
||||||
|
|
||||||
|
function defaultVisibilityForKind(kind: ConversationEntryKind): ConversationEntryVisibility {
|
||||||
|
switch (kind) {
|
||||||
|
case ConversationEntryKind.UserMessage:
|
||||||
|
case ConversationEntryKind.AssistantMessage:
|
||||||
|
case ConversationEntryKind.Attachment:
|
||||||
|
return ConversationEntryVisibility.UserVisible
|
||||||
|
case ConversationEntryKind.ToolCall:
|
||||||
|
case ConversationEntryKind.ToolResult:
|
||||||
|
case ConversationEntryKind.ContextSummary:
|
||||||
|
case ConversationEntryKind.SystemNote:
|
||||||
|
return ConversationEntryVisibility.Internal
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,7 +25,7 @@ const listEntriesCalls: Array<{
|
|||||||
params: ListConversationEntriesParams
|
params: ListConversationEntriesParams
|
||||||
}> = []
|
}> = []
|
||||||
|
|
||||||
mock.module("./storage.ts", () => ({
|
mock.module("./db-storage.ts", () => ({
|
||||||
conversations: (_db: Database, userId: string) => ({
|
conversations: (_db: Database, userId: string) => ({
|
||||||
async listConversations(): Promise<ConversationRow[]> {
|
async listConversations(): Promise<ConversationRow[]> {
|
||||||
return conversationRowsByUser.get(userId) ?? []
|
return conversationRowsByUser.get(userId) ?? []
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import type { AuthSessionMiddleware } from "../auth/session-middleware.ts"
|
|||||||
import type { Database } from "../db/index.ts"
|
import type { Database } from "../db/index.ts"
|
||||||
import type { ConversationRow } from "./storage.ts"
|
import type { ConversationRow } from "./storage.ts"
|
||||||
|
|
||||||
|
import { conversations } from "./db-storage.ts"
|
||||||
import { ConversationNotFoundError } from "./errors.ts"
|
import { ConversationNotFoundError } from "./errors.ts"
|
||||||
import { conversations } from "./storage.ts"
|
|
||||||
|
|
||||||
/** Hono environment populated by the conversations route middleware. */
|
/** Hono environment populated by the conversations route middleware. */
|
||||||
type Env = {
|
type Env = {
|
||||||
|
|||||||
@@ -2,28 +2,70 @@ import {
|
|||||||
AssistantMessagePayload,
|
AssistantMessagePayload,
|
||||||
AttachmentPayload,
|
AttachmentPayload,
|
||||||
ConversationEntryKind,
|
ConversationEntryKind,
|
||||||
|
ConversationEntryMetadata,
|
||||||
ConversationEntryVisibility,
|
ConversationEntryVisibility,
|
||||||
ContextSummaryPayload,
|
ContextSummaryPayload,
|
||||||
ConversationEntryMetadata,
|
|
||||||
GenericObjectPayload,
|
GenericObjectPayload,
|
||||||
UserMessagePayload,
|
UserMessagePayload,
|
||||||
type ConversationEntryPayload,
|
|
||||||
} from "@freya/core"
|
} from "@freya/core"
|
||||||
import { type } from "arktype"
|
|
||||||
import { and, asc, desc, eq } from "drizzle-orm"
|
|
||||||
|
|
||||||
import type { Database } from "../db/index.ts"
|
|
||||||
|
|
||||||
import {
|
import {
|
||||||
conversationEntries,
|
conversationEntries,
|
||||||
|
conversationResponseState as conversationResponseStateTable,
|
||||||
conversations as conversationsTable,
|
conversations as conversationsTable,
|
||||||
files,
|
files,
|
||||||
user,
|
type ConversationResponseStateStatus,
|
||||||
} from "../db/schema.ts"
|
} from "../db/schema.ts"
|
||||||
import { ConversationNotFoundError } from "./errors.ts"
|
|
||||||
|
|
||||||
const conversationEntryKind = type.enumerated(...Object.values(ConversationEntryKind))
|
export interface ConversationStorage {
|
||||||
const conversationEntryVisibility = type.enumerated(...Object.values(ConversationEntryVisibility))
|
transaction<T>(tx: (storage: ConversationStorage) => T | Promise<T>): Promise<T>
|
||||||
|
createConversation(userId: string): Promise<ConversationRow>
|
||||||
|
listUserConversations(userId: string): Promise<ConversationRow[]>
|
||||||
|
findConversation(conversationId: string): Promise<ConversationRow | null>
|
||||||
|
getOrCreateConversation(userId: string): Promise<ConversationRow>
|
||||||
|
createFile(userId: string, input: CreateFileInput): Promise<FileRow>
|
||||||
|
appendEntry(
|
||||||
|
conversationId: string,
|
||||||
|
input: AppendConversationEntryInput,
|
||||||
|
): Promise<ConversationEntryRow>
|
||||||
|
appendAttachmentEntry(
|
||||||
|
conversationId: string,
|
||||||
|
input: AppendAttachmentEntryInput,
|
||||||
|
): Promise<AppendAttachmentEntryResult>
|
||||||
|
nextSequence(conversationId: string): Promise<number>
|
||||||
|
listUserConversationEntries(
|
||||||
|
userId: string,
|
||||||
|
conversationId: string,
|
||||||
|
params?: ListConversationEntriesParams,
|
||||||
|
): Promise<ConversationEntryRow[]>
|
||||||
|
listPendingUserConversationEntries(
|
||||||
|
userId: string,
|
||||||
|
conversationId: string,
|
||||||
|
): Promise<ConversationEntryRow[]>
|
||||||
|
findConversationResponseState(
|
||||||
|
conversationId: string,
|
||||||
|
): Promise<ConversationResponseStateRow | null>
|
||||||
|
// TODO: add pagination support
|
||||||
|
listPendingResponseStates(): Promise<ConversationResponseStateRow[]>
|
||||||
|
// TODO: add pagination support
|
||||||
|
listRunningResponseStates(): Promise<ConversationResponseStateRow[]>
|
||||||
|
upsertConversationResponseState(
|
||||||
|
conversationId: string,
|
||||||
|
input: UpsertConversationResponseStateInput,
|
||||||
|
): Promise<ConversationResponseStateRow>
|
||||||
|
updateConversationResponseState(
|
||||||
|
conversationId: string,
|
||||||
|
input: UpdateConversationResponseStateInput,
|
||||||
|
): Promise<ConversationResponseStateRow | null>
|
||||||
|
markResponseStateStatus(
|
||||||
|
conversationIds: string[],
|
||||||
|
status: ConversationResponseStateStatus,
|
||||||
|
): Promise<ConversationResponseStateRow[]>
|
||||||
|
claimPendingConversationResponseState(
|
||||||
|
conversationId: string,
|
||||||
|
): Promise<ConversationResponseStateRow | null>
|
||||||
|
clearConversationResponseState(conversationId: string): Promise<void>
|
||||||
|
}
|
||||||
|
|
||||||
/** Database row shape for a conversation owned by a user. */
|
/** Database row shape for a conversation owned by a user. */
|
||||||
export type ConversationRow = typeof conversationsTable.$inferSelect
|
export type ConversationRow = typeof conversationsTable.$inferSelect
|
||||||
@@ -31,6 +73,9 @@ export type ConversationRow = typeof conversationsTable.$inferSelect
|
|||||||
/** Database row shape for an entry in a conversation timeline. */
|
/** Database row shape for an entry in a conversation timeline. */
|
||||||
export type ConversationEntryRow = typeof conversationEntries.$inferSelect
|
export type ConversationEntryRow = typeof conversationEntries.$inferSelect
|
||||||
|
|
||||||
|
/** Database row shape for pending assistant response state in a conversation. */
|
||||||
|
export type ConversationResponseStateRow = typeof conversationResponseStateTable.$inferSelect
|
||||||
|
|
||||||
/** Database row shape for an uploaded file referenced by conversations. */
|
/** Database row shape for an uploaded file referenced by conversations. */
|
||||||
export type FileRow = typeof files.$inferSelect
|
export type FileRow = typeof files.$inferSelect
|
||||||
|
|
||||||
@@ -99,291 +144,26 @@ export interface ListConversationEntriesParams {
|
|||||||
visibility?: ConversationEntryVisibility
|
visibility?: ConversationEntryVisibility
|
||||||
}
|
}
|
||||||
|
|
||||||
export function conversations(db: Database, userId: string) {
|
/** Input for creating or replacing pending assistant response state. */
|
||||||
const storage = {
|
export interface UpsertConversationResponseStateInput {
|
||||||
async createConversation(): Promise<ConversationRow> {
|
status?: ConversationResponseStateStatus
|
||||||
return insertConversation(db, userId)
|
pendingSinceEntryId: string
|
||||||
},
|
maxWaitUntil: Date
|
||||||
|
runningSince?: Date | null
|
||||||
async listConversations(): Promise<ConversationRow[]> {
|
|
||||||
return db
|
|
||||||
.select()
|
|
||||||
.from(conversationsTable)
|
|
||||||
.where(eq(conversationsTable.userId, userId))
|
|
||||||
.orderBy(desc(conversationsTable.updatedAt), desc(conversationsTable.createdAt))
|
|
||||||
},
|
|
||||||
|
|
||||||
async getConversation(conversationId: string): Promise<ConversationRow | null> {
|
|
||||||
const rows = await db
|
|
||||||
.select()
|
|
||||||
.from(conversationsTable)
|
|
||||||
.where(
|
|
||||||
and(eq(conversationsTable.id, conversationId), eq(conversationsTable.userId, userId)),
|
|
||||||
)
|
|
||||||
.limit(1)
|
|
||||||
|
|
||||||
return rows[0] ?? null
|
|
||||||
},
|
|
||||||
|
|
||||||
async getOrCreateConversation(): Promise<ConversationRow> {
|
|
||||||
return db.transaction(async (tx) => {
|
|
||||||
await requireUserForUpdate(tx, userId)
|
|
||||||
const existing = await latestConversation(tx, userId)
|
|
||||||
if (existing) return existing
|
|
||||||
|
|
||||||
return insertConversation(tx, userId)
|
|
||||||
})
|
|
||||||
},
|
|
||||||
|
|
||||||
async createFile(input: CreateFileInput): Promise<FileRow> {
|
|
||||||
return insertFile(db, userId, input)
|
|
||||||
},
|
|
||||||
|
|
||||||
async appendEntry(
|
|
||||||
conversationId: string,
|
|
||||||
input: AppendConversationEntryInput,
|
|
||||||
): Promise<ConversationEntryRow> {
|
|
||||||
const kind = conversationEntryKind.assert(input.kind)
|
|
||||||
const visibility = conversationEntryVisibility.assert(
|
|
||||||
input.visibility ?? defaultVisibilityForKind(kind),
|
|
||||||
)
|
|
||||||
const payload = payloadForKind(kind, input.payload)
|
|
||||||
const metadata = ConversationEntryMetadata.assert(input.metadata ?? {})
|
|
||||||
let fileId: string | null = null
|
|
||||||
|
|
||||||
if (input.kind === ConversationEntryKind.Attachment) {
|
|
||||||
fileId = input.fileId
|
|
||||||
await requireFile(db, userId, fileId)
|
|
||||||
}
|
|
||||||
|
|
||||||
const rows = await db.transaction(async (tx) => {
|
|
||||||
if (!(await findConversationForUpdate(tx, userId, conversationId))) {
|
|
||||||
throw new ConversationNotFoundError(conversationId, userId)
|
|
||||||
}
|
|
||||||
const sequence = await nextSequence(tx, conversationId)
|
|
||||||
|
|
||||||
const rows = await tx
|
|
||||||
.insert(conversationEntries)
|
|
||||||
.values({
|
|
||||||
conversationId,
|
|
||||||
sequence,
|
|
||||||
kind,
|
|
||||||
visibility,
|
|
||||||
fileId,
|
|
||||||
payload,
|
|
||||||
metadata,
|
|
||||||
})
|
|
||||||
.returning()
|
|
||||||
|
|
||||||
await touchConversation(tx, userId, conversationId)
|
|
||||||
return rows
|
|
||||||
})
|
|
||||||
|
|
||||||
return requireRow(rows)
|
|
||||||
},
|
|
||||||
|
|
||||||
async appendAttachmentEntry(
|
|
||||||
conversationId: string,
|
|
||||||
input: AppendAttachmentEntryInput,
|
|
||||||
): Promise<AppendAttachmentEntryResult> {
|
|
||||||
const payload = AttachmentPayload.assert(input.payload)
|
|
||||||
const visibility = conversationEntryVisibility.assert(
|
|
||||||
input.visibility ?? defaultVisibilityForKind(ConversationEntryKind.Attachment),
|
|
||||||
)
|
|
||||||
const metadata = ConversationEntryMetadata.assert(input.metadata ?? {})
|
|
||||||
|
|
||||||
return db.transaction(async (tx) => {
|
|
||||||
if (!(await findConversationForUpdate(tx, userId, conversationId))) {
|
|
||||||
throw new ConversationNotFoundError(conversationId, userId)
|
|
||||||
}
|
|
||||||
|
|
||||||
const file = await insertFile(tx, userId, input.file)
|
|
||||||
const sequence = await nextSequence(tx, conversationId)
|
|
||||||
const rows = await tx
|
|
||||||
.insert(conversationEntries)
|
|
||||||
.values({
|
|
||||||
conversationId,
|
|
||||||
sequence,
|
|
||||||
kind: ConversationEntryKind.Attachment,
|
|
||||||
visibility,
|
|
||||||
fileId: file.id,
|
|
||||||
payload,
|
|
||||||
metadata,
|
|
||||||
})
|
|
||||||
.returning()
|
|
||||||
|
|
||||||
await touchConversation(tx, userId, conversationId)
|
|
||||||
return {
|
|
||||||
file,
|
|
||||||
entry: requireRow(rows),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
},
|
|
||||||
|
|
||||||
async listEntries(
|
|
||||||
conversationId: string,
|
|
||||||
params: ListConversationEntriesParams = {},
|
|
||||||
): Promise<ConversationEntryRow[]> {
|
|
||||||
if (!(await storage.getConversation(conversationId))) {
|
|
||||||
throw new ConversationNotFoundError(conversationId, userId)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.visibility) {
|
|
||||||
return db
|
|
||||||
.select()
|
|
||||||
.from(conversationEntries)
|
|
||||||
.where(
|
|
||||||
and(
|
|
||||||
eq(conversationEntries.conversationId, conversationId),
|
|
||||||
eq(conversationEntries.visibility, params.visibility),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.orderBy(asc(conversationEntries.sequence))
|
|
||||||
}
|
|
||||||
|
|
||||||
return db
|
|
||||||
.select()
|
|
||||||
.from(conversationEntries)
|
|
||||||
.where(eq(conversationEntries.conversationId, conversationId))
|
|
||||||
.orderBy(asc(conversationEntries.sequence))
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return storage
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function payloadForKind(
|
/** Input for patching pending assistant response state. */
|
||||||
kind: ConversationEntryKind,
|
export interface UpdateConversationResponseStateInput {
|
||||||
payload: AppendConversationEntryInput["payload"],
|
status?: ConversationResponseStateStatus
|
||||||
): ConversationEntryPayload {
|
pendingSinceEntryId?: string
|
||||||
switch (kind) {
|
maxWaitUntil?: Date
|
||||||
case ConversationEntryKind.UserMessage:
|
runningSince?: Date | null
|
||||||
return UserMessagePayload.assert(payload)
|
|
||||||
case ConversationEntryKind.AssistantMessage:
|
|
||||||
return AssistantMessagePayload.assert(payload)
|
|
||||||
case ConversationEntryKind.Attachment:
|
|
||||||
return AttachmentPayload.assert(payload)
|
|
||||||
case ConversationEntryKind.ContextSummary:
|
|
||||||
return ContextSummaryPayload.assert(payload)
|
|
||||||
case ConversationEntryKind.ToolCall:
|
|
||||||
case ConversationEntryKind.ToolResult:
|
|
||||||
case ConversationEntryKind.SystemNote:
|
|
||||||
return GenericObjectPayload.assert(payload)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function requireUserForUpdate(db: Database, userId: string): Promise<void> {
|
export {
|
||||||
const rows = await db
|
createConversationStorage,
|
||||||
.select({ id: user.id })
|
conversationResponse,
|
||||||
.from(user)
|
conversations,
|
||||||
.where(eq(user.id, userId))
|
DrizzleConversationStorage,
|
||||||
.limit(1)
|
findConversation,
|
||||||
.for("update")
|
} from "./db-storage.ts"
|
||||||
|
|
||||||
requireRow(rows, `User not found: ${userId}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
async function findConversationForUpdate(
|
|
||||||
db: Database,
|
|
||||||
userId: string,
|
|
||||||
conversationId: string,
|
|
||||||
): Promise<ConversationRow | null> {
|
|
||||||
const rows = await db
|
|
||||||
.select()
|
|
||||||
.from(conversationsTable)
|
|
||||||
.where(and(eq(conversationsTable.id, conversationId), eq(conversationsTable.userId, userId)))
|
|
||||||
.limit(1)
|
|
||||||
.for("update")
|
|
||||||
|
|
||||||
return rows[0] ?? null
|
|
||||||
}
|
|
||||||
|
|
||||||
async function latestConversation(db: Database, userId: string): Promise<ConversationRow | null> {
|
|
||||||
const rows = await db
|
|
||||||
.select()
|
|
||||||
.from(conversationsTable)
|
|
||||||
.where(eq(conversationsTable.userId, userId))
|
|
||||||
.orderBy(desc(conversationsTable.updatedAt), desc(conversationsTable.createdAt))
|
|
||||||
.limit(1)
|
|
||||||
|
|
||||||
return rows[0] ?? null
|
|
||||||
}
|
|
||||||
|
|
||||||
async function insertConversation(db: Database, userId: string): Promise<ConversationRow> {
|
|
||||||
const rows = await db
|
|
||||||
.insert(conversationsTable)
|
|
||||||
.values({
|
|
||||||
userId,
|
|
||||||
})
|
|
||||||
.returning()
|
|
||||||
|
|
||||||
return requireRow(rows)
|
|
||||||
}
|
|
||||||
|
|
||||||
async function requireFile(db: Database, userId: string, fileId: string): Promise<FileRow> {
|
|
||||||
const rows = await db
|
|
||||||
.select()
|
|
||||||
.from(files)
|
|
||||||
.where(and(eq(files.id, fileId), eq(files.userId, userId)))
|
|
||||||
.limit(1)
|
|
||||||
|
|
||||||
return requireRow(rows, `File not found: ${fileId}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
async function insertFile(db: Database, userId: string, input: CreateFileInput): Promise<FileRow> {
|
|
||||||
const rows = await db
|
|
||||||
.insert(files)
|
|
||||||
.values({
|
|
||||||
userId,
|
|
||||||
storageKey: input.storageKey,
|
|
||||||
originalName: input.originalName ?? null,
|
|
||||||
mimeType: input.mimeType,
|
|
||||||
sizeBytes: input.sizeBytes,
|
|
||||||
metadata: input.metadata ?? {},
|
|
||||||
})
|
|
||||||
.returning()
|
|
||||||
|
|
||||||
return requireRow(rows)
|
|
||||||
}
|
|
||||||
|
|
||||||
async function touchConversation(
|
|
||||||
db: Database,
|
|
||||||
userId: string,
|
|
||||||
conversationId: string,
|
|
||||||
): Promise<void> {
|
|
||||||
await db
|
|
||||||
.update(conversationsTable)
|
|
||||||
.set({ updatedAt: new Date() })
|
|
||||||
.where(and(eq(conversationsTable.id, conversationId), eq(conversationsTable.userId, userId)))
|
|
||||||
}
|
|
||||||
|
|
||||||
async function nextSequence(db: Database, conversationId: string): Promise<number> {
|
|
||||||
const rows = await db
|
|
||||||
.select({ sequence: conversationEntries.sequence })
|
|
||||||
.from(conversationEntries)
|
|
||||||
.where(eq(conversationEntries.conversationId, conversationId))
|
|
||||||
.orderBy(desc(conversationEntries.sequence))
|
|
||||||
.limit(1)
|
|
||||||
|
|
||||||
return (rows[0]?.sequence ?? 0) + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
function requireRow<T>(rows: T[], message = "Expected database row"): T {
|
|
||||||
const row = rows[0]
|
|
||||||
if (!row) throw new Error(message)
|
|
||||||
return row
|
|
||||||
}
|
|
||||||
|
|
||||||
function defaultVisibilityForKind(kind: ConversationEntryKind): ConversationEntryVisibility {
|
|
||||||
switch (kind) {
|
|
||||||
case ConversationEntryKind.UserMessage:
|
|
||||||
case ConversationEntryKind.AssistantMessage:
|
|
||||||
case ConversationEntryKind.Attachment:
|
|
||||||
return ConversationEntryVisibility.UserVisible
|
|
||||||
case ConversationEntryKind.ToolCall:
|
|
||||||
case ConversationEntryKind.ToolResult:
|
|
||||||
case ConversationEntryKind.ContextSummary:
|
|
||||||
case ConversationEntryKind.SystemNote:
|
|
||||||
return ConversationEntryVisibility.Internal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -48,6 +48,15 @@ const bytea = customType<{ data: Buffer }>({
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
export const ConversationResponseStateStatus = {
|
||||||
|
Pending: "pending",
|
||||||
|
Running: "running",
|
||||||
|
Failed: "failed",
|
||||||
|
} as const
|
||||||
|
|
||||||
|
export type ConversationResponseStateStatus =
|
||||||
|
(typeof ConversationResponseStateStatus)[keyof typeof ConversationResponseStateStatus]
|
||||||
|
|
||||||
export const userSources = pgTable(
|
export const userSources = pgTable(
|
||||||
"user_sources",
|
"user_sources",
|
||||||
{
|
{
|
||||||
@@ -146,6 +155,38 @@ export const conversationEntries = pgTable(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
export const conversationResponseState = pgTable(
|
||||||
|
"conversation_response_state",
|
||||||
|
{
|
||||||
|
conversationId: uuid("conversation_id")
|
||||||
|
.primaryKey()
|
||||||
|
.references(() => conversations.id, { onDelete: "cascade" }),
|
||||||
|
status: text("status")
|
||||||
|
.$type<ConversationResponseStateStatus>()
|
||||||
|
.notNull()
|
||||||
|
.default(ConversationResponseStateStatus.Pending),
|
||||||
|
pendingSinceEntryId: uuid("pending_since_entry_id")
|
||||||
|
.notNull()
|
||||||
|
.references(() => conversationEntries.id, { onDelete: "cascade" }),
|
||||||
|
maxWaitUntil: timestamp("max_wait_until").notNull(),
|
||||||
|
runningSince: timestamp("running_since"),
|
||||||
|
createdAt: timestamp("created_at").notNull().defaultNow(),
|
||||||
|
updatedAt: timestamp("updated_at")
|
||||||
|
.notNull()
|
||||||
|
.defaultNow()
|
||||||
|
.$onUpdate(() => new Date()),
|
||||||
|
},
|
||||||
|
(t) => [
|
||||||
|
index("conversation_response_state_status_max_wait_until_idx").on(t.status, t.maxWaitUntil),
|
||||||
|
index("conversation_response_state_running_since_idx").on(t.runningSince),
|
||||||
|
index("conversation_response_state_pending_since_entry_id_idx").on(t.pendingSinceEntryId),
|
||||||
|
check(
|
||||||
|
"conversation_response_state_status_check",
|
||||||
|
sql`${t.status} in ('pending', 'running', 'failed')`,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// FREYA — reminders source storage
|
// FREYA — reminders source storage
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ interface FeedResponse {
|
|||||||
items: Array<{
|
items: Array<{
|
||||||
id: string
|
id: string
|
||||||
type: string
|
type: string
|
||||||
priority: number
|
|
||||||
timestamp: string
|
timestamp: string
|
||||||
data: Record<string, unknown>
|
data: Record<string, unknown>
|
||||||
}>
|
}>
|
||||||
@@ -85,7 +84,7 @@ mock.module("../sources/user-sources.ts", () => ({
|
|||||||
}),
|
}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
mock.module("../conversations/storage.ts", () => ({
|
mock.module("../conversations/db-storage.ts", () => ({
|
||||||
conversations: (_db: Database, userId: string) => ({
|
conversations: (_db: Database, userId: string) => ({
|
||||||
async getOrCreateConversation() {
|
async getOrCreateConversation() {
|
||||||
return { id: `conversation-${userId}` }
|
return { id: `conversation-${userId}` }
|
||||||
@@ -118,7 +117,6 @@ describe("GET /api/feed", () => {
|
|||||||
id: "item-1",
|
id: "item-1",
|
||||||
sourceId: "test",
|
sourceId: "test",
|
||||||
type: "test",
|
type: "test",
|
||||||
priority: 0.8,
|
|
||||||
timestamp: new Date("2025-01-01T00:00:00.000Z"),
|
timestamp: new Date("2025-01-01T00:00:00.000Z"),
|
||||||
data: { value: 42 },
|
data: { value: 42 },
|
||||||
},
|
},
|
||||||
@@ -149,7 +147,6 @@ describe("GET /api/feed", () => {
|
|||||||
expect(body.items).toHaveLength(1)
|
expect(body.items).toHaveLength(1)
|
||||||
expect(body.items[0]!.id).toBe("item-1")
|
expect(body.items[0]!.id).toBe("item-1")
|
||||||
expect(body.items[0]!.type).toBe("test")
|
expect(body.items[0]!.type).toBe("test")
|
||||||
expect(body.items[0]!.priority).toBe(0.8)
|
|
||||||
expect(body.items[0]!.timestamp).toBe("2025-01-01T00:00:00.000Z")
|
expect(body.items[0]!.timestamp).toBe("2025-01-01T00:00:00.000Z")
|
||||||
expect(body.errors).toHaveLength(0)
|
expect(body.errors).toHaveLength(0)
|
||||||
})
|
})
|
||||||
@@ -160,7 +157,6 @@ describe("GET /api/feed", () => {
|
|||||||
id: "fresh-1",
|
id: "fresh-1",
|
||||||
sourceId: "test",
|
sourceId: "test",
|
||||||
type: "test",
|
type: "test",
|
||||||
priority: 0.5,
|
|
||||||
timestamp: new Date("2025-06-01T12:00:00.000Z"),
|
timestamp: new Date("2025-06-01T12:00:00.000Z"),
|
||||||
data: { fresh: true },
|
data: { fresh: true },
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -135,8 +135,9 @@ describe("schema sync", () => {
|
|||||||
|
|
||||||
// JSON Schema structure matches
|
// JSON Schema structure matches
|
||||||
const jsonSchema = enhancementResultJsonSchema
|
const jsonSchema = enhancementResultJsonSchema
|
||||||
|
const payloadKeys = Object.keys(payload).sort() as Array<(typeof jsonSchema.required)[number]>
|
||||||
expect(Object.keys(jsonSchema.properties).sort()).toEqual(Object.keys(payload).sort())
|
expect(Object.keys(jsonSchema.properties).sort()).toEqual(Object.keys(payload).sort())
|
||||||
expect([...jsonSchema.required].sort()).toEqual(Object.keys(payload).sort())
|
expect([...jsonSchema.required].sort()).toEqual(payloadKeys)
|
||||||
|
|
||||||
// syntheticItems item schema has the right required fields
|
// syntheticItems item schema has the right required fields
|
||||||
const itemSchema = jsonSchema.properties.syntheticItems.items
|
const itemSchema = jsonSchema.properties.syntheticItems.items
|
||||||
|
|||||||
116
apps/freya-backend/src/lib/job.ts
Normal file
116
apps/freya-backend/src/lib/job.ts
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import { Queue } from "./queue"
|
||||||
|
|
||||||
|
const JobStatus = {
|
||||||
|
Pending: "pending",
|
||||||
|
Running: "running",
|
||||||
|
} as const
|
||||||
|
type JobStatus = (typeof JobStatus)[keyof typeof JobStatus]
|
||||||
|
|
||||||
|
export interface Job<Payload> {
|
||||||
|
id: number
|
||||||
|
payload: Payload
|
||||||
|
signal: AbortSignal
|
||||||
|
}
|
||||||
|
|
||||||
|
interface PendingJob<Payload> {
|
||||||
|
status: typeof JobStatus.Pending
|
||||||
|
controller: AbortController
|
||||||
|
job: Job<Payload>
|
||||||
|
}
|
||||||
|
|
||||||
|
interface RunningJob<Payload> {
|
||||||
|
status: typeof JobStatus.Running
|
||||||
|
controller: AbortController
|
||||||
|
job: Job<Payload>
|
||||||
|
}
|
||||||
|
|
||||||
|
type JobState<Payload> = PendingJob<Payload> | RunningJob<Payload>
|
||||||
|
|
||||||
|
type JobEventListener<Payload> = (job: Job<Payload>) => void
|
||||||
|
|
||||||
|
type JobEvent = "settled" | "cancelled"
|
||||||
|
|
||||||
|
export class JobRegistry<Payload> {
|
||||||
|
private queue = new Queue<Job<Payload>>()
|
||||||
|
|
||||||
|
private states = new Map<number, JobState<Payload>>()
|
||||||
|
|
||||||
|
private listeners: Record<JobEvent, JobEventListener<Payload>[]> = {
|
||||||
|
settled: [],
|
||||||
|
cancelled: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
addJob({ payload }: { payload: Payload }): Job<Payload> {
|
||||||
|
const controller = new AbortController()
|
||||||
|
const job: Job<Payload> = {
|
||||||
|
id: this.generateJobId(),
|
||||||
|
payload,
|
||||||
|
signal: controller.signal,
|
||||||
|
}
|
||||||
|
this.queue.enqueue(job)
|
||||||
|
this.states.set(job.id, { status: JobStatus.Pending, controller, job })
|
||||||
|
return job
|
||||||
|
}
|
||||||
|
|
||||||
|
async nextJob(signal?: AbortSignal): Promise<Job<Payload> | null> {
|
||||||
|
while (true) {
|
||||||
|
const job = await this.queue.next(signal)
|
||||||
|
if (!job) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
const state = this.states.get(job.id)
|
||||||
|
|
||||||
|
if (!state || state.job !== job || state.status === JobStatus.Running) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if (state.controller.signal.aborted) {
|
||||||
|
this.states.delete(job.id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
this.states.set(job.id, { status: JobStatus.Running, controller: state.controller, job })
|
||||||
|
|
||||||
|
return job
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cancelJob(job: Job<unknown>): void {
|
||||||
|
const state = this.states.get(job.id)
|
||||||
|
if (state?.job === job) {
|
||||||
|
state?.controller.abort()
|
||||||
|
this.notifyListeners("cancelled", job.id)
|
||||||
|
this.states.delete(job.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
markJobAsCompleted(job: Job<unknown>): void {
|
||||||
|
const state = this.states.get(job.id)
|
||||||
|
if (state?.job === job) {
|
||||||
|
this.notifyListeners("settled", job.id)
|
||||||
|
this.states.delete(job.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
addEventListener(event: JobEvent, listener: JobEventListener<Payload>): () => void {
|
||||||
|
this.listeners[event].push(listener)
|
||||||
|
return () => {
|
||||||
|
this.listeners[event] = this.listeners[event].filter((l) => l !== listener)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private generateJobId(): number {
|
||||||
|
let id: number
|
||||||
|
do {
|
||||||
|
id = Math.floor(Math.random() * 1000000)
|
||||||
|
} while (this.states.has(id))
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
private notifyListeners(event: JobEvent, id: number): void {
|
||||||
|
const job = this.states.get(id)?.job
|
||||||
|
if (job) {
|
||||||
|
this.listeners[event].forEach((listener) => listener(job))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
69
apps/freya-backend/src/lib/queue.ts
Normal file
69
apps/freya-backend/src/lib/queue.ts
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
interface Item<T> {
|
||||||
|
value: T
|
||||||
|
next: Item<T> | null
|
||||||
|
}
|
||||||
|
|
||||||
|
export class Queue<T> {
|
||||||
|
private front: Item<T> | null = null
|
||||||
|
private back: Item<T> | null = null
|
||||||
|
private waiters: Array<(value: T) => void> = []
|
||||||
|
|
||||||
|
enqueue(value: T): void {
|
||||||
|
const waiter = this.waiters.shift()
|
||||||
|
if (waiter) {
|
||||||
|
waiter(value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const newItem: Item<T> = { value, next: null }
|
||||||
|
if (this.back) {
|
||||||
|
this.back.next = newItem
|
||||||
|
} else {
|
||||||
|
this.front = newItem
|
||||||
|
}
|
||||||
|
this.back = newItem
|
||||||
|
}
|
||||||
|
|
||||||
|
dequeue(): T | null {
|
||||||
|
if (!this.front) return null
|
||||||
|
const value = this.front.value
|
||||||
|
this.front = this.front.next
|
||||||
|
if (!this.front) this.back = null
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
next(signal?: AbortSignal): Promise<T | null> {
|
||||||
|
const value = this.dequeue()
|
||||||
|
if (value !== null) return Promise.resolve(value)
|
||||||
|
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
if (signal) {
|
||||||
|
if (signal.aborted) {
|
||||||
|
resolve(null)
|
||||||
|
} else {
|
||||||
|
let _resolve: (v: T) => void
|
||||||
|
|
||||||
|
const onAbort = () => {
|
||||||
|
this.waiters = this.waiters.filter((w) => w !== _resolve)
|
||||||
|
resolve(null)
|
||||||
|
}
|
||||||
|
|
||||||
|
signal.addEventListener(
|
||||||
|
"abort",
|
||||||
|
onAbort,
|
||||||
|
{ once: true },
|
||||||
|
)
|
||||||
|
|
||||||
|
_resolve = (v: T) => {
|
||||||
|
signal.removeEventListener("abort", onAbort)
|
||||||
|
resolve(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
this.waiters.push(_resolve)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
this.waiters.push(resolve)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
51
apps/freya-backend/src/lib/worker.ts
Normal file
51
apps/freya-backend/src/lib/worker.ts
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import type { Job, JobRegistry } from "./job"
|
||||||
|
import type { Queue } from "./queue"
|
||||||
|
|
||||||
|
export interface JobExecutor<JobPayload> {
|
||||||
|
execute(job: Job<JobPayload>): Promise<void>
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface WorkerConfig<Job> {
|
||||||
|
concurrency: number
|
||||||
|
registry: JobRegistry<Job>
|
||||||
|
runner: JobExecutor<Job>
|
||||||
|
signal: AbortSignal
|
||||||
|
}
|
||||||
|
|
||||||
|
export class Worker<Job> {
|
||||||
|
private concurrency: number
|
||||||
|
private registry: JobRegistry<Job>
|
||||||
|
private runner: JobExecutor<Job>
|
||||||
|
private signal: AbortSignal
|
||||||
|
|
||||||
|
constructor({ concurrency, registry, runner, signal }: WorkerConfig<Job>) {
|
||||||
|
this.concurrency = concurrency
|
||||||
|
this.registry = registry
|
||||||
|
this.runner = runner
|
||||||
|
this.signal = signal
|
||||||
|
}
|
||||||
|
|
||||||
|
start() {
|
||||||
|
if (this.signal.aborted) return
|
||||||
|
for (let i = 0; i < this.concurrency; i++) {
|
||||||
|
void this.pollJobFromRegistry()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async pollJobFromRegistry() {
|
||||||
|
while (!this.signal.aborted) {
|
||||||
|
const job = await this.registry.nextJob(this.signal)
|
||||||
|
if (!job) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
await this.runner.execute(job)
|
||||||
|
} catch {
|
||||||
|
// TODO: handle logging of job execution errors
|
||||||
|
} finally {
|
||||||
|
this.registry.markJobAsCompleted(job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
36
apps/freya-backend/src/notification/notification-central.ts
Normal file
36
apps/freya-backend/src/notification/notification-central.ts
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import type { AgentEvent } from "@freya/agent-protocol"
|
||||||
|
|
||||||
|
export interface AgentNotification {
|
||||||
|
kind: "agent"
|
||||||
|
payload: AgentEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
export type NotificationPayload = AgentNotification
|
||||||
|
export type NotificationListener = (notification: NotificationPayload) => Promise<void>
|
||||||
|
|
||||||
|
export class NotificationCentral {
|
||||||
|
private listeners: Map<string, Set<NotificationListener>> = new Map()
|
||||||
|
|
||||||
|
registerListenerForUser(userId: string, listener: NotificationListener): () => void {
|
||||||
|
let listeners = this.listeners.get(userId)
|
||||||
|
if (!listeners) {
|
||||||
|
listeners = new Set()
|
||||||
|
this.listeners.set(userId, listeners)
|
||||||
|
}
|
||||||
|
|
||||||
|
listeners.add(listener)
|
||||||
|
return () => {
|
||||||
|
listeners.delete(listener)
|
||||||
|
if (listeners.size === 0) {
|
||||||
|
this.listeners.delete(userId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async notifyUser(userId: string, notification: NotificationPayload): Promise<void> {
|
||||||
|
const listeners = this.listeners.get(userId)
|
||||||
|
if (!listeners) return
|
||||||
|
|
||||||
|
await Promise.allSettled(Array.from(listeners).map((listener) => listener(notification)))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ import { createMiddleware } from "hono/factory"
|
|||||||
import { registerAdminHttpHandlers } from "./admin/http.ts"
|
import { registerAdminHttpHandlers } from "./admin/http.ts"
|
||||||
import { createQueryDebugTools } from "./agent/debug-tools.ts"
|
import { createQueryDebugTools } from "./agent/debug-tools.ts"
|
||||||
import { registerAgentHttpHandlers, registerDebugAgentHttpHandlers } from "./agent/http.ts"
|
import { registerAgentHttpHandlers, registerDebugAgentHttpHandlers } from "./agent/http.ts"
|
||||||
|
import { AgentService } from "./agent/service.ts"
|
||||||
import { agentWebSocket, registerAgentWebSocketHandlers } from "./agent/ws.ts"
|
import { agentWebSocket, registerAgentWebSocketHandlers } from "./agent/ws.ts"
|
||||||
import { createRequireAdmin } from "./auth/admin-middleware.ts"
|
import { createRequireAdmin } from "./auth/admin-middleware.ts"
|
||||||
import { registerAuthHandlers } from "./auth/http.ts"
|
import { registerAuthHandlers } from "./auth/http.ts"
|
||||||
@@ -12,6 +13,7 @@ import { createAuth } from "./auth/index.ts"
|
|||||||
import { createRequireSession } from "./auth/session-middleware.ts"
|
import { createRequireSession } from "./auth/session-middleware.ts"
|
||||||
import { CalDavSourceProvider } from "./caldav/provider.ts"
|
import { CalDavSourceProvider } from "./caldav/provider.ts"
|
||||||
import { registerConversationsHttpHandlers } from "./conversations/http.ts"
|
import { registerConversationsHttpHandlers } from "./conversations/http.ts"
|
||||||
|
import { DrizzleConversationStorage } from "./conversations/storage.ts"
|
||||||
import { createDatabase } from "./db/index.ts"
|
import { createDatabase } from "./db/index.ts"
|
||||||
import { registerFeedHttpHandlers } from "./engine/http.ts"
|
import { registerFeedHttpHandlers } from "./engine/http.ts"
|
||||||
import { createFeedEnhancer } from "./enhancement/enhance-feed.ts"
|
import { createFeedEnhancer } from "./enhancement/enhance-feed.ts"
|
||||||
@@ -21,6 +23,7 @@ import { CredentialEncryptor } from "./lib/crypto.ts"
|
|||||||
import { ensureEnv } from "./lib/env.ts"
|
import { ensureEnv } from "./lib/env.ts"
|
||||||
import { registerLocationHttpHandlers } from "./location/http.ts"
|
import { registerLocationHttpHandlers } from "./location/http.ts"
|
||||||
import { LocationSourceProvider } from "./location/provider.ts"
|
import { LocationSourceProvider } from "./location/provider.ts"
|
||||||
|
import { NotificationCentral } from "./notification/notification-central.ts"
|
||||||
import { ReminderSourceProvider } from "./reminders/provider.ts"
|
import { ReminderSourceProvider } from "./reminders/provider.ts"
|
||||||
import { UserSessionManager } from "./session/index.ts"
|
import { UserSessionManager } from "./session/index.ts"
|
||||||
import { registerSourcesHttpHandlers } from "./sources/http.ts"
|
import { registerSourcesHttpHandlers } from "./sources/http.ts"
|
||||||
@@ -32,8 +35,12 @@ function main() {
|
|||||||
const env = ensureEnv(process.env)
|
const env = ensureEnv(process.env)
|
||||||
|
|
||||||
const { db, close: closeDb } = createDatabase(env.databaseUrl)
|
const { db, close: closeDb } = createDatabase(env.databaseUrl)
|
||||||
|
const conversationStorage = new DrizzleConversationStorage(db, false)
|
||||||
|
|
||||||
const auth = createAuth(db)
|
const auth = createAuth(db)
|
||||||
|
|
||||||
|
const abortController = new AbortController()
|
||||||
|
|
||||||
const feedEnhancer = createFeedEnhancer({
|
const feedEnhancer = createFeedEnhancer({
|
||||||
client: createLlmClient({
|
client: createLlmClient({
|
||||||
apiKey: env.openrouterApiKey,
|
apiKey: env.openrouterApiKey,
|
||||||
@@ -73,6 +80,15 @@ function main() {
|
|||||||
console.warn("[query] PI_API_KEY or OPENROUTER_API_KEY not set — query agent unavailable")
|
console.warn("[query] PI_API_KEY or OPENROUTER_API_KEY not set — query agent unavailable")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const notificationCentral = new NotificationCentral()
|
||||||
|
|
||||||
|
const agentService = new AgentService({
|
||||||
|
notificationCentral,
|
||||||
|
storage: conversationStorage,
|
||||||
|
userSessionManager: sessionManager,
|
||||||
|
signal: abortController.signal,
|
||||||
|
})
|
||||||
|
|
||||||
const app = new Hono()
|
const app = new Hono()
|
||||||
|
|
||||||
const isDev = process.env.NODE_ENV !== "production"
|
const isDev = process.env.NODE_ENV !== "production"
|
||||||
@@ -141,17 +157,22 @@ function main() {
|
|||||||
registerAdminHttpHandlers(app, { sessionManager, adminMiddleware, db })
|
registerAdminHttpHandlers(app, { sessionManager, adminMiddleware, db })
|
||||||
|
|
||||||
registerAgentWebSocketHandlers(app, {
|
registerAgentWebSocketHandlers(app, {
|
||||||
sessionManager,
|
agentService,
|
||||||
|
notificationCentral,
|
||||||
|
storage: conversationStorage,
|
||||||
authSessionMiddleware,
|
authSessionMiddleware,
|
||||||
corsMiddleware: agentWebSocketCorsMiddleware,
|
corsMiddleware: agentWebSocketCorsMiddleware,
|
||||||
})
|
})
|
||||||
|
|
||||||
process.on("SIGTERM", async () => {
|
process.on("SIGTERM", async () => {
|
||||||
sessionManager.dispose()
|
sessionManager.dispose()
|
||||||
|
abortController.abort()
|
||||||
await closeDb()
|
await closeDb()
|
||||||
process.exit(0)
|
process.exit(0)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
agentService.start()
|
||||||
|
|
||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ mock.module("../sources/user-sources.ts", () => ({
|
|||||||
}),
|
}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
mock.module("../conversations/storage.ts", () => ({
|
mock.module("../conversations/db-storage.ts", () => ({
|
||||||
conversations: (_db: Database, userId: string) => ({
|
conversations: (_db: Database, userId: string) => ({
|
||||||
async getOrCreateConversation(): Promise<{ id: string }> {
|
async getOrCreateConversation(): Promise<{ id: string }> {
|
||||||
mockConversationCalls.push({ type: "getOrCreate", userId })
|
mockConversationCalls.push({ type: "getOrCreate", userId })
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import type { FeedEnhancer } from "../enhancement/enhance-feed.ts"
|
|||||||
import type { CredentialEncryptor } from "../lib/crypto.ts"
|
import type { CredentialEncryptor } from "../lib/crypto.ts"
|
||||||
import type { FeedSourceProvider } from "./feed-source-provider.ts"
|
import type { FeedSourceProvider } from "./feed-source-provider.ts"
|
||||||
|
|
||||||
import { conversations } from "../conversations/storage.ts"
|
import { conversations } from "../conversations/db-storage.ts"
|
||||||
import {
|
import {
|
||||||
CredentialStorageUnavailableError,
|
CredentialStorageUnavailableError,
|
||||||
InvalidSourceConfigError,
|
InvalidSourceConfigError,
|
||||||
|
|||||||
@@ -263,18 +263,12 @@ export class UserSession {
|
|||||||
const conversation = await conversationStorage.getOrCreateConversation()
|
const conversation = await conversationStorage.getOrCreateConversation()
|
||||||
const entries = await conversationStorage.listEntries(conversation.id)
|
const entries = await conversationStorage.listEntries(conversation.id)
|
||||||
|
|
||||||
this.queryAgent = new ConversationRecordingQueryAgent({
|
this.queryAgent = new PiQueryAgent({
|
||||||
agent: new PiQueryAgent({
|
toolbox: this.toolbox,
|
||||||
toolbox: this.toolbox,
|
apiKey: this.agentConfig?.apiKey,
|
||||||
apiKey: this.agentConfig?.apiKey,
|
cwd: this.agentConfig?.cwd,
|
||||||
cwd: this.agentConfig?.cwd,
|
systemPrompt: this.agentConfig?.systemPrompt,
|
||||||
systemPrompt: this.agentConfig?.systemPrompt,
|
initialEntries: entries,
|
||||||
initialEntries: entries,
|
|
||||||
}),
|
|
||||||
storage: conversationStorage,
|
|
||||||
defaultConversationId: conversation.id,
|
|
||||||
modelProvider: PI_MODEL_PROVIDER,
|
|
||||||
modelId: PI_MODEL_ID,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ mock.module("../sources/user-sources.ts", () => ({
|
|||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
mock.module("../conversations/storage.ts", () => ({
|
mock.module("../conversations/db-storage.ts", () => ({
|
||||||
conversations: (_db: Database, userId: string) => ({
|
conversations: (_db: Database, userId: string) => ({
|
||||||
async getOrCreateConversation() {
|
async getOrCreateConversation() {
|
||||||
return { id: `conversation-${userId}` }
|
return { id: `conversation-${userId}` }
|
||||||
|
|||||||
6
bun.lock
6
bun.lock
@@ -51,7 +51,7 @@
|
|||||||
"version": "0.0.0",
|
"version": "0.0.0",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@freya/agent-protocol": "workspace:*",
|
"@freya/agent-protocol": "workspace:*",
|
||||||
"@nym.sh/jrpc": "^0.1.0",
|
"@nym.sh/jrpc": "1.1.0",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"apps/freya-backend": {
|
"apps/freya-backend": {
|
||||||
@@ -69,7 +69,7 @@
|
|||||||
"@freya/source-tfl": "workspace:*",
|
"@freya/source-tfl": "workspace:*",
|
||||||
"@freya/source-weatherkit": "workspace:*",
|
"@freya/source-weatherkit": "workspace:*",
|
||||||
"@freya/source-web-search": "workspace:*",
|
"@freya/source-web-search": "workspace:*",
|
||||||
"@nym.sh/jrpc": "^0.1.0",
|
"@nym.sh/jrpc": "1.1.0",
|
||||||
"@openrouter/sdk": "^0.9.11",
|
"@openrouter/sdk": "^0.9.11",
|
||||||
"arktype": "^2.1.29",
|
"arktype": "^2.1.29",
|
||||||
"better-auth": "^1",
|
"better-auth": "^1",
|
||||||
@@ -838,7 +838,7 @@
|
|||||||
|
|
||||||
"@nolyfill/is-core-module": ["@nolyfill/is-core-module@1.0.39", "", {}, "sha512-nn5ozdjYQpUCZlWGuxcJY/KpxkWQs4DcbMCmKojjyrYDEAGy4Ce19NN4v5MduafTwJlbKc99UA8YhSVqq9yPZA=="],
|
"@nolyfill/is-core-module": ["@nolyfill/is-core-module@1.0.39", "", {}, "sha512-nn5ozdjYQpUCZlWGuxcJY/KpxkWQs4DcbMCmKojjyrYDEAGy4Ce19NN4v5MduafTwJlbKc99UA8YhSVqq9yPZA=="],
|
||||||
|
|
||||||
"@nym.sh/jrpc": ["@nym.sh/jrpc@0.1.0", "", {}, "sha512-qH+vqKojPrX4RkW67U2R4J98uWHxZOwYxX2J5GLZcfm/yjklCcN5zTfDNLfgAa9jAoOFVscC3DFWhvdZOmN3fA=="],
|
"@nym.sh/jrpc": ["@nym.sh/jrpc@1.1.0", "", {}, "sha512-212SYMB37GdL8enaRTTqG/LNa5bJ7eYth6jfQfECuedQCuaju0bOMUzCN6hvY5KkrxdYuqVKmr2Uz+ZZTjPlaQ=="],
|
||||||
|
|
||||||
"@nym.sh/jrx": ["@nym.sh/jrx@0.2.0", "", { "peerDependencies": { "@json-render/core": ">=0.10.0" } }, "sha512-jd7Z1Q6T21366MtSUnwCFiu6Yl1AdNc9s5m6HxeUg265P+0enZCiyyxOuHsFwvpUcSEs/2DVBsqfMptdca44lA=="],
|
"@nym.sh/jrx": ["@nym.sh/jrx@0.2.0", "", { "peerDependencies": { "@json-render/core": ">=0.10.0" } }, "sha512-jd7Z1Q6T21366MtSUnwCFiu6Yl1AdNc9s5m6HxeUg265P+0enZCiyyxOuHsFwvpUcSEs/2DVBsqfMptdca44lA=="],
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,3 @@
|
|||||||
export interface SendMessageResult {
|
|
||||||
message: string
|
|
||||||
conversationId: string
|
|
||||||
}
|
|
||||||
|
|
||||||
export type AgentEvent =
|
export type AgentEvent =
|
||||||
| { type: "conversation_started"; conversationId: string }
|
| { type: "conversation_started"; conversationId: string }
|
||||||
| { type: "message_created"; text: string }
|
| { type: "message_created"; text: string }
|
||||||
@@ -11,8 +6,11 @@ export type AgentEvent =
|
|||||||
| { type: "message_finished" }
|
| { type: "message_finished" }
|
||||||
| { type: "message_failed"; error: string }
|
| { type: "message_failed"; error: string }
|
||||||
|
|
||||||
|
export type UserEvent = { type: "typing" }
|
||||||
|
|
||||||
export interface AgentServerApi {
|
export interface AgentServerApi {
|
||||||
sendMessage(message: string): Promise<SendMessageResult>
|
sendMessage(message: string): Promise<boolean>
|
||||||
|
notify(event: UserEvent): void
|
||||||
ping(): "pong"
|
ping(): "pong"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -146,6 +146,19 @@ export const ConversationEntryMetadata = type({
|
|||||||
/** Metadata bag attached to a conversation entry. */
|
/** Metadata bag attached to a conversation entry. */
|
||||||
export type ConversationEntryMetadata = typeof ConversationEntryMetadata.infer
|
export type ConversationEntryMetadata = typeof ConversationEntryMetadata.infer
|
||||||
|
|
||||||
|
export const ToolCallPayload = type({
|
||||||
|
toolName: "string",
|
||||||
|
})
|
||||||
|
|
||||||
|
export type ToolCallPayload = typeof ToolCallPayload.infer
|
||||||
|
|
||||||
|
export const ToolResultPayload = type({
|
||||||
|
toolName: "string",
|
||||||
|
ok: "boolean",
|
||||||
|
})
|
||||||
|
|
||||||
|
export type ToolResultPayload = typeof ToolResultPayload.infer
|
||||||
|
|
||||||
/** Generic object payload used by operational entries. */
|
/** Generic object payload used by operational entries. */
|
||||||
export const GenericObjectPayload = type("Record<string, unknown>")
|
export const GenericObjectPayload = type("Record<string, unknown>")
|
||||||
|
|
||||||
@@ -158,4 +171,6 @@ export type ConversationEntryPayload =
|
|||||||
| AssistantMessagePayload
|
| AssistantMessagePayload
|
||||||
| AttachmentPayload
|
| AttachmentPayload
|
||||||
| ContextSummaryPayload
|
| ContextSummaryPayload
|
||||||
|
| ToolCallPayload
|
||||||
|
| ToolResultPayload
|
||||||
| GenericObjectPayload
|
| GenericObjectPayload
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ export {
|
|||||||
ModelRunMetadata,
|
ModelRunMetadata,
|
||||||
TextMessagePart,
|
TextMessagePart,
|
||||||
UserMessagePayload,
|
UserMessagePayload,
|
||||||
|
ToolCallPayload,
|
||||||
|
ToolResultPayload,
|
||||||
} from "./conversation"
|
} from "./conversation"
|
||||||
|
|
||||||
// Feed
|
// Feed
|
||||||
|
|||||||
Reference in New Issue
Block a user