Skip to content

Commit

Permalink
dataloader.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-williams committed Dec 17, 2024
1 parent 3e7b523 commit 869a99f
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 264 deletions.
2 changes: 1 addition & 1 deletion src/tiledbsoma_ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

"""An API to support machine learning applications built on SOMA."""

from .dataloader import experiment_dataloader
from .pytorch import (
ExperimentAxisQueryIterableDataset,
ExperimentAxisQueryIterDataPipe,
experiment_dataloader,
)

__version__ = "0.1.0-dev"
Expand Down
85 changes: 85 additions & 0 deletions src/tiledbsoma_ml/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation
# Copyright (c) 2021-2024 TileDB, Inc.
#
# Licensed under the MIT License.

from __future__ import annotations

from typing import Any, TypeVar

from torch.utils.data import DataLoader

from tiledbsoma_ml._distributed import init_multiprocessing
from tiledbsoma_ml.pytorch import (
ExperimentAxisQueryIterableDataset,
ExperimentAxisQueryIterDataPipe,
)

_T = TypeVar("_T")


def experiment_dataloader(
ds: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
**dataloader_kwargs: Any,
) -> DataLoader:
"""Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a
:class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset`
or :class:`tiledbsoma_ml.ExperimentAxisQueryIterDataPipe`.
Several :class:`torch.utils.data.DataLoader` constructor parameters are not applicable, or are non-performant,
when using loaders from this module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``.
Specifying any of these parameters will result in an error.
Refer to ``https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader`` for more information on
:class:`torch.utils.data.DataLoader` parameters.
Args:
ds:
A :class:`torch.utils.data.IterableDataset` or a :class:`torchdata.datapipes.iter.IterDataPipe`. May
include chained data pipes.
**dataloader_kwargs:
Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor,
except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not
supported when using data loaders in this module.
Returns:
A :class:`torch.utils.data.DataLoader`.
Raises:
ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, or ``batch_sampler`` params
are passed as keyword arguments.
Lifecycle:
experimental
"""
unsupported_dataloader_args = [
"shuffle",
"batch_size",
"sampler",
"batch_sampler",
]
if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()):
raise ValueError(
f"The {','.join(unsupported_dataloader_args)} DataLoader parameters are not supported"
)

if dataloader_kwargs.get("num_workers", 0) > 0:
init_multiprocessing()

if "collate_fn" not in dataloader_kwargs:
dataloader_kwargs["collate_fn"] = _collate_noop

return DataLoader(
ds,
batch_size=None, # batching is handled by upstream iterator
shuffle=False, # shuffling is handled by upstream iterator
**dataloader_kwargs,
)


def _collate_noop(datum: _T) -> _T:
"""Noop collation for use with a dataloader instance.
Private.
"""
return datum
72 changes: 0 additions & 72 deletions src/tiledbsoma_ml/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
import time
from math import ceil
from typing import (
Any,
ContextManager,
Iterable,
Iterator,
Sequence,
Tuple,
TypeVar,
Union,
)

Expand All @@ -36,15 +34,12 @@
from tiledbsoma_ml._distributed import (
get_distributed_world_rank,
get_worker_world_rank,
init_multiprocessing,
)
from tiledbsoma_ml._experiment_locator import ExperimentLocator
from tiledbsoma_ml._utils import NDArrayNumber, batched, splits

logger = logging.getLogger("tiledbsoma_ml.pytorch")

_T = TypeVar("_T")

NDArrayJoinId = npt.NDArray[np.int64]
XBatch = Union[NDArrayNumber, sparse.csr_matrix]
Batch = Tuple[XBatch, pd.DataFrame]
Expand Down Expand Up @@ -872,70 +867,3 @@ def set_epoch(self, epoch: int) -> None:
@property
def epoch(self) -> int:
return self._exp_iter.epoch


def experiment_dataloader(
ds: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
**dataloader_kwargs: Any,
) -> torch.utils.data.DataLoader:
"""Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a
:class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset`
or :class:`tiledbsoma_ml.ExperimentAxisQueryIterDataPipe`.
Several :class:`torch.utils.data.DataLoader` constructor parameters are not applicable, or are non-performant,
when using loaders from this module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``.
Specifying any of these parameters will result in an error.
Refer to ``https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader`` for more information on
:class:`torch.utils.data.DataLoader` parameters.
Args:
ds:
A :class:`torch.utils.data.IterableDataset` or a :class:`torchdata.datapipes.iter.IterDataPipe`. May
include chained data pipes.
**dataloader_kwargs:
Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor,
except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not
supported when using data loaders in this module.
Returns:
A :class:`torch.utils.data.DataLoader`.
Raises:
ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, or ``batch_sampler`` params
are passed as keyword arguments.
Lifecycle:
experimental
"""
unsupported_dataloader_args = [
"shuffle",
"batch_size",
"sampler",
"batch_sampler",
]
if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()):
raise ValueError(
f"The {','.join(unsupported_dataloader_args)} DataLoader parameters are not supported"
)

