-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3e7b523
commit 869a99f
Showing
5 changed files
with
286 additions
and
264 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation | ||
# Copyright (c) 2021-2024 TileDB, Inc. | ||
# | ||
# Licensed under the MIT License. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any, TypeVar | ||
|
||
from torch.utils.data import DataLoader | ||
|
||
from tiledbsoma_ml._distributed import init_multiprocessing | ||
from tiledbsoma_ml.pytorch import ( | ||
ExperimentAxisQueryIterableDataset, | ||
ExperimentAxisQueryIterDataPipe, | ||
) | ||
|
||
_T = TypeVar("_T") | ||
|
||
|
||
def experiment_dataloader( | ||
ds: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, | ||
**dataloader_kwargs: Any, | ||
) -> DataLoader: | ||
"""Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a | ||
:class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` | ||
or :class:`tiledbsoma_ml.ExperimentAxisQueryIterDataPipe`. | ||
Several :class:`torch.utils.data.DataLoader` constructor parameters are not applicable, or are non-performant, | ||
when using loaders from this module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``. | ||
Specifying any of these parameters will result in an error. | ||
Refer to ``https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader`` for more information on | ||
:class:`torch.utils.data.DataLoader` parameters. | ||
Args: | ||
ds: | ||
A :class:`torch.utils.data.IterableDataset` or a :class:`torchdata.datapipes.iter.IterDataPipe`. May | ||
include chained data pipes. | ||
**dataloader_kwargs: | ||
Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor, | ||
except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not | ||
supported when using data loaders in this module. | ||
Returns: | ||
A :class:`torch.utils.data.DataLoader`. | ||
Raises: | ||
ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, or ``batch_sampler`` params | ||
are passed as keyword arguments. | ||
Lifecycle: | ||
experimental | ||
""" | ||
unsupported_dataloader_args = [ | ||
"shuffle", | ||
"batch_size", | ||
"sampler", | ||
"batch_sampler", | ||
] | ||
if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()): | ||
raise ValueError( | ||
f"The {','.join(unsupported_dataloader_args)} DataLoader parameters are not supported" | ||
) | ||
|
||
if dataloader_kwargs.get("num_workers", 0) > 0: | ||
init_multiprocessing() | ||
|
||
if "collate_fn" not in dataloader_kwargs: | ||
dataloader_kwargs["collate_fn"] = _collate_noop | ||
|
||
return DataLoader( | ||
ds, | ||
batch_size=None, # batching is handled by upstream iterator | ||
shuffle=False, # shuffling is handled by upstream iterator | ||
**dataloader_kwargs, | ||
) | ||
|
||
|
||
def _collate_noop(datum: _T) -> _T: | ||
"""Noop collation for use with a dataloader instance. | ||
Private. | ||
""" | ||
return datum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation | ||
# Copyright (c) 2021-2024 TileDB, Inc. | ||
# | ||
# Licensed under the MIT License. | ||
|
||
from functools import partial | ||
from typing import Any, Tuple | ||
from unittest.mock import patch | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
import pandas as pd | ||
import pytest | ||
from tiledbsoma import Experiment | ||
|
||
from tests._utils import IterableWrappers, IterableWrapperType, pytorch_x_value_gen | ||
from tiledbsoma_ml import ExperimentAxisQueryIterDataPipe | ||
from tiledbsoma_ml.dataloader import experiment_dataloader | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] | ||
) | ||
@pytest.mark.parametrize("PipeClass", IterableWrappers) | ||
def test_multiprocessing__returns_full_result( | ||
PipeClass: IterableWrapperType, | ||
soma_experiment: Experiment, | ||
) -> None: | ||
"""Tests that ``ExperimentAxisQueryIterDataPipe`` / ``ExperimentAxisQueryIterableDataset`` | ||
provide all data, as collected from multiple processes that are managed by a PyTorch DataLoader | ||
with multiple workers configured.""" | ||
with soma_experiment.axis_query(measurement_name="RNA") as query: | ||
dp = PipeClass( | ||
query, | ||
X_name="raw", | ||
obs_column_names=["soma_joinid", "label"], | ||
io_batch_size=3, # two chunks, one per worker | ||
) | ||
# Wrap with a DataLoader, which sets up the multiprocessing | ||
dl = experiment_dataloader(dp, num_workers=2) | ||
|
||
full_result = list(iter(dl)) | ||
|
||
soma_joinids = np.concatenate( | ||
[t[1]["soma_joinid"].to_numpy() for t in full_result] | ||
) | ||
assert sorted(soma_joinids) == list(range(6)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"obs_range,var_range,X_value_gen,use_eager_fetch", | ||
[(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], | ||
) | ||
@pytest.mark.parametrize("PipeClass", IterableWrappers) | ||
def test_experiment_dataloader__non_batched( | ||
PipeClass: IterableWrapperType, | ||
soma_experiment: Experiment, | ||
use_eager_fetch: bool, | ||
) -> None: | ||
with soma_experiment.axis_query(measurement_name="RNA") as query: | ||
dp = PipeClass( | ||
query, | ||
X_name="raw", | ||
obs_column_names=["label"], | ||
shuffle=False, | ||
use_eager_fetch=use_eager_fetch, | ||
) | ||
dl = experiment_dataloader(dp) | ||
data = [row for row in dl] | ||
assert all(d[0].shape == (3,) for d in data) | ||
assert all(d[1].shape == (1, 1) for d in data) | ||
|
||
row = data[0] | ||
assert row[0].tolist() == [0, 1, 0] | ||
assert row[1]["label"].tolist() == ["0"] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"obs_range,var_range,X_value_gen,use_eager_fetch", | ||
[(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], | ||
) | ||
@pytest.mark.parametrize("PipeClass", IterableWrappers) | ||
def test_experiment_dataloader__batched( | ||
PipeClass: IterableWrapperType, | ||
soma_experiment: Experiment, | ||
use_eager_fetch: bool, | ||
) -> None: | ||
with soma_experiment.axis_query(measurement_name="RNA") as query: | ||
dp = PipeClass( | ||
query, | ||
X_name="raw", | ||
batch_size=3, | ||
shuffle=False, | ||
use_eager_fetch=use_eager_fetch, | ||
) | ||
dl = experiment_dataloader(dp) | ||
data = [row for row in dl] | ||
|
||
batch = data[0] | ||
assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] | ||
assert batch[1].to_numpy().tolist() == [[0], [1], [2]] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"obs_range,var_range,X_value_gen,use_eager_fetch", | ||
[ | ||
(10, 3, pytorch_x_value_gen, use_eager_fetch) | ||
for use_eager_fetch in (True, False) | ||
], | ||
) | ||
@pytest.mark.parametrize("PipeClass", IterableWrappers) | ||
def test_experiment_dataloader__batched_length( | ||
PipeClass: IterableWrapperType, | ||
soma_experiment: Experiment, | ||
use_eager_fetch: bool, | ||
) -> None: | ||
with soma_experiment.axis_query(measurement_name="RNA") as query: | ||
dp = PipeClass( | ||
query, | ||
X_name="raw", | ||
obs_column_names=["label"], | ||
batch_size=3, | ||
shuffle=False, | ||
use_eager_fetch=use_eager_fetch, | ||
) | ||
dl = experiment_dataloader(dp) | ||
assert len(dl) == len(list(dl)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"obs_range,var_range,X_value_gen,batch_size", | ||
[(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)], | ||
) | ||
@pytest.mark.parametrize("PipeClass", IterableWrappers) | ||
def test_experiment_dataloader__collate_fn( | ||
PipeClass: IterableWrapperType, | ||
soma_experiment: Experiment, | ||
batch_size: int, | ||
) -> None: | ||
def collate_fn( | ||
batch_size: int, data: Tuple[npt.NDArray[np.number[Any]], pd.DataFrame] | ||
) -> Tuple[npt.NDArray[np.number[Any]], pd.DataFrame]: | ||
assert isinstance(data, tuple) | ||
assert len(data) == 2 | ||
assert isinstance(data[0], np.ndarray) and isinstance(data[1], pd.DataFrame) | ||
if batch_size > 1: | ||
assert data[0].shape[0] == data[1].shape[0] | ||
assert data[0].shape[0] <= batch_size | ||
else: | ||
assert data[0].ndim == 1 | ||
assert data[1].shape[1] <= batch_size | ||
return data | ||
|
||
with soma_experiment.axis_query(measurement_name="RNA") as query: | ||
dp = PipeClass( | ||
query, | ||
X_name="raw", | ||
obs_column_names=["label"], | ||
batch_size=batch_size, | ||
shuffle=False, | ||
) | ||
dl = experiment_dataloader(dp, collate_fn=partial(collate_fn, batch_size)) | ||
assert len(list(dl)) > 0 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"obs_range,var_range,X_value_gen", [(10, 1, pytorch_x_value_gen)] | ||
) | ||
def test__pytorch_splitting( | ||
soma_experiment: Experiment, | ||
) -> None: | ||
with soma_experiment.axis_query(measurement_name="RNA") as query: | ||
dp = ExperimentAxisQueryIterDataPipe( | ||
query, | ||
X_name="raw", | ||
obs_column_names=["label"], | ||
) | ||
# function not available for IterableDataset, yet.... | ||
dp_train, dp_test = dp.random_split( | ||
weights={"train": 0.7, "test": 0.3}, seed=1234 | ||
) | ||
dl = experiment_dataloader(dp_train) | ||
|
||
all_rows = list(iter(dl)) | ||
assert len(all_rows) == 7 | ||
|
||
|
||
def test_experiment_dataloader__unsupported_params__fails() -> None: | ||
with patch( | ||
"tiledbsoma_ml.pytorch.ExperimentAxisQueryIterDataPipe" | ||
) as dummy_exp_data_pipe: | ||
with pytest.raises(ValueError): | ||
experiment_dataloader(dummy_exp_data_pipe, shuffle=True) | ||
with pytest.raises(ValueError): | ||
experiment_dataloader(dummy_exp_data_pipe, batch_size=3) | ||
with pytest.raises(ValueError): | ||
experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[]) | ||
with pytest.raises(ValueError): | ||
experiment_dataloader(dummy_exp_data_pipe, sampler=[]) |
Oops, something went wrong.