increase data diversity

This commit is contained in:
2024-12-13 22:19:53 +00:00
parent 46b60151c7
commit 14c6f26ddc
4 changed files with 150 additions and 84 deletions

41
moondream/siglip.py Normal file
View File

@@ -0,0 +1,41 @@
import transformers
import torch
import datasets
import sklearn
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
model = transformers.AutoModel.from_pretrained("google/siglip-base-patch16-224").to(device)
processor = transformers.AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
nn = sklearn.neighbors.NearestNeighbors(metric="euclidean", radius=1.0)
ds = datasets.load_dataset("ehristoforu/midjourney-images", split="train", trust_remote_code=True, streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is an AI image."
}
})\
.take(500)
with torch.no_grad():
inputs = processor(images=[row["image"] for row in ds], return_tensors="pt").to(device)
image_features = model.get_image_features(**inputs).cpu()
nn.fit(image_features)
used_indices = set()
unique_indices = []
for i, row in enumerate(ds):
if i in used_indices:
continue
feature = image_features[i]
neighbors = nn.radius_neighbors([feature], radius=1.0, return_distance=False)[0]
unique_indices.append(i)
used_indices.update(neighbors)
print(len(unique_indices))

View File

@@ -18,45 +18,9 @@ moondream = transformers.AutoModelForCausalLM.from_pretrained(
device_map={"": DEVICE}, device_map={"": DEVICE},
) )
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\
.shuffle()\
.take(100)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is an AI image."
}
})
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
.shuffle()\
.take(100)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is a real image."
}
})
midjourney_dataset = datasets.load_dataset("ehristoforu/midjourney-images", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is an AI image."
}
})
dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle()
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True) pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
img = Image.open("samples/frames_3.jpg") img = Image.open("samples/Untitled.jpg")
md_answer = moondream.answer_question( md_answer = moondream.answer_question(
moondream.encode_image(img), moondream.encode_image(img),
"Describe this image.", "Describe this image.",
@@ -68,28 +32,28 @@ md_answer = moondream.answer_question(
print(md_answer) print(md_answer)
correct_predictions = 0 # correct_predictions = 0
for i, sample in enumerate(midjourney_dataset): # for i, sample in enumerate(flickr_dataset):
if i > 4: # if i > 4:
break # break
sample["image"].save(f"samples/{i}.png", "PNG") # sample["image"].save(f"samples/{i}.png", "PNG")
md_answer = moondream.answer_question( # md_answer = moondream.answer_question(
moondream.encode_image(sample['image']), # moondream.encode_image(sample['image']),
sample['qa']['question'], # sample['qa']['question'],
tokenizer=tokenizer, # tokenizer=tokenizer,
num_beams=4, # num_beams=4,
no_repeat_ngram_size=5, # no_repeat_ngram_size=5,
early_stopping=True # early_stopping=True
) # )
print(f"Question: {sample['qa']['question']}") # print(f"Question: {sample['qa']['question']}")
print(f"Ground truth: {sample['qa']['answer']}") # print(f"Ground truth: {sample['qa']['answer']}")
print(f"Moondream: {md_answer}") # print(f"Moondream: {md_answer}")
print() # print()
if md_answer.lower() == sample['qa']['answer'].lower(): # if md_answer.lower() == sample['qa']['answer'].lower():
correct_predictions += 1 # correct_predictions += 1
print(f"Accuracy: {correct_predictions * 100 / 10}%") # print(f"Accuracy: {correct_predictions * 100 / 10}%")

View File

@@ -4,58 +4,114 @@ import datasets
import transformers import transformers
import bitsandbytes import bitsandbytes
import pathlib import pathlib
import io
import PIL
import utils.datasets
from tqdm import tqdm from tqdm import tqdm
from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS
DEVICE = "cuda" DEVICE = "cuda"
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16 DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
MD_REVISION = "2024-07-23" MD_REVISION = "2024-07-23"
TOTAL_DATA_SIZE = 8000
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\ diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", split="train", trust_remote_code=True, streaming=True)\
.select_columns(["image"])\ .select_columns(["image"])\
.map(lambda row: { .map(lambda row: {
**row, **row,
"qa": { "qa": {
"question": "Describe this image.", "question": "Is this image AI generated?",
"answer": "This is an AI image." "answer": "Yes."
} }
})\ })
.train_test_split(test_size=TEST_SIZE) diffusion_db_dataset = utils.datasets.split_streaming_dataset(diffusion_db_dataset, total_size=2000, test_size=TEST_SIZE)
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\ midjourney_dataset = datasets.load_dataset("brivangl/midjourney-v6-llava", split="train", streaming=True)\
.take(2500)\ .select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "Yes."
}
})
midjourney_dataset = utils.datasets.split_streaming_dataset(midjourney_dataset, total_size=2000, test_size=TEST_SIZE)
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", streaming=True)\
.select_columns(["image"])\ .select_columns(["image"])\
.map(lambda row: { .map(lambda row: {
**row, **row,
"qa": { "qa": {
"question": "Describe this image.", "question": "Is this image AI generated?",
"answer": "This is a real image." "answer": "No."
} }
})\ })
.train_test_split(test_size=TEST_SIZE) flickr_dataset = utils.datasets.split_streaming_dataset(flickr_dataset, total_size=800, test_size=TEST_SIZE)
wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train")\ wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", streaming=True)\
.take(2500)\
.select_columns(["image"])\ .select_columns(["image"])\
.map(lambda row: { .map(lambda row: {
**row, **row,
"qa": { "qa": {
"question": "Describe thie image.", "question": "Is this image AI generated?",
"answer": "This is a real image." "answer": "No."
} }
})\ })
.train_test_split(test_size=TEST_SIZE) wiki_art_dataset = utils.datasets.split_streaming_dataset(wiki_art_dataset, total_size=800, test_size=TEST_SIZE)
training_dataset = datasets.concatenate_datasets([ anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust_remote_code=True, split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No."
}
})
anime_dataset = utils.datasets.split_streaming_dataset(anime_dataset, total_size=800, test_size=TEST_SIZE)
coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No."
}
})
coco_dataset = utils.datasets.split_streaming_dataset(coco_dataset, total_size=800, test_size=TEST_SIZE)
movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split="train", streaming=True)\
.select_columns(["age"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No."
}
})
movie_poster_dataset = utils.datasets.split_streaming_dataset(movie_poster_dataset, total_size=800, test_size=TEST_SIZE)
training_dataset = datasets.interleave_datasets([
diffusion_db_dataset["train"], diffusion_db_dataset["train"],
midjourney_dataset["train"],
flickr_dataset["train"], flickr_dataset["train"],
wiki_art_dataset["train"], wiki_art_dataset["train"],
]).shuffle() anime_dataset["train"],
test_dataset = datasets.concatenate_datasets([ coco_dataset["train"],
movie_poster_dataset["train"],
], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True))
test_dataset = datasets.interleave_datasets([
diffusion_db_dataset["test"], diffusion_db_dataset["test"],
midjourney_dataset["test"],
flickr_dataset["test"], flickr_dataset["test"],
wiki_art_dataset["test"], wiki_art_dataset["test"],
]).shuffle() anime_dataset["test"],
coco_dataset["test"],
movie_poster_dataset["test"],
], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True))
print("Training and test dataset prepared.")
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2") tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
moondream = transformers.AutoModelForCausalLM.from_pretrained( moondream = transformers.AutoModelForCausalLM.from_pretrained(
@@ -150,7 +206,6 @@ dataloaders = {
"train": torch.utils.data.DataLoader( "train": torch.utils.data.DataLoader(
training_dataset, training_dataset,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=collate, collate_fn=collate,
), ),
} }
@@ -158,7 +213,7 @@ dataloaders = {
moondream.text_model.train() moondream.text_model.train()
moondream.text_model.transformer.gradient_checkpointing_enable() moondream.text_model.transformer.gradient_checkpointing_enable()
total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS total_steps = EPOCHS * (TOTAL_DATA_SIZE * (1 - TEST_SIZE)) // GRAD_ACCUM_STEPS
optimizer = bitsandbytes.optim.Adam8bit( optimizer = bitsandbytes.optim.Adam8bit(
[{"params": moondream.text_model.parameters()}], [{"params": moondream.text_model.parameters()}],
lr=LR*0.1, lr=LR*0.1,
@@ -184,6 +239,7 @@ for epoch in range(EPOCHS):
moondream.save_pretrained("checkpoints/moondream-mai") moondream.save_pretrained("checkpoints/moondream-mai")
moondream.eval()
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True) pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
correct_predictions = 0 correct_predictions = 0
@@ -201,9 +257,6 @@ for sample in tqdm(test_dataset, desc="Validation"):
if md_answer == ground_truth: if md_answer == ground_truth:
correct_predictions += 1 correct_predictions += 1
if i % 10 == 0: accuracy = correct_predictions * 100 / (TOTAL_DATA_SIZE * TEST_SIZE)
print(f"Question: f{sample["qa"]["answer"]")
accuracy = correct_predictions * 100 / len(test_dataset)
print(f"Model accuracy: f{accuracy}%") print(f"Model accuracy: f{accuracy}%")

8
utils/datasets.py Normal file
View File

@@ -0,0 +1,8 @@
import datasets
def split_streaming_dataset(ds: datasets.IterableDataset, total_size: int, test_size: float) -> dict[str, datasets.IterableDataset]:
size = round(total_size * (1 - test_size))
return {
"train": ds.take(size),
"test": ds.skip(size).take(total_size - size),
}