Skip to content

Commit

Permalink
Split Dataset into train/test/val (#604)
Browse files Browse the repository at this point in the history
Co-authored-by: Helio Machado <[email protected]>
  • Loading branch information
dreadatour and 0x2b3bfa0 authored Nov 17, 2024
1 parent 1fe4891 commit ebc19c6
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/datachain/toolkit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .split import train_test_split

__all__ = ["train_test_split"]
67 changes: 67 additions & 0 deletions src/datachain/toolkit/split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from datachain import C, DataChain


def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
"""
Splits a DataChain into multiple subsets based on the provided weights.
This function partitions the rows or items of a DataChain into disjoint subsets,
ensuring that the relative sizes of the subsets correspond to the given weights.
It is particularly useful for creating training, validation, and test datasets.
Args:
dc (DataChain):
The DataChain instance to split.
weights (list[float]):
A list of weights indicating the relative proportions of the splits.
The weights do not need to sum to 1; they will be normalized internally.
For example:
- `[0.7, 0.3]` corresponds to a 70/30 split;
- `[2, 1, 1]` corresponds to a 50/25/25 split.
Returns:
list[DataChain]:
A list of DataChain instances, one for each weight in the weights list.
Examples:
Train-test split:
```python
from datachain import DataChain
from datachain.toolkit import train_test_split
# Load a DataChain from a storage source (e.g., S3 bucket)
dc = DataChain.from_storage("s3://bucket/dir/")
# Perform a 70/30 train-test split
train, test = train_test_split(dc, [0.7, 0.3])
# Save the resulting splits
train.save("dataset_train")
test.save("dataset_test")
```
Train-test-validation split:
```python
train, test, val = train_test_split(dc, [0.7, 0.2, 0.1])
train.save("dataset_train")
test.save("dataset_test")
val.save("dataset_val")
```
Note:
The splits are random but deterministic, based on Dataset `sys__rand` field.
"""
if len(weights) < 2:
raise ValueError("Weights should have at least two elements")
if any(weight < 0 for weight in weights):
raise ValueError("Weights should be non-negative")

weights_normalized = [weight / sum(weights) for weight in weights]

return [
dc.filter(
C("sys__rand") % 1000 >= round(sum(weights_normalized[:index]) * 1000),
C("sys__rand") % 1000 < round(sum(weights_normalized[: index + 1]) * 1000),
)
for index, _ in enumerate(weights_normalized)
]
42 changes: 41 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
SQLiteWarehouse,
)
from datachain.dataset import DatasetRecord
from datachain.lib.dc import DataChain
from datachain.lib.dc import DataChain, Sys
from datachain.query.session import Session
from datachain.utils import (
ENV_DATACHAIN_GLOBAL_CONFIG_DIR,
Expand Down Expand Up @@ -701,3 +701,43 @@ def studio_datasets(requests_mock):
]

requests_mock.post(f"{STUDIO_URL}/api/datachain/ls-datasets", json=datasets)


@pytest.fixture
def not_random_ds(test_session):
return DataChain.from_records(
[
{"sys__id": 1, "sys__rand": 50, "fib": 0},
{"sys__id": 2, "sys__rand": 150, "fib": 1},
{"sys__id": 3, "sys__rand": 250, "fib": 1},
{"sys__id": 4, "sys__rand": 350, "fib": 2},
{"sys__id": 5, "sys__rand": 450, "fib": 3},
{"sys__id": 6, "sys__rand": 550, "fib": 5},
{"sys__id": 7, "sys__rand": 650, "fib": 8},
{"sys__id": 8, "sys__rand": 750, "fib": 13},
{"sys__id": 9, "sys__rand": 850, "fib": 21},
{"sys__id": 10, "sys__rand": 950, "fib": 34},
],
session=test_session,
schema={"sys": Sys, "fib": int},
)


@pytest.fixture
def pseudo_random_ds(test_session):
return DataChain.from_records(
[
{"sys__id": 1, "sys__rand": 1344339883, "fib": 0},
{"sys__id": 2, "sys__rand": 3901153096, "fib": 1},
{"sys__id": 3, "sys__rand": 4255991360, "fib": 1},
{"sys__id": 4, "sys__rand": 2526403609, "fib": 2},
{"sys__id": 5, "sys__rand": 1871733386, "fib": 3},
{"sys__id": 6, "sys__rand": 9380910850, "fib": 5},
{"sys__id": 7, "sys__rand": 2770679740, "fib": 8},
{"sys__id": 8, "sys__rand": 2538886575, "fib": 13},
{"sys__id": 9, "sys__rand": 3969542617, "fib": 21},
{"sys__id": 10, "sys__rand": 7541790992, "fib": 34},
],
session=test_session,
schema={"sys": Sys, "fib": int},
)
42 changes: 42 additions & 0 deletions tests/func/test_toolkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from datachain.toolkit import train_test_split


@pytest.mark.parametrize(
"weights,expected",
[
[[1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]],
[[4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]],
[[0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]],
],
)
def test_train_test_split_not_random(not_random_ds, weights, expected):
res = train_test_split(not_random_ds, weights)
assert len(res) == len(expected)

for i, dc in enumerate(res):
assert list(dc.collect("sys.id")) == expected[i]


@pytest.mark.parametrize(
"weights,expected",
[
[[1, 1], [[2, 3, 5], [1, 4, 6, 7, 8, 9, 10]]],
[[4, 1], [[2, 3, 4, 5, 7, 8, 9], [1, 6, 10]]],
[[0.7, 0.2, 0.1], [[2, 3, 4, 5, 8, 9], [1, 6, 7], [10]]],
],
)
def test_train_test_split_random(pseudo_random_ds, weights, expected):
res = train_test_split(pseudo_random_ds, weights)
assert len(res) == len(expected)

for i, dc in enumerate(res):
assert list(dc.collect("sys.id")) == expected[i]


def test_train_test_split_errors(not_random_ds):
with pytest.raises(ValueError, match="Weights should have at least two elements"):
train_test_split(not_random_ds, [0.5])
with pytest.raises(ValueError, match="Weights should be non-negative"):
train_test_split(not_random_ds, [-1, 1])

0 comments on commit ebc19c6

Please sign in to comment.