Skip to content

Commit

Permalink
Issue #479 local UDFs now use the same dimension order as in backend
Browse files Browse the repository at this point in the history
  • Loading branch information
JeroenVerstraelen committed Oct 25, 2023
1 parent a2ff467 commit c7418a4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 26 deletions.
10 changes: 5 additions & 5 deletions openeo/udf/run_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,13 @@ def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.D
d = XarrayDataCube(datacube)
else:
raise ValueError(datacube)
d_array = d.get_array()
expected_order = ("t", "bands", "y", "x")
dims = [d for d in expected_order if d in d_array.dims]

# TODO: skip going through XarrayDataCube above, we only need xarray.DataArray here anyway.
# datacube's data is to be float and x,y not provided
d = XarrayDataCube(d.get_array()
.astype(numpy.float64)
.drop(labels='x')
.drop(labels='y')
)
d = XarrayDataCube(d_array.transpose(*dims).astype(numpy.float64).drop(labels="x").drop(labels="y"))
# wrap to udf_data
udf_data = UdfData(datacube_list=[d])

Expand Down
33 changes: 12 additions & 21 deletions tests/udf/test_run_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,15 @@ def test_execute_local_udf_basic():
assert isinstance(res, UdfData)
result = res.get_datacube_list()[0].get_array()

assert result.shape == (3, 1, 5, 6)
assert result.shape == (3, 1, 6, 5)
expected = xarray.DataArray(
[
[_ndvi(0, 100), _ndvi(1, 101)],
[_ndvi(10, 110), _ndvi(11, 111)]
],
dims=["x", "y"],
[[_ndvi(0, 100), _ndvi(10, 110)], [_ndvi(1, 101), _ndvi(11, 111)]],
dims=["y", "x"],
coords={"t": numpy.datetime64("2020-08-01"), "bands": "ndvi"}
)
xarray.testing.assert_equal(result[0, 0, 0:2, 0:2], expected)

assert result[2, 0, 3, 4] == _ndvi(2034, 2134)
assert result[2, 0, 4, 3] == _ndvi(2034, 2134)


def test_run_local_udf_from_file_json(tmp_path):
Expand All @@ -266,18 +263,15 @@ def test_run_local_udf_from_file_json(tmp_path):
assert isinstance(res, UdfData)
result = res.get_datacube_list()[0].get_array()

assert result.shape == (3, 1, 5, 6)
assert result.shape == (3, 1, 6, 5)
expected = xarray.DataArray(
[
[_ndvi(0, 100), _ndvi(1, 101)],
[_ndvi(10, 110), _ndvi(11, 111)]
],
dims=["x", "y"],
[[_ndvi(0, 100), _ndvi(10, 110)], [_ndvi(1, 101), _ndvi(11, 111)]],
dims=["y", "x"],
coords={"t": numpy.datetime64("2020-08-01"), "bands": "ndvi"}
)
xarray.testing.assert_equal(result[0, 0, 0:2, 0:2], expected)

assert result[2, 0, 3, 4] == _ndvi(2034, 2134)
assert result[2, 0, 4, 3] == _ndvi(2034, 2134)


def test_run_local_udf_from_file_netcdf(tmp_path):
Expand All @@ -294,15 +288,12 @@ def test_run_local_udf_from_file_netcdf(tmp_path):
assert isinstance(res, UdfData)
result = res.get_datacube_list()[0].get_array()

assert result.shape == (3, 1, 5, 6)
assert result.shape == (3, 1, 6, 5)
expected = xarray.DataArray(
[
[_ndvi(0, 100), _ndvi(1, 101)],
[_ndvi(10, 110), _ndvi(11, 111)]
],
dims=["x", "y"],
[[_ndvi(0, 100), _ndvi(10, 110)], [_ndvi(1, 101), _ndvi(11, 111)]],
dims=["y", "x"],
coords={"t": numpy.datetime64("2020-08-01"), "bands": "ndvi"}
)
xarray.testing.assert_equal(result[0, 0, 0:2, 0:2], expected)

assert result[2, 0, 3, 4] == _ndvi(2034, 2134)
assert result[2, 0, 4, 3] == _ndvi(2034, 2134)

0 comments on commit c7418a4

Please sign in to comment.