add midjourney dataset

This commit is contained in:
2024-12-09 00:41:54 +00:00
parent 6b53cb0411
commit 46b60151c7
2 changed files with 88 additions and 12 deletions

View File

@@ -2,6 +2,8 @@ import torch
import datasets import datasets
import transformers import transformers
import pathlib import pathlib
from PIL import Image
from tqdm import tqdm
DEVICE = "cuda" DEVICE = "cuda"
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16 DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
@@ -40,12 +42,38 @@ flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
} }
}) })
midjourney_dataset = datasets.load_dataset("ehristoforu/midjourney-images", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is an AI image."
}
})
dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle() dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle()
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True) pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
for i, sample in enumerate(dataset): img = Image.open("samples/frames_3.jpg")
sample['image'].save(f"samples/{i}.png", "PNG") md_answer = moondream.answer_question(
moondream.encode_image(img),
"Describe this image.",
tokenizer=tokenizer,
num_beams=4,
no_repeat_ngram_size=5,
early_stopping=True,
)
print(md_answer)
correct_predictions = 0
for i, sample in enumerate(midjourney_dataset):
if i > 4:
break
sample["image"].save(f"samples/{i}.png", "PNG")
md_answer = moondream.answer_question( md_answer = moondream.answer_question(
moondream.encode_image(sample['image']), moondream.encode_image(sample['image']),
@@ -56,9 +84,12 @@ for i, sample in enumerate(dataset):
early_stopping=True early_stopping=True
) )
if i < 3: print(f"Question: {sample['qa']['question']}")
print('Question:', sample['qa']['question']) print(f"Ground truth: {sample['qa']['answer']}")
print('Ground Truth:', sample['qa']['answer']) print(f"Moondream: {md_answer}")
print('Moondream:', md_answer) print()
else:
break if md_answer.lower() == sample['qa']['answer'].lower():
correct_predictions += 1
print(f"Accuracy: {correct_predictions * 100 / 10}%")

View File

@@ -3,6 +3,7 @@ import torch
import datasets import datasets
import transformers import transformers
import bitsandbytes import bitsandbytes
import pathlib
from tqdm import tqdm from tqdm import tqdm
from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS
@@ -22,7 +23,7 @@ diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_
.train_test_split(test_size=TEST_SIZE) .train_test_split(test_size=TEST_SIZE)
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\ flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
.take(5000)\ .take(2500)\
.select_columns(["image"])\ .select_columns(["image"])\
.map(lambda row: { .map(lambda row: {
**row, **row,
@@ -33,8 +34,28 @@ flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
})\ })\
.train_test_split(test_size=TEST_SIZE) .train_test_split(test_size=TEST_SIZE)
training_dataset = datasets.concatenate_datasets([diffusion_db_dataset["train"], flickr_dataset["train"]]).shuffle() wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train")\
test_dataset = datasets.concatenate_datasets([diffusion_db_dataset["test"], flickr_dataset["test"]]).shuffle() .take(2500)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe thie image.",
"answer": "This is a real image."
}
})\
.train_test_split(test_size=TEST_SIZE)
training_dataset = datasets.concatenate_datasets([
diffusion_db_dataset["train"],
flickr_dataset["train"],
wiki_art_dataset["train"],
]).shuffle()
test_dataset = datasets.concatenate_datasets([
diffusion_db_dataset["test"],
flickr_dataset["test"],
wiki_art_dataset["test"],
]).shuffle()
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2") tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
moondream = transformers.AutoModelForCausalLM.from_pretrained( moondream = transformers.AutoModelForCausalLM.from_pretrained(
@@ -131,7 +152,7 @@ dataloaders = {
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
shuffle=True, shuffle=True,
collate_fn=collate, collate_fn=collate,
) ),
} }
moondream.text_model.train() moondream.text_model.train()
@@ -162,3 +183,27 @@ for epoch in range(EPOCHS):
param_group["lr"] = lr param_group["lr"] = lr
moondream.save_pretrained("checkpoints/moondream-mai") moondream.save_pretrained("checkpoints/moondream-mai")
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
correct_predictions = 0
for sample in tqdm(test_dataset, desc="Validation"):
md_answer = moondream.answer_question(
moondream.encode_image(sample['image']),
sample['qa']['question'],
tokenizer=tokenizer,
num_beams=4,
no_repeat_ngram_size=5,
early_stopping=True
)
ground_truth = sample["qa"]["answer"]
if md_answer == ground_truth:
correct_predictions += 1
if i % 10 == 0:
print(f"Question: f{sample["qa"]["answer"]")
accuracy = correct_predictions * 100 / len(test_dataset)
print(f"Model accuracy: f{accuracy}%")