increase data diversity

This commit is contained in:
2024-12-13 22:19:53 +00:00
parent 46b60151c7
commit 14c6f26ddc
4 changed files with 150 additions and 84 deletions

View File

@@ -18,45 +18,9 @@ moondream = transformers.AutoModelForCausalLM.from_pretrained(
device_map={"": DEVICE},
)
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\
.shuffle()\
.take(100)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is an AI image."
}
})
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
.shuffle()\
.take(100)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is a real image."
}
})
midjourney_dataset = datasets.load_dataset("ehristoforu/midjourney-images", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is an AI image."
}
})
dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle()
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
img = Image.open("samples/frames_3.jpg")
img = Image.open("samples/Untitled.jpg")
md_answer = moondream.answer_question(
moondream.encode_image(img),
"Describe this image.",
@@ -68,28 +32,28 @@ md_answer = moondream.answer_question(
print(md_answer)
correct_predictions = 0
for i, sample in enumerate(midjourney_dataset):
if i > 4:
break
# correct_predictions = 0
# for i, sample in enumerate(flickr_dataset):
# if i > 4:
# break
sample["image"].save(f"samples/{i}.png", "PNG")
# sample["image"].save(f"samples/{i}.png", "PNG")
md_answer = moondream.answer_question(
moondream.encode_image(sample['image']),
sample['qa']['question'],
tokenizer=tokenizer,
num_beams=4,
no_repeat_ngram_size=5,
early_stopping=True
)
# md_answer = moondream.answer_question(
# moondream.encode_image(sample['image']),
# sample['qa']['question'],
# tokenizer=tokenizer,
# num_beams=4,
# no_repeat_ngram_size=5,
# early_stopping=True
# )
print(f"Question: {sample['qa']['question']}")
print(f"Ground truth: {sample['qa']['answer']}")
print(f"Moondream: {md_answer}")
print()
# print(f"Question: {sample['qa']['question']}")
# print(f"Ground truth: {sample['qa']['answer']}")
# print(f"Moondream: {md_answer}")
# print()
if md_answer.lower() == sample['qa']['answer'].lower():
correct_predictions += 1
# if md_answer.lower() == sample['qa']['answer'].lower():
# correct_predictions += 1
print(f"Accuracy: {correct_predictions * 100 / 10}%")
# print(f"Accuracy: {correct_predictions * 100 / 10}%")