fix: serialize tensor in modal before returning

This commit is contained in:
2024-07-21 18:42:11 +01:00
parent 8197873161
commit bf9b2b8bb3
2 changed files with 18 additions and 5 deletions

View File

@@ -1,5 +1,7 @@
from audiocraft.models.musicgen import MusicGen
import modal
import io
import torch
MODEL_DIR = "/root/model/model_input"
@@ -35,7 +37,13 @@ class Model:
self.model = MusicGen.get_pretrained(MODEL_ID)
self.model.set_generation_params(duration=60)
@modal.method()
def sample_rate(self):
return self.model.sample_rate
@modal.method()
def generate(self, prompts):
wav = self.model.generate(prompts)
return [one_wav for i, one_wav in enumerate(wav)]
buf = io.BytesIO()
torch.save(wav, buf)
return buf.getvalue()

View File

@@ -1,5 +1,7 @@
import threading
import io
import torch
import modal
from contextlib import asynccontextmanager
from fastapi import FastAPI
@@ -36,16 +38,19 @@ def generate_new_audio():
global current_index
offset = 0
wav = []
wav_buf = None
if current_index == 0:
offset = 5
wav = model.generate.remote(prompts)
wav_buf = model.generate.remote(prompts)
elif current_index == 5:
offset = 0
wav = model.generate.remote(prompts)
wav_buf = model.generate.remote(prompts)
else:
return
wav = torch.load(io.BytesIO(wav_buf), map_location=torch.device("cpu"))
sample_rate = model.sample_rate.remote()
print("generating new audio...")
for idx, one_wav in enumerate(wav):
@@ -53,7 +58,7 @@ def generate_new_audio():
audio_write(
f"{idx + offset}",
one_wav.cpu(),
model.sample_rate,
sample_rate,
format="mp3",
strategy="loudness",
loudness_compressor=True,