feat: add websocket endpoint for fal

This commit is contained in:
2024-07-25 11:00:52 +01:00
parent 0e744af73f
commit 37cf800d8d

View File

@@ -1,5 +1,6 @@
import io
import fal
from fastapi import WebSocket
import torch
from fal.toolkit import File
@@ -33,3 +34,16 @@ class InfinifiFalApp(fal.App, keep_alive=300):
serialized.append(buf.getvalue())
return serialized
@fal.endpoint("/ws")
async def run_ws(self, ws: WebSocket):
await ws.accept()
wav = self.model.generate(PROMPTS)
for one_wav in enumerate(wav):
buf = io.BytesIO()
torch.save(one_wav, buf)
await ws.send_bytes(buf.getvalue())
await ws.close()