diff --git a/moondream/test.py b/moondream/test.py index cc08f29..558b31a 100644 --- a/moondream/test.py +++ b/moondream/test.py @@ -2,6 +2,8 @@ import torch import datasets import transformers import pathlib +from PIL import Image +from tqdm import tqdm DEVICE = "cuda" DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16 @@ -40,12 +42,38 @@ flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\ } }) +midjourney_dataset = datasets.load_dataset("ehristoforu/midjourney-images", split="train", streaming=True)\ + .select_columns(["image"])\ + .map(lambda row: { + **row, + "qa": { + "question": "Describe this image.", + "answer": "This is an AI image." + } + }) + dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle() pathlib.Path("./samples").mkdir(parents=True, exist_ok=True) -for i, sample in enumerate(dataset): - sample['image'].save(f"samples/{i}.png", "PNG") +img = Image.open("samples/frames_3.jpg") +md_answer = moondream.answer_question( + moondream.encode_image(img), + "Describe this image.", + tokenizer=tokenizer, + num_beams=4, + no_repeat_ngram_size=5, + early_stopping=True, +) + +print(md_answer) + +correct_predictions = 0 +for i, sample in enumerate(midjourney_dataset): + if i > 4: + break + + sample["image"].save(f"samples/{i}.png", "PNG") md_answer = moondream.answer_question( moondream.encode_image(sample['image']), @@ -56,9 +84,12 @@ for i, sample in enumerate(dataset): early_stopping=True ) - if i < 3: - print('Question:', sample['qa']['question']) - print('Ground Truth:', sample['qa']['answer']) - print('Moondream:', md_answer) - else: - break + print(f"Question: {sample['qa']['question']}") + print(f"Ground truth: {sample['qa']['answer']}") + print(f"Moondream: {md_answer}") + print() + + if md_answer.lower() == sample['qa']['answer'].lower(): + correct_predictions += 1 + +print(f"Accuracy: {correct_predictions * 100 / 10}%") diff --git a/moondream/train.py b/moondream/train.py index b9c159f..7194cd4 100644 --- a/moondream/train.py +++ b/moondream/train.py @@ -3,6 +3,7 @@ import torch import datasets import transformers import bitsandbytes +import pathlib from tqdm import tqdm from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS @@ -22,7 +23,7 @@ diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_ .train_test_split(test_size=TEST_SIZE) flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\ - .take(5000)\ + .take(2500)\ .select_columns(["image"])\ .map(lambda row: { **row, @@ -33,8 +34,28 @@ flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\ })\ .train_test_split(test_size=TEST_SIZE) -training_dataset = datasets.concatenate_datasets([diffusion_db_dataset["train"], flickr_dataset["train"]]).shuffle() -test_dataset = datasets.concatenate_datasets([diffusion_db_dataset["test"], flickr_dataset["test"]]).shuffle() +wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train")\ + .take(2500)\ + .select_columns(["image"])\ + .map(lambda row: { + **row, + "qa": { + "question": "Describe thie image.", + "answer": "This is a real image." + } + })\ + .train_test_split(test_size=TEST_SIZE) + +training_dataset = datasets.concatenate_datasets([ + diffusion_db_dataset["train"], + flickr_dataset["train"], + wiki_art_dataset["train"], +]).shuffle() +test_dataset = datasets.concatenate_datasets([ + diffusion_db_dataset["test"], + flickr_dataset["test"], + wiki_art_dataset["test"], +]).shuffle() tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2") moondream = transformers.AutoModelForCausalLM.from_pretrained( @@ -131,7 +152,7 @@ dataloaders = { batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate, - ) + ), } moondream.text_model.train() @@ -162,3 +183,27 @@ for epoch in range(EPOCHS): param_group["lr"] = lr moondream.save_pretrained("checkpoints/moondream-mai") + +pathlib.Path("./samples").mkdir(parents=True, exist_ok=True) + +correct_predictions = 0 +for sample in tqdm(test_dataset, desc="Validation"): + md_answer = moondream.answer_question( + moondream.encode_image(sample['image']), + sample['qa']['question'], + tokenizer=tokenizer, + num_beams=4, + no_repeat_ngram_size=5, + early_stopping=True + ) + + ground_truth = sample["qa"]["answer"] + if md_answer == ground_truth: + correct_predictions += 1 + + if i % 10 == 0: + print(f"Question: f{sample["qa"]["answer"]") + +accuracy = correct_predictions * 100 / len(test_dataset) + +print(f"Model accuracy: f{accuracy}%")