65 lines
2.3 KiB
Python
65 lines
2.3 KiB
Python
import torch
|
|
import datasets
|
|
import transformers
|
|
import pathlib
|
|
|
|
DEVICE = "cuda"
|
|
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
|
|
MD_REVISION = "2024-07-23"
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
|
|
moondream = transformers.AutoModelForCausalLM.from_pretrained(
|
|
"./checkpoints/moondream-mai",
|
|
trust_remote_code=True,
|
|
attn_implementation="flash_attention_2",
|
|
torch_dtype=DTYPE,
|
|
device_map={"": DEVICE},
|
|
)
|
|
|
|
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\
|
|
.shuffle()\
|
|
.take(100)\
|
|
.select_columns(["image"])\
|
|
.map(lambda row: {
|
|
**row,
|
|
"qa": {
|
|
"question": "Describe this image.",
|
|
"answer": "This is an AI image."
|
|
}
|
|
})
|
|
|
|
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
|
|
.shuffle()\
|
|
.take(100)\
|
|
.select_columns(["image"])\
|
|
.map(lambda row: {
|
|
**row,
|
|
"qa": {
|
|
"question": "Describe this image.",
|
|
"answer": "This is a real image."
|
|
}
|
|
})
|
|
|
|
dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle()
|
|
|
|
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
|
|
|
|
for i, sample in enumerate(dataset):
|
|
sample['image'].save(f"samples/{i}.png", "PNG")
|
|
|
|
md_answer = moondream.answer_question(
|
|
moondream.encode_image(sample['image']),
|
|
sample['qa']['question'],
|
|
tokenizer=tokenizer,
|
|
num_beams=4,
|
|
no_repeat_ngram_size=5,
|
|
early_stopping=True
|
|
)
|
|
|
|
if i < 3:
|
|
print('Question:', sample['qa']['question'])
|
|
print('Ground Truth:', sample['qa']['answer'])
|
|
print('Moondream:', md_answer)
|
|
else:
|
|
break
|