feat: add code for finetuning moondream

This commit is contained in:
2024-12-08 17:33:37 +00:00
parent 4bf2c89bf8
commit 6b53cb0411
11 changed files with 261 additions and 0 deletions

51
resnet/augmentation.py Normal file
View File

@@ -0,0 +1,51 @@
import albumentations as a
import numpy as np
from albumentations.pytorch import ToTensorV2
from hyperparams import CROP_SIZE
preprocess_training = a.Compose(
[
a.augmentations.PadIfNeeded(min_width=CROP_SIZE, min_height=CROP_SIZE),
a.RandomCrop(width=CROP_SIZE, height=CROP_SIZE),
a.GaussNoise(),
a.Flip(p=0.5),
a.RandomRotate90(p=0.5),
a.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
]
)
preprocess_validation = a.Compose(
[
a.augmentations.PadIfNeeded(min_width=CROP_SIZE, min_height=CROP_SIZE),
a.CenterCrop(width=CROP_SIZE, height=CROP_SIZE),
a.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
]
)
def transform_training(example):
transformed = []
for pil_image in example["image"]:
array = np.array(pil_image.convert("RGB"))
# check if image is in (height, width, channel) shape
# if not, do a transpose
if array.shape[-1] != 3:
array = np.transpose(array, (1, 2, 0))
img = preprocess_training(image=array)["image"]
transformed.append(img)
example["pixel_values"] = transformed
return example
def transform_validation(example):
transformed = []
for pil_image in example["image"]:
array = np.array(pil_image.convert("RGB"))
if array.shape[-1] != 3:
array = np.transpose(array, (1, 2, 0))
img = preprocess_validation(image=array)["image"]
transformed.append(img)
example["pixel_values"] = transformed
return example

5
resnet/hyperparams.py Normal file
View File

@@ -0,0 +1,5 @@
FILTER_COUNT = 32
KERNEL_SIZE = 2
CROP_SIZE = 224
BATCH_SIZE = 256
EPOCHS = 20

21
resnet/inference.py Normal file
View File

@@ -0,0 +1,21 @@
import torch
import numpy as np
from PIL import Image
from model import mai
from augmentation import preprocess_validation
def load_model_and_run_inference(img):
mai.load_state_dict(torch.load("mai"))
mai.eval()
img_batch = np.expand_dims(img, axis=0)
img_batch = torch.tensor(img_batch)
prediction = mai(img_batch)
prediction = torch.sigmoid(prediction)
print(prediction)
def main():
img = Image.open("test_images/dog.jpg")
img = preprocess_validation(image=np.array(img))["image"]
load_model_and_run_inference(img)

8
resnet/label.py Normal file
View File

@@ -0,0 +1,8 @@
def label_fake(example):
example["is_synthetic"] = 1.0
return example
def label_real(example):
example["is_synthetic"] = 0.0
return example

19
resnet/model.py Normal file
View File

@@ -0,0 +1,19 @@
import torch.nn as nn
mai = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(25 * 25 * 64, 120),
nn.ReLU(),
nn.Linear(120, 30),
nn.ReLU(),
nn.Linear(30, 1),
)

193
resnet/resnet.py Normal file
View File

