From 16a6090e9575e690b624d4e8394da77c84381fe3 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Fri, 6 Dec 2024 11:24:25 -0500 Subject: [PATCH 1/9] pytorch.py nits --- src/tiledbsoma_ml/pytorch.py | 12 ++++++------ tests/test_pytorch.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index a5561fb..d783ba7 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -175,7 +175,7 @@ def __init__( When using this class in any distributed mode, calling the :meth:`set_epoch` method at the beginning of each epoch **before** creating the :class:`DataLoader` iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, - the same ordering will be always used. + the same ordering will always be used. In addition, when using shuffling in a distributed configuration (e.g., ``DDP``), you must provide a seed, ensuring that the same shuffle is used across all replicas. @@ -251,19 +251,19 @@ def _create_obs_joinids_partition(self) -> Iterator[NDArrayJoinId]: if self.shuffle: assert self.io_batch_size % self.shuffle_chunk_size == 0 shuffle_split = np.array_split( - _gpu_split, max(1, ceil(len(_gpu_split) / self.shuffle_chunk_size)) + _gpu_split, max(1, ceil(min_len / self.shuffle_chunk_size)) ) # Deterministically create RNG - state must be same across all processes, ensuring # that the joinid partitions are identical across all processes. rng = np.random.default_rng(self.seed + self.epoch + 99) rng.shuffle(shuffle_split) - obs_joinids_chunked = list( + obs_joinids_chunked = [ np.concatenate(b) for b in _batched( shuffle_split, self.io_batch_size // self.shuffle_chunk_size ) - ) + ] else: obs_joinids_chunked = np.array_split( _gpu_split, max(1, ceil(len(_gpu_split) / self.io_batch_size)) @@ -463,7 +463,7 @@ def _io_batch_iter( f"Retrieving next SOMA IO batch of length {len(obs_coords)}..." ) - # to maximize optty's for concurrency, when in eager_fetch mode, + # To maximize opportunities for concurrency, when in eager_fetch mode, # create the X read iterator first, as the eager iterator will begin # the read-ahead immediately. Then proceed to fetch obs DataFrame. # This matters most on latent backing stores, e.g., S3. @@ -910,7 +910,7 @@ def epoch(self) -> int: def experiment_dataloader( - ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset, + ds: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, **dataloader_kwargs: Any, ) -> torch.utils.data.DataLoader: """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 05bf6ca..52f7aee 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -776,7 +776,7 @@ def test_experiment_axis_query_iterable_error_checks( dp[0] with pytest.raises(ValueError): - dp = ExperimentAxisQueryIterable( + ExperimentAxisQueryIterable( query, obs_column_names=(), X_name="raw", From b03c65fb6aa359329cc04d075c0ef14e16f22a4e Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sun, 15 Dec 2024 11:00:20 -0500 Subject: [PATCH 2/9] rename `XObsDatum` to `Batch` --- src/tiledbsoma_ml/pytorch.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index d783ba7..c6ae245 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -52,12 +52,12 @@ NDArrayNumber = npt.NDArray[np.number[Any]] NDArrayJoinId = npt.NDArray[np.int64] _CSRIdxArray = npt.NDArray[np.unsignedinteger[Any]] -XDatum = Union[NDArrayNumber, sparse.csr_matrix] -XObsDatum = Tuple[XDatum, pd.DataFrame] -"""Return type of ``ExperimentAxisQueryIterableDataset`` and ``ExperimentAxisQueryIterDataPipe``, -which pairs a slice of ``X`` rows with a corresponding slice of ``obs``. In the default case, -the datum is a tuple of :class:`numpy.ndarray` and :class:`pandas.DataFrame` (for ``X`` and ``obs`` -respectively). If the object is created with ``return_sparse_X`` as True, the ``X`` slice is +XBatch = Union[NDArrayNumber, sparse.csr_matrix] +Batch = Tuple[XBatch, pd.DataFrame] +""""Batch" type yielded by ``ExperimentAxisQueryIterableDataset`` and ``ExperimentAxisQueryIterDataPipe``; +pairs a slice of ``X`` rows with a corresponding slice of ``obs``. In the default case. +a Batch is a tuple of :class:`numpy.ndarray` and :class:`pandas.DataFrame` (for ``X`` and ``obs``, +respectively). If the iterator is created with ``return_sparse_X`` as True, the ``X`` slice is returned as a :class:`scipy.sparse.csr_matrix`. If the ``batch_size`` is 1, the :class:`numpy.ndarray` will be returned with rank 1; in all other cases, objects are returned with rank 2.""" @@ -91,7 +91,7 @@ def open_experiment(self) -> Generator[soma.Experiment, None, None]: ) -class ExperimentAxisQueryIterable(Iterable[XObsDatum]): +class ExperimentAxisQueryIterable(Iterable[Batch]): """An :class:`Iterable` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as selected by a user-specified :class:`tiledbsoma.ExperimentAxisQuery`. Each step of the iterator produces a batch containing equal-sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and @@ -318,7 +318,7 @@ def _init_once(self, exp: soma.Experiment | None = None) -> None: self._initialized = True - def __iter__(self) -> Iterator[XObsDatum]: + def __iter__(self) -> Iterator[Batch]: """Create iterator over query. Returns: @@ -422,7 +422,7 @@ def set_epoch(self, epoch: int) -> None: """ self.epoch = epoch - def __getitem__(self, index: int) -> XObsDatum: + def __getitem__(self, index: int) -> Batch: raise NotImplementedError( "`ExperimentAxisQueryIterable` can only be iterated - does not support mapping" ) @@ -523,7 +523,7 @@ def _mini_batch_iter( obs: soma.DataFrame, X: soma.SparseNDArray, obs_joinid_iter: Iterator[NDArrayJoinId], - ) -> Iterator[XObsDatum]: + ) -> Iterator[Batch]: """Break IO batches into shuffled mini-batch-sized chunks. Private method. @@ -598,7 +598,7 @@ def _mini_batch_iter( class ExperimentAxisQueryIterDataPipe( torchdata.datapipes.iter.IterDataPipe[ # type:ignore[misc] - torch.utils.data.dataset.Dataset[XObsDatum] + torch.utils.data.dataset.Dataset[Batch] ], ): """A :class:`torchdata.datapipes.iter.IterDataPipe` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. @@ -646,7 +646,7 @@ def __init__( shuffle_chunk_size=shuffle_chunk_size, ) - def __iter__(self) -> Iterator[XObsDatum]: + def __iter__(self) -> Iterator[Batch]: """ See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. @@ -699,7 +699,7 @@ def epoch(self) -> int: class ExperimentAxisQueryIterableDataset( - torch.utils.data.IterableDataset[XObsDatum] # type:ignore[misc] + torch.utils.data.IterableDataset[Batch] # type:ignore[misc] ): """A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. @@ -843,7 +843,7 @@ def __init__( shuffle_chunk_size=shuffle_chunk_size, ) - def __iter__(self) -> Iterator[XObsDatum]: + def __iter__(self) -> Iterator[Batch]: """Create ``Iterator`` yielding "mini-batch" tuples of :class:`numpy.ndarray` (or :class:`scipy.csr_matrix`) and :class:`pandas.DataFrame`. From 8385bdc831ee33daf24fe2b1c316392be3dd9dbf Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sun, 15 Dec 2024 11:02:09 -0500 Subject: [PATCH 3/9] `_experiment_locator.py` --- src/tiledbsoma_ml/_experiment_locator.py | 39 ++++++++++++++++++++++++ src/tiledbsoma_ml/pytorch.py | 37 ++-------------------- 2 files changed, 42 insertions(+), 34 deletions(-) create mode 100644 src/tiledbsoma_ml/_experiment_locator.py diff --git a/src/tiledbsoma_ml/_experiment_locator.py b/src/tiledbsoma_ml/_experiment_locator.py new file mode 100644 index 0000000..062111d --- /dev/null +++ b/src/tiledbsoma_ml/_experiment_locator.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from contextlib import contextmanager +from typing import Dict, Generator, Union + +import attrs +from tiledbsoma import Experiment, SOMATileDBContext + + +@attrs.define(frozen=True, kw_only=True) +class ExperimentLocator: + """State required to open the Experiment. + + Serializable across multiple processes. + + Private implementation class. + """ + + uri: str + tiledb_timestamp_ms: int + tiledb_config: Dict[str, Union[str, float]] + + @classmethod + def create(cls, experiment: Experiment) -> "ExperimentLocator": + return ExperimentLocator( + uri=experiment.uri, + tiledb_timestamp_ms=experiment.tiledb_timestamp_ms, + tiledb_config=experiment.context.tiledb_config, + ) + + @contextmanager + def open_experiment(self) -> Generator[Experiment, None, None]: + context = SOMATileDBContext(tiledb_config=self.tiledb_config) + yield Experiment.open( + self.uri, tiledb_timestamp=self.tiledb_timestamp_ms, context=context + ) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index c6ae245..28be47f 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -13,14 +13,11 @@ import os import sys import time -from contextlib import contextmanager from itertools import islice from math import ceil from typing import ( Any, ContextManager, - Dict, - Generator, Iterable, Iterator, List, @@ -31,7 +28,6 @@ Union, ) -import attrs import numba import numpy as np import numpy.typing as npt @@ -44,6 +40,8 @@ from somacore.query._eager_iter import EagerIterator as _EagerIterator from typing_extensions import Self +from tiledbsoma_ml._experiment_locator import ExperimentLocator + logger = logging.getLogger("tiledbsoma_ml.pytorch") _T = TypeVar("_T") @@ -62,35 +60,6 @@ will be returned with rank 1; in all other cases, objects are returned with rank 2.""" -@attrs.define(frozen=True, kw_only=True) -class _ExperimentLocator: - """State required to open the Experiment. - - Serializable across multiple processes. - - Private implementation class. - """ - - uri: str - tiledb_timestamp_ms: int - tiledb_config: Dict[str, Union[str, float]] - - @classmethod - def create(cls, experiment: soma.Experiment) -> "_ExperimentLocator": - return _ExperimentLocator( - uri=experiment.uri, - tiledb_timestamp_ms=experiment.tiledb_timestamp_ms, - tiledb_config=experiment.context.tiledb_config, - ) - - @contextmanager - def open_experiment(self) -> Generator[soma.Experiment, None, None]: - context = soma.SOMATileDBContext(tiledb_config=self.tiledb_config) - yield soma.Experiment.open( - self.uri, tiledb_timestamp=self.tiledb_timestamp_ms, context=context - ) - - class ExperimentAxisQueryIterable(Iterable[Batch]): """An :class:`Iterable` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as selected by a user-specified :class:`tiledbsoma.ExperimentAxisQuery`. Each step of the iterator @@ -184,7 +153,7 @@ def __init__( super().__init__() # Anything set in the instance needs to be pickle-able for multi-process DataLoaders - self.experiment_locator = _ExperimentLocator.create(query.experiment) + self.experiment_locator = ExperimentLocator.create(query.experiment) self.layer_name = X_name self.measurement_name = query.measurement_name self.obs_query = query._matrix_axis_query.obs From 446cdee0ad9947dc511ad98cf8538b763ae3a271 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sun, 15 Dec 2024 11:06:11 -0500 Subject: [PATCH 4/9] `_distributed.py` --- src/tiledbsoma_ml/_distributed.py | 67 ++++++++++++++++++++++++++++ src/tiledbsoma_ml/pytorch.py | 74 +++++-------------------------- 2 files changed, 79 insertions(+), 62 deletions(-) create mode 100644 src/tiledbsoma_ml/_distributed.py diff --git a/src/tiledbsoma_ml/_distributed.py b/src/tiledbsoma_ml/_distributed.py new file mode 100644 index 0000000..75c865b --- /dev/null +++ b/src/tiledbsoma_ml/_distributed.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +import logging +import os +from typing import Tuple + +import torch + +logger = logging.getLogger("tiledbsoma_ml.pytorch") + + +def get_distributed_world_rank() -> Tuple[int, int]: + """Return tuple containing equivalent of ``torch.distributed`` world size and rank.""" + world_size, rank = 1, 0 + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + elif "LOCAL_RANK" in os.environ and "WORLD_SIZE" in os.environ: + # Lightning doesn't use RANK! LOCAL_RANK is only for the local node. There + # is a NODE_RANK for the node's rank, but no way to tell the local node's + # world. So computing a global rank is impossible(?). Using LOCAL_RANK as a + # proxy, which works fine on a single-CPU box. TODO: could throw/error + # if NODE_RANK != 0. + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["LOCAL_RANK"]) + elif torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + return world_size, rank + + +def get_worker_world_rank() -> Tuple[int, int]: + """Return number of DataLoader workers and our worker rank/id""" + num_workers, worker = 1, 0 + if "WORKER" in os.environ and "NUM_WORKERS" in os.environ: + num_workers = int(os.environ["NUM_WORKERS"]) + worker = int(os.environ["WORKER"]) + else: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + num_workers = worker_info.num_workers + worker = worker_info.id + return num_workers, worker + + +def init_multiprocessing() -> None: + """Ensures use of "spawn" for starting child processes with multiprocessing. + + Forked processes are known to be problematic: + https://pytorch.org/docs/stable/notes/multiprocessing.html#avoiding-and-fighting-deadlocks + Also, CUDA does not support forked child processes: + https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing + + Private. + """ + orig_start_method = torch.multiprocessing.get_start_method() + if orig_start_method != "spawn": + if orig_start_method: + logger.warning( + "switching torch multiprocessing start method from " + f'"{torch.multiprocessing.get_start_method()}" to "spawn"' + ) + torch.multiprocessing.set_start_method("spawn", force=True) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 28be47f..5e5f7bd 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -40,6 +40,11 @@ from somacore.query._eager_iter import EagerIterator as _EagerIterator from typing_extensions import Self +from tiledbsoma_ml._distributed import ( + get_distributed_world_rank, + get_worker_world_rank, + init_multiprocessing, +) from tiledbsoma_ml._experiment_locator import ExperimentLocator logger = logging.getLogger("tiledbsoma_ml.pytorch") @@ -206,7 +211,7 @@ def _create_obs_joinids_partition(self) -> Iterator[NDArrayJoinId]: obs_joinids: NDArrayJoinId = self._obs_joinids # 1. Get the split for the model replica/GPU - world_size, rank = _get_distributed_world_rank() + world_size, rank = get_distributed_world_rank() _gpu_splits = _splits(len(obs_joinids), world_size) _gpu_split = obs_joinids[_gpu_splits[rank] : _gpu_splits[rank + 1]] @@ -239,7 +244,7 @@ def _create_obs_joinids_partition(self) -> Iterator[NDArrayJoinId]: ) # 4. Partition by DataLoader worker - n_workers, worker_id = _get_worker_world_rank() + n_workers, worker_id = get_worker_world_rank() obs_splits = _splits(len(obs_joinids_chunked), n_workers) obs_partition_joinids = obs_joinids_chunked[ obs_splits[worker_id] : obs_splits[worker_id + 1] @@ -305,8 +310,8 @@ def __iter__(self) -> Iterator[Batch]: "(see https://github.com/pytorch/pytorch/issues/20248)" ) - world_size, rank = _get_distributed_world_rank() - n_workers, worker_id = _get_worker_world_rank() + world_size, rank = get_distributed_world_rank() + n_workers, worker_id = get_worker_world_rank() logger.debug( f"Iterator created {rank=}, {world_size=}, {worker_id=}, {n_workers=}, seed={self.seed}, epoch={self.epoch}" ) @@ -367,8 +372,8 @@ def shape(self) -> Tuple[int, int]: self._init_once() assert self._obs_joinids is not None assert self._var_joinids is not None - world_size, rank = _get_distributed_world_rank() - n_workers, worker_id = _get_worker_world_rank() + world_size, rank = get_distributed_world_rank() + n_workers, worker_id = get_worker_world_rank() # Every "distributed" process must receive the same number of "obs" rows; the last ≤world_size may be dropped # (see _create_obs_joinids_partition). obs_per_proc = len(self._obs_joinids) // world_size @@ -924,7 +929,7 @@ def experiment_dataloader( ) if dataloader_kwargs.get("num_workers", 0) > 0: - _init_multiprocessing() + init_multiprocessing() if "collate_fn" not in dataloader_kwargs: dataloader_kwargs["collate_fn"] = _collate_noop @@ -984,61 +989,6 @@ def _batched(iterable: Iterable[_T_co], n: int) -> Iterator[Tuple[_T_co, ...]]: yield batch -def _get_distributed_world_rank() -> Tuple[int, int]: - """Return tuple containing equivalent of ``torch.distributed`` world size and rank.""" - world_size, rank = 1, 0 - if "RANK" in os.environ and "WORLD_SIZE" in os.environ: - world_size = int(os.environ["WORLD_SIZE"]) - rank = int(os.environ["RANK"]) - elif "LOCAL_RANK" in os.environ and "WORLD_SIZE" in os.environ: - # Lightning doesn't use RANK! LOCAL_RANK is only for the local node. There - # is a NODE_RANK for the node's rank, but no way to tell the local node's - # world. So computing a global rank is impossible(?). Using LOCAL_RANK as a - # proxy, which works fine on a single-CPU box. TODO: could throw/error - # if NODE_RANK != 0. - world_size = int(os.environ["WORLD_SIZE"]) - rank = int(os.environ["LOCAL_RANK"]) - elif torch.distributed.is_initialized(): - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - - return world_size, rank - - -def _get_worker_world_rank() -> Tuple[int, int]: - """Return number of DataLoader workers and our worker rank/id""" - num_workers, worker = 1, 0 - if "WORKER" in os.environ and "NUM_WORKERS" in os.environ: - num_workers = int(os.environ["NUM_WORKERS"]) - worker = int(os.environ["WORKER"]) - else: - worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: - num_workers = worker_info.num_workers - worker = worker_info.id - return num_workers, worker - - -def _init_multiprocessing() -> None: - """Ensures use of "spawn" for starting child processes with multiprocessing. - - Forked processes are known to be problematic: - https://pytorch.org/docs/stable/notes/multiprocessing.html#avoiding-and-fighting-deadlocks - Also, CUDA does not support forked child processes: - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing - - Private. - """ - orig_start_method = torch.multiprocessing.get_start_method() - if orig_start_method != "spawn": - if orig_start_method: - logger.warning( - "switching torch multiprocessing start method from " - f'"{torch.multiprocessing.get_start_method()}" to "spawn"' - ) - torch.multiprocessing.set_start_method("spawn", force=True) - - class _CSR_IO_Buffer: """Implement a minimal CSR matrix with specific optimizations for use in this package. From b15df81e7ed49d91c963cf599634625fed996c58 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sun, 15 Dec 2024 11:11:48 -0500 Subject: [PATCH 5/9] `_utils.py` --- src/tiledbsoma_ml/_utils.py | 53 ++++++++++++++++++++++++++++++++++++ src/tiledbsoma_ml/pytorch.py | 50 +++------------------------------- tests/test_pytorch.py | 35 ------------------------ tests/test_utils.py | 41 ++++++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 81 deletions(-) create mode 100644 src/tiledbsoma_ml/_utils.py create mode 100644 tests/test_utils.py diff --git a/src/tiledbsoma_ml/_utils.py b/src/tiledbsoma_ml/_utils.py new file mode 100644 index 0000000..2aad4fe --- /dev/null +++ b/src/tiledbsoma_ml/_utils.py @@ -0,0 +1,53 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +import itertools +import sys +from itertools import islice +from typing import Iterable, Iterator, Tuple, TypeVar + +import numpy as np +import numpy.typing as npt + +_T_co = TypeVar("_T_co", covariant=True) + + +def splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: + """For ``total_length`` points, compute start/stop offsets that split the length into roughly equal sizes. + + A total_length of L, split into N sections, will return L%N sections of size L//N+1, + and the remainder as size L//N. This results in the same split as numpy.array_split, + for an array of length L and sections N. + + Private. + + Examples + -------- + >>> splits(10, 3) + array([0, 4, 7, 10]) + >>> splits(4, 2) + array([0, 2, 4]) + """ + if sections <= 0: + raise ValueError("number of sections must greater than 0.") from None + each_section, extras = divmod(total_length, sections) + per_section_sizes = ( + [0] + extras * [each_section + 1] + (sections - extras) * [each_section] + ) + splits = np.array(per_section_sizes, dtype=np.intp).cumsum() + return splits + + +if sys.version_info >= (3, 12): + batched = itertools.batched +else: + + def batched(iterable: Iterable[_T_co], n: int) -> Iterator[Tuple[_T_co, ...]]: + """Same as the Python 3.12+ ``itertools.batched`` -- polyfill for old Python versions.""" + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + while batch := tuple(islice(it, n)): + yield batch diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 5e5f7bd..5c229b2 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -7,13 +7,10 @@ import contextlib import gc -import itertools import logging import math import os -import sys import time -from itertools import islice from math import ceil from typing import ( Any, @@ -46,11 +43,11 @@ init_multiprocessing, ) from tiledbsoma_ml._experiment_locator import ExperimentLocator +from tiledbsoma_ml._utils import batched, splits logger = logging.getLogger("tiledbsoma_ml.pytorch") _T = TypeVar("_T") -_T_co = TypeVar("_T_co", covariant=True) NDArrayNumber = npt.NDArray[np.number[Any]] NDArrayJoinId = npt.NDArray[np.int64] @@ -212,7 +209,7 @@ def _create_obs_joinids_partition(self) -> Iterator[NDArrayJoinId]: # 1. Get the split for the model replica/GPU world_size, rank = get_distributed_world_rank() - _gpu_splits = _splits(len(obs_joinids), world_size) + _gpu_splits = splits(len(obs_joinids), world_size) _gpu_split = obs_joinids[_gpu_splits[rank] : _gpu_splits[rank + 1]] # 2. Trim to be all of equal length - equivalent to a "drop_last" @@ -234,7 +231,7 @@ def _create_obs_joinids_partition(self) -> Iterator[NDArrayJoinId]: rng.shuffle(shuffle_split) obs_joinids_chunked = [ np.concatenate(b) - for b in _batched( + for b in batched( shuffle_split, self.io_batch_size // self.shuffle_chunk_size ) ] @@ -245,7 +242,7 @@ def _create_obs_joinids_partition(self) -> Iterator[NDArrayJoinId]: # 4. Partition by DataLoader worker n_workers, worker_id = get_worker_world_rank() - obs_splits = _splits(len(obs_joinids_chunked), n_workers) + obs_splits = splits(len(obs_joinids_chunked), n_workers) obs_partition_joinids = obs_joinids_chunked[ obs_splits[worker_id] : obs_splits[worker_id + 1] ].copy() @@ -950,45 +947,6 @@ def _collate_noop(datum: _T) -> _T: return datum -def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: - """For ``total_length`` points, compute start/stop offsets that split the length into roughly equal sizes. - - A total_length of L, split into N sections, will return L%N sections of size L//N+1, - and the remainder as size L//N. This results in the same split as numpy.array_split, - for an array of length L and sections N. - - Private. - - Examples - -------- - >>> _splits(10, 3) - array([0, 4, 7, 10]) - >>> _splits(4, 2) - array([0, 2, 4]) - """ - if sections <= 0: - raise ValueError("number of sections must greater than 0.") from None - each_section, extras = divmod(total_length, sections) - per_section_sizes = ( - [0] + extras * [each_section + 1] + (sections - extras) * [each_section] - ) - splits = np.array(per_section_sizes, dtype=np.intp).cumsum() - return splits - - -if sys.version_info >= (3, 12): - _batched = itertools.batched -else: - - def _batched(iterable: Iterable[_T_co], n: int) -> Iterator[Tuple[_T_co, ...]]: - """Same as the Python 3.12+ ``itertools.batched`` -- polyfill for old Python versions.""" - if n < 1: - raise ValueError("n must be at least one") - it = iter(iterable) - while batch := tuple(islice(it, n)): - yield batch - - class _CSR_IO_Buffer: """Implement a minimal CSR matrix with specific optimizations for use in this package. diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 52f7aee..06c8a71 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -798,41 +798,6 @@ def test_experiment_dataloader__unsupported_params__fails() -> None: experiment_dataloader(dummy_exp_data_pipe, sampler=[]) -def test_batched() -> None: - from tiledbsoma_ml.pytorch import _batched - - assert list(_batched(range(6), 1)) == list((i,) for i in range(6)) - assert list(_batched(range(6), 2)) == [(0, 1), (2, 3), (4, 5)] - assert list(_batched(range(6), 3)) == [(0, 1, 2), (3, 4, 5)] - assert list(_batched(range(6), 4)) == [(0, 1, 2, 3), (4, 5)] - assert list(_batched(range(6), 5)) == [(0, 1, 2, 3, 4), (5,)] - assert list(_batched(range(6), 6)) == [(0, 1, 2, 3, 4, 5)] - assert list(_batched(range(6), 7)) == [(0, 1, 2, 3, 4, 5)] - - # bogus batch value - with pytest.raises(ValueError): - list(_batched([0, 1], 0)) - with pytest.raises(ValueError): - list(_batched([2, 3], -1)) - - -def test_splits() -> None: - from tiledbsoma_ml.pytorch import _splits - - assert _splits(10, 1).tolist() == [0, 10] - assert _splits(10, 2).tolist() == [0, 5, 10] - assert _splits(10, 3).tolist() == [0, 4, 7, 10] - assert _splits(10, 4).tolist() == [0, 3, 6, 8, 10] - assert _splits(10, 10).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - assert _splits(10, 11).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10] - - # bad number of sections - with pytest.raises(ValueError): - _splits(10, 0) - with pytest.raises(ValueError): - _splits(10, -1) - - @pytest.mark.parametrize( # keep these small as we materialize as a dense ndarray "shape", [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..18fdb6a --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +import pytest + + +def test_batched() -> None: + from tiledbsoma_ml._utils import batched + + assert list(batched(range(6), 1)) == list((i,) for i in range(6)) + assert list(batched(range(6), 2)) == [(0, 1), (2, 3), (4, 5)] + assert list(batched(range(6), 3)) == [(0, 1, 2), (3, 4, 5)] + assert list(batched(range(6), 4)) == [(0, 1, 2, 3), (4, 5)] + assert list(batched(range(6), 5)) == [(0, 1, 2, 3, 4), (5,)] + assert list(batched(range(6), 6)) == [(0, 1, 2, 3, 4, 5)] + assert list(batched(range(6), 7)) == [(0, 1, 2, 3, 4, 5)] + + # bogus batch value + with pytest.raises(ValueError): + list(batched([0, 1], 0)) + with pytest.raises(ValueError): + list(batched([2, 3], -1)) + + +def test_splits() -> None: + from tiledbsoma_ml._utils import splits + + assert splits(10, 1).tolist() == [0, 10] + assert splits(10, 2).tolist() == [0, 5, 10] + assert splits(10, 3).tolist() == [0, 4, 7, 10] + assert splits(10, 4).tolist() == [0, 3, 6, 8, 10] + assert splits(10, 10).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + assert splits(10, 11).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10] + + # bad number of sections + with pytest.raises(ValueError): + splits(10, 0) + with pytest.raises(ValueError): + splits(10, -1) From 2af96edfea211942b09bd519c42c1d151b4e5a4c Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sun, 15 Dec 2024 11:18:12 -0500 Subject: [PATCH 6/9] `_csr.py` --- src/tiledbsoma_ml/_csr.py | 272 ++++++++++++++++++++++++++++++++++ src/tiledbsoma_ml/_utils.py | 3 +- src/tiledbsoma_ml/pytorch.py | 273 +---------------------------------- tests/test_csr.py | 142 ++++++++++++++++++ tests/test_pytorch.py | 133 ----------------- 5 files changed, 422 insertions(+), 401 deletions(-) create mode 100644 src/tiledbsoma_ml/_csr.py create mode 100644 tests/test_csr.py diff --git a/src/tiledbsoma_ml/_csr.py b/src/tiledbsoma_ml/_csr.py new file mode 100644 index 0000000..ea3dbbf --- /dev/null +++ b/src/tiledbsoma_ml/_csr.py @@ -0,0 +1,272 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from math import ceil +from typing import Any, List, Sequence, Tuple, Type + +import numba +import numpy as np +import numpy.typing as npt +from scipy import sparse +from typing_extensions import Self + +from tiledbsoma_ml._utils import NDArrayNumber + +_CSRIdxArray = npt.NDArray[np.unsignedinteger[Any]] + + +class CSR_IO_Buffer: + """Implement a minimal CSR matrix with specific optimizations for use in this package. + + Operations supported are: + * Incrementally build a CSR from COO, allowing overlapped I/O and CSR conversion for I/O batches, + and a final "merge" step which combines the result. + * Zero intermediate copy conversion of an arbitrary row slice to dense (i.e., mini-batch extraction). + * Parallel processing, where possible (construction, merge, etc.). + * Minimize memory use for index arrays. + + Overall is significantly faster, and uses less memory, than the equivalent ``scipy.sparse`` operations. + """ + + __slots__ = ("indptr", "indices", "data", "shape") + + def __init__( + self, + indptr: _CSRIdxArray, + indices: _CSRIdxArray, + data: NDArrayNumber, + shape: Tuple[int, int], + ) -> None: + """Construct from PJV format.""" + assert len(data) == len(indices) + assert len(data) <= np.iinfo(indptr.dtype).max + assert shape[1] <= np.iinfo(indices.dtype).max + assert indptr[-1] == len(data) and indptr[0] == 0 + + self.shape = shape + self.indptr = indptr + self.indices = indices + self.data = data + + @staticmethod + def from_ijd( + i: _CSRIdxArray, j: _CSRIdxArray, d: NDArrayNumber, shape: Tuple[int, int] + ) -> "CSR_IO_Buffer": + """Factory from COO""" + nnz = len(d) + indptr: _CSRIdxArray = np.zeros((shape[0] + 1), dtype=smallest_uint_dtype(nnz)) + indices: _CSRIdxArray = np.empty((nnz,), dtype=smallest_uint_dtype(shape[1])) + data = np.empty((nnz,), dtype=d.dtype) + _coo_to_csr_inner(shape[0], i, j, d, indptr, indices, data) + return CSR_IO_Buffer(indptr, indices, data, shape) + + @staticmethod + def from_pjd( + p: _CSRIdxArray, j: _CSRIdxArray, d: NDArrayNumber, shape: Tuple[int, int] + ) -> "CSR_IO_Buffer": + """Factory from CSR""" + return CSR_IO_Buffer(p, j, d, shape) + + @property + def nnz(self) -> int: + return len(self.indices) + + @property + def nbytes(self) -> int: + return int(self.indptr.nbytes + self.indices.nbytes + self.data.nbytes) + + @property + def dtype(self) -> npt.DTypeLike: + return self.data.dtype + + def slice_tonumpy(self, row_index: slice) -> NDArrayNumber: + """Extract slice as a dense ndarray. Does not assume any particular ordering of minor axis.""" + assert isinstance(row_index, slice) + assert row_index.step in (1, None) + row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) + n_rows = max(row_idx_end - row_idx_start, 0) + out = np.zeros((n_rows, self.shape[1]), dtype=self.data.dtype) + if n_rows >= 0: + _csr_to_dense_inner( + row_idx_start, n_rows, self.indptr, self.indices, self.data, out + ) + return out + + def slice_toscipy(self, row_index: slice) -> sparse.csr_matrix: + """Extract slice as a ``sparse.csr_matrix``. Does not assume any particular ordering of + minor axis, but will return a canonically ordered scipy sparse object.""" + assert isinstance(row_index, slice) + assert row_index.step in (1, None) + row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) + n_rows = max(row_idx_end - row_idx_start, 0) + if n_rows == 0: + return sparse.csr_matrix((0, self.shape[1]), dtype=self.dtype) + + indptr = self.indptr[row_idx_start : row_idx_end + 1].copy() + indices = self.indices[indptr[0] : indptr[-1]].copy() + data = self.data[indptr[0] : indptr[-1]].copy() + indptr -= indptr[0] + return sparse.csr_matrix((data, indices, indptr), shape=(n_rows, self.shape[1])) + + @staticmethod + def merge(mtxs: Sequence["CSR_IO_Buffer"]) -> "CSR_IO_Buffer": + assert len(mtxs) > 0 + nnz = sum(m.nnz for m in mtxs) + shape = mtxs[0].shape + for m in mtxs[1:]: + assert m.shape == mtxs[0].shape + assert m.indices.dtype == mtxs[0].indices.dtype + assert all(m.shape == shape for m in mtxs) + + indptr = np.sum( + [m.indptr for m in mtxs], axis=0, dtype=smallest_uint_dtype(nnz) + ) + indices = np.empty((nnz,), dtype=mtxs[0].indices.dtype) + data = np.empty((nnz,), mtxs[0].data.dtype) + + _csr_merge_inner( + tuple((m.indptr.astype(indptr.dtype), m.indices, m.data) for m in mtxs), + indptr, + indices, + data, + ) + return CSR_IO_Buffer.from_pjd(indptr, indices, data, shape) + + def sort_indices(self) -> Self: + """Sort indices, IN PLACE.""" + _csr_sort_indices(self.indptr, self.indices, self.data) + return self + + +def smallest_uint_dtype(max_val: int) -> Type[np.unsignedinteger[Any]]: + dts: List[Type[np.unsignedinteger[Any]]] = [np.uint16, np.uint32] + for dt in dts: + if max_val <= np.iinfo(dt).max: + return dt + else: + return np.uint64 + + +@numba.njit(nogil=True, parallel=True) # type:ignore[misc] +def _csr_merge_inner( + As: Tuple[Tuple[_CSRIdxArray, _CSRIdxArray, NDArrayNumber], ...], # P,J,D + Bp: _CSRIdxArray, + Bj: _CSRIdxArray, + Bd: NDArrayNumber, +) -> None: + n_rows = len(Bp) - 1 + offsets = Bp.copy() + for Ap, Aj, Ad in As: + n_elmts = Ap[1:] - Ap[:-1] + for n in numba.prange(n_rows): + Bj[offsets[n] : offsets[n] + n_elmts[n]] = Aj[Ap[n] : Ap[n] + n_elmts[n]] + Bd[offsets[n] : offsets[n] + n_elmts[n]] = Ad[Ap[n] : Ap[n] + n_elmts[n]] + offsets[:-1] += n_elmts + + +@numba.njit(nogil=True, parallel=True) # type:ignore[misc] +def _csr_to_dense_inner( + row_idx_start: int, + n_rows: int, + indptr: _CSRIdxArray, + indices: _CSRIdxArray, + data: NDArrayNumber, + out: NDArrayNumber, +) -> None: + for i in numba.prange(row_idx_start, row_idx_start + n_rows): + for j in range(indptr[i], indptr[i + 1]): + out[i - row_idx_start, indices[j]] = data[j] + + +@numba.njit(nogil=True, parallel=True, inline="always") # type:ignore[misc] +def _count_rows(n_rows: int, Ai: NDArrayNumber, Bp: NDArrayNumber) -> NDArrayNumber: + """Private: parallel row count.""" + nnz = len(Ai) + + partition_size = 32 * 1024**2 + n_partitions = ceil(nnz / partition_size) + if n_partitions > 1: + counts = np.zeros((n_partitions, n_rows), dtype=Bp.dtype) + for p in numba.prange(n_partitions): + for n in range(p * partition_size, min(nnz, (p + 1) * partition_size)): + row = Ai[n] + counts[p, row] += 1 + + Bp[:-1] = counts.sum(axis=0) + else: + for n in range(nnz): + row = Ai[n] + Bp[row] += 1 + + return Bp + + +@numba.njit(nogil=True, parallel=True) # type:ignore[misc] +def _coo_to_csr_inner( + n_rows: int, + Ai: _CSRIdxArray, + Aj: _CSRIdxArray, + Ad: NDArrayNumber, + Bp: _CSRIdxArray, + Bj: _CSRIdxArray, + Bd: NDArrayNumber, +) -> None: + nnz = len(Ai) + + _count_rows(n_rows, Ai, Bp) + + # cum sum to get the row index pointers (NOTE: starting with zero) + cumsum = 0 + for n in range(n_rows): + tmp = Bp[n] + Bp[n] = cumsum + cumsum += tmp + Bp[n_rows] = nnz + + # Reorganize all the data. Side effect: pointers shifted (reversed in the + # subsequent section). + # + # Method is concurrent (partitioned by rows) if number of rows is greater + # than 2**partition_bits. This partitioning scheme leverages the fact + # that reads are much cheaper than writes. + # + # The code is equivalent to: + # for n in range(nnz): + # row = Ai[n] + # dst_row = Bp[row] + # Bj[dst_row] = Aj[n] + # Bd[dst_row] = Ad[n] + # Bp[row] += 1 + + partition_bits = 13 + n_partitions = (n_rows + 2**partition_bits - 1) >> partition_bits + for p in numba.prange(n_partitions): + for n in range(nnz): + row = Ai[n] + if (row >> partition_bits) != p: + continue + dst_row = Bp[row] + Bj[dst_row] = Aj[n] + Bd[dst_row] = Ad[n] + Bp[row] += 1 + + # Shift the pointers by one slot (i.e., start at zero) + prev_ptr = 0 + for n in range(n_rows + 1): + tmp = Bp[n] + Bp[n] = prev_ptr + prev_ptr = tmp + + +@numba.njit(nogil=True, parallel=True) # type:ignore[misc] +def _csr_sort_indices(Bp: _CSRIdxArray, Bj: _CSRIdxArray, Bd: NDArrayNumber) -> None: + """In-place sort of minor axis indices""" + n_rows = len(Bp) - 1 + for r in numba.prange(n_rows): + row_start = Bp[r] + row_end = Bp[r + 1] + order = np.argsort(Bj[row_start:row_end]) + Bj[row_start:row_end] = Bj[row_start:row_end][order] + Bd[row_start:row_end] = Bd[row_start:row_end][order] diff --git a/src/tiledbsoma_ml/_utils.py b/src/tiledbsoma_ml/_utils.py index 2aad4fe..83456eb 100644 --- a/src/tiledbsoma_ml/_utils.py +++ b/src/tiledbsoma_ml/_utils.py @@ -6,12 +6,13 @@ import itertools import sys from itertools import islice -from typing import Iterable, Iterator, Tuple, TypeVar +from typing import Any, Iterable, Iterator, Tuple, TypeVar import numpy as np import numpy.typing as npt _T_co = TypeVar("_T_co", covariant=True) +NDArrayNumber = npt.NDArray[np.number[Any]] def splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 5c229b2..50686e8 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -8,7 +8,6 @@ import contextlib import gc import logging -import math import os import time from math import ceil @@ -17,15 +16,12 @@ ContextManager, Iterable, Iterator, - List, Sequence, Tuple, - Type, TypeVar, Union, ) -import numba import numpy as np import numpy.typing as npt import pandas as pd @@ -35,23 +31,21 @@ import torch import torchdata from somacore.query._eager_iter import EagerIterator as _EagerIterator -from typing_extensions import Self +from tiledbsoma_ml._csr import CSR_IO_Buffer from tiledbsoma_ml._distributed import ( get_distributed_world_rank, get_worker_world_rank, init_multiprocessing, ) from tiledbsoma_ml._experiment_locator import ExperimentLocator -from tiledbsoma_ml._utils import batched, splits +from tiledbsoma_ml._utils import NDArrayNumber, batched, splits logger = logging.getLogger("tiledbsoma_ml.pytorch") _T = TypeVar("_T") -NDArrayNumber = npt.NDArray[np.number[Any]] NDArrayJoinId = npt.NDArray[np.int64] -_CSRIdxArray = npt.NDArray[np.unsignedinteger[Any]] XBatch = Union[NDArrayNumber, sparse.csr_matrix] Batch = Tuple[XBatch, pd.DataFrame] """"Batch" type yielded by ``ExperimentAxisQueryIterableDataset`` and ``ExperimentAxisQueryIterDataPipe``; @@ -403,7 +397,7 @@ def _io_batch_iter( obs: soma.DataFrame, X: soma.SparseNDArray, obs_joinid_iter: Iterator[NDArrayJoinId], - ) -> Iterator[Tuple[_CSR_IO_Buffer, pd.DataFrame]]: + ) -> Iterator[Tuple[CSR_IO_Buffer, pd.DataFrame]]: """Iterate over IO batches, i.e., SOMA query reads, producing tuples of ``(X: csr_array, obs: DataFrame)``. ``obs`` joinids read are controlled by the ``obs_joinid_iter``. Iterator results will be reindexed and shuffled @@ -448,9 +442,9 @@ def make_io_buffer( obs_coords: NDArrayJoinId, var_coords: NDArrayJoinId, obs_indexer: soma.IntIndexer, - ) -> _CSR_IO_Buffer: + ) -> CSR_IO_Buffer: """This function provides a GC after we throw off (large) garbage.""" - m = _CSR_IO_Buffer.from_ijd( + m = CSR_IO_Buffer.from_ijd( obs_indexer.get_indexer(X_tbl["soma_dim_0"]), var_indexer.get_indexer(X_tbl["soma_dim_1"]), X_tbl["soma_data"].to_numpy(), @@ -478,7 +472,7 @@ def make_io_buffer( [self.obs_column_names] ) # fmt: on - X_io_batch = _CSR_IO_Buffer.merge(tuple(_io_buf_iter)) + X_io_batch = CSR_IO_Buffer.merge(tuple(_io_buf_iter)) del obs_indexer, obs_coords, obs_shuffled_coords, _io_buf_iter gc.collect() @@ -945,258 +939,3 @@ def _collate_noop(datum: _T) -> _T: Private. """ return datum - - -class _CSR_IO_Buffer: - """Implement a minimal CSR matrix with specific optimizations for use in this package. - - Operations supported are: - * Incrementally build a CSR from COO, allowing overlapped I/O and CSR conversion for I/O batches, - and a final "merge" step which combines the result. - * Zero intermediate copy conversion of an arbitrary row slice to dense (i.e., mini-batch extraction). - * Parallel processing, where possible (construction, merge, etc.). - * Minimize memory use for index arrays. - - Overall is significantly faster, and uses less memory, than the equivalent ``scipy.sparse`` operations. - """ - - __slots__ = ("indptr", "indices", "data", "shape") - - def __init__( - self, - indptr: _CSRIdxArray, - indices: _CSRIdxArray, - data: NDArrayNumber, - shape: Tuple[int, int], - ) -> None: - """Construct from PJV format.""" - assert len(data) == len(indices) - assert len(data) <= np.iinfo(indptr.dtype).max - assert shape[1] <= np.iinfo(indices.dtype).max - assert indptr[-1] == len(data) and indptr[0] == 0 - - self.shape = shape - self.indptr = indptr - self.indices = indices - self.data = data - - @staticmethod - def from_ijd( - i: _CSRIdxArray, j: _CSRIdxArray, d: NDArrayNumber, shape: Tuple[int, int] - ) -> _CSR_IO_Buffer: - """Factory from COO""" - nnz = len(d) - indptr: _CSRIdxArray = np.zeros((shape[0] + 1), dtype=smallest_uint_dtype(nnz)) - indices: _CSRIdxArray = np.empty((nnz,), dtype=smallest_uint_dtype(shape[1])) - data = np.empty((nnz,), dtype=d.dtype) - _coo_to_csr_inner(shape[0], i, j, d, indptr, indices, data) - return _CSR_IO_Buffer(indptr, indices, data, shape) - - @staticmethod - def from_pjd( - p: _CSRIdxArray, j: _CSRIdxArray, d: NDArrayNumber, shape: Tuple[int, int] - ) -> _CSR_IO_Buffer: - """Factory from CSR""" - return _CSR_IO_Buffer(p, j, d, shape) - - @property - def nnz(self) -> int: - return len(self.indices) - - @property - def nbytes(self) -> int: - return int(self.indptr.nbytes + self.indices.nbytes + self.data.nbytes) - - @property - def dtype(self) -> npt.DTypeLike: - return self.data.dtype - - def slice_tonumpy(self, row_index: slice) -> NDArrayNumber: - """Extract slice as a dense ndarray. Does not assume any particular ordering of minor axis.""" - assert isinstance(row_index, slice) - assert row_index.step in (1, None) - row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) - n_rows = max(row_idx_end - row_idx_start, 0) - out = np.zeros((n_rows, self.shape[1]), dtype=self.data.dtype) - if n_rows >= 0: - _csr_to_dense_inner( - row_idx_start, n_rows, self.indptr, self.indices, self.data, out - ) - return out - - def slice_toscipy(self, row_index: slice) -> sparse.csr_matrix: - """Extract slice as a ``sparse.csr_matrix``. Does not assume any particular ordering of - minor axis, but will return a canonically ordered scipy sparse object.""" - assert isinstance(row_index, slice) - assert row_index.step in (1, None) - row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) - n_rows = max(row_idx_end - row_idx_start, 0) - if n_rows == 0: - return sparse.csr_matrix((0, self.shape[1]), dtype=self.dtype) - - indptr = self.indptr[row_idx_start : row_idx_end + 1].copy() - indices = self.indices[indptr[0] : indptr[-1]].copy() - data = self.data[indptr[0] : indptr[-1]].copy() - indptr -= indptr[0] - return sparse.csr_matrix((data, indices, indptr), shape=(n_rows, self.shape[1])) - - @staticmethod - def merge(mtxs: Sequence[_CSR_IO_Buffer]) -> _CSR_IO_Buffer: - assert len(mtxs) > 0 - nnz = sum(m.nnz for m in mtxs) - shape = mtxs[0].shape - for m in mtxs[1:]: - assert m.shape == mtxs[0].shape - assert m.indices.dtype == mtxs[0].indices.dtype - assert all(m.shape == shape for m in mtxs) - - indptr = np.sum( - [m.indptr for m in mtxs], axis=0, dtype=smallest_uint_dtype(nnz) - ) - indices = np.empty((nnz,), dtype=mtxs[0].indices.dtype) - data = np.empty((nnz,), mtxs[0].data.dtype) - - _csr_merge_inner( - tuple((m.indptr.astype(indptr.dtype), m.indices, m.data) for m in mtxs), - indptr, - indices, - data, - ) - return _CSR_IO_Buffer.from_pjd(indptr, indices, data, shape) - - def sort_indices(self) -> Self: - """Sort indices, IN PLACE.""" - _csr_sort_indices(self.indptr, self.indices, self.data) - return self - - -def smallest_uint_dtype(max_val: int) -> Type[np.unsignedinteger[Any]]: - dts: List[Type[np.unsignedinteger[Any]]] = [np.uint16, np.uint32] - for dt in dts: - if max_val <= np.iinfo(dt).max: - return dt - else: - return np.uint64 - - -@numba.njit(nogil=True, parallel=True) # type:ignore[misc] -def _csr_merge_inner( - As: Tuple[Tuple[_CSRIdxArray, _CSRIdxArray, NDArrayNumber], ...], # P,J,D - Bp: _CSRIdxArray, - Bj: _CSRIdxArray, - Bd: NDArrayNumber, -) -> None: - n_rows = len(Bp) - 1 - offsets = Bp.copy() - for Ap, Aj, Ad in As: - n_elmts = Ap[1:] - Ap[:-1] - for n in numba.prange(n_rows): - Bj[offsets[n] : offsets[n] + n_elmts[n]] = Aj[Ap[n] : Ap[n] + n_elmts[n]] - Bd[offsets[n] : offsets[n] + n_elmts[n]] = Ad[Ap[n] : Ap[n] + n_elmts[n]] - offsets[:-1] += n_elmts - - -@numba.njit(nogil=True, parallel=True) # type:ignore[misc] -def _csr_to_dense_inner( - row_idx_start: int, - n_rows: int, - indptr: _CSRIdxArray, - indices: _CSRIdxArray, - data: NDArrayNumber, - out: NDArrayNumber, -) -> None: - for i in numba.prange(row_idx_start, row_idx_start + n_rows): - for j in range(indptr[i], indptr[i + 1]): - out[i - row_idx_start, indices[j]] = data[j] - - -@numba.njit(nogil=True, parallel=True, inline="always") # type:ignore[misc] -def _count_rows(n_rows: int, Ai: NDArrayNumber, Bp: NDArrayNumber) -> NDArrayNumber: - """Private: parallel row count.""" - nnz = len(Ai) - - partition_size = 32 * 1024**2 - n_partitions = math.ceil(nnz / partition_size) - if n_partitions > 1: - counts = np.zeros((n_partitions, n_rows), dtype=Bp.dtype) - for p in numba.prange(n_partitions): - for n in range(p * partition_size, min(nnz, (p + 1) * partition_size)): - row = Ai[n] - counts[p, row] += 1 - - Bp[:-1] = counts.sum(axis=0) - else: - for n in range(nnz): - row = Ai[n] - Bp[row] += 1 - - return Bp - - -@numba.njit(nogil=True, parallel=True) # type:ignore[misc] -def _coo_to_csr_inner( - n_rows: int, - Ai: _CSRIdxArray, - Aj: _CSRIdxArray, - Ad: NDArrayNumber, - Bp: _CSRIdxArray, - Bj: _CSRIdxArray, - Bd: NDArrayNumber, -) -> None: - nnz = len(Ai) - - _count_rows(n_rows, Ai, Bp) - - # cum sum to get the row index pointers (NOTE: starting with zero) - cumsum = 0 - for n in range(n_rows): - tmp = Bp[n] - Bp[n] = cumsum - cumsum += tmp - Bp[n_rows] = nnz - - # Reorganize all the data. Side effect: pointers shifted (reversed in the - # subsequent section). - # - # Method is concurrent (partitioned by rows) if number of rows is greater - # than 2**partition_bits. This partitioning scheme leverages the fact - # that reads are much cheaper than writes. - # - # The code is equivalent to: - # for n in range(nnz): - # row = Ai[n] - # dst_row = Bp[row] - # Bj[dst_row] = Aj[n] - # Bd[dst_row] = Ad[n] - # Bp[row] += 1 - - partition_bits = 13 - n_partitions = (n_rows + 2**partition_bits - 1) >> partition_bits - for p in numba.prange(n_partitions): - for n in range(nnz): - row = Ai[n] - if (row >> partition_bits) != p: - continue - dst_row = Bp[row] - Bj[dst_row] = Aj[n] - Bd[dst_row] = Ad[n] - Bp[row] += 1 - - # Shift the pointers by one slot (i.e., start at zero) - prev_ptr = 0 - for n in range(n_rows + 1): - tmp = Bp[n] - Bp[n] = prev_ptr - prev_ptr = tmp - - -@numba.njit(nogil=True, parallel=True) # type:ignore[misc] -def _csr_sort_indices(Bp: _CSRIdxArray, Bj: _CSRIdxArray, Bd: NDArrayNumber) -> None: - """In-place sort of minor axis indices""" - n_rows = len(Bp) - 1 - for r in numba.prange(n_rows): - row_start = Bp[r] - row_end = Bp[r + 1] - order = np.argsort(Bj[row_start:row_end]) - Bj[row_start:row_end] = Bj[row_start:row_end][order] - Bd[row_start:row_end] = Bd[row_start:row_end][order] diff --git a/tests/test_csr.py b/tests/test_csr.py new file mode 100644 index 0000000..b997cec --- /dev/null +++ b/tests/test_csr.py @@ -0,0 +1,142 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +import sys +from typing import Tuple + +import numpy as np +import numpy.typing as npt +import pytest +from scipy import sparse + + +@pytest.mark.parametrize( # keep these small as we materialize as a dense ndarray + "shape", + [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], +) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) +def test_construct_from_ijd(shape: Tuple[int, int], dtype: npt.DTypeLike) -> None: + from tiledbsoma_ml._csr import CSR_IO_Buffer + + sp_coo = sparse.random(shape[0], shape[1], dtype=dtype, format="coo", density=0.05) + sp_csr = sp_coo.tocsr() + + _ncsr = CSR_IO_Buffer.from_ijd( + sp_coo.row, sp_coo.col, sp_coo.data, shape=sp_coo.shape + ) + assert _ncsr.nnz == sp_coo.nnz == sp_csr.nnz + assert _ncsr.dtype == sp_coo.dtype == sp_csr.dtype + assert _ncsr.nbytes == ( + _ncsr.data.nbytes + _ncsr.indices.nbytes + _ncsr.indptr.nbytes + ) + + # CSR_IO_Buffer makes no guarantees about minor axis ordering (ie, "canonical" form) until + # sort_indices is called, so use the SciPy sparse csr package to validate by round-tripping. + assert ( + sparse.csr_matrix((_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape) + != sp_csr + ).nnz == 0 + + # Check dense slicing + assert np.array_equal(_ncsr.slice_tonumpy(slice(0, shape[0])), sp_coo.toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(0, shape[0])), sp_csr.toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(1, -1)), sp_csr[1:-1].toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(None, -2)), sp_csr[:-2].toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(None)), sp_csr[:].toarray()) + + # Check sparse slicing + assert (_ncsr.slice_toscipy(slice(0, shape[0])) != sp_csr).nnz == 0 + assert (_ncsr.slice_toscipy(slice(1, -1)) != sp_csr[1:-1]).nnz == 0 + assert (_ncsr.slice_toscipy(slice(None, -2)) != sp_csr[:-2]).nnz == 0 + assert (_ncsr.slice_toscipy(slice(None)) != sp_csr[:]).nnz == 0 + + +@pytest.mark.parametrize( + "shape", + [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], +) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) +def test_construct_from_pjd(shape: Tuple[int, int], dtype: npt.DTypeLike) -> None: + from tiledbsoma_ml._csr import CSR_IO_Buffer + + sp_csr = sparse.random(shape[0], shape[1], dtype=dtype, format="csr", density=0.05) + + _ncsr = CSR_IO_Buffer.from_pjd( + sp_csr.indptr.copy(), + sp_csr.indices.copy(), + sp_csr.data.copy(), + shape=sp_csr.shape, + ) + + # CSR_IO_Buffer makes no guarantees about minor axis ordering (ie, "canonical" form) until + # sort_indices is called, so use the SciPy sparse csr package to validate by round-tripping. + assert ( + sparse.csr_matrix((_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape) + != sp_csr + ).nnz == 0 + + # Check dense slicing + assert np.array_equal(_ncsr.slice_tonumpy(slice(0, shape[0])), sp_csr.toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(1, -1)), sp_csr[1:-1].toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(None, -2)), sp_csr[:-2].toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(None)), sp_csr[:].toarray()) + + # Check sparse slicing + assert (_ncsr.slice_toscipy(slice(0, shape[0])) != sp_csr).nnz == 0 + assert (_ncsr.slice_toscipy(slice(1, -1)) != sp_csr[1:-1]).nnz == 0 + assert (_ncsr.slice_toscipy(slice(None, -2)) != sp_csr[:-2]).nnz == 0 + assert (_ncsr.slice_toscipy(slice(None)) != sp_csr[:]).nnz == 0 + + +@pytest.mark.parametrize( + "shape", + [(100, 10), (10, 100)], +) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) +@pytest.mark.parametrize("n_splits", [2, 3, 4]) +def test_merge(shape: Tuple[int, int], dtype: npt.DTypeLike, n_splits: int) -> None: + from tiledbsoma_ml._csr import CSR_IO_Buffer + + sp_coo = sparse.random(shape[0], shape[1], dtype=dtype, format="coo", density=0.5) + splits = [ + t + for t in zip( + np.array_split(sp_coo.row, n_splits), + np.array_split(sp_coo.col, n_splits), + np.array_split(sp_coo.data, n_splits), + **(dict(strict=False) if sys.version_info >= (3, 10) else {}), + ) + ] + _ncsr = CSR_IO_Buffer.merge( + [CSR_IO_Buffer.from_ijd(i, j, d, shape=sp_coo.shape) for i, j, d in splits] + ) + + assert ( + sp_coo.tocsr() + != sparse.csr_matrix( + (_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape + ) + ).nnz == 0 + + +@pytest.mark.parametrize( + "shape", + [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], +) +def test_sort_indices(shape: Tuple[int, int]) -> None: + from tiledbsoma_ml._csr import CSR_IO_Buffer + + sp_coo = sparse.random( + shape[0], shape[1], dtype=np.float32, format="coo", density=0.05 + ) + sp_csr = sp_coo.tocsr() + + _ncsr = CSR_IO_Buffer.from_ijd( + sp_coo.row, sp_coo.col, sp_coo.data, shape=sp_coo.shape + ).sort_indices() + + assert np.array_equal(sp_csr.indptr, _ncsr.indptr) + assert np.array_equal(sp_csr.indices, _ncsr.indices) + assert np.array_equal(sp_csr.data, _ncsr.data) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 06c8a71..f514c53 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -5,7 +5,6 @@ from __future__ import annotations -import sys from functools import partial from pathlib import Path from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union @@ -796,135 +795,3 @@ def test_experiment_dataloader__unsupported_params__fails() -> None: experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[]) with pytest.raises(ValueError): experiment_dataloader(dummy_exp_data_pipe, sampler=[]) - - -@pytest.mark.parametrize( # keep these small as we materialize as a dense ndarray - "shape", - [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], -) -@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) -def test_csr__construct_from_ijd(shape: Tuple[int, int], dtype: npt.DTypeLike) -> None: - from tiledbsoma_ml.pytorch import _CSR_IO_Buffer - - sp_coo = sparse.random(shape[0], shape[1], dtype=dtype, format="coo", density=0.05) - sp_csr = sp_coo.tocsr() - - _ncsr = _CSR_IO_Buffer.from_ijd( - sp_coo.row, sp_coo.col, sp_coo.data, shape=sp_coo.shape - ) - assert _ncsr.nnz == sp_coo.nnz == sp_csr.nnz - assert _ncsr.dtype == sp_coo.dtype == sp_csr.dtype - assert _ncsr.nbytes == ( - _ncsr.data.nbytes + _ncsr.indices.nbytes + _ncsr.indptr.nbytes - ) - - # _CSR_IO_Buffer makes no guarantees about minor axis ordering (ie, "canonical" form) until - # sort_indices is called, so use the SciPy sparse csr package to validate by round-tripping. - assert ( - sparse.csr_matrix((_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape) - != sp_csr - ).nnz == 0 - - # Check dense slicing - assert np.array_equal(_ncsr.slice_tonumpy(slice(0, shape[0])), sp_coo.toarray()) - assert np.array_equal(_ncsr.slice_tonumpy(slice(0, shape[0])), sp_csr.toarray()) - assert np.array_equal(_ncsr.slice_tonumpy(slice(1, -1)), sp_csr[1:-1].toarray()) - assert np.array_equal(_ncsr.slice_tonumpy(slice(None, -2)), sp_csr[:-2].toarray()) - assert np.array_equal(_ncsr.slice_tonumpy(slice(None)), sp_csr[:].toarray()) - - # Check sparse slicing - assert (_ncsr.slice_toscipy(slice(0, shape[0])) != sp_csr).nnz == 0 - assert (_ncsr.slice_toscipy(slice(1, -1)) != sp_csr[1:-1]).nnz == 0 - assert (_ncsr.slice_toscipy(slice(None, -2)) != sp_csr[:-2]).nnz == 0 - assert (_ncsr.slice_toscipy(slice(None)) != sp_csr[:]).nnz == 0 - - -@pytest.mark.parametrize( - "shape", - [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], -) -@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) -def test_csr__construct_from_pjd(shape: Tuple[int, int], dtype: npt.DTypeLike) -> None: - from tiledbsoma_ml.pytorch import _CSR_IO_Buffer - - sp_csr = sparse.random(shape[0], shape[1], dtype=dtype, format="csr", density=0.05) - - _ncsr = _CSR_IO_Buffer.from_pjd( - sp_csr.indptr.copy(), - sp_csr.indices.copy(), - sp_csr.data.copy(), - shape=sp_csr.shape, - ) - - # _CSR_IO_Buffer makes no guarantees about minor axis ordering (ie, "canonical" form) until - # sort_indices is called, so use the SciPy sparse csr package to validate by round-tripping. - assert ( - sparse.csr_matrix((_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape) - != sp_csr - ).nnz == 0 - - # Check dense slicing - assert np.array_equal(_ncsr.slice_tonumpy(slice(0, shape[0])), sp_csr.toarray()) - assert np.array_equal(_ncsr.slice_tonumpy(slice(1, -1)), sp_csr[1:-1].toarray()) - assert np.array_equal(_ncsr.slice_tonumpy(slice(None, -2)), sp_csr[:-2].toarray()) - assert np.array_equal(_ncsr.slice_tonumpy(slice(None)), sp_csr[:].toarray()) - - # Check sparse slicing - assert (_ncsr.slice_toscipy(slice(0, shape[0])) != sp_csr).nnz == 0 - assert (_ncsr.slice_toscipy(slice(1, -1)) != sp_csr[1:-1]).nnz == 0 - assert (_ncsr.slice_toscipy(slice(None, -2)) != sp_csr[:-2]).nnz == 0 - assert (_ncsr.slice_toscipy(slice(None)) != sp_csr[:]).nnz == 0 - - -@pytest.mark.parametrize( - "shape", - [(100, 10), (10, 100)], -) -@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) -@pytest.mark.parametrize("n_splits", [2, 3, 4]) -def test_csr__merge( - shape: Tuple[int, int], dtype: npt.DTypeLike, n_splits: int -) -> None: - from tiledbsoma_ml.pytorch import _CSR_IO_Buffer - - sp_coo = sparse.random(shape[0], shape[1], dtype=dtype, format="coo", density=0.5) - splits = [ - t - for t in zip( - np.array_split(sp_coo.row, n_splits), - np.array_split(sp_coo.col, n_splits), - np.array_split(sp_coo.data, n_splits), - **(dict(strict=False) if sys.version_info >= (3, 10) else {}), - ) - ] - _ncsr = _CSR_IO_Buffer.merge( - [_CSR_IO_Buffer.from_ijd(i, j, d, shape=sp_coo.shape) for i, j, d in splits] - ) - - assert ( - sp_coo.tocsr() - != sparse.csr_matrix( - (_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape - ) - ).nnz == 0 - - -@pytest.mark.parametrize( - "shape", - [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], -) -def test_csr__sort_indices(shape: Tuple[int, int]) -> None: - from tiledbsoma_ml.pytorch import _CSR_IO_Buffer - - sp_coo = sparse.random( - shape[0], shape[1], dtype=np.float32, format="coo", density=0.05 - ) - sp_csr = sp_coo.tocsr() - - _ncsr = _CSR_IO_Buffer.from_ijd( - sp_coo.row, sp_coo.col, sp_coo.data, shape=sp_coo.shape - ).sort_indices() - - assert np.array_equal(sp_csr.indptr, _ncsr.indptr) - assert np.array_equal(sp_csr.indices, _ncsr.indices) - assert np.array_equal(sp_csr.data, _ncsr.data) From 3e7b5234e64d838b3e6d9476c37388c1852ef4ac Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sun, 15 Dec 2024 11:30:42 -0500 Subject: [PATCH 7/9] `tests/{_utils,conftest}.py` --- tests/__init__.py | 0 tests/_utils.py | 97 +++++++++++++++++++++++++++ tests/conftest.py | 43 ++++++++++++ tests/test_pytorch.py | 150 ++++-------------------------------------- 4 files changed, 151 insertions(+), 139 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/_utils.py create mode 100644 tests/conftest.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/_utils.py b/tests/_utils.py new file mode 100644 index 0000000..6f51fdb --- /dev/null +++ b/tests/_utils.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from functools import partial +from typing import Callable, Type, Union + +import numpy as np +import pyarrow as pa +from scipy.sparse import coo_matrix, spmatrix +from tiledbsoma._collection import CollectionBase + +from tiledbsoma_ml import ( + ExperimentAxisQueryIterableDataset, + ExperimentAxisQueryIterDataPipe, +) +from tiledbsoma_ml.pytorch import ExperimentAxisQueryIterable + +assert_array_equal = partial(np.testing.assert_array_equal, strict=True) + +# These control which classes are tested (for most, but not all tests). +# Centralized to allow easy add/delete of specific test parameters. +IterableWrapperType = Union[ + Type[ExperimentAxisQueryIterDataPipe], + Type[ExperimentAxisQueryIterableDataset], +] +IterableWrappers = ( + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, +) +PipeClassType = Union[ + Type[ExperimentAxisQueryIterable], + IterableWrapperType, +] +PipeClasses = ( + ExperimentAxisQueryIterable, + *IterableWrappers, +) +XValueGen = Callable[[range, range], spmatrix] + + +def pytorch_x_value_gen(obs_range: range, var_range: range) -> spmatrix: + occupied_shape = ( + obs_range.stop - obs_range.start, + var_range.stop - var_range.start, + ) + checkerboard_of_ones = coo_matrix(np.indices(occupied_shape).sum(axis=0) % 2) + checkerboard_of_ones.row += obs_range.start + checkerboard_of_ones.col += var_range.start + return checkerboard_of_ones + + +def pytorch_seq_x_value_gen(obs_range: range, var_range: range) -> spmatrix: + """A sparse matrix where the values of each col are the obs_range values. Useful for checking the + X values are being returned in the correct order.""" + data = np.vstack([list(obs_range)] * len(var_range)).flatten() + rows = np.vstack([list(obs_range)] * len(var_range)).flatten() + cols = np.column_stack([list(var_range)] * len(obs_range)).flatten() + return coo_matrix((data, (rows, cols))) + + +def add_dataframe(coll: CollectionBase, key: str, value_range: range) -> None: + df = coll.add_new_dataframe( + key, + schema=pa.schema( + [ + ("soma_joinid", pa.int64()), + ("label", pa.large_string()), + ("label2", pa.large_string()), + ] + ), + index_column_names=["soma_joinid"], + ) + df.write( + pa.Table.from_pydict( + { + "soma_joinid": list(value_range), + "label": [str(i) for i in value_range], + "label2": ["c" for i in value_range], + } + ) + ) + + +def add_sparse_array( + coll: CollectionBase, + key: str, + obs_range: range, + var_range: range, + value_gen: XValueGen, +) -> None: + a = coll.add_new_sparse_ndarray( + key, type=pa.float32(), shape=(obs_range.stop, var_range.stop) + ) + tensor = pa.SparseCOOTensor.from_scipy(value_gen(obs_range, var_range)) + a.write(tensor) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..fd7a7ce --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. +# +# conftest.py defines pytest fixtures that are available to all test files. + +from pathlib import Path +from typing import Union + +import pytest +from tiledbsoma import Experiment, Measurement +from tiledbsoma._collection import Collection + +from ._utils import XValueGen, add_dataframe, add_sparse_array + + +@pytest.fixture +def X_layer_names() -> list[str]: + return ["raw"] + + +@pytest.fixture(scope="function") +def soma_experiment( + tmp_path: Path, + obs_range: Union[int, range], + var_range: Union[int, range], + X_value_gen: XValueGen, +) -> Experiment: + with Experiment.create((tmp_path / "exp").as_posix()) as exp: + if isinstance(obs_range, int): + obs_range = range(obs_range) + if isinstance(var_range, int): + var_range = range(var_range) + + add_dataframe(exp, "obs", obs_range) + ms = exp.add_new_collection("ms") + rna = ms.add_new_collection("RNA", Measurement) + add_dataframe(rna, "var", var_range) + rna_x = rna.add_new_collection("X", Collection) + add_sparse_array(rna_x, "raw", obs_range, var_range, X_value_gen) + + return Experiment.open((tmp_path / "exp").as_posix()) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index f514c53..5c4a576 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -6,162 +6,34 @@ from __future__ import annotations from functools import partial -from pathlib import Path -from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union +from typing import Any, Tuple from unittest.mock import patch import numpy as np import numpy.typing as npt import pandas as pd -import pyarrow as pa import pytest import tiledbsoma as soma from pandas._testing import assert_frame_equal from scipy import sparse -from scipy.sparse import coo_matrix, spmatrix -from tiledbsoma import Experiment, _factory -from tiledbsoma._collection import CollectionBase +from tiledbsoma import Experiment from torch.utils.data._utils.worker import WorkerInfo +from tests._utils import ( + IterableWrappers, + IterableWrapperType, + PipeClasses, + PipeClassType, + assert_array_equal, + pytorch_seq_x_value_gen, + pytorch_x_value_gen, +) from tiledbsoma_ml.pytorch import ( ExperimentAxisQueryIterable, - ExperimentAxisQueryIterableDataset, ExperimentAxisQueryIterDataPipe, experiment_dataloader, ) -assert_array_equal = partial(np.testing.assert_array_equal, strict=True) - -# These control which classes are tested (for most, but not all tests). -# Centralized to allow easy add/delete of specific test parameters. -IterableWrapperType = Union[ - Type[ExperimentAxisQueryIterDataPipe], - Type[ExperimentAxisQueryIterableDataset], -] -IterableWrappers = ( - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, -) -PipeClassType = Union[ - Type[ExperimentAxisQueryIterable], - IterableWrapperType, -] -PipeClasses = ( - ExperimentAxisQueryIterable, - *IterableWrappers, -) -XValueGen = Callable[[range, range], spmatrix] - - -def pytorch_x_value_gen(obs_range: range, var_range: range) -> spmatrix: - occupied_shape = ( - obs_range.stop - obs_range.start, - var_range.stop - var_range.start, - ) - checkerboard_of_ones = coo_matrix(np.indices(occupied_shape).sum(axis=0) % 2) - checkerboard_of_ones.row += obs_range.start - checkerboard_of_ones.col += var_range.start - return checkerboard_of_ones - - -def pytorch_seq_x_value_gen(obs_range: range, var_range: range) -> spmatrix: - """A sparse matrix where the values of each col are the obs_range values. Useful for checking the - X values are being returned in the correct order.""" - data = np.vstack([list(obs_range)] * len(var_range)).flatten() - rows = np.vstack([list(obs_range)] * len(var_range)).flatten() - cols = np.column_stack([list(var_range)] * len(obs_range)).flatten() - return coo_matrix((data, (rows, cols))) - - -@pytest.fixture -def X_layer_names() -> list[str]: - return ["raw"] - - -@pytest.fixture -def obsp_layer_names() -> Optional[list[str]]: - return None - - -@pytest.fixture -def varp_layer_names() -> Optional[list[str]]: - return None - - -def add_dataframe(coll: CollectionBase, key: str, value_range: range) -> None: - df = coll.add_new_dataframe( - key, - schema=pa.schema( - [ - ("soma_joinid", pa.int64()), - ("label", pa.large_string()), - ("label2", pa.large_string()), - ] - ), - index_column_names=["soma_joinid"], - ) - df.write( - pa.Table.from_pydict( - { - "soma_joinid": list(value_range), - "label": [str(i) for i in value_range], - "label2": ["c" for i in value_range], - } - ) - ) - - -def add_sparse_array( - coll: CollectionBase, - key: str, - obs_range: range, - var_range: range, - value_gen: XValueGen, -) -> None: - a = coll.add_new_sparse_ndarray( - key, type=pa.float32(), shape=(obs_range.stop, var_range.stop) - ) - tensor = pa.SparseCOOTensor.from_scipy(value_gen(obs_range, var_range)) - a.write(tensor) - - -@pytest.fixture(scope="function") -def soma_experiment( - tmp_path: Path, - obs_range: Union[int, range], - var_range: Union[int, range], - X_value_gen: XValueGen, - obsp_layer_names: Sequence[str], - varp_layer_names: Sequence[str], -) -> soma.Experiment: - with soma.Experiment.create((tmp_path / "exp").as_posix()) as exp: - if isinstance(obs_range, int): - obs_range = range(obs_range) - if isinstance(var_range, int): - var_range = range(var_range) - - add_dataframe(exp, "obs", obs_range) - ms = exp.add_new_collection("ms") - rna = ms.add_new_collection("RNA", soma.Measurement) - add_dataframe(rna, "var", var_range) - rna_x = rna.add_new_collection("X", soma.Collection) - add_sparse_array(rna_x, "raw", obs_range, var_range, X_value_gen) - - if obsp_layer_names: - obsp = rna.add_new_collection("obsp") - for obsp_layer_name in obsp_layer_names: - add_sparse_array( - obsp, obsp_layer_name, obs_range, var_range, X_value_gen - ) - - if varp_layer_names: - varp = rna.add_new_collection("varp") - for varp_layer_name in varp_layer_names: - add_sparse_array( - varp, varp_layer_name, obs_range, var_range, X_value_gen - ) - return _factory.open((tmp_path / "exp").as_posix()) - @pytest.mark.parametrize( "obs_range,var_range,X_value_gen", From 869a99fdffdc2ce4b830373acdab6b67700047ce Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sun, 15 Dec 2024 11:35:42 -0500 Subject: [PATCH 8/9] `dataloader.py` --- src/tiledbsoma_ml/__init__.py | 2 +- src/tiledbsoma_ml/dataloader.py | 85 ++++++++++++++ src/tiledbsoma_ml/pytorch.py | 72 ------------ tests/test_dataloader.py | 199 ++++++++++++++++++++++++++++++++ tests/test_pytorch.py | 192 +----------------------------- 5 files changed, 286 insertions(+), 264 deletions(-) create mode 100644 src/tiledbsoma_ml/dataloader.py create mode 100644 tests/test_dataloader.py diff --git a/src/tiledbsoma_ml/__init__.py b/src/tiledbsoma_ml/__init__.py index 263608f..793da7a 100644 --- a/src/tiledbsoma_ml/__init__.py +++ b/src/tiledbsoma_ml/__init__.py @@ -5,10 +5,10 @@ """An API to support machine learning applications built on SOMA.""" +from .dataloader import experiment_dataloader from .pytorch import ( ExperimentAxisQueryIterableDataset, ExperimentAxisQueryIterDataPipe, - experiment_dataloader, ) __version__ = "0.1.0-dev" diff --git a/src/tiledbsoma_ml/dataloader.py b/src/tiledbsoma_ml/dataloader.py new file mode 100644 index 0000000..001e332 --- /dev/null +++ b/src/tiledbsoma_ml/dataloader.py @@ -0,0 +1,85 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Any, TypeVar + +from torch.utils.data import DataLoader + +from tiledbsoma_ml._distributed import init_multiprocessing +from tiledbsoma_ml.pytorch import ( + ExperimentAxisQueryIterableDataset, + ExperimentAxisQueryIterDataPipe, +) + +_T = TypeVar("_T") + + +def experiment_dataloader( + ds: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + **dataloader_kwargs: Any, +) -> DataLoader: + """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a + :class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` + or :class:`tiledbsoma_ml.ExperimentAxisQueryIterDataPipe`. + + Several :class:`torch.utils.data.DataLoader` constructor parameters are not applicable, or are non-performant, + when using loaders from this module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``. + Specifying any of these parameters will result in an error. + + Refer to ``https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader`` for more information on + :class:`torch.utils.data.DataLoader` parameters. + + Args: + ds: + A :class:`torch.utils.data.IterableDataset` or a :class:`torchdata.datapipes.iter.IterDataPipe`. May + include chained data pipes. + **dataloader_kwargs: + Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor, + except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not + supported when using data loaders in this module. + + Returns: + A :class:`torch.utils.data.DataLoader`. + + Raises: + ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, or ``batch_sampler`` params + are passed as keyword arguments. + + Lifecycle: + experimental + """ + unsupported_dataloader_args = [ + "shuffle", + "batch_size", + "sampler", + "batch_sampler", + ] + if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()): + raise ValueError( + f"The {','.join(unsupported_dataloader_args)} DataLoader parameters are not supported" + ) + + if dataloader_kwargs.get("num_workers", 0) > 0: + init_multiprocessing() + + if "collate_fn" not in dataloader_kwargs: + dataloader_kwargs["collate_fn"] = _collate_noop + + return DataLoader( + ds, + batch_size=None, # batching is handled by upstream iterator + shuffle=False, # shuffling is handled by upstream iterator + **dataloader_kwargs, + ) + + +def _collate_noop(datum: _T) -> _T: + """Noop collation for use with a dataloader instance. + + Private. + """ + return datum diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 50686e8..410cfeb 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -12,13 +12,11 @@ import time from math import ceil from typing import ( - Any, ContextManager, Iterable, Iterator, Sequence, Tuple, - TypeVar, Union, ) @@ -36,15 +34,12 @@ from tiledbsoma_ml._distributed import ( get_distributed_world_rank, get_worker_world_rank, - init_multiprocessing, ) from tiledbsoma_ml._experiment_locator import ExperimentLocator from tiledbsoma_ml._utils import NDArrayNumber, batched, splits logger = logging.getLogger("tiledbsoma_ml.pytorch") -_T = TypeVar("_T") - NDArrayJoinId = npt.NDArray[np.int64] XBatch = Union[NDArrayNumber, sparse.csr_matrix] Batch = Tuple[XBatch, pd.DataFrame] @@ -872,70 +867,3 @@ def set_epoch(self, epoch: int) -> None: @property def epoch(self) -> int: return self._exp_iter.epoch - - -def experiment_dataloader( - ds: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, - **dataloader_kwargs: Any, -) -> torch.utils.data.DataLoader: - """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a - :class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` - or :class:`tiledbsoma_ml.ExperimentAxisQueryIterDataPipe`. - - Several :class:`torch.utils.data.DataLoader` constructor parameters are not applicable, or are non-performant, - when using loaders from this module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``. - Specifying any of these parameters will result in an error. - - Refer to ``https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader`` for more information on - :class:`torch.utils.data.DataLoader` parameters. - - Args: - ds: - A :class:`torch.utils.data.IterableDataset` or a :class:`torchdata.datapipes.iter.IterDataPipe`. May - include chained data pipes. - **dataloader_kwargs: - Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor, - except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not - supported when using data loaders in this module. - - Returns: - A :class:`torch.utils.data.DataLoader`. - - Raises: - ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, or ``batch_sampler`` params - are passed as keyword arguments. - - Lifecycle: - experimental - """ - unsupported_dataloader_args = [ - "shuffle", - "batch_size", - "sampler", - "batch_sampler", - ] - if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()): - raise ValueError( - f"The {','.join(unsupported_dataloader_args)} DataLoader parameters are not supported" - ) - - if dataloader_kwargs.get("num_workers", 0) > 0: - init_multiprocessing() - - if "collate_fn" not in dataloader_kwargs: - dataloader_kwargs["collate_fn"] = _collate_noop - - return torch.utils.data.DataLoader( - ds, - batch_size=None, # batching is handled by upstream iterator - shuffle=False, # shuffling is handled by upstream iterator - **dataloader_kwargs, - ) - - -def _collate_noop(datum: _T) -> _T: - """Noop collation for use with a dataloader instance. - - Private. - """ - return datum diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py new file mode 100644 index 0000000..62e9b4c --- /dev/null +++ b/tests/test_dataloader.py @@ -0,0 +1,199 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from functools import partial +from typing import Any, Tuple +from unittest.mock import patch + +import numpy as np +import numpy.typing as npt +import pandas as pd +import pytest +from tiledbsoma import Experiment + +from tests._utils import IterableWrappers, IterableWrapperType, pytorch_x_value_gen +from tiledbsoma_ml import ExperimentAxisQueryIterDataPipe +from tiledbsoma_ml.dataloader import experiment_dataloader + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] +) +@pytest.mark.parametrize("PipeClass", IterableWrappers) +def test_multiprocessing__returns_full_result( + PipeClass: IterableWrapperType, + soma_experiment: Experiment, +) -> None: + """Tests that ``ExperimentAxisQueryIterDataPipe`` / ``ExperimentAxisQueryIterableDataset`` + provide all data, as collected from multiple processes that are managed by a PyTorch DataLoader + with multiple workers configured.""" + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid", "label"], + io_batch_size=3, # two chunks, one per worker + ) + # Wrap with a DataLoader, which sets up the multiprocessing + dl = experiment_dataloader(dp, num_workers=2) + + full_result = list(iter(dl)) + + soma_joinids = np.concatenate( + [t[1]["soma_joinid"].to_numpy() for t in full_result] + ) + assert sorted(soma_joinids) == list(range(6)) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("PipeClass", IterableWrappers) +def test_experiment_dataloader__non_batched( + PipeClass: IterableWrapperType, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + data = [row for row in dl] + assert all(d[0].shape == (3,) for d in data) + assert all(d[1].shape == (1, 1) for d in data) + + row = data[0] + assert row[0].tolist() == [0, 1, 0] + assert row[1]["label"].tolist() == ["0"] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("PipeClass", IterableWrappers) +def test_experiment_dataloader__batched( + PipeClass: IterableWrapperType, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + data = [row for row in dl] + + batch = data[0] + assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1].to_numpy().tolist() == [[0], [1], [2]] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [ + (10, 3, pytorch_x_value_gen, use_eager_fetch) + for use_eager_fetch in (True, False) + ], +) +@pytest.mark.parametrize("PipeClass", IterableWrappers) +def test_experiment_dataloader__batched_length( + PipeClass: IterableWrapperType, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + assert len(dl) == len(list(dl)) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,batch_size", + [(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)], +) +@pytest.mark.parametrize("PipeClass", IterableWrappers) +def test_experiment_dataloader__collate_fn( + PipeClass: IterableWrapperType, + soma_experiment: Experiment, + batch_size: int, +) -> None: + def collate_fn( + batch_size: int, data: Tuple[npt.NDArray[np.number[Any]], pd.DataFrame] + ) -> Tuple[npt.NDArray[np.number[Any]], pd.DataFrame]: + assert isinstance(data, tuple) + assert len(data) == 2 + assert isinstance(data[0], np.ndarray) and isinstance(data[1], pd.DataFrame) + if batch_size > 1: + assert data[0].shape[0] == data[1].shape[0] + assert data[0].shape[0] <= batch_size + else: + assert data[0].ndim == 1 + assert data[1].shape[1] <= batch_size + return data + + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=batch_size, + shuffle=False, + ) + dl = experiment_dataloader(dp, collate_fn=partial(collate_fn, batch_size)) + assert len(list(dl)) > 0 + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(10, 1, pytorch_x_value_gen)] +) +def test__pytorch_splitting( + soma_experiment: Experiment, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = ExperimentAxisQueryIterDataPipe( + query, + X_name="raw", + obs_column_names=["label"], + ) + # function not available for IterableDataset, yet.... + dp_train, dp_test = dp.random_split( + weights={"train": 0.7, "test": 0.3}, seed=1234 + ) + dl = experiment_dataloader(dp_train) + + all_rows = list(iter(dl)) + assert len(all_rows) == 7 + + +def test_experiment_dataloader__unsupported_params__fails() -> None: + with patch( + "tiledbsoma_ml.pytorch.ExperimentAxisQueryIterDataPipe" + ) as dummy_exp_data_pipe: + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, shuffle=True) + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, batch_size=3) + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[]) + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, sampler=[]) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 5c4a576..3faf124 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -5,12 +5,9 @@ from __future__ import annotations -from functools import partial -from typing import Any, Tuple from unittest.mock import patch import numpy as np -import numpy.typing as npt import pandas as pd import pytest import tiledbsoma as soma @@ -20,19 +17,13 @@ from torch.utils.data._utils.worker import WorkerInfo from tests._utils import ( - IterableWrappers, - IterableWrapperType, PipeClasses, PipeClassType, assert_array_equal, pytorch_seq_x_value_gen, pytorch_x_value_gen, ) -from tiledbsoma_ml.pytorch import ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - experiment_dataloader, -) +from tiledbsoma_ml.pytorch import ExperimentAxisQueryIterable @pytest.mark.parametrize( @@ -324,35 +315,6 @@ def test_batching__partial_soma_batches_are_concatenated( assert [len(batch[0]) for batch in batches] == [3, 3, 3, 1] -@pytest.mark.parametrize( - "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] -) -@pytest.mark.parametrize("PipeClass", IterableWrappers) -def test_multiprocessing__returns_full_result( - PipeClass: IterableWrapperType, - soma_experiment: Experiment, -) -> None: - """Tests that ``ExperimentAxisQueryIterDataPipe`` / ``ExperimentAxisQueryIterableDataset`` - provide all data, as collected from multiple processes that are managed by a PyTorch DataLoader - with multiple workers configured.""" - with soma_experiment.axis_query(measurement_name="RNA") as query: - dp = PipeClass( - query, - X_name="raw", - obs_column_names=["soma_joinid", "label"], - io_batch_size=3, # two chunks, one per worker - ) - # Wrap with a DataLoader, which sets up the multiprocessing - dl = experiment_dataloader(dp, num_workers=2) - - full_result = list(iter(dl)) - - soma_joinids = np.concatenate( - [t[1]["soma_joinid"].to_numpy() for t in full_result] - ) - assert sorted(soma_joinids) == list(range(6)) - - @pytest.mark.parametrize( "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen), (7, 3, pytorch_x_value_gen)], @@ -464,144 +426,6 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank( assert soma_joinids == expected_joinids -@pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], -) -@pytest.mark.parametrize("PipeClass", IterableWrappers) -def test_experiment_dataloader__non_batched( - PipeClass: IterableWrapperType, - soma_experiment: Experiment, - use_eager_fetch: bool, -) -> None: - with soma_experiment.axis_query(measurement_name="RNA") as query: - dp = PipeClass( - query, - X_name="raw", - obs_column_names=["label"], - shuffle=False, - use_eager_fetch=use_eager_fetch, - ) - dl = experiment_dataloader(dp) - data = [row for row in dl] - assert all(d[0].shape == (3,) for d in data) - assert all(d[1].shape == (1, 1) for d in data) - - row = data[0] - assert row[0].tolist() == [0, 1, 0] - assert row[1]["label"].tolist() == ["0"] - - -@pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], -) -@pytest.mark.parametrize("PipeClass", IterableWrappers) -def test_experiment_dataloader__batched( - PipeClass: IterableWrapperType, - soma_experiment: Experiment, - use_eager_fetch: bool, -) -> None: - with soma_experiment.axis_query(measurement_name="RNA") as query: - dp = PipeClass( - query, - X_name="raw", - batch_size=3, - shuffle=False, - use_eager_fetch=use_eager_fetch, - ) - dl = experiment_dataloader(dp) - data = [row for row in dl] - - batch = data[0] - assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] - assert batch[1].to_numpy().tolist() == [[0], [1], [2]] - - -@pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [ - (10, 3, pytorch_x_value_gen, use_eager_fetch) - for use_eager_fetch in (True, False) - ], -) -@pytest.mark.parametrize("PipeClass", IterableWrappers) -def test_experiment_dataloader__batched_length( - PipeClass: IterableWrapperType, - soma_experiment: Experiment, - use_eager_fetch: bool, -) -> None: - with soma_experiment.axis_query(measurement_name="RNA") as query: - dp = PipeClass( - query, - X_name="raw", - obs_column_names=["label"], - batch_size=3, - shuffle=False, - use_eager_fetch=use_eager_fetch, - ) - dl = experiment_dataloader(dp) - assert len(dl) == len(list(dl)) - - -@pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,batch_size", - [(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)], -) -@pytest.mark.parametrize("PipeClass", IterableWrappers) -def test_experiment_dataloader__collate_fn( - PipeClass: IterableWrapperType, - soma_experiment: Experiment, - batch_size: int, -) -> None: - def collate_fn( - batch_size: int, data: Tuple[npt.NDArray[np.number[Any]], pd.DataFrame] - ) -> Tuple[npt.NDArray[np.number[Any]], pd.DataFrame]: - assert isinstance(data, tuple) - assert len(data) == 2 - assert isinstance(data[0], np.ndarray) and isinstance(data[1], pd.DataFrame) - if batch_size > 1: - assert data[0].shape[0] == data[1].shape[0] - assert data[0].shape[0] <= batch_size - else: - assert data[0].ndim == 1 - assert data[1].shape[1] <= batch_size - return data - - with soma_experiment.axis_query(measurement_name="RNA") as query: - dp = PipeClass( - query, - X_name="raw", - obs_column_names=["label"], - batch_size=batch_size, - shuffle=False, - ) - dl = experiment_dataloader(dp, collate_fn=partial(collate_fn, batch_size)) - assert len(list(dl)) > 0 - - -@pytest.mark.parametrize( - "obs_range,var_range,X_value_gen", [(10, 1, pytorch_x_value_gen)] -) -def test__pytorch_splitting( - soma_experiment: Experiment, -) -> None: - with soma_experiment.axis_query(measurement_name="RNA") as query: - dp = ExperimentAxisQueryIterDataPipe( - query, - X_name="raw", - obs_column_names=["label"], - ) - # function not available for IterableDataset, yet.... - dp_train, dp_test = dp.random_split( - weights={"train": 0.7, "test": 0.3}, seed=1234 - ) - dl = experiment_dataloader(dp_train) - - all_rows = list(iter(dl)) - assert len(all_rows) == 7 - - @pytest.mark.parametrize( "obs_range,var_range,X_value_gen", [(16, 1, pytorch_seq_x_value_gen)] ) @@ -653,17 +477,3 @@ def test_experiment_axis_query_iterable_error_checks( X_name="raw", shuffle=True, ) - - -def test_experiment_dataloader__unsupported_params__fails() -> None: - with patch( - "tiledbsoma_ml.pytorch.ExperimentAxisQueryIterDataPipe" - ) as dummy_exp_data_pipe: - with pytest.raises(ValueError): - experiment_dataloader(dummy_exp_data_pipe, shuffle=True) - with pytest.raises(ValueError): - experiment_dataloader(dummy_exp_data_pipe, batch_size=3) - with pytest.raises(ValueError): - experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[]) - with pytest.raises(ValueError): - experiment_dataloader(dummy_exp_data_pipe, sampler=[]) From 07084961142b0ebbed9acc3897eb83a13b9a185a Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sun, 15 Dec 2024 12:08:43 -0500 Subject: [PATCH 9/9] `data{pipe,set}.py` --- src/tiledbsoma_ml/__init__.py | 6 +- src/tiledbsoma_ml/dataloader.py | 6 +- src/tiledbsoma_ml/datapipe.py | 114 ++++++++++++ src/tiledbsoma_ml/dataset.py | 222 ++++++++++++++++++++++ src/tiledbsoma_ml/pytorch.py | 314 -------------------------------- tests/test_dataloader.py | 2 +- 6 files changed, 341 insertions(+), 323 deletions(-) create mode 100644 src/tiledbsoma_ml/datapipe.py create mode 100644 src/tiledbsoma_ml/dataset.py diff --git a/src/tiledbsoma_ml/__init__.py b/src/tiledbsoma_ml/__init__.py index 793da7a..0cab2bb 100644 --- a/src/tiledbsoma_ml/__init__.py +++ b/src/tiledbsoma_ml/__init__.py @@ -6,10 +6,8 @@ """An API to support machine learning applications built on SOMA.""" from .dataloader import experiment_dataloader -from .pytorch import ( - ExperimentAxisQueryIterableDataset, - ExperimentAxisQueryIterDataPipe, -) +from .datapipe import ExperimentAxisQueryIterDataPipe +from .dataset import ExperimentAxisQueryIterableDataset __version__ = "0.1.0-dev" diff --git a/src/tiledbsoma_ml/dataloader.py b/src/tiledbsoma_ml/dataloader.py index 001e332..4cc1a4b 100644 --- a/src/tiledbsoma_ml/dataloader.py +++ b/src/tiledbsoma_ml/dataloader.py @@ -10,10 +10,8 @@ from torch.utils.data import DataLoader from tiledbsoma_ml._distributed import init_multiprocessing -from tiledbsoma_ml.pytorch import ( - ExperimentAxisQueryIterableDataset, - ExperimentAxisQueryIterDataPipe, -) +from tiledbsoma_ml.datapipe import ExperimentAxisQueryIterDataPipe +from tiledbsoma_ml.dataset import ExperimentAxisQueryIterableDataset _T = TypeVar("_T") diff --git a/src/tiledbsoma_ml/datapipe.py b/src/tiledbsoma_ml/datapipe.py new file mode 100644 index 0000000..0f7c20c --- /dev/null +++ b/src/tiledbsoma_ml/datapipe.py @@ -0,0 +1,114 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Iterator, Sequence, Tuple + +from somacore import ExperimentAxisQuery +from torch.utils.data.dataset import Dataset +from torchdata.datapipes.iter import IterDataPipe + +from tiledbsoma_ml.pytorch import Batch, ExperimentAxisQueryIterable + + +class ExperimentAxisQueryIterDataPipe( + IterDataPipe[Dataset[Batch]] # type:ignore[misc] +): + """A :class:`torchdata.datapipes.iter.IterDataPipe` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. + + This class is based upon the now-deprecated :class:`torchdata.datapipes` API, and should only be used for + legacy code. See [GitHub issue #1196](https://github.com/pytorch/data/issues/1196) and the + TorchData [README](https://github.com/pytorch/data/blob/v0.8.0/README.md) for more information. + + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + + def __init__( + self, + query: ExperimentAxisQuery, + X_name: str = "raw", + obs_column_names: Sequence[str] = ("soma_joinid",), + batch_size: int = 1, + shuffle: bool = True, + seed: int | None = None, + io_batch_size: int = 2**16, + shuffle_chunk_size: int = 64, + return_sparse_X: bool = False, + use_eager_fetch: bool = True, + ): + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + super().__init__() + self._exp_iter = ExperimentAxisQueryIterable( + query=query, + X_name=X_name, + obs_column_names=obs_column_names, + batch_size=batch_size, + shuffle=shuffle, + seed=seed, + io_batch_size=io_batch_size, + return_sparse_X=return_sparse_X, + use_eager_fetch=use_eager_fetch, + shuffle_chunk_size=shuffle_chunk_size, + ) + + def __iter__(self) -> Iterator[Batch]: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + batch_size = self._exp_iter.batch_size + for X, obs in self._exp_iter: + if batch_size == 1: + X = X[0] # This is a no-op for `csr_matrix`s + yield X, obs + + def __len__(self) -> int: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + return len(self._exp_iter) + + @property + def shape(self) -> Tuple[int, int]: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + return self._exp_iter.shape + + def set_epoch(self, epoch: int) -> None: + """ + Set the epoch for this Data iterator. + + When :attr:`shuffle=True`, this will ensure that all replicas use a different + random ordering for each epoch. Failure to call this method before each epoch + will result in the same data ordering. + + This call must be made before the per-epoch iterator is created. + + Lifecycle: + experimental + """ + self._exp_iter.set_epoch(epoch) + + @property + def epoch(self) -> int: + return self._exp_iter.epoch diff --git a/src/tiledbsoma_ml/dataset.py b/src/tiledbsoma_ml/dataset.py new file mode 100644 index 0000000..f1f7cd6 --- /dev/null +++ b/src/tiledbsoma_ml/dataset.py @@ -0,0 +1,222 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Iterator, Sequence, Tuple + +from somacore import ExperimentAxisQuery +from torch.utils.data import IterableDataset + +from tiledbsoma_ml.pytorch import Batch, ExperimentAxisQueryIterable + + +class ExperimentAxisQueryIterableDataset(IterableDataset[Batch]): # type:ignore[misc] + """A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. + + This class works seamlessly with :class:`torch.utils.data.DataLoader` to load ``obs`` and ``X`` data as + specified by a SOMA :class:`tiledbsoma.ExperimentAxisQuery`, providing an iterator over batches of + ``obs`` and ``X`` data. Each iteration will yield a tuple containing an :class:`numpy.ndarray` + and a :class:`pandas.DataFrame`. + + For example: + + >>> import torch + >>> import tiledbsoma + >>> import tiledbsoma_ml + >>> with tiledbsoma.Experiment.open("my_experiment_path") as exp: + ... with exp.axis_query(measurement_name="RNA", obs_query=tiledbsoma.AxisQuery(value_filter="tissue_type=='lung'")) as query: + ... ds = tiledbsoma_ml.ExperimentAxisQueryIterableDataset(query) + ... dataloader = torch.utils.data.DataLoader(ds) + >>> data = next(iter(dataloader)) + >>> data + (array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), + soma_joinid + 0 57905025) + >>> data[0] + array([0., 0., 0., ..., 0., 0., 0.], dtype=float32) + >>> data[1] + soma_joinid + 0 57905025 + + The ``batch_size`` parameter controls the number of rows of ``obs`` and ``X`` data that are returned in each + iteration. If the ``batch_size`` is 1, then each result will have rank 1, else it will have rank 2. A ``batch_size`` + of 1 is compatible with :class:`torch.utils.data.DataLoader`-implemented batching, but it will usually be more + performant to create mini-batches using this class, and set the ``DataLoader`` batch size to `None`. + + The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` DataFrame (the + default is a single column, containing the ``soma_joinid`` for the ``obs`` dimension). + + The ``io_batch_size`` parameter determines the number of rows read, from which mini-batches are yielded. A + larger value will increase total memory usage and may reduce average read time per row. + + Shuffling support is enabled with the ``shuffle`` parameter, and will normally be more performant than using + :class:`DataLoader` shuffling. The shuffling algorithm works as follows: + + 1. Rows selected by the query are subdivided into groups of size ``shuffle_chunk_size``, aka a "shuffle chunk". + 2. A random selection of shuffle chunks is drawn and read as a single I/O buffer (of size ``io_buffer_size``). + 3. The entire I/O buffer is shuffled. + + Put another way, we read randomly selected groups of observations from across all query results, concatenate + those into an I/O buffer, and shuffle the buffer before returning mini-batches. The randomness of the shuffle + is therefore determined by the ``io_buffer_size`` (number of rows read), and the ``shuffle_chunk_size`` + (number of rows in each draw). Decreasing ``shuffle_chunk_size`` will increase shuffling randomness, and decrease I/O + performance. + + This class will detect when run in a multiprocessing mode, including multi-worker :class:`torch.utils.data.DataLoader` + and multi-process training such as :class:`torch.nn.parallel.DistributedDataParallel`, and will automatically partition + data appropriately. In the case of distributed training, sample partitions across all processes must be equal. Any + data tail will be dropped. + + Lifecycle: + experimental + """ + + def __init__( + self, + query: ExperimentAxisQuery, + X_name: str = "raw", + obs_column_names: Sequence[str] = ("soma_joinid",), + batch_size: int = 1, + shuffle: bool = True, + seed: int | None = None, + io_batch_size: int = 2**16, + shuffle_chunk_size: int = 64, + return_sparse_X: bool = False, + use_eager_fetch: bool = True, + ): + """ + Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. + + The resulting iterator will produce a tuple containing associated slices of ``X`` and ``obs`` data, as + a NumPy ``ndarray`` (or optionally, :class:`scipy.sparse.csr_matrix`) and a Pandas ``DataFrame`` respectively. + + Args: + query: + A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data which will be iterated over. + X_name: + The name of the ``X`` layer to read. + obs_column_names: + The names of the ``obs`` columns to return. At least one column name must be specified. + Default is ``('soma_joinid',)``. + batch_size: + The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of + ``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will + result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). + + Note that a ``batch_size`` of 1 allows this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` + batching, but you will achieve higher performance by performing batching in this class, and setting the ``DataLoader`` + batch_size parameter to ``None``. + shuffle: + Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. + io_batch_size: + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of + this class's behavior: 1) The maximum memory utilization, with larger values providing + better read performance, but also requiring more memory; 2) The number of rows read prior to shuffling + (see ``shuffle`` parameter for details). The default value of 131,072 provides high performance, but + may need to be reduced in memory limited hosts (or where a large number of :class:`DataLoader` workers + are employed). + shuffle_chunk_size: + The number of contiguous rows sampled, prior to concatenation and shuffling. + Larger numbers correspond to less randomness, but greater read performance. + If ``shuffle == False``, this parameter is ignored. + return_sparse_X: + If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will + return ``X`` data as a :class:`numpy.ndarray`. + seed: + The random seed used for shuffling. Defaults to ``None`` (no seed). This argument *must* be specified when using + :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker + processes. + use_eager_fetch: + Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made + available for processing via the iterator. This allows network (or filesystem) requests to be made in + parallel with client-side processing of the SOMA data, potentially improving overall performance at the + cost of doubling memory utilization. Defaults to ``True``. + + Raises: + ``ValueError`` on various unsupported or malformed parameter values. + + Lifecycle: + experimental + + """ + super().__init__() + self._exp_iter = ExperimentAxisQueryIterable( + query=query, + X_name=X_name, + obs_column_names=obs_column_names, + batch_size=batch_size, + shuffle=shuffle, + seed=seed, + io_batch_size=io_batch_size, + return_sparse_X=return_sparse_X, + use_eager_fetch=use_eager_fetch, + shuffle_chunk_size=shuffle_chunk_size, + ) + + def __iter__(self) -> Iterator[Batch]: + """Create ``Iterator`` yielding "mini-batch" tuples of :class:`numpy.ndarray` (or :class:`scipy.csr_matrix`) and + :class:`pandas.DataFrame`. + + Returns: + ``iterator`` + + Lifecycle: + experimental + """ + batch_size = self._exp_iter.batch_size + for X, obs in self._exp_iter: + if batch_size == 1: + X = X[0] # This is a no-op for `csr_matrix`s + yield X, obs + + def __len__(self) -> int: + """Return number of batches this iterable will produce. + + See important caveats in the PyTorch + [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) + documentation regarding ``len(dataloader)``, which also apply to this class. + + Returns: + ``int`` (number of batches). + + Lifecycle: + experimental + """ + return len(self._exp_iter) + + @property + def shape(self) -> Tuple[int, int]: + """Return the number of batches and features that will be yielded from this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`. + + If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), + the number of batches will reflect the size of the data partition assigned to the active process. + + Returns: + A tuple of two ``int`` values: number of batches, number of vars. + + Lifecycle: + experimental + """ + return self._exp_iter.shape + + def set_epoch(self, epoch: int) -> None: + """ + Set the epoch for this Data iterator. + + When :attr:`shuffle=True`, this will ensure that all replicas use a different + random ordering for each epoch. Failure to call this method before each epoch + will result in the same data ordering. + + This call must be made before the per-epoch iterator is created. + + Lifecycle: + experimental + """ + self._exp_iter.set_epoch(epoch) + + @property + def epoch(self) -> int: + return self._exp_iter.epoch diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 410cfeb..ca275f1 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -27,7 +27,6 @@ import scipy.sparse as sparse import tiledbsoma as soma import torch -import torchdata from somacore.query._eager_iter import EagerIterator as _EagerIterator from tiledbsoma_ml._csr import CSR_IO_Buffer @@ -554,316 +553,3 @@ def _mini_batch_iter( # yield the remnant, if any if result is not None: yield result - - -class ExperimentAxisQueryIterDataPipe( - torchdata.datapipes.iter.IterDataPipe[ # type:ignore[misc] - torch.utils.data.dataset.Dataset[Batch] - ], -): - """A :class:`torchdata.datapipes.iter.IterDataPipe` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. - - This class is based upon the now-deprecated :class:`torchdata.datapipes` API, and should only be used for - legacy code. See [GitHub issue #1196](https://github.com/pytorch/data/issues/1196) and the - TorchData [README](https://github.com/pytorch/data/blob/v0.8.0/README.md) for more information. - - See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. - - Lifecycle: - deprecated - """ - - def __init__( - self, - query: soma.ExperimentAxisQuery, - X_name: str = "raw", - obs_column_names: Sequence[str] = ("soma_joinid",), - batch_size: int = 1, - shuffle: bool = True, - seed: int | None = None, - io_batch_size: int = 2**16, - shuffle_chunk_size: int = 64, - return_sparse_X: bool = False, - use_eager_fetch: bool = True, - ): - """ - See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. - - Lifecycle: - deprecated - """ - super().__init__() - self._exp_iter = ExperimentAxisQueryIterable( - query=query, - X_name=X_name, - obs_column_names=obs_column_names, - batch_size=batch_size, - shuffle=shuffle, - seed=seed, - io_batch_size=io_batch_size, - return_sparse_X=return_sparse_X, - use_eager_fetch=use_eager_fetch, - shuffle_chunk_size=shuffle_chunk_size, - ) - - def __iter__(self) -> Iterator[Batch]: - """ - See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. - - Lifecycle: - deprecated - """ - batch_size = self._exp_iter.batch_size - for X, obs in self._exp_iter: - if batch_size == 1: - X = X[0] # This is a no-op for `csr_matrix`s - yield X, obs - - def __len__(self) -> int: - """ - See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. - - Lifecycle: - deprecated - """ - return len(self._exp_iter) - - @property - def shape(self) -> Tuple[int, int]: - """ - See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. - - Lifecycle: - deprecated - """ - return self._exp_iter.shape - - def set_epoch(self, epoch: int) -> None: - """ - Set the epoch for this Data iterator. - - When :attr:`shuffle=True`, this will ensure that all replicas use a different - random ordering for each epoch. Failure to call this method before each epoch - will result in the same data ordering. - - This call must be made before the per-epoch iterator is created. - - Lifecycle: - experimental - """ - self._exp_iter.set_epoch(epoch) - - @property - def epoch(self) -> int: - return self._exp_iter.epoch - - -class ExperimentAxisQueryIterableDataset( - torch.utils.data.IterableDataset[Batch] # type:ignore[misc] -): - """A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. - - This class works seamlessly with :class:`torch.utils.data.DataLoader` to load ``obs`` and ``X`` data as - specified by a SOMA :class:`tiledbsoma.ExperimentAxisQuery`, providing an iterator over batches of - ``obs`` and ``X`` data. Each iteration will yield a tuple containing an :class:`numpy.ndarray` - and a :class:`pandas.DataFrame`. - - For example: - - >>> import torch - >>> import tiledbsoma - >>> import tiledbsoma_ml - >>> with tiledbsoma.Experiment.open("my_experiment_path") as exp: - ... with exp.axis_query(measurement_name="RNA", obs_query=tiledbsoma.AxisQuery(value_filter="tissue_type=='lung'")) as query: - ... ds = tiledbsoma_ml.ExperimentAxisQueryIterableDataset(query) - ... dataloader = torch.utils.data.DataLoader(ds) - >>> data = next(iter(dataloader)) - >>> data - (array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), - soma_joinid - 0 57905025) - >>> data[0] - array([0., 0., 0., ..., 0., 0., 0.], dtype=float32) - >>> data[1] - soma_joinid - 0 57905025 - - The ``batch_size`` parameter controls the number of rows of ``obs`` and ``X`` data that are returned in each - iteration. If the ``batch_size`` is 1, then each result will have rank 1, else it will have rank 2. A ``batch_size`` - of 1 is compatible with :class:`torch.utils.data.DataLoader`-implemented batching, but it will usually be more - performant to create mini-batches using this class, and set the ``DataLoader`` batch size to `None`. - - The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` DataFrame (the - default is a single column, containing the ``soma_joinid`` for the ``obs`` dimension). - - The ``io_batch_size`` parameter determines the number of rows read, from which mini-batches are yielded. A - larger value will increase total memory usage and may reduce average read time per row. - - Shuffling support is enabled with the ``shuffle`` parameter, and will normally be more performant than using - :class:`DataLoader` shuffling. The shuffling algorithm works as follows: - - 1. Rows selected by the query are subdivided into groups of size ``shuffle_chunk_size``, aka a "shuffle chunk". - 2. A random selection of shuffle chunks is drawn and read as a single I/O buffer (of size ``io_buffer_size``). - 3. The entire I/O buffer is shuffled. - - Put another way, we read randomly selected groups of observations from across all query results, concatenate - those into an I/O buffer, and shuffle the buffer before returning mini-batches. The randomness of the shuffle - is therefore determined by the ``io_buffer_size`` (number of rows read), and the ``shuffle_chunk_size`` - (number of rows in each draw). Decreasing ``shuffle_chunk_size`` will increase shuffling randomness, and decrease I/O - performance. - - This class will detect when run in a multiprocessing mode, including multi-worker :class:`torch.utils.data.DataLoader` - and multi-process training such as :class:`torch.nn.parallel.DistributedDataParallel`, and will automatically partition - data appropriately. In the case of distributed training, sample partitions across all processes must be equal. Any - data tail will be dropped. - - Lifecycle: - experimental - """ - - def __init__( - self, - query: soma.ExperimentAxisQuery, - X_name: str = "raw", - obs_column_names: Sequence[str] = ("soma_joinid",), - batch_size: int = 1, - shuffle: bool = True, - seed: int | None = None, - io_batch_size: int = 2**16, - shuffle_chunk_size: int = 64, - return_sparse_X: bool = False, - use_eager_fetch: bool = True, - ): - """ - Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. - - The resulting iterator will produce a tuple containing associated slices of ``X`` and ``obs`` data, as - a NumPy ``ndarray`` (or optionally, :class:`scipy.sparse.csr_matrix`) and a Pandas ``DataFrame`` respectively. - - Args: - query: - A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data which will be iterated over. - X_name: - The name of the ``X`` layer to read. - obs_column_names: - The names of the ``obs`` columns to return. At least one column name must be specified. - Default is ``('soma_joinid',)``. - batch_size: - The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of - ``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will - result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). - - Note that a ``batch_size`` of 1 allows this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` - batching, but you will achieve higher performance by performing batching in this class, and setting the ``DataLoader`` - batch_size parameter to ``None``. - shuffle: - Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. - io_batch_size: - The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of - this class's behavior: 1) The maximum memory utilization, with larger values providing - better read performance, but also requiring more memory; 2) The number of rows read prior to shuffling - (see ``shuffle`` parameter for details). The default value of 131,072 provides high performance, but - may need to be reduced in memory limited hosts (or where a large number of :class:`DataLoader` workers - are employed). - shuffle_chunk_size: - The number of contiguous rows sampled, prior to concatenation and shuffling. - Larger numbers correspond to less randomness, but greater read performance. - If ``shuffle == False``, this parameter is ignored. - return_sparse_X: - If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will - return ``X`` data as a :class:`numpy.ndarray`. - seed: - The random seed used for shuffling. Defaults to ``None`` (no seed). This argument *must* be specified when using - :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker - processes. - use_eager_fetch: - Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made - available for processing via the iterator. This allows network (or filesystem) requests to be made in - parallel with client-side processing of the SOMA data, potentially improving overall performance at the - cost of doubling memory utilization. Defaults to ``True``. - - Raises: - ``ValueError`` on various unsupported or malformed parameter values. - - Lifecycle: - experimental - - """ - super().__init__() - self._exp_iter = ExperimentAxisQueryIterable( - query=query, - X_name=X_name, - obs_column_names=obs_column_names, - batch_size=batch_size, - shuffle=shuffle, - seed=seed, - io_batch_size=io_batch_size, - return_sparse_X=return_sparse_X, - use_eager_fetch=use_eager_fetch, - shuffle_chunk_size=shuffle_chunk_size, - ) - - def __iter__(self) -> Iterator[Batch]: - """Create ``Iterator`` yielding "mini-batch" tuples of :class:`numpy.ndarray` (or :class:`scipy.csr_matrix`) and - :class:`pandas.DataFrame`. - - Returns: - ``iterator`` - - Lifecycle: - experimental - """ - batch_size = self._exp_iter.batch_size - for X, obs in self._exp_iter: - if batch_size == 1: - X = X[0] # This is a no-op for `csr_matrix`s - yield X, obs - - def __len__(self) -> int: - """Return number of batches this iterable will produce. - - See important caveats in the PyTorch - [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) - documentation regarding ``len(dataloader)``, which also apply to this class. - - Returns: - ``int`` (number of batches). - - Lifecycle: - experimental - """ - return len(self._exp_iter) - - @property - def shape(self) -> Tuple[int, int]: - """Return the number of batches and features that will be yielded from this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`. - - If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), - the number of batches will reflect the size of the data partition assigned to the active process. - - Returns: - A tuple of two ``int`` values: number of batches, number of vars. - - Lifecycle: - experimental - """ - return self._exp_iter.shape - - def set_epoch(self, epoch: int) -> None: - """ - Set the epoch for this Data iterator. - - When :attr:`shuffle=True`, this will ensure that all replicas use a different - random ordering for each epoch. Failure to call this method before each epoch - will result in the same data ordering. - - This call must be made before the per-epoch iterator is created. - - Lifecycle: - experimental - """ - self._exp_iter.set_epoch(epoch) - - @property - def epoch(self) -> int: - return self._exp_iter.epoch diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 62e9b4c..aaa8fda 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -187,7 +187,7 @@ def test__pytorch_splitting( def test_experiment_dataloader__unsupported_params__fails() -> None: with patch( - "tiledbsoma_ml.pytorch.ExperimentAxisQueryIterDataPipe" + "tiledbsoma_ml.datapipe.ExperimentAxisQueryIterDataPipe" ) as dummy_exp_data_pipe: with pytest.raises(ValueError): experiment_dataloader(dummy_exp_data_pipe, shuffle=True)