feat: modal integration

This commit is contained in:
2024-07-21 17:34:09 +01:00
parent 437987d90a
commit 8197873161
2 changed files with 85 additions and 7 deletions

41
modal_wrapper.py Normal file
View File

@@ -0,0 +1,41 @@
from audiocraft.models.musicgen import MusicGen
import modal
MODEL_DIR = "/root/model/model_input"
MODEL_ID = "facebook/musicgen-large"
N_GPUS = 1
GPU_CONFIG = modal.gpu.A100(count=N_GPUS)
def download_model():
import torchaudio
from audiocraft.models.musicgen import MusicGen
MusicGen.get_pretrained(MODEL_ID)
image = modal.Image.from_registry("python:3.9.19-slim-bookworm")
image = (
image.apt_install("ffmpeg")
.env({"AUDIOCRAFT_CACHE_DIR": MODEL_DIR})
.pip_install("audiocraft==1.3.0", "torchaudio==2.1.0")
.run_function(download_model, timeout=20 * 60)
)
app = modal.App("infinifi", image=image)
@app.cls(gpu=GPU_CONFIG, container_idle_timeout=15 * 60)
class Model:
@modal.enter()
def load(self):
self.model = MusicGen.get_pretrained(MODEL_ID)
self.model.set_generation_params(duration=60)
@modal.method()
def generate(self, prompts):
wav = self.model.generate(prompts)
return [one_wav for i, one_wav in enumerate(wav)]

View File

@@ -1,16 +1,28 @@
import threading
# from generate import generate
import modal
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from audiocraft.data.audio import audio_write
# the index of the current audio track from 0 to 9
current_index = -1
# the timer that periodically advances the current audio track
t = None
prompts = [
"gentle, calming lo-fi beats that helps with studying and focusing",
"calm, piano lo-fi beats to help with studying and focusing",
"gentle lo-fi hip-hop to relax to",
"gentle, quiet synthwave lo-fi beats",
"morning lo-fi beats",
]
model = modal.Cls.lookup("infinifi", "Model")
@asynccontextmanager
async def lifespan(app: FastAPI):
@@ -20,20 +32,45 @@ async def lifespan(app: FastAPI):
t.cancel()
def generate_new_audio():
global current_index
offset = 0
wav = []
if current_index == 0:
offset = 5
wav = model.generate.remote(prompts)
elif current_index == 5:
offset = 0
wav = model.generate.remote(prompts)
else:
return
print("generating new audio...")
for idx, one_wav in enumerate(wav):
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
audio_write(
f"{idx + offset}",
one_wav.cpu(),
model.sample_rate,
format="mp3",
strategy="loudness",
loudness_compressor=True,
)
print("audio generated.")
def advance():
global current_index, t
# if current_index == 0:
# generate(offset=5)
# elif current_index == 5:
# generate(offset=0)
if current_index == 9:
current_index = 0
else:
current_index = current_index + 1
print(f"advancing, current index {current_index}")
threading.Thread(target=generate_new_audio).start()
t = threading.Timer(60, advance)
t.start()