Files
mai/moondream/train.py

305 lines
11 KiB
Python
Raw Normal View History

import math
import torch
import datasets
import transformers
import bitsandbytes
2024-12-09 00:41:54 +00:00
import pathlib
2024-12-13 22:19:53 +00:00
import io
import PIL
import utils.datasets
from tqdm import tqdm
from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS
DEVICE = "cuda"
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
MD_REVISION = "2024-07-23"
2025-01-12 00:42:14 +00:00
TOTAL_DATA_SIZE = 20000
2024-12-13 22:19:53 +00:00
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", split="train", trust_remote_code=True, streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
2024-12-13 22:19:53 +00:00
"question": "Is this image AI generated?",
"answer": "Yes."
}
2024-12-13 22:19:53 +00:00
})
2025-01-12 00:42:14 +00:00
diffusion_db_dataset = utils.datasets.split_streaming_dataset(diffusion_db_dataset, total_size=5000, test_size=TEST_SIZE)
2024-12-13 22:19:53 +00:00
midjourney_dataset = datasets.load_dataset("brivangl/midjourney-v6-llava", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "Yes."
}
})
2025-01-12 00:42:14 +00:00
midjourney_dataset = utils.datasets.split_streaming_dataset(midjourney_dataset, total_size=5000, test_size=TEST_SIZE)
2024-12-13 22:19:53 +00:00
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
2024-12-13 22:19:53 +00:00
"question": "Is this image AI generated?",
"answer": "No."
}
2024-12-13 22:19:53 +00:00
})
2025-01-12 00:42:14 +00:00
flickr_dataset = utils.datasets.split_streaming_dataset(flickr_dataset, total_size=1250, test_size=TEST_SIZE)
2024-12-13 22:19:53 +00:00
wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", streaming=True)\
2024-12-09 00:41:54 +00:00
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
2024-12-13 22:19:53 +00:00
"question": "Is this image AI generated?",
"answer": "No."
2024-12-09 00:41:54 +00:00
}
2024-12-13 22:19:53 +00:00
})
2025-01-12 00:42:14 +00:00
wiki_art_dataset = utils.datasets.split_streaming_dataset(wiki_art_dataset, total_size=1250, test_size=TEST_SIZE)
2024-12-13 22:19:53 +00:00
anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust_remote_code=True, split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No."
}
})
2025-01-12 00:42:14 +00:00
anime_dataset = utils.datasets.split_streaming_dataset(anime_dataset, total_size=1250, test_size=TEST_SIZE)
2024-12-13 22:19:53 +00:00
coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No."
}
})
2025-01-12 00:42:14 +00:00
coco_dataset = utils.datasets.split_streaming_dataset(coco_dataset, total_size=1250, test_size=TEST_SIZE)
2024-12-13 22:19:53 +00:00
movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split="train", streaming=True)\
2025-01-12 00:42:14 +00:00
.select_columns(["image"])\
2024-12-13 22:19:53 +00:00
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No."
}
})
2025-01-12 00:42:14 +00:00
movie_poster_dataset = utils.datasets.split_streaming_dataset(movie_poster_dataset, total_size=1250, test_size=TEST_SIZE)
cars_dataset = datasets.load_dataset("tanganke/stanford_cars", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No."
}
})
cars_dataset = utils.datasets.split_streaming_dataset(cars_dataset, total_size=1250, test_size=TEST_SIZE)
website_dataset = datasets.load_dataset("silatus/1k_Website_Screenshots_and_Metadata", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No.",
}
})
website_dataset = utils.datasets.split_streaming_dataset(website_dataset, total_size=1250, test_size=TEST_SIZE)
movie_scene_dataset = datasets.load_dataset("unography/movie-scenes-resized-captioned", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No.",
}
})
movie_scene_dataset = utils.datasets.split_streaming_dataset(movie_scene_dataset, total_size=1250, test_size=TEST_SIZE)
2024-12-09 00:41:54 +00:00
2024-12-13 22:19:53 +00:00
training_dataset = datasets.interleave_datasets([
2024-12-09 00:41:54 +00:00
diffusion_db_dataset["train"],
2024-12-13 22:19:53 +00:00
midjourney_dataset["train"],
2024-12-09 00:41:54 +00:00
flickr_dataset["train"],
wiki_art_dataset["train"],
2024-12-13 22:19:53 +00:00
anime_dataset["train"],
coco_dataset["train"],
movie_poster_dataset["train"],
2025-01-12 00:42:14 +00:00
cars_dataset["train"],
website_dataset["train"],
movie_scene_dataset["train"],
2024-12-13 22:19:53 +00:00
], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True))
test_dataset = datasets.interleave_datasets([
2024-12-09 00:41:54 +00:00
diffusion_db_dataset["test"],
2024-12-13 22:19:53 +00:00
midjourney_dataset["test"],
2024-12-09 00:41:54 +00:00
flickr_dataset["test"],
wiki_art_dataset["test"],
2024-12-13 22:19:53 +00:00
anime_dataset["test"],
coco_dataset["test"],
movie_poster_dataset["test"],
2025-01-12 00:42:14 +00:00
cars_dataset["test"],
website_dataset["test"],
movie_scene_dataset["test"],
2024-12-13 22:19:53 +00:00
], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True))
print("Training and test dataset prepared.")
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
moondream = transformers.AutoModelForCausalLM.from_pretrained(
"vikhyatk/moondream2",
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=DTYPE,
device_map={"": DEVICE},
)
def collate(batch):
images = []
all_tokens = []
all_labels = []
for sample in batch:
images.append(sample["image"])
tokens = [tokenizer.bos_token_id]
labels = [-100] * (IMG_TOKENS + 1)
qa = sample["qa"]
q_t = tokenizer(
f"\n\nQuestion: {qa['question']}\n\nAnswer:",
add_special_tokens=False,
).input_ids
tokens.extend(q_t)
labels.extend([-100] * len(q_t))
a_t = tokenizer(
f" {qa['answer']}{ANSWER_EOS}",
add_special_tokens=False,
).input_ids
tokens.extend(a_t)
labels.extend(a_t)
all_tokens.append(tokens)
all_labels.append(labels)
longest_label_len = -1
for label in all_labels:
longest_label_len = max(longest_label_len, len(label))
all_attn_masks = []
for i in range(len(batch)):
label_len = len(all_labels[i])
pad_len = longest_label_len - label_len
all_labels[i].extend([-100] * pad_len)
all_tokens[i].extend([tokenizer.eos_token_id] * pad_len)
all_attn_masks.append([1] * label_len + [0] * pad_len)
return (
images,
torch.stack([torch.tensor(token, dtype=torch.long) for token in all_tokens]),
torch.stack([torch.tensor(label, dtype=torch.long) for label in all_labels]),
torch.stack([torch.tensor(mask, dtype=torch.bool) for mask in all_attn_masks]),
)
def compute_loss(batch):
images, tokens, labels, masks = batch
tokens = tokens.to(DEVICE)
labels = labels.to(DEVICE)
masks = masks.to(DEVICE)
with torch.no_grad():
img_embeds = moondream.vision_encoder(images)
token_embeds = moondream.text_model.get_input_embeddings()(tokens)
# start with embedding vector that represents bos, then insert image embeds, then the rest of the token embeds
# <BOS> + the image + all the tokens
inputs_embeds = torch.cat((token_embeds[:, 0:1, :], img_embeds, token_embeds[:, 1:, :]), dim=1)
outputs = moondream.text_model(
inputs_embeds=inputs_embeds,
labels=labels,
attention_mask=masks,
)
return outputs.loss
def lr_schedule(step, max_steps):
x = step / max_steps
if x < 0.1:
return 0.1 * LR + 0.9 * LR * x / 0.1
else:
return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2
dataloaders = {
"train": torch.utils.data.DataLoader(
training_dataset,
batch_size=BATCH_SIZE,
collate_fn=collate,
2024-12-09 00:41:54 +00:00
),
}
moondream.text_model.train()
moondream.text_model.transformer.gradient_checkpointing_enable()
2024-12-13 22:19:53 +00:00
total_steps = EPOCHS * (TOTAL_DATA_SIZE * (1 - TEST_SIZE)) // GRAD_ACCUM_STEPS
optimizer = bitsandbytes.optim.Adam8bit(
[{"params": moondream.text_model.parameters()}],
lr=LR*0.1,
betas=(0.9, 0.95),
eps=1e-6,
)
i = 0
for epoch in range(EPOCHS):
for batch in tqdm(dataloaders["train"], desc=f"Epoch {epoch + 1}/{EPOCHS}"):
i += 1
loss = compute_loss(batch)
loss.backward()
if i % GRAD_ACCUM_STEPS == 0:
optimizer.step()
optimizer.zero_grad()
lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
moondream.save_pretrained("checkpoints/moondream-mai")
2024-12-09 00:41:54 +00:00
2024-12-13 22:19:53 +00:00
moondream.eval()
2024-12-09 00:41:54 +00:00
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
2025-01-12 00:42:14 +00:00
total = 0
2024-12-09 00:41:54 +00:00
correct_predictions = 0
for sample in tqdm(test_dataset, desc="Validation"):
2025-01-12 00:42:14 +00:00
total += 1
2024-12-09 00:41:54 +00:00
md_answer = moondream.answer_question(
moondream.encode_image(sample['image']),
sample['qa']['question'],
tokenizer=tokenizer,
num_beams=4,
no_repeat_ngram_size=5,
early_stopping=True
)
ground_truth = sample["qa"]["answer"]
if md_answer == ground_truth:
correct_predictions += 1
2025-01-12 00:42:14 +00:00
accuracy = correct_predictions * 100 / total
2024-12-09 00:41:54 +00:00
print(f"Model accuracy: f{accuracy}%")