add script for generating dataset
This commit is contained in:
@@ -13,7 +13,7 @@ from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOC
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
|
||||
MD_REVISION = "2024-07-23"
|
||||
TOTAL_DATA_SIZE = 8000
|
||||
TOTAL_DATA_SIZE = 20000
|
||||
|
||||
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", split="train", trust_remote_code=True, streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
@@ -24,7 +24,7 @@ diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_
|
||||
"answer": "Yes."
|
||||
}
|
||||
})
|
||||
diffusion_db_dataset = utils.datasets.split_streaming_dataset(diffusion_db_dataset, total_size=2000, test_size=TEST_SIZE)
|
||||
diffusion_db_dataset = utils.datasets.split_streaming_dataset(diffusion_db_dataset, total_size=5000, test_size=TEST_SIZE)
|
||||
|
||||
midjourney_dataset = datasets.load_dataset("brivangl/midjourney-v6-llava", split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
@@ -35,7 +35,7 @@ midjourney_dataset = datasets.load_dataset("brivangl/midjourney-v6-llava", split
|
||||
"answer": "Yes."
|
||||
}
|
||||
})
|
||||
midjourney_dataset = utils.datasets.split_streaming_dataset(midjourney_dataset, total_size=2000, test_size=TEST_SIZE)
|
||||
midjourney_dataset = utils.datasets.split_streaming_dataset(midjourney_dataset, total_size=5000, test_size=TEST_SIZE)
|
||||
|
||||
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
@@ -46,7 +46,7 @@ flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", stream
|
||||
"answer": "No."
|
||||
}
|
||||
})
|
||||
flickr_dataset = utils.datasets.split_streaming_dataset(flickr_dataset, total_size=800, test_size=TEST_SIZE)
|
||||
flickr_dataset = utils.datasets.split_streaming_dataset(flickr_dataset, total_size=1250, test_size=TEST_SIZE)
|
||||
|
||||
wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
@@ -57,7 +57,7 @@ wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", stream
|
||||
"answer": "No."
|
||||
}
|
||||
})
|
||||
wiki_art_dataset = utils.datasets.split_streaming_dataset(wiki_art_dataset, total_size=800, test_size=TEST_SIZE)
|
||||
wiki_art_dataset = utils.datasets.split_streaming_dataset(wiki_art_dataset, total_size=1250, test_size=TEST_SIZE)
|
||||
|
||||
anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust_remote_code=True, split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
@@ -68,7 +68,7 @@ anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust
|
||||
"answer": "No."
|
||||
}
|
||||
})
|
||||
anime_dataset = utils.datasets.split_streaming_dataset(anime_dataset, total_size=800, test_size=TEST_SIZE)
|
||||
anime_dataset = utils.datasets.split_streaming_dataset(anime_dataset, total_size=1250, test_size=TEST_SIZE)
|
||||
|
||||
coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
@@ -79,10 +79,10 @@ coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", s
|
||||
"answer": "No."
|
||||
}
|
||||
})
|
||||
coco_dataset = utils.datasets.split_streaming_dataset(coco_dataset, total_size=800, test_size=TEST_SIZE)
|
||||
coco_dataset = utils.datasets.split_streaming_dataset(coco_dataset, total_size=1250, test_size=TEST_SIZE)
|
||||
|
||||
movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split="train", streaming=True)\
|
||||
.select_columns(["age"])\
|
||||
.select_columns(["image"])\
|
||||
.map(lambda row: {
|
||||
**row,
|
||||
"qa": {
|
||||
@@ -90,7 +90,40 @@ movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split
|
||||
"answer": "No."
|
||||
}
|
||||
})
|
||||
movie_poster_dataset = utils.datasets.split_streaming_dataset(movie_poster_dataset, total_size=800, test_size=TEST_SIZE)
|
||||
movie_poster_dataset = utils.datasets.split_streaming_dataset(movie_poster_dataset, total_size=1250, test_size=TEST_SIZE)
|
||||
|
||||
cars_dataset = datasets.load_dataset("tanganke/stanford_cars", split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
.map(lambda row: {
|
||||
**row,
|
||||
"qa": {
|
||||
"question": "Is this image AI generated?",
|
||||
"answer": "No."
|
||||
}
|
||||
})
|
||||
cars_dataset = utils.datasets.split_streaming_dataset(cars_dataset, total_size=1250, test_size=TEST_SIZE)
|
||||
|
||||
website_dataset = datasets.load_dataset("silatus/1k_Website_Screenshots_and_Metadata", split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
.map(lambda row: {
|
||||
**row,
|
||||
"qa": {
|
||||
"question": "Is this image AI generated?",
|
||||
"answer": "No.",
|
||||
}
|
||||
})
|
||||
website_dataset = utils.datasets.split_streaming_dataset(website_dataset, total_size=1250, test_size=TEST_SIZE)
|
||||
|
||||
movie_scene_dataset = datasets.load_dataset("unography/movie-scenes-resized-captioned", split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
.map(lambda row: {
|
||||
**row,
|
||||
"qa": {
|
||||
"question": "Is this image AI generated?",
|
||||
"answer": "No.",
|
||||
}
|
||||
})
|
||||
movie_scene_dataset = utils.datasets.split_streaming_dataset(movie_scene_dataset, total_size=1250, test_size=TEST_SIZE)
|
||||
|
||||
training_dataset = datasets.interleave_datasets([
|
||||
diffusion_db_dataset["train"],
|
||||
@@ -100,6 +133,9 @@ training_dataset = datasets.interleave_datasets([
|
||||
anime_dataset["train"],
|
||||
coco_dataset["train"],
|
||||
movie_poster_dataset["train"],
|
||||
cars_dataset["train"],
|
||||
website_dataset["train"],
|
||||
movie_scene_dataset["train"],
|
||||
], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True))
|
||||
test_dataset = datasets.interleave_datasets([
|
||||
diffusion_db_dataset["test"],
|
||||
@@ -109,6 +145,9 @@ test_dataset = datasets.interleave_datasets([
|
||||
anime_dataset["test"],
|
||||
coco_dataset["test"],
|
||||
movie_poster_dataset["test"],
|
||||
cars_dataset["test"],
|
||||
website_dataset["test"],
|
||||
movie_scene_dataset["test"],
|
||||
], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True))
|
||||
|
||||
print("Training and test dataset prepared.")
|
||||
@@ -242,8 +281,11 @@ moondream.save_pretrained("checkpoints/moondream-mai")
|
||||
moondream.eval()
|
||||
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total = 0
|
||||
correct_predictions = 0
|
||||
for sample in tqdm(test_dataset, desc="Validation"):
|
||||
total += 1
|
||||
|
||||
md_answer = moondream.answer_question(
|
||||
moondream.encode_image(sample['image']),
|
||||
sample['qa']['question'],
|
||||
@@ -257,6 +299,6 @@ for sample in tqdm(test_dataset, desc="Validation"):
|
||||
if md_answer == ground_truth:
|
||||
correct_predictions += 1
|
||||
|
||||
accuracy = correct_predictions * 100 / (TOTAL_DATA_SIZE * TEST_SIZE)
|
||||
accuracy = correct_predictions * 100 / total
|
||||
|
||||
print(f"Model accuracy: f{accuracy}%")
|
||||
|
Reference in New Issue
Block a user