Skip to content

Commit

Permalink
Generalize wrap_xr
Browse files Browse the repository at this point in the history
- remove limitation around number of extra dims
- support custom dimension names
- easy option to force `y,x` spatial names
- adding missing parameters to docs
  • Loading branch information
Kirill888 committed May 13, 2024
1 parent e975e3f commit 61a43e6
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 27 deletions.
103 changes: 84 additions & 19 deletions odc/geo/_xr_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,20 +363,42 @@ def crop(


def xr_coords(
gbox: SomeGeoBox, crs_coord_name: Optional[str] = _DEFAULT_CRS_COORD_NAME
gbox: SomeGeoBox,
crs_coord_name: Optional[str] = _DEFAULT_CRS_COORD_NAME,
always_yx: bool = False,
dims: Optional[Tuple[str, str]] = None,
) -> Dict[Hashable, xarray.DataArray]:
"""
Dictionary of Coordinates in xarray format.
:param gbox:
:py:class:`~odc.geo.geobox.GeoBox` or :py:class:`~odc.geo.gcp.GCPGeoBox`
:param crs_coord_name:
Use custom name for CRS coordinate, default is "spatial_ref". Set to ``None`` to not generate
CRS coordinate at all.
Use custom name for CRS coordinate, default is "spatial_ref". Set to
``None`` to not generate CRS coordinate at all.
:param always_yx:
If True, always use names ``y,x`` for spatial coordinates even for
geographic geoboxes.
:param dims:
Use custom names for spatial dimensions, default is to use ``y,x`` or
``latitude, longitude`` based on projection used. Dimensions are supplied
in "array" order, i.e. ``('y', 'x')``.
:returns:
Dictionary ``name:str -> xr.DataArray``. Where names are either ``y,x`` for projected or
Dictionary ``name:str -> xr.DataArray``. Where names are either as
supplied by ``dims=`` or otherwise ``y,x`` for projected or
``latitude, longitude`` for geographic.
"""
if dims is None:
if always_yx:
dims = ("y", "x")
else:
dims = gbox.dimensions

attrs = {}
crs = gbox.crs
if crs is not None:
Expand All @@ -387,21 +409,19 @@ def xr_coords(

if isinstance(gbox, GCPGeoBox):
coords: Dict[Hashable, xarray.DataArray] = {
name: _mk_pixel_coord(name, sz)
for name, sz in zip(gbox.dimensions, gbox.shape)
name: _mk_pixel_coord(name, sz) for name, sz in zip(dims, gbox.shape)
}
gcps = gbox.gcps()
else:
transform = gbox.transform
if gbox.axis_aligned:
coords = {
name: _coord_to_xr(name, coord, **attrs)
for name, coord in gbox.coordinates.items()
for name, coord in zip(dims, gbox.coordinates.values())
}
else:
coords = {
name: _mk_pixel_coord(name, sz)
for name, sz in zip(gbox.dimensions, gbox.shape)
name: _mk_pixel_coord(name, sz) for name, sz in zip(dims, gbox.shape)
}

if crs_coord_name is not None and crs is not None:
Expand Down Expand Up @@ -544,6 +564,9 @@ def _extract_transform(

def _locate_geo_info(src: XarrayObject) -> GeoState:
# pylint: disable=too-many-locals
if len(src.dims) < 2:
return GeoState()

sdims = spatial_dims(src, relaxed=True)
if sdims is None:
return GeoState()
Expand Down Expand Up @@ -1009,6 +1032,8 @@ def wrap_xr(
time=None,
nodata=None,
crs_coord_name: Optional[str] = _DEFAULT_CRS_COORD_NAME,
always_yx: bool = False,
dims: Optional[Tuple[str, ...]] = None,
axis: Optional[int] = None,
**attrs,
) -> xarray.DataArray:
Expand All @@ -1019,37 +1044,74 @@ def wrap_xr(
:param gbox: Geobox, must same shape as last two axis of ``im``
:param time: optional time axis value(s), defaults to None
:param nodata: optional `nodata` value, defaults to None
:param crs_coord_name: allows to change name of the crs coordinate variable
:param always_yx: If True, always use names ``y,x`` for spatial coordinates
:param dims: Custom names for spatial dimensions
:param axis: Which axis of the input array corresponds to Y,X
:param attrs: Any other attributes to set on the result
:return: xarray DataArray
"""
# pylint: disable=too-many-locals,too-many-arguments
assert dims is None or len(dims) == im.ndim

if axis is None:
axis = 1 if time is not None else 0
elif axis < 0: # handle numpy style negative axis
axis = int(im.ndim) + axis

if im.ndim == 2 and axis == 1:
im = im[numpy.newaxis, ...]

assert axis in (0, 1) # upto 1 extra dimension on the left only
assert im.ndim - axis - 2 in (0, 1) # upto 1 extra dimension on the right only
assert axis >= 0
assert im.ndim - axis - 2 >= 0
assert im.shape[axis : axis + 2] == gbox.shape

prefix_dims: Tuple[str, ...] = ("time",) if axis == 1 else ()
postfix_dims: Tuple[str, ...] = ("band",) if im.ndim - axis > 2 else ()
def _prefix_dims(n):
if n == 0:
return ()
if n == 1:
return ("time",)
return ("time", *[f"dim_{i}" for i in range(n - 1)])

def _postfix_dims(n):
if n == 0:
return ()
if n == 1:
return ("band",)
return (f"b_{i}" for i in range(n))

sdims: Optional[Tuple[str, str]] = None
if dims is None:
sdims = ("y", "x") if always_yx else gbox.dimensions
dims = (*_prefix_dims(axis), *sdims, *_postfix_dims(im.ndim - axis - 2))
else:
sdims = dims[axis], dims[axis + 1]

prefix_dims = dims[:axis]
postfix_dims = dims[axis + 2 :]

dims = (*prefix_dims, *gbox.dimensions, *postfix_dims)
coords = xr_coords(gbox, crs_coord_name=crs_coord_name)
coords = xr_coords(
gbox,
crs_coord_name=crs_coord_name,
always_yx=always_yx,
dims=sdims,
)

if time is not None:
if not isinstance(time, xarray.DataArray):
if len(prefix_dims) > 0 and isinstance(time, (str, datetime)):
time = [time]

time = xarray.DataArray(time, dims=prefix_dims).astype("datetime64[ns]")
time = xarray.DataArray(time, dims=prefix_dims[:1]).astype("datetime64[ns]")

coords["time"] = time

if postfix_dims:
coords["band"] = xarray.DataArray(
[f"b{i}" for i in range(im.shape[-1])], dims=postfix_dims
)
for a, dim in enumerate(postfix_dims):
nb = im.shape[axis + 2 + a]
coords[dim] = xarray.DataArray(
[f"b{i}" for i in range(nb)], dims=(dim,), name=dim
)

if nodata is not None:
attrs = {"nodata": nodata, **attrs}
Expand Down Expand Up @@ -1079,6 +1141,9 @@ def xr_zeros(
:param crs_coord_name: allows to change name of the crs coordinate variable
:return: :py:class:`xarray.DataArray` filled with zeros (numpy or dask)
.. seealso:: :py:meth:`odc.geo.xr.wrap_xr`
"""
if time is not None:
_shape: Tuple[int, ...] = (len(time), *geobox.shape.yx)
Expand Down
1 change: 0 additions & 1 deletion odc/geo/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def _purge_attributes(_x: xr.DataArray) -> xr.DataArray:
for attr in attributes_to_clear:
_x.attrs.pop(attr, None)
_x.encoding.pop("grid_mapping", None)
_x.encoding.pop("_transform", None)
return _x

# remove non-dimensional coordinate, which is CRS in our case
Expand Down
41 changes: 34 additions & 7 deletions tests/test_xr_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
)

# pylint: disable=redefined-outer-name,import-outside-toplevel,protected-access
TEST_GEOBOXES_SMALL_AXIS_ALIGNED = [
GeoBox.from_bbox((-10, -2, 5, 4), "epsg:4326", tight=True, resolution=0.2),
GeoBox.from_bbox((-10, -2, 5, 4), "epsg:3857", tight=True, resolution=1),
GeoBox.from_bbox((-10, -2, 5, 4), "epsg:3857", tight=True, resolution=resxy_(1, 2)),
]


@pytest.fixture
Expand Down Expand Up @@ -175,6 +180,7 @@ def test_odc_extension(xx_epsg4326: xr.DataArray, geobox_epsg4326: GeoBox):
assert xx.odc.output_geobox("epsg:3857").crs == "epsg:3857"
assert xx.odc.map_bounds() == gbox.map_bounds()
assert xx.odc.output_geobox("utm").crs.epsg is not None
assert xx.odc.aspect == gbox.aspect

# this drops encoding/attributes, but crs/geobox should remain the same
_xx = xx * 10.0
Expand Down Expand Up @@ -351,6 +357,9 @@ def test_wrap_xr():
assert xx.dims == gbox.dims
assert xx.attrs == {}

assert wrap_xr(data, gbox, always_yx=True).dims == ("y", "x")
assert wrap_xr(data, gbox, dims=("Y", "X")).dims == ("Y", "X")

xx = wrap_xr(data, gbox, nodata=None)
assert xx.attrs == {}

Expand All @@ -376,6 +385,31 @@ def test_wrap_xr():
assert xx.band.data.tolist() == ["b0"]


@pytest.mark.parametrize("gbox", TEST_GEOBOXES_SMALL_AXIS_ALIGNED)
@pytest.mark.parametrize("nprefix", [0, 1, 2])
@pytest.mark.parametrize("npostfix", [0, 1, 2])
def test_wrap_xr_nd(gbox: GeoBox, nprefix: int, npostfix: int):
shape = (1,) * nprefix + gbox.shape + (3,) * npostfix
data = np.zeros(shape, dtype="uint16")
xx = wrap_xr(data, gbox, axis=nprefix)
assert xx.odc.geobox == gbox
assert xx.odc.ydim == nprefix
assert xx.dims[:nprefix] == ("time", "dim_0", "dim_1", "dim_2", "dim_3")[:nprefix]

if npostfix == 1:
assert xx.dims[-1] == "band"
if npostfix > 1:
assert xx.dims[nprefix + 2 :] == tuple(f"b_{i}" for i in range(npostfix))

_dims = tuple(f"custom_{dim}" for dim in xx.dims)
_dims = _dims[:nprefix] + ("y", "x") + _dims[nprefix + 2 :]

yy = wrap_xr(data, gbox, axis=nprefix, dims=_dims)
assert yy.dims == _dims
assert yy.odc.geobox == gbox
assert yy.odc.spatial_dims == _dims[nprefix : nprefix + 2]


@pytest.mark.parametrize("xx_time", [None, ["2020-01-30"]])
@pytest.mark.parametrize("xx_chunks", [None, (-1, -1), (4, 4)])
def test_xr_reproject(xx_epsg4326: xr.DataArray):
Expand Down Expand Up @@ -480,13 +514,6 @@ def test_is_dask_collection():
assert is_dask_collection is dask.is_dask_collection


TEST_GEOBOXES_SMALL_AXIS_ALIGNED = [
GeoBox.from_bbox((-10, -2, 5, 4), "epsg:4326", tight=True, resolution=0.2),
GeoBox.from_bbox((-10, -2, 5, 4), "epsg:3857", tight=True, resolution=1),
GeoBox.from_bbox((-10, -2, 5, 4), "epsg:3857", tight=True, resolution=resxy_(1, 2)),
]


@pytest.mark.parametrize("geobox", TEST_GEOBOXES_SMALL_AXIS_ALIGNED)
@pytest.mark.parametrize(
"bad_geo_transform",
Expand Down

0 comments on commit 61a43e6

Please sign in to comment.