Skip to content

Commit

Permalink
Simplified sparse_xr_dot
Browse files Browse the repository at this point in the history
Used chunking suggestion from pydata/xarray#9934
  • Loading branch information
brendan-m-murphy committed Jan 9, 2025
1 parent 1c6c769 commit 8c415e2
Showing 1 changed file with 22 additions and 84 deletions.
106 changes: 22 additions & 84 deletions openghg_inversions/array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
work correctly.
"""

from typing import Any, TypeVar
from typing import Any, overload, TypeVar
from collections.abc import Sequence

import numpy as np
import pandas as pd
import sparse
import xarray as xr
from sparse import COO
from xarray.core.common import DataWithCoords
from sparse import COO, SparseArray
from xarray.core.common import DataWithCoords, is_chunked_array # type: ignore


# type for xr.Dataset *or* xr.DataArray
Expand Down Expand Up @@ -72,11 +72,15 @@ def get_xr_dummies(
return result.unstack(stack_dim) if stack_dim else result


def sparse_xr_dot(
da1: xr.DataArray,
da2: DataSetOrArray,
debug: bool = False,
) -> DataSetOrArray:
@overload
def sparse_xr_dot(da1: xr.DataArray, da2: xr.DataArray) -> xr.DataArray: ...


@overload
def sparse_xr_dot(da1: xr.DataArray, da2: xr.Dataset) -> xr.Dataset: ...


def sparse_xr_dot(da1: xr.DataArray, da2: xr.DataArray | xr.Dataset) -> xr.DataArray | xr.Dataset:
"""Compute the matrix "dot" of a tuple of DataArrays with sparse.COO values.
This multiplies and sums over all common dimensions of the input DataArrays, and
Expand All @@ -85,87 +89,21 @@ def sparse_xr_dot(
Common dimensions are automatically selected by name. The input arrays must have at
least one dimension in common. All matching dimensions will be used for multiplication.
NOTE: this function shouldn't be necessary, but `da1 @ da2` doesn't work properly if the
values of `da1` and `da2` are `sparse.COO` arrays.
Compared to just using da1 @ da2, this function has two advantages:
1. if da1 is sparse but not a dask array, then da1 @ da2 will fail if da2 is a dask array
2. da2 can be a Dataset, and current DataArray @ Dataset is not allowed by xarray
Args:
da1, da2: xr.DataArrays to multiply and sum along common dimensions.
debug: if true, will print the dimensions of the inputs to `sparse.tensordot`
as well as the dimension of the result.
along_dim: name
Returns:
xr.Dataset or xr.DataArray containing the result of matrix/tensor multiplication.
The type that is returned will be the same as the type of `da2`.
Raises:
ValueError if the input DataArrays have no common dimensions to multiply.
"""
common_dims = set(da1.dims).intersection(set(da2.dims))
nc = len(common_dims)

dims1 = set(da1.dims) - common_dims
dims2 = set(da2.dims) - common_dims

broadcast_dims = list(dims1) + list(dims2)


if nc == 0:
raise ValueError(f"DataArrays \n{da1}\n{da2}\n have no common dimensions. Cannot compute `dot`.")

tensor_dot_axes = tuple([tuple(range(-nc, 0))] * 2)
input_core_dims = [list(common_dims)] * 2

if debug:
print("common dims:", common_dims)
print("dims1:", dims1)
print("dims2:", dims2)
print("broadcast dims:", broadcast_dims)


# xarray will insert new axes into broadcast dims so that the number of axes
# in da1 and da2 are equal, unless the broadcast dim to be added would come first (from left to right)
# we need to remove these axes, because sparse.tensordot does not
to_select1 = []
to_select2 = []

for dims, to_select in zip([dims1, dims2], [to_select1, to_select2]):
for bdim in broadcast_dims:
if bdim in dims:
to_select.append(slice(None))
elif to_select:
to_select.append(0)

to_select = tuple(to_select1 + to_select2)

if debug:
print("select:", to_select)

# compute tensor dot on last nc coordinates (because core dims are moved to end)
# and then drop 1D coordinates resulting from summing
def _func(x, y):
result = sparse.tensordot(x, y, axes=tensor_dot_axes) # type: ignore

if debug:
print("raw _func result shape:", result.shape)
return result[to_select]

def wrapper(da1, da2):
for arr in [da1, da2]:
print(f"_func received array of type {type(arr)}, shape {arr.shape}")
result = _func(da1, da2)
print(f"_func result shape: {result.shape}\n")
return result

func = wrapper if debug else _func

# return xr.apply_ufunc(func, da1, da2.as_numpy(), input_core_dims=input_core_dims, join="outer")
return xr.apply_ufunc(
func,
da1.transpose(..., *common_dims), # fix for issue with xarray 2025.1.0 release
da2.transpose(..., *common_dims), # this makes removing broadcast dims easier..
input_core_dims=input_core_dims,
join="outer",
dask="parallelized",
output_dtypes=[da1.dtype],
)
if isinstance(da1.data, SparseArray) and not is_chunked_array(da1): # type: ignore
da1 = da1.chunk()

if isinstance(da2, xr.DataArray):
return da1 @ da2

return da2.map(lambda x: da1 @ x)

0 comments on commit 8c415e2

Please sign in to comment.