From 8197873161c3275d6227b0941f100b969b168094 Mon Sep 17 00:00:00 2001 From: Kenneth Date: Sun, 21 Jul 2024 17:34:09 +0100 Subject: [PATCH] feat: modal integration --- modal_wrapper.py | 41 ++++++++++++++++++++++++++++++++++++++ server.py | 51 +++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 85 insertions(+), 7 deletions(-) create mode 100644 modal_wrapper.py diff --git a/modal_wrapper.py b/modal_wrapper.py new file mode 100644 index 0000000..805f7c9 --- /dev/null +++ b/modal_wrapper.py @@ -0,0 +1,41 @@ +from audiocraft.models.musicgen import MusicGen +import modal + + +MODEL_DIR = "/root/model/model_input" +MODEL_ID = "facebook/musicgen-large" +N_GPUS = 1 +GPU_CONFIG = modal.gpu.A100(count=N_GPUS) + + +def download_model(): + import torchaudio + from audiocraft.models.musicgen import MusicGen + + MusicGen.get_pretrained(MODEL_ID) + + +image = modal.Image.from_registry("python:3.9.19-slim-bookworm") + +image = ( + image.apt_install("ffmpeg") + .env({"AUDIOCRAFT_CACHE_DIR": MODEL_DIR}) + .pip_install("audiocraft==1.3.0", "torchaudio==2.1.0") + .run_function(download_model, timeout=20 * 60) +) + + +app = modal.App("infinifi", image=image) + + +@app.cls(gpu=GPU_CONFIG, container_idle_timeout=15 * 60) +class Model: + @modal.enter() + def load(self): + self.model = MusicGen.get_pretrained(MODEL_ID) + self.model.set_generation_params(duration=60) + + @modal.method() + def generate(self, prompts): + wav = self.model.generate(prompts) + return [one_wav for i, one_wav in enumerate(wav)] diff --git a/server.py b/server.py index 1ed6393..b305286 100644 --- a/server.py +++ b/server.py @@ -1,16 +1,28 @@ import threading -# from generate import generate +import modal from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles +from audiocraft.data.audio import audio_write # the index of the current audio track from 0 to 9 current_index = -1 # the timer that periodically advances the current audio track t = None +prompts = [ + "gentle, calming lo-fi beats that helps with studying and focusing", + "calm, piano lo-fi beats to help with studying and focusing", + "gentle lo-fi hip-hop to relax to", + "gentle, quiet synthwave lo-fi beats", + "morning lo-fi beats", +] + + +model = modal.Cls.lookup("infinifi", "Model") + @asynccontextmanager async def lifespan(app: FastAPI): @@ -20,20 +32,45 @@ async def lifespan(app: FastAPI): t.cancel() +def generate_new_audio(): + global current_index + + offset = 0 + wav = [] + if current_index == 0: + offset = 5 + wav = model.generate.remote(prompts) + elif current_index == 5: + offset = 0 + wav = model.generate.remote(prompts) + else: + return + + print("generating new audio...") + + for idx, one_wav in enumerate(wav): + # Will save under {idx}.wav, with loudness normalization at -14 db LUFS. + audio_write( + f"{idx + offset}", + one_wav.cpu(), + model.sample_rate, + format="mp3", + strategy="loudness", + loudness_compressor=True, + ) + + print("audio generated.") + + def advance(): global current_index, t - # if current_index == 0: - # generate(offset=5) - # elif current_index == 5: - # generate(offset=0) - if current_index == 9: current_index = 0 else: current_index = current_index + 1 - print(f"advancing, current index {current_index}") + threading.Thread(target=generate_new_audio).start() t = threading.Timer(60, advance) t.start()