refactor: use http polling instead of websocket
This commit is contained in:
73
server.py
73
server.py
@@ -1,9 +1,11 @@
|
||||
import threading
|
||||
import os
|
||||
from time import sleep
|
||||
import requests
|
||||
|
||||
import websocket
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, status
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from logger import log_info, log_warn
|
||||
@@ -15,18 +17,18 @@ current_index = -1
|
||||
t = None
|
||||
# websocket connection to the inference server
|
||||
ws = None
|
||||
ws_url = ""
|
||||
inference_url = ""
|
||||
ws_connection_manager = WebSocketConnectionManager()
|
||||
active_listeners = set()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global ws, ws_url
|
||||
global ws, inference_url
|
||||
|
||||
ws_url = os.environ.get("INFERENCE_SERVER_WS_URL")
|
||||
if not ws_url:
|
||||
ws_url = "ws://localhost:8001"
|
||||
inference_url = os.environ.get("INFERENCE_SERVER_URL")
|
||||
if not inference_url:
|
||||
inference_url = "ws://localhost:8001"
|
||||
|
||||
advance()
|
||||
|
||||
@@ -39,7 +41,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
def generate_new_audio():
|
||||
if not ws_url:
|
||||
if not inference_url:
|
||||
return
|
||||
|
||||
global current_index
|
||||
@@ -52,31 +54,50 @@ def generate_new_audio():
|
||||
else:
|
||||
return
|
||||
|
||||
log_info("generating new audio...")
|
||||
log_info("requesting new audio...")
|
||||
|
||||
try:
|
||||
ws = websocket.create_connection(ws_url)
|
||||
|
||||
ws.send("generate")
|
||||
|
||||
wavs = []
|
||||
for i in range(5):
|
||||
raw = ws.recv()
|
||||
if isinstance(raw, str):
|
||||
continue
|
||||
wavs.append(raw)
|
||||
|
||||
for i, wav in enumerate(wavs):
|
||||
with open(f"{i + offset}.mp3", "wb") as f:
|
||||
f.write(wav)
|
||||
|
||||
log_info("audio generated.")
|
||||
|
||||
ws.close()
|
||||
print(f"{inference_url}/generate")
|
||||
requests.post(f"{inference_url}/generate")
|
||||
except:
|
||||
log_warn(
|
||||
"inference server potentially unreachable. recycling cached audio for now."
|
||||
)
|
||||
return
|
||||
|
||||
is_available = False
|
||||
while not is_available:
|
||||
try:
|
||||
res = requests.post(f"{inference_url}/clips/0", stream=True)
|
||||
except:
|
||||
log_warn(
|
||||
"inference server potentially unreachable. recycling cached audio for now."
|
||||
)
|
||||
return
|
||||
|
||||
if res.status_code != status.HTTP_200_OK:
|
||||
print("still generating...")
|
||||
sleep(5)
|
||||
continue
|
||||
|
||||
print("inference complete! downloading new clips")
|
||||
|
||||
is_available = True
|
||||
with open(f"{offset}.mp3", "wb") as f:
|
||||
for chunk in res.iter_content(chunk_size=128):
|
||||
f.write(chunk)
|
||||
|
||||
for i in range(4):
|
||||
res = requests.post(f"{inference_url}/clips/{i + 1}", stream=True)
|
||||
|
||||
if res.status_code != status.HTTP_200_OK:
|
||||
continue
|
||||
|
||||
with open(f"{i + 1 + offset}.mp3", "wb") as f:
|
||||
for chunk in res.iter_content(chunk_size=128):
|
||||
f.write(chunk)
|
||||
|
||||
log_info("audio generated.")
|
||||
|
||||
|
||||
def advance():
|
||||
|
Reference in New Issue
Block a user