feat: create inference ws server
This commit is contained in:
30
inference_server.py
Normal file
30
inference_server.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import asyncio
|
||||||
|
from websockets.server import serve
|
||||||
|
|
||||||
|
# from generate import generate
|
||||||
|
|
||||||
|
|
||||||
|
async def handler(websocket):
|
||||||
|
async for message in websocket:
|
||||||
|
if message != "generate":
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("generating new audio clips...")
|
||||||
|
|
||||||
|
# generate()
|
||||||
|
|
||||||
|
print("audio generated")
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
with open(f"{i + 5}.mp3", "rb") as f:
|
||||||
|
data = f.read()
|
||||||
|
await websocket.send(data)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with serve(handler, "", 8001):
|
||||||
|
await asyncio.Future()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
3
requirements-inference.txt
Normal file
3
requirements-inference.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
audiocraft==1.3.0
|
||||||
|
torchaudio==2.1.0
|
||||||
|
websockets==11.0.3
|
2
requirements-server.txt
Normal file
2
requirements-server.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
fastapi==0.111.1
|
||||||
|
websocket_client==1.8.0
|
35
server.py
35
server.py
@@ -1,16 +1,18 @@
|
|||||||
import threading
|
import threading
|
||||||
|
import os
|
||||||
|
|
||||||
from generate import generate
|
import websocket
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from audiocraft.data.audio import audio_write
|
|
||||||
|
|
||||||
# the index of the current audio track from 0 to 9
|
# the index of the current audio track from 0 to 9
|
||||||
current_index = -1
|
current_index = -1
|
||||||
# the timer that periodically advances the current audio track
|
# the timer that periodically advances the current audio track
|
||||||
t = None
|
t = None
|
||||||
|
# websocket connection to the inference server
|
||||||
|
ws = None
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"gentle, calming lo-fi beats that helps with studying and focusing",
|
"gentle, calming lo-fi beats that helps with studying and focusing",
|
||||||
@@ -23,13 +25,29 @@ prompts = [
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
global ws
|
||||||
|
|
||||||
|
url = os.environ.get("INFERENCE_SERVER_WS_URL")
|
||||||
|
if not url:
|
||||||
|
url = "ws://localhost:8001"
|
||||||
|
|
||||||
|
ws = websocket.create_connection(url)
|
||||||
|
print(f"websocket connected to {url}")
|
||||||
|
|
||||||
advance()
|
advance()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
if ws:
|
||||||
|
ws.close()
|
||||||
if t:
|
if t:
|
||||||
t.cancel()
|
t.cancel()
|
||||||
|
|
||||||
|
|
||||||
def generate_new_audio():
|
def generate_new_audio():
|
||||||
|
if not ws:
|
||||||
|
return
|
||||||
|
|
||||||
global current_index
|
global current_index
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
@@ -42,7 +60,18 @@ def generate_new_audio():
|
|||||||
|
|
||||||
print("generating new audio...")
|
print("generating new audio...")
|
||||||
|
|
||||||
generate(offset)
|
ws.send("generate")
|
||||||
|
|
||||||
|
wavs = []
|
||||||
|
for i in range(5):
|
||||||
|
raw = ws.recv()
|
||||||
|
if isinstance(raw, str):
|
||||||
|
continue
|
||||||
|
wavs.append(raw)
|
||||||
|
|
||||||
|
for i, wav in enumerate(wavs):
|
||||||
|
with open(f"{i + offset}.mp3", "wb") as f:
|
||||||
|
f.write(wav)
|
||||||
|
|
||||||
print("audio generated.")
|
print("audio generated.")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user