Skip to content

Commit

Permalink
speedup(~7x) of the clipping array inside scaling function (#3100)
Browse files Browse the repository at this point in the history
Co-authored-by: Severin Dicks <[email protected]>
Co-authored-by: Intron7 <[email protected]>
  • Loading branch information
3 people authored Jun 18, 2024
1 parent 3f2af97 commit ad657ed
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1.10.2.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
* `sparse_mean_variance_axis` now uses all cores for the calculations {pr}`3015` {smaller}`S Dicks`
* `pp.highly_variable_genes` with `flavor=seurat_v3` now uses a numba kernel {pr}`3017` {smaller}`S Dicks`
* Speed up {func}`~scanpy.pp.scrublet` {pr}`3044` {smaller}`S Dicks` and {pr}`3056` {smaller}`P Angerer`
* Speed up clipping of array in {func}`~scanpy.pp.scale` {pr}`3100` {smaller}`P Ashish & S Dicks`
30 changes: 23 additions & 7 deletions src/scanpy/preprocessing/_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ def _scale_sparse_numba(indptr, indices, data, *, std, mask_obs, clip):
data[j] /= std[indices[j]]


@numba.njit(parallel=True, cache=True)
def clip_array(X: np.ndarray, max_value: float | None = 10, zero_center: bool = True):
a_min, a_max = -max_value, max_value
if X.ndim > 1:
for r, c in numba.pndindex(X.shape):
if X[r, c] > a_max:
X[r, c] = a_max
elif X[r, c] < a_min and zero_center:
X[r, c] = a_min
else:
for i in numba.prange(X.size):
if X[i] > a_max:
X[i] = a_max
elif X[i] < a_min and zero_center:
X[i] = a_min
return X


@renamed_arg("X", "data", pos_0=True)
@old_positionals("zero_center", "max_value", "copy", "layer", "obsm")
@singledispatch
Expand Down Expand Up @@ -197,14 +215,12 @@ def clip_set(x):

X = da.map_blocks(clip_set, X)
else:
if zero_center:
a_min, a_max = -max_value, max_value
X = np.clip(X, a_min, a_max) # dask does not accept these as kwargs
if isinstance(X, DaskArray):
X = X.map_blocks(clip_array, max_value, zero_center)
elif issparse(X):
X.data = clip_array(X.data, max_value=max_value, zero_center=False)
else:
if issparse(X):
X.data[X.data > max_value] = max_value
else:
X[X > max_value] = max_value
X = clip_array(X, max_value=max_value, zero_center=zero_center)
if return_mean_std:
return X, mean, std
else:
Expand Down

0 comments on commit ad657ed

Please sign in to comment.