feat: add code for finetuning moondream
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -250,3 +250,5 @@ $RECYCLE.BIN/
|
|||||||
# End of https://www.toptal.com/developers/gitignore/api/python,macos,linux,windows
|
# End of https://www.toptal.com/developers/gitignore/api/python,macos,linux,windows
|
||||||
|
|
||||||
test_images/
|
test_images/
|
||||||
|
checkpoints/
|
||||||
|
samples/
|
||||||
|
31
moondream/hyperparams.py
Normal file
31
moondream/hyperparams.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
TEST_SIZE = 0.2
|
||||||
|
|
||||||
|
# Number of times to repeat the training dataset. Increasing this may cause the model to overfit or
|
||||||
|
# lose generalization due to catastrophic forgetting. Decreasing it may cause the model to underfit.
|
||||||
|
EPOCHS = 1
|
||||||
|
|
||||||
|
# Number of samples to process in each batch. Set this to the highest value that doesn't cause an
|
||||||
|
# out-of-memory error. Decrease it if you're running out of memory.
|
||||||
|
BATCH_SIZE = 8
|
||||||
|
|
||||||
|
# Number of batches to process before updating the model. You can use this to simulate a higher batch
|
||||||
|
# size than your GPU can handle. Set this to 1 to disable gradient accumulation.
|
||||||
|
GRAD_ACCUM_STEPS = 2
|
||||||
|
|
||||||
|
# Learning rate for the Adam optimizer. Needs to be tuned on a case-by-case basis. As a general rule
|
||||||
|
# of thumb, increase it by 1.4 times each time you double the effective batch size.
|
||||||
|
#
|
||||||
|
# Source: https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
|
||||||
|
#
|
||||||
|
# Note that we linearly warm the learning rate up from 0.1 * LR to LR over the first 10% of the
|
||||||
|
# training run, and then decay it back to 0.1 * LR over the last 90% of the training run using a
|
||||||
|
# cosine schedule.
|
||||||
|
LR = 1e-5
|
||||||
|
|
||||||
|
# Whether to use Weights and Biases for logging training metrics.
|
||||||
|
USE_WANDB = False
|
||||||
|
|
||||||
|
ANSWER_EOS = "<|endoftext|>"
|
||||||
|
|
||||||
|
# Number of tokens used to represent each image.
|
||||||
|
IMG_TOKENS = 729
|
64
moondream/test.py
Normal file
64
moondream/test.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import torch
|
||||||
|
import datasets
|
||||||
|
import transformers
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
DEVICE = "cuda"
|
||||||
|
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
|
||||||
|
MD_REVISION = "2024-07-23"
|
||||||
|
|
||||||
|
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
|
||||||
|
moondream = transformers.AutoModelForCausalLM.from_pretrained(
|
||||||
|
"./checkpoints/moondream-mai",
|
||||||
|
trust_remote_code=True,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
torch_dtype=DTYPE,
|
||||||
|
device_map={"": DEVICE},
|
||||||
|
)
|
||||||
|
|
||||||
|
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\
|
||||||
|
.shuffle()\
|
||||||
|
.take(100)\
|
||||||
|
.select_columns(["image"])\
|
||||||
|
.map(lambda row: {
|
||||||
|
**row,
|
||||||
|
"qa": {
|
||||||
|
"question": "Describe this image.",
|
||||||
|
"answer": "This is an AI image."
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
|
||||||
|
.shuffle()\
|
||||||
|
.take(100)\
|
||||||
|
.select_columns(["image"])\
|
||||||
|
.map(lambda row: {
|
||||||
|
**row,
|
||||||
|
"qa": {
|
||||||
|
"question": "Describe this image.",
|
||||||
|
"answer": "This is a real image."
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle()
|
||||||
|
|
||||||
|
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for i, sample in enumerate(dataset):
|
||||||
|
sample['image'].save(f"samples/{i}.png", "PNG")
|
||||||
|
|
||||||
|
md_answer = moondream.answer_question(
|
||||||
|
moondream.encode_image(sample['image']),
|
||||||
|
sample['qa']['question'],
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_beams=4,
|
||||||
|
no_repeat_ngram_size=5,
|
||||||
|
early_stopping=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print('Question:', sample['qa']['question'])
|
||||||
|
print('Ground Truth:', sample['qa']['answer'])
|
||||||
|
print('Moondream:', md_answer)
|
||||||
|
else:
|
||||||
|
break
|
164
moondream/train.py
Normal file
164
moondream/train.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import datasets
|
||||||
|
import transformers
|
||||||
|
import bitsandbytes
|
||||||
|
from tqdm import tqdm
|
||||||
|
from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS
|
||||||
|
|
||||||
|
DEVICE = "cuda"
|
||||||
|
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
|
||||||
|
MD_REVISION = "2024-07-23"
|
||||||
|
|
||||||
|
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\
|
||||||
|
.select_columns(["image"])\
|
||||||
|
.map(lambda row: {
|
||||||
|
**row,
|
||||||
|
"qa": {
|
||||||
|
"question": "Describe this image.",
|
||||||
|
"answer": "This is an AI image."
|
||||||
|
}
|
||||||
|
})\
|
||||||
|
.train_test_split(test_size=TEST_SIZE)
|
||||||
|
|
||||||
|
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
|
||||||
|
.take(5000)\
|
||||||
|
.select_columns(["image"])\
|
||||||
|
.map(lambda row: {
|
||||||
|
**row,
|
||||||
|
"qa": {
|
||||||
|
"question": "Describe this image.",
|
||||||
|
"answer": "This is a real image."
|
||||||
|
}
|
||||||
|
})\
|
||||||
|
.train_test_split(test_size=TEST_SIZE)
|
||||||
|
|
||||||
|
training_dataset = datasets.concatenate_datasets([diffusion_db_dataset["train"], flickr_dataset["train"]]).shuffle()
|
||||||
|
test_dataset = datasets.concatenate_datasets([diffusion_db_dataset["test"], flickr_dataset["test"]]).shuffle()
|
||||||
|
|
||||||
|
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
|
||||||
|
moondream = transformers.AutoModelForCausalLM.from_pretrained(
|
||||||
|
"vikhyatk/moondream2",
|
||||||
|
trust_remote_code=True,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
torch_dtype=DTYPE,
|
||||||
|
device_map={"": DEVICE},
|
||||||
|
)
|
||||||
|
|
||||||
|
def collate(batch):
|
||||||
|
images = []
|
||||||
|
all_tokens = []
|
||||||
|
all_labels = []
|
||||||
|
|
||||||
|
for sample in batch:
|
||||||
|
images.append(sample["image"])
|
||||||
|
|
||||||
|
tokens = [tokenizer.bos_token_id]
|
||||||
|
labels = [-100] * (IMG_TOKENS + 1)
|
||||||
|
|
||||||
|
qa = sample["qa"]
|
||||||
|
q_t = tokenizer(
|
||||||
|
f"\n\nQuestion: {qa['question']}\n\nAnswer:",
|
||||||
|
add_special_tokens=False,
|
||||||
|
).input_ids
|
||||||
|
tokens.extend(q_t)
|
||||||
|
labels.extend([-100] * len(q_t))
|
||||||
|
|
||||||
|
a_t = tokenizer(
|
||||||
|
f" {qa['answer']}{ANSWER_EOS}",
|
||||||
|
add_special_tokens=False,
|
||||||
|
).input_ids
|
||||||
|
tokens.extend(a_t)
|
||||||
|
labels.extend(a_t)
|
||||||
|
|
||||||
|
all_tokens.append(tokens)
|
||||||
|
all_labels.append(labels)
|
||||||
|
|
||||||
|
longest_label_len = -1
|
||||||
|
for label in all_labels:
|
||||||
|
longest_label_len = max(longest_label_len, len(label))
|
||||||
|
|
||||||
|
all_attn_masks = []
|
||||||
|
for i in range(len(batch)):
|
||||||
|
label_len = len(all_labels[i])
|
||||||
|
pad_len = longest_label_len - label_len
|
||||||
|
|
||||||
|
all_labels[i].extend([-100] * pad_len)
|
||||||
|
all_tokens[i].extend([tokenizer.eos_token_id] * pad_len)
|
||||||
|
all_attn_masks.append([1] * label_len + [0] * pad_len)
|
||||||
|
|
||||||
|
return (
|
||||||
|
images,
|
||||||
|
torch.stack([torch.tensor(token, dtype=torch.long) for token in all_tokens]),
|
||||||
|
torch.stack([torch.tensor(label, dtype=torch.long) for label in all_labels]),
|
||||||
|
torch.stack([torch.tensor(mask, dtype=torch.bool) for mask in all_attn_masks]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_loss(batch):
|
||||||
|
images, tokens, labels, masks = batch
|
||||||
|
|
||||||
|
tokens = tokens.to(DEVICE)
|
||||||
|
labels = labels.to(DEVICE)
|
||||||
|
masks = masks.to(DEVICE)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
img_embeds = moondream.vision_encoder(images)
|
||||||
|
|
||||||
|
token_embeds = moondream.text_model.get_input_embeddings()(tokens)
|
||||||
|
|
||||||
|
# start with embedding vector that represents bos, then insert image embeds, then the rest of the token embeds
|
||||||
|
# <BOS> + the image + all the tokens
|
||||||
|
inputs_embeds = torch.cat((token_embeds[:, 0:1, :], img_embeds, token_embeds[:, 1:, :]), dim=1)
|
||||||
|
|
||||||
|
outputs = moondream.text_model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
labels=labels,
|
||||||
|
attention_mask=masks,
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs.loss
|
||||||
|
|
||||||
|
def lr_schedule(step, max_steps):
|
||||||
|
x = step / max_steps
|
||||||
|
if x < 0.1:
|
||||||
|
return 0.1 * LR + 0.9 * LR * x / 0.1
|
||||||
|
else:
|
||||||
|
return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2
|
||||||
|
|
||||||
|
dataloaders = {
|
||||||
|
"train": torch.utils.data.DataLoader(
|
||||||
|
training_dataset,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=collate,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
moondream.text_model.train()
|
||||||
|
moondream.text_model.transformer.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS
|
||||||
|
optimizer = bitsandbytes.optim.Adam8bit(
|
||||||
|
[{"params": moondream.text_model.parameters()}],
|
||||||
|
lr=LR*0.1,
|
||||||
|
betas=(0.9, 0.95),
|
||||||
|
eps=1e-6,
|
||||||
|
)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for epoch in range(EPOCHS):
|
||||||
|
for batch in tqdm(dataloaders["train"], desc=f"Epoch {epoch + 1}/{EPOCHS}"):
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
loss = compute_loss(batch)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
if i % GRAD_ACCUM_STEPS == 0:
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group["lr"] = lr
|
||||||
|
|
||||||
|
moondream.save_pretrained("checkpoints/moondream-mai")
|
Reference in New Issue
Block a user