increase data diversity

This commit is contained in:
2024-12-13 22:19:53 +00:00
parent 46b60151c7
commit 14c6f26ddc
4 changed files with 150 additions and 84 deletions

View File

@@ -4,58 +4,114 @@ import datasets
import transformers
import bitsandbytes
import pathlib
import io
import PIL
import utils.datasets
from tqdm import tqdm
from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS
DEVICE = "cuda"
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
MD_REVISION = "2024-07-23"
TOTAL_DATA_SIZE = 8000
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", split="train", trust_remote_code=True, streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is an AI image."
"question": "Is this image AI generated?",
"answer": "Yes."
}
})\
.train_test_split(test_size=TEST_SIZE)
})
diffusion_db_dataset = utils.datasets.split_streaming_dataset(diffusion_db_dataset, total_size=2000, test_size=TEST_SIZE)
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
.take(2500)\
midjourney_dataset = datasets.load_dataset("brivangl/midjourney-v6-llava", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "Yes."
}
})
midjourney_dataset = utils.datasets.split_streaming_dataset(midjourney_dataset, total_size=2000, test_size=TEST_SIZE)
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is a real image."
"question": "Is this image AI generated?",
"answer": "No."
}
})\
.train_test_split(test_size=TEST_SIZE)
})
flickr_dataset = utils.datasets.split_streaming_dataset(flickr_dataset, total_size=800, test_size=TEST_SIZE)
wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train")\
.take(2500)\
wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe thie image.",
"answer": "This is a real image."
"question": "Is this image AI generated?",
"answer": "No."
}
})\
.train_test_split(test_size=TEST_SIZE)
})
wiki_art_dataset = utils.datasets.split_streaming_dataset(wiki_art_dataset, total_size=800, test_size=TEST_SIZE)
training_dataset = datasets.concatenate_datasets([
anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust_remote_code=True, split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No."
}
})
anime_dataset = utils.datasets.split_streaming_dataset(anime_dataset, total_size=800, test_size=TEST_SIZE)
coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No."
}
})
coco_dataset = utils.datasets.split_streaming_dataset(coco_dataset, total_size=800, test_size=TEST_SIZE)
movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split="train", streaming=True)\
.select_columns(["age"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No."
}
})
movie_poster_dataset = utils.datasets.split_streaming_dataset(movie_poster_dataset, total_size=800, test_size=TEST_SIZE)
training_dataset = datasets.interleave_datasets([
diffusion_db_dataset["train"],
midjourney_dataset["train"],
flickr_dataset["train"],
wiki_art_dataset["train"],
]).shuffle()
test_dataset = datasets.concatenate_datasets([
anime_dataset["train"],
coco_dataset["train"],
movie_poster_dataset["train"],
], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True))
test_dataset = datasets.interleave_datasets([
diffusion_db_dataset["test"],
midjourney_dataset["test"],
flickr_dataset["test"],
wiki_art_dataset["test"],
]).shuffle()
anime_dataset["test"],
coco_dataset["test"],
movie_poster_dataset["test"],
], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True))
print("Training and test dataset prepared.")
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
moondream = transformers.AutoModelForCausalLM.from_pretrained(
@@ -150,7 +206,6 @@ dataloaders = {
"train": torch.utils.data.DataLoader(
training_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=collate,
),
}
@@ -158,7 +213,7 @@ dataloaders = {
moondream.text_model.train()
moondream.text_model.transformer.gradient_checkpointing_enable()
total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS
total_steps = EPOCHS * (TOTAL_DATA_SIZE * (1 - TEST_SIZE)) // GRAD_ACCUM_STEPS
optimizer = bitsandbytes.optim.Adam8bit(
[{"params": moondream.text_model.parameters()}],
lr=LR*0.1,
@@ -184,6 +239,7 @@ for epoch in range(EPOCHS):
moondream.save_pretrained("checkpoints/moondream-mai")
moondream.eval()
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
correct_predictions = 0
@@ -201,9 +257,6 @@ for sample in tqdm(test_dataset, desc="Validation"):
if md_answer == ground_truth:
correct_predictions += 1
if i % 10 == 0:
print(f"Question: f{sample["qa"]["answer"]")
accuracy = correct_predictions * 100 / len(test_dataset)
accuracy = correct_predictions * 100 / (TOTAL_DATA_SIZE * TEST_SIZE)
print(f"Model accuracy: f{accuracy}%")