implement ResNet model
This commit is contained in:
5
train.py
5
train.py
@@ -9,6 +9,7 @@ from label import label_fake, label_real
|
||||
from model import mai
|
||||
from augmentation import transform_training, transform_validation
|
||||
from hyperparams import BATCH_SIZE, EPOCHS
|
||||
from resnet import MaiRes
|
||||
|
||||
TEST_SIZE = 0.1
|
||||
|
||||
@@ -37,7 +38,7 @@ def load_data():
|
||||
trust_remote_code=True,
|
||||
split="train",
|
||||
)
|
||||
flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:50%]")
|
||||
flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:25%]")
|
||||
painting_dataset = load_dataset(
|
||||
"keremberke/painting-style-classification", name="full", split="train"
|
||||
)
|
||||
@@ -128,7 +129,7 @@ def train():
|
||||
print(f"sample size: {sample_size}")
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model = mai.cuda()
|
||||
model = MaiRes().cuda()
|
||||
loss_fn = torch.nn.BCEWithLogitsLoss()
|
||||
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)
|
||||
best_vloss = 1_000_000.0
|
||||
|
Reference in New Issue
Block a user