Add code for fal.ai integration #12

Merged
kennethnym merged 8 commits from fal-integration into main 2024-08-25 17:08:48 +01:00
5 changed files with 133 additions and 42 deletions
Showing only changes of commit 58b47201a6 - Show all commits

View File

@@ -1,14 +1,21 @@
import io
import datetime
from pathlib import Path
import threading
from audiocraft.data.audio import audio_write
import fal
from fastapi import WebSocket
from fastapi import Response, status
import torch
from prompts import PROMPTS
DATA_DIR = Path("/data/audio")
PROMPTS = [
"Create a futuristic lo-fi beat that blends modern electronic elements with synthwave influences. Incorporate smooth, atmospheric synths and gentle, relaxing rhythms to evoke a sense of a serene, neon-lit future. Ensure the track is continuous with no background noise or interruptions, maintaining a calm and tranquil atmosphere throughout while adding a touch of retro-futuristic vibes.",
"gentle lo-fi beat with a smooth, mellow piano melody in the background. Ensure there are no background noises or interruptions, maintaining a continuous and seamless flow throughout the track. The beat should be relaxing and tranquil, perfect for a calm and reflective atmosphere.",
"Create an earthy lo-fi beat that evokes a natural, grounded atmosphere. Incorporate organic sounds like soft percussion, rustling leaves, and gentle acoustic instruments. The track should have a warm, soothing rhythm with a continuous flow and no background noise or interruptions, maintaining a calm and reflective ambiance throughout.",
"Create a soothing lo-fi beat featuring gentle, melodic guitar riffs. The guitar should be the focal point, supported by subtle, ambient electronic elements and a smooth, relaxed rhythm. Ensure the track is continuous with no background noise or interruptions, maintaining a warm and mellow atmosphere throughout.",
"Create an ambient lo-fi beat with a tranquil and ethereal atmosphere. Use soft, atmospheric pads, gentle melodies, and minimalistic percussion to evoke a sense of calm and serenity. Ensure the track is continuous with no background noise or interruptions, maintaining a soothing and immersive ambiance throughout.",
]
class InfinifiFalApp(fal.App, keep_alive=300):
machine_type = "GPU-A6000"
@@ -17,8 +24,11 @@ class InfinifiFalApp(fal.App, keep_alive=300):
"audiocraft==1.3.0",
"torchaudio==2.1.0",
"websockets==11.0.3",
"numpy==1.26.4",
]
__is_generating = False
def setup(self):
import torchaudio
from audiocraft.models.musicgen import MusicGen
@@ -28,22 +38,26 @@ class InfinifiFalApp(fal.App, keep_alive=300):
@fal.endpoint("/generate")
def run(self):
wav = self.model.generate(PROMPTS)
if self.__is_generating:
return Response(status_code=status.HTTP_409_CONFLICT)
threading.Thread(target=self.__generate_audio).start()
serialized = []
for one_wav in wav:
buf = io.BytesIO()
torch.save(one_wav.cpu(), buf)
serialized.append(buf.getvalue())
@fal.endpoint("/clips/{index}")
def get_clips(self, index):
if self.__is_generating:
return Response(status_code=status.HTTP_404_NOT_FOUND)
return serialized
path = DATA_DIR.joinpath(f"{index}")
with open(path.with_suffix(".mp3"), "rb") as f:
data = f.read()
return Response(content=data)
@fal.endpoint("/ws")
async def run_ws(self, ws: WebSocket):
await ws.accept()
def __generate_audio(self):
self.__is_generating = True
print(f"[INFO] {datetime.datetime.now()}: generating audio...")
wav = self.model.generate(PROMPTS)
for i, one_wav in enumerate(wav):
path = DATA_DIR.joinpath(f"{i}")
audio_write(
@@ -53,9 +67,7 @@ class InfinifiFalApp(fal.App, keep_alive=300):
format="mp3",
strategy="loudness",
loudness_compressor=True,
make_parent_dir=True,
)
with open(path, "rb") as f:
data = f.read()
await ws.send_bytes(data)
await ws.close()
self.__is_generating = False

View File

@@ -1,9 +1,11 @@
import threading
import os
from time import sleep
import requests
import websocket
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, status
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from logger import log_info, log_warn
@@ -15,18 +17,18 @@ current_index = -1
t = None
# websocket connection to the inference server
ws = None
ws_url = ""
inference_url = ""
ws_connection_manager = WebSocketConnectionManager()
active_listeners = set()
@asynccontextmanager
async def lifespan(app: FastAPI):
global ws, ws_url
global ws, inference_url
ws_url = os.environ.get("INFERENCE_SERVER_WS_URL")
if not ws_url:
ws_url = "ws://localhost:8001"
inference_url = os.environ.get("INFERENCE_SERVER_URL")
if not inference_url:
inference_url = "ws://localhost:8001"
advance()
@@ -39,7 +41,7 @@ async def lifespan(app: FastAPI):
def generate_new_audio():
if not ws_url:
if not inference_url:
return
global current_index
@@ -52,31 +54,50 @@ def generate_new_audio():
else:
return
log_info("generating new audio...")
log_info("requesting new audio...")
try:
ws = websocket.create_connection(ws_url)
ws.send("generate")
wavs = []
for i in range(5):
raw = ws.recv()
if isinstance(raw, str):
continue
wavs.append(raw)
for i, wav in enumerate(wavs):
with open(f"{i + offset}.mp3", "wb") as f:
f.write(wav)
log_info("audio generated.")
ws.close()
print(f"{inference_url}/generate")
requests.post(f"{inference_url}/generate")
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
return
is_available = False
while not is_available:
try:
res = requests.post(f"{inference_url}/clips/0", stream=True)
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
return
if res.status_code != status.HTTP_200_OK:
print("still generating...")
sleep(5)
continue
print("inference complete! downloading new clips")
is_available = True
with open(f"{offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)
for i in range(4):
res = requests.post(f"{inference_url}/clips/{i + 1}", stream=True)
if res.status_code != status.HTTP_200_OK:
continue
with open(f"{i + 1 + offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)
log_info("audio generated.")
def advance():