increase data diversity
This commit is contained in:
41
moondream/siglip.py
Normal file
41
moondream/siglip.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import transformers
|
||||||
|
import torch
|
||||||
|
import datasets
|
||||||
|
import sklearn
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
||||||
|
model = transformers.AutoModel.from_pretrained("google/siglip-base-patch16-224").to(device)
|
||||||
|
processor = transformers.AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
|
||||||
|
nn = sklearn.neighbors.NearestNeighbors(metric="euclidean", radius=1.0)
|
||||||
|
|
||||||
|
ds = datasets.load_dataset("ehristoforu/midjourney-images", split="train", trust_remote_code=True, streaming=True)\
|
||||||
|
.select_columns(["image"])\
|
||||||
|
.map(lambda row: {
|
||||||
|
**row,
|
||||||
|
"qa": {
|
||||||
|
"question": "Describe this image.",
|
||||||
|
"answer": "This is an AI image."
|
||||||
|
}
|
||||||
|
})\
|
||||||
|
.take(500)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = processor(images=[row["image"] for row in ds], return_tensors="pt").to(device)
|
||||||
|
image_features = model.get_image_features(**inputs).cpu()
|
||||||
|
|
||||||
|
nn.fit(image_features)
|
||||||
|
|
||||||
|
used_indices = set()
|
||||||
|
unique_indices = []
|
||||||
|
for i, row in enumerate(ds):
|
||||||
|
if i in used_indices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
feature = image_features[i]
|
||||||
|
|
||||||
|
neighbors = nn.radius_neighbors([feature], radius=1.0, return_distance=False)[0]
|
||||||
|
|
||||||
|
unique_indices.append(i)
|
||||||
|
used_indices.update(neighbors)
|
||||||
|
|
||||||
|
print(len(unique_indices))
|
@@ -18,45 +18,9 @@ moondream = transformers.AutoModelForCausalLM.from_pretrained(
|
|||||||
device_map={"": DEVICE},
|
device_map={"": DEVICE},
|
||||||
)
|
)
|
||||||
|
|
||||||
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\
|
|
||||||
.shuffle()\
|
|
||||||
.take(100)\
|
|
||||||
.select_columns(["image"])\
|
|
||||||
.map(lambda row: {
|
|
||||||
**row,
|
|
||||||
"qa": {
|
|
||||||
"question": "Describe this image.",
|
|
||||||
"answer": "This is an AI image."
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
|
|
||||||
.shuffle()\
|
|
||||||
.take(100)\
|
|
||||||
.select_columns(["image"])\
|
|
||||||
.map(lambda row: {
|
|
||||||
**row,
|
|
||||||
"qa": {
|
|
||||||
"question": "Describe this image.",
|
|
||||||
"answer": "This is a real image."
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
midjourney_dataset = datasets.load_dataset("ehristoforu/midjourney-images", split="train", streaming=True)\
|
|
||||||
.select_columns(["image"])\
|
|
||||||
.map(lambda row: {
|
|
||||||
**row,
|
|
||||||
"qa": {
|
|
||||||
"question": "Describe this image.",
|
|
||||||
"answer": "This is an AI image."
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle()
|
|
||||||
|
|
||||||
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
|
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
img = Image.open("samples/frames_3.jpg")
|
img = Image.open("samples/Untitled.jpg")
|
||||||
md_answer = moondream.answer_question(
|
md_answer = moondream.answer_question(
|
||||||
moondream.encode_image(img),
|
moondream.encode_image(img),
|
||||||
"Describe this image.",
|
"Describe this image.",
|
||||||
@@ -68,28 +32,28 @@ md_answer = moondream.answer_question(
|
|||||||
|
|
||||||
print(md_answer)
|
print(md_answer)
|
||||||
|
|
||||||
correct_predictions = 0
|
# correct_predictions = 0
|
||||||
for i, sample in enumerate(midjourney_dataset):
|
# for i, sample in enumerate(flickr_dataset):
|
||||||
if i > 4:
|
# if i > 4:
|
||||||
break
|
# break
|
||||||
|
|
||||||
sample["image"].save(f"samples/{i}.png", "PNG")
|
# sample["image"].save(f"samples/{i}.png", "PNG")
|
||||||
|
|
||||||
md_answer = moondream.answer_question(
|
# md_answer = moondream.answer_question(
|
||||||
moondream.encode_image(sample['image']),
|
# moondream.encode_image(sample['image']),
|
||||||
sample['qa']['question'],
|
# sample['qa']['question'],
|
||||||
tokenizer=tokenizer,
|
# tokenizer=tokenizer,
|
||||||
num_beams=4,
|
# num_beams=4,
|
||||||
no_repeat_ngram_size=5,
|
# no_repeat_ngram_size=5,
|
||||||
early_stopping=True
|
# early_stopping=True
|
||||||
)
|
# )
|
||||||
|
|
||||||
print(f"Question: {sample['qa']['question']}")
|
# print(f"Question: {sample['qa']['question']}")
|
||||||
print(f"Ground truth: {sample['qa']['answer']}")
|
# print(f"Ground truth: {sample['qa']['answer']}")
|
||||||
print(f"Moondream: {md_answer}")
|
# print(f"Moondream: {md_answer}")
|
||||||
print()
|
# print()
|
||||||
|
|
||||||
if md_answer.lower() == sample['qa']['answer'].lower():
|
# if md_answer.lower() == sample['qa']['answer'].lower():
|
||||||
correct_predictions += 1
|
# correct_predictions += 1
|
||||||
|
|
||||||
print(f"Accuracy: {correct_predictions * 100 / 10}%")
|
# print(f"Accuracy: {correct_predictions * 100 / 10}%")
|
||||||
|
@@ -4,58 +4,114 @@ import datasets
|
|||||||
import transformers
|
import transformers
|
||||||
import bitsandbytes
|
import bitsandbytes
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import io
|
||||||
|
import PIL
|
||||||
|
import utils.datasets
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS
|
from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS
|
||||||
|
|
||||||
DEVICE = "cuda"
|
DEVICE = "cuda"
|
||||||
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
|
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
|
||||||
MD_REVISION = "2024-07-23"
|
MD_REVISION = "2024-07-23"
|
||||||
|
TOTAL_DATA_SIZE = 8000
|
||||||
|
|
||||||
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\
|
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", split="train", trust_remote_code=True, streaming=True)\
|
||||||
.select_columns(["image"])\
|
.select_columns(["image"])\
|
||||||
.map(lambda row: {
|
.map(lambda row: {
|
||||||
**row,
|
**row,
|
||||||
"qa": {
|
"qa": {
|
||||||
"question": "Describe this image.",
|
"question": "Is this image AI generated?",
|
||||||
"answer": "This is an AI image."
|
"answer": "Yes."
|
||||||
}
|
}
|
||||||
})\
|
})
|
||||||
.train_test_split(test_size=TEST_SIZE)
|
diffusion_db_dataset = utils.datasets.split_streaming_dataset(diffusion_db_dataset, total_size=2000, test_size=TEST_SIZE)
|
||||||
|
|
||||||
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
|
midjourney_dataset = datasets.load_dataset("brivangl/midjourney-v6-llava", split="train", streaming=True)\
|
||||||
.take(2500)\
|
.select_columns(["image"])\
|
||||||
|
.map(lambda row: {
|
||||||
|
**row,
|
||||||
|
"qa": {
|
||||||
|
"question": "Is this image AI generated?",
|
||||||
|
"answer": "Yes."
|
||||||
|
}
|
||||||
|
})
|
||||||
|
midjourney_dataset = utils.datasets.split_streaming_dataset(midjourney_dataset, total_size=2000, test_size=TEST_SIZE)
|
||||||
|
|
||||||
|
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", streaming=True)\
|
||||||
.select_columns(["image"])\
|
.select_columns(["image"])\
|
||||||
.map(lambda row: {
|
.map(lambda row: {
|
||||||
**row,
|
**row,
|
||||||
"qa": {
|
"qa": {
|
||||||
"question": "Describe this image.",
|
"question": "Is this image AI generated?",
|
||||||
"answer": "This is a real image."
|
"answer": "No."
|
||||||
}
|
}
|
||||||
})\
|
})
|
||||||
.train_test_split(test_size=TEST_SIZE)
|
flickr_dataset = utils.datasets.split_streaming_dataset(flickr_dataset, total_size=800, test_size=TEST_SIZE)
|
||||||
|
|
||||||
wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train")\
|
wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", streaming=True)\
|
||||||
.take(2500)\
|
|
||||||
.select_columns(["image"])\
|
.select_columns(["image"])\
|
||||||
.map(lambda row: {
|
.map(lambda row: {
|
||||||
**row,
|
**row,
|
||||||
"qa": {
|
"qa": {
|
||||||
"question": "Describe thie image.",
|
"question": "Is this image AI generated?",
|
||||||
"answer": "This is a real image."
|
"answer": "No."
|
||||||
}
|
}
|
||||||
})\
|
})
|
||||||
.train_test_split(test_size=TEST_SIZE)
|
wiki_art_dataset = utils.datasets.split_streaming_dataset(wiki_art_dataset, total_size=800, test_size=TEST_SIZE)
|
||||||
|
|
||||||
training_dataset = datasets.concatenate_datasets([
|
anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust_remote_code=True, split="train", streaming=True)\
|
||||||
|
.select_columns(["image"])\
|
||||||
|
.map(lambda row: {
|
||||||
|
**row,
|
||||||
|
"qa": {
|
||||||
|
"question": "Is this image AI generated?",
|
||||||
|
"answer": "No."
|
||||||
|
}
|
||||||
|
})
|
||||||
|
anime_dataset = utils.datasets.split_streaming_dataset(anime_dataset, total_size=800, test_size=TEST_SIZE)
|
||||||
|
|
||||||
|
coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", streaming=True)\
|
||||||
|
.select_columns(["image"])\
|
||||||
|
.map(lambda row: {
|
||||||
|
**row,
|
||||||
|
"qa": {
|
||||||
|
"question": "Is this image AI generated?",
|
||||||
|
"answer": "No."
|
||||||
|
}
|
||||||
|
})
|
||||||
|
coco_dataset = utils.datasets.split_streaming_dataset(coco_dataset, total_size=800, test_size=TEST_SIZE)
|
||||||
|
|
||||||
|
movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split="train", streaming=True)\
|
||||||
|
.select_columns(["age"])\
|
||||||
|
.map(lambda row: {
|
||||||
|
**row,
|
||||||
|
"qa": {
|
||||||
|
"question": "Is this image AI generated?",
|
||||||
|
"answer": "No."
|
||||||
|
}
|
||||||
|
})
|
||||||
|
movie_poster_dataset = utils.datasets.split_streaming_dataset(movie_poster_dataset, total_size=800, test_size=TEST_SIZE)
|
||||||
|
|
||||||
|
training_dataset = datasets.interleave_datasets([
|
||||||
diffusion_db_dataset["train"],
|
diffusion_db_dataset["train"],
|
||||||
|
midjourney_dataset["train"],
|
||||||
flickr_dataset["train"],
|
flickr_dataset["train"],
|
||||||
wiki_art_dataset["train"],
|
wiki_art_dataset["train"],
|
||||||
]).shuffle()
|
anime_dataset["train"],
|
||||||
test_dataset = datasets.concatenate_datasets([
|
coco_dataset["train"],
|
||||||
|
movie_poster_dataset["train"],
|
||||||
|
], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True))
|
||||||
|
test_dataset = datasets.interleave_datasets([
|
||||||
diffusion_db_dataset["test"],
|
diffusion_db_dataset["test"],
|
||||||
|
midjourney_dataset["test"],
|
||||||
flickr_dataset["test"],
|
flickr_dataset["test"],
|
||||||
wiki_art_dataset["test"],
|
wiki_art_dataset["test"],
|
||||||
]).shuffle()
|
anime_dataset["test"],
|
||||||
|
coco_dataset["test"],
|
||||||
|
movie_poster_dataset["test"],
|
||||||
|
], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True))
|
||||||
|
|
||||||
|
print("Training and test dataset prepared.")
|
||||||
|
|
||||||
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
|
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
|
||||||
moondream = transformers.AutoModelForCausalLM.from_pretrained(
|
moondream = transformers.AutoModelForCausalLM.from_pretrained(
|
||||||
@@ -150,7 +206,6 @@ dataloaders = {
|
|||||||
"train": torch.utils.data.DataLoader(
|
"train": torch.utils.data.DataLoader(
|
||||||
training_dataset,
|
training_dataset,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
shuffle=True,
|
|
||||||
collate_fn=collate,
|
collate_fn=collate,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
@@ -158,7 +213,7 @@ dataloaders = {
|
|||||||
moondream.text_model.train()
|
moondream.text_model.train()
|
||||||
moondream.text_model.transformer.gradient_checkpointing_enable()
|
moondream.text_model.transformer.gradient_checkpointing_enable()
|
||||||
|
|
||||||
total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS
|
total_steps = EPOCHS * (TOTAL_DATA_SIZE * (1 - TEST_SIZE)) // GRAD_ACCUM_STEPS
|
||||||
optimizer = bitsandbytes.optim.Adam8bit(
|
optimizer = bitsandbytes.optim.Adam8bit(
|
||||||
[{"params": moondream.text_model.parameters()}],
|
[{"params": moondream.text_model.parameters()}],
|
||||||
lr=LR*0.1,
|
lr=LR*0.1,
|
||||||
@@ -184,6 +239,7 @@ for epoch in range(EPOCHS):
|
|||||||
|
|
||||||
moondream.save_pretrained("checkpoints/moondream-mai")
|
moondream.save_pretrained("checkpoints/moondream-mai")
|
||||||
|
|
||||||
|
moondream.eval()
|
||||||
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
|
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
correct_predictions = 0
|
correct_predictions = 0
|
||||||
@@ -201,9 +257,6 @@ for sample in tqdm(test_dataset, desc="Validation"):
|
|||||||
if md_answer == ground_truth:
|
if md_answer == ground_truth:
|
||||||
correct_predictions += 1
|
correct_predictions += 1
|
||||||
|
|
||||||
if i % 10 == 0:
|
accuracy = correct_predictions * 100 / (TOTAL_DATA_SIZE * TEST_SIZE)
|
||||||
print(f"Question: f{sample["qa"]["answer"]")
|
|
||||||
|
|
||||||
accuracy = correct_predictions * 100 / len(test_dataset)
|
|
||||||
|
|
||||||
print(f"Model accuracy: f{accuracy}%")
|
print(f"Model accuracy: f{accuracy}%")
|
||||||
|
8
utils/datasets.py
Normal file
8
utils/datasets.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
import datasets
|
||||||
|
|
||||||
|
def split_streaming_dataset(ds: datasets.IterableDataset, total_size: int, test_size: float) -> dict[str, datasets.IterableDataset]:
|
||||||
|
size = round(total_size * (1 - test_size))
|
||||||
|
return {
|
||||||
|
"train": ds.take(size),
|
||||||
|
"test": ds.skip(size).take(total_size - size),
|
||||||
|
}
|
Reference in New Issue
Block a user