refactor: use websocket again

This commit is contained in:
2024-11-25 17:53:37 +00:00
parent b50190f67e
commit bd4827dd99
5 changed files with 240 additions and 61 deletions

19
listener_counter.py Normal file
View File

@@ -0,0 +1,19 @@
import threading
class ListenerCounter:
def __init__(self) -> None:
self.__listener = set()
self.__lock = threading.Lock()
def add_listener(self, listener_id: str):
with self.__lock:
self.__listener.add(listener_id)
def remove_listener(self, listener_id: str):
with self.__lock:
self.__listener.discard(listener_id)
def count(self) -> int:
with self.__lock:
return len(self.__listener)

View File

@@ -1,4 +1,4 @@
fastapi==0.115.5 fastapi==0.115.5
websockets==14.1
logger==1.4 logger==1.4
Requests==2.32.3 Requests==2.32.3
sse_starlette==2.1.3

View File

@@ -1,22 +1,19 @@
import asyncio
import threading import threading
import os import os
import json
from time import sleep from time import sleep
import requests import requests
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
Request, WebSocket,
HTTPException,
status, status,
) )
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from listener_counter import ListenerCounter
from logger import log_info, log_warn from logger import log_info, log_warn
from websocket_connection_manager import WebSocketConnectionManager from websocket_connection_manager import WebSocketConnectionManager
from sse_starlette.sse import EventSourceResponse
# the index of the current audio track from 0 to 9 # the index of the current audio track from 0 to 9
current_index = -1 current_index = -1
@@ -25,7 +22,7 @@ t = None
inference_url = "" inference_url = ""
api_key = "" api_key = ""
ws_connection_manager = WebSocketConnectionManager() ws_connection_manager = WebSocketConnectionManager()
active_listeners = set() listener_counter = ListenerCounter()
@asynccontextmanager @asynccontextmanager
@@ -138,46 +135,32 @@ def get_current_audio():
return FileResponse(f"{current_index}.mp3") return FileResponse(f"{current_index}.mp3")
@app.get("/status") @app.websocket("/ws")
def status_stream(request: Request): async def ws_endpoint(ws: WebSocket):
async def status_generator(): await ws_connection_manager.connect(ws)
last_listener_count = len(active_listeners)
yield json.dumps({"listeners": last_listener_count})
while True: addr = ""
if await request.is_disconnected(): if ws.client:
break addr, _ = ws.client
else:
listener_count = len(active_listeners) await ws.close()
if listener_count != last_listener_count: ws_connection_manager.disconnect(ws)
last_listener_count = listener_count return
yield json.dumps({"listeners": listener_count})
await asyncio.sleep(1)
return EventSourceResponse(status_generator())
@app.post("/client-status")
async def change_status(request: Request):
body = await request.json()
try: try:
is_listening = body["isListening"] while True:
msg = await ws.receive_text()
client = request.client match msg:
if not client: case "listening":
raise HTTPException(status_code=400, detail="ip address unavailable.") listener_counter.add_listener(addr)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
if is_listening: case "paused":
active_listeners.add(client.host) listener_counter.remove_listener(addr)
else: await ws_connection_manager.broadcast(f"{listener_counter.count()}")
active_listeners.discard(client.host) except:
listener_counter.remove_listener(addr)
return {"isListening": is_listening} ws_connection_manager.disconnect(ws)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
except KeyError:
raise HTTPException(status_code=400, detail="'isListening' must be a boolean")
app.mount("/", StaticFiles(directory="web", html=True), name="web") app.mount("/", StaticFiles(directory="web", html=True), name="web")

166
tmp/server.py Normal file
View File

