From bd4827dd993c38d7097f3394aadbf65342a8de74 Mon Sep 17 00:00:00 2001 From: Kenneth Date: Mon, 25 Nov 2024 17:53:37 +0000 Subject: [PATCH] refactor: use websocket again --- listener_counter.py | 19 +++++ requirements-server.txt | 2 +- server.py | 69 +++++++---------- tmp/server.py | 166 ++++++++++++++++++++++++++++++++++++++++ web/script.js | 45 +++++++---- 5 files changed, 240 insertions(+), 61 deletions(-) create mode 100644 listener_counter.py create mode 100644 tmp/server.py diff --git a/listener_counter.py b/listener_counter.py new file mode 100644 index 0000000..c06445c --- /dev/null +++ b/listener_counter.py @@ -0,0 +1,19 @@ +import threading + + +class ListenerCounter: + def __init__(self) -> None: + self.__listener = set() + self.__lock = threading.Lock() + + def add_listener(self, listener_id: str): + with self.__lock: + self.__listener.add(listener_id) + + def remove_listener(self, listener_id: str): + with self.__lock: + self.__listener.discard(listener_id) + + def count(self) -> int: + with self.__lock: + return len(self.__listener) diff --git a/requirements-server.txt b/requirements-server.txt index fa5b71d..441e7c3 100644 --- a/requirements-server.txt +++ b/requirements-server.txt @@ -1,4 +1,4 @@ fastapi==0.115.5 +websockets==14.1 logger==1.4 Requests==2.32.3 -sse_starlette==2.1.3 diff --git a/server.py b/server.py index 1845343..6d4a3fc 100644 --- a/server.py +++ b/server.py @@ -1,22 +1,19 @@ -import asyncio import threading import os -import json from time import sleep import requests from contextlib import asynccontextmanager from fastapi import ( FastAPI, - Request, - HTTPException, + WebSocket, status, ) from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles +from listener_counter import ListenerCounter from logger import log_info, log_warn from websocket_connection_manager import WebSocketConnectionManager -from sse_starlette.sse import EventSourceResponse # the index of the current audio track from 0 to 9 current_index = -1 @@ -25,7 +22,7 @@ t = None inference_url = "" api_key = "" ws_connection_manager = WebSocketConnectionManager() -active_listeners = set() +listener_counter = ListenerCounter() @asynccontextmanager @@ -138,46 +135,32 @@ def get_current_audio(): return FileResponse(f"{current_index}.mp3") -@app.get("/status") -def status_stream(request: Request): - async def status_generator(): - last_listener_count = len(active_listeners) - yield json.dumps({"listeners": last_listener_count}) +@app.websocket("/ws") +async def ws_endpoint(ws: WebSocket): + await ws_connection_manager.connect(ws) - while True: - if await request.is_disconnected(): - break - - listener_count = len(active_listeners) - if listener_count != last_listener_count: - last_listener_count = listener_count - yield json.dumps({"listeners": listener_count}) - - await asyncio.sleep(1) - - return EventSourceResponse(status_generator()) - - -@app.post("/client-status") -async def change_status(request: Request): - body = await request.json() + addr = "" + if ws.client: + addr, _ = ws.client + else: + await ws.close() + ws_connection_manager.disconnect(ws) + return try: - is_listening = body["isListening"] - - client = request.client - if not client: - raise HTTPException(status_code=400, detail="ip address unavailable.") - - if is_listening: - active_listeners.add(client.host) - else: - active_listeners.discard(client.host) - - return {"isListening": is_listening} - - except KeyError: - raise HTTPException(status_code=400, detail="'isListening' must be a boolean") + while True: + msg = await ws.receive_text() + match msg: + case "listening": + listener_counter.add_listener(addr) + await ws_connection_manager.broadcast(f"{listener_counter.count()}") + case "paused": + listener_counter.remove_listener(addr) + await ws_connection_manager.broadcast(f"{listener_counter.count()}") + except: + listener_counter.remove_listener(addr) + ws_connection_manager.disconnect(ws) + await ws_connection_manager.broadcast(f"{listener_counter.count()}") app.mount("/", StaticFiles(directory="web", html=True), name="web") diff --git a/tmp/server.py b/tmp/server.py new file mode 100644 index 0000000..6d4a3fc --- /dev/null +++ b/tmp/server.py @@ -0,0 +1,166 @@ +import threading +import os +from time import sleep +import requests + +from contextlib import asynccontextmanager +from fastapi import ( + FastAPI, + WebSocket, + status, +) +from fastapi.responses import FileResponse +from fastapi.staticfiles import StaticFiles +from listener_counter import ListenerCounter +from logger import log_info, log_warn +from websocket_connection_manager import WebSocketConnectionManager + +# the index of the current audio track from 0 to 9 +current_index = -1 +# the timer that periodically advances the current audio track +t = None +inference_url = "" +api_key = "" +ws_connection_manager = WebSocketConnectionManager() +listener_counter = ListenerCounter() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global ws, inference_url, api_key + + inference_url = os.environ.get("INFERENCE_SERVER_URL") + api_key = os.environ.get("API_KEY") + + if not inference_url: + inference_url = "http://localhost:8001" + + advance() + + yield + + if t: + t.cancel() + + +def generate_new_audio(): + if not inference_url: + return + + global current_index + + offset = 0 + if current_index == 0: + offset = 5 + elif current_index == 5: + offset = 0 + else: + return + + log_info("requesting new audio...") + + try: + requests.post( + f"{inference_url}/generate", + headers={"Authorization": f"key {api_key}"}, + ) + except: + log_warn( + "inference server potentially unreachable. recycling cached audio for now." + ) + return + + is_available = False + while not is_available: + try: + res = requests.post( + f"{inference_url}/clips/0", + stream=True, + headers={"Authorization": f"key {api_key}"}, + ) + except: + log_warn( + "inference server potentially unreachable. recycling cached audio for now." + ) + return + + if res.status_code != status.HTTP_200_OK: + print(res.status_code) + print("still generating...") + sleep(30) + continue + + print("inference complete! downloading new clips") + + is_available = True + with open(f"{offset}.mp3", "wb") as f: + for chunk in res.iter_content(chunk_size=128): + f.write(chunk) + + for i in range(4): + res = requests.post( + f"{inference_url}/clips/{i + 1}", + stream=True, + headers={"Authorization": f"key {api_key}"}, + ) + + if res.status_code != status.HTTP_200_OK: + continue + + with open(f"{i + 1 + offset}.mp3", "wb") as f: + for chunk in res.iter_content(chunk_size=128): + f.write(chunk) + + log_info("audio generated.") + + +def advance(): + global current_index, t + + if current_index == 9: + current_index = 0 + else: + current_index = current_index + 1 + threading.Thread(target=generate_new_audio).start() + + t = threading.Timer(60, advance) + t.start() + + +app = FastAPI(lifespan=lifespan) + + +@app.get("/current.mp3") +def get_current_audio(): + return FileResponse(f"{current_index}.mp3") + + +@app.websocket("/ws") +async def ws_endpoint(ws: WebSocket): + await ws_connection_manager.connect(ws) + + addr = "" + if ws.client: + addr, _ = ws.client + else: + await ws.close() + ws_connection_manager.disconnect(ws) + return + + try: + while True: + msg = await ws.receive_text() + match msg: + case "listening": + listener_counter.add_listener(addr) + await ws_connection_manager.broadcast(f"{listener_counter.count()}") + case "paused": + listener_counter.remove_listener(addr) + await ws_connection_manager.broadcast(f"{listener_counter.count()}") + except: + listener_counter.remove_listener(addr) + ws_connection_manager.disconnect(ws) + await ws_connection_manager.broadcast(f"{listener_counter.count()}") + + +app.mount("/", StaticFiles(directory="web", html=True), name="web") diff --git a/web/script.js b/web/script.js index d7ff580..16b0225 100644 --- a/web/script.js +++ b/web/script.js @@ -22,6 +22,8 @@ const achievementUnlockedAudio = document.getElementById( "achievement-unlocked-audio", ); +const ws = initializeWebSocket(); + let isPlaying = false; let isFading = false; let currentAudio; @@ -188,14 +190,6 @@ function showNotification(title, content, duration) { }, duration); } -function listenToServerStatusEvent() { - const statusEvent = new EventSource("/status"); - statusEvent.addEventListener("message", (event) => { - const data = JSON.parse(event.data); - updateListenerCountLabel(data.listeners); - }); -} - function updateListenerCountLabel(newCount) { if (newCount <= 1) { listenerCountLabel.innerText = `${newCount} person tuned in`; @@ -205,14 +199,32 @@ function updateListenerCountLabel(newCount) { } async function updateClientStatus(status) { - await fetch("/client-status", { - method: "POST", - body: JSON.stringify({ isListening: status.isListening }), - headers: { - "Content-Type": "application/json", - }, - keepalive: true, - }); + if (status.isListening) { + ws.send("listening"); + } else { + ws.send("paused"); + } +} + +function initializeWebSocket() { + const ws = new WebSocket( + `${location.protocol === "https:" ? "wss:" : "ws:"}//${location.host}/ws`, + ); + + ws.onmessage = (event) => { + if (typeof event.data !== "string") { + return; + } + + const listenerCount = Number.parseInt(event.data); + if (Number.isNaN(listenerCount)) { + return; + } + + updateListenerCountLabel(listenerCount); + }; + + return ws; } window.addEventListener("beforeunload", (e) => { @@ -281,4 +293,3 @@ achievementUnlockedAudio.volume = 0.05; loadMeowCount(); loadInitialVolume(); enableSpaceBarControl(); -listenToServerStatusEvent();