increase data diversity
This commit is contained in:
8
utils/datasets.py
Normal file
8
utils/datasets.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import datasets
|
||||
|
||||
def split_streaming_dataset(ds: datasets.IterableDataset, total_size: int, test_size: float) -> dict[str, datasets.IterableDataset]:
|
||||
size = round(total_size * (1 - test_size))
|
||||
return {
|
||||
"train": ds.take(size),
|
||||
"test": ds.skip(size).take(total_size - size),
|
||||
}
|
Reference in New Issue
Block a user