@@ -0,0 +1,166 @@
import threading
import os
from time import sleep
import requests
from contextlib import asynccontextmanager
from fastapi import (
FastAPI,
WebSocket,
status,
)
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from listener_counter import ListenerCounter
from logger import log_info, log_warn
from websocket_connection_manager import WebSocketConnectionManager
# the index of the current audio track from 0 to 9
current_index = -1
# the timer that periodically advances the current audio track
t = None
inference_url = ""
api_key = ""
ws_connection_manager = WebSocketConnectionManager()
listener_counter = ListenerCounter()
@asynccontextmanager
async def lifespan(app: FastAPI):
global ws, inference_url, api_key
inference_url = os.environ.get("INFERENCE_SERVER_URL")
api_key = os.environ.get("API_KEY")
if not inference_url:
inference_url = "http://localhost:8001"
advance()
yield
if t:
t.cancel()
def generate_new_audio():
if not inference_url:
return
global current_index
offset = 0
if current_index == 0:
offset = 5
elif current_index == 5:
offset = 0
else:
return
log_info("requesting new audio...")
try:
requests.post(
f"{inference_url}/generate",
headers={"Authorization": f"key {api_key}"},
)
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
return
is_available = False
while not is_available:
try:
res = requests.post(
f"{inference_url}/clips/0",
stream=True,
headers={"Authorization": f"key {api_key}"},
)
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
return
if res.status_code != status.HTTP_200_OK:
print(res.status_code)
print("still generating...")
sleep(30)
continue
print("inference complete! downloading new clips")
is_available = True
with open(f"{offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)
for i in range(4):
res = requests.post(
f"{inference_url}/clips/{i + 1}",
stream=True,
headers={"Authorization": f"key {api_key}"},
)
if res.status_code != status.HTTP_200_OK:
continue
with open(f"{i + 1 + offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)
log_info("audio generated.")
def advance():
global current_index, t
if current_index == 9:
current_index = 0
else:
current_index = current_index + 1
threading.Thread(target=generate_new_audio).start()
t = threading.Timer(60, advance)
t.start()
app = FastAPI(lifespan=lifespan)
@app.get("/current.mp3")
def get_current_audio():
return FileResponse(f"{current_index}.mp3")
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
await ws_connection_manager.connect(ws)
addr = ""
if ws.client:
addr, _ = ws.client
else:
await ws.close()
ws_connection_manager.disconnect(ws)
return
try:
while True:
msg = await ws.receive_text()
match msg:
case "listening":
listener_counter.add_listener(addr)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
case "paused":
listener_counter.remove_listener(addr)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
except:
listener_counter.remove_listener(addr)
ws_connection_manager.disconnect(ws)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
app.mount("/", StaticFiles(directory="web", html=True), name="web")

View File

@@ -22,6 +22,8 @@ const achievementUnlockedAudio = document.getElementById(
"achievement-unlocked-audio", "achievement-unlocked-audio",
); );
const ws = initializeWebSocket();
let isPlaying = false; let isPlaying = false;
let isFading = false; let isFading = false;
let currentAudio; let currentAudio;
@@ -188,14 +190,6 @@ function showNotification(title, content, duration) {
}, duration); }, duration);
} }
function listenToServerStatusEvent() {
const statusEvent = new EventSource("/status");
statusEvent.addEventListener("message", (event) => {
const data = JSON.parse(event.data);
updateListenerCountLabel(data.listeners);
});
}
function updateListenerCountLabel(newCount) { function updateListenerCountLabel(newCount) {
if (newCount <= 1) { if (newCount <= 1) {
listenerCountLabel.innerText = `${newCount} person tuned in`; listenerCountLabel.innerText = `${newCount} person tuned in`;
@@ -205,14 +199,32 @@ function updateListenerCountLabel(newCount) {
} }
async function updateClientStatus(status) { async function updateClientStatus(status) {
await fetch("/client-status", { if (status.isListening) {
method: "POST", ws.send("listening");
body: JSON.stringify({ isListening: status.isListening }), } else {
headers: { ws.send("paused");
"Content-Type": "application/json", }
}, }
keepalive: true,
}); function initializeWebSocket() {
const ws = new WebSocket(
`${location.protocol === "https:" ? "wss:" : "ws:"}//${location.host}/ws`,
);
ws.onmessage = (event) => {
if (typeof event.data !== "string") {
return;
}
const listenerCount = Number.parseInt(event.data);
if (Number.isNaN(listenerCount)) {
return;
}
updateListenerCountLabel(listenerCount);
};
return ws;
} }
window.addEventListener("beforeunload", (e) => { window.addEventListener("beforeunload", (e) => {
@@ -281,4 +293,3 @@ achievementUnlockedAudio.volume = 0.05;
loadMeowCount(); loadMeowCount();
loadInitialVolume(); loadInitialVolume();
enableSpaceBarControl(); enableSpaceBarControl();
listenToServerStatusEvent();