add midjourney dataset
This commit is contained in:
@@ -2,6 +2,8 @@ import torch
|
|||||||
import datasets
|
import datasets
|
||||||
import transformers
|
import transformers
|
||||||
import pathlib
|
import pathlib
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
DEVICE = "cuda"
|
DEVICE = "cuda"
|
||||||
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
|
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
|
||||||
@@ -40,12 +42,38 @@ flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
midjourney_dataset = datasets.load_dataset("ehristoforu/midjourney-images", split="train", streaming=True)\
|
||||||
|
.select_columns(["image"])\
|
||||||
|
.map(lambda row: {
|
||||||
|
**row,
|
||||||
|
"qa": {
|
||||||
|
"question": "Describe this image.",
|
||||||
|
"answer": "This is an AI image."
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle()
|
dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle()
|
||||||
|
|
||||||
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
|
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
for i, sample in enumerate(dataset):
|
img = Image.open("samples/frames_3.jpg")
|
||||||
sample['image'].save(f"samples/{i}.png", "PNG")
|
md_answer = moondream.answer_question(
|
||||||
|
moondream.encode_image(img),
|
||||||
|
"Describe this image.",
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_beams=4,
|
||||||
|
no_repeat_ngram_size=5,
|
||||||
|
early_stopping=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(md_answer)
|
||||||
|
|
||||||
|
correct_predictions = 0
|
||||||
|
for i, sample in enumerate(midjourney_dataset):
|
||||||
|
if i > 4:
|
||||||
|
break
|
||||||
|
|
||||||
|
sample["image"].save(f"samples/{i}.png", "PNG")
|
||||||
|
|
||||||
md_answer = moondream.answer_question(
|
md_answer = moondream.answer_question(
|
||||||
moondream.encode_image(sample['image']),
|
moondream.encode_image(sample['image']),
|
||||||
@@ -56,9 +84,12 @@ for i, sample in enumerate(dataset):
|
|||||||
early_stopping=True
|
early_stopping=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if i < 3:
|
print(f"Question: {sample['qa']['question']}")
|
||||||
print('Question:', sample['qa']['question'])
|
print(f"Ground truth: {sample['qa']['answer']}")
|
||||||
print('Ground Truth:', sample['qa']['answer'])
|
print(f"Moondream: {md_answer}")
|
||||||
print('Moondream:', md_answer)
|
print()
|
||||||
else:
|
|
||||||
break
|
if md_answer.lower() == sample['qa']['answer'].lower():
|
||||||
|
correct_predictions += 1
|
||||||
|
|
||||||
|
print(f"Accuracy: {correct_predictions * 100 / 10}%")
|
||||||
|
@@ -3,6 +3,7 @@ import torch
|
|||||||
import datasets
|
import datasets
|
||||||
import transformers
|
import transformers
|
||||||
import bitsandbytes
|
import bitsandbytes
|
||||||
|
import pathlib
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS
|
from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS
|
||||||
|
|
||||||
@@ -22,7 +23,7 @@ diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_
|
|||||||
.train_test_split(test_size=TEST_SIZE)
|
.train_test_split(test_size=TEST_SIZE)
|
||||||
|
|
||||||
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
|
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
|
||||||
.take(5000)\
|
.take(2500)\
|
||||||
.select_columns(["image"])\
|
.select_columns(["image"])\
|
||||||
.map(lambda row: {
|
.map(lambda row: {
|
||||||
**row,
|
**row,
|
||||||
@@ -33,8 +34,28 @@ flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
|
|||||||
})\
|
})\
|
||||||
.train_test_split(test_size=TEST_SIZE)
|
.train_test_split(test_size=TEST_SIZE)
|
||||||
|
|
||||||
training_dataset = datasets.concatenate_datasets([diffusion_db_dataset["train"], flickr_dataset["train"]]).shuffle()
|
wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train")\
|
||||||
test_dataset = datasets.concatenate_datasets([diffusion_db_dataset["test"], flickr_dataset["test"]]).shuffle()
|
.take(2500)\
|
||||||
|
.select_columns(["image"])\
|
||||||
|
.map(lambda row: {
|
||||||
|
**row,
|
||||||
|
"qa": {
|
||||||
|
"question": "Describe thie image.",
|
||||||
|
"answer": "This is a real image."
|
||||||
|
}
|
||||||
|
})\
|
||||||
|
.train_test_split(test_size=TEST_SIZE)
|
||||||
|
|
||||||
|
training_dataset = datasets.concatenate_datasets([
|
||||||
|
diffusion_db_dataset["train"],
|
||||||
|
flickr_dataset["train"],
|
||||||
|
wiki_art_dataset["train"],
|
||||||
|
]).shuffle()
|
||||||
|
test_dataset = datasets.concatenate_datasets([
|
||||||
|
diffusion_db_dataset["test"],
|
||||||
|
flickr_dataset["test"],
|
||||||
|
wiki_art_dataset["test"],
|
||||||
|
]).shuffle()
|
||||||
|
|
||||||
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
|
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
|
||||||
moondream = transformers.AutoModelForCausalLM.from_pretrained(
|
moondream = transformers.AutoModelForCausalLM.from_pretrained(
|
||||||
@@ -131,7 +152,7 @@ dataloaders = {
|
|||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collate,
|
collate_fn=collate,
|
||||||
)
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
moondream.text_model.train()
|
moondream.text_model.train()
|
||||||
@@ -162,3 +183,27 @@ for epoch in range(EPOCHS):
|
|||||||
param_group["lr"] = lr
|
param_group["lr"] = lr
|
||||||
|
|
||||||
moondream.save_pretrained("checkpoints/moondream-mai")
|
moondream.save_pretrained("checkpoints/moondream-mai")
|
||||||
|
|
||||||
|
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
correct_predictions = 0
|
||||||
|
for sample in tqdm(test_dataset, desc="Validation"):
|
||||||
|
md_answer = moondream.answer_question(
|
||||||
|
moondream.encode_image(sample['image']),
|
||||||
|
sample['qa']['question'],
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_beams=4,
|
||||||
|
no_repeat_ngram_size=5,
|
||||||
|
early_stopping=True
|
||||||
|
)
|
||||||
|
|
||||||
|
ground_truth = sample["qa"]["answer"]
|
||||||
|
if md_answer == ground_truth:
|
||||||
|
correct_predictions += 1
|
||||||
|
|
||||||
|
if i % 10 == 0:
|
||||||
|
print(f"Question: f{sample["qa"]["answer"]")
|
||||||
|
|
||||||
|
accuracy = correct_predictions * 100 / len(test_dataset)
|
||||||
|
|
||||||
|
print(f"Model accuracy: f{accuracy}%")
|
||||||
|
Reference in New Issue
Block a user