diff --git a/train.py b/train.py index d1caa7d..a412735 100644 --- a/train.py +++ b/train.py @@ -39,7 +39,7 @@ def load_data(): trust_remote_code=True, split="train", ) - flickr_dataset = load_dataset("nlphuji/flickr30k", split="test[:50%]") + flickr_dataset = load_dataset("nlphuji/flickr30k", split="test") painting_dataset = load_dataset( "keremberke/painting-style-classification", name="full", split="train" )