Files
mai/utils/datasets.py

9 lines
303 B
Python
Raw Normal View History

2024-12-13 22:19:53 +00:00
import datasets
def split_streaming_dataset(ds: datasets.IterableDataset, total_size: int, test_size: float) -> dict[str, datasets.IterableDataset]:
size = round(total_size * (1 - test_size))
return {
"train": ds.take(size),
"test": ds.skip(size).take(total_size - size),
}