Skip to content

Commit

Permalink
Updated sparse_xr_dot to work with dask
Browse files Browse the repository at this point in the history
plus added more complete fix for removing axes
added by apply_ufunc
  • Loading branch information
brendan-m-murphy committed Jan 9, 2025
1 parent 47f15ab commit 1c6c769
Showing 1 changed file with 45 additions and 44 deletions.
89 changes: 45 additions & 44 deletions openghg_inversions/array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def sparse_xr_dot(
da1: xr.DataArray,
da2: DataSetOrArray,
debug: bool = False,
broadcast_dims: Sequence[str] | None = None,
) -> DataSetOrArray:
"""Compute the matrix "dot" of a tuple of DataArrays with sparse.COO values.
Expand Down Expand Up @@ -105,66 +104,68 @@ def sparse_xr_dot(
common_dims = set(da1.dims).intersection(set(da2.dims))
nc = len(common_dims)

dims1 = set(da1.dims) - common_dims
dims2 = set(da2.dims) - common_dims

broadcast_dims = list(dims1) + list(dims2)


if nc == 0:
raise ValueError(f"DataArrays \n{da1}\n{da2}\n have no common dimensions. Cannot compute `dot`.")

if broadcast_dims is not None:
_broadcast_dims = set(broadcast_dims).intersection(common_dims)
else:
_broadcast_dims = set([])
tensor_dot_axes = tuple([tuple(range(-nc, 0))] * 2)
input_core_dims = [list(common_dims)] * 2

contract_dims = common_dims.difference(_broadcast_dims)
ncontract = len(contract_dims)
if debug:
print("common dims:", common_dims)
print("dims1:", dims1)
print("dims2:", dims2)
print("broadcast dims:", broadcast_dims)

tensor_dot_axes = tuple([tuple(range(-ncontract, 0))] * 2)
input_core_dims = [list(contract_dims)] * 2

# compute tensor dot on last nc coordinates (because core dims are moved to end)
# and then drop 1D coordinates resulting from summing
def _func(x, y, debug=False):
result = sparse.tensordot(x, y, axes=tensor_dot_axes) # type: ignore
# xarray will insert new axes into broadcast dims so that the number of axes
# in da1 and da2 are equal, unless the broadcast dim to be added would come first (from left to right)
# we need to remove these axes, because sparse.tensordot does not
to_select1 = []
to_select2 = []

xs = list(x.shape[:-ncontract])
nxs = len(xs)
ys = list(y.shape[:-ncontract])
nys = len(ys)
pad_y = nxs - nys
if pad_y > 0:
ys = [1] * pad_y + ys

idx1, idx2 = [], []
for i, j in zip(xs, ys):
if j in {i, 1}:
idx1.append(slice(None))
idx2.append(0)
elif i == 1:
# x broadcasted to match y's dim
idx1.append(0)
idx2.append(slice(None))
for dims, to_select in zip([dims1, dims2], [to_select1, to_select2]):
for bdim in broadcast_dims:
if bdim in dims:
to_select.append(slice(None))
elif to_select:
to_select.append(0)

if debug:
print("pad y", pad_y)
print(xs, ys)
print(idx1, idx2)
to_select = tuple(to_select1 + to_select2)

idx2 = idx2[pad_y:]
idx3 = [0] * (result.ndim - len(idx1) - len(idx2))
idx = tuple(idx1 + idx3 + idx2)
if debug:
print("select:", to_select)

if debug:
print("x.shape", x.shape, "y.shape", y.shape)
print("idx", idx)
print("result shape:", result.shape)
# compute tensor dot on last nc coordinates (because core dims are moved to end)
# and then drop 1D coordinates resulting from summing
def _func(x, y):
result = sparse.tensordot(x, y, axes=tensor_dot_axes) # type: ignore

return result[idx] # type: ignore
if debug:
print("raw _func result shape:", result.shape)
return result[to_select]

def wrapper(da1, da2):
for arr in [da1, da2]:
print(f"_func received array of type {type(arr)}, shape {arr.shape}")
result = _func(da1, da2, debug=True)
result = _func(da1, da2)
print(f"_func result shape: {result.shape}\n")
return result

func = wrapper if debug else _func

return xr.apply_ufunc(func, da1, da2.as_numpy(), input_core_dims=input_core_dims, join="outer")
# return xr.apply_ufunc(func, da1, da2.as_numpy(), input_core_dims=input_core_dims, join="outer")
return xr.apply_ufunc(
func,
da1.transpose(..., *common_dims), # fix for issue with xarray 2025.1.0 release
da2.transpose(..., *common_dims), # this makes removing broadcast dims easier..
input_core_dims=input_core_dims,
join="outer",
dask="parallelized",
output_dtypes=[da1.dtype],
)

0 comments on commit 1c6c769

Please sign in to comment.