add midjourney dataset
This commit is contained in:
@@ -2,6 +2,8 @@ import torch
|
||||
import datasets
|
||||
import transformers
|
||||
import pathlib
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
|
||||
@@ -40,12 +42,38 @@ flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
|
||||
}
|
||||
})
|
||||
|
||||
midjourney_dataset = datasets.load_dataset("ehristoforu/midjourney-images", split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
.map(lambda row: {
|
||||
**row,
|
||||
"qa": {
|
||||
"question": "Describe this image.",
|
||||
"answer": "This is an AI image."
|
||||
}
|
||||
})
|
||||
|
||||
dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle()
|
||||
|
||||
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i, sample in enumerate(dataset):
|
||||
sample['image'].save(f"samples/{i}.png", "PNG")
|
||||
img = Image.open("samples/frames_3.jpg")
|
||||
md_answer = moondream.answer_question(
|
||||
moondream.encode_image(img),
|
||||
"Describe this image.",
|
||||
tokenizer=tokenizer,
|
||||
num_beams=4,
|
||||
no_repeat_ngram_size=5,
|
||||
early_stopping=True,
|
||||
)
|
||||
|
||||
print(md_answer)
|
||||
|
||||
correct_predictions = 0
|
||||
for i, sample in enumerate(midjourney_dataset):
|
||||
if i > 4:
|
||||
break
|
||||
|
||||
sample["image"].save(f"samples/{i}.png", "PNG")
|
||||
|
||||
md_answer = moondream.answer_question(
|
||||
moondream.encode_image(sample['image']),
|
||||
@@ -56,9 +84,12 @@ for i, sample in enumerate(dataset):
|
||||
early_stopping=True
|
||||
)
|
||||
|
||||
if i < 3:
|
||||
print('Question:', sample['qa']['question'])
|
||||
print('Ground Truth:', sample['qa']['answer'])
|
||||
print('Moondream:', md_answer)
|
||||
else:
|
||||
break
|
||||
print(f"Question: {sample['qa']['question']}")
|
||||
print(f"Ground truth: {sample['qa']['answer']}")
|
||||
print(f"Moondream: {md_answer}")
|
||||
print()
|
||||
|
||||
if md_answer.lower() == sample['qa']['answer'].lower():
|
||||
correct_predictions += 1
|
||||
|
||||
print(f"Accuracy: {correct_predictions * 100 / 10}%")
|
||||
|
Reference in New Issue
Block a user