Add code for fal.ai integration #12
22
fal_app.py
22
fal_app.py
@@ -1,11 +1,14 @@
|
|||||||
import io
|
import io
|
||||||
|
from pathlib import Path
|
||||||
|
from audiocraft.data.audio import audio_write
|
||||||
import fal
|
import fal
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
import torch
|
import torch
|
||||||
from fal.toolkit import File
|
|
||||||
|
|
||||||
from prompts import PROMPTS
|
from prompts import PROMPTS
|
||||||
|
|
||||||
|
DATA_DIR = Path("/data/audio")
|
||||||
|
|
||||||
|
|
||||||
class InfinifiFalApp(fal.App, keep_alive=300):
|
class InfinifiFalApp(fal.App, keep_alive=300):
|
||||||
machine_type = "GPU-A6000"
|
machine_type = "GPU-A6000"
|
||||||
@@ -41,9 +44,18 @@ class InfinifiFalApp(fal.App, keep_alive=300):
|
|||||||
|
|
||||||
wav = self.model.generate(PROMPTS)
|
wav = self.model.generate(PROMPTS)
|
||||||
|
|
||||||
for one_wav in enumerate(wav):
|
for i, one_wav in enumerate(wav):
|
||||||
buf = io.BytesIO()
|
path = DATA_DIR.joinpath(f"{i}")
|
||||||
torch.save(one_wav, buf)
|
audio_write(
|
||||||
await ws.send_bytes(buf.getvalue())
|
path,
|
||||||
|
one_wav.cpu(),
|
||||||
|
self.model.sample_rate,
|
||||||
|
format="mp3",
|
||||||
|
strategy="loudness",
|
||||||
|
loudness_compressor=True,
|
||||||
|
)
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
data = f.read()
|
||||||
|
await ws.send_bytes(data)
|
||||||
|
|
||||||
await ws.close()
|
await ws.close()
|
||||||
|
Reference in New Issue
Block a user