From 5d7960d4c76d54eef20920621b4ce8bd0b0298ba Mon Sep 17 00:00:00 2001 From: Kenneth Date: Sat, 18 May 2024 22:27:05 +0100 Subject: [PATCH] some tweaks and fixes --- hyperparams.py | 4 ++-- resnet.py | 2 +- train.py | 16 ++++++++++------ 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/hyperparams.py b/hyperparams.py index 3c11f3f..a9c7218 100644 --- a/hyperparams.py +++ b/hyperparams.py @@ -1,5 +1,5 @@ FILTER_COUNT = 32 KERNEL_SIZE = 2 CROP_SIZE = 224 -BATCH_SIZE = 4 -EPOCHS = 5 +BATCH_SIZE = 256 +EPOCHS = 20 diff --git a/resnet.py b/resnet.py index e3b4402..a98e3cb 100644 --- a/resnet.py +++ b/resnet.py @@ -175,7 +175,7 @@ class MaiRes(nn.Module): ) self.avgpool = nn.AvgPool2d(kernel_size=RESNET_KERNEL_SIZE) - self.fc = nn.Linear(in_features=8192, out_features=1) + self.fc = nn.Linear(in_features=2048, out_features=1) def forward(self, x): x = self.conv(x) diff --git a/train.py b/train.py index a2d5880..bd775a5 100644 --- a/train.py +++ b/train.py @@ -3,6 +3,7 @@ import torch import torch.nn import torch.optim import torch.utils.data +from pathlib import Path from datetime import datetime from datasets import concatenate_datasets, load_dataset from label import label_fake, label_real @@ -34,11 +35,11 @@ def load_data(): diffusion_db_dataset = load_dataset( "poloclub/diffusiondb", - "2m_random_10k", + "2m_random_50k", trust_remote_code=True, split="train", ) - flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:25%]") + flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:50%]") painting_dataset = load_dataset( "keremberke/painting-style-classification", name="full", split="train" ) @@ -47,7 +48,7 @@ def load_data(): ) movie_poaster_dataset = load_dataset("nanxstats/movie-poster-5k", split="train") metal_album_art_dataset = load_dataset( - "Alphonsce/metal_album_covers", split="train[:80%]" + "Alphonsce/metal_album_covers", split="train[:50%]" ) diffusion_db_dataset = diffusion_db_dataset.select_columns("image") @@ -58,7 +59,6 @@ def load_data(): painting_dataset = painting_dataset.select_columns("image") painting_dataset = painting_dataset.map(label_real) - anime_scene_datatset = anime_scene_datatset.select_columns("image") anime_scene_datatset = anime_scene_datatset.map(label_real) @@ -109,14 +109,14 @@ def load_data(): training_ds, batch_size=BATCH_SIZE, shuffle=True, - num_workers=4, + num_workers=6, collate_fn=collate, ) validation_loader = torch.utils.data.DataLoader( validation_ds, batch_size=BATCH_SIZE, shuffle=True, - num_workers=4, + num_workers=6, collate_fn=collate, ) @@ -134,6 +134,9 @@ def train(): optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001) best_vloss = 1_000_000.0 + Path("models").mkdir(parents=True, exist_ok=True) + print("models directory created") + for epoch in range(EPOCHS): print(f"EPOCH {epoch + 1}:") @@ -199,6 +202,7 @@ def train(): best_vloss = avg_validation_loss model_path = f"model/mai_{timestamp}_{epoch}" torch.save(model.state_dict(), model_path) + print("current model state saved") train()