Files
infinifi/fal_app.py

62 lines
1.5 KiB
Python
Raw Normal View History

import io
from pathlib import Path
from audiocraft.data.audio import audio_write
import fal
2024-07-25 11:00:52 +01:00
from fastapi import WebSocket
import torch
from prompts import PROMPTS
DATA_DIR = Path("/data/audio")
class InfinifiFalApp(fal.App, keep_alive=300):
machine_type = "GPU-A6000"
requirements = [
"torch==2.1.0",
"audiocraft==1.3.0",
"torchaudio==2.1.0",
"websockets==11.0.3",
]
def setup(self):
import torchaudio
from audiocraft.models.musicgen import MusicGen
self.model = MusicGen.get_pretrained("facebook/musicgen-large")
self.model.set_generation_params(duration=60)
@fal.endpoint("/generate")
def run(self):
wav = self.model.generate(PROMPTS)
serialized = []
for one_wav in wav:
buf = io.BytesIO()
torch.save(one_wav.cpu(), buf)
serialized.append(buf.getvalue())
return serialized
2024-07-25 11:00:52 +01:00
@fal.endpoint("/ws")
async def run_ws(self, ws: WebSocket):
await ws.accept()
wav = self.model.generate(PROMPTS)
for i, one_wav in enumerate(wav):
path = DATA_DIR.joinpath(f"{i}")
audio_write(
path,
one_wav.cpu(),
self.model.sample_rate,
format="mp3",
strategy="loudness",
loudness_compressor=True,
)
with open(path, "rb") as f:
data = f.read()
await ws.send_bytes(data)
2024-07-25 11:00:52 +01:00
await ws.close()