implement ResNet model

This commit is contained in:
2024-05-18 01:07:06 +01:00
parent eaa6cf38ca
commit 0ca7d5b8ec
3 changed files with 194 additions and 4 deletions

View File

@@ -9,6 +9,7 @@ from label import label_fake, label_real
from model import mai
from augmentation import transform_training, transform_validation
from hyperparams import BATCH_SIZE, EPOCHS
from resnet import MaiRes
TEST_SIZE = 0.1
@@ -37,7 +38,7 @@ def load_data():
trust_remote_code=True,
split="train",
)
flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:50%]")
flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:25%]")
painting_dataset = load_dataset(
"keremberke/painting-style-classification", name="full", split="train"
)
@@ -128,7 +129,7 @@ def train():
print(f"sample size: {sample_size}")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model = mai.cuda()
model = MaiRes().cuda()
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)
best_vloss = 1_000_000.0