some tweaks and fixes
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
FILTER_COUNT = 32
|
||||
KERNEL_SIZE = 2
|
||||
CROP_SIZE = 224
|
||||
BATCH_SIZE = 4
|
||||
EPOCHS = 5
|
||||
BATCH_SIZE = 256
|
||||
EPOCHS = 20
|
||||
|
@@ -175,7 +175,7 @@ class MaiRes(nn.Module):
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(kernel_size=RESNET_KERNEL_SIZE)
|
||||
self.fc = nn.Linear(in_features=8192, out_features=1)
|
||||
self.fc = nn.Linear(in_features=2048, out_features=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
|
16
train.py
16
train.py
@@ -3,6 +3,7 @@ import torch
|
||||
import torch.nn
|
||||
import torch.optim
|
||||
import torch.utils.data
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from label import label_fake, label_real
|
||||
@@ -34,11 +35,11 @@ def load_data():
|
||||
|
||||
diffusion_db_dataset = load_dataset(
|
||||
"poloclub/diffusiondb",
|
||||
"2m_random_10k",
|
||||
"2m_random_50k",
|
||||
trust_remote_code=True,
|
||||
split="train",
|
||||
)
|
||||
flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:25%]")
|
||||
flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:50%]")
|
||||
painting_dataset = load_dataset(
|
||||
"keremberke/painting-style-classification", name="full", split="train"
|
||||
)
|
||||
@@ -47,7 +48,7 @@ def load_data():
|
||||
)
|
||||
movie_poaster_dataset = load_dataset("nanxstats/movie-poster-5k", split="train")
|
||||
metal_album_art_dataset = load_dataset(
|
||||
"Alphonsce/metal_album_covers", split="train[:80%]"
|
||||
"Alphonsce/metal_album_covers", split="train[:50%]"
|
||||
)
|
||||
|
||||
diffusion_db_dataset = diffusion_db_dataset.select_columns("image")
|
||||
@@ -58,7 +59,6 @@ def load_data():
|
||||
|
||||
painting_dataset = painting_dataset.select_columns("image")
|
||||
painting_dataset = painting_dataset.map(label_real)
|
||||
|
||||
anime_scene_datatset = anime_scene_datatset.select_columns("image")
|
||||
anime_scene_datatset = anime_scene_datatset.map(label_real)
|
||||
|
||||
@@ -109,14 +109,14 @@ def load_data():
|
||||
training_ds,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
num_workers=6,
|
||||
collate_fn=collate,
|
||||
)
|
||||
validation_loader = torch.utils.data.DataLoader(
|
||||
validation_ds,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
num_workers=6,
|
||||
collate_fn=collate,
|
||||
)
|
||||
|
||||
@@ -134,6 +134,9 @@ def train():
|
||||
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)
|
||||
best_vloss = 1_000_000.0
|
||||
|
||||
Path("models").mkdir(parents=True, exist_ok=True)
|
||||
print("models directory created")
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
print(f"EPOCH {epoch + 1}:")
|
||||
|
||||
@@ -199,6 +202,7 @@ def train():
|
||||
best_vloss = avg_validation_loss
|
||||
model_path = f"model/mai_{timestamp}_{epoch}"
|
||||
torch.save(model.state_dict(), model_path)
|
||||
print("current model state saved")
|
||||
|
||||
|
||||
train()
|
||||
|
Reference in New Issue
Block a user