Files
mai/utils/datasets.py
2024-12-13 22:19:53 +00:00

9 lines
303 B
Python

import datasets
def split_streaming_dataset(ds: datasets.IterableDataset, total_size: int, test_size: float) -> dict[str, datasets.IterableDataset]:
size = round(total_size * (1 - test_size))
return {
"train": ds.take(size),
"test": ds.skip(size).take(total_size - size),
}