diff --git a/inference_server.py b/inference_server.py index 04bbdb1..87de50d 100644 --- a/inference_server.py +++ b/inference_server.py @@ -22,7 +22,7 @@ async def handler(websocket): async def main(): - async with serve(handler, "", 8001): + async with serve(handler, "localhost", 8001, ping_interval=5): await asyncio.Future() diff --git a/server.py b/server.py index b52353d..51c8cd6 100644 --- a/server.py +++ b/server.py @@ -13,17 +13,16 @@ current_index = -1 t = None # websocket connection to the inference server ws = None +ws_url = "" + @asynccontextmanager async def lifespan(app: FastAPI): - global ws + global ws, ws_url - url = os.environ.get("INFERENCE_SERVER_WS_URL") - if not url: - url = "ws://localhost:8001" - - ws = websocket.create_connection(url) - print(f"websocket connected to {url}") + ws_url = os.environ.get("INFERENCE_SERVER_WS_URL") + if not ws_url: + ws_url = "ws://localhost:8001" advance() @@ -36,7 +35,7 @@ async def lifespan(app: FastAPI): def generate_new_audio(): - if not ws: + if not ws_url: return global current_index @@ -51,6 +50,9 @@ def generate_new_audio(): print("generating new audio...") + ws = websocket.create_connection(ws_url) + print(f"websocket connected to {ws_url}") + ws.send("generate") wavs = [] @@ -66,6 +68,8 @@ def generate_new_audio(): print("audio generated.") + ws.close() + def advance(): global current_index, t