diff --git a/modal_wrapper.py b/modal_wrapper.py index 805f7c9..592b0c2 100644 --- a/modal_wrapper.py +++ b/modal_wrapper.py @@ -1,5 +1,7 @@ from audiocraft.models.musicgen import MusicGen import modal +import io +import torch MODEL_DIR = "/root/model/model_input" @@ -35,7 +37,13 @@ class Model: self.model = MusicGen.get_pretrained(MODEL_ID) self.model.set_generation_params(duration=60) + @modal.method() + def sample_rate(self): + return self.model.sample_rate + @modal.method() def generate(self, prompts): wav = self.model.generate(prompts) - return [one_wav for i, one_wav in enumerate(wav)] + buf = io.BytesIO() + torch.save(wav, buf) + return buf.getvalue() diff --git a/server.py b/server.py index b305286..d73ea07 100644 --- a/server.py +++ b/server.py @@ -1,5 +1,7 @@ import threading +import io +import torch import modal from contextlib import asynccontextmanager from fastapi import FastAPI @@ -36,16 +38,19 @@ def generate_new_audio(): global current_index offset = 0 - wav = [] + wav_buf = None if current_index == 0: offset = 5 - wav = model.generate.remote(prompts) + wav_buf = model.generate.remote(prompts) elif current_index == 5: offset = 0 - wav = model.generate.remote(prompts) + wav_buf = model.generate.remote(prompts) else: return + wav = torch.load(io.BytesIO(wav_buf), map_location=torch.device("cpu")) + sample_rate = model.sample_rate.remote() + print("generating new audio...") for idx, one_wav in enumerate(wav): @@ -53,7 +58,7 @@ def generate_new_audio(): audio_write( f"{idx + offset}", one_wav.cpu(), - model.sample_rate, + sample_rate, format="mp3", strategy="loudness", loudness_compressor=True,