-
Notifications
You must be signed in to change notification settings - Fork 106
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Split Dataset into train/test/val (#604)
Co-authored-by: Helio Machado <[email protected]>
- Loading branch information
1 parent
1fe4891
commit ebc19c6
Showing
4 changed files
with
153 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .split import train_test_split | ||
|
||
__all__ = ["train_test_split"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from datachain import C, DataChain | ||
|
||
|
||
def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]: | ||
""" | ||
Splits a DataChain into multiple subsets based on the provided weights. | ||
This function partitions the rows or items of a DataChain into disjoint subsets, | ||
ensuring that the relative sizes of the subsets correspond to the given weights. | ||
It is particularly useful for creating training, validation, and test datasets. | ||
Args: | ||
dc (DataChain): | ||
The DataChain instance to split. | ||
weights (list[float]): | ||
A list of weights indicating the relative proportions of the splits. | ||
The weights do not need to sum to 1; they will be normalized internally. | ||
For example: | ||
- `[0.7, 0.3]` corresponds to a 70/30 split; | ||
- `[2, 1, 1]` corresponds to a 50/25/25 split. | ||
Returns: | ||
list[DataChain]: | ||
A list of DataChain instances, one for each weight in the weights list. | ||
Examples: | ||
Train-test split: | ||
```python | ||
from datachain import DataChain | ||
from datachain.toolkit import train_test_split | ||
# Load a DataChain from a storage source (e.g., S3 bucket) | ||
dc = DataChain.from_storage("s3://bucket/dir/") | ||
# Perform a 70/30 train-test split | ||
train, test = train_test_split(dc, [0.7, 0.3]) | ||
# Save the resulting splits | ||
train.save("dataset_train") | ||
test.save("dataset_test") | ||
``` | ||
Train-test-validation split: | ||
```python | ||
train, test, val = train_test_split(dc, [0.7, 0.2, 0.1]) | ||
train.save("dataset_train") | ||
test.save("dataset_test") | ||
val.save("dataset_val") | ||
``` | ||
Note: | ||
The splits are random but deterministic, based on Dataset `sys__rand` field. | ||
""" | ||
if len(weights) < 2: | ||
raise ValueError("Weights should have at least two elements") | ||
if any(weight < 0 for weight in weights): | ||
raise ValueError("Weights should be non-negative") | ||
|
||
weights_normalized = [weight / sum(weights) for weight in weights] | ||
|
||
return [ | ||
dc.filter( | ||
C("sys__rand") % 1000 >= round(sum(weights_normalized[:index]) * 1000), | ||
C("sys__rand") % 1000 < round(sum(weights_normalized[: index + 1]) * 1000), | ||
) | ||
for index, _ in enumerate(weights_normalized) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import pytest | ||
|
||
from datachain.toolkit import train_test_split | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"weights,expected", | ||
[ | ||
[[1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]], | ||
[[4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]], | ||
[[0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]], | ||
], | ||
) | ||
def test_train_test_split_not_random(not_random_ds, weights, expected): | ||
res = train_test_split(not_random_ds, weights) | ||
assert len(res) == len(expected) | ||
|
||
for i, dc in enumerate(res): | ||
assert list(dc.collect("sys.id")) == expected[i] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"weights,expected", | ||
[ | ||
[[1, 1], [[2, 3, 5], [1, 4, 6, 7, 8, 9, 10]]], | ||
[[4, 1], [[2, 3, 4, 5, 7, 8, 9], [1, 6, 10]]], | ||
[[0.7, 0.2, 0.1], [[2, 3, 4, 5, 8, 9], [1, 6, 7], [10]]], | ||
], | ||
) | ||
def test_train_test_split_random(pseudo_random_ds, weights, expected): | ||
res = train_test_split(pseudo_random_ds, weights) | ||
assert len(res) == len(expected) | ||
|
||
for i, dc in enumerate(res): | ||
assert list(dc.collect("sys.id")) == expected[i] | ||
|
||
|
||
def test_train_test_split_errors(not_random_ds): | ||
with pytest.raises(ValueError, match="Weights should have at least two elements"): | ||
train_test_split(not_random_ds, [0.5]) | ||
with pytest.raises(ValueError, match="Weights should be non-negative"): | ||
train_test_split(not_random_ds, [-1, 1]) |