Files
infinifi/server.py

94 lines
2.1 KiB
Python
Raw Normal View History

2024-07-20 17:09:22 +01:00
import threading
import io
2024-07-20 21:18:41 +01:00
import torch
2024-07-21 17:34:09 +01:00
import modal
from contextlib import asynccontextmanager
2024-07-20 17:09:22 +01:00
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
2024-07-21 17:34:09 +01:00
from audiocraft.data.audio import audio_write
2024-07-20 17:09:22 +01:00
# the index of the current audio track from 0 to 9
2024-07-20 21:18:41 +01:00
current_index = -1
# the timer that periodically advances the current audio track
t = None
2024-07-21 17:34:09 +01:00
prompts = [
"gentle, calming lo-fi beats that helps with studying and focusing",
"calm, piano lo-fi beats to help with studying and focusing",
"gentle lo-fi hip-hop to relax to",
"gentle, quiet synthwave lo-fi beats",
"morning lo-fi beats",
]
2024-07-21 21:58:48 +01:00
model = modal.Cls.lookup("infinifi", "Model")
2024-07-21 17:34:09 +01:00
@asynccontextmanager
async def lifespan(app: FastAPI):
advance()
yield
if t:
t.cancel()
2024-07-20 17:09:22 +01:00
2024-07-21 17:34:09 +01:00
def generate_new_audio():
global current_index
offset = 0
wav_buf = None
2024-07-21 17:34:09 +01:00
if current_index == 0:
offset = 5
wav_buf = model.generate.remote(prompts)
2024-07-21 17:34:09 +01:00
elif current_index == 5:
offset = 0
wav_buf = model.generate.remote(prompts)
2024-07-21 17:34:09 +01:00
else:
return
wav = torch.load(io.BytesIO(wav_buf), map_location=torch.device("cpu"))
sample_rate = model.sample_rate.remote()
2024-07-21 17:34:09 +01:00
print("generating new audio...")
for idx, one_wav in enumerate(wav):
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
audio_write(
f"{idx + offset}",
one_wav.cpu(),
sample_rate,
2024-07-21 17:34:09 +01:00
format="mp3",
strategy="loudness",
loudness_compressor=True,
)
print("audio generated.")
2024-07-20 17:09:22 +01:00
def advance():
global current_index, t
2024-07-20 17:09:22 +01:00
if current_index == 9:
current_index = 0
else:
current_index = current_index + 1
2024-07-21 21:58:48 +01:00
threading.Thread(target=generate_new_audio).start()
2024-07-20 21:18:41 +01:00
2024-07-20 17:09:22 +01:00
t = threading.Timer(60, advance)
t.start()
app = FastAPI(lifespan=lifespan)
2024-07-20 21:18:41 +01:00
@app.get("/current.mp3")
def get_current_audio():
print("hello")
return FileResponse(f"{current_index}.mp3")
app.mount("/", StaticFiles(directory="web", html=True), name="web")