Add code for fal.ai integration #12
22
fal_app.py
22
fal_app.py
@@ -1,11 +1,14 @@
|
||||
import io
|
||||
from pathlib import Path
|
||||
from audiocraft.data.audio import audio_write
|
||||
import fal
|
||||
from fastapi import WebSocket
|
||||
import torch
|
||||
from fal.toolkit import File
|
||||
|
||||
from prompts import PROMPTS
|
||||
|
||||
DATA_DIR = Path("/data/audio")
|
||||
|
||||
|
||||
class InfinifiFalApp(fal.App, keep_alive=300):
|
||||
machine_type = "GPU-A6000"
|
||||
@@ -41,9 +44,18 @@ class InfinifiFalApp(fal.App, keep_alive=300):
|
||||
|
||||
wav = self.model.generate(PROMPTS)
|
||||
|
||||
for one_wav in enumerate(wav):
|
||||
buf = io.BytesIO()
|
||||
torch.save(one_wav, buf)
|
||||
await ws.send_bytes(buf.getvalue())
|
||||
for i, one_wav in enumerate(wav):
|
||||
path = DATA_DIR.joinpath(f"{i}")
|
||||
audio_write(
|
||||
path,
|
||||
one_wav.cpu(),
|
||||
self.model.sample_rate,
|
||||
format="mp3",
|
||||
strategy="loudness",
|
||||
loudness_compressor=True,
|
||||
)
|
||||
with open(path, "rb") as f:
|
||||
data = f.read()
|
||||
await ws.send_bytes(data)
|
||||
|
||||
await ws.close()
|
||||
|
Reference in New Issue
Block a user