feat: add code for finetuning moondream

This commit is contained in:
2024-12-08 17:33:37 +00:00
parent 4bf2c89bf8
commit 6b53cb0411
11 changed files with 261 additions and 0 deletions

64
moondream/test.py Normal file
View File

@@ -0,0 +1,64 @@
import torch
import datasets
import transformers
import pathlib
DEVICE = "cuda"
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
MD_REVISION = "2024-07-23"
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
moondream = transformers.AutoModelForCausalLM.from_pretrained(
"./checkpoints/moondream-mai",
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=DTYPE,
device_map={"": DEVICE},
)
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\
.shuffle()\
.take(100)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is an AI image."
}
})
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
.shuffle()\
.take(100)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is a real image."
}
})
dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle()
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
for i, sample in enumerate(dataset):
sample['image'].save(f"samples/{i}.png", "PNG")
md_answer = moondream.answer_question(
moondream.encode_image(sample['image']),
sample['qa']['question'],
tokenizer=tokenizer,
num_beams=4,
no_repeat_ngram_size=5,
early_stopping=True
)
if i < 3:
print('Question:', sample['qa']['question'])
print('Ground Truth:', sample['qa']['answer'])
print('Moondream:', md_answer)
else:
break