feat: modal integration
This commit is contained in:
41
modal_wrapper.py
Normal file
41
modal_wrapper.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from audiocraft.models.musicgen import MusicGen
|
||||||
|
import modal
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_DIR = "/root/model/model_input"
|
||||||
|
MODEL_ID = "facebook/musicgen-large"
|
||||||
|
N_GPUS = 1
|
||||||
|
GPU_CONFIG = modal.gpu.A100(count=N_GPUS)
|
||||||
|
|
||||||
|
|
||||||
|
def download_model():
|
||||||
|
import torchaudio
|
||||||
|
from audiocraft.models.musicgen import MusicGen
|
||||||
|
|
||||||
|
MusicGen.get_pretrained(MODEL_ID)
|
||||||
|
|
||||||
|
|
||||||
|
image = modal.Image.from_registry("python:3.9.19-slim-bookworm")
|
||||||
|
|
||||||
|
image = (
|
||||||
|
image.apt_install("ffmpeg")
|
||||||
|
.env({"AUDIOCRAFT_CACHE_DIR": MODEL_DIR})
|
||||||
|
.pip_install("audiocraft==1.3.0", "torchaudio==2.1.0")
|
||||||
|
.run_function(download_model, timeout=20 * 60)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app = modal.App("infinifi", image=image)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cls(gpu=GPU_CONFIG, container_idle_timeout=15 * 60)
|
||||||
|
class Model:
|
||||||
|
@modal.enter()
|
||||||
|
def load(self):
|
||||||
|
self.model = MusicGen.get_pretrained(MODEL_ID)
|
||||||
|
self.model.set_generation_params(duration=60)
|
||||||
|
|
||||||
|
@modal.method()
|
||||||
|
def generate(self, prompts):
|
||||||
|
wav = self.model.generate(prompts)
|
||||||
|
return [one_wav for i, one_wav in enumerate(wav)]
|
51
server.py
51
server.py
@@ -1,16 +1,28 @@
|
|||||||
import threading
|
import threading
|
||||||
|
|
||||||
# from generate import generate
|
import modal
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from audiocraft.data.audio import audio_write
|
||||||
|
|
||||||
# the index of the current audio track from 0 to 9
|
# the index of the current audio track from 0 to 9
|
||||||
current_index = -1
|
current_index = -1
|
||||||
# the timer that periodically advances the current audio track
|
# the timer that periodically advances the current audio track
|
||||||
t = None
|
t = None
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"gentle, calming lo-fi beats that helps with studying and focusing",
|
||||||
|
"calm, piano lo-fi beats to help with studying and focusing",
|
||||||
|
"gentle lo-fi hip-hop to relax to",
|
||||||
|
"gentle, quiet synthwave lo-fi beats",
|
||||||
|
"morning lo-fi beats",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
model = modal.Cls.lookup("infinifi", "Model")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
@@ -20,20 +32,45 @@ async def lifespan(app: FastAPI):
|
|||||||
t.cancel()
|
t.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_new_audio():
|
||||||
|
global current_index
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
wav = []
|
||||||
|
if current_index == 0:
|
||||||
|
offset = 5
|
||||||
|
wav = model.generate.remote(prompts)
|
||||||
|
elif current_index == 5:
|
||||||
|
offset = 0
|
||||||
|
wav = model.generate.remote(prompts)
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
print("generating new audio...")
|
||||||
|
|
||||||
|
for idx, one_wav in enumerate(wav):
|
||||||
|
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
|
||||||
|
audio_write(
|
||||||
|
f"{idx + offset}",
|
||||||
|
one_wav.cpu(),
|
||||||
|
model.sample_rate,
|
||||||
|
format="mp3",
|
||||||
|
strategy="loudness",
|
||||||
|
loudness_compressor=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("audio generated.")
|
||||||
|
|
||||||
|
|
||||||
def advance():
|
def advance():
|
||||||
global current_index, t
|
global current_index, t
|
||||||
|
|
||||||
# if current_index == 0:
|
|
||||||
# generate(offset=5)
|
|
||||||
# elif current_index == 5:
|
|
||||||
# generate(offset=0)
|
|
||||||
|
|
||||||
if current_index == 9:
|
if current_index == 9:
|
||||||
current_index = 0
|
current_index = 0
|
||||||
else:
|
else:
|
||||||
current_index = current_index + 1
|
current_index = current_index + 1
|
||||||
|
|
||||||
print(f"advancing, current index {current_index}")
|
threading.Thread(target=generate_new_audio).start()
|
||||||
|
|
||||||
t = threading.Timer(60, advance)
|
t = threading.Timer(60, advance)
|
||||||
t.start()
|
t.start()
|
||||||
|
Reference in New Issue
Block a user