Files
mai/moondream/generate_dataset.py

108 lines
3.1 KiB
Python

import os
import torch
import datasets
import diffusers
from .hyperparams import MOONDREAM_REVISION
auth_token = os.getenv("HF_ACCESS_TOKEN")
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
moondream = transformers.AutoModelForCausalLM.from_pretrained(
"vikhyatk/moondream2",
revision=MOONDREAM_REVISION,
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
).to("cuda")
def collate(batch):
images = []
questions = []
for sample in batch:
images.append(sample["image"])
questions.append("Describe this image.")
return images, questions
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", streaming=True)\
.select_columns(["image"])\
.take(1)
wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", streaming=True)\
.select_columns(["image"])\
.take(1)
anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust_remote_code=True, split="train", streaming=True)\
.select_columns(["image"])\
.take(1)
coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", streaming=True)\
.select_columns(["image"])\
.take(1)
movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split="train", streaming=True)\
.select_columns(["image"])\
.take(1)
cars_dataset = datasets.load_dataset("tanganke/stanford_cars", split="train", streaming=True)\
.select_columns(["image"])\
.take(1)
website_dataset = datasets.load_dataset("silatus/1k_Website_Screenshots_and_Metadata", split="train", streaming=True)\
.select_columns(["image"])\
.take(1)
movie_scene_dataset = datasets.load_dataset("unography/movie-scenes-resized-captioned", split="train", streaming=True)\
.select_columns(["image"])\
.take(1)
ds = datasets.concatenate_datasets([
flickr_dataset,
wiki_art_dataset,
anime_dataset,
coco_dataset,
movie_poster_dataset,
cars_dataset,
website_dataset,
movie_scene_dataset,
])
data_loader = torch.utils.data.DataLoader(
ds,
batch_size=8,
collate_fn=collate
)
captions = []
for batch in data_loader:
images, questions = batch
answers = moondream.batch_answer(
images=images,
prompts=questions,
tokenizer=tokenizer
)
for ans in answers:
print(ans)
print()
captions.extend(answers)
ds = ds.add_column("caption", captions)
del moondream
pipe = diffusers.StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-large",
torch_dtype=torch.bfloat16,
token=auth_token,
).to("cuda")
image = pipe(
"A capybara holding a sign that reads Hello World",
num_inference_steps=28,
guidance_scale=3.5,
).images[0]
image.save("capybara.png")