diff --git a/apps/agent-test-cli/package.json b/apps/agent-test-cli/package.json index 5a7dee6..3bd5e48 100644 --- a/apps/agent-test-cli/package.json +++ b/apps/agent-test-cli/package.json @@ -10,6 +10,6 @@ }, "dependencies": { "@freya/agent-protocol": "workspace:*", - "@nym.sh/jrpc": "^0.1.0" + "@nym.sh/jrpc": "1.1.0" } } diff --git a/apps/freya-backend/package.json b/apps/freya-backend/package.json index 5d8d1d3..9b8f26b 100644 --- a/apps/freya-backend/package.json +++ b/apps/freya-backend/package.json @@ -26,7 +26,7 @@ "@freya/source-tfl": "workspace:*", "@freya/source-weatherkit": "workspace:*", "@freya/source-web-search": "workspace:*", - "@nym.sh/jrpc": "^0.1.0", + "@nym.sh/jrpc": "1.1.0", "@openrouter/sdk": "^0.9.11", "arktype": "^2.1.29", "better-auth": "^1", diff --git a/apps/freya-backend/src/admin/http.test.ts b/apps/freya-backend/src/admin/http.test.ts index 7837777..d1dfa17 100644 --- a/apps/freya-backend/src/admin/http.test.ts +++ b/apps/freya-backend/src/admin/http.test.ts @@ -44,7 +44,7 @@ mock.module("../sources/user-sources.ts", () => ({ }), })) -mock.module("../conversations/storage.ts", () => ({ +mock.module("../conversations/db-storage.ts", () => ({ conversations: (_db: Database, userId: string) => ({ async getOrCreateConversation() { return { id: `conversation-${userId}` } diff --git a/apps/freya-backend/src/agent/job.ts b/apps/freya-backend/src/agent/job.ts new file mode 100644 index 0000000..29c426b --- /dev/null +++ b/apps/freya-backend/src/agent/job.ts @@ -0,0 +1,145 @@ +import type { AgentEvent } from "@freya/agent-protocol" + +import { + AssistantMessagePayload, + ConversationEntryKind, + UserMessagePayload, + ToolCallPayload, + ToolResultPayload, +} from "@freya/core" +import { type } from "arktype" + +import type { ConversationStorage } from "../conversations/storage" +import type { Job } from "../lib/job" +import type { JobExecutor } from "../lib/worker" +import type { NotificationCentral } from "../notification/notification-central" +import type { UserSessionManager } from "../session" + +import { ConversationResponseStateStatus } from "../db/schema" +import { streamAgentResponse } from "./streaming" + +export interface AgentResponseJobPayload { + conversationId: string +} + +interface AgentResponseWorkerConfig { + conversationStorage: ConversationStorage + userSessionManager: UserSessionManager + notificationCentral: NotificationCentral +} + +export class AgentResponseJobExecutor implements JobExecutor { + private conversationStorage: ConversationStorage + private userSessionManager: UserSessionManager + private notificationCentral: NotificationCentral + + constructor({ + conversationStorage, + userSessionManager, + notificationCentral, + }: AgentResponseWorkerConfig) { + this.conversationStorage = conversationStorage + this.userSessionManager = userSessionManager + this.notificationCentral = notificationCentral + } + + async execute(job: Job): Promise { + const conversation = await this.conversationStorage.findConversation(job.payload.conversationId) + if (!conversation) { + return + } + + const claimed = await this.conversationStorage.claimPendingConversationResponseState( + job.payload.conversationId, + ) + if (!claimed) { + // conversation response state not found or already claimed + return + } + + const pendingEntries = await this.conversationStorage.listPendingUserConversationEntries( + conversation.userId, + conversation.id, + ) + if (pendingEntries.length === 0) { + await this.conversationStorage.clearConversationResponseState(job.payload.conversationId) + return + } + + const message = pendingEntries.reduce((acc, entry) => { + const payload = UserMessagePayload(entry.payload) + if (payload instanceof type.errors) { + return acc + } + return ( + acc + "\n" + payload.parts.reduce((msg, p) => (p.type === "text" ? msg + p.text : msg), "") + ) + }, "") + + const session = await this.userSessionManager.getOrCreate(conversation.userId) + + try { + for await (const event of streamAgentResponse({ + agent: session.agent, + input: { message, signal: job.signal }, + })) { + if (job.signal.aborted) { + break + } + + await this.recordAgentEvent(event, conversation.id) + await this.notificationCentral.notifyUser(conversation.userId, { + kind: "agent", + payload: event, + }) + } + + // if job is aborted, stop everything immediately, including clean up. + // the aborter is assumed responsibility on how to proceed. + if (!job.signal.aborted) { + await this.conversationStorage.clearConversationResponseState(job.payload.conversationId) + } + } catch (err) { + console.error("[agent job executor] error streaming agent response:", err) + if (!job.signal.aborted) { + await this.conversationStorage.markResponseStateStatus( + [job.payload.conversationId], + ConversationResponseStateStatus.Failed, + ) + } + } + } + + private async recordAgentEvent(event: AgentEvent, conversationId: string) { + switch (event.type) { + case "message_created": + await this.conversationStorage.appendEntry(conversationId, { + kind: ConversationEntryKind.AssistantMessage, + payload: { + role: "assistant", + parts: [{ type: "text", text: event.text }], + } satisfies AssistantMessagePayload, + }) + break + + case "tool_started": + await this.conversationStorage.appendEntry(conversationId, { + kind: ConversationEntryKind.ToolCall, + payload: { + toolName: event.toolName, + } satisfies ToolCallPayload, + }) + break + + case "tool_finished": + await this.conversationStorage.appendEntry(conversationId, { + kind: ConversationEntryKind.ToolResult, + payload: { + toolName: event.toolName, + ok: event.ok, + } satisfies ToolResultPayload, + }) + break + } + } +} diff --git a/apps/freya-backend/src/agent/pi-query-agent.ts b/apps/freya-backend/src/agent/pi-query-agent.ts index 8c285ca..50abe64 100644 --- a/apps/freya-backend/src/agent/pi-query-agent.ts +++ b/apps/freya-backend/src/agent/pi-query-agent.ts @@ -166,6 +166,16 @@ export class PiQueryAgent implements QueryAgent { this.handlePiEvent(event, pushRunEvent) }) + input.signal?.addEventListener( + "abort", + async () => { + await session.abort() + close() + unsubscribe() + }, + { once: true }, + ) + session .prompt(input.message) .then(() => { diff --git a/apps/freya-backend/src/agent/query-agent.ts b/apps/freya-backend/src/agent/query-agent.ts index 54bd132..14589a9 100644 --- a/apps/freya-backend/src/agent/query-agent.ts +++ b/apps/freya-backend/src/agent/query-agent.ts @@ -2,6 +2,7 @@ export interface QueryAgentAsk { message: string conversationId?: string userMessageEntry?: QueryAgentConversationEntryRef + signal?: AbortSignal } export type QueryAgentStreamEvent = diff --git a/apps/freya-backend/src/agent/reconciler.ts b/apps/freya-backend/src/agent/reconciler.ts new file mode 100644 index 0000000..6b60a1a --- /dev/null +++ b/apps/freya-backend/src/agent/reconciler.ts @@ -0,0 +1,70 @@ +import type { ConversationStorage } from "../conversations/storage" +import type { AgentWorkScheduler } from "./scheduler" + +interface AgentResponseReconcilerConfig { + storage: ConversationStorage + interval: number + scheduler: AgentWorkScheduler + signal: AbortSignal +} + +export class AgentResponseReconciler { + private storage: ConversationStorage + private interval: number + private scheduler: AgentWorkScheduler + private signal: AbortSignal + + private stopLoop: ReturnType | null = null + + constructor({ storage, interval, scheduler, signal }: AgentResponseReconcilerConfig) { + this.storage = storage + this.interval = interval + this.scheduler = scheduler + this.signal = signal + } + + start() { + this.signal.throwIfAborted() + + this.signal.addEventListener( + "abort", + () => { + if (this.stopLoop !== null) { + clearInterval(this.stopLoop) + this.stopLoop = null + } + }, + { once: true }, + ) + + this.stopLoop = setInterval(this.reconcile.bind(this), this.interval) + } + + private async reconcile() { + // enqueue pending responses + const pendingStates = await this.storage.listPendingResponseStates() + const now = new Date().getTime() + for (const state of pendingStates) { + if (state.maxWaitUntil.getTime() < now) { + this.scheduler.enqueueAgentResponse(state.conversationId) + } + } + + // re-enqueue stuck responses + const runningStates = await this.storage.listRunningResponseStates() + const stuckIds: string[] = [] + for (const state of runningStates) { + if (state.runningSince && Math.max(now - state.runningSince.getTime(), 0) > 5 * 1000 * 60) { + // if the response is running for more than 5 minutes + // we assume that its stuck and enqueue it for retry + stuckIds.push(state.conversationId) + } + } + if (stuckIds.length > 0) { + await this.storage.markResponseStateStatus(stuckIds, "pending") + for (const id of stuckIds) { + this.scheduler.enqueueAgentResponse(id) + } + } + } +} diff --git a/apps/freya-backend/src/agent/scheduler.ts b/apps/freya-backend/src/agent/scheduler.ts new file mode 100644 index 0000000..fe31471 --- /dev/null +++ b/apps/freya-backend/src/agent/scheduler.ts @@ -0,0 +1,160 @@ +import type { UserEvent } from "@freya/agent-protocol" + +import { ConversationEntryKind, UserMessagePayload } from "@freya/core" + +import type { ConversationStorage } from "../conversations/storage" +import type { Job, JobRegistry } from "../lib/job" +import type { AgentResponseJobPayload } from "./job" +import { ConversationNotFoundError } from "../conversations/errors"; +import { ConversationResponseStateStatus } from "../db/schema"; + +interface AgentMessageSchedulerConfig { + storage: ConversationStorage + maxWaitTime: number + + /** + * How long to wait before responding to the user. + */ + waitTIme: number + + jobRegistry: JobRegistry +} + +/** + * Schedules and manages the flow of messages between the user and the query agent for a specific conversation. + */ +export class AgentWorkScheduler { + private conversationStorage: ConversationStorage + private jobRegistry: JobRegistry + + private timing: { + maxWaitTime: number + waitTime: number + } + + private timers = new Map>() + private runningJobs = new Map>() + + constructor(config: AgentMessageSchedulerConfig) { + this.conversationStorage = config.storage + this.jobRegistry = config.jobRegistry + this.timing = { + maxWaitTime: config.maxWaitTime, + waitTime: config.waitTIme, + } + + this.jobRegistry.addEventListener("settled", this.eraseJob.bind(this)) + this.jobRegistry.addEventListener("cancelled", this.eraseJob.bind(this)) + } + + async receiveMessage(conversationId: string, message: string) { + await this.conversationStorage.transaction(async (storage) => { + const now = new Date() + + const entry = await storage.appendEntry(conversationId, { + kind: ConversationEntryKind.UserMessage, + payload: { + role: "user", + parts: [{ type: "text", text: message }], + } satisfies UserMessagePayload, + }) + + await storage.upsertConversationResponseState(conversationId, { + maxWaitUntil: new Date(now.getTime() + this.timing.maxWaitTime), + pendingSinceEntryId: entry.id, + status: "pending", + }) + + return entry + }) + this.scheduleAgentResponse(conversationId, this.timing.waitTime) + } + + async receiveUserEvent(conversationId: string, event: UserEvent) { + if (event.type === "typing") { + await this.delayAgentResponse(conversationId) + } + } + + enqueueAgentResponse(conversationId: string): void { + const existing = this.timers.get(conversationId) + if (existing) { + clearTimeout(existing) + this.timers.delete(conversationId) + } + + this.cancelCurrentJob(conversationId) + + const job = this.jobRegistry.addJob({ + payload: { conversationId }, + }) + this.runningJobs.set(conversationId, job) + } + + private async delayAgentResponse(conversationId: string) { + this.cancelCurrentJob(conversationId); + + try { + const ok = await this.conversationStorage.transaction(async (storage) => { + const state = await storage.findConversationResponseState(conversationId); + if (state && state.status !== ConversationResponseStateStatus.Failed) { + await storage.updateConversationResponseState(conversationId, { + status: ConversationResponseStateStatus.Pending, + // the agent response was cancelled, so its no longer running + // clear runningSince timestamp + runningSince: null, + }) + return true + } + return false + }) + if (ok) { + await this.scheduleAgentResponse(conversationId, this.timing.waitTime) + } + } catch (error) { + if (error instanceof ConversationNotFoundError) { + // the user is typing but there isn't a scheduled agent response yet + // which means the user is typing their first message after the agent has previously responded + // swallow the error + } else { + console.error("[agent response scheduler] error delaying agent response", error) + } + return + } + } + + private async scheduleAgentResponse(conversationId: string, delay: number) { + const existing = this.timers.get(conversationId) + if (existing) { + clearTimeout(existing) + } + + this.cancelCurrentJob(conversationId) + + this.timers.set( + conversationId, + setTimeout(() => { + this.enqueueAgentResponse(conversationId) + }, delay), + ) + } + + /** + * cancels the current job for agent response for the given conversation id + * no-op if there is no active job for the conversation. + */ + private cancelCurrentJob(conversationId: string): void { + const job = this.runningJobs.get(conversationId) + if (!job) return + + // If an active response is working on stale context, abort it so the next + // job can answer using the latest pending user messages. + this.jobRegistry.cancelJob(job) + } + + private eraseJob(job: Job) { + if (this.runningJobs.get(job.payload.conversationId) === job) { + this.runningJobs.delete(job.payload.conversationId) + } + } +} diff --git a/apps/freya-backend/src/agent/service.ts b/apps/freya-backend/src/agent/service.ts new file mode 100644 index 0000000..4889d96 --- /dev/null +++ b/apps/freya-backend/src/agent/service.ts @@ -0,0 +1,66 @@ +import type { UserEvent } from "@freya/agent-protocol" + +import type { ConversationStorage } from "../conversations/storage" +import type { NotificationCentral } from "../notification/notification-central" +import type { UserSessionManager } from "../session" + +import { JobRegistry } from "../lib/job" +import { Worker } from "../lib/worker" +import { AgentResponseJobExecutor, type AgentResponseJobPayload } from "./job" +import { AgentResponseReconciler } from "./reconciler" +import { AgentWorkScheduler } from "./scheduler" + +interface AgentServiceConfig { + storage: ConversationStorage + userSessionManager: UserSessionManager + notificationCentral: NotificationCentral + signal: AbortSignal +} + +export class AgentService { + private readonly storage: ConversationStorage + private readonly scheduler: AgentWorkScheduler + private readonly reconciler: AgentResponseReconciler + private readonly worker: Worker + + private readonly jobRegistry = new JobRegistry() + + constructor({ storage, userSessionManager, notificationCentral, signal }: AgentServiceConfig) { + this.storage = storage + this.scheduler = new AgentWorkScheduler({ + storage, + jobRegistry: this.jobRegistry, + waitTIme: 5 * 1000, + maxWaitTime: 5 * 1000 * 60, + }) + this.reconciler = new AgentResponseReconciler({ + signal, + storage: this.storage, + interval: 60 * 1000, + scheduler: this.scheduler, + }) + this.worker = new Worker({ + signal, + concurrency: 10, + registry: this.jobRegistry, + runner: new AgentResponseJobExecutor({ + conversationStorage: storage, + notificationCentral, + userSessionManager, + }), + }) + } + + start() { + this.worker.start() + this.reconciler.start() + } + + async scheduleAgentResponse(conversationId: string, message: string) { + await this.scheduler.receiveMessage(conversationId, message) + } + + async handleUserEvent(conversationId: string, event: UserEvent) { + await this.scheduler.receiveUserEvent(conversationId, event) + } +} diff --git a/apps/freya-backend/src/agent/streaming.test.ts b/apps/freya-backend/src/agent/streaming.test.ts index 3e4f2d3..2c1aef1 100644 --- a/apps/freya-backend/src/agent/streaming.test.ts +++ b/apps/freya-backend/src/agent/streaming.test.ts @@ -9,7 +9,6 @@ import type { QueryAgentEventListener, QueryAgentStreamEvent, } from "./query-agent.ts" -import type { AgentResponseStreamItem } from "./streaming.ts" import { streamAgentResponse } from "./streaming.ts" @@ -47,17 +46,13 @@ describe("streamAgentResponse", () => { { type: "done" }, ]) - const { events, result } = await collectStreamAgentResponse( + const events = await collectStreamAgentResponse( streamAgentResponse({ agent, input: { message: "hello" }, }), ) - expect(result).toEqual({ - conversationId: "conversation-1", - message: "First message\nSecond message\nThird message", - }) expect(events).toEqual([ { type: "conversation_started", conversationId: "conversation-1" }, { type: "message_created", text: "First message" }, @@ -74,17 +69,13 @@ describe("streamAgentResponse", () => { { type: "done" }, ]) - const { events, result } = await collectStreamAgentResponse( + const events = await collectStreamAgentResponse( streamAgentResponse({ agent, input: { message: "hello" }, }), ) - expect(result).toEqual({ - conversationId: "conversation-1", - message: " const value = 1 \n\n return value", - }) expect(events).toEqual([ { type: "conversation_started", conversationId: "conversation-1" }, { type: "message_created", text: " const value = 1 " }, @@ -122,28 +113,12 @@ describe("streamAgentResponse", () => { }) async function collectStreamAgentResponse( - stream: AsyncIterable, + stream: AsyncIterable, events: AgentEvent[] = [], -): Promise<{ - events: AgentEvent[] - result: { message: string; conversationId: string } -}> { - let result: { message: string; conversationId: string } | null = null - - for await (const item of stream) { - switch (item.type) { - case "event": - events.push(item.event) - break - case "result": - result = item.result - break - } +): Promise { + for await (const event of stream) { + events.push(event) } - if (!result) { - throw new Error("Expected stream result") - } - - return { events, result } + return events } diff --git a/apps/freya-backend/src/agent/streaming.ts b/apps/freya-backend/src/agent/streaming.ts index f80abda..80e1260 100644 --- a/apps/freya-backend/src/agent/streaming.ts +++ b/apps/freya-backend/src/agent/streaming.ts @@ -1,10 +1,8 @@ -import type { AgentEvent, SendMessageResult } from "@freya/agent-protocol" +import type { AgentEvent } from "@freya/agent-protocol" import type { QueryAgent, QueryAgentAsk } from "./query-agent.ts" -export type AgentResponseStreamItem = - | { type: "event"; event: AgentEvent } - | { type: "result"; result: SendMessageResult } +export type AgentResponseStreamItem = { type: "event"; event: AgentEvent } export async function* streamAgentResponse({ agent, @@ -12,18 +10,18 @@ export async function* streamAgentResponse({ }: { agent: QueryAgent input: QueryAgentAsk -}): AsyncGenerator { +}): AsyncGenerator { let message = "" let conversationId: string | null = null const splitter = new AgentMessageSplitter() - function messageEvent(text: string): AgentResponseStreamItem | null { + function messageEvent(text: string): AgentEvent | null { if (text.trim() === "") return null - return { type: "event", event: { type: "message_created", text } } + return { type: "message_created", text } } - function flushPendingMessage(): AgentResponseStreamItem | null { + function flushPendingMessage(): AgentEvent | null { const text = splitter.flush() if (text === null) return null @@ -31,10 +29,14 @@ export async function* streamAgentResponse({ } for await (const event of agent.ask(input)) { + if (input.signal?.aborted) { + break + } + switch (event.type) { case "conversation": conversationId = event.conversationId - yield { type: "event", event: { type: "conversation_started", conversationId } } + yield { type: "conversation_started", conversationId } break case "text_delta": @@ -50,7 +52,7 @@ export async function* streamAgentResponse({ const item = flushPendingMessage() if (item) yield item } - yield { type: "event", event: { type: "tool_started", toolName: event.toolName } } + yield { type: "tool_started", toolName: event.toolName } break case "tool_end": @@ -59,12 +61,9 @@ export async function* streamAgentResponse({ if (item) yield item } yield { - type: "event", - event: { - type: "tool_finished", - toolName: event.toolName, - ok: event.ok, - }, + type: "tool_finished", + toolName: event.toolName, + ok: event.ok, } break @@ -73,7 +72,7 @@ export async function* streamAgentResponse({ const item = flushPendingMessage() if (item) yield item } - yield { type: "event", event: { type: "message_failed", error: event.message } } + yield { type: "message_failed", error: event.message } throw new Error(event.message) case "done": @@ -81,26 +80,15 @@ export async function* streamAgentResponse({ const item = flushPendingMessage() if (item) yield item } - const result = createResult(message, conversationId) - yield { type: "event", event: { type: "message_finished" } } - yield { type: "result", result } + yield { type: "message_finished" } return } } const item = flushPendingMessage() if (item) yield item - const result = createResult(message, conversationId) - yield { type: "event", event: { type: "message_finished" } } - yield { type: "result", result } -} -function createResult(message: string, conversationId: string | null): SendMessageResult { - if (!conversationId) { - throw new Error("Agent response stream ended without a conversation id") - } - - return { message, conversationId } + yield { type: "message_finished" } } class AgentMessageSplitter { diff --git a/apps/freya-backend/src/agent/ws.test.ts b/apps/freya-backend/src/agent/ws.test.ts index 74ceb8a..f13867f 100644 --- a/apps/freya-backend/src/agent/ws.test.ts +++ b/apps/freya-backend/src/agent/ws.test.ts @@ -1,8 +1,10 @@ import { describe, expect, test } from "bun:test" import { Hono } from "hono" -import type { UserSessionManager } from "../session/index.ts" +import type { ConversationStorage } from "../conversations/storage.ts" +import type { NotificationCentral } from "../notification/notification-central.ts" +import type { AgentService } from "./service.ts" import { registerAgentWebSocketHandlers } from "./ws.ts" describe("agent websocket handler", () => { @@ -11,7 +13,9 @@ describe("agent websocket handler", () => { const app = new Hono() registerAgentWebSocketHandlers(app, { - sessionManager: {} as UserSessionManager, + agentService: {} as AgentService, + storage: {} as ConversationStorage, + notificationCentral: {} as NotificationCentral, corsMiddleware: async (c, next) => { const origin = c.req.header("origin") if (origin && origin !== "https://app.freya.test") { @@ -44,7 +48,9 @@ describe("agent websocket handler", () => { const app = new Hono() registerAgentWebSocketHandlers(app, { - sessionManager: {} as UserSessionManager, + agentService: {} as AgentService, + storage: {} as ConversationStorage, + notificationCentral: {} as NotificationCentral, corsMiddleware: async (_c, next) => { await next() }, diff --git a/apps/freya-backend/src/agent/ws.ts b/apps/freya-backend/src/agent/ws.ts index 9d66fb0..03ab819 100644 --- a/apps/freya-backend/src/agent/ws.ts +++ b/apps/freya-backend/src/agent/ws.ts @@ -1,53 +1,58 @@ -import type { AgentClientApi, AgentServerApi, SendMessageResult } from "@freya/agent-protocol" +import type { AgentClientApi, AgentServerApi, UserEvent } from "@freya/agent-protocol" import type { JrpcChannel, JrpcMessage, JsonRpcMessage } from "@nym.sh/jrpc" import type { Hono, MiddlewareHandler } from "hono" import type { WSContext } from "hono/ws" -import { JsonRpcClient, JsonRpcServer } from "@nym.sh/jrpc" -import { type } from "arktype" +import { JsonRpcClient, JsonRpcServer, deserializeJrpcMessage } from "@nym.sh/jrpc" import { upgradeWebSocket, websocket } from "hono/bun" import type { AuthSessionMiddleware } from "../auth/session-middleware.ts" -import type { UserSessionManager } from "../session/index.ts" - -import { streamAgentResponse } from "./streaming.ts" +import type { ConversationStorage } from "../conversations/storage.ts" +import type { + NotificationCentral, + NotificationPayload, +} from "../notification/notification-central.ts" +import type { AgentService } from "./service.ts" interface AgentWebSocketHandlerDeps { - sessionManager: UserSessionManager + agentService: AgentService + storage: ConversationStorage + notificationCentral: NotificationCentral authSessionMiddleware: AuthSessionMiddleware corsMiddleware: MiddlewareHandler } -interface ValidSendMessageInput { - message: string -} - export const agentWebSocket = websocket -const SendMessageInputBody = type({ - "+": "reject", - message: "string", -}) - export function registerAgentWebSocketHandlers( app: Hono, - { sessionManager, authSessionMiddleware, corsMiddleware }: AgentWebSocketHandlerDeps, + { + agentService, + storage, + notificationCentral, + authSessionMiddleware, + corsMiddleware, + }: AgentWebSocketHandlerDeps, ): void { app.get( "/api/agent/ws", corsMiddleware, authSessionMiddleware, - upgradeWebSocket((c) => { + upgradeWebSocket(async (c) => { const user = c.get("user") if (!user) { throw new Error("Authenticated WebSocket user missing") } + const conversation = await storage.getOrCreateConversation(user.id) + const channel = new HonoWebSocketJrpcChannel() const connection = new AgentRpcConnection({ channel, - sessionManager, + notificationCentral, + agentService, userId: user.id, + conversationId: conversation.id, }) return { @@ -64,6 +69,7 @@ export function registerAgentWebSocketHandlers( }, onClose() { + connection.close() channel.close() }, } @@ -74,54 +80,52 @@ export function registerAgentWebSocketHandlers( class AgentRpcConnection implements AgentServerApi { private readonly client: JsonRpcClient private readonly server: JsonRpcServer - private activeMessage: Promise | null = null - private readonly sessionManager: UserSessionManager + private readonly agentService: AgentService + private readonly notificationCentral: NotificationCentral private readonly userId: string + private readonly conversationId: string + + private cleanup: (() => void) | null = null constructor({ + agentService, + notificationCentral, channel, - sessionManager, userId, + conversationId, }: { + agentService: AgentService + notificationCentral: NotificationCentral channel: JrpcChannel - sessionManager: UserSessionManager userId: string + conversationId: string }) { - this.sessionManager = sessionManager - this.userId = userId this.client = new JsonRpcClient(channel) + this.agentService = agentService + this.notificationCentral = notificationCentral + this.userId = userId + this.conversationId = conversationId this.server = new JsonRpcServer( { sendMessage: this.sendMessage.bind(this), + notify: this.notify.bind(this), ping: this.ping.bind(this), }, channel, ) } - start(): Promise { - return this.server.start() + notify(event: UserEvent): void { + this.agentService.handleUserEvent(this.conversationId, event) } - async sendMessage(message: string): Promise { - const parsed = SendMessageInputBody({ message }) - if (parsed instanceof type.errors) { - throw new Error(parsed.summary) - } - - if (this.activeMessage) { - throw new Error("A message is already running") - } - - const run = this.runMessage(parsed) - this.activeMessage = run - + async sendMessage(message: string): Promise { try { - return await run - } finally { - if (this.activeMessage === run) { - this.activeMessage = null - } + await this.agentService.scheduleAgentResponse(this.conversationId, message) + return true + } catch (error) { + console.log("[agent rpc connection] error when scheduling agent response", error) + return false } } @@ -129,26 +133,22 @@ class AgentRpcConnection implements AgentServerApi { return "pong" } - private async runMessage(input: ValidSendMessageInput): Promise { - const session = await this.sessionManager.getOrCreate(this.userId) - let result: SendMessageResult | null = null + async start() { + this.cleanup = this.notificationCentral.registerListenerForUser( + this.userId, + this.onNotificationReceived.bind(this), + ) + await this.server.start() + } - for await (const item of streamAgentResponse({ agent: session.agent, input })) { - switch (item.type) { - case "event": - await this.client.call("notify", item.event) - break - case "result": - result = item.result - break - } + close() { + this.cleanup?.() + } + + private async onNotificationReceived(notification: NotificationPayload) { + if (notification.kind === "agent") { + await this.client.call("notify", notification.payload) } - - if (!result) { - throw new Error("Agent response stream ended without a result") - } - - return result } } @@ -171,7 +171,11 @@ class HonoWebSocketJrpcChannel implements JrpcChannel { } receive(message: unknown): void { - const parsed = parseJrpcMessage(message) + if (typeof message !== "string") { + return + } + + const parsed = deserializeJrpcMessage(message) if (!parsed) { this.ws?.close(1003, "Invalid JSON-RPC message") return @@ -236,52 +240,6 @@ class HonoWebSocketJrpcChannel implements JrpcChannel { } } -function parseJrpcMessage(message: unknown): JrpcMessage | null { - const text = webSocketMessageText(message) - if (text === null) return null - - try { - const value: unknown = JSON.parse(text) - return isJrpcMessage(value) ? value : null - } catch { - return null - } -} - -function webSocketMessageText(message: unknown): string | null { - if (typeof message === "string") return message - if (message instanceof ArrayBuffer) return Buffer.from(message).toString("utf8") - if (ArrayBuffer.isView(message)) { - return Buffer.from(message.buffer, message.byteOffset, message.byteLength).toString("utf8") - } - - return null -} - -function isJrpcMessage(value: unknown): value is JrpcMessage { - if (typeof value !== "object" || value === null) return false - if (!("jsonrpc" in value) || value.jsonrpc !== "2.0") return false - - if ("method" in value) { - return "id" in value && typeof value.id === "number" && typeof value.method === "string" - } - - if ("result" in value) { - return "id" in value && typeof value.id === "number" - } - - if ("error" in value) { - return ( - "id" in value && - typeof value.id === "number" && - typeof value.error === "object" && - value.error !== null - ) - } - - return false -} - function errorMessage(error: unknown): string { return error instanceof Error ? error.message : String(error) } diff --git a/apps/freya-backend/src/conversations/db-storage.ts b/apps/freya-backend/src/conversations/db-storage.ts new file mode 100644 index 0000000..32e2d26 --- /dev/null +++ b/apps/freya-backend/src/conversations/db-storage.ts @@ -0,0 +1,686 @@ +import { + AssistantMessagePayload, + AttachmentPayload, + ConversationEntryKind, + ConversationEntryMetadata, + ConversationEntryVisibility, + ContextSummaryPayload, + GenericObjectPayload, + UserMessagePayload, + type ConversationEntryPayload, +} from "@freya/core" +import { type } from "arktype" +import { and, asc, desc, eq, gte, inArray } from "drizzle-orm" +import { alias } from "drizzle-orm/pg-core" + +import type { Database } from "../db/index.ts" +import type { + AppendAttachmentEntryInput, + AppendAttachmentEntryResult, + AppendConversationEntryInput, + ConversationEntryRow, + ConversationResponseStateRow, + ConversationRow, + ConversationStorage, + CreateFileInput, + FileRow, + ListConversationEntriesParams, + UpdateConversationResponseStateInput, + UpsertConversationResponseStateInput, +} from "./storage.ts" + +import { + conversationEntries, + ConversationResponseStateStatus, + conversationResponseState as conversationResponseStateTable, + conversations as conversationsTable, + files, + user, +} from "../db/schema.ts" +import { ConversationNotFoundError } from "./errors.ts" + +const conversationEntryKind = type.enumerated(...Object.values(ConversationEntryKind)) +const conversationEntryVisibility = type.enumerated(...Object.values(ConversationEntryVisibility)) +const pendingSinceEntry = alias(conversationEntries, "pending_since_entry") + +export class DrizzleConversationStorage implements ConversationStorage { + private readonly db: Database + private readonly inTransaction: boolean + + constructor(db: Database, inTransaction = false) { + this.db = db + this.inTransaction = inTransaction + } + + async transaction(tx: (storage: ConversationStorage) => T | Promise): Promise { + if (this.inTransaction) return tx(this) + + return this.db.transaction(async (transactionDb) => + tx(new DrizzleConversationStorage(transactionDb, true)), + ) + } + + async createConversation(userId: string): Promise { + return insertConversation(this.db, userId) + } + + async listUserConversations(userId: string): Promise { + return this.db + .select() + .from(conversationsTable) + .where(eq(conversationsTable.userId, userId)) + .orderBy(desc(conversationsTable.updatedAt), desc(conversationsTable.createdAt)) + } + + async findConversation(conversationId: string): Promise { + return findConversation(this.db, conversationId) + } + + async getOrCreateConversation(userId: string): Promise { + return this.write(async (db) => { + await requireUserForUpdate(db, userId) + const existing = await latestConversation(db, userId) + if (existing) return existing + + return insertConversation(db, userId) + }) + } + + async createFile(userId: string, input: CreateFileInput): Promise { + return insertFile(this.db, userId, input) + } + + async appendEntry( + conversationId: string, + input: AppendConversationEntryInput, + ): Promise { + return this.write((db) => appendEntryToConversation(db, null, conversationId, input)) + } + + async appendAttachmentEntry( + conversationId: string, + input: AppendAttachmentEntryInput, + ): Promise { + return this.write((db) => appendAttachmentEntryToConversation(db, null, conversationId, input)) + } + + async nextSequence(conversationId: string): Promise { + return nextSequence(this.db, conversationId) + } + + async listUserConversationEntries( + userId: string, + conversationId: string, + params: ListConversationEntriesParams = {}, + ): Promise { + if (!(await findUserConversation(this.db, userId, conversationId))) { + throw new ConversationNotFoundError(conversationId, userId) + } + + if (params.visibility) { + return this.db + .select() + .from(conversationEntries) + .where( + and( + eq(conversationEntries.conversationId, conversationId), + eq(conversationEntries.visibility, params.visibility), + ), + ) + .orderBy(asc(conversationEntries.sequence)) + } + + return this.db + .select() + .from(conversationEntries) + .where(eq(conversationEntries.conversationId, conversationId)) + .orderBy(asc(conversationEntries.sequence)) + } + + async listPendingUserConversationEntries( + userId: string, + conversationId: string, + ): Promise { + const entries = await this.db + .select({ entry: conversationEntries }) + .from(conversationResponseStateTable) + .innerJoin( + conversationsTable, + and( + eq(conversationsTable.id, conversationResponseStateTable.conversationId), + eq(conversationsTable.userId, userId), + ), + ) + .innerJoin( + pendingSinceEntry, + and( + eq(pendingSinceEntry.id, conversationResponseStateTable.pendingSinceEntryId), + eq(pendingSinceEntry.conversationId, conversationResponseStateTable.conversationId), + ), + ) + .innerJoin( + conversationEntries, + and( + eq(conversationEntries.conversationId, conversationResponseStateTable.conversationId), + eq(conversationEntries.kind, ConversationEntryKind.UserMessage), + gte(conversationEntries.sequence, pendingSinceEntry.sequence), + ), + ) + .where( + and( + eq(conversationResponseStateTable.conversationId, conversationId), + eq(conversationEntries.conversationId, conversationId), + ), + ) + .orderBy(asc(conversationEntries.sequence)) + + if (entries.length > 0) return entries.map(({ entry }) => entry) + if (await findUserConversation(this.db, userId, conversationId)) return [] + + throw new ConversationNotFoundError(conversationId, userId) + } + + async findConversationResponseState( + conversationId: string, + ): Promise { + const rows = await this.db + .select() + .from(conversationResponseStateTable) + .where(eq(conversationResponseStateTable.conversationId, conversationId)) + .limit(1) + + return rows[0] ?? null + } + + async listPendingResponseStates(): Promise { + const rows = await this.db + .select() + .from(conversationResponseStateTable) + .where(eq(conversationResponseStateTable.status, ConversationResponseStateStatus.Pending)) + + return rows + } + + async listRunningResponseStates(): Promise { + const rows = await this.db + .select() + .from(conversationResponseStateTable) + .where(eq(conversationResponseStateTable.status, ConversationResponseStateStatus.Running)) + + return rows + } + + async upsertConversationResponseState( + conversationId: string, + input: UpsertConversationResponseStateInput, + ): Promise { + const now = new Date() + + return this.write(async (db) => { + if (!(await findConversationByIdForUpdate(db, conversationId))) { + throw new ConversationNotFoundError(conversationId, "") + } + + const rows = await db + .insert(conversationResponseStateTable) + .values({ + conversationId, + status: input.status ?? ConversationResponseStateStatus.Pending, + pendingSinceEntryId: input.pendingSinceEntryId, + maxWaitUntil: input.maxWaitUntil, + runningSince: input.runningSince ?? null, + updatedAt: now, + }) + .onConflictDoUpdate({ + target: conversationResponseStateTable.conversationId, + set: { + status: input.status ?? ConversationResponseStateStatus.Pending, + maxWaitUntil: input.maxWaitUntil, + runningSince: input.runningSince ?? null, + updatedAt: now, + }, + }) + .returning() + + return requireRow(rows) + }) + } + + async updateConversationResponseState( + conversationId: string, + input: UpdateConversationResponseStateInput, + ): Promise { + return this.write(async (db) => { + if (!(await findConversationByIdForUpdate(db, conversationId))) { + throw new ConversationNotFoundError(conversationId, "") + } + + const rows = await db + .update(conversationResponseStateTable) + .set({ + status: input.status, + pendingSinceEntryId: input.pendingSinceEntryId, + maxWaitUntil: input.maxWaitUntil, + runningSince: input.runningSince, + updatedAt: new Date(), + }) + .where(eq(conversationResponseStateTable.conversationId, conversationId)) + .returning() + + return rows[0] ?? null + }) + } + + async markResponseStateStatus( + conversationIds: string[], + status: ConversationResponseStateStatus, + ): Promise { + return this.write(async (db) => { + const now = new Date() + + let runningSince: Date | null + switch (status) { + case "pending": + case "failed": + runningSince = null + break + case "running": + runningSince = now + break + } + + const rows = await db + .update(conversationResponseStateTable) + .set({ + status, + runningSince, + updatedAt: now, + }) + .where(inArray(conversationResponseStateTable.conversationId, conversationIds)) + .returning() + + return rows + }) + } + + async claimPendingConversationResponseState( + conversationId: string, + ): Promise { + return this.write(async (db) => { + const now = new Date() + const rows = await db + .update(conversationResponseStateTable) + .set({ + status: "running", + runningSince: now, + updatedAt: now, + }) + .where( + and( + eq(conversationResponseStateTable.conversationId, conversationId), + eq(conversationResponseStateTable.status, "pending"), + ), + ) + .returning() + + return rows[0] ?? null + }) + } + + async clearConversationResponseState(conversationId: string): Promise { + await this.write(async (db) => { + if (!(await findConversationByIdForUpdate(db, conversationId))) { + throw new ConversationNotFoundError(conversationId, "") + } + + await db + .delete(conversationResponseStateTable) + .where(eq(conversationResponseStateTable.conversationId, conversationId)) + }) + } + + private async write(fn: (db: Database) => Promise): Promise { + if (this.inTransaction) return fn(this.db) + + return this.db.transaction(fn) + } +} + +export function createConversationStorage(db: Database): ConversationStorage { + return new DrizzleConversationStorage(db) +} + +export function conversations(db: Database, userId: string) { + const storage = createConversationStorage(db) + + return { + createConversation(): Promise { + return storage.createConversation(userId) + }, + + listConversations(): Promise { + return storage.listUserConversations(userId) + }, + + getConversation(conversationId: string): Promise { + return findUserConversation(db, userId, conversationId) + }, + + getOrCreateConversation(): Promise { + return storage.getOrCreateConversation(userId) + }, + + createFile(input: CreateFileInput): Promise { + return storage.createFile(userId, input) + }, + + appendEntry( + conversationId: string, + input: AppendConversationEntryInput, + ): Promise { + return db.transaction((tx) => appendEntryToConversation(tx, userId, conversationId, input)) + }, + + appendAttachmentEntry( + conversationId: string, + input: AppendAttachmentEntryInput, + ): Promise { + return db.transaction((tx) => + appendAttachmentEntryToConversation(tx, userId, conversationId, input), + ) + }, + + listEntries( + conversationId: string, + params: ListConversationEntriesParams = {}, + ): Promise { + return storage.listUserConversationEntries(userId, conversationId, params) + }, + } +} + +export function conversationResponse(db: Database, _userId: string, conversationId: string) { + const storage = createConversationStorage(db) + + return { + get(): Promise { + return storage.findConversationResponseState(conversationId) + }, + + upsert(input: UpsertConversationResponseStateInput): Promise { + return storage.upsertConversationResponseState(conversationId, input) + }, + + update( + input: UpdateConversationResponseStateInput, + ): Promise { + return storage.updateConversationResponseState(conversationId, input) + }, + + clear(): Promise { + return storage.clearConversationResponseState(conversationId) + }, + } +} + +function payloadForKind( + kind: ConversationEntryKind, + payload: AppendConversationEntryInput["payload"], +): ConversationEntryPayload { + switch (kind) { + case ConversationEntryKind.UserMessage: + return UserMessagePayload.assert(payload) + case ConversationEntryKind.AssistantMessage: + return AssistantMessagePayload.assert(payload) + case ConversationEntryKind.Attachment: + return AttachmentPayload.assert(payload) + case ConversationEntryKind.ContextSummary: + return ContextSummaryPayload.assert(payload) + case ConversationEntryKind.ToolCall: + case ConversationEntryKind.ToolResult: + case ConversationEntryKind.SystemNote: + return GenericObjectPayload.assert(payload) + } +} + +async function appendEntryToConversation( + db: Database, + userId: string | null, + conversationId: string, + input: AppendConversationEntryInput, +): Promise { + const kind = conversationEntryKind.assert(input.kind) + const visibility = conversationEntryVisibility.assert( + input.visibility ?? defaultVisibilityForKind(kind), + ) + const payload = payloadForKind(kind, input.payload) + const metadata = ConversationEntryMetadata.assert(input.metadata ?? {}) + let fileId: string | null = null + + if (input.kind === ConversationEntryKind.Attachment) { + fileId = input.fileId + } + + const conversation = userId + ? await findConversationForUpdate(db, userId, conversationId) + : await findConversationByIdForUpdate(db, conversationId) + if (!conversation) { + throw new ConversationNotFoundError(conversationId, userId ?? "") + } + if (fileId) await requireFile(db, conversation.userId, fileId) + + const sequence = await nextSequence(db, conversationId) + const rows = await db + .insert(conversationEntries) + .values({ + conversationId, + sequence, + kind, + visibility, + fileId, + payload, + metadata, + }) + .returning() + + await touchConversation(db, conversation.userId, conversationId) + return requireRow(rows) +} + +async function appendAttachmentEntryToConversation( + db: Database, + userId: string | null, + conversationId: string, + input: AppendAttachmentEntryInput, +): Promise { + const payload = AttachmentPayload.assert(input.payload) + const visibility = conversationEntryVisibility.assert( + input.visibility ?? defaultVisibilityForKind(ConversationEntryKind.Attachment), + ) + const metadata = ConversationEntryMetadata.assert(input.metadata ?? {}) + const conversation = userId + ? await findConversationForUpdate(db, userId, conversationId) + : await findConversationByIdForUpdate(db, conversationId) + + if (!conversation) { + throw new ConversationNotFoundError(conversationId, userId ?? "") + } + + const file = await insertFile(db, conversation.userId, input.file) + const sequence = await nextSequence(db, conversationId) + const rows = await db + .insert(conversationEntries) + .values({ + conversationId, + sequence, + kind: ConversationEntryKind.Attachment, + visibility, + fileId: file.id, + payload, + metadata, + }) + .returning() + + await touchConversation(db, conversation.userId, conversationId) + return { + file, + entry: requireRow(rows), + } +} + +async function requireUserForUpdate(db: Database, userId: string): Promise { + const rows = await db + .select({ id: user.id }) + .from(user) + .where(eq(user.id, userId)) + .limit(1) + .for("update") + + requireRow(rows, `User not found: ${userId}`) +} + +export async function findConversation( + db: Database, + conversationId: string, +): Promise { + const rows = await db + .select() + .from(conversationsTable) + .where(eq(conversationsTable.id, conversationId)) + .limit(1) + + return rows[0] ?? null +} + +async function findUserConversation( + db: Database, + userId: string, + conversationId: string, +): Promise { + const rows = await db + .select() + .from(conversationsTable) + .where(and(eq(conversationsTable.id, conversationId), eq(conversationsTable.userId, userId))) + .limit(1) + + return rows[0] ?? null +} + +async function findConversationForUpdate( + db: Database, + userId: string, + conversationId: string, +): Promise { + const rows = await db + .select() + .from(conversationsTable) + .where(and(eq(conversationsTable.id, conversationId), eq(conversationsTable.userId, userId))) + .limit(1) + .for("update") + + return rows[0] ?? null +} + +async function findConversationByIdForUpdate( + db: Database, + conversationId: string, +): Promise { + const rows = await db + .select() + .from(conversationsTable) + .where(eq(conversationsTable.id, conversationId)) + .limit(1) + .for("update") + + return rows[0] ?? null +} + +async function latestConversation(db: Database, userId: string): Promise { + const rows = await db + .select() + .from(conversationsTable) + .where(eq(conversationsTable.userId, userId)) + .orderBy(desc(conversationsTable.updatedAt), desc(conversationsTable.createdAt)) + .limit(1) + + return rows[0] ?? null +} + +async function insertConversation(db: Database, userId: string): Promise { + const rows = await db + .insert(conversationsTable) + .values({ + userId, + }) + .returning() + + return requireRow(rows) +} + +async function requireFile(db: Database, userId: string, fileId: string): Promise { + const rows = await db + .select() + .from(files) + .where(and(eq(files.id, fileId), eq(files.userId, userId))) + .limit(1) + + return requireRow(rows, `File not found: ${fileId}`) +} + +async function insertFile(db: Database, userId: string, input: CreateFileInput): Promise { + const rows = await db + .insert(files) + .values({ + userId, + storageKey: input.storageKey, + originalName: input.originalName ?? null, + mimeType: input.mimeType, + sizeBytes: input.sizeBytes, + metadata: input.metadata ?? {}, + }) + .returning() + + return requireRow(rows) +} + +async function touchConversation( + db: Database, + userId: string, + conversationId: string, +): Promise { + await db + .update(conversationsTable) + .set({ updatedAt: new Date() }) + .where(and(eq(conversationsTable.id, conversationId), eq(conversationsTable.userId, userId))) +} + +async function nextSequence(db: Database, conversationId: string): Promise { + const rows = await db + .select({ sequence: conversationEntries.sequence }) + .from(conversationEntries) + .where(eq(conversationEntries.conversationId, conversationId)) + .orderBy(desc(conversationEntries.sequence)) + .limit(1) + + return (rows[0]?.sequence ?? 0) + 1 +} + +function requireRow(rows: T[], message = "Expected database row"): T { + const row = rows[0] + if (!row) throw new Error(message) + return row +} + +function defaultVisibilityForKind(kind: ConversationEntryKind): ConversationEntryVisibility { + switch (kind) { + case ConversationEntryKind.UserMessage: + case ConversationEntryKind.AssistantMessage: + case ConversationEntryKind.Attachment: + return ConversationEntryVisibility.UserVisible + case ConversationEntryKind.ToolCall: + case ConversationEntryKind.ToolResult: + case ConversationEntryKind.ContextSummary: + case ConversationEntryKind.SystemNote: + return ConversationEntryVisibility.Internal + } +} diff --git a/apps/freya-backend/src/conversations/http.test.ts b/apps/freya-backend/src/conversations/http.test.ts index fc093e8..680356c 100644 --- a/apps/freya-backend/src/conversations/http.test.ts +++ b/apps/freya-backend/src/conversations/http.test.ts @@ -25,7 +25,7 @@ const listEntriesCalls: Array<{ params: ListConversationEntriesParams }> = [] -mock.module("./storage.ts", () => ({ +mock.module("./db-storage.ts", () => ({ conversations: (_db: Database, userId: string) => ({ async listConversations(): Promise { return conversationRowsByUser.get(userId) ?? [] diff --git a/apps/freya-backend/src/conversations/http.ts b/apps/freya-backend/src/conversations/http.ts index 5cc8d86..2f7d86b 100644 --- a/apps/freya-backend/src/conversations/http.ts +++ b/apps/freya-backend/src/conversations/http.ts @@ -8,8 +8,8 @@ import type { AuthSessionMiddleware } from "../auth/session-middleware.ts" import type { Database } from "../db/index.ts" import type { ConversationRow } from "./storage.ts" +import { conversations } from "./db-storage.ts" import { ConversationNotFoundError } from "./errors.ts" -import { conversations } from "./storage.ts" /** Hono environment populated by the conversations route middleware. */ type Env = { diff --git a/apps/freya-backend/src/conversations/storage.ts b/apps/freya-backend/src/conversations/storage.ts index 67b31a7..a7747ca 100644 --- a/apps/freya-backend/src/conversations/storage.ts +++ b/apps/freya-backend/src/conversations/storage.ts @@ -2,28 +2,70 @@ import { AssistantMessagePayload, AttachmentPayload, ConversationEntryKind, + ConversationEntryMetadata, ConversationEntryVisibility, ContextSummaryPayload, - ConversationEntryMetadata, GenericObjectPayload, UserMessagePayload, - type ConversationEntryPayload, } from "@freya/core" -import { type } from "arktype" -import { and, asc, desc, eq } from "drizzle-orm" - -import type { Database } from "../db/index.ts" import { conversationEntries, + conversationResponseState as conversationResponseStateTable, conversations as conversationsTable, files, - user, + type ConversationResponseStateStatus, } from "../db/schema.ts" -import { ConversationNotFoundError } from "./errors.ts" -const conversationEntryKind = type.enumerated(...Object.values(ConversationEntryKind)) -const conversationEntryVisibility = type.enumerated(...Object.values(ConversationEntryVisibility)) +export interface ConversationStorage { + transaction(tx: (storage: ConversationStorage) => T | Promise): Promise + createConversation(userId: string): Promise + listUserConversations(userId: string): Promise + findConversation(conversationId: string): Promise + getOrCreateConversation(userId: string): Promise + createFile(userId: string, input: CreateFileInput): Promise + appendEntry( + conversationId: string, + input: AppendConversationEntryInput, + ): Promise + appendAttachmentEntry( + conversationId: string, + input: AppendAttachmentEntryInput, + ): Promise + nextSequence(conversationId: string): Promise + listUserConversationEntries( + userId: string, + conversationId: string, + params?: ListConversationEntriesParams, + ): Promise + listPendingUserConversationEntries( + userId: string, + conversationId: string, + ): Promise + findConversationResponseState( + conversationId: string, + ): Promise + // TODO: add pagination support + listPendingResponseStates(): Promise + // TODO: add pagination support + listRunningResponseStates(): Promise + upsertConversationResponseState( + conversationId: string, + input: UpsertConversationResponseStateInput, + ): Promise + updateConversationResponseState( + conversationId: string, + input: UpdateConversationResponseStateInput, + ): Promise + markResponseStateStatus( + conversationIds: string[], + status: ConversationResponseStateStatus, + ): Promise + claimPendingConversationResponseState( + conversationId: string, + ): Promise + clearConversationResponseState(conversationId: string): Promise +} /** Database row shape for a conversation owned by a user. */ export type ConversationRow = typeof conversationsTable.$inferSelect @@ -31,6 +73,9 @@ export type ConversationRow = typeof conversationsTable.$inferSelect /** Database row shape for an entry in a conversation timeline. */ export type ConversationEntryRow = typeof conversationEntries.$inferSelect +/** Database row shape for pending assistant response state in a conversation. */ +export type ConversationResponseStateRow = typeof conversationResponseStateTable.$inferSelect + /** Database row shape for an uploaded file referenced by conversations. */ export type FileRow = typeof files.$inferSelect @@ -99,291 +144,26 @@ export interface ListConversationEntriesParams { visibility?: ConversationEntryVisibility } -export function conversations(db: Database, userId: string) { - const storage = { - async createConversation(): Promise { - return insertConversation(db, userId) - }, - - async listConversations(): Promise { - return db - .select() - .from(conversationsTable) - .where(eq(conversationsTable.userId, userId)) - .orderBy(desc(conversationsTable.updatedAt), desc(conversationsTable.createdAt)) - }, - - async getConversation(conversationId: string): Promise { - const rows = await db - .select() - .from(conversationsTable) - .where( - and(eq(conversationsTable.id, conversationId), eq(conversationsTable.userId, userId)), - ) - .limit(1) - - return rows[0] ?? null - }, - - async getOrCreateConversation(): Promise { - return db.transaction(async (tx) => { - await requireUserForUpdate(tx, userId) - const existing = await latestConversation(tx, userId) - if (existing) return existing - - return insertConversation(tx, userId) - }) - }, - - async createFile(input: CreateFileInput): Promise { - return insertFile(db, userId, input) - }, - - async appendEntry( - conversationId: string, - input: AppendConversationEntryInput, - ): Promise { - const kind = conversationEntryKind.assert(input.kind) - const visibility = conversationEntryVisibility.assert( - input.visibility ?? defaultVisibilityForKind(kind), - ) - const payload = payloadForKind(kind, input.payload) - const metadata = ConversationEntryMetadata.assert(input.metadata ?? {}) - let fileId: string | null = null - - if (input.kind === ConversationEntryKind.Attachment) { - fileId = input.fileId - await requireFile(db, userId, fileId) - } - - const rows = await db.transaction(async (tx) => { - if (!(await findConversationForUpdate(tx, userId, conversationId))) { - throw new ConversationNotFoundError(conversationId, userId) - } - const sequence = await nextSequence(tx, conversationId) - - const rows = await tx - .insert(conversationEntries) - .values({ - conversationId, - sequence, - kind, - visibility, - fileId, - payload, - metadata, - }) - .returning() - - await touchConversation(tx, userId, conversationId) - return rows - }) - - return requireRow(rows) - }, - - async appendAttachmentEntry( - conversationId: string, - input: AppendAttachmentEntryInput, - ): Promise { - const payload = AttachmentPayload.assert(input.payload) - const visibility = conversationEntryVisibility.assert( - input.visibility ?? defaultVisibilityForKind(ConversationEntryKind.Attachment), - ) - const metadata = ConversationEntryMetadata.assert(input.metadata ?? {}) - - return db.transaction(async (tx) => { - if (!(await findConversationForUpdate(tx, userId, conversationId))) { - throw new ConversationNotFoundError(conversationId, userId) - } - - const file = await insertFile(tx, userId, input.file) - const sequence = await nextSequence(tx, conversationId) - const rows = await tx - .insert(conversationEntries) - .values({ - conversationId, - sequence, - kind: ConversationEntryKind.Attachment, - visibility, - fileId: file.id, - payload, - metadata, - }) - .returning() - - await touchConversation(tx, userId, conversationId) - return { - file, - entry: requireRow(rows), - } - }) - }, - - async listEntries( - conversationId: string, - params: ListConversationEntriesParams = {}, - ): Promise { - if (!(await storage.getConversation(conversationId))) { - throw new ConversationNotFoundError(conversationId, userId) - } - - if (params.visibility) { - return db - .select() - .from(conversationEntries) - .where( - and( - eq(conversationEntries.conversationId, conversationId), - eq(conversationEntries.visibility, params.visibility), - ), - ) - .orderBy(asc(conversationEntries.sequence)) - } - - return db - .select() - .from(conversationEntries) - .where(eq(conversationEntries.conversationId, conversationId)) - .orderBy(asc(conversationEntries.sequence)) - }, - } - - return storage +/** Input for creating or replacing pending assistant response state. */ +export interface UpsertConversationResponseStateInput { + status?: ConversationResponseStateStatus + pendingSinceEntryId: string + maxWaitUntil: Date + runningSince?: Date | null } -function payloadForKind( - kind: ConversationEntryKind, - payload: AppendConversationEntryInput["payload"], -): ConversationEntryPayload { - switch (kind) { - case ConversationEntryKind.UserMessage: - return UserMessagePayload.assert(payload) - case ConversationEntryKind.AssistantMessage: - return AssistantMessagePayload.assert(payload) - case ConversationEntryKind.Attachment: - return AttachmentPayload.assert(payload) - case ConversationEntryKind.ContextSummary: - return ContextSummaryPayload.assert(payload) - case ConversationEntryKind.ToolCall: - case ConversationEntryKind.ToolResult: - case ConversationEntryKind.SystemNote: - return GenericObjectPayload.assert(payload) - } +/** Input for patching pending assistant response state. */ +export interface UpdateConversationResponseStateInput { + status?: ConversationResponseStateStatus + pendingSinceEntryId?: string + maxWaitUntil?: Date + runningSince?: Date | null } -async function requireUserForUpdate(db: Database, userId: string): Promise { - const rows = await db - .select({ id: user.id }) - .from(user) - .where(eq(user.id, userId)) - .limit(1) - .for("update") - - requireRow(rows, `User not found: ${userId}`) -} - -async function findConversationForUpdate( - db: Database, - userId: string, - conversationId: string, -): Promise { - const rows = await db - .select() - .from(conversationsTable) - .where(and(eq(conversationsTable.id, conversationId), eq(conversationsTable.userId, userId))) - .limit(1) - .for("update") - - return rows[0] ?? null -} - -async function latestConversation(db: Database, userId: string): Promise { - const rows = await db - .select() - .from(conversationsTable) - .where(eq(conversationsTable.userId, userId)) - .orderBy(desc(conversationsTable.updatedAt), desc(conversationsTable.createdAt)) - .limit(1) - - return rows[0] ?? null -} - -async function insertConversation(db: Database, userId: string): Promise { - const rows = await db - .insert(conversationsTable) - .values({ - userId, - }) - .returning() - - return requireRow(rows) -} - -async function requireFile(db: Database, userId: string, fileId: string): Promise { - const rows = await db - .select() - .from(files) - .where(and(eq(files.id, fileId), eq(files.userId, userId))) - .limit(1) - - return requireRow(rows, `File not found: ${fileId}`) -} - -async function insertFile(db: Database, userId: string, input: CreateFileInput): Promise { - const rows = await db - .insert(files) - .values({ - userId, - storageKey: input.storageKey, - originalName: input.originalName ?? null, - mimeType: input.mimeType, - sizeBytes: input.sizeBytes, - metadata: input.metadata ?? {}, - }) - .returning() - - return requireRow(rows) -} - -async function touchConversation( - db: Database, - userId: string, - conversationId: string, -): Promise { - await db - .update(conversationsTable) - .set({ updatedAt: new Date() }) - .where(and(eq(conversationsTable.id, conversationId), eq(conversationsTable.userId, userId))) -} - -async function nextSequence(db: Database, conversationId: string): Promise { - const rows = await db - .select({ sequence: conversationEntries.sequence }) - .from(conversationEntries) - .where(eq(conversationEntries.conversationId, conversationId)) - .orderBy(desc(conversationEntries.sequence)) - .limit(1) - - return (rows[0]?.sequence ?? 0) + 1 -} - -function requireRow(rows: T[], message = "Expected database row"): T { - const row = rows[0] - if (!row) throw new Error(message) - return row -} - -function defaultVisibilityForKind(kind: ConversationEntryKind): ConversationEntryVisibility { - switch (kind) { - case ConversationEntryKind.UserMessage: - case ConversationEntryKind.AssistantMessage: - case ConversationEntryKind.Attachment: - return ConversationEntryVisibility.UserVisible - case ConversationEntryKind.ToolCall: - case ConversationEntryKind.ToolResult: - case ConversationEntryKind.ContextSummary: - case ConversationEntryKind.SystemNote: - return ConversationEntryVisibility.Internal - } -} +export { + createConversationStorage, + conversationResponse, + conversations, + DrizzleConversationStorage, + findConversation, +} from "./db-storage.ts" diff --git a/apps/freya-backend/src/db/schema.ts b/apps/freya-backend/src/db/schema.ts index d822c90..cfb252a 100644 --- a/apps/freya-backend/src/db/schema.ts +++ b/apps/freya-backend/src/db/schema.ts @@ -48,6 +48,15 @@ const bytea = customType<{ data: Buffer }>({ }, }) +export const ConversationResponseStateStatus = { + Pending: "pending", + Running: "running", + Failed: "failed", +} as const + +export type ConversationResponseStateStatus = + (typeof ConversationResponseStateStatus)[keyof typeof ConversationResponseStateStatus] + export const userSources = pgTable( "user_sources", { @@ -146,6 +155,38 @@ export const conversationEntries = pgTable( ], ) +export const conversationResponseState = pgTable( + "conversation_response_state", + { + conversationId: uuid("conversation_id") + .primaryKey() + .references(() => conversations.id, { onDelete: "cascade" }), + status: text("status") + .$type() + .notNull() + .default(ConversationResponseStateStatus.Pending), + pendingSinceEntryId: uuid("pending_since_entry_id") + .notNull() + .references(() => conversationEntries.id, { onDelete: "cascade" }), + maxWaitUntil: timestamp("max_wait_until").notNull(), + runningSince: timestamp("running_since"), + createdAt: timestamp("created_at").notNull().defaultNow(), + updatedAt: timestamp("updated_at") + .notNull() + .defaultNow() + .$onUpdate(() => new Date()), + }, + (t) => [ + index("conversation_response_state_status_max_wait_until_idx").on(t.status, t.maxWaitUntil), + index("conversation_response_state_running_since_idx").on(t.runningSince), + index("conversation_response_state_pending_since_entry_id_idx").on(t.pendingSinceEntryId), + check( + "conversation_response_state_status_check", + sql`${t.status} in ('pending', 'running', 'failed')`, + ), + ], +) + // --------------------------------------------------------------------------- // FREYA — reminders source storage // --------------------------------------------------------------------------- diff --git a/apps/freya-backend/src/engine/http.test.ts b/apps/freya-backend/src/engine/http.test.ts index f2ebb31..1a205b0 100644 --- a/apps/freya-backend/src/engine/http.test.ts +++ b/apps/freya-backend/src/engine/http.test.ts @@ -14,7 +14,6 @@ interface FeedResponse { items: Array<{ id: string type: string - priority: number timestamp: string data: Record }> @@ -85,7 +84,7 @@ mock.module("../sources/user-sources.ts", () => ({ }), })) -mock.module("../conversations/storage.ts", () => ({ +mock.module("../conversations/db-storage.ts", () => ({ conversations: (_db: Database, userId: string) => ({ async getOrCreateConversation() { return { id: `conversation-${userId}` } @@ -118,7 +117,6 @@ describe("GET /api/feed", () => { id: "item-1", sourceId: "test", type: "test", - priority: 0.8, timestamp: new Date("2025-01-01T00:00:00.000Z"), data: { value: 42 }, }, @@ -149,7 +147,6 @@ describe("GET /api/feed", () => { expect(body.items).toHaveLength(1) expect(body.items[0]!.id).toBe("item-1") expect(body.items[0]!.type).toBe("test") - expect(body.items[0]!.priority).toBe(0.8) expect(body.items[0]!.timestamp).toBe("2025-01-01T00:00:00.000Z") expect(body.errors).toHaveLength(0) }) @@ -160,7 +157,6 @@ describe("GET /api/feed", () => { id: "fresh-1", sourceId: "test", type: "test", - priority: 0.5, timestamp: new Date("2025-06-01T12:00:00.000Z"), data: { fresh: true }, }, diff --git a/apps/freya-backend/src/enhancement/schema.test.ts b/apps/freya-backend/src/enhancement/schema.test.ts index 418de6c..01fee02 100644 --- a/apps/freya-backend/src/enhancement/schema.test.ts +++ b/apps/freya-backend/src/enhancement/schema.test.ts @@ -135,8 +135,9 @@ describe("schema sync", () => { // JSON Schema structure matches const jsonSchema = enhancementResultJsonSchema + const payloadKeys = Object.keys(payload).sort() as Array<(typeof jsonSchema.required)[number]> expect(Object.keys(jsonSchema.properties).sort()).toEqual(Object.keys(payload).sort()) - expect([...jsonSchema.required].sort()).toEqual(Object.keys(payload).sort()) + expect([...jsonSchema.required].sort()).toEqual(payloadKeys) // syntheticItems item schema has the right required fields const itemSchema = jsonSchema.properties.syntheticItems.items diff --git a/apps/freya-backend/src/lib/job.ts b/apps/freya-backend/src/lib/job.ts new file mode 100644 index 0000000..78b856a --- /dev/null +++ b/apps/freya-backend/src/lib/job.ts @@ -0,0 +1,116 @@ +import { Queue } from "./queue" + +const JobStatus = { + Pending: "pending", + Running: "running", +} as const +type JobStatus = (typeof JobStatus)[keyof typeof JobStatus] + +export interface Job { + id: number + payload: Payload + signal: AbortSignal +} + +interface PendingJob { + status: typeof JobStatus.Pending + controller: AbortController + job: Job +} + +interface RunningJob { + status: typeof JobStatus.Running + controller: AbortController + job: Job +} + +type JobState = PendingJob | RunningJob + +type JobEventListener = (job: Job) => void + +type JobEvent = "settled" | "cancelled" + +export class JobRegistry { + private queue = new Queue>() + + private states = new Map>() + + private listeners: Record[]> = { + settled: [], + cancelled: [], + } + + addJob({ payload }: { payload: Payload }): Job { + const controller = new AbortController() + const job: Job = { + id: this.generateJobId(), + payload, + signal: controller.signal, + } + this.queue.enqueue(job) + this.states.set(job.id, { status: JobStatus.Pending, controller, job }) + return job + } + + async nextJob(signal?: AbortSignal): Promise | null> { + while (true) { + const job = await this.queue.next(signal) + if (!job) { + return null + } + + const state = this.states.get(job.id) + + if (!state || state.job !== job || state.status === JobStatus.Running) { + continue + } + if (state.controller.signal.aborted) { + this.states.delete(job.id) + continue + } + + this.states.set(job.id, { status: JobStatus.Running, controller: state.controller, job }) + + return job + } + } + + cancelJob(job: Job): void { + const state = this.states.get(job.id) + if (state?.job === job) { + state?.controller.abort() + this.notifyListeners("cancelled", job.id) + this.states.delete(job.id) + } + } + + markJobAsCompleted(job: Job): void { + const state = this.states.get(job.id) + if (state?.job === job) { + this.notifyListeners("settled", job.id) + this.states.delete(job.id) + } + } + + addEventListener(event: JobEvent, listener: JobEventListener): () => void { + this.listeners[event].push(listener) + return () => { + this.listeners[event] = this.listeners[event].filter((l) => l !== listener) + } + } + + private generateJobId(): number { + let id: number + do { + id = Math.floor(Math.random() * 1000000) + } while (this.states.has(id)) + return id + } + + private notifyListeners(event: JobEvent, id: number): void { + const job = this.states.get(id)?.job + if (job) { + this.listeners[event].forEach((listener) => listener(job)) + } + } +} diff --git a/apps/freya-backend/src/lib/queue.ts b/apps/freya-backend/src/lib/queue.ts new file mode 100644 index 0000000..b4b8e29 --- /dev/null +++ b/apps/freya-backend/src/lib/queue.ts @@ -0,0 +1,69 @@ +interface Item { + value: T + next: Item | null +} + +export class Queue { + private front: Item | null = null + private back: Item | null = null + private waiters: Array<(value: T) => void> = [] + + enqueue(value: T): void { + const waiter = this.waiters.shift() + if (waiter) { + waiter(value) + return + } + + const newItem: Item = { value, next: null } + if (this.back) { + this.back.next = newItem + } else { + this.front = newItem + } + this.back = newItem + } + + dequeue(): T | null { + if (!this.front) return null + const value = this.front.value + this.front = this.front.next + if (!this.front) this.back = null + return value + } + + next(signal?: AbortSignal): Promise { + const value = this.dequeue() + if (value !== null) return Promise.resolve(value) + + return new Promise((resolve) => { + if (signal) { + if (signal.aborted) { + resolve(null) + } else { + let _resolve: (v: T) => void + + const onAbort = () => { + this.waiters = this.waiters.filter((w) => w !== _resolve) + resolve(null) + } + + signal.addEventListener( + "abort", + onAbort, + { once: true }, + ) + + _resolve = (v: T) => { + signal.removeEventListener("abort", onAbort) + resolve(v) + } + + this.waiters.push(_resolve) + } + } else { + this.waiters.push(resolve) + } + }) + } +} diff --git a/apps/freya-backend/src/lib/worker.ts b/apps/freya-backend/src/lib/worker.ts new file mode 100644 index 0000000..b7c733a --- /dev/null +++ b/apps/freya-backend/src/lib/worker.ts @@ -0,0 +1,51 @@ +import type { Job, JobRegistry } from "./job" +import type { Queue } from "./queue" + +export interface JobExecutor { + execute(job: Job): Promise +} + +export interface WorkerConfig { + concurrency: number + registry: JobRegistry + runner: JobExecutor + signal: AbortSignal +} + +export class Worker { + private concurrency: number + private registry: JobRegistry + private runner: JobExecutor + private signal: AbortSignal + + constructor({ concurrency, registry, runner, signal }: WorkerConfig) { + this.concurrency = concurrency + this.registry = registry + this.runner = runner + this.signal = signal + } + + start() { + if (this.signal.aborted) return + for (let i = 0; i < this.concurrency; i++) { + void this.pollJobFromRegistry() + } + } + + private async pollJobFromRegistry() { + while (!this.signal.aborted) { + const job = await this.registry.nextJob(this.signal) + if (!job) { + return + } + + try { + await this.runner.execute(job) + } catch { + // TODO: handle logging of job execution errors + } finally { + this.registry.markJobAsCompleted(job) + } + } + } +} diff --git a/apps/freya-backend/src/notification/notification-central.ts b/apps/freya-backend/src/notification/notification-central.ts new file mode 100644 index 0000000..a3d87d0 --- /dev/null +++ b/apps/freya-backend/src/notification/notification-central.ts @@ -0,0 +1,36 @@ +import type { AgentEvent } from "@freya/agent-protocol" + +export interface AgentNotification { + kind: "agent" + payload: AgentEvent +} + +export type NotificationPayload = AgentNotification +export type NotificationListener = (notification: NotificationPayload) => Promise + +export class NotificationCentral { + private listeners: Map> = new Map() + + registerListenerForUser(userId: string, listener: NotificationListener): () => void { + let listeners = this.listeners.get(userId) + if (!listeners) { + listeners = new Set() + this.listeners.set(userId, listeners) + } + + listeners.add(listener) + return () => { + listeners.delete(listener) + if (listeners.size === 0) { + this.listeners.delete(userId) + } + } + } + + async notifyUser(userId: string, notification: NotificationPayload): Promise { + const listeners = this.listeners.get(userId) + if (!listeners) return + + await Promise.allSettled(Array.from(listeners).map((listener) => listener(notification))) + } +} diff --git a/apps/freya-backend/src/server.ts b/apps/freya-backend/src/server.ts index 34ab18d..23faf5b 100644 --- a/apps/freya-backend/src/server.ts +++ b/apps/freya-backend/src/server.ts @@ -5,6 +5,7 @@ import { createMiddleware } from "hono/factory" import { registerAdminHttpHandlers } from "./admin/http.ts" import { createQueryDebugTools } from "./agent/debug-tools.ts" import { registerAgentHttpHandlers, registerDebugAgentHttpHandlers } from "./agent/http.ts" +import { AgentService } from "./agent/service.ts" import { agentWebSocket, registerAgentWebSocketHandlers } from "./agent/ws.ts" import { createRequireAdmin } from "./auth/admin-middleware.ts" import { registerAuthHandlers } from "./auth/http.ts" @@ -12,6 +13,7 @@ import { createAuth } from "./auth/index.ts" import { createRequireSession } from "./auth/session-middleware.ts" import { CalDavSourceProvider } from "./caldav/provider.ts" import { registerConversationsHttpHandlers } from "./conversations/http.ts" +import { DrizzleConversationStorage } from "./conversations/storage.ts" import { createDatabase } from "./db/index.ts" import { registerFeedHttpHandlers } from "./engine/http.ts" import { createFeedEnhancer } from "./enhancement/enhance-feed.ts" @@ -21,6 +23,7 @@ import { CredentialEncryptor } from "./lib/crypto.ts" import { ensureEnv } from "./lib/env.ts" import { registerLocationHttpHandlers } from "./location/http.ts" import { LocationSourceProvider } from "./location/provider.ts" +import { NotificationCentral } from "./notification/notification-central.ts" import { ReminderSourceProvider } from "./reminders/provider.ts" import { UserSessionManager } from "./session/index.ts" import { registerSourcesHttpHandlers } from "./sources/http.ts" @@ -32,8 +35,12 @@ function main() { const env = ensureEnv(process.env) const { db, close: closeDb } = createDatabase(env.databaseUrl) + const conversationStorage = new DrizzleConversationStorage(db, false) + const auth = createAuth(db) + const abortController = new AbortController() + const feedEnhancer = createFeedEnhancer({ client: createLlmClient({ apiKey: env.openrouterApiKey, @@ -73,6 +80,15 @@ function main() { console.warn("[query] PI_API_KEY or OPENROUTER_API_KEY not set — query agent unavailable") } + const notificationCentral = new NotificationCentral() + + const agentService = new AgentService({ + notificationCentral, + storage: conversationStorage, + userSessionManager: sessionManager, + signal: abortController.signal, + }) + const app = new Hono() const isDev = process.env.NODE_ENV !== "production" @@ -141,17 +157,22 @@ function main() { registerAdminHttpHandlers(app, { sessionManager, adminMiddleware, db }) registerAgentWebSocketHandlers(app, { - sessionManager, + agentService, + notificationCentral, + storage: conversationStorage, authSessionMiddleware, corsMiddleware: agentWebSocketCorsMiddleware, }) process.on("SIGTERM", async () => { sessionManager.dispose() + abortController.abort() await closeDb() process.exit(0) }) + agentService.start() + return app } diff --git a/apps/freya-backend/src/session/user-session-manager.test.ts b/apps/freya-backend/src/session/user-session-manager.test.ts index 48f5764..c3a2e21 100644 --- a/apps/freya-backend/src/session/user-session-manager.test.ts +++ b/apps/freya-backend/src/session/user-session-manager.test.ts @@ -120,7 +120,7 @@ mock.module("../sources/user-sources.ts", () => ({ }), })) -mock.module("../conversations/storage.ts", () => ({ +mock.module("../conversations/db-storage.ts", () => ({ conversations: (_db: Database, userId: string) => ({ async getOrCreateConversation(): Promise<{ id: string }> { mockConversationCalls.push({ type: "getOrCreate", userId }) diff --git a/apps/freya-backend/src/session/user-session-manager.ts b/apps/freya-backend/src/session/user-session-manager.ts index 18143d1..221db6d 100644 --- a/apps/freya-backend/src/session/user-session-manager.ts +++ b/apps/freya-backend/src/session/user-session-manager.ts @@ -8,7 +8,7 @@ import type { FeedEnhancer } from "../enhancement/enhance-feed.ts" import type { CredentialEncryptor } from "../lib/crypto.ts" import type { FeedSourceProvider } from "./feed-source-provider.ts" -import { conversations } from "../conversations/storage.ts" +import { conversations } from "../conversations/db-storage.ts" import { CredentialStorageUnavailableError, InvalidSourceConfigError, diff --git a/apps/freya-backend/src/session/user-session.ts b/apps/freya-backend/src/session/user-session.ts index 2ee5c95..a6aa719 100644 --- a/apps/freya-backend/src/session/user-session.ts +++ b/apps/freya-backend/src/session/user-session.ts @@ -263,18 +263,12 @@ export class UserSession { const conversation = await conversationStorage.getOrCreateConversation() const entries = await conversationStorage.listEntries(conversation.id) - this.queryAgent = new ConversationRecordingQueryAgent({ - agent: new PiQueryAgent({ - toolbox: this.toolbox, - apiKey: this.agentConfig?.apiKey, - cwd: this.agentConfig?.cwd, - systemPrompt: this.agentConfig?.systemPrompt, - initialEntries: entries, - }), - storage: conversationStorage, - defaultConversationId: conversation.id, - modelProvider: PI_MODEL_PROVIDER, - modelId: PI_MODEL_ID, + this.queryAgent = new PiQueryAgent({ + toolbox: this.toolbox, + apiKey: this.agentConfig?.apiKey, + cwd: this.agentConfig?.cwd, + systemPrompt: this.agentConfig?.systemPrompt, + initialEntries: entries, }) } diff --git a/apps/freya-backend/src/sources/http.test.ts b/apps/freya-backend/src/sources/http.test.ts index abee1f8..b63fdef 100644 --- a/apps/freya-backend/src/sources/http.test.ts +++ b/apps/freya-backend/src/sources/http.test.ts @@ -128,7 +128,7 @@ mock.module("../sources/user-sources.ts", () => ({ }, })) -mock.module("../conversations/storage.ts", () => ({ +mock.module("../conversations/db-storage.ts", () => ({ conversations: (_db: Database, userId: string) => ({ async getOrCreateConversation() { return { id: `conversation-${userId}` } diff --git a/bun.lock b/bun.lock index 0f31646..fca2798 100644 --- a/bun.lock +++ b/bun.lock @@ -51,7 +51,7 @@ "version": "0.0.0", "dependencies": { "@freya/agent-protocol": "workspace:*", - "@nym.sh/jrpc": "^0.1.0", + "@nym.sh/jrpc": "1.1.0", }, }, "apps/freya-backend": { @@ -69,7 +69,7 @@ "@freya/source-tfl": "workspace:*", "@freya/source-weatherkit": "workspace:*", "@freya/source-web-search": "workspace:*", - "@nym.sh/jrpc": "^0.1.0", + "@nym.sh/jrpc": "1.1.0", "@openrouter/sdk": "^0.9.11", "arktype": "^2.1.29", "better-auth": "^1", @@ -838,7 +838,7 @@ "@nolyfill/is-core-module": ["@nolyfill/is-core-module@1.0.39", "", {}, "sha512-nn5ozdjYQpUCZlWGuxcJY/KpxkWQs4DcbMCmKojjyrYDEAGy4Ce19NN4v5MduafTwJlbKc99UA8YhSVqq9yPZA=="], - "@nym.sh/jrpc": ["@nym.sh/jrpc@0.1.0", "", {}, "sha512-qH+vqKojPrX4RkW67U2R4J98uWHxZOwYxX2J5GLZcfm/yjklCcN5zTfDNLfgAa9jAoOFVscC3DFWhvdZOmN3fA=="], + "@nym.sh/jrpc": ["@nym.sh/jrpc@1.1.0", "", {}, "sha512-212SYMB37GdL8enaRTTqG/LNa5bJ7eYth6jfQfECuedQCuaju0bOMUzCN6hvY5KkrxdYuqVKmr2Uz+ZZTjPlaQ=="], "@nym.sh/jrx": ["@nym.sh/jrx@0.2.0", "", { "peerDependencies": { "@json-render/core": ">=0.10.0" } }, "sha512-jd7Z1Q6T21366MtSUnwCFiu6Yl1AdNc9s5m6HxeUg265P+0enZCiyyxOuHsFwvpUcSEs/2DVBsqfMptdca44lA=="], diff --git a/packages/freya-agent-protocol/src/index.ts b/packages/freya-agent-protocol/src/index.ts index 0c1004a..0418489 100644 --- a/packages/freya-agent-protocol/src/index.ts +++ b/packages/freya-agent-protocol/src/index.ts @@ -1,8 +1,3 @@ -export interface SendMessageResult { - message: string - conversationId: string -} - export type AgentEvent = | { type: "conversation_started"; conversationId: string } | { type: "message_created"; text: string } @@ -11,8 +6,11 @@ export type AgentEvent = | { type: "message_finished" } | { type: "message_failed"; error: string } +export type UserEvent = { type: "typing" } + export interface AgentServerApi { - sendMessage(message: string): Promise + sendMessage(message: string): Promise + notify(event: UserEvent): void ping(): "pong" } diff --git a/packages/freya-core/src/conversation.ts b/packages/freya-core/src/conversation.ts index 276e34a..d107f6f 100644 --- a/packages/freya-core/src/conversation.ts +++ b/packages/freya-core/src/conversation.ts @@ -146,6 +146,19 @@ export const ConversationEntryMetadata = type({ /** Metadata bag attached to a conversation entry. */ export type ConversationEntryMetadata = typeof ConversationEntryMetadata.infer +export const ToolCallPayload = type({ + toolName: "string", +}) + +export type ToolCallPayload = typeof ToolCallPayload.infer + +export const ToolResultPayload = type({ + toolName: "string", + ok: "boolean", +}) + +export type ToolResultPayload = typeof ToolResultPayload.infer + /** Generic object payload used by operational entries. */ export const GenericObjectPayload = type("Record") @@ -158,4 +171,6 @@ export type ConversationEntryPayload = | AssistantMessagePayload | AttachmentPayload | ContextSummaryPayload + | ToolCallPayload + | ToolResultPayload | GenericObjectPayload diff --git a/packages/freya-core/src/index.ts b/packages/freya-core/src/index.ts index 97a64cf..4ae7629 100644 --- a/packages/freya-core/src/index.ts +++ b/packages/freya-core/src/index.ts @@ -23,6 +23,8 @@ export { ModelRunMetadata, TextMessagePart, UserMessagePayload, + ToolCallPayload, + ToolResultPayload, } from "./conversation" // Feed