increase data diversity

This commit is contained in:
2024-12-13 22:19:53 +00:00
parent 46b60151c7
commit 14c6f26ddc
4 changed files with 150 additions and 84 deletions

8
utils/datasets.py Normal file
View File

@@ -0,0 +1,8 @@
import datasets
def split_streaming_dataset(ds: datasets.IterableDataset, total_size: int, test_size: float) -> dict[str, datasets.IterableDataset]:
size = round(total_size * (1 - test_size))
return {
"train": ds.take(size),
"test": ds.skip(size).take(total_size - size),
}