some tweaks and fixes
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
FILTER_COUNT = 32
|
FILTER_COUNT = 32
|
||||||
KERNEL_SIZE = 2
|
KERNEL_SIZE = 2
|
||||||
CROP_SIZE = 224
|
CROP_SIZE = 224
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 256
|
||||||
EPOCHS = 5
|
EPOCHS = 20
|
||||||
|
@@ -175,7 +175,7 @@ class MaiRes(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.avgpool = nn.AvgPool2d(kernel_size=RESNET_KERNEL_SIZE)
|
self.avgpool = nn.AvgPool2d(kernel_size=RESNET_KERNEL_SIZE)
|
||||||
self.fc = nn.Linear(in_features=8192, out_features=1)
|
self.fc = nn.Linear(in_features=2048, out_features=1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
|
16
train.py
16
train.py
@@ -3,6 +3,7 @@ import torch
|
|||||||
import torch.nn
|
import torch.nn
|
||||||
import torch.optim
|
import torch.optim
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datasets import concatenate_datasets, load_dataset
|
from datasets import concatenate_datasets, load_dataset
|
||||||
from label import label_fake, label_real
|
from label import label_fake, label_real
|
||||||
@@ -34,11 +35,11 @@ def load_data():
|
|||||||
|
|
||||||
diffusion_db_dataset = load_dataset(
|
diffusion_db_dataset = load_dataset(
|
||||||
"poloclub/diffusiondb",
|
"poloclub/diffusiondb",
|
||||||
"2m_random_10k",
|
"2m_random_50k",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
split="train",
|
split="train",
|
||||||
)
|
)
|
||||||
flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:25%]")
|
flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:50%]")
|
||||||
painting_dataset = load_dataset(
|
painting_dataset = load_dataset(
|
||||||
"keremberke/painting-style-classification", name="full", split="train"
|
"keremberke/painting-style-classification", name="full", split="train"
|
||||||
)
|
)
|
||||||
@@ -47,7 +48,7 @@ def load_data():
|
|||||||
)
|
)
|
||||||
movie_poaster_dataset = load_dataset("nanxstats/movie-poster-5k", split="train")
|
movie_poaster_dataset = load_dataset("nanxstats/movie-poster-5k", split="train")
|
||||||
metal_album_art_dataset = load_dataset(
|
metal_album_art_dataset = load_dataset(
|
||||||
"Alphonsce/metal_album_covers", split="train[:80%]"
|
"Alphonsce/metal_album_covers", split="train[:50%]"
|
||||||
)
|
)
|
||||||
|
|
||||||
diffusion_db_dataset = diffusion_db_dataset.select_columns("image")
|
diffusion_db_dataset = diffusion_db_dataset.select_columns("image")
|
||||||
@@ -58,7 +59,6 @@ def load_data():
|
|||||||
|
|
||||||
painting_dataset = painting_dataset.select_columns("image")
|
painting_dataset = painting_dataset.select_columns("image")
|
||||||
painting_dataset = painting_dataset.map(label_real)
|
painting_dataset = painting_dataset.map(label_real)
|
||||||
|
|
||||||
anime_scene_datatset = anime_scene_datatset.select_columns("image")
|
anime_scene_datatset = anime_scene_datatset.select_columns("image")
|
||||||
anime_scene_datatset = anime_scene_datatset.map(label_real)
|
anime_scene_datatset = anime_scene_datatset.map(label_real)
|
||||||
|
|
||||||
@@ -109,14 +109,14 @@ def load_data():
|
|||||||
training_ds,
|
training_ds,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=4,
|
num_workers=6,
|
||||||
collate_fn=collate,
|
collate_fn=collate,
|
||||||
)
|
)
|
||||||
validation_loader = torch.utils.data.DataLoader(
|
validation_loader = torch.utils.data.DataLoader(
|
||||||
validation_ds,
|
validation_ds,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=4,
|
num_workers=6,
|
||||||
collate_fn=collate,
|
collate_fn=collate,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -134,6 +134,9 @@ def train():
|
|||||||
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)
|
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)
|
||||||
best_vloss = 1_000_000.0
|
best_vloss = 1_000_000.0
|
||||||
|
|
||||||
|
Path("models").mkdir(parents=True, exist_ok=True)
|
||||||
|
print("models directory created")
|
||||||
|
|
||||||
for epoch in range(EPOCHS):
|
for epoch in range(EPOCHS):
|
||||||
print(f"EPOCH {epoch + 1}:")
|
print(f"EPOCH {epoch + 1}:")
|
||||||
|
|
||||||
@@ -199,6 +202,7 @@ def train():
|
|||||||
best_vloss = avg_validation_loss
|
best_vloss = avg_validation_loss
|
||||||
model_path = f"model/mai_{timestamp}_{epoch}"
|
model_path = f"model/mai_{timestamp}_{epoch}"
|
||||||
torch.save(model.state_dict(), model_path)
|
torch.save(model.state_dict(), model_path)
|
||||||
|
print("current model state saved")
|
||||||
|
|
||||||
|
|
||||||
train()
|
train()
|
||||||
|
Reference in New Issue
Block a user