some tweaks and fixes

This commit is contained in:
2024-05-18 22:27:05 +01:00
parent 9922c1ba85
commit 5d7960d4c7
3 changed files with 13 additions and 9 deletions

View File

@@ -1,5 +1,5 @@
FILTER_COUNT = 32
KERNEL_SIZE = 2
CROP_SIZE = 224
BATCH_SIZE = 4
EPOCHS = 5
BATCH_SIZE = 256
EPOCHS = 20

View File

@@ -175,7 +175,7 @@ class MaiRes(nn.Module):
)
self.avgpool = nn.AvgPool2d(kernel_size=RESNET_KERNEL_SIZE)
self.fc = nn.Linear(in_features=8192, out_features=1)
self.fc = nn.Linear(in_features=2048, out_features=1)
def forward(self, x):
x = self.conv(x)

View File

@@ -3,6 +3,7 @@ import torch
import torch.nn
import torch.optim
import torch.utils.data
from pathlib import Path
from datetime import datetime
from datasets import concatenate_datasets, load_dataset
from label import label_fake, label_real
@@ -34,11 +35,11 @@ def load_data():
diffusion_db_dataset = load_dataset(
"poloclub/diffusiondb",
"2m_random_10k",
"2m_random_50k",
trust_remote_code=True,
split="train",
)
flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:25%]")
flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:50%]")
painting_dataset = load_dataset(
"keremberke/painting-style-classification", name="full", split="train"
)
@@ -47,7 +48,7 @@ def load_data():
)
movie_poaster_dataset = load_dataset("nanxstats/movie-poster-5k", split="train")
metal_album_art_dataset = load_dataset(
"Alphonsce/metal_album_covers", split="train[:80%]"
"Alphonsce/metal_album_covers", split="train[:50%]"
)
diffusion_db_dataset = diffusion_db_dataset.select_columns("image")
@@ -58,7 +59,6 @@ def load_data():
painting_dataset = painting_dataset.select_columns("image")
painting_dataset = painting_dataset.map(label_real)
anime_scene_datatset = anime_scene_datatset.select_columns("image")
anime_scene_datatset = anime_scene_datatset.map(label_real)
@@ -109,14 +109,14 @@ def load_data():
training_ds,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4,
num_workers=6,
collate_fn=collate,
)
validation_loader = torch.utils.data.DataLoader(
validation_ds,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4,
num_workers=6,
collate_fn=collate,
)
@@ -134,6 +134,9 @@ def train():
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)
best_vloss = 1_000_000.0
Path("models").mkdir(parents=True, exist_ok=True)
print("models directory created")
for epoch in range(EPOCHS):
print(f"EPOCH {epoch + 1}:")
@@ -199,6 +202,7 @@ def train():
best_vloss = avg_validation_loss
model_path = f"model/mai_{timestamp}_{epoch}"
torch.save(model.state_dict(), model_path)
print("current model state saved")
train()