refactor: use websocket again

This commit is contained in:
2024-11-25 17:53:37 +00:00
parent b50190f67e
commit bd4827dd99
5 changed files with 240 additions and 61 deletions

19
listener_counter.py Normal file
View File

@@ -0,0 +1,19 @@
import threading
class ListenerCounter:
def __init__(self) -> None:
self.__listener = set()
self.__lock = threading.Lock()
def add_listener(self, listener_id: str):
with self.__lock:
self.__listener.add(listener_id)
def remove_listener(self, listener_id: str):
with self.__lock:
self.__listener.discard(listener_id)
def count(self) -> int:
with self.__lock:
return len(self.__listener)

View File

@@ -1,4 +1,4 @@
fastapi==0.115.5
websockets==14.1
logger==1.4
Requests==2.32.3
sse_starlette==2.1.3

View File

@@ -1,22 +1,19 @@
import asyncio
import threading
import os
import json
from time import sleep
import requests
from contextlib import asynccontextmanager
from fastapi import (
FastAPI,
Request,
HTTPException,
WebSocket,
status,
)
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from listener_counter import ListenerCounter
from logger import log_info, log_warn
from websocket_connection_manager import WebSocketConnectionManager
from sse_starlette.sse import EventSourceResponse
# the index of the current audio track from 0 to 9
current_index = -1
@@ -25,7 +22,7 @@ t = None
inference_url = ""
api_key = ""
ws_connection_manager = WebSocketConnectionManager()
active_listeners = set()
listener_counter = ListenerCounter()
@asynccontextmanager
@@ -138,46 +135,32 @@ def get_current_audio():
return FileResponse(f"{current_index}.mp3")
@app.get("/status")
def status_stream(request: Request):
async def status_generator():
last_listener_count = len(active_listeners)
yield json.dumps({"listeners": last_listener_count})
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
await ws_connection_manager.connect(ws)
while True:
if await request.is_disconnected():
break
listener_count = len(active_listeners)
if listener_count != last_listener_count:
last_listener_count = listener_count
yield json.dumps({"listeners": listener_count})
await asyncio.sleep(1)
return EventSourceResponse(status_generator())
@app.post("/client-status")
async def change_status(request: Request):
body = await request.json()
addr = ""
if ws.client:
addr, _ = ws.client
else:
await ws.close()
ws_connection_manager.disconnect(ws)
return
try:
is_listening = body["isListening"]
client = request.client
if not client:
raise HTTPException(status_code=400, detail="ip address unavailable.")
if is_listening:
active_listeners.add(client.host)
else:
active_listeners.discard(client.host)
return {"isListening": is_listening}
except KeyError:
raise HTTPException(status_code=400, detail="'isListening' must be a boolean")
while True:
msg = await ws.receive_text()
match msg:
case "listening":
listener_counter.add_listener(addr)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
case "paused":
listener_counter.remove_listener(addr)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
except:
listener_counter.remove_listener(addr)
ws_connection_manager.disconnect(ws)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
app.mount("/", StaticFiles(directory="web", html=True), name="web")

166
tmp/server.py Normal file
View File

@@ -0,0 +1,166 @@
import threading
import os
from time import sleep
import requests
from contextlib import asynccontextmanager
from fastapi import (
FastAPI,
WebSocket,
status,
)
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from listener_counter import ListenerCounter
from logger import log_info, log_warn
from websocket_connection_manager import WebSocketConnectionManager
# the index of the current audio track from 0 to 9
current_index = -1
# the timer that periodically advances the current audio track
t = None
inference_url = ""
api_key = ""
ws_connection_manager = WebSocketConnectionManager()
listener_counter = ListenerCounter()
@asynccontextmanager
async def lifespan(app: FastAPI):
global ws, inference_url, api_key
inference_url = os.environ.get("INFERENCE_SERVER_URL")
api_key = os.environ.get("API_KEY")
if not inference_url:
inference_url = "http://localhost:8001"
advance()
yield
if t:
t.cancel()
def generate_new_audio():
if not inference_url:
return
global current_index
offset = 0
if current_index == 0:
offset = 5
elif current_index == 5:
offset = 0
else:
return
log_info("requesting new audio...")
try:
requests.post(
f"{inference_url}/generate",
headers={"Authorization": f"key {api_key}"},
)
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
return
is_available = False
while not is_available:
try:
res = requests.post(
f"{inference_url}/clips/0",
stream=True,
headers={"Authorization": f"key {api_key}"},
)
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
return
if res.status_code != status.HTTP_200_OK:
print(res.status_code)
print("still generating...")
sleep(30)
continue
print("inference complete! downloading new clips")
is_available = True
with open(f"{offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)
for i in range(4):
res = requests.post(
f"{inference_url}/clips/{i + 1}",
stream=True,
headers={"Authorization": f"key {api_key}"},
)
if res.status_code != status.HTTP_200_OK:
continue
with open(f"{i + 1 + offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)
log_info("audio generated.")
def advance():
global current_index, t
if current_index == 9:
current_index = 0
else:
current_index = current_index + 1
threading.Thread(target=generate_new_audio).start()
t = threading.Timer(60, advance)
t.start()
app = FastAPI(lifespan=lifespan)
@app.get("/current.mp3")
def get_current_audio():
return FileResponse(f"{current_index}.mp3")
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
await ws_connection_manager.connect(ws)
addr = ""
if ws.client:
addr, _ = ws.client
else:
await ws.close()
ws_connection_manager.disconnect(ws)
return
try:
while True:
msg = await ws.receive_text()
match msg:
case "listening":
listener_counter.add_listener(addr)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
case "paused":
listener_counter.remove_listener(addr)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
except:
listener_counter.remove_listener(addr)
ws_connection_manager.disconnect(ws)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
app.mount("/", StaticFiles(directory="web", html=True), name="web")

View File

@@ -22,6 +22,8 @@ const achievementUnlockedAudio = document.getElementById(
"achievement-unlocked-audio",
);
const ws = initializeWebSocket();
let isPlaying = false;
let isFading = false;
let currentAudio;
@@ -188,14 +190,6 @@ function showNotification(title, content, duration) {
}, duration);
}
function listenToServerStatusEvent() {
const statusEvent = new EventSource("/status");
statusEvent.addEventListener("message", (event) => {
const data = JSON.parse(event.data);
updateListenerCountLabel(data.listeners);
});
}
function updateListenerCountLabel(newCount) {
if (newCount <= 1) {
listenerCountLabel.innerText = `${newCount} person tuned in`;
@@ -205,14 +199,32 @@ function updateListenerCountLabel(newCount) {
}
async function updateClientStatus(status) {
await fetch("/client-status", {
method: "POST",
body: JSON.stringify({ isListening: status.isListening }),
headers: {
"Content-Type": "application/json",
},
keepalive: true,
});
if (status.isListening) {
ws.send("listening");
} else {
ws.send("paused");
}
}
function initializeWebSocket() {
const ws = new WebSocket(
`${location.protocol === "https:" ? "wss:" : "ws:"}//${location.host}/ws`,
);
ws.onmessage = (event) => {
if (typeof event.data !== "string") {
return;
}
const listenerCount = Number.parseInt(event.data);
if (Number.isNaN(listenerCount)) {
return;
}
updateListenerCountLabel(listenerCount);
};
return ws;
}
window.addEventListener("beforeunload", (e) => {
@@ -281,4 +293,3 @@ achievementUnlockedAudio.volume = 0.05;
loadMeowCount();
loadInitialVolume();
enableSpaceBarControl();
listenToServerStatusEvent();