From ebc19c6cead46aa549a5405cf42aa2a4d14c549a Mon Sep 17 00:00:00 2001 From: Vladimir Rudnykh Date: Sun, 17 Nov 2024 09:43:14 +0700 Subject: [PATCH] Split Dataset into train/test/val (#604) Co-authored-by: Helio Machado <0x2b3bfa0+git@googlemail.com> --- src/datachain/toolkit/__init__.py | 3 ++ src/datachain/toolkit/split.py | 67 +++++++++++++++++++++++++++++++ tests/conftest.py | 42 ++++++++++++++++++- tests/func/test_toolkit.py | 42 +++++++++++++++++++ 4 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 src/datachain/toolkit/__init__.py create mode 100644 src/datachain/toolkit/split.py create mode 100644 tests/func/test_toolkit.py diff --git a/src/datachain/toolkit/__init__.py b/src/datachain/toolkit/__init__.py new file mode 100644 index 000000000..0bcb55a4e --- /dev/null +++ b/src/datachain/toolkit/__init__.py @@ -0,0 +1,3 @@ +from .split import train_test_split + +__all__ = ["train_test_split"] diff --git a/src/datachain/toolkit/split.py b/src/datachain/toolkit/split.py new file mode 100644 index 000000000..1fb62fe0a --- /dev/null +++ b/src/datachain/toolkit/split.py @@ -0,0 +1,67 @@ +from datachain import C, DataChain + + +def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]: + """ + Splits a DataChain into multiple subsets based on the provided weights. + + This function partitions the rows or items of a DataChain into disjoint subsets, + ensuring that the relative sizes of the subsets correspond to the given weights. + It is particularly useful for creating training, validation, and test datasets. + + Args: + dc (DataChain): + The DataChain instance to split. + weights (list[float]): + A list of weights indicating the relative proportions of the splits. + The weights do not need to sum to 1; they will be normalized internally. + For example: + - `[0.7, 0.3]` corresponds to a 70/30 split; + - `[2, 1, 1]` corresponds to a 50/25/25 split. + + Returns: + list[DataChain]: + A list of DataChain instances, one for each weight in the weights list. + + Examples: + Train-test split: + ```python + from datachain import DataChain + from datachain.toolkit import train_test_split + + # Load a DataChain from a storage source (e.g., S3 bucket) + dc = DataChain.from_storage("s3://bucket/dir/") + + # Perform a 70/30 train-test split + train, test = train_test_split(dc, [0.7, 0.3]) + + # Save the resulting splits + train.save("dataset_train") + test.save("dataset_test") + ``` + + Train-test-validation split: + ```python + train, test, val = train_test_split(dc, [0.7, 0.2, 0.1]) + train.save("dataset_train") + test.save("dataset_test") + val.save("dataset_val") + ``` + + Note: + The splits are random but deterministic, based on Dataset `sys__rand` field. + """ + if len(weights) < 2: + raise ValueError("Weights should have at least two elements") + if any(weight < 0 for weight in weights): + raise ValueError("Weights should be non-negative") + + weights_normalized = [weight / sum(weights) for weight in weights] + + return [ + dc.filter( + C("sys__rand") % 1000 >= round(sum(weights_normalized[:index]) * 1000), + C("sys__rand") % 1000 < round(sum(weights_normalized[: index + 1]) * 1000), + ) + for index, _ in enumerate(weights_normalized) + ] diff --git a/tests/conftest.py b/tests/conftest.py index a9bd17f6d..3a15e216b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,7 +22,7 @@ SQLiteWarehouse, ) from datachain.dataset import DatasetRecord -from datachain.lib.dc import DataChain +from datachain.lib.dc import DataChain, Sys from datachain.query.session import Session from datachain.utils import ( ENV_DATACHAIN_GLOBAL_CONFIG_DIR, @@ -701,3 +701,43 @@ def studio_datasets(requests_mock): ] requests_mock.post(f"{STUDIO_URL}/api/datachain/ls-datasets", json=datasets) + + +@pytest.fixture +def not_random_ds(test_session): + return DataChain.from_records( + [ + {"sys__id": 1, "sys__rand": 50, "fib": 0}, + {"sys__id": 2, "sys__rand": 150, "fib": 1}, + {"sys__id": 3, "sys__rand": 250, "fib": 1}, + {"sys__id": 4, "sys__rand": 350, "fib": 2}, + {"sys__id": 5, "sys__rand": 450, "fib": 3}, + {"sys__id": 6, "sys__rand": 550, "fib": 5}, + {"sys__id": 7, "sys__rand": 650, "fib": 8}, + {"sys__id": 8, "sys__rand": 750, "fib": 13}, + {"sys__id": 9, "sys__rand": 850, "fib": 21}, + {"sys__id": 10, "sys__rand": 950, "fib": 34}, + ], + session=test_session, + schema={"sys": Sys, "fib": int}, + ) + + +@pytest.fixture +def pseudo_random_ds(test_session): + return DataChain.from_records( + [ + {"sys__id": 1, "sys__rand": 1344339883, "fib": 0}, + {"sys__id": 2, "sys__rand": 3901153096, "fib": 1}, + {"sys__id": 3, "sys__rand": 4255991360, "fib": 1}, + {"sys__id": 4, "sys__rand": 2526403609, "fib": 2}, + {"sys__id": 5, "sys__rand": 1871733386, "fib": 3}, + {"sys__id": 6, "sys__rand": 9380910850, "fib": 5}, + {"sys__id": 7, "sys__rand": 2770679740, "fib": 8}, + {"sys__id": 8, "sys__rand": 2538886575, "fib": 13}, + {"sys__id": 9, "sys__rand": 3969542617, "fib": 21}, + {"sys__id": 10, "sys__rand": 7541790992, "fib": 34}, + ], + session=test_session, + schema={"sys": Sys, "fib": int}, + ) diff --git a/tests/func/test_toolkit.py b/tests/func/test_toolkit.py new file mode 100644 index 000000000..f2388254d --- /dev/null +++ b/tests/func/test_toolkit.py @@ -0,0 +1,42 @@ +import pytest + +from datachain.toolkit import train_test_split + + +@pytest.mark.parametrize( + "weights,expected", + [ + [[1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]], + [[4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]], + [[0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]], + ], +) +def test_train_test_split_not_random(not_random_ds, weights, expected): + res = train_test_split(not_random_ds, weights) + assert len(res) == len(expected) + + for i, dc in enumerate(res): + assert list(dc.collect("sys.id")) == expected[i] + + +@pytest.mark.parametrize( + "weights,expected", + [ + [[1, 1], [[2, 3, 5], [1, 4, 6, 7, 8, 9, 10]]], + [[4, 1], [[2, 3, 4, 5, 7, 8, 9], [1, 6, 10]]], + [[0.7, 0.2, 0.1], [[2, 3, 4, 5, 8, 9], [1, 6, 7], [10]]], + ], +) +def test_train_test_split_random(pseudo_random_ds, weights, expected): + res = train_test_split(pseudo_random_ds, weights) + assert len(res) == len(expected) + + for i, dc in enumerate(res): + assert list(dc.collect("sys.id")) == expected[i] + + +def test_train_test_split_errors(not_random_ds): + with pytest.raises(ValueError, match="Weights should have at least two elements"): + train_test_split(not_random_ds, [0.5]) + with pytest.raises(ValueError, match="Weights should be non-negative"): + train_test_split(not_random_ds, [-1, 1])