feat: modal integration

This commit is contained in:
2024-07-21 17:34:09 +01:00
parent 437987d90a
commit 8197873161
2 changed files with 85 additions and 7 deletions

41
modal_wrapper.py Normal file
View File

@@ -0,0 +1,41 @@
from audiocraft.models.musicgen import MusicGen
import modal
MODEL_DIR = "/root/model/model_input"
MODEL_ID = "facebook/musicgen-large"
N_GPUS = 1
GPU_CONFIG = modal.gpu.A100(count=N_GPUS)
def download_model():
import torchaudio
from audiocraft.models.musicgen import MusicGen
MusicGen.get_pretrained(MODEL_ID)
image = modal.Image.from_registry("python:3.9.19-slim-bookworm")
image = (
image.apt_install("ffmpeg")
.env({"AUDIOCRAFT_CACHE_DIR": MODEL_DIR})
.pip_install("audiocraft==1.3.0", "torchaudio==2.1.0")
.run_function(download_model, timeout=20 * 60)
)
app = modal.App("infinifi", image=image)
@app.cls(gpu=GPU_CONFIG, container_idle_timeout=15 * 60)
class Model:
@modal.enter()
def load(self):
self.model = MusicGen.get_pretrained(MODEL_ID)
self.model.set_generation_params(duration=60)
@modal.method()
def generate(self, prompts):
wav = self.model.generate(prompts)
return [one_wav for i, one_wav in enumerate(wav)]

View File

@@ -1,16 +1,28 @@
import threading import threading
# from generate import generate import modal
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from audiocraft.data.audio import audio_write
# the index of the current audio track from 0 to 9 # the index of the current audio track from 0 to 9
current_index = -1 current_index = -1
# the timer that periodically advances the current audio track # the timer that periodically advances the current audio track
t = None t = None
prompts = [
"gentle, calming lo-fi beats that helps with studying and focusing",
"calm, piano lo-fi beats to help with studying and focusing",
"gentle lo-fi hip-hop to relax to",
"gentle, quiet synthwave lo-fi beats",
"morning lo-fi beats",
]
model = modal.Cls.lookup("infinifi", "Model")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@@ -20,20 +32,45 @@ async def lifespan(app: FastAPI):
t.cancel() t.cancel()
def generate_new_audio():
global current_index
offset = 0
wav = []
if current_index == 0:
offset = 5
wav = model.generate.remote(prompts)
elif current_index == 5:
offset = 0
wav = model.generate.remote(prompts)
else:
return
print("generating new audio...")
for idx, one_wav in enumerate(wav):
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
audio_write(
f"{idx + offset}",
one_wav.cpu(),
model.sample_rate,
format="mp3",
strategy="loudness",
loudness_compressor=True,
)
print("audio generated.")
def advance(): def advance():
global current_index, t global current_index, t
# if current_index == 0:
# generate(offset=5)
# elif current_index == 5:
# generate(offset=0)
if current_index == 9: if current_index == 9:
current_index = 0 current_index = 0
else: else:
current_index = current_index + 1 current_index = current_index + 1
print(f"advancing, current index {current_index}") threading.Thread(target=generate_new_audio).start()
t = threading.Timer(60, advance) t = threading.Timer(60, advance)
t.start() t.start()