204 lines
6.2 KiB
Python
204 lines
6.2 KiB
Python
import datasets
|
|
import torch
|
|
import torch.nn
|
|
import torch.optim
|
|
import torch.utils.data
|
|
from datetime import datetime
|
|
from datasets import concatenate_datasets, load_dataset
|
|
from label import label_fake, label_real
|
|
from model import mai
|
|
from augmentation import transform_training, transform_validation
|
|
from hyperparams import BATCH_SIZE, EPOCHS
|
|
|
|
TEST_SIZE = 0.1
|
|
|
|
datasets.logging.set_verbosity(datasets.logging.INFO)
|
|
|
|
|
|
def collate(batch):
|
|
pixel_values = []
|
|
is_synthetic = []
|
|
for row in batch:
|
|
is_synthetic.append(row["is_synthetic"])
|
|
pixel_values.append(torch.tensor(row["pixel_values"]))
|
|
|
|
pixel_values = torch.stack(pixel_values, dim=0)
|
|
is_synthetic = torch.tensor(is_synthetic, dtype=torch.float)
|
|
|
|
return pixel_values, is_synthetic
|
|
|
|
|
|
def load_data():
|
|
print("loading datasets...")
|
|
|
|
diffusion_db_dataset = load_dataset(
|
|
"poloclub/diffusiondb",
|
|
"2m_random_10k",
|
|
trust_remote_code=True,
|
|
split="train",
|
|
)
|
|
flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:50%]")
|
|
painting_dataset = load_dataset(
|
|
"keremberke/painting-style-classification", name="full", split="train"
|
|
)
|
|
anime_scene_datatset = load_dataset(
|
|
"animelover/scenery-images", "0-sfw", split="train"
|
|
)
|
|
movie_poaster_dataset = load_dataset("nanxstats/movie-poster-5k", split="train")
|
|
metal_album_art_dataset = load_dataset(
|
|
"Alphonsce/metal_album_covers", split="train[:80%]"
|
|
)
|
|
|
|
diffusion_db_dataset = diffusion_db_dataset.select_columns("image")
|
|
diffusion_db_dataset = diffusion_db_dataset.map(label_fake)
|
|
|
|
flickr_dataset = flickr_dataset.select_columns("image")
|
|
flickr_dataset = flickr_dataset.map(label_real)
|
|
|
|
painting_dataset = painting_dataset.select_columns("image")
|
|
painting_dataset = painting_dataset.map(label_real)
|
|
|
|
anime_scene_datatset = anime_scene_datatset.select_columns("image")
|
|
anime_scene_datatset = anime_scene_datatset.map(label_real)
|
|
|
|
movie_poaster_dataset = movie_poaster_dataset.select_columns("image")
|
|
movie_poaster_dataset = movie_poaster_dataset.map(label_real)
|
|
|
|
metal_album_art_dataset = metal_album_art_dataset.select_columns("image")
|
|
metal_album_art_dataset = metal_album_art_dataset.map(label_real)
|
|
|
|
diffusion_split = diffusion_db_dataset.train_test_split(test_size=TEST_SIZE)
|
|
flickr_split = flickr_dataset.train_test_split(test_size=TEST_SIZE)
|
|
painting_split = painting_dataset.train_test_split(test_size=TEST_SIZE)
|
|
anime_scene_split = anime_scene_datatset.train_test_split(test_size=TEST_SIZE)
|
|
movie_poaster_split = movie_poaster_dataset.train_test_split(test_size=TEST_SIZE)
|
|
metal_album_art_split = metal_album_art_dataset.train_test_split(
|
|
test_size=TEST_SIZE
|
|
)
|
|
|
|
training_ds = concatenate_datasets(
|
|
[
|
|
diffusion_split["train"],
|
|
flickr_split["train"],
|
|
painting_split["train"],
|
|
anime_scene_split["train"],
|
|
movie_poaster_split["train"],
|
|
metal_album_art_split["train"],
|
|
]
|
|
)
|
|
validation_ds = concatenate_datasets(
|
|
[
|
|
diffusion_split["test"],
|
|
flickr_split["test"],
|
|
painting_split["test"],
|
|
anime_scene_split["test"],
|
|
movie_poaster_split["test"],
|
|
metal_album_art_split["test"],
|
|
]
|
|
)
|
|
|
|
training_ds = training_ds.map(
|
|
transform_training, remove_columns=["image"], batched=True
|
|
)
|
|
validation_ds = validation_ds.map(
|
|
transform_validation, remove_columns=["image"], batched=True
|
|
)
|
|
|
|
training_loader = torch.utils.data.DataLoader(
|
|
training_ds,
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=True,
|
|
num_workers=4,
|
|
collate_fn=collate,
|
|
)
|
|
validation_loader = torch.utils.data.DataLoader(
|
|
validation_ds,
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=True,
|
|
num_workers=4,
|
|
collate_fn=collate,
|
|
)
|
|
|
|
return training_loader, validation_loader, len(training_ds)
|
|
|
|
|
|
def train():
|
|
training_loader, validation_loader, sample_size = load_data()
|
|
|
|
print(f"sample size: {sample_size}")
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
model = mai.cuda()
|
|
loss_fn = torch.nn.BCEWithLogitsLoss()
|
|
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)
|
|
best_vloss = 1_000_000.0
|
|
|
|
for epoch in range(EPOCHS):
|
|
print(f"EPOCH {epoch + 1}:")
|
|
|
|
model.train(True)
|
|
|
|
running_loss = 0.0
|
|
last_loss = 0.0
|
|
correct = 0.0
|
|
total = 0.0
|
|
accuracy = 0.0
|
|
i = 0
|
|
|
|
for i, data in enumerate(training_loader):
|
|
inputs, labels = data
|
|
inputs = inputs.cuda()
|
|
|
|
labels = labels.cuda()
|
|
labels = labels.view(-1, 1)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
outputs = model(inputs)
|
|
|
|
loss = loss_fn(outputs, labels)
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
# Calculate accuracy
|
|
predicted = (outputs > 0.5).float() # Applying a threshold of 0.5
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
accuracy = 100 * correct / sample_size
|
|
|
|
running_loss += loss.item()
|
|
if i % 1000 == 999:
|
|
last_loss = running_loss / 1000 # loss per batch
|
|
print(" batch {} loss: {}".format(i + 1, last_loss))
|
|
running_loss = 0.0
|
|
|
|
print(f"ACCURACY {accuracy}")
|
|
|
|
# validation step
|
|
running_vloss = 0.0
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for i, validation_data in enumerate(validation_loader):
|
|
vinputs, vlabels = validation_data
|
|
vinputs = vinputs.cuda()
|
|
vlabels = vlabels.cuda()
|
|
vlabels = vlabels.view(-1, 1)
|
|
|
|
voutputs = model(vinputs)
|
|
vloss = loss_fn(voutputs, vlabels)
|
|
|
|
running_vloss += vloss
|
|
|
|
avg_validation_loss = running_vloss / (i + 1)
|
|
print("LOSS train {} valid {}".format(last_loss, avg_validation_loss))
|
|
|
|
if avg_validation_loss < best_vloss:
|
|
best_vloss = avg_validation_loss
|
|
model_path = f"model/mai_{timestamp}_{epoch}"
|
|
torch.save(model.state_dict(), model_path)
|
|
|
|
|
|
train()
|