Skip to content

Commit

Permalink
Add cupy support + CI (#1066)
Browse files Browse the repository at this point in the history
* Add GPU CI

Signed-off-by: zethson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add draft of Test action

Signed-off-by: zethson <[email protected]>

* Remove python specification

Signed-off-by: zethson <[email protected]>

* Switch to mamba

Signed-off-by: zethson <[email protected]>

* Add shell check

Co-authored-by: Isaac Virshup <[email protected]>

* Switch to mamba

Signed-off-by: zethson <[email protected]>

* micromamba list

Signed-off-by: zethson <[email protected]>

* Add shell

Signed-off-by: zethson <[email protected]>

* Add environment-name

Signed-off-by: zethson <[email protected]>

* rename environment-name

Signed-off-by: zethson <[email protected]>

* specify python

* Remove env name

* Don't make a shell

* add env name

* Get git info so version is specified right

* proper cirun label

* Add gpu mark, --only-gpu argument

* Start tests

* Add gpu tests to ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Better lazy importing

* Fix test skipping

* Views

* Update conftest

* Basic concatenation support + bedtime

* `.raw` copys from VRAM to RAM (#1078)

* `.raw` copys from VRAM to RAM

* Update anndata/_core/raw.py

Co-authored-by: Isaac Virshup <[email protected]>

* Update anndata/_core/raw.py

Co-authored-by: Isaac Virshup <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update anndata/_core/raw.py

Co-authored-by: Isaac Virshup <[email protected]>

---------

Co-authored-by: Isaac Virshup <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Better gpu test

* Deduplicate some test params

* Support IO

* Fixes related to cupy/cupy#7757

* coverage

* Cancel jobs if new commits are pushed + whitespace to trigger precomit

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update GPU CI name + paralellize GPU CI

* Simplify pytest setup

* Fix typo

* Release note

* Change run rules for GPU CI

---------

Signed-off-by: zethson <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Isaac Virshup <[email protected]>
Co-authored-by: Severin Dicks <[email protected]>
Co-authored-by: Philipp A <[email protected]>
  • Loading branch information
5 people authored Jul 31, 2023
1 parent 0c4c0b0 commit 8b1a7e4
Show file tree
Hide file tree
Showing 17 changed files with 562 additions and 70 deletions.
9 changes: 9 additions & 0 deletions .cirun.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
runners:
- name: aws-gpu-runner
cloud: aws
instance_type: g4dn.xlarge
machine_image: ami-0678adbdcb4c3a662
preemptible: false
workflow: .github/workflows/test-gpu.yml
labels:
- cirun-aws-gpu
54 changes: 54 additions & 0 deletions .github/workflows/test-gpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: AWS GPU

on:
push:
branches: [main]
workflow_dispatch:

# Cancel the job if new commits are pushed
# https://stackoverflow.com/questions/66335225/how-to-cancel-previous-runs-in-the-pr-when-you-push-new-commitsupdate-the-curre
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
test:
runs-on: "cirun-aws-gpu--${{ github.run_id }}"
defaults:
run:
shell: bash -el {0}
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Nvidia SMI sanity check
run: nvidia-smi

- uses: mamba-org/setup-micromamba@v1
with:
micromamba-version: "1.3.1-0"
environment-name: anndata-gpu-ci
create-args: >-
python=3.10
cupy
numba
pytest
pytest-cov
pytest-xdist
init-shell: >-
bash
generate-run-shell: false

- name: Install AnnData
run: pip install .[dev,test,gpu]

- name: Mamba list
run: micromamba list

- name: Run test
run: pytest -m gpu --cov --cov-report=xml --cov-context=test -n 4

- uses: codecov/codecov-action@v3
with:
flags: gpu-tests
4 changes: 4 additions & 0 deletions anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
ZarrArray,
ZappyArray,
DaskArray,
CupyArray,
CupySparseMatrix,
_move_adj_mtx,
)

Expand All @@ -62,6 +64,8 @@ class StorageType(Enum):
ZarrArray = ZarrArray
ZappyArray = ZappyArray
DaskArray = DaskArray
CupyArray = CupyArray
CupySparseMatrix = CupySparseMatrix

@classmethod
def classes(cls):
Expand Down
163 changes: 152 additions & 11 deletions anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from scipy.sparse import spmatrix

from .anndata import AnnData
from ..compat import AwkArray, DaskArray
from ..compat import AwkArray, DaskArray, CupySparseMatrix, CupyArray, CupyCSRMatrix
from ..utils import asarray, dim_len
from .index import _subset, make_slice
from anndata._warnings import ExperimentalFeatureWarning
Expand Down Expand Up @@ -132,24 +132,43 @@ def equal_array(a, b) -> bool:
return equal(pd.DataFrame(a), pd.DataFrame(asarray(b)))


@equal.register(CupyArray)
def equal_cupyarray(a, b) -> bool:
import cupy as cp

return bool(cp.array_equal(a, b, equal_nan=True))


@equal.register(pd.Series)
def equal_series(a, b) -> bool:
return a.equals(b)


@equal.register(sparse.spmatrix)
@equal.register(CupySparseMatrix)
def equal_sparse(a, b) -> bool:
# It's a weird api, don't blame me
if isinstance(b, sparse.spmatrix):
import array_api_compat

xp = array_api_compat.array_namespace(a.data)

if isinstance(b, (CupySparseMatrix, sparse.spmatrix)):
if isinstance(a, CupySparseMatrix):
# Comparison broken for CSC matrices
# https://github.com/cupy/cupy/issues/7757
a, b = CupyCSRMatrix(a), CupyCSRMatrix(b)
comp = a != b
if isinstance(comp, bool):
return not comp
if isinstance(comp, CupySparseMatrix):
# https://github.com/cupy/cupy/issues/7751
comp = comp.get()
# fmt: off
return (
(len(comp.data) == 0)
or (
np.isnan(a[comp]).all()
and np.isnan(b[comp]).all()
xp.isnan(a[comp]).all()
and xp.isnan(b[comp]).all()
)
)
# fmt: on
Expand All @@ -171,6 +190,17 @@ def as_sparse(x):
return x


def as_cp_sparse(x) -> CupySparseMatrix:
import cupyx.scipy.sparse as cpsparse

if isinstance(x, cpsparse.spmatrix):
return x
elif isinstance(x, np.ndarray):
return cpsparse.csr_matrix(as_sparse(x))
else:
return cpsparse.csr_matrix(x)


def unify_dtypes(dfs: Iterable[pd.DataFrame]) -> list[pd.DataFrame]:
"""
Attempts to unify datatypes from multiple dataframes.
Expand Down Expand Up @@ -268,6 +298,43 @@ def check_combinable_cols(cols: list[pd.Index], join: Literal["inner", "outer"])
)


# TODO: open PR or feature request to cupy
def _cpblock_diag(mats, format=None, dtype=None):
"""
Modified version of scipy.sparse.block_diag for cupy sparse.
"""
import cupy as cp
from cupyx.scipy import sparse as cpsparse

row = []
col = []
data = []
r_idx = 0
c_idx = 0
for a in mats:
# if isinstance(a, (list, numbers.Number)):
# a = cpsparse.coo_matrix(a)
nrows, ncols = a.shape
if cpsparse.issparse(a):
a = a.tocoo()
row.append(a.row + r_idx)
col.append(a.col + c_idx)
data.append(a.data)
else:
a_row, a_col = cp.divmod(cp.arange(nrows * ncols), ncols)
row.append(a_row + r_idx)
col.append(a_col + c_idx)
data.append(a.reshape(-1))
r_idx += nrows
c_idx += ncols
row = cp.concatenate(row)
col = cp.concatenate(col)
data = cp.concatenate(data)
return cpsparse.coo_matrix(
(data, (row, col)), shape=(r_idx, c_idx), dtype=dtype
).asformat(format)


###################
# Per element logic
###################
Expand Down Expand Up @@ -430,6 +497,10 @@ def apply(self, el, *, axis, fill_value=None):
return self._apply_to_awkward(el, axis=axis, fill_value=fill_value)
elif isinstance(el, DaskArray):
return self._apply_to_dask_array(el, axis=axis, fill_value=fill_value)
elif isinstance(el, CupyArray):
return self._apply_to_cupy_array(el, axis=axis, fill_value=fill_value)
elif isinstance(el, CupySparseMatrix):
return self._apply_to_sparse(el, axis=axis, fill_value=fill_value)
else:
return self._apply_to_array(el, axis=axis, fill_value=fill_value)

Expand Down Expand Up @@ -457,6 +528,32 @@ def _apply_to_dask_array(self, el: DaskArray, *, axis, fill_value=None):

return sub_el

def _apply_to_cupy_array(self, el, *, axis, fill_value=None):
import cupy as cp

if fill_value is None:
fill_value = default_fill_value([el])
if el.shape[axis] == 0:
# Presumably faster since it won't allocate the full array
shape = list(el.shape)
shape[axis] = len(self.new_idx)
return cp.broadcast_to(cp.asarray(fill_value), tuple(shape))

old_idx_tuple = [slice(None)] * len(el.shape)
old_idx_tuple[axis] = self.old_pos
old_idx_tuple = tuple(old_idx_tuple)
new_idx_tuple = [slice(None)] * len(el.shape)
new_idx_tuple[axis] = self.new_pos
new_idx_tuple = tuple(new_idx_tuple)

out_shape = list(el.shape)
out_shape[axis] = len(self.new_idx)

out = cp.full(tuple(out_shape), fill_value)
out[new_idx_tuple] = el[old_idx_tuple]

return out

def _apply_to_array(self, el, *, axis, fill_value=None):
if fill_value is None:
fill_value = default_fill_value([el])
Expand All @@ -474,12 +571,20 @@ def _apply_to_array(self, el, *, axis, fill_value=None):
)

def _apply_to_sparse(self, el: spmatrix, *, axis, fill_value=None) -> spmatrix:
if isinstance(el, CupySparseMatrix):
from cupyx.scipy import sparse
else:
from scipy import sparse
import array_api_compat

xp = array_api_compat.array_namespace(el.data)

if fill_value is None:
fill_value = default_fill_value([el])
if fill_value != 0:
to_fill = self.new_idx.get_indexer(self.new_idx.difference(self.old_idx))
else:
to_fill = np.array([])
to_fill = xp.array([])

# Fixing outer indexing for missing values
if el.shape[axis] == 0:
Expand All @@ -489,18 +594,21 @@ def _apply_to_sparse(self, el: spmatrix, *, axis, fill_value=None) -> spmatrix:
if fill_value == 0:
return sparse.csr_matrix(shape)
else:
return np.broadcast_to(fill_value, shape)
return type(el)(xp.broadcast_to(xp.asarray(fill_value), shape))

fill_idxer = None

if len(to_fill) > 0:
idxmtx_dtype = np.promote_types(el.dtype, np.array(fill_value).dtype)
if len(to_fill) > 0 or isinstance(el, CupySparseMatrix):
idxmtx_dtype = xp.promote_types(el.dtype, xp.array(fill_value).dtype)
else:
idxmtx_dtype = bool

if axis == 1:
idxmtx = sparse.coo_matrix(
(np.ones(len(self.new_pos), dtype=bool), (self.old_pos, self.new_pos)),
(
xp.ones(len(self.new_pos), dtype=idxmtx_dtype),
(xp.asarray(self.old_pos), xp.asarray(self.new_pos)),
),
shape=(len(self.old_idx), len(self.new_idx)),
dtype=idxmtx_dtype,
)
Expand All @@ -511,7 +619,10 @@ def _apply_to_sparse(self, el: spmatrix, *, axis, fill_value=None) -> spmatrix:
fill_idxer = (slice(None), to_fill)
elif axis == 0:
idxmtx = sparse.coo_matrix(
(np.ones(len(self.new_pos), dtype=bool), (self.new_pos, self.old_pos)),
(
xp.ones(len(self.new_pos), dtype=idxmtx_dtype),
(xp.asarray(self.new_pos), xp.asarray(self.old_pos)),
),
shape=(len(self.new_idx), len(self.old_idx)),
dtype=idxmtx_dtype,
)
Expand Down Expand Up @@ -626,6 +737,31 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None):
)

return ak.concatenate([f(a) for f, a in zip(reindexers, arrays)], axis=axis)
elif any(isinstance(a, CupySparseMatrix) for a in arrays):
import cupyx.scipy.sparse as cpsparse

sparse_stack = (cpsparse.vstack, cpsparse.hstack)[axis]
return sparse_stack(
[
f(as_cp_sparse(a), axis=1 - axis, fill_value=fill_value)
for f, a in zip(reindexers, arrays)
],
format="csr",
)
elif any(isinstance(a, CupyArray) for a in arrays):
import cupy as cp

if not all(isinstance(a, CupyArray) or 0 in a.shape for a in arrays):
raise NotImplementedError(
"Cannot concatenate a cupy array with other array types."
)
return cp.concatenate(
[
f(cp.asarray(x), fill_value=fill_value, axis=1 - axis)
for f, x in zip(reindexers, arrays)
],
axis=axis,
)
elif any(isinstance(a, sparse.spmatrix) for a in arrays):
sparse_stack = (sparse.vstack, sparse.hstack)[axis]
return sparse_stack(
Expand Down Expand Up @@ -774,7 +910,12 @@ def concat_pairwise_mapping(
m.get(k, sparse.csr_matrix((s, s), dtype=bool))
for m, s in zip(mappings, shapes)
]
result[k] = sparse.block_diag(els, format="csr")
if all(isinstance(el, (CupySparseMatrix, CupyArray)) for el in els):
from cupyx.scipy import sparse as cpsparse

result[k] = _cpblock_diag(els, format="csr")
else:
result[k] = sparse.block_diag(els, format="csr")
return result


Expand Down
8 changes: 7 additions & 1 deletion anndata/_core/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .aligned_mapping import AxisArrays, AxisArraysView
from .sparse_dataset import SparseDataset

from ..compat import CupyArray, CupySparseMatrix


# TODO: Implement views for Raw
class Raw:
Expand All @@ -31,7 +33,11 @@ def __init__(
self._var = _gen_dataframe(var, self.X.shape[1], ["var_names"])
self._varm = AxisArrays(self, 1, varm)
elif X is None: # construct from adata
self._X = adata.X.copy()
# Move from GPU to CPU since it's large and not always used
if isinstance(adata.X, (CupyArray, CupySparseMatrix)):
self._X = adata.X.get()
else:
self._X = adata.X.copy()
self._var = adata.var.copy()
self._varm = AxisArrays(self, 1, adata.varm.copy())
elif adata.isbacked:
Expand Down
Loading

0 comments on commit 8b1a7e4

Please sign in to comment.