add script for generating dataset
This commit is contained in:
107
moondream/generate_dataset.py
Normal file
107
moondream/generate_dataset.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import os
|
||||
import torch
|
||||
import datasets
|
||||
import diffusers
|
||||
from .hyperparams import MOONDREAM_REVISION
|
||||
|
||||
auth_token = os.getenv("HF_ACCESS_TOKEN")
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
|
||||
moondream = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
"vikhyatk/moondream2",
|
||||
revision=MOONDREAM_REVISION,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
def collate(batch):
|
||||
images = []
|
||||
questions = []
|
||||
|
||||
for sample in batch:
|
||||
images.append(sample["image"])
|
||||
questions.append("Describe this image.")
|
||||
|
||||
return images, questions
|
||||
|
||||
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
.take(1)
|
||||
|
||||
wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
.take(1)
|
||||
|
||||
anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust_remote_code=True, split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
.take(1)
|
||||
|
||||
coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
.take(1)
|
||||
|
||||
movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
.take(1)
|
||||
|
||||
cars_dataset = datasets.load_dataset("tanganke/stanford_cars", split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
.take(1)
|
||||
|
||||
website_dataset = datasets.load_dataset("silatus/1k_Website_Screenshots_and_Metadata", split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
.take(1)
|
||||
|
||||
movie_scene_dataset = datasets.load_dataset("unography/movie-scenes-resized-captioned", split="train", streaming=True)\
|
||||
.select_columns(["image"])\
|
||||
.take(1)
|
||||
|
||||
ds = datasets.concatenate_datasets([
|
||||
flickr_dataset,
|
||||
wiki_art_dataset,
|
||||
anime_dataset,
|
||||
coco_dataset,
|
||||
movie_poster_dataset,
|
||||
cars_dataset,
|
||||
website_dataset,
|
||||
movie_scene_dataset,
|
||||
])
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
ds,
|
||||
batch_size=8,
|
||||
collate_fn=collate
|
||||
)
|
||||
|
||||
captions = []
|
||||
for batch in data_loader:
|
||||
images, questions = batch
|
||||
answers = moondream.batch_answer(
|
||||
images=images,
|
||||
prompts=questions,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
for ans in answers:
|
||||
print(ans)
|
||||
print()
|
||||
|
||||
captions.extend(answers)
|
||||
|
||||
ds = ds.add_column("caption", captions)
|
||||
|
||||
del moondream
|
||||
|
||||
pipe = diffusers.StableDiffusion3Pipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
torch_dtype=torch.bfloat16,
|
||||
token=auth_token,
|
||||
).to("cuda")
|
||||
|
||||
image = pipe(
|
||||
"A capybara holding a sign that reads Hello World",
|
||||
num_inference_steps=28,
|
||||
guidance_scale=3.5,
|
||||
).images[0]
|
||||
image.save("capybara.png")
|
Reference in New Issue
Block a user