add midjourney dataset
This commit is contained in:
@@ -3,6 +3,7 @@ import torch
|
||||
import datasets
|
||||
import transformers
|
||||
import bitsandbytes
|
||||
import pathlib
|
||||
from tqdm import tqdm
|
||||
from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS
|
||||
|
||||
@@ -22,7 +23,7 @@ diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_
|
||||
.train_test_split(test_size=TEST_SIZE)
|
||||
|
||||
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
|
||||
.take(5000)\
|
||||
.take(2500)\
|
||||
.select_columns(["image"])\
|
||||
.map(lambda row: {
|
||||
**row,
|
||||
@@ -33,8 +34,28 @@ flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
|
||||
})\
|
||||
.train_test_split(test_size=TEST_SIZE)
|
||||
|
||||
training_dataset = datasets.concatenate_datasets([diffusion_db_dataset["train"], flickr_dataset["train"]]).shuffle()
|
||||
test_dataset = datasets.concatenate_datasets([diffusion_db_dataset["test"], flickr_dataset["test"]]).shuffle()
|
||||
wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train")\
|
||||
.take(2500)\
|
||||
.select_columns(["image"])\
|
||||
.map(lambda row: {
|
||||
**row,
|
||||
"qa": {
|
||||
"question": "Describe thie image.",
|
||||
"answer": "This is a real image."
|
||||
}
|
||||
})\
|
||||
.train_test_split(test_size=TEST_SIZE)
|
||||
|
||||
training_dataset = datasets.concatenate_datasets([
|
||||
diffusion_db_dataset["train"],
|
||||
flickr_dataset["train"],
|
||||
wiki_art_dataset["train"],
|
||||
]).shuffle()
|
||||
test_dataset = datasets.concatenate_datasets([
|
||||
diffusion_db_dataset["test"],
|
||||
flickr_dataset["test"],
|
||||
wiki_art_dataset["test"],
|
||||
]).shuffle()
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
|
||||
moondream = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
@@ -131,7 +152,7 @@ dataloaders = {
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
collate_fn=collate,
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
moondream.text_model.train()
|
||||
@@ -162,3 +183,27 @@ for epoch in range(EPOCHS):
|
||||
param_group["lr"] = lr
|
||||
|
||||
moondream.save_pretrained("checkpoints/moondream-mai")
|
||||
|
||||
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
correct_predictions = 0
|
||||
for sample in tqdm(test_dataset, desc="Validation"):
|
||||
md_answer = moondream.answer_question(
|
||||
moondream.encode_image(sample['image']),
|
||||
sample['qa']['question'],
|
||||
tokenizer=tokenizer,
|
||||
num_beams=4,
|
||||
no_repeat_ngram_size=5,
|
||||
early_stopping=True
|
||||
)
|
||||
|
||||
ground_truth = sample["qa"]["answer"]
|
||||
if md_answer == ground_truth:
|
||||
correct_predictions += 1
|
||||
|
||||
if i % 10 == 0:
|
||||
print(f"Question: f{sample["qa"]["answer"]")
|
||||
|
||||
accuracy = correct_predictions * 100 / len(test_dataset)
|
||||
|
||||
print(f"Model accuracy: f{accuracy}%")
|
||||
|
Reference in New Issue
Block a user