Add code for fal.ai integration #12

Merged
kennethnym merged 8 commits from fal-integration into main 2024-08-25 17:08:48 +01:00
4 changed files with 74 additions and 16 deletions
Showing only changes of commit 13ae315b7d - Show all commits

View File

@@ -1,11 +1,14 @@
import io
from pathlib import Path
from audiocraft.data.audio import audio_write
import fal
from fastapi import WebSocket
import torch
from fal.toolkit import File
from prompts import PROMPTS
DATA_DIR = Path("/data/audio")
class InfinifiFalApp(fal.App, keep_alive=300):
machine_type = "GPU-A6000"
@@ -41,9 +44,18 @@ class InfinifiFalApp(fal.App, keep_alive=300):
wav = self.model.generate(PROMPTS)
for one_wav in enumerate(wav):
buf = io.BytesIO()
torch.save(one_wav, buf)
await ws.send_bytes(buf.getvalue())
for i, one_wav in enumerate(wav):
path = DATA_DIR.joinpath(f"{i}")
audio_write(
path,
one_wav.cpu(),
self.model.sample_rate,
format="mp3",
strategy="loudness",
loudness_compressor=True,
)
with open(path, "rb") as f:
data = f.read()
await ws.send_bytes(data)
await ws.close()