if dataloader_kwargs.get("num_workers", 0) > 0:
init_multiprocessing()

if "collate_fn" not in dataloader_kwargs:
dataloader_kwargs["collate_fn"] = _collate_noop

return torch.utils.data.DataLoader(
ds,
batch_size=None, # batching is handled by upstream iterator
shuffle=False, # shuffling is handled by upstream iterator
**dataloader_kwargs,
)


def _collate_noop(datum: _T) -> _T:
"""Noop collation for use with a dataloader instance.
Private.
"""
return datum
199 changes: 199 additions & 0 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation
# Copyright (c) 2021-2024 TileDB, Inc.
#
# Licensed under the MIT License.

from functools import partial
from typing import Any, Tuple
from unittest.mock import patch

import numpy as np
import numpy.typing as npt
import pandas as pd
import pytest
from tiledbsoma import Experiment

from tests._utils import IterableWrappers, IterableWrapperType, pytorch_x_value_gen
from tiledbsoma_ml import ExperimentAxisQueryIterDataPipe
from tiledbsoma_ml.dataloader import experiment_dataloader


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)]
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_multiprocessing__returns_full_result(
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
) -> None:
"""Tests that ``ExperimentAxisQueryIterDataPipe`` / ``ExperimentAxisQueryIterableDataset``
provide all data, as collected from multiple processes that are managed by a PyTorch DataLoader
with multiple workers configured."""
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["soma_joinid", "label"],
io_batch_size=3, # two chunks, one per worker
)
# Wrap with a DataLoader, which sets up the multiprocessing
dl = experiment_dataloader(dp, num_workers=2)

full_result = list(iter(dl))

soma_joinids = np.concatenate(
[t[1]["soma_joinid"].to_numpy() for t in full_result]
)
assert sorted(soma_joinids) == list(range(6))


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__non_batched(
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["label"],
shuffle=False,
use_eager_fetch=use_eager_fetch,
)
dl = experiment_dataloader(dp)
data = [row for row in dl]
assert all(d[0].shape == (3,) for d in data)
assert all(d[1].shape == (1, 1) for d in data)

row = data[0]
assert row[0].tolist() == [0, 1, 0]
assert row[1]["label"].tolist() == ["0"]


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__batched(
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
batch_size=3,
shuffle=False,
use_eager_fetch=use_eager_fetch,
)
dl = experiment_dataloader(dp)
data = [row for row in dl]

batch = data[0]
assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]]
assert batch[1].to_numpy().tolist() == [[0], [1], [2]]


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[
(10, 3, pytorch_x_value_gen, use_eager_fetch)
for use_eager_fetch in (True, False)
],
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__batched_length(
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["label"],
batch_size=3,
shuffle=False,
use_eager_fetch=use_eager_fetch,
)
dl = experiment_dataloader(dp)
assert len(dl) == len(list(dl))


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,batch_size",
[(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)],
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__collate_fn(
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
batch_size: int,
) -> None:
def collate_fn(
batch_size: int, data: Tuple[npt.NDArray[np.number[Any]], pd.DataFrame]
) -> Tuple[npt.NDArray[np.number[Any]], pd.DataFrame]:
assert isinstance(data, tuple)
assert len(data) == 2
assert isinstance(data[0], np.ndarray) and isinstance(data[1], pd.DataFrame)
if batch_size > 1:
assert data[0].shape[0] == data[1].shape[0]
assert data[0].shape[0] <= batch_size
else:
assert data[0].ndim == 1
assert data[1].shape[1] <= batch_size
return data

with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["label"],
batch_size=batch_size,
shuffle=False,
)
dl = experiment_dataloader(dp, collate_fn=partial(collate_fn, batch_size))
assert len(list(dl)) > 0


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen", [(10, 1, pytorch_x_value_gen)]
)
def test__pytorch_splitting(
soma_experiment: Experiment,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = ExperimentAxisQueryIterDataPipe(
query,
X_name="raw",
obs_column_names=["label"],
)
# function not available for IterableDataset, yet....
dp_train, dp_test = dp.random_split(
weights={"train": 0.7, "test": 0.3}, seed=1234
)
dl = experiment_dataloader(dp_train)

all_rows = list(iter(dl))
assert len(all_rows) == 7


def test_experiment_dataloader__unsupported_params__fails() -> None:
with patch(
"tiledbsoma_ml.pytorch.ExperimentAxisQueryIterDataPipe"
) as dummy_exp_data_pipe:
with pytest.raises(ValueError):
experiment_dataloader(dummy_exp_data_pipe, shuffle=True)
with pytest.raises(ValueError):
experiment_dataloader(dummy_exp_data_pipe, batch_size=3)
with pytest.raises(ValueError):
experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[])
with pytest.raises(ValueError):
experiment_dataloader(dummy_exp_data_pipe, sampler=[])
Loading

0 comments on commit 869a99f

Please sign in to comment.