refactor: use websocket again
This commit is contained in:
19
listener_counter.py
Normal file
19
listener_counter.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
class ListenerCounter:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.__listener = set()
|
||||||
|
self.__lock = threading.Lock()
|
||||||
|
|
||||||
|
def add_listener(self, listener_id: str):
|
||||||
|
with self.__lock:
|
||||||
|
self.__listener.add(listener_id)
|
||||||
|
|
||||||
|
def remove_listener(self, listener_id: str):
|
||||||
|
with self.__lock:
|
||||||
|
self.__listener.discard(listener_id)
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
with self.__lock:
|
||||||
|
return len(self.__listener)
|
@@ -1,4 +1,4 @@
|
|||||||
fastapi==0.115.5
|
fastapi==0.115.5
|
||||||
|
websockets==14.1
|
||||||
logger==1.4
|
logger==1.4
|
||||||
Requests==2.32.3
|
Requests==2.32.3
|
||||||
sse_starlette==2.1.3
|
|
||||||
|
69
server.py
69
server.py
@@ -1,22 +1,19 @@
|
|||||||
import asyncio
|
|
||||||
import threading
|
import threading
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
from time import sleep
|
from time import sleep
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
FastAPI,
|
FastAPI,
|
||||||
Request,
|
WebSocket,
|
||||||
HTTPException,
|
|
||||||
status,
|
status,
|
||||||
)
|
)
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from listener_counter import ListenerCounter
|
||||||
from logger import log_info, log_warn
|
from logger import log_info, log_warn
|
||||||
from websocket_connection_manager import WebSocketConnectionManager
|
from websocket_connection_manager import WebSocketConnectionManager
|
||||||
from sse_starlette.sse import EventSourceResponse
|
|
||||||
|
|
||||||
# the index of the current audio track from 0 to 9
|
# the index of the current audio track from 0 to 9
|
||||||
current_index = -1
|
current_index = -1
|
||||||
@@ -25,7 +22,7 @@ t = None
|
|||||||
inference_url = ""
|
inference_url = ""
|
||||||
api_key = ""
|
api_key = ""
|
||||||
ws_connection_manager = WebSocketConnectionManager()
|
ws_connection_manager = WebSocketConnectionManager()
|
||||||
active_listeners = set()
|
listener_counter = ListenerCounter()
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -138,46 +135,32 @@ def get_current_audio():
|
|||||||
return FileResponse(f"{current_index}.mp3")
|
return FileResponse(f"{current_index}.mp3")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/status")
|
@app.websocket("/ws")
|
||||||
def status_stream(request: Request):
|
async def ws_endpoint(ws: WebSocket):
|
||||||
async def status_generator():
|
await ws_connection_manager.connect(ws)
|
||||||
last_listener_count = len(active_listeners)
|
|
||||||
yield json.dumps({"listeners": last_listener_count})
|
|
||||||
|
|
||||||
while True:
|
addr = ""
|
||||||
if await request.is_disconnected():
|
if ws.client:
|
||||||
break
|
addr, _ = ws.client
|
||||||
|
else:
|
||||||
listener_count = len(active_listeners)
|
await ws.close()
|
||||||
if listener_count != last_listener_count:
|
ws_connection_manager.disconnect(ws)
|
||||||
last_listener_count = listener_count
|
return
|
||||||
yield json.dumps({"listeners": listener_count})
|
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
return EventSourceResponse(status_generator())
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/client-status")
|
|
||||||
async def change_status(request: Request):
|
|
||||||
body = await request.json()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_listening = body["isListening"]
|
while True:
|
||||||
|
msg = await ws.receive_text()
|
||||||
client = request.client
|
match msg:
|
||||||
if not client:
|
case "listening":
|
||||||
raise HTTPException(status_code=400, detail="ip address unavailable.")
|
listener_counter.add_listener(addr)
|
||||||
|
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
|
||||||
if is_listening:
|
case "paused":
|
||||||
active_listeners.add(client.host)
|
listener_counter.remove_listener(addr)
|
||||||
else:
|
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
|
||||||
active_listeners.discard(client.host)
|
except:
|
||||||
|
listener_counter.remove_listener(addr)
|
||||||
return {"isListening": is_listening}
|
ws_connection_manager.disconnect(ws)
|
||||||
|
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
|
||||||
except KeyError:
|
|
||||||
raise HTTPException(status_code=400, detail="'isListening' must be a boolean")
|
|
||||||
|
|
||||||
|
|
||||||
app.mount("/", StaticFiles(directory="web", html=True), name="web")
|
app.mount("/", StaticFiles(directory="web", html=True), name="web")
|
||||||
|
166
tmp/server.py
Normal file
166
tmp/server.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
import threading
|
||||||
|
import os
|
||||||
|
from time import sleep
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import (
|
||||||
|
FastAPI,
|
||||||
|
WebSocket,
|
||||||
|
status,
|
||||||
|
)
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from listener_counter import ListenerCounter
|
||||||
|
from logger import log_info, log_warn
|
||||||
|
from websocket_connection_manager import WebSocketConnectionManager
|
||||||
|
|
||||||
|
# the index of the current audio track from 0 to 9
|
||||||
|
current_index = -1
|
||||||
|
# the timer that periodically advances the current audio track
|
||||||
|
t = None
|
||||||
|
inference_url = ""
|
||||||
|
api_key = ""
|
||||||
|
ws_connection_manager = WebSocketConnectionManager()
|
||||||
|
listener_counter = ListenerCounter()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
global ws, inference_url, api_key
|
||||||
|
|
||||||
|
inference_url = os.environ.get("INFERENCE_SERVER_URL")
|
||||||
|
api_key = os.environ.get("API_KEY")
|
||||||
|
|
||||||
|
if not inference_url:
|
||||||
|
inference_url = "http://localhost:8001"
|
||||||
|
|
||||||
|
advance()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
if t:
|
||||||
|
t.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_new_audio():
|
||||||
|
if not inference_url:
|
||||||
|
return
|
||||||
|
|
||||||
|
global current_index
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
if current_index == 0:
|
||||||
|
offset = 5
|
||||||
|
elif current_index == 5:
|
||||||
|
offset = 0
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
log_info("requesting new audio...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
requests.post(
|
||||||
|
f"{inference_url}/generate",
|
||||||
|
headers={"Authorization": f"key {api_key}"},
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
log_warn(
|
||||||
|
"inference server potentially unreachable. recycling cached audio for now."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
is_available = False
|
||||||
|
while not is_available:
|
||||||
|
try:
|
||||||
|
res = requests.post(
|
||||||
|
f"{inference_url}/clips/0",
|
||||||
|
stream=True,
|
||||||
|
headers={"Authorization": f"key {api_key}"},
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
log_warn(
|
||||||
|
"inference server potentially unreachable. recycling cached audio for now."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if res.status_code != status.HTTP_200_OK:
|
||||||
|
print(res.status_code)
|
||||||
|
print("still generating...")
|
||||||
|
sleep(30)
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("inference complete! downloading new clips")
|
||||||
|
|
||||||
|
is_available = True
|
||||||
|
with open(f"{offset}.mp3", "wb") as f:
|
||||||
|
for chunk in res.iter_content(chunk_size=128):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
res = requests.post(
|
||||||
|
f"{inference_url}/clips/{i + 1}",
|
||||||
|
stream=True,
|
||||||
|
headers={"Authorization": f"key {api_key}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if res.status_code != status.HTTP_200_OK:
|
||||||
|
continue
|
||||||
|
|
||||||
|
with open(f"{i + 1 + offset}.mp3", "wb") as f:
|
||||||
|
for chunk in res.iter_content(chunk_size=128):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
log_info("audio generated.")
|
||||||
|
|
||||||
|
|
||||||
|
def advance():
|
||||||
|
global current_index, t
|
||||||
|
|
||||||
|
if current_index == 9:
|
||||||
|
current_index = 0
|
||||||
|
else:
|
||||||
|
current_index = current_index + 1
|
||||||
|
threading.Thread(target=generate_new_audio).start()
|
||||||
|
|
||||||
|
t = threading.Timer(60, advance)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/current.mp3")
|
||||||
|
def get_current_audio():
|
||||||
|
return FileResponse(f"{current_index}.mp3")
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws")
|
||||||
|
async def ws_endpoint(ws: WebSocket):
|
||||||
|
await ws_connection_manager.connect(ws)
|
||||||
|
|
||||||
|
addr = ""
|
||||||
|
if ws.client:
|
||||||
|
addr, _ = ws.client
|
||||||
|
else:
|
||||||
|
await ws.close()
|
||||||
|
ws_connection_manager.disconnect(ws)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
msg = await ws.receive_text()
|
||||||
|
match msg:
|
||||||
|
case "listening":
|
||||||
|
listener_counter.add_listener(addr)
|
||||||
|
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
|
||||||
|
case "paused":
|
||||||
|
listener_counter.remove_listener(addr)
|
||||||
|
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
|
||||||
|
except:
|
||||||
|
listener_counter.remove_listener(addr)
|
||||||
|
ws_connection_manager.disconnect(ws)
|
||||||
|
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
|
||||||
|
|
||||||
|
|
||||||
|
app.mount("/", StaticFiles(directory="web", html=True), name="web")
|
@@ -22,6 +22,8 @@ const achievementUnlockedAudio = document.getElementById(
|
|||||||
"achievement-unlocked-audio",
|
"achievement-unlocked-audio",
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const ws = initializeWebSocket();
|
||||||
|
|
||||||
let isPlaying = false;
|
let isPlaying = false;
|
||||||
let isFading = false;
|
let isFading = false;
|
||||||
let currentAudio;
|
let currentAudio;
|
||||||
@@ -188,14 +190,6 @@ function showNotification(title, content, duration) {
|
|||||||
}, duration);
|
}, duration);
|
||||||
}
|
}
|
||||||
|
|
||||||
function listenToServerStatusEvent() {
|
|
||||||
const statusEvent = new EventSource("/status");
|
|
||||||
statusEvent.addEventListener("message", (event) => {
|
|
||||||
const data = JSON.parse(event.data);
|
|
||||||
updateListenerCountLabel(data.listeners);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
function updateListenerCountLabel(newCount) {
|
function updateListenerCountLabel(newCount) {
|
||||||
if (newCount <= 1) {
|
if (newCount <= 1) {
|
||||||
listenerCountLabel.innerText = `${newCount} person tuned in`;
|
listenerCountLabel.innerText = `${newCount} person tuned in`;
|
||||||
@@ -205,14 +199,32 @@ function updateListenerCountLabel(newCount) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function updateClientStatus(status) {
|
async function updateClientStatus(status) {
|
||||||
await fetch("/client-status", {
|
if (status.isListening) {
|
||||||
method: "POST",
|
ws.send("listening");
|
||||||
body: JSON.stringify({ isListening: status.isListening }),
|
} else {
|
||||||
headers: {
|
ws.send("paused");
|
||||||
"Content-Type": "application/json",
|
}
|
||||||
},
|
}
|
||||||
keepalive: true,
|
|
||||||
});
|
function initializeWebSocket() {
|
||||||
|
const ws = new WebSocket(
|
||||||
|
`${location.protocol === "https:" ? "wss:" : "ws:"}//${location.host}/ws`,
|
||||||
|
);
|
||||||
|
|
||||||
|
ws.onmessage = (event) => {
|
||||||
|
if (typeof event.data !== "string") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const listenerCount = Number.parseInt(event.data);
|
||||||
|
if (Number.isNaN(listenerCount)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
updateListenerCountLabel(listenerCount);
|
||||||
|
};
|
||||||
|
|
||||||
|
return ws;
|
||||||
}
|
}
|
||||||
|
|
||||||
window.addEventListener("beforeunload", (e) => {
|
window.addEventListener("beforeunload", (e) => {
|
||||||
@@ -281,4 +293,3 @@ achievementUnlockedAudio.volume = 0.05;
|
|||||||
loadMeowCount();
|
loadMeowCount();
|
||||||
loadInitialVolume();
|
loadInitialVolume();
|
||||||
enableSpaceBarControl();
|
enableSpaceBarControl();
|
||||||
listenToServerStatusEvent();
|
|
||||||
|
Reference in New Issue
Block a user