2024-07-20 17:09:22 +01:00
|
|
|
import threading
|
2024-07-22 22:24:39 +01:00
|
|
|
import os
|
2024-08-22 23:48:16 +01:00
|
|
|
from time import sleep
|
|
|
|
import requests
|
2024-07-20 21:18:41 +01:00
|
|
|
|
2024-07-20 21:47:57 +01:00
|
|
|
from contextlib import asynccontextmanager
|
2024-11-25 00:22:44 +00:00
|
|
|
from fastapi import (
|
|
|
|
FastAPI,
|
2024-11-25 17:53:37 +00:00
|
|
|
WebSocket,
|
2024-11-25 00:22:44 +00:00
|
|
|
status,
|
|
|
|
)
|
2024-07-20 17:09:22 +01:00
|
|
|
from fastapi.responses import FileResponse
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
2024-11-25 17:53:37 +00:00
|
|
|
from listener_counter import ListenerCounter
|
2024-07-24 17:32:46 +01:00
|
|
|
from logger import log_info, log_warn
|
2024-07-26 22:34:44 +01:00
|
|
|
from websocket_connection_manager import WebSocketConnectionManager
|
2024-07-20 17:09:22 +01:00
|
|
|
|
2024-07-20 21:47:57 +01:00
|
|
|
# the index of the current audio track from 0 to 9
|
2024-07-20 21:18:41 +01:00
|
|
|
current_index = -1
|
2024-07-20 21:47:57 +01:00
|
|
|
# the timer that periodically advances the current audio track
|
|
|
|
t = None
|
2024-08-22 23:48:16 +01:00
|
|
|
inference_url = ""
|
2024-08-26 12:24:09 +01:00
|
|
|
api_key = ""
|
2024-07-26 22:34:44 +01:00
|
|
|
ws_connection_manager = WebSocketConnectionManager()
|
2024-11-25 17:53:37 +00:00
|
|
|
listener_counter = ListenerCounter()
|
2024-07-23 21:43:35 +01:00
|
|
|
|
2024-07-20 21:47:57 +01:00
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
async def lifespan(app: FastAPI):
|
2024-08-26 12:24:09 +01:00
|
|
|
global ws, inference_url, api_key
|
2024-07-22 22:24:39 +01:00
|
|
|
|
2024-08-22 23:48:16 +01:00
|
|
|
inference_url = os.environ.get("INFERENCE_SERVER_URL")
|
2024-08-26 12:24:09 +01:00
|
|
|
api_key = os.environ.get("API_KEY")
|
|
|
|
|
2024-08-22 23:48:16 +01:00
|
|
|
if not inference_url:
|
2024-08-22 23:50:58 +01:00
|
|
|
inference_url = "http://localhost:8001"
|
2024-07-22 22:24:39 +01:00
|
|
|
|
2024-07-20 21:47:57 +01:00
|
|
|
advance()
|
2024-07-22 22:24:39 +01:00
|
|
|
|
2024-07-20 21:47:57 +01:00
|
|
|
yield
|
2024-07-22 22:24:39 +01:00
|
|
|
|
2024-07-20 21:47:57 +01:00
|
|
|
if t:
|
|
|
|
t.cancel()
|
2024-07-20 17:09:22 +01:00
|
|
|
|
|
|
|
|
2024-07-21 17:34:09 +01:00
|
|
|
def generate_new_audio():
|
2024-08-22 23:48:16 +01:00
|
|
|
if not inference_url:
|
2024-07-22 22:24:39 +01:00
|
|
|
return
|
|
|
|
|
2024-07-21 17:34:09 +01:00
|
|
|
global current_index
|
|
|
|
|
|
|
|
offset = 0
|
|
|
|
if current_index == 0:
|
|
|
|
offset = 5
|
|
|
|
elif current_index == 5:
|
|
|
|
offset = 0
|
|
|
|
else:
|
|
|
|
return
|
|
|
|
|
2024-08-22 23:48:16 +01:00
|
|
|
log_info("requesting new audio...")
|
2024-07-21 17:34:09 +01:00
|
|
|
|
2024-07-24 17:32:46 +01:00
|
|
|
try:
|
2024-08-26 12:24:09 +01:00
|
|
|
requests.post(
|
|
|
|
f"{inference_url}/generate",
|
|
|
|
headers={"Authorization": f"key {api_key}"},
|
|
|
|
)
|
2024-08-22 23:48:16 +01:00
|
|
|
except:
|
|
|
|
log_warn(
|
|
|
|
"inference server potentially unreachable. recycling cached audio for now."
|
|
|
|
)
|
|
|
|
return
|
2024-07-23 21:43:35 +01:00
|
|
|
|
2024-08-22 23:48:16 +01:00
|
|
|
is_available = False
|
|
|
|
while not is_available:
|
|
|
|
try:
|
2024-08-26 12:24:09 +01:00
|
|
|
res = requests.post(
|
|
|
|
f"{inference_url}/clips/0",
|
|
|
|
stream=True,
|
|
|
|
headers={"Authorization": f"key {api_key}"},
|
|
|
|
)
|
2024-08-22 23:48:16 +01:00
|
|
|
except:
|
|
|
|
log_warn(
|
|
|
|
"inference server potentially unreachable. recycling cached audio for now."
|
|
|
|
)
|
|
|
|
return
|
2024-07-22 22:24:39 +01:00
|
|
|
|
2024-08-22 23:48:16 +01:00
|
|
|
if res.status_code != status.HTTP_200_OK:
|
2024-08-26 12:24:09 +01:00
|
|
|
print(res.status_code)
|
2024-08-22 23:48:16 +01:00
|
|
|
print("still generating...")
|
2024-09-03 23:36:04 +01:00
|
|
|
sleep(30)
|
2024-08-22 23:48:16 +01:00
|
|
|
continue
|
2024-07-22 22:24:39 +01:00
|
|
|
|
2024-08-22 23:48:16 +01:00
|
|
|
print("inference complete! downloading new clips")
|
2024-07-21 17:34:09 +01:00
|
|
|
|
2024-08-22 23:48:16 +01:00
|
|
|
is_available = True
|
|
|
|
with open(f"{offset}.mp3", "wb") as f:
|
|
|
|
for chunk in res.iter_content(chunk_size=128):
|
|
|
|
f.write(chunk)
|
2024-07-21 17:34:09 +01:00
|
|
|
|
2024-08-22 23:48:16 +01:00
|
|
|
for i in range(4):
|
2024-08-26 12:24:09 +01:00
|
|
|
res = requests.post(
|
|
|
|
f"{inference_url}/clips/{i + 1}",
|
|
|
|
stream=True,
|
|
|
|
headers={"Authorization": f"key {api_key}"},
|
|
|
|
)
|
2024-08-22 23:48:16 +01:00
|
|
|
|
|
|
|
if res.status_code != status.HTTP_200_OK:
|
|
|
|
continue
|
|
|
|
|
|
|
|
with open(f"{i + 1 + offset}.mp3", "wb") as f:
|
|
|
|
for chunk in res.iter_content(chunk_size=128):
|
|
|
|
f.write(chunk)
|
|
|
|
|
|
|
|
log_info("audio generated.")
|
2024-07-23 21:43:35 +01:00
|
|
|
|
2024-07-21 17:34:09 +01:00
|
|
|
|
2024-07-20 17:09:22 +01:00
|
|
|
def advance():
|
2024-07-20 21:47:57 +01:00
|
|
|
global current_index, t
|
2024-07-20 17:09:22 +01:00
|
|
|
|
|
|
|
if current_index == 9:
|
|
|
|
current_index = 0
|
|
|
|
else:
|
|
|
|
current_index = current_index + 1
|
2024-07-24 17:32:46 +01:00
|
|
|
threading.Thread(target=generate_new_audio).start()
|
2024-07-20 21:18:41 +01:00
|
|
|
|
2024-07-20 17:09:22 +01:00
|
|
|
t = threading.Timer(60, advance)
|
|
|
|
t.start()
|
|
|
|
|
|
|
|
|
2024-07-20 21:47:57 +01:00
|
|
|
app = FastAPI(lifespan=lifespan)
|
2024-07-20 21:18:41 +01:00
|
|
|
|
|
|
|
|
|
|
|
@app.get("/current.mp3")
|
|
|
|
def get_current_audio():
|
|
|
|
return FileResponse(f"{current_index}.mp3")
|
|
|
|
|
|
|
|
|
2024-11-25 17:53:37 +00:00
|
|
|
@app.websocket("/ws")
|
|
|
|
async def ws_endpoint(ws: WebSocket):
|
|
|
|
await ws_connection_manager.connect(ws)
|
2024-11-25 00:22:44 +00:00
|
|
|
|
2024-11-25 17:53:37 +00:00
|
|
|
addr = ""
|
|
|
|
if ws.client:
|
|
|
|
addr, _ = ws.client
|
|
|
|
else:
|
|
|
|
await ws.close()
|
|
|
|
ws_connection_manager.disconnect(ws)
|
|
|
|
return
|
2024-07-27 15:54:09 +01:00
|
|
|
|
2024-11-26 18:10:13 +00:00
|
|
|
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
|
|
|
|
|
2024-07-26 22:34:44 +01:00
|
|
|
try:
|
2024-11-25 17:53:37 +00:00
|
|
|
while True:
|
|
|
|
msg = await ws.receive_text()
|
|
|
|
match msg:
|
|
|
|
case "listening":
|
|
|
|
listener_counter.add_listener(addr)
|
|
|
|
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
|
|
|
|
case "paused":
|
|
|
|
listener_counter.remove_listener(addr)
|
|
|
|
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
|
|
|
|
except:
|
|
|
|
listener_counter.remove_listener(addr)
|
|
|
|
ws_connection_manager.disconnect(ws)
|
|
|
|
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
|
2024-07-26 22:34:44 +01:00
|
|
|
|
|
|
|
|
2024-07-20 21:18:41 +01:00
|
|
|
app.mount("/", StaticFiles(directory="web", html=True), name="web")
|