Files
infinifi/server.py

134 lines
3.0 KiB
Python
Raw Normal View History

2024-07-20 17:09:22 +01:00
import threading
2024-07-22 22:24:39 +01:00
import os
2024-07-20 21:18:41 +01:00
2024-07-22 22:24:39 +01:00
import websocket
from contextlib import asynccontextmanager
2024-07-26 22:34:44 +01:00
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
2024-07-20 17:09:22 +01:00
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
2024-07-24 17:32:46 +01:00
from logger import log_info, log_warn
2024-07-26 22:34:44 +01:00
from websocket_connection_manager import WebSocketConnectionManager
2024-07-20 17:09:22 +01:00
# the index of the current audio track from 0 to 9
2024-07-20 21:18:41 +01:00
current_index = -1
# the timer that periodically advances the current audio track
t = None
2024-07-22 22:24:39 +01:00
# websocket connection to the inference server
ws = None
ws_url = ""
2024-07-26 22:34:44 +01:00
ws_connection_manager = WebSocketConnectionManager()
active_listeners = set()
@asynccontextmanager
async def lifespan(app: FastAPI):
global ws, ws_url
2024-07-22 22:24:39 +01:00
ws_url = os.environ.get("INFERENCE_SERVER_WS_URL")
if not ws_url:
ws_url = "ws://localhost:8001"
2024-07-22 22:24:39 +01:00
advance()
2024-07-22 22:24:39 +01:00
yield
2024-07-22 22:24:39 +01:00
if ws:
ws.close()
if t:
t.cancel()
2024-07-20 17:09:22 +01:00
2024-07-21 17:34:09 +01:00
def generate_new_audio():
if not ws_url:
2024-07-22 22:24:39 +01:00
return
2024-07-21 17:34:09 +01:00
global current_index
offset = 0
if current_index == 0:
offset = 5
elif current_index == 5:
offset = 0
else:
return
2024-07-24 17:32:46 +01:00
log_info("generating new audio...")
2024-07-21 17:34:09 +01:00
2024-07-24 17:32:46 +01:00
try:
ws = websocket.create_connection(ws_url)
2024-07-24 17:32:46 +01:00
ws.send("generate")
2024-07-22 22:24:39 +01:00
2024-07-24 17:32:46 +01:00
wavs = []
for i in range(5):
raw = ws.recv()
if isinstance(raw, str):
continue
wavs.append(raw)
2024-07-22 22:24:39 +01:00
2024-07-24 17:32:46 +01:00
for i, wav in enumerate(wavs):
with open(f"{i + offset}.mp3", "wb") as f:
f.write(wav)
2024-07-21 17:34:09 +01:00
2024-07-24 17:32:46 +01:00
log_info("audio generated.")
2024-07-21 17:34:09 +01:00
2024-07-24 17:32:46 +01:00
ws.close()
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
2024-07-21 17:34:09 +01:00
2024-07-20 17:09:22 +01:00
def advance():
global current_index, t
2024-07-20 17:09:22 +01:00
if current_index == 9:
current_index = 0
else:
current_index = current_index + 1
2024-07-24 17:32:46 +01:00
threading.Thread(target=generate_new_audio).start()
2024-07-20 21:18:41 +01:00
2024-07-20 17:09:22 +01:00
t = threading.Timer(60, advance)
t.start()
app = FastAPI(lifespan=lifespan)
2024-07-20 21:18:41 +01:00
@app.get("/current.mp3")
def get_current_audio():
return FileResponse(f"{current_index}.mp3")
2024-07-26 22:34:44 +01:00
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
await ws_connection_manager.connect(ws)
addr = ""
if ws.client:
addr, _ = ws.client
else:
await ws.close()
ws_connection_manager.disconnect(ws)
await ws_connection_manager.broadcast(f"{len(active_listeners)}")
2024-07-26 22:34:44 +01:00
try:
while True:
msg = await ws.receive_text()
if msg == "playing":
active_listeners.add(addr)
await ws_connection_manager.broadcast(f"{len(active_listeners)}")
elif msg == "paused":
active_listeners.discard(addr)
2024-07-26 22:34:44 +01:00
await ws_connection_manager.broadcast(f"{len(active_listeners)}")
except WebSocketDisconnect:
active_listeners.discard(addr)
2024-07-26 22:34:44 +01:00
ws_connection_manager.disconnect(ws)
await ws_connection_manager.broadcast(f"{len(active_listeners)}")
2024-07-20 21:18:41 +01:00
app.mount("/", StaticFiles(directory="web", html=True), name="web")