refactor: use http polling instead of websocket

This commit is contained in:
2024-08-22 23:48:16 +01:00
parent 13ae315b7d
commit 58b47201a6
2 changed files with 78 additions and 45 deletions

View File

@@ -1,14 +1,21 @@
import io import datetime
from pathlib import Path from pathlib import Path
import threading
from audiocraft.data.audio import audio_write from audiocraft.data.audio import audio_write
import fal import fal
from fastapi import WebSocket from fastapi import Response, status
import torch import torch
from prompts import PROMPTS
DATA_DIR = Path("/data/audio") DATA_DIR = Path("/data/audio")
PROMPTS = [
"Create a futuristic lo-fi beat that blends modern electronic elements with synthwave influences. Incorporate smooth, atmospheric synths and gentle, relaxing rhythms to evoke a sense of a serene, neon-lit future. Ensure the track is continuous with no background noise or interruptions, maintaining a calm and tranquil atmosphere throughout while adding a touch of retro-futuristic vibes.",
"gentle lo-fi beat with a smooth, mellow piano melody in the background. Ensure there are no background noises or interruptions, maintaining a continuous and seamless flow throughout the track. The beat should be relaxing and tranquil, perfect for a calm and reflective atmosphere.",
"Create an earthy lo-fi beat that evokes a natural, grounded atmosphere. Incorporate organic sounds like soft percussion, rustling leaves, and gentle acoustic instruments. The track should have a warm, soothing rhythm with a continuous flow and no background noise or interruptions, maintaining a calm and reflective ambiance throughout.",
"Create a soothing lo-fi beat featuring gentle, melodic guitar riffs. The guitar should be the focal point, supported by subtle, ambient electronic elements and a smooth, relaxed rhythm. Ensure the track is continuous with no background noise or interruptions, maintaining a warm and mellow atmosphere throughout.",
"Create an ambient lo-fi beat with a tranquil and ethereal atmosphere. Use soft, atmospheric pads, gentle melodies, and minimalistic percussion to evoke a sense of calm and serenity. Ensure the track is continuous with no background noise or interruptions, maintaining a soothing and immersive ambiance throughout.",
]
class InfinifiFalApp(fal.App, keep_alive=300): class InfinifiFalApp(fal.App, keep_alive=300):
machine_type = "GPU-A6000" machine_type = "GPU-A6000"
@@ -17,8 +24,11 @@ class InfinifiFalApp(fal.App, keep_alive=300):
"audiocraft==1.3.0", "audiocraft==1.3.0",
"torchaudio==2.1.0", "torchaudio==2.1.0",
"websockets==11.0.3", "websockets==11.0.3",
"numpy==1.26.4",
] ]
__is_generating = False
def setup(self): def setup(self):
import torchaudio import torchaudio
from audiocraft.models.musicgen import MusicGen from audiocraft.models.musicgen import MusicGen
@@ -28,22 +38,26 @@ class InfinifiFalApp(fal.App, keep_alive=300):
@fal.endpoint("/generate") @fal.endpoint("/generate")
def run(self): def run(self):
wav = self.model.generate(PROMPTS) if self.__is_generating:
return Response(status_code=status.HTTP_409_CONFLICT)
threading.Thread(target=self.__generate_audio).start()
serialized = [] @fal.endpoint("/clips/{index}")
for one_wav in wav: def get_clips(self, index):
buf = io.BytesIO() if self.__is_generating:
torch.save(one_wav.cpu(), buf) return Response(status_code=status.HTTP_404_NOT_FOUND)
serialized.append(buf.getvalue())
return serialized path = DATA_DIR.joinpath(f"{index}")
with open(path.with_suffix(".mp3"), "rb") as f:
data = f.read()
return Response(content=data)
@fal.endpoint("/ws") def __generate_audio(self):
async def run_ws(self, ws: WebSocket): self.__is_generating = True
await ws.accept()
print(f"[INFO] {datetime.datetime.now()}: generating audio...")
wav = self.model.generate(PROMPTS) wav = self.model.generate(PROMPTS)
for i, one_wav in enumerate(wav): for i, one_wav in enumerate(wav):
path = DATA_DIR.joinpath(f"{i}") path = DATA_DIR.joinpath(f"{i}")
audio_write( audio_write(
@@ -53,9 +67,7 @@ class InfinifiFalApp(fal.App, keep_alive=300):
format="mp3", format="mp3",
strategy="loudness", strategy="loudness",
loudness_compressor=True, loudness_compressor=True,
make_parent_dir=True,
) )
with open(path, "rb") as f:
data = f.read()
await ws.send_bytes(data)
await ws.close() self.__is_generating = False

View File

@@ -1,9 +1,11 @@
import threading import threading
import os import os
from time import sleep
import requests
import websocket import websocket
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi import FastAPI, WebSocket, WebSocketDisconnect, status
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from logger import log_info, log_warn from logger import log_info, log_warn
@@ -15,18 +17,18 @@ current_index = -1
t = None t = None
# websocket connection to the inference server # websocket connection to the inference server
ws = None ws = None
ws_url = "" inference_url = ""
ws_connection_manager = WebSocketConnectionManager() ws_connection_manager = WebSocketConnectionManager()
active_listeners = set() active_listeners = set()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global ws, ws_url global ws, inference_url
ws_url = os.environ.get("INFERENCE_SERVER_WS_URL") inference_url = os.environ.get("INFERENCE_SERVER_URL")
if not ws_url: if not inference_url:
ws_url = "ws://localhost:8001" inference_url = "ws://localhost:8001"
advance() advance()
@@ -39,7 +41,7 @@ async def lifespan(app: FastAPI):
def generate_new_audio(): def generate_new_audio():
if not ws_url: if not inference_url:
return return
global current_index global current_index
@@ -52,31 +54,50 @@ def generate_new_audio():
else: else:
return return
log_info("generating new audio...") log_info("requesting new audio...")
try: try:
ws = websocket.create_connection(ws_url) print(f"{inference_url}/generate")
requests.post(f"{inference_url}/generate")
ws.send("generate")
wavs = []
for i in range(5):
raw = ws.recv()
if isinstance(raw, str):
continue
wavs.append(raw)
for i, wav in enumerate(wavs):
with open(f"{i + offset}.mp3", "wb") as f:
f.write(wav)
log_info("audio generated.")
ws.close()
except: except:
log_warn( log_warn(
"inference server potentially unreachable. recycling cached audio for now." "inference server potentially unreachable. recycling cached audio for now."
) )
return
is_available = False
while not is_available:
try:
res = requests.post(f"{inference_url}/clips/0", stream=True)
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
return
if res.status_code != status.HTTP_200_OK:
print("still generating...")
sleep(5)
continue
print("inference complete! downloading new clips")
is_available = True
with open(f"{offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)
for i in range(4):
res = requests.post(f"{inference_url}/clips/{i + 1}", stream=True)
if res.status_code != status.HTTP_200_OK:
continue
with open(f"{i + 1 + offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)
log_info("audio generated.")
def advance(): def advance():