feat: start adding server code
This commit is contained in:
34
generate.py
34
generate.py
@@ -1,13 +1,10 @@
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
import time
|
|
||||||
from audiocraft.models.musicgen import MusicGen
|
from audiocraft.models.musicgen import MusicGen
|
||||||
from audiocraft.data.audio import audio_write
|
from audiocraft.data.audio import audio_write
|
||||||
|
|
||||||
MODEL_NAME = "facebook/musicgen-large"
|
MODEL_NAME = "facebook/musicgen-large"
|
||||||
MUSIC_DURATION_SECONDS = 60
|
MUSIC_DURATION_SECONDS = 60
|
||||||
|
|
||||||
print(f"getting {MODEL_NAME}...")
|
|
||||||
|
|
||||||
model = MusicGen.get_pretrained(MODEL_NAME)
|
model = MusicGen.get_pretrained(MODEL_NAME)
|
||||||
model.set_generation_params(duration=MUSIC_DURATION_SECONDS)
|
model.set_generation_params(duration=MUSIC_DURATION_SECONDS)
|
||||||
descriptions = [
|
descriptions = [
|
||||||
@@ -15,25 +12,20 @@ descriptions = [
|
|||||||
"calm, piano lo-fi beats to help with studying and focusing",
|
"calm, piano lo-fi beats to help with studying and focusing",
|
||||||
"gentle lo-fi hip-hop to relax to",
|
"gentle lo-fi hip-hop to relax to",
|
||||||
"gentle, quiet synthwave lo-fi beats",
|
"gentle, quiet synthwave lo-fi beats",
|
||||||
"morning lo-fi beats"
|
"morning lo-fi beats",
|
||||||
]
|
]
|
||||||
|
|
||||||
print("model obtained. generating wav files...")
|
|
||||||
|
|
||||||
a = time.time()
|
def generate(offset=0):
|
||||||
|
wav = model.generate(descriptions)
|
||||||
|
|
||||||
wav = model.generate(descriptions)
|
for idx, one_wav in enumerate(wav):
|
||||||
|
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
|
||||||
b = time.time()
|
audio_write(
|
||||||
|
f"{idx + offset}",
|
||||||
print(f"{len(wav)} generated. took {b - a} seconds.")
|
one_wav.cpu(),
|
||||||
|
model.sample_rate,
|
||||||
for idx, one_wav in enumerate(wav):
|
format="mp3",
|
||||||
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
|
strategy="loudness",
|
||||||
audio_write(
|
loudness_compressor=True,
|
||||||
f"{idx}",
|
)
|
||||||
one_wav.cpu(),
|
|
||||||
model.sample_rate,
|
|
||||||
strategy="loudness",
|
|
||||||
loudness_compressor=True,
|
|
||||||
)
|
|
||||||
|
38
generate_manual.py
Normal file
38
generate_manual.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import torchaudio
|
||||||
|
import time
|
||||||
|
from audiocraft.models.musicgen import MusicGen
|
||||||
|
from audiocraft.data.audio import audio_write
|
||||||
|
|
||||||
|
MODEL_NAME = "facebook/musicgen-large"
|
||||||
|
MUSIC_DURATION_SECONDS = 60
|
||||||
|
|
||||||
|
print("obtaining model...")
|
||||||
|
|
||||||
|
model = MusicGen.get_pretrained(MODEL_NAME)
|
||||||
|
model.set_generation_params(duration=MUSIC_DURATION_SECONDS)
|
||||||
|
descriptions = [
|
||||||
|
"gentle, calming lo-fi beats that helps with studying and focusing",
|
||||||
|
"calm, piano lo-fi beats to help with studying and focusing",
|
||||||
|
"gentle lo-fi hip-hop to relax to",
|
||||||
|
"gentle, quiet synthwave lo-fi beats",
|
||||||
|
"morning lo-fi beats",
|
||||||
|
]
|
||||||
|
|
||||||
|
print("model obtained. generating audio...")
|
||||||
|
|
||||||
|
a = time.time()
|
||||||
|
wav = model.generate(descriptions)
|
||||||
|
b = time.time()
|
||||||
|
|
||||||
|
print(f"audio generated. took {b - a} seconds.")
|
||||||
|
|
||||||
|
for idx, one_wav in enumerate(wav):
|
||||||
|
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
|
||||||
|
audio_write(
|
||||||
|
f"{idx}",
|
||||||
|
one_wav.cpu(),
|
||||||
|
model.sample_rate,
|
||||||
|
format="mp3",
|
||||||
|
strategy="loudness",
|
||||||
|
loudness_compressor=True,
|
||||||
|
)
|
30
server.py
Normal file
30
server.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import threading
|
||||||
|
from .generate import generate
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
current_index = 0
|
||||||
|
|
||||||
|
app.mount("/", StaticFiles(directory="web", html=True), name="web")
|
||||||
|
|
||||||
|
|
||||||
|
def advance():
|
||||||
|
global current_index
|
||||||
|
|
||||||
|
# if current_index == 0:
|
||||||
|
# generate(offset=5)
|
||||||
|
# elif current_index == 5:
|
||||||
|
# generate(offset=0)
|
||||||
|
|
||||||
|
if current_index == 9:
|
||||||
|
current_index = 0
|
||||||
|
else:
|
||||||
|
current_index = current_index + 1
|
||||||
|
|
||||||
|
t = threading.Timer(60, advance)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
advance()
|
12
web/index.html
Normal file
12
web/index.html
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>infinifi</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<main>
|
||||||
|
<p>test</p>
|
||||||
|
</main>
|
||||||
|
<script type="text/javascript" src="/script.js"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
0
web/script.js
Normal file
0
web/script.js
Normal file
Reference in New Issue
Block a user