feat: create inference ws server

This commit is contained in:
2024-07-22 22:24:39 +01:00
parent e4e4fc5f45
commit e1866e08f5
4 changed files with 67 additions and 3 deletions

30
inference_server.py Normal file
View File

@@ -0,0 +1,30 @@
import asyncio
from websockets.server import serve
# from generate import generate
async def handler(websocket):
async for message in websocket:
if message != "generate":
continue
print("generating new audio clips...")
# generate()
print("audio generated")
for i in range(5):
with open(f"{i + 5}.mp3", "rb") as f:
data = f.read()
await websocket.send(data)
async def main():
async with serve(handler, "", 8001):
await asyncio.Future()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,3 @@
audiocraft==1.3.0
torchaudio==2.1.0
websockets==11.0.3

2
requirements-server.txt Normal file
View File

@@ -0,0 +1,2 @@
fastapi==0.111.1
websocket_client==1.8.0

View File

@@ -1,16 +1,18 @@
import threading import threading
import os
from generate import generate import websocket
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from audiocraft.data.audio import audio_write
# the index of the current audio track from 0 to 9 # the index of the current audio track from 0 to 9
current_index = -1 current_index = -1
# the timer that periodically advances the current audio track # the timer that periodically advances the current audio track
t = None t = None
# websocket connection to the inference server
ws = None
prompts = [ prompts = [
"gentle, calming lo-fi beats that helps with studying and focusing", "gentle, calming lo-fi beats that helps with studying and focusing",
@@ -23,13 +25,29 @@ prompts = [
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global ws
url = os.environ.get("INFERENCE_SERVER_WS_URL")
if not url:
url = "ws://localhost:8001"
ws = websocket.create_connection(url)
print(f"websocket connected to {url}")
advance() advance()
yield yield
if ws:
ws.close()
if t: if t:
t.cancel() t.cancel()
def generate_new_audio(): def generate_new_audio():
if not ws:
return
global current_index global current_index
offset = 0 offset = 0
@@ -42,7 +60,18 @@ def generate_new_audio():
print("generating new audio...") print("generating new audio...")
generate(offset) ws.send("generate")
wavs = []
for i in range(5):
raw = ws.recv()
if isinstance(raw, str):
continue
wavs.append(raw)
for i, wav in enumerate(wavs):
with open(f"{i + offset}.mp3", "wb") as f:
f.write(wav)
print("audio generated.") print("audio generated.")