From 0e744af73f7fd086713b4246f4499d58f4814e12 Mon Sep 17 00:00:00 2001 From: Kenneth Date: Thu, 25 Jul 2024 10:36:24 +0100 Subject: [PATCH 1/7] feat: initial code for fal.ai integration --- fal_app.py | 35 +++++++++++++++++++++++++++++++++++ generate.py | 4 +++- generate_manual.py | 11 +++-------- prompts.py | 7 +++++++ 4 files changed, 48 insertions(+), 9 deletions(-) create mode 100644 fal_app.py create mode 100644 prompts.py diff --git a/fal_app.py b/fal_app.py new file mode 100644 index 0000000..bbb1bf2 --- /dev/null +++ b/fal_app.py @@ -0,0 +1,35 @@ +import io +import fal +import torch +from fal.toolkit import File + +from prompts import PROMPTS + + +class InfinifiFalApp(fal.App, keep_alive=300): + machine_type = "GPU-A6000" + requirements = [ + "torch==2.1.0", + "audiocraft==1.3.0", + "torchaudio==2.1.0", + "websockets==11.0.3", + ] + + def setup(self): + import torchaudio + from audiocraft.models.musicgen import MusicGen + + self.model = MusicGen.get_pretrained("facebook/musicgen-large") + self.model.set_generation_params(duration=60) + + @fal.endpoint("/generate") + def run(self): + wav = self.model.generate(PROMPTS) + + serialized = [] + for one_wav in wav: + buf = io.BytesIO() + torch.save(one_wav.cpu(), buf) + serialized.append(buf.getvalue()) + + return serialized diff --git a/generate.py b/generate.py index 6912c9c..146fa03 100644 --- a/generate.py +++ b/generate.py @@ -2,6 +2,8 @@ import torchaudio from audiocraft.models.musicgen import MusicGen from audiocraft.data.audio import audio_write +from prompts import PROMPTS + MODEL_NAME = "facebook/musicgen-large" MUSIC_DURATION_SECONDS = 60 @@ -17,7 +19,7 @@ descriptions = [ def generate(offset=0): - wav = model.generate(descriptions) + wav = model.generate(PROMPTS) for idx, one_wav in enumerate(wav): # Will save under {idx}.wav, with loudness normalization at -14 db LUFS. diff --git a/generate_manual.py b/generate_manual.py index 536e02f..fe2b34a 100644 --- a/generate_manual.py +++ b/generate_manual.py @@ -3,6 +3,8 @@ import time from audiocraft.models.musicgen import MusicGen from audiocraft.data.audio import audio_write +from prompts import PROMPTS + MODEL_NAME = "facebook/musicgen-large" MUSIC_DURATION_SECONDS = 60 @@ -10,18 +12,11 @@ print("obtaining model...") model = MusicGen.get_pretrained(MODEL_NAME) model.set_generation_params(duration=MUSIC_DURATION_SECONDS) -descriptions = [ - "Create a futuristic lo-fi beat that blends modern electronic elements with synthwave influences. Incorporate smooth, atmospheric synths and gentle, relaxing rhythms to evoke a sense of a serene, neon-lit future. Ensure the track is continuous with no background noise or interruptions, maintaining a calm and tranquil atmosphere throughout while adding a touch of retro-futuristic vibes.", - "gentle lo-fi beat with a smooth, mellow piano melody in the background. Ensure there are no background noises or interruptions, maintaining a continuous and seamless flow throughout the track. The beat should be relaxing and tranquil, perfect for a calm and reflective atmosphere.", - "Create an earthy lo-fi beat that evokes a natural, grounded atmosphere. Incorporate organic sounds like soft percussion, rustling leaves, and gentle acoustic instruments. The track should have a warm, soothing rhythm with a continuous flow and no background noise or interruptions, maintaining a calm and reflective ambiance throughout.", - "Create a soothing lo-fi beat featuring gentle, melodic guitar riffs. The guitar should be the focal point, supported by subtle, ambient electronic elements and a smooth, relaxed rhythm. Ensure the track is continuous with no background noise or interruptions, maintaining a warm and mellow atmosphere throughout.", - "Create an ambient lo-fi beat with a tranquil and ethereal atmosphere. Use soft, atmospheric pads, gentle melodies, and minimalistic percussion to evoke a sense of calm and serenity. Ensure the track is continuous with no background noise or interruptions, maintaining a soothing and immersive ambiance throughout.", -] print("model obtained. generating audio...") a = time.time() -wav = model.generate(descriptions) +wav = model.generate(PROMPTS) b = time.time() print(f"audio generated. took {b - a} seconds.") diff --git a/prompts.py b/prompts.py new file mode 100644 index 0000000..6f42bf8 --- /dev/null +++ b/prompts.py @@ -0,0 +1,7 @@ +PROMPTS = [ + "Create a futuristic lo-fi beat that blends modern electronic elements with synthwave influences. Incorporate smooth, atmospheric synths and gentle, relaxing rhythms to evoke a sense of a serene, neon-lit future. Ensure the track is continuous with no background noise or interruptions, maintaining a calm and tranquil atmosphere throughout while adding a touch of retro-futuristic vibes.", + "gentle lo-fi beat with a smooth, mellow piano melody in the background. Ensure there are no background noises or interruptions, maintaining a continuous and seamless flow throughout the track. The beat should be relaxing and tranquil, perfect for a calm and reflective atmosphere.", + "Create an earthy lo-fi beat that evokes a natural, grounded atmosphere. Incorporate organic sounds like soft percussion, rustling leaves, and gentle acoustic instruments. The track should have a warm, soothing rhythm with a continuous flow and no background noise or interruptions, maintaining a calm and reflective ambiance throughout.", + "Create a soothing lo-fi beat featuring gentle, melodic guitar riffs. The guitar should be the focal point, supported by subtle, ambient electronic elements and a smooth, relaxed rhythm. Ensure the track is continuous with no background noise or interruptions, maintaining a warm and mellow atmosphere throughout.", + "Create an ambient lo-fi beat with a tranquil and ethereal atmosphere. Use soft, atmospheric pads, gentle melodies, and minimalistic percussion to evoke a sense of calm and serenity. Ensure the track is continuous with no background noise or interruptions, maintaining a soothing and immersive ambiance throughout.", +] From 37cf800d8d5ad4279fd34e1c8144b0c64deb40c7 Mon Sep 17 00:00:00 2001 From: Kenneth Date: Thu, 25 Jul 2024 11:00:52 +0100 Subject: [PATCH 2/7] feat: add websocket endpoint for fal --- fal_app.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/fal_app.py b/fal_app.py index bbb1bf2..dd4ade7 100644 --- a/fal_app.py +++ b/fal_app.py @@ -1,5 +1,6 @@ import io import fal +from fastapi import WebSocket import torch from fal.toolkit import File @@ -33,3 +34,16 @@ class InfinifiFalApp(fal.App, keep_alive=300): serialized.append(buf.getvalue()) return serialized + + @fal.endpoint("/ws") + async def run_ws(self, ws: WebSocket): + await ws.accept() + + wav = self.model.generate(PROMPTS) + + for one_wav in enumerate(wav): + buf = io.BytesIO() + torch.save(one_wav, buf) + await ws.send_bytes(buf.getvalue()) + + await ws.close() From 74d2378ef4a3d10d2d5cdb09b85fc6a8d336f0a5 Mon Sep 17 00:00:00 2001 From: Kenneth Date: Fri, 26 Jul 2024 18:12:40 +0100 Subject: [PATCH 3/7] chore: remove redundant prompt strings --- generate.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/generate.py b/generate.py index 146fa03..2f3b677 100644 --- a/generate.py +++ b/generate.py @@ -9,13 +9,6 @@ MUSIC_DURATION_SECONDS = 60 model = MusicGen.get_pretrained(MODEL_NAME) model.set_generation_params(duration=MUSIC_DURATION_SECONDS) -descriptions = [ - "Create a futuristic lo-fi beat that blends modern electronic elements with synthwave influences. Incorporate smooth, atmospheric synths and gentle, relaxing rhythms to evoke a sense of a serene, neon-lit future. Ensure the track is continuous with no background noise or interruptions, maintaining a calm and tranquil atmosphere throughout while adding a touch of retro-futuristic vibes.", - "gentle lo-fi beat with a smooth, mellow piano melody in the background. Ensure there are no background noises or interruptions, maintaining a continuous and seamless flow throughout the track. The beat should be relaxing and tranquil, perfect for a calm and reflective atmosphere.", - "Create an earthy lo-fi beat that evokes a natural, grounded atmosphere. Incorporate organic sounds like soft percussion, rustling leaves, and gentle acoustic instruments. The track should have a warm, soothing rhythm with a continuous flow and no background noise or interruptions, maintaining a calm and reflective ambiance throughout.", - "Create a soothing lo-fi beat featuring gentle, melodic guitar riffs. The guitar should be the focal point, supported by subtle, ambient electronic elements and a smooth, relaxed rhythm. Ensure the track is continuous with no background noise or interruptions, maintaining a warm and mellow atmosphere throughout.", - "Create an ambient lo-fi beat with a tranquil and ethereal atmosphere. Use soft, atmospheric pads, gentle melodies, and minimalistic percussion to evoke a sense of calm and serenity. Ensure the track is continuous with no background noise or interruptions, maintaining a soothing and immersive ambiance throughout.", -] def generate(offset=0): From 13ae315b7dfd4719b0d14ccef42d05fafd452e0d Mon Sep 17 00:00:00 2001 From: Kenneth Date: Tue, 20 Aug 2024 21:12:02 +0100 Subject: [PATCH 4/7] fix: write generated audio to fal storage --- fal_app.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/fal_app.py b/fal_app.py index dd4ade7..3176cad 100644 --- a/fal_app.py +++ b/fal_app.py @@ -1,11 +1,14 @@ import io +from pathlib import Path +from audiocraft.data.audio import audio_write import fal from fastapi import WebSocket import torch -from fal.toolkit import File from prompts import PROMPTS +DATA_DIR = Path("/data/audio") + class InfinifiFalApp(fal.App, keep_alive=300): machine_type = "GPU-A6000" @@ -41,9 +44,18 @@ class InfinifiFalApp(fal.App, keep_alive=300): wav = self.model.generate(PROMPTS) - for one_wav in enumerate(wav): - buf = io.BytesIO() - torch.save(one_wav, buf) - await ws.send_bytes(buf.getvalue()) + for i, one_wav in enumerate(wav): + path = DATA_DIR.joinpath(f"{i}") + audio_write( + path, + one_wav.cpu(), + self.model.sample_rate, + format="mp3", + strategy="loudness", + loudness_compressor=True, + ) + with open(path, "rb") as f: + data = f.read() + await ws.send_bytes(data) await ws.close() From 58b47201a6c59dcc7bf667ae7044876f49ed20ce Mon Sep 17 00:00:00 2001 From: Kenneth Date: Thu, 22 Aug 2024 23:48:16 +0100 Subject: [PATCH 5/7] refactor: use http polling instead of websocket --- fal_app.py | 50 +++++++++++++++++++++++-------------- server.py | 73 +++++++++++++++++++++++++++++++++++------------------- 2 files changed, 78 insertions(+), 45 deletions(-) diff --git a/fal_app.py b/fal_app.py index 3176cad..126e848 100644 --- a/fal_app.py +++ b/fal_app.py @@ -1,14 +1,21 @@ -import io +import datetime from pathlib import Path +import threading from audiocraft.data.audio import audio_write import fal -from fastapi import WebSocket +from fastapi import Response, status import torch -from prompts import PROMPTS - DATA_DIR = Path("/data/audio") +PROMPTS = [ + "Create a futuristic lo-fi beat that blends modern electronic elements with synthwave influences. Incorporate smooth, atmospheric synths and gentle, relaxing rhythms to evoke a sense of a serene, neon-lit future. Ensure the track is continuous with no background noise or interruptions, maintaining a calm and tranquil atmosphere throughout while adding a touch of retro-futuristic vibes.", + "gentle lo-fi beat with a smooth, mellow piano melody in the background. Ensure there are no background noises or interruptions, maintaining a continuous and seamless flow throughout the track. The beat should be relaxing and tranquil, perfect for a calm and reflective atmosphere.", + "Create an earthy lo-fi beat that evokes a natural, grounded atmosphere. Incorporate organic sounds like soft percussion, rustling leaves, and gentle acoustic instruments. The track should have a warm, soothing rhythm with a continuous flow and no background noise or interruptions, maintaining a calm and reflective ambiance throughout.", + "Create a soothing lo-fi beat featuring gentle, melodic guitar riffs. The guitar should be the focal point, supported by subtle, ambient electronic elements and a smooth, relaxed rhythm. Ensure the track is continuous with no background noise or interruptions, maintaining a warm and mellow atmosphere throughout.", + "Create an ambient lo-fi beat with a tranquil and ethereal atmosphere. Use soft, atmospheric pads, gentle melodies, and minimalistic percussion to evoke a sense of calm and serenity. Ensure the track is continuous with no background noise or interruptions, maintaining a soothing and immersive ambiance throughout.", +] + class InfinifiFalApp(fal.App, keep_alive=300): machine_type = "GPU-A6000" @@ -17,8 +24,11 @@ class InfinifiFalApp(fal.App, keep_alive=300): "audiocraft==1.3.0", "torchaudio==2.1.0", "websockets==11.0.3", + "numpy==1.26.4", ] + __is_generating = False + def setup(self): import torchaudio from audiocraft.models.musicgen import MusicGen @@ -28,22 +38,26 @@ class InfinifiFalApp(fal.App, keep_alive=300): @fal.endpoint("/generate") def run(self): - wav = self.model.generate(PROMPTS) + if self.__is_generating: + return Response(status_code=status.HTTP_409_CONFLICT) + threading.Thread(target=self.__generate_audio).start() - serialized = [] - for one_wav in wav: - buf = io.BytesIO() - torch.save(one_wav.cpu(), buf) - serialized.append(buf.getvalue()) + @fal.endpoint("/clips/{index}") + def get_clips(self, index): + if self.__is_generating: + return Response(status_code=status.HTTP_404_NOT_FOUND) - return serialized + path = DATA_DIR.joinpath(f"{index}") + with open(path.with_suffix(".mp3"), "rb") as f: + data = f.read() + return Response(content=data) - @fal.endpoint("/ws") - async def run_ws(self, ws: WebSocket): - await ws.accept() + def __generate_audio(self): + self.__is_generating = True + + print(f"[INFO] {datetime.datetime.now()}: generating audio...") wav = self.model.generate(PROMPTS) - for i, one_wav in enumerate(wav): path = DATA_DIR.joinpath(f"{i}") audio_write( @@ -53,9 +67,7 @@ class InfinifiFalApp(fal.App, keep_alive=300): format="mp3", strategy="loudness", loudness_compressor=True, + make_parent_dir=True, ) - with open(path, "rb") as f: - data = f.read() - await ws.send_bytes(data) - await ws.close() + self.__is_generating = False diff --git a/server.py b/server.py index e5649b4..1d14725 100644 --- a/server.py +++ b/server.py @@ -1,9 +1,11 @@ import threading import os +from time import sleep +import requests import websocket from contextlib import asynccontextmanager -from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, status from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from logger import log_info, log_warn @@ -15,18 +17,18 @@ current_index = -1 t = None # websocket connection to the inference server ws = None -ws_url = "" +inference_url = "" ws_connection_manager = WebSocketConnectionManager() active_listeners = set() @asynccontextmanager async def lifespan(app: FastAPI): - global ws, ws_url + global ws, inference_url - ws_url = os.environ.get("INFERENCE_SERVER_WS_URL") - if not ws_url: - ws_url = "ws://localhost:8001" + inference_url = os.environ.get("INFERENCE_SERVER_URL") + if not inference_url: + inference_url = "ws://localhost:8001" advance() @@ -39,7 +41,7 @@ async def lifespan(app: FastAPI): def generate_new_audio(): - if not ws_url: + if not inference_url: return global current_index @@ -52,31 +54,50 @@ def generate_new_audio(): else: return - log_info("generating new audio...") + log_info("requesting new audio...") try: - ws = websocket.create_connection(ws_url) - - ws.send("generate") - - wavs = [] - for i in range(5): - raw = ws.recv() - if isinstance(raw, str): - continue - wavs.append(raw) - - for i, wav in enumerate(wavs): - with open(f"{i + offset}.mp3", "wb") as f: - f.write(wav) - - log_info("audio generated.") - - ws.close() + print(f"{inference_url}/generate") + requests.post(f"{inference_url}/generate") except: log_warn( "inference server potentially unreachable. recycling cached audio for now." ) + return + + is_available = False + while not is_available: + try: + res = requests.post(f"{inference_url}/clips/0", stream=True) + except: + log_warn( + "inference server potentially unreachable. recycling cached audio for now." + ) + return + + if res.status_code != status.HTTP_200_OK: + print("still generating...") + sleep(5) + continue + + print("inference complete! downloading new clips") + + is_available = True + with open(f"{offset}.mp3", "wb") as f: + for chunk in res.iter_content(chunk_size=128): + f.write(chunk) + + for i in range(4): + res = requests.post(f"{inference_url}/clips/{i + 1}", stream=True) + + if res.status_code != status.HTTP_200_OK: + continue + + with open(f"{i + 1 + offset}.mp3", "wb") as f: + for chunk in res.iter_content(chunk_size=128): + f.write(chunk) + + log_info("audio generated.") def advance(): From b398f4d025a0dbab4d2f6688bc35ba5629c772a6 Mon Sep 17 00:00:00 2001 From: Kenneth Date: Thu, 22 Aug 2024 23:50:17 +0100 Subject: [PATCH 6/7] fix: remove unused websocket obj in web server --- server.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/server.py b/server.py index 1d14725..92410c8 100644 --- a/server.py +++ b/server.py @@ -3,7 +3,6 @@ import os from time import sleep import requests -import websocket from contextlib import asynccontextmanager from fastapi import FastAPI, WebSocket, WebSocketDisconnect, status from fastapi.responses import FileResponse @@ -15,8 +14,6 @@ from websocket_connection_manager import WebSocketConnectionManager current_index = -1 # the timer that periodically advances the current audio track t = None -# websocket connection to the inference server -ws = None inference_url = "" ws_connection_manager = WebSocketConnectionManager() active_listeners = set() @@ -34,8 +31,6 @@ async def lifespan(app: FastAPI): yield - if ws: - ws.close() if t: t.cancel() From 3299f2dcc6c88b329a1685de64c097fc22dbe98a Mon Sep 17 00:00:00 2001 From: Kenneth Date: Thu, 22 Aug 2024 23:50:58 +0100 Subject: [PATCH 7/7] fix: change local inference server url --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 92410c8..d705b81 100644 --- a/server.py +++ b/server.py @@ -25,7 +25,7 @@ async def lifespan(app: FastAPI): inference_url = os.environ.get("INFERENCE_SERVER_URL") if not inference_url: - inference_url = "ws://localhost:8001" + inference_url = "http://localhost:8001" advance()