feat: create inference ws server

This commit is contained in:
2024-07-22 22:24:39 +01:00
parent e4e4fc5f45
commit e1866e08f5
4 changed files with 67 additions and 3 deletions

View File

@@ -1,16 +1,18 @@
import threading
import os
from generate import generate
import websocket
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from audiocraft.data.audio import audio_write
# the index of the current audio track from 0 to 9
current_index = -1
# the timer that periodically advances the current audio track
t = None
# websocket connection to the inference server
ws = None
prompts = [
"gentle, calming lo-fi beats that helps with studying and focusing",
@@ -23,13 +25,29 @@ prompts = [
@asynccontextmanager
async def lifespan(app: FastAPI):
global ws
url = os.environ.get("INFERENCE_SERVER_WS_URL")
if not url:
url = "ws://localhost:8001"
ws = websocket.create_connection(url)
print(f"websocket connected to {url}")
advance()
yield
if ws:
ws.close()
if t:
t.cancel()
def generate_new_audio():
if not ws:
return
global current_index
offset = 0
@@ -42,7 +60,18 @@ def generate_new_audio():
print("generating new audio...")
generate(offset)
ws.send("generate")
wavs = []
for i in range(5):
raw = ws.recv()
if isinstance(raw, str):
continue
wavs.append(raw)
for i, wav in enumerate(wavs):
with open(f"{i + offset}.mp3", "wb") as f:
f.write(wav)
print("audio generated.")