feat: create inference ws server
This commit is contained in:
30
inference_server.py
Normal file
30
inference_server.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import asyncio
|
||||
from websockets.server import serve
|
||||
|
||||
# from generate import generate
|
||||
|
||||
|
||||
async def handler(websocket):
|
||||
async for message in websocket:
|
||||
if message != "generate":
|
||||
continue
|
||||
|
||||
print("generating new audio clips...")
|
||||
|
||||
# generate()
|
||||
|
||||
print("audio generated")
|
||||
|
||||
for i in range(5):
|
||||
with open(f"{i + 5}.mp3", "rb") as f:
|
||||
data = f.read()
|
||||
await websocket.send(data)
|
||||
|
||||
|
||||
async def main():
|
||||
async with serve(handler, "", 8001):
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
3
requirements-inference.txt
Normal file
3
requirements-inference.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
audiocraft==1.3.0
|
||||
torchaudio==2.1.0
|
||||
websockets==11.0.3
|
2
requirements-server.txt
Normal file
2
requirements-server.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
fastapi==0.111.1
|
||||
websocket_client==1.8.0
|
35
server.py
35
server.py
@@ -1,16 +1,18 @@
|
||||
import threading
|
||||
import os
|
||||
|
||||
from generate import generate
|
||||
import websocket
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from audiocraft.data.audio import audio_write
|
||||
|
||||
# the index of the current audio track from 0 to 9
|
||||
current_index = -1
|
||||
# the timer that periodically advances the current audio track
|
||||
t = None
|
||||
# websocket connection to the inference server
|
||||
ws = None
|
||||
|
||||
prompts = [
|
||||
"gentle, calming lo-fi beats that helps with studying and focusing",
|
||||
@@ -23,13 +25,29 @@ prompts = [
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global ws
|
||||
|
||||
url = os.environ.get("INFERENCE_SERVER_WS_URL")
|
||||
if not url:
|
||||
url = "ws://localhost:8001"
|
||||
|
||||
ws = websocket.create_connection(url)
|
||||
print(f"websocket connected to {url}")
|
||||
|
||||
advance()
|
||||
|
||||
yield
|
||||
|
||||
if ws:
|
||||
ws.close()
|
||||
if t:
|
||||
t.cancel()
|
||||
|
||||
|
||||
def generate_new_audio():
|
||||
if not ws:
|
||||
return
|
||||
|
||||
global current_index
|
||||
|
||||
offset = 0
|
||||
@@ -42,7 +60,18 @@ def generate_new_audio():
|
||||
|
||||
print("generating new audio...")
|
||||
|
||||
generate(offset)
|
||||
ws.send("generate")
|
||||
|
||||
wavs = []
|
||||
for i in range(5):
|
||||
raw = ws.recv()
|
||||
if isinstance(raw, str):
|
||||
continue
|
||||
wavs.append(raw)
|
||||
|
||||
for i, wav in enumerate(wavs):
|
||||
with open(f"{i + offset}.mp3", "wb") as f:
|
||||
f.write(wav)
|
||||
|
||||
print("audio generated.")
|
||||
|
||||
|
Reference in New Issue
Block a user