diff --git a/moondream/generate_dataset.py b/moondream/generate_dataset.py new file mode 100644 index 0000000..76c1c47 --- /dev/null +++ b/moondream/generate_dataset.py @@ -0,0 +1,107 @@ +import os +import torch +import datasets +import diffusers +from .hyperparams import MOONDREAM_REVISION + +auth_token = os.getenv("HF_ACCESS_TOKEN") + +tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2") +moondream = transformers.AutoModelForCausalLM.from_pretrained( + "vikhyatk/moondream2", + revision=MOONDREAM_REVISION, + trust_remote_code=True, + attn_implementation="flash_attention_2", + torch_dtype=torch.float16, +).to("cuda") + +def collate(batch): + images = [] + questions = [] + + for sample in batch: + images.append(sample["image"]) + questions.append("Describe this image.") + + return images, questions + +flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", streaming=True)\ + .select_columns(["image"])\ + .take(1) + +wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", streaming=True)\ + .select_columns(["image"])\ + .take(1) + +anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust_remote_code=True, split="train", streaming=True)\ + .select_columns(["image"])\ + .take(1) + +coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", streaming=True)\ + .select_columns(["image"])\ + .take(1) + +movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split="train", streaming=True)\ + .select_columns(["image"])\ + .take(1) + +cars_dataset = datasets.load_dataset("tanganke/stanford_cars", split="train", streaming=True)\ + .select_columns(["image"])\ + .take(1) + +website_dataset = datasets.load_dataset("silatus/1k_Website_Screenshots_and_Metadata", split="train", streaming=True)\ + .select_columns(["image"])\ + .take(1) + +movie_scene_dataset = datasets.load_dataset("unography/movie-scenes-resized-captioned", split="train", streaming=True)\ + .select_columns(["image"])\ + .take(1) + +ds = datasets.concatenate_datasets([ + flickr_dataset, + wiki_art_dataset, + anime_dataset, + coco_dataset, + movie_poster_dataset, + cars_dataset, + website_dataset, + movie_scene_dataset, +]) + +data_loader = torch.utils.data.DataLoader( + ds, + batch_size=8, + collate_fn=collate +) + +captions = [] +for batch in data_loader: + images, questions = batch + answers = moondream.batch_answer( + images=images, + prompts=questions, + tokenizer=tokenizer + ) + + for ans in answers: + print(ans) + print() + + captions.extend(answers) + +ds = ds.add_column("caption", captions) + +del moondream + +pipe = diffusers.StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-large", + torch_dtype=torch.bfloat16, + token=auth_token, +).to("cuda") + +image = pipe( + "A capybara holding a sign that reads Hello World", + num_inference_steps=28, + guidance_scale=3.5, +).images[0] +image.save("capybara.png") diff --git a/moondream/hyperparams.py b/moondream/hyperparams.py index 81dcc7b..2a7faa6 100644 --- a/moondream/hyperparams.py +++ b/moondream/hyperparams.py @@ -1,8 +1,10 @@ +MOONDREAM_REVISION = "2024-08-26" + TEST_SIZE = 0.2 # Number of times to repeat the training dataset. Increasing this may cause the model to overfit or # lose generalization due to catastrophic forgetting. Decreasing it may cause the model to underfit. -EPOCHS = 1 +EPOCHS = 2 # Number of samples to process in each batch. Set this to the highest value that doesn't cause an # out-of-memory error. Decrease it if you're running out of memory. @@ -10,7 +12,7 @@ BATCH_SIZE = 8 # Number of batches to process before updating the model. You can use this to simulate a higher batch # size than your GPU can handle. Set this to 1 to disable gradient accumulation. -GRAD_ACCUM_STEPS = 2 +GRAD_ACCUM_STEPS = 1 # Learning rate for the Adam optimizer. Needs to be tuned on a case-by-case basis. As a general rule # of thumb, increase it by 1.4 times each time you double the effective batch size. diff --git a/moondream/test.py b/moondream/test.py index 5ad381b..0b8bec4 100644 --- a/moondream/test.py +++ b/moondream/test.py @@ -23,7 +23,7 @@ pathlib.Path("./samples").mkdir(parents=True, exist_ok=True) img = Image.open("samples/Untitled.jpg") md_answer = moondream.answer_question( moondream.encode_image(img), - "Describe this image.", + "Is this image AI generated?", tokenizer=tokenizer, num_beams=4, no_repeat_ngram_size=5, diff --git a/moondream/train.py b/moondream/train.py index e2289c9..aa051e3 100644 --- a/moondream/train.py +++ b/moondream/train.py @@ -13,7 +13,7 @@ from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOC DEVICE = "cuda" DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16 MD_REVISION = "2024-07-23" -TOTAL_DATA_SIZE = 8000 +TOTAL_DATA_SIZE = 20000 diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", split="train", trust_remote_code=True, streaming=True)\ .select_columns(["image"])\ @@ -24,7 +24,7 @@ diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_ "answer": "Yes." } }) -diffusion_db_dataset = utils.datasets.split_streaming_dataset(diffusion_db_dataset, total_size=2000, test_size=TEST_SIZE) +diffusion_db_dataset = utils.datasets.split_streaming_dataset(diffusion_db_dataset, total_size=5000, test_size=TEST_SIZE) midjourney_dataset = datasets.load_dataset("brivangl/midjourney-v6-llava", split="train", streaming=True)\ .select_columns(["image"])\ @@ -35,7 +35,7 @@ midjourney_dataset = datasets.load_dataset("brivangl/midjourney-v6-llava", split "answer": "Yes." } }) -midjourney_dataset = utils.datasets.split_streaming_dataset(midjourney_dataset, total_size=2000, test_size=TEST_SIZE) +midjourney_dataset = utils.datasets.split_streaming_dataset(midjourney_dataset, total_size=5000, test_size=TEST_SIZE) flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", streaming=True)\ .select_columns(["image"])\ @@ -46,7 +46,7 @@ flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", stream "answer": "No." } }) -flickr_dataset = utils.datasets.split_streaming_dataset(flickr_dataset, total_size=800, test_size=TEST_SIZE) +flickr_dataset = utils.datasets.split_streaming_dataset(flickr_dataset, total_size=1250, test_size=TEST_SIZE) wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", streaming=True)\ .select_columns(["image"])\ @@ -57,7 +57,7 @@ wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", stream "answer": "No." } }) -wiki_art_dataset = utils.datasets.split_streaming_dataset(wiki_art_dataset, total_size=800, test_size=TEST_SIZE) +wiki_art_dataset = utils.datasets.split_streaming_dataset(wiki_art_dataset, total_size=1250, test_size=TEST_SIZE) anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust_remote_code=True, split="train", streaming=True)\ .select_columns(["image"])\ @@ -68,7 +68,7 @@ anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust "answer": "No." } }) -anime_dataset = utils.datasets.split_streaming_dataset(anime_dataset, total_size=800, test_size=TEST_SIZE) +anime_dataset = utils.datasets.split_streaming_dataset(anime_dataset, total_size=1250, test_size=TEST_SIZE) coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", streaming=True)\ .select_columns(["image"])\ @@ -79,10 +79,10 @@ coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", s "answer": "No." } }) -coco_dataset = utils.datasets.split_streaming_dataset(coco_dataset, total_size=800, test_size=TEST_SIZE) +coco_dataset = utils.datasets.split_streaming_dataset(coco_dataset, total_size=1250, test_size=TEST_SIZE) movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split="train", streaming=True)\ - .select_columns(["age"])\ + .select_columns(["image"])\ .map(lambda row: { **row, "qa": { @@ -90,7 +90,40 @@ movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split "answer": "No." } }) -movie_poster_dataset = utils.datasets.split_streaming_dataset(movie_poster_dataset, total_size=800, test_size=TEST_SIZE) +movie_poster_dataset = utils.datasets.split_streaming_dataset(movie_poster_dataset, total_size=1250, test_size=TEST_SIZE) + +cars_dataset = datasets.load_dataset("tanganke/stanford_cars", split="train", streaming=True)\ + .select_columns(["image"])\ + .map(lambda row: { + **row, + "qa": { + "question": "Is this image AI generated?", + "answer": "No." + } + }) +cars_dataset = utils.datasets.split_streaming_dataset(cars_dataset, total_size=1250, test_size=TEST_SIZE) + +website_dataset = datasets.load_dataset("silatus/1k_Website_Screenshots_and_Metadata", split="train", streaming=True)\ + .select_columns(["image"])\ + .map(lambda row: { + **row, + "qa": { + "question": "Is this image AI generated?", + "answer": "No.", + } + }) +website_dataset = utils.datasets.split_streaming_dataset(website_dataset, total_size=1250, test_size=TEST_SIZE) + +movie_scene_dataset = datasets.load_dataset("unography/movie-scenes-resized-captioned", split="train", streaming=True)\ + .select_columns(["image"])\ + .map(lambda row: { + **row, + "qa": { + "question": "Is this image AI generated?", + "answer": "No.", + } + }) +movie_scene_dataset = utils.datasets.split_streaming_dataset(movie_scene_dataset, total_size=1250, test_size=TEST_SIZE) training_dataset = datasets.interleave_datasets([ diffusion_db_dataset["train"], @@ -100,6 +133,9 @@ training_dataset = datasets.interleave_datasets([ anime_dataset["train"], coco_dataset["train"], movie_poster_dataset["train"], + cars_dataset["train"], + website_dataset["train"], + movie_scene_dataset["train"], ], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True)) test_dataset = datasets.interleave_datasets([ diffusion_db_dataset["test"], @@ -109,6 +145,9 @@ test_dataset = datasets.interleave_datasets([ anime_dataset["test"], coco_dataset["test"], movie_poster_dataset["test"], + cars_dataset["test"], + website_dataset["test"], + movie_scene_dataset["test"], ], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True)) print("Training and test dataset prepared.") @@ -242,8 +281,11 @@ moondream.save_pretrained("checkpoints/moondream-mai") moondream.eval() pathlib.Path("./samples").mkdir(parents=True, exist_ok=True) +total = 0 correct_predictions = 0 for sample in tqdm(test_dataset, desc="Validation"): + total += 1 + md_answer = moondream.answer_question( moondream.encode_image(sample['image']), sample['qa']['question'], @@ -257,6 +299,6 @@ for sample in tqdm(test_dataset, desc="Validation"): if md_answer == ground_truth: correct_predictions += 1 -accuracy = correct_predictions * 100 / (TOTAL_DATA_SIZE * TEST_SIZE) +accuracy = correct_predictions * 100 / total print(f"Model accuracy: f{accuracy}%")