fix: serialize tensor in modal before returning
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
from audiocraft.models.musicgen import MusicGen
|
from audiocraft.models.musicgen import MusicGen
|
||||||
import modal
|
import modal
|
||||||
|
import io
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
MODEL_DIR = "/root/model/model_input"
|
MODEL_DIR = "/root/model/model_input"
|
||||||
@@ -35,7 +37,13 @@ class Model:
|
|||||||
self.model = MusicGen.get_pretrained(MODEL_ID)
|
self.model = MusicGen.get_pretrained(MODEL_ID)
|
||||||
self.model.set_generation_params(duration=60)
|
self.model.set_generation_params(duration=60)
|
||||||
|
|
||||||
|
@modal.method()
|
||||||
|
def sample_rate(self):
|
||||||
|
return self.model.sample_rate
|
||||||
|
|
||||||
@modal.method()
|
@modal.method()
|
||||||
def generate(self, prompts):
|
def generate(self, prompts):
|
||||||
wav = self.model.generate(prompts)
|
wav = self.model.generate(prompts)
|
||||||
return [one_wav for i, one_wav in enumerate(wav)]
|
buf = io.BytesIO()
|
||||||
|
torch.save(wav, buf)
|
||||||
|
return buf.getvalue()
|
||||||
|
13
server.py
13
server.py
@@ -1,5 +1,7 @@
|
|||||||
import threading
|
import threading
|
||||||
|
import io
|
||||||
|
|
||||||
|
import torch
|
||||||
import modal
|
import modal
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
@@ -36,16 +38,19 @@ def generate_new_audio():
|
|||||||
global current_index
|
global current_index
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
wav = []
|
wav_buf = None
|
||||||
if current_index == 0:
|
if current_index == 0:
|
||||||
offset = 5
|
offset = 5
|
||||||
wav = model.generate.remote(prompts)
|
wav_buf = model.generate.remote(prompts)
|
||||||
elif current_index == 5:
|
elif current_index == 5:
|
||||||
offset = 0
|
offset = 0
|
||||||
wav = model.generate.remote(prompts)
|
wav_buf = model.generate.remote(prompts)
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
wav = torch.load(io.BytesIO(wav_buf), map_location=torch.device("cpu"))
|
||||||
|
sample_rate = model.sample_rate.remote()
|
||||||
|
|
||||||
print("generating new audio...")
|
print("generating new audio...")
|
||||||
|
|
||||||
for idx, one_wav in enumerate(wav):
|
for idx, one_wav in enumerate(wav):
|
||||||
@@ -53,7 +58,7 @@ def generate_new_audio():
|
|||||||
audio_write(
|
audio_write(
|
||||||
f"{idx + offset}",
|
f"{idx + offset}",
|
||||||
one_wav.cpu(),
|
one_wav.cpu(),
|
||||||
model.sample_rate,
|
sample_rate,
|
||||||
format="mp3",
|
format="mp3",
|
||||||
strategy="loudness",
|
strategy="loudness",
|
||||||
loudness_compressor=True,
|
loudness_compressor=True,
|
||||||
|
Reference in New Issue
Block a user