diff --git a/fal_app.py b/fal_app.py index bbb1bf2..dd4ade7 100644 --- a/fal_app.py +++ b/fal_app.py @@ -1,5 +1,6 @@ import io import fal +from fastapi import WebSocket import torch from fal.toolkit import File @@ -33,3 +34,16 @@ class InfinifiFalApp(fal.App, keep_alive=300): serialized.append(buf.getvalue()) return serialized + + @fal.endpoint("/ws") + async def run_ws(self, ws: WebSocket): + await ws.accept() + + wav = self.model.generate(PROMPTS) + + for one_wav in enumerate(wav): + buf = io.BytesIO() + torch.save(one_wav, buf) + await ws.send_bytes(buf.getvalue()) + + await ws.close()