feat: add code for finetuning moondream
This commit is contained in:
51
resnet/augmentation.py
Normal file
51
resnet/augmentation.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import albumentations as a
|
||||
import numpy as np
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from hyperparams import CROP_SIZE
|
||||
|
||||
|
||||
preprocess_training = a.Compose(
|
||||
[
|
||||
a.augmentations.PadIfNeeded(min_width=CROP_SIZE, min_height=CROP_SIZE),
|
||||
a.RandomCrop(width=CROP_SIZE, height=CROP_SIZE),
|
||||
a.GaussNoise(),
|
||||
a.Flip(p=0.5),
|
||||
a.RandomRotate90(p=0.5),
|
||||
a.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
preprocess_validation = a.Compose(
|
||||
[
|
||||
a.augmentations.PadIfNeeded(min_width=CROP_SIZE, min_height=CROP_SIZE),
|
||||
a.CenterCrop(width=CROP_SIZE, height=CROP_SIZE),
|
||||
a.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def transform_training(example):
|
||||
transformed = []
|
||||
for pil_image in example["image"]:
|
||||
array = np.array(pil_image.convert("RGB"))
|
||||
# check if image is in (height, width, channel) shape
|
||||
# if not, do a transpose
|
||||
if array.shape[-1] != 3:
|
||||
array = np.transpose(array, (1, 2, 0))
|
||||
img = preprocess_training(image=array)["image"]
|
||||
transformed.append(img)
|
||||
example["pixel_values"] = transformed
|
||||
return example
|
||||
|
||||
|
||||
def transform_validation(example):
|
||||
transformed = []
|
||||
for pil_image in example["image"]:
|
||||
array = np.array(pil_image.convert("RGB"))
|
||||
if array.shape[-1] != 3:
|
||||
array = np.transpose(array, (1, 2, 0))
|
||||
img = preprocess_validation(image=array)["image"]
|
||||
transformed.append(img)
|
||||
example["pixel_values"] = transformed
|
||||
return example
|
5
resnet/hyperparams.py
Normal file
5
resnet/hyperparams.py
Normal file
@@ -0,0 +1,5 @@
|
||||
FILTER_COUNT = 32
|
||||
KERNEL_SIZE = 2
|
||||
CROP_SIZE = 224
|
||||
BATCH_SIZE = 256
|
||||
EPOCHS = 20
|
21
resnet/inference.py
Normal file
21
resnet/inference.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from model import mai
|
||||
from augmentation import preprocess_validation
|
||||
|
||||
|
||||
def load_model_and_run_inference(img):
|
||||
mai.load_state_dict(torch.load("mai"))
|
||||
mai.eval()
|
||||
img_batch = np.expand_dims(img, axis=0)
|
||||
img_batch = torch.tensor(img_batch)
|
||||
prediction = mai(img_batch)
|
||||
prediction = torch.sigmoid(prediction)
|
||||
print(prediction)
|
||||
|
||||
|
||||
def main():
|
||||
img = Image.open("test_images/dog.jpg")
|
||||
img = preprocess_validation(image=np.array(img))["image"]
|
||||
load_model_and_run_inference(img)
|
8
resnet/label.py
Normal file
8
resnet/label.py
Normal file
@@ -0,0 +1,8 @@
|
||||
def label_fake(example):
|
||||
example["is_synthetic"] = 1.0
|
||||
return example
|
||||
|
||||
|
||||
def label_real(example):
|
||||
example["is_synthetic"] = 0.0
|
||||
return example
|
19
resnet/model.py
Normal file
19
resnet/model.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch.nn as nn
|
||||
|
||||
mai = nn.Sequential(
|
||||
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Flatten(),
|
||||
nn.Linear(25 * 25 * 64, 120),
|
||||
nn.ReLU(),
|
||||
nn.Linear(120, 30),
|
||||
nn.ReLU(),
|
||||
nn.Linear(30, 1),
|
||||
)
|
193
resnet/resnet.py
Normal file
193
resnet/resnet.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# "the convolutional layers mostly have 3×3 filters and follow two simple design rules: ..."
|
||||
# He et al., ‘Deep Residual Learning for Image Recognition’
|
||||
RESNET_KERNEL_SIZE = 3
|
||||
|
||||
|
||||
# used to match dimensions of input to output, done by a 1x1 convolution
|
||||
# He et al., ‘Deep Residual Learning for Image Recognition’ page 4
|
||||
def projection_shortcut(in_channels, out_channels):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
# "when the shortcuts go across feature maps of two sizes, they are performed with a stride of 2"
|
||||
# He et al., ‘Deep Residual Learning for Image Recognition’.
|
||||
stride=2,
|
||||
kernel_size=1,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, stride=1, shortcut=None, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.conv0 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=stride, padding=1
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
self.relu = nn.ReLU()
|
||||
self.out_channels = out_channels
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.conv0(x)
|
||||
out = self.conv1(out)
|
||||
if self.shortcut:
|
||||
out += self.shortcut(residual)
|
||||
else:
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
# MAI in ResNet with 34 layers
|
||||
# He et al., ‘Deep Residual Learning for Image Recognition’.
|
||||
class MaiRes(nn.Module):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# first 7x7 conv layer
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=3,
|
||||
out_channels=64,
|
||||
stride=2,
|
||||
padding=3,
|
||||
kernel_size=RESNET_KERNEL_SIZE,
|
||||
)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=RESNET_KERNEL_SIZE, stride=2)
|
||||
|
||||
# layers are named after the colors used for each group
|
||||
# in the diagram presented in the ResNet paper
|
||||
|
||||
# 3 residual blocks for a total of 6 layers
|
||||
self.layer_purple = nn.Sequential(
|
||||
ResidualBlock(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
stride=1,
|
||||
),
|
||||
ResidualBlock(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
stride=1,
|
||||
),
|
||||
ResidualBlock(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
stride=1,
|
||||
),
|
||||
)
|
||||
|
||||
# 4 residual blocks for a total of 8 layers
|
||||
self.layer_green = nn.Sequential(
|
||||
ResidualBlock(
|
||||
in_channels=64,
|
||||
out_channels=128,
|
||||
stride=2,
|
||||
shortcut=projection_shortcut(in_channels=64, out_channels=128),
|
||||
),
|
||||
ResidualBlock(
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
stride=1,
|
||||
),
|
||||
ResidualBlock(
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
stride=1,
|
||||
),
|
||||
ResidualBlock(
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
stride=1,
|
||||
),
|
||||
)
|
||||
|
||||
# 6 residual blocks for a total of 12 layers
|
||||
self.layer_red = nn.Sequential(
|
||||
ResidualBlock(
|
||||
in_channels=128,
|
||||
out_channels=256,
|
||||
stride=2,
|
||||
shortcut=projection_shortcut(in_channels=128, out_channels=256),
|
||||
),
|
||||
ResidualBlock(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
stride=1,
|
||||
),
|
||||
ResidualBlock(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
stride=1,
|
||||
),
|
||||
ResidualBlock(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
stride=1,
|
||||
),
|
||||
ResidualBlock(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
stride=1,
|
||||
),
|
||||
ResidualBlock(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
stride=1,
|
||||
),
|
||||
)
|
||||
|
||||
# 3 residual blocks for a total of 6 layers
|
||||
self.layer_blue = nn.Sequential(
|
||||
ResidualBlock(
|
||||
in_channels=256,
|
||||
out_channels=512,
|
||||
stride=2,
|
||||
shortcut=projection_shortcut(in_channels=256, out_channels=512),
|
||||
),
|
||||
ResidualBlock(
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
stride=1,
|
||||
),
|
||||
ResidualBlock(
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
stride=1,
|
||||
),
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(kernel_size=RESNET_KERNEL_SIZE)
|
||||
self.fc = nn.Linear(in_features=2048, out_features=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer_purple(x)
|
||||
x = self.layer_green(x)
|
||||
x = self.layer_red(x)
|
||||
x = self.layer_blue(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
208
resnet/train.py
Normal file
208
resnet/train.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import datasets
|
||||
import torch
|
||||
import torch.nn
|
||||
import torch.optim
|
||||
import torch.utils.data
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from label import label_fake, label_real
|
||||
from model import mai
|
||||
from augmentation import transform_training, transform_validation
|
||||
from hyperparams import BATCH_SIZE, EPOCHS
|
||||
from resnet import MaiRes
|
||||
|
||||
TEST_SIZE = 0.1
|
||||
|
||||
datasets.logging.set_verbosity(datasets.logging.INFO)
|
||||
|
||||
|
||||
def collate(batch):
|
||||
pixel_values = []
|
||||
is_synthetic = []
|
||||
for row in batch:
|
||||
is_synthetic.append(row["is_synthetic"])
|
||||
pixel_values.append(torch.tensor(row["pixel_values"]))
|
||||
|
||||
pixel_values = torch.stack(pixel_values, dim=0)
|
||||
is_synthetic = torch.tensor(is_synthetic, dtype=torch.float)
|
||||
|
||||
return pixel_values, is_synthetic
|
||||
|
||||
|
||||
def load_data():
|
||||
print("loading datasets...")
|
||||
|
||||
diffusion_db_dataset = load_dataset(
|
||||
"poloclub/diffusiondb",
|
||||
"2m_random_50k",
|
||||
trust_remote_code=True,
|
||||
split="train",
|
||||
)
|
||||
flickr_dataset = load_dataset("nlphuji/flickr30k", split="test")
|
||||
painting_dataset = load_dataset(
|
||||
"keremberke/painting-style-classification", name="full", split="train"
|
||||
)
|
||||
anime_scene_datatset = load_dataset(
|
||||
"animelover/scenery-images", "0-sfw", split="train"
|
||||
)
|
||||
movie_poaster_dataset = load_dataset("nanxstats/movie-poster-5k", split="train")
|
||||
metal_album_art_dataset = load_dataset(
|
||||
"Alphonsce/metal_album_covers", split="train[:50%]"
|
||||
)
|
||||
|
||||
diffusion_db_dataset = diffusion_db_dataset.select_columns("image")
|
||||
diffusion_db_dataset = diffusion_db_dataset.map(label_fake)
|
||||
|
||||
flickr_dataset = flickr_dataset.select_columns("image")
|
||||
flickr_dataset = flickr_dataset.map(label_real)
|
||||
|
||||
painting_dataset = painting_dataset.select_columns("image")
|
||||
painting_dataset = painting_dataset.map(label_real)
|
||||
anime_scene_datatset = anime_scene_datatset.select_columns("image")
|
||||
anime_scene_datatset = anime_scene_datatset.map(label_real)
|
||||
|
||||
movie_poaster_dataset = movie_poaster_dataset.select_columns("image")
|
||||
movie_poaster_dataset = movie_poaster_dataset.map(label_real)
|
||||
|
||||
metal_album_art_dataset = metal_album_art_dataset.select_columns("image")
|
||||
metal_album_art_dataset = metal_album_art_dataset.map(label_real)
|
||||
|
||||
diffusion_split = diffusion_db_dataset.train_test_split(test_size=TEST_SIZE)
|
||||
flickr_split = flickr_dataset.train_test_split(test_size=TEST_SIZE)
|
||||
painting_split = painting_dataset.train_test_split(test_size=TEST_SIZE)
|
||||
anime_scene_split = anime_scene_datatset.train_test_split(test_size=TEST_SIZE)
|
||||
movie_poaster_split = movie_poaster_dataset.train_test_split(test_size=TEST_SIZE)
|
||||
metal_album_art_split = metal_album_art_dataset.train_test_split(
|
||||
test_size=TEST_SIZE
|
||||
)
|
||||
|
||||
training_ds = concatenate_datasets(
|
||||
[
|
||||
diffusion_split["train"],
|
||||
flickr_split["train"],
|
||||
painting_split["train"],
|
||||
anime_scene_split["train"],
|
||||
movie_poaster_split["train"],
|
||||
metal_album_art_split["train"],
|
||||
]
|
||||
)
|
||||
validation_ds = concatenate_datasets(
|
||||
[
|
||||
diffusion_split["test"],
|
||||
flickr_split["test"],
|
||||
painting_split["test"],
|
||||
anime_scene_split["test"],
|
||||
movie_poaster_split["test"],
|
||||
metal_album_art_split["test"],
|
||||
]
|
||||
)
|
||||
|
||||
training_ds = training_ds.map(
|
||||
transform_training, remove_columns=["image"], batched=True
|
||||
)
|
||||
validation_ds = validation_ds.map(
|
||||
transform_validation, remove_columns=["image"], batched=True
|
||||
)
|
||||
|
||||
training_loader = torch.utils.data.DataLoader(
|
||||
training_ds,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
num_workers=6,
|
||||
collate_fn=collate,
|
||||
)
|
||||
validation_loader = torch.utils.data.DataLoader(
|
||||
validation_ds,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
num_workers=6,
|
||||
collate_fn=collate,
|
||||
)
|
||||
|
||||
return training_loader, validation_loader, len(training_ds)
|
||||
|
||||
|
||||
def train():
|
||||
training_loader, validation_loader, sample_size = load_data()
|
||||
|
||||
print(f"sample size: {sample_size}")
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model = MaiRes().cuda()
|
||||
loss_fn = torch.nn.BCEWithLogitsLoss()
|
||||
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)
|
||||
best_vloss = 1_000_000.0
|
||||
|
||||
Path("models").mkdir(parents=True, exist_ok=True)
|
||||
print("models directory created")
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
print(f"EPOCH {epoch + 1}:")
|
||||
|
||||
model.train(True)
|
||||
|
||||
running_loss = 0.0
|
||||
last_loss = 0.0
|
||||
correct = 0.0
|
||||
total = 0.0
|
||||
accuracy = 0.0
|
||||
i = 0
|
||||
|
||||
for i, data in enumerate(training_loader):
|
||||
inputs, labels = data
|
||||
inputs = inputs.cuda()
|
||||
|
||||
labels = labels.cuda()
|
||||
labels = labels.view(-1, 1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(inputs)
|
||||
|
||||
loss = loss_fn(outputs, labels)
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
# Calculate accuracy
|
||||
predicted = (outputs > 0.5).float() # Applying a threshold of 0.5
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
accuracy = 100 * correct / sample_size
|
||||
|
||||
running_loss += loss.item()
|
||||
if i % 10 == 9:
|
||||
last_loss = running_loss / 10 # loss per batch
|
||||
print(" batch {} loss: {}".format(i + 1, last_loss))
|
||||
running_loss = 0.0
|
||||
|
||||
print(f"ACCURACY {accuracy}")
|
||||
|
||||
# validation step
|
||||
running_vloss = 0.0
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for i, validation_data in enumerate(validation_loader):
|
||||
vinputs, vlabels = validation_data
|
||||
vinputs = vinputs.cuda()
|
||||
vlabels = vlabels.cuda()
|
||||
vlabels = vlabels.view(-1, 1)
|
||||
|
||||
voutputs = model(vinputs)
|
||||
vloss = loss_fn(voutputs, vlabels)
|
||||
|
||||
running_vloss += vloss
|
||||
|
||||
avg_validation_loss = running_vloss / (i + 1)
|
||||
print("LOSS train {} valid {}".format(last_loss, avg_validation_loss))
|
||||
|
||||
if avg_validation_loss < best_vloss:
|
||||
best_vloss = avg_validation_loss
|
||||
model_path = f"model/mai_{timestamp}_{epoch}"
|
||||
torch.save(model.state_dict(), model_path)
|
||||
print("current model state saved")
|
||||
|
||||
|
||||
train()
|
Reference in New Issue
Block a user