-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split {test_,}pytorch.py
into a few files
#23
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
16a6090
pytorch.py nits
ryan-williams b03c65f
rename `XObsDatum` to `Batch`
ryan-williams 8385bdc
`_experiment_locator.py`
ryan-williams 446cdee
`_distributed.py`
ryan-williams b15df81
`_utils.py`
ryan-williams 2af96ed
`_csr.py`
ryan-williams 3e7b523
`tests/{_utils,conftest}.py`
ryan-williams 869a99f
`dataloader.py`
ryan-williams 0708496
`data{pipe,set}.py`
ryan-williams File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation | ||
# Copyright (c) 2021-2024 TileDB, Inc. | ||
# | ||
# Licensed under the MIT License. | ||
|
||
from math import ceil | ||
from typing import Any, List, Sequence, Tuple, Type | ||
|
||
import numba | ||
import numpy as np | ||
import numpy.typing as npt | ||
from scipy import sparse | ||
from typing_extensions import Self | ||
|
||
from tiledbsoma_ml._utils import NDArrayNumber | ||
|
||
_CSRIdxArray = npt.NDArray[np.unsignedinteger[Any]] | ||
|
||
|
||
class CSR_IO_Buffer: | ||
"""Implement a minimal CSR matrix with specific optimizations for use in this package. | ||
|
||
Operations supported are: | ||
* Incrementally build a CSR from COO, allowing overlapped I/O and CSR conversion for I/O batches, | ||
and a final "merge" step which combines the result. | ||
* Zero intermediate copy conversion of an arbitrary row slice to dense (i.e., mini-batch extraction). | ||
* Parallel processing, where possible (construction, merge, etc.). | ||
* Minimize memory use for index arrays. | ||
|
||
Overall is significantly faster, and uses less memory, than the equivalent ``scipy.sparse`` operations. | ||
""" | ||
|
||
__slots__ = ("indptr", "indices", "data", "shape") | ||
|
||
def __init__( | ||
self, | ||
indptr: _CSRIdxArray, | ||
indices: _CSRIdxArray, | ||
data: NDArrayNumber, | ||
shape: Tuple[int, int], | ||
) -> None: | ||
"""Construct from PJV format.""" | ||
assert len(data) == len(indices) | ||
assert len(data) <= np.iinfo(indptr.dtype).max | ||
assert shape[1] <= np.iinfo(indices.dtype).max | ||
assert indptr[-1] == len(data) and indptr[0] == 0 | ||
|
||
self.shape = shape | ||
self.indptr = indptr | ||
self.indices = indices | ||
self.data = data | ||
|
||
@staticmethod | ||
def from_ijd( | ||
i: _CSRIdxArray, j: _CSRIdxArray, d: NDArrayNumber, shape: Tuple[int, int] | ||
) -> "CSR_IO_Buffer": | ||
"""Factory from COO""" | ||
nnz = len(d) | ||
indptr: _CSRIdxArray = np.zeros((shape[0] + 1), dtype=smallest_uint_dtype(nnz)) | ||
indices: _CSRIdxArray = np.empty((nnz,), dtype=smallest_uint_dtype(shape[1])) | ||
data = np.empty((nnz,), dtype=d.dtype) | ||
_coo_to_csr_inner(shape[0], i, j, d, indptr, indices, data) | ||
return CSR_IO_Buffer(indptr, indices, data, shape) | ||
|
||
@staticmethod | ||
def from_pjd( | ||
p: _CSRIdxArray, j: _CSRIdxArray, d: NDArrayNumber, shape: Tuple[int, int] | ||
) -> "CSR_IO_Buffer": | ||
"""Factory from CSR""" | ||
return CSR_IO_Buffer(p, j, d, shape) | ||
|
||
@property | ||
def nnz(self) -> int: | ||
return len(self.indices) | ||
|
||
@property | ||
def nbytes(self) -> int: | ||
return int(self.indptr.nbytes + self.indices.nbytes + self.data.nbytes) | ||
|
||
@property | ||
def dtype(self) -> npt.DTypeLike: | ||
return self.data.dtype | ||
|
||
def slice_tonumpy(self, row_index: slice) -> NDArrayNumber: | ||
"""Extract slice as a dense ndarray. Does not assume any particular ordering of minor axis.""" | ||
assert isinstance(row_index, slice) | ||
assert row_index.step in (1, None) | ||
row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) | ||
n_rows = max(row_idx_end - row_idx_start, 0) | ||
out = np.zeros((n_rows, self.shape[1]), dtype=self.data.dtype) | ||
if n_rows >= 0: | ||
_csr_to_dense_inner( | ||
row_idx_start, n_rows, self.indptr, self.indices, self.data, out | ||
) | ||
return out | ||
|
||
def slice_toscipy(self, row_index: slice) -> sparse.csr_matrix: | ||
"""Extract slice as a ``sparse.csr_matrix``. Does not assume any particular ordering of | ||
minor axis, but will return a canonically ordered scipy sparse object.""" | ||
assert isinstance(row_index, slice) | ||
assert row_index.step in (1, None) | ||
row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) | ||
n_rows = max(row_idx_end - row_idx_start, 0) | ||
if n_rows == 0: | ||
return sparse.csr_matrix((0, self.shape[1]), dtype=self.dtype) | ||
|
||
indptr = self.indptr[row_idx_start : row_idx_end + 1].copy() | ||
indices = self.indices[indptr[0] : indptr[-1]].copy() | ||
data = self.data[indptr[0] : indptr[-1]].copy() | ||
indptr -= indptr[0] | ||
return sparse.csr_matrix((data, indices, indptr), shape=(n_rows, self.shape[1])) | ||
|
||
@staticmethod | ||
def merge(mtxs: Sequence["CSR_IO_Buffer"]) -> "CSR_IO_Buffer": | ||
assert len(mtxs) > 0 | ||
nnz = sum(m.nnz for m in mtxs) | ||
shape = mtxs[0].shape | ||
for m in mtxs[1:]: | ||
assert m.shape == mtxs[0].shape | ||
assert m.indices.dtype == mtxs[0].indices.dtype | ||
assert all(m.shape == shape for m in mtxs) | ||
|
||
indptr = np.sum( | ||
[m.indptr for m in mtxs], axis=0, dtype=smallest_uint_dtype(nnz) | ||
) | ||
indices = np.empty((nnz,), dtype=mtxs[0].indices.dtype) | ||
data = np.empty((nnz,), mtxs[0].data.dtype) | ||
|
||
_csr_merge_inner( | ||
tuple((m.indptr.astype(indptr.dtype), m.indices, m.data) for m in mtxs), | ||
indptr, | ||
indices, | ||
data, | ||
) | ||
return CSR_IO_Buffer.from_pjd(indptr, indices, data, shape) | ||
|
||
def sort_indices(self) -> Self: | ||
"""Sort indices, IN PLACE.""" | ||
_csr_sort_indices(self.indptr, self.indices, self.data) | ||
return self | ||
|
||
|
||
def smallest_uint_dtype(max_val: int) -> Type[np.unsignedinteger[Any]]: | ||
dts: List[Type[np.unsignedinteger[Any]]] = [np.uint16, np.uint32] | ||
for dt in dts: | ||
if max_val <= np.iinfo(dt).max: | ||
return dt | ||
else: | ||
return np.uint64 | ||
|
||
|
||
@numba.njit(nogil=True, parallel=True) # type:ignore[misc] | ||
def _csr_merge_inner( | ||
As: Tuple[Tuple[_CSRIdxArray, _CSRIdxArray, NDArrayNumber], ...], # P,J,D | ||
Bp: _CSRIdxArray, | ||
Bj: _CSRIdxArray, | ||
Bd: NDArrayNumber, | ||
) -> None: | ||
n_rows = len(Bp) - 1 | ||
offsets = Bp.copy() | ||
for Ap, Aj, Ad in As: | ||
n_elmts = Ap[1:] - Ap[:-1] | ||
for n in numba.prange(n_rows): | ||
Bj[offsets[n] : offsets[n] + n_elmts[n]] = Aj[Ap[n] : Ap[n] + n_elmts[n]] | ||
Bd[offsets[n] : offsets[n] + n_elmts[n]] = Ad[Ap[n] : Ap[n] + n_elmts[n]] | ||
offsets[:-1] += n_elmts | ||
|
||
|
||
@numba.njit(nogil=True, parallel=True) # type:ignore[misc] | ||
def _csr_to_dense_inner( | ||
row_idx_start: int, | ||
n_rows: int, | ||
indptr: _CSRIdxArray, | ||
indices: _CSRIdxArray, | ||
data: NDArrayNumber, | ||
out: NDArrayNumber, | ||
) -> None: | ||
for i in numba.prange(row_idx_start, row_idx_start + n_rows): | ||
for j in range(indptr[i], indptr[i + 1]): | ||
out[i - row_idx_start, indices[j]] = data[j] | ||
|
||
|
||
@numba.njit(nogil=True, parallel=True, inline="always") # type:ignore[misc] | ||
def _count_rows(n_rows: int, Ai: NDArrayNumber, Bp: NDArrayNumber) -> NDArrayNumber: | ||
"""Private: parallel row count.""" | ||
nnz = len(Ai) | ||
|
||
partition_size = 32 * 1024**2 | ||
n_partitions = ceil(nnz / partition_size) | ||
if n_partitions > 1: | ||
counts = np.zeros((n_partitions, n_rows), dtype=Bp.dtype) | ||
for p in numba.prange(n_partitions): | ||
for n in range(p * partition_size, min(nnz, (p + 1) * partition_size)): | ||
row = Ai[n] | ||
counts[p, row] += 1 | ||
|
||
Bp[:-1] = counts.sum(axis=0) | ||
else: | ||
for n in range(nnz): | ||
row = Ai[n] | ||
Bp[row] += 1 | ||
|
||
return Bp | ||
|
||
|
||
@numba.njit(nogil=True, parallel=True) # type:ignore[misc] | ||
def _coo_to_csr_inner( | ||
n_rows: int, | ||
Ai: _CSRIdxArray, | ||
Aj: _CSRIdxArray, | ||
Ad: NDArrayNumber, | ||
Bp: _CSRIdxArray, | ||
Bj: _CSRIdxArray, | ||
Bd: NDArrayNumber, | ||
) -> None: | ||
nnz = len(Ai) | ||
|
||
_count_rows(n_rows, Ai, Bp) | ||
|
||
# cum sum to get the row index pointers (NOTE: starting with zero) | ||
cumsum = 0 | ||
for n in range(n_rows): | ||
tmp = Bp[n] | ||
Bp[n] = cumsum | ||
cumsum += tmp | ||
Bp[n_rows] = nnz | ||
|
||
# Reorganize all the data. Side effect: pointers shifted (reversed in the | ||
# subsequent section). | ||
# | ||
# Method is concurrent (partitioned by rows) if number of rows is greater | ||
# than 2**partition_bits. This partitioning scheme leverages the fact | ||
# that reads are much cheaper than writes. | ||
# | ||
# The code is equivalent to: | ||
# for n in range(nnz): | ||
# row = Ai[n] | ||
# dst_row = Bp[row] | ||
# Bj[dst_row] = Aj[n] | ||
# Bd[dst_row] = Ad[n] | ||
# Bp[row] += 1 | ||
|
||
partition_bits = 13 | ||
n_partitions = (n_rows + 2**partition_bits - 1) >> partition_bits | ||
for p in numba.prange(n_partitions): | ||
for n in range(nnz): | ||
row = Ai[n] | ||
if (row >> partition_bits) != p: | ||
continue | ||
dst_row = Bp[row] | ||
Bj[dst_row] = Aj[n] | ||
Bd[dst_row] = Ad[n] | ||
Bp[row] += 1 | ||
|
||
# Shift the pointers by one slot (i.e., start at zero) | ||
prev_ptr = 0 | ||
for n in range(n_rows + 1): | ||
tmp = Bp[n] | ||
Bp[n] = prev_ptr | ||
prev_ptr = tmp | ||
|
||
|
||
@numba.njit(nogil=True, parallel=True) # type:ignore[misc] | ||
def _csr_sort_indices(Bp: _CSRIdxArray, Bj: _CSRIdxArray, Bd: NDArrayNumber) -> None: | ||
"""In-place sort of minor axis indices""" | ||
n_rows = len(Bp) - 1 | ||
for r in numba.prange(n_rows): | ||
row_start = Bp[r] | ||
row_end = Bp[r + 1] | ||
order = np.argsort(Bj[row_start:row_end]) | ||
Bj[row_start:row_end] = Bj[row_start:row_end][order] | ||
Bd[row_start:row_end] = Bd[row_start:row_end][order] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation | ||
# Copyright (c) 2021-2024 TileDB, Inc. | ||
# | ||
# Licensed under the MIT License. | ||
|
||
import logging | ||
import os | ||
from typing import Tuple | ||
|
||
import torch | ||
|
||
logger = logging.getLogger("tiledbsoma_ml.pytorch") | ||
|
||
|
||
def get_distributed_world_rank() -> Tuple[int, int]: | ||
"""Return tuple containing equivalent of ``torch.distributed`` world size and rank.""" | ||
world_size, rank = 1, 0 | ||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ: | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
rank = int(os.environ["RANK"]) | ||
elif "LOCAL_RANK" in os.environ and "WORLD_SIZE" in os.environ: | ||
# Lightning doesn't use RANK! LOCAL_RANK is only for the local node. There | ||
# is a NODE_RANK for the node's rank, but no way to tell the local node's | ||
# world. So computing a global rank is impossible(?). Using LOCAL_RANK as a | ||
# proxy, which works fine on a single-CPU box. TODO: could throw/error | ||
# if NODE_RANK != 0. | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
rank = int(os.environ["LOCAL_RANK"]) | ||
elif torch.distributed.is_initialized(): | ||
world_size = torch.distributed.get_world_size() | ||
rank = torch.distributed.get_rank() | ||
|
||
return world_size, rank | ||
|
||
|
||
def get_worker_world_rank() -> Tuple[int, int]: | ||
"""Return number of DataLoader workers and our worker rank/id""" | ||
num_workers, worker = 1, 0 | ||
if "WORKER" in os.environ and "NUM_WORKERS" in os.environ: | ||
num_workers = int(os.environ["NUM_WORKERS"]) | ||
worker = int(os.environ["WORKER"]) | ||
else: | ||
worker_info = torch.utils.data.get_worker_info() | ||
if worker_info is not None: | ||
num_workers = worker_info.num_workers | ||
worker = worker_info.id | ||
return num_workers, worker | ||
|
||
|
||
def init_multiprocessing() -> None: | ||
"""Ensures use of "spawn" for starting child processes with multiprocessing. | ||
|
||
Forked processes are known to be problematic: | ||
https://pytorch.org/docs/stable/notes/multiprocessing.html#avoiding-and-fighting-deadlocks | ||
Also, CUDA does not support forked child processes: | ||
https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing | ||
|
||
Private. | ||
""" | ||
orig_start_method = torch.multiprocessing.get_start_method() | ||
if orig_start_method != "spawn": | ||
if orig_start_method: | ||
logger.warning( | ||
"switching torch multiprocessing start method from " | ||
f'"{torch.multiprocessing.get_start_method()}" to "spawn"' | ||
) | ||
torch.multiprocessing.set_start_method("spawn", force=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation | ||
# Copyright (c) 2021-2024 TileDB, Inc. | ||
# | ||
# Licensed under the MIT License. | ||
|
||
from contextlib import contextmanager | ||
from typing import Dict, Generator, Union | ||
|
||
import attrs | ||
from tiledbsoma import Experiment, SOMATileDBContext | ||
|
||
|
||
@attrs.define(frozen=True, kw_only=True) | ||
class ExperimentLocator: | ||
"""State required to open the Experiment. | ||
|
||
Serializable across multiple processes. | ||
|
||
Private implementation class. | ||
""" | ||
|
||
uri: str | ||
tiledb_timestamp_ms: int | ||
tiledb_config: Dict[str, Union[str, float]] | ||
|
||
@classmethod | ||
def create(cls, experiment: Experiment) -> "ExperimentLocator": | ||
return ExperimentLocator( | ||
uri=experiment.uri, | ||
tiledb_timestamp_ms=experiment.tiledb_timestamp_ms, | ||
tiledb_config=experiment.context.tiledb_config, | ||
) | ||
|
||
@contextmanager | ||
def open_experiment(self) -> Generator[Experiment, None, None]: | ||
context = SOMATileDBContext(tiledb_config=self.tiledb_config) | ||
yield Experiment.open( | ||
self.uri, tiledb_timestamp=self.tiledb_timestamp_ms, context=context | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need to define __all__?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's down at L14, unchanged