diff --git a/train.py b/train.py index 72a46a2..506501d 100644 --- a/train.py +++ b/train.py @@ -33,7 +33,7 @@ def load_data(): diffusion_db_dataset = load_dataset( "poloclub/diffusiondb", - "2m_random_50k", + "2m_random_10k", trust_remote_code=True, split="train", )