@@ -0,0 +1,193 @@
import torch
import torch.nn as nn
# "the convolutional layers mostly have 3×3 filters and follow two simple design rules: ..."
# He et al., Deep Residual Learning for Image Recognition
RESNET_KERNEL_SIZE = 3
# used to match dimensions of input to output, done by a 1x1 convolution
# He et al., Deep Residual Learning for Image Recognition page 4
def projection_shortcut(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
# "when the shortcuts go across feature maps of two sizes, they are performed with a stride of 2"
# He et al., Deep Residual Learning for Image Recognition.
stride=2,
kernel_size=1,
),
nn.BatchNorm2d(out_channels),
)
class ResidualBlock(nn.Module):
def __init__(
self, in_channels, out_channels, stride=1, shortcut=None, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.conv0 = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=stride, padding=1
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
self.conv1 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
)
self.relu = nn.ReLU()
self.out_channels = out_channels
self.shortcut = shortcut
def forward(self, x):
residual = x
out = self.conv0(x)
out = self.conv1(out)
if self.shortcut:
out += self.shortcut(residual)
else:
out += residual
out = self.relu(out)
return out
# MAI in ResNet with 34 layers
# He et al., Deep Residual Learning for Image Recognition.
class MaiRes(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# first 7x7 conv layer
self.conv = nn.Conv2d(
in_channels=3,
out_channels=64,
stride=2,
padding=3,
kernel_size=RESNET_KERNEL_SIZE,
)
self.maxpool = nn.MaxPool2d(kernel_size=RESNET_KERNEL_SIZE, stride=2)
# layers are named after the colors used for each group
# in the diagram presented in the ResNet paper
# 3 residual blocks for a total of 6 layers
self.layer_purple = nn.Sequential(
ResidualBlock(
in_channels=64,
out_channels=64,
stride=1,
),
ResidualBlock(
in_channels=64,
out_channels=64,
stride=1,
),
ResidualBlock(
in_channels=64,
out_channels=64,
stride=1,
),
)
# 4 residual blocks for a total of 8 layers
self.layer_green = nn.Sequential(
ResidualBlock(
in_channels=64,
out_channels=128,
stride=2,
shortcut=projection_shortcut(in_channels=64, out_channels=128),
),
ResidualBlock(
in_channels=128,
out_channels=128,
stride=1,
),
ResidualBlock(
in_channels=128,
out_channels=128,
stride=1,
),
ResidualBlock(
in_channels=128,
out_channels=128,
stride=1,
),
)
# 6 residual blocks for a total of 12 layers
self.layer_red = nn.Sequential(
ResidualBlock(
in_channels=128,
out_channels=256,
stride=2,
shortcut=projection_shortcut(in_channels=128, out_channels=256),
),
ResidualBlock(
in_channels=256,
out_channels=256,
stride=1,
),
ResidualBlock(
in_channels=256,
out_channels=256,
stride=1,
),
ResidualBlock(
in_channels=256,
out_channels=256,
stride=1,
),
ResidualBlock(
in_channels=256,
out_channels=256,
stride=1,
),
ResidualBlock(
in_channels=256,
out_channels=256,
stride=1,
),
)
# 3 residual blocks for a total of 6 layers
self.layer_blue = nn.Sequential(
ResidualBlock(
in_channels=256,
out_channels=512,
stride=2,
shortcut=projection_shortcut(in_channels=256, out_channels=512),
),
ResidualBlock(
in_channels=512,
out_channels=512,
stride=1,
),
ResidualBlock(
in_channels=512,
out_channels=512,
stride=1,
),
)
self.avgpool = nn.AvgPool2d(kernel_size=RESNET_KERNEL_SIZE)
self.fc = nn.Linear(in_features=2048, out_features=1)
def forward(self, x):
x = self.conv(x)
x = self.maxpool(x)
x = self.layer_purple(x)
x = self.layer_green(x)
x = self.layer_red(x)
x = self.layer_blue(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

208
resnet/train.py Normal file
View File

@@ -0,0 +1,208 @@
import datasets
import torch
import torch.nn
import torch.optim
import torch.utils.data
from pathlib import Path
from datetime import datetime
from datasets import concatenate_datasets, load_dataset
from label import label_fake, label_real
from model import mai
from augmentation import transform_training, transform_validation
from hyperparams import BATCH_SIZE, EPOCHS
from resnet import MaiRes
TEST_SIZE = 0.1
datasets.logging.set_verbosity(datasets.logging.INFO)
def collate(batch):
pixel_values = []
is_synthetic = []
for row in batch:
is_synthetic.append(row["is_synthetic"])
pixel_values.append(torch.tensor(row["pixel_values"]))
pixel_values = torch.stack(pixel_values, dim=0)
is_synthetic = torch.tensor(is_synthetic, dtype=torch.float)
return pixel_values, is_synthetic
def load_data():
print("loading datasets...")
diffusion_db_dataset = load_dataset(
"poloclub/diffusiondb",
"2m_random_50k",
trust_remote_code=True,
split="train",
)
flickr_dataset = load_dataset("nlphuji/flickr30k", split="test")
painting_dataset = load_dataset(
"keremberke/painting-style-classification", name="full", split="train"
)
anime_scene_datatset = load_dataset(
"animelover/scenery-images", "0-sfw", split="train"
)
movie_poaster_dataset = load_dataset("nanxstats/movie-poster-5k", split="train")
metal_album_art_dataset = load_dataset(
"Alphonsce/metal_album_covers", split="train[:50%]"
)
diffusion_db_dataset = diffusion_db_dataset.select_columns("image")
diffusion_db_dataset = diffusion_db_dataset.map(label_fake)
flickr_dataset = flickr_dataset.select_columns("image")
flickr_dataset = flickr_dataset.map(label_real)
painting_dataset = painting_dataset.select_columns("image")
painting_dataset = painting_dataset.map(label_real)
anime_scene_datatset = anime_scene_datatset.select_columns("image")
anime_scene_datatset = anime_scene_datatset.map(label_real)
movie_poaster_dataset = movie_poaster_dataset.select_columns("image")
movie_poaster_dataset = movie_poaster_dataset.map(label_real)
metal_album_art_dataset = metal_album_art_dataset.select_columns("image")
metal_album_art_dataset = metal_album_art_dataset.map(label_real)
diffusion_split = diffusion_db_dataset.train_test_split(test_size=TEST_SIZE)
flickr_split = flickr_dataset.train_test_split(test_size=TEST_SIZE)
painting_split = painting_dataset.train_test_split(test_size=TEST_SIZE)
anime_scene_split = anime_scene_datatset.train_test_split(test_size=TEST_SIZE)
movie_poaster_split = movie_poaster_dataset.train_test_split(test_size=TEST_SIZE)
metal_album_art_split = metal_album_art_dataset.train_test_split(
test_size=TEST_SIZE
)
training_ds = concatenate_datasets(
[
diffusion_split["train"],
flickr_split["train"],
painting_split["train"],
anime_scene_split["train"],
movie_poaster_split["train"],
metal_album_art_split["train"],
]
)
validation_ds = concatenate_datasets(
[
diffusion_split["test"],
flickr_split["test"],
painting_split["test"],
anime_scene_split["test"],
movie_poaster_split["test"],
metal_album_art_split["test"],
]
)
training_ds = training_ds.map(
transform_training, remove_columns=["image"], batched=True
)
validation_ds = validation_ds.map(
transform_validation, remove_columns=["image"], batched=True
)
training_loader = torch.utils.data.DataLoader(
training_ds,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=6,
collate_fn=collate,
)
validation_loader = torch.utils.data.DataLoader(
validation_ds,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=6,
collate_fn=collate,
)
return training_loader, validation_loader, len(training_ds)
def train():
training_loader, validation_loader, sample_size = load_data()
print(f"sample size: {sample_size}")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model = MaiRes().cuda()
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)
best_vloss = 1_000_000.0
Path("models").mkdir(parents=True, exist_ok=True)
print("models directory created")
for epoch in range(EPOCHS):
print(f"EPOCH {epoch + 1}:")
model.train(True)
running_loss = 0.0
last_loss = 0.0
correct = 0.0
total = 0.0
accuracy = 0.0
i = 0
for i, data in enumerate(training_loader):
inputs, labels = data
inputs = inputs.cuda()
labels = labels.cuda()
labels = labels.view(-1, 1)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
# Calculate accuracy
predicted = (outputs > 0.5).float() # Applying a threshold of 0.5
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / sample_size
running_loss += loss.item()
if i % 10 == 9:
last_loss = running_loss / 10 # loss per batch
print(" batch {} loss: {}".format(i + 1, last_loss))
running_loss = 0.0
print(f"ACCURACY {accuracy}")
# validation step
running_vloss = 0.0
model.eval()
with torch.no_grad():
for i, validation_data in enumerate(validation_loader):
vinputs, vlabels = validation_data
vinputs = vinputs.cuda()
vlabels = vlabels.cuda()
vlabels = vlabels.view(-1, 1)
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_validation_loss = running_vloss / (i + 1)
print("LOSS train {} valid {}".format(last_loss, avg_validation_loss))
if avg_validation_loss < best_vloss:
best_vloss = avg_validation_loss
model_path = f"model/mai_{timestamp}_{epoch}"
torch.save(model.state_dict(), model_path)
print("current model state saved")
train()