refactor: use http polling instead of websocket

This commit is contained in:
2024-08-22 23:48:16 +01:00
parent 13ae315b7d
commit 58b47201a6
2 changed files with 78 additions and 45 deletions

View File

@@ -1,9 +1,11 @@
import threading
import os
from time import sleep
import requests
import websocket
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, status
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from logger import log_info, log_warn
@@ -15,18 +17,18 @@ current_index = -1
t = None
# websocket connection to the inference server
ws = None
ws_url = ""
inference_url = ""
ws_connection_manager = WebSocketConnectionManager()
active_listeners = set()
@asynccontextmanager
async def lifespan(app: FastAPI):
global ws, ws_url
global ws, inference_url
ws_url = os.environ.get("INFERENCE_SERVER_WS_URL")
if not ws_url:
ws_url = "ws://localhost:8001"
inference_url = os.environ.get("INFERENCE_SERVER_URL")
if not inference_url:
inference_url = "ws://localhost:8001"
advance()
@@ -39,7 +41,7 @@ async def lifespan(app: FastAPI):
def generate_new_audio():
if not ws_url:
if not inference_url:
return
global current_index
@@ -52,31 +54,50 @@ def generate_new_audio():
else:
return
log_info("generating new audio...")
log_info("requesting new audio...")
try:
ws = websocket.create_connection(ws_url)
ws.send("generate")
wavs = []
for i in range(5):
raw = ws.recv()
if isinstance(raw, str):
continue
wavs.append(raw)
for i, wav in enumerate(wavs):
with open(f"{i + offset}.mp3", "wb") as f:
f.write(wav)
log_info("audio generated.")
ws.close()
print(f"{inference_url}/generate")
requests.post(f"{inference_url}/generate")
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
return
is_available = False
while not is_available:
try:
res = requests.post(f"{inference_url}/clips/0", stream=True)
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
return
if res.status_code != status.HTTP_200_OK:
print("still generating...")
sleep(5)
continue
print("inference complete! downloading new clips")
is_available = True
with open(f"{offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)
for i in range(4):
res = requests.post(f"{inference_url}/clips/{i + 1}", stream=True)
if res.status_code != status.HTTP_200_OK:
continue
with open(f"{i + 1 + offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)
log_info("audio generated.")
def advance():