From 8c415e24ae75f530411090258552b10f633fa1d1 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Thu, 9 Jan 2025 15:36:47 +0000 Subject: [PATCH] Simplified `sparse_xr_dot` Used chunking suggestion from https://github.com/pydata/xarray/issues/9934 --- openghg_inversions/array_ops.py | 106 +++++++------------------------- 1 file changed, 22 insertions(+), 84 deletions(-) diff --git a/openghg_inversions/array_ops.py b/openghg_inversions/array_ops.py index 9b589f0..84b0b68 100644 --- a/openghg_inversions/array_ops.py +++ b/openghg_inversions/array_ops.py @@ -12,15 +12,15 @@ work correctly. """ -from typing import Any, TypeVar +from typing import Any, overload, TypeVar from collections.abc import Sequence import numpy as np import pandas as pd import sparse import xarray as xr -from sparse import COO -from xarray.core.common import DataWithCoords +from sparse import COO, SparseArray +from xarray.core.common import DataWithCoords, is_chunked_array # type: ignore # type for xr.Dataset *or* xr.DataArray @@ -72,11 +72,15 @@ def get_xr_dummies( return result.unstack(stack_dim) if stack_dim else result -def sparse_xr_dot( - da1: xr.DataArray, - da2: DataSetOrArray, - debug: bool = False, -) -> DataSetOrArray: +@overload +def sparse_xr_dot(da1: xr.DataArray, da2: xr.DataArray) -> xr.DataArray: ... + + +@overload +def sparse_xr_dot(da1: xr.DataArray, da2: xr.Dataset) -> xr.Dataset: ... + + +def sparse_xr_dot(da1: xr.DataArray, da2: xr.DataArray | xr.Dataset) -> xr.DataArray | xr.Dataset: """Compute the matrix "dot" of a tuple of DataArrays with sparse.COO values. This multiplies and sums over all common dimensions of the input DataArrays, and @@ -85,87 +89,21 @@ def sparse_xr_dot( Common dimensions are automatically selected by name. The input arrays must have at least one dimension in common. All matching dimensions will be used for multiplication. - NOTE: this function shouldn't be necessary, but `da1 @ da2` doesn't work properly if the - values of `da1` and `da2` are `sparse.COO` arrays. + Compared to just using da1 @ da2, this function has two advantages: + 1. if da1 is sparse but not a dask array, then da1 @ da2 will fail if da2 is a dask array + 2. da2 can be a Dataset, and current DataArray @ Dataset is not allowed by xarray Args: da1, da2: xr.DataArrays to multiply and sum along common dimensions. - debug: if true, will print the dimensions of the inputs to `sparse.tensordot` - as well as the dimension of the result. - along_dim: name Returns: xr.Dataset or xr.DataArray containing the result of matrix/tensor multiplication. The type that is returned will be the same as the type of `da2`. - - Raises: - ValueError if the input DataArrays have no common dimensions to multiply. """ - common_dims = set(da1.dims).intersection(set(da2.dims)) - nc = len(common_dims) - - dims1 = set(da1.dims) - common_dims - dims2 = set(da2.dims) - common_dims - - broadcast_dims = list(dims1) + list(dims2) - - - if nc == 0: - raise ValueError(f"DataArrays \n{da1}\n{da2}\n have no common dimensions. Cannot compute `dot`.") - - tensor_dot_axes = tuple([tuple(range(-nc, 0))] * 2) - input_core_dims = [list(common_dims)] * 2 - - if debug: - print("common dims:", common_dims) - print("dims1:", dims1) - print("dims2:", dims2) - print("broadcast dims:", broadcast_dims) - - - # xarray will insert new axes into broadcast dims so that the number of axes - # in da1 and da2 are equal, unless the broadcast dim to be added would come first (from left to right) - # we need to remove these axes, because sparse.tensordot does not - to_select1 = [] - to_select2 = [] - - for dims, to_select in zip([dims1, dims2], [to_select1, to_select2]): - for bdim in broadcast_dims: - if bdim in dims: - to_select.append(slice(None)) - elif to_select: - to_select.append(0) - - to_select = tuple(to_select1 + to_select2) - - if debug: - print("select:", to_select) - - # compute tensor dot on last nc coordinates (because core dims are moved to end) - # and then drop 1D coordinates resulting from summing - def _func(x, y): - result = sparse.tensordot(x, y, axes=tensor_dot_axes) # type: ignore - - if debug: - print("raw _func result shape:", result.shape) - return result[to_select] - - def wrapper(da1, da2): - for arr in [da1, da2]: - print(f"_func received array of type {type(arr)}, shape {arr.shape}") - result = _func(da1, da2) - print(f"_func result shape: {result.shape}\n") - return result - - func = wrapper if debug else _func - - # return xr.apply_ufunc(func, da1, da2.as_numpy(), input_core_dims=input_core_dims, join="outer") - return xr.apply_ufunc( - func, - da1.transpose(..., *common_dims), # fix for issue with xarray 2025.1.0 release - da2.transpose(..., *common_dims), # this makes removing broadcast dims easier.. - input_core_dims=input_core_dims, - join="outer", - dask="parallelized", - output_dtypes=[da1.dtype], - ) + if isinstance(da1.data, SparseArray) and not is_chunked_array(da1): # type: ignore + da1 = da1.chunk() + + if isinstance(da2, xr.DataArray): + return da1 @ da2 + + return da2.map(lambda x: da1 @ x)