From c7418a43785adb10758dcd1229b720af6860def1 Mon Sep 17 00:00:00 2001 From: JeroenVerstraelen Date: Wed, 25 Oct 2023 11:10:38 +0200 Subject: [PATCH] Issue #479 local UDFs now use the same dimension order as in backend --- openeo/udf/run_code.py | 10 +++++----- tests/udf/test_run_code.py | 33 ++++++++++++--------------------- 2 files changed, 17 insertions(+), 26 deletions(-) diff --git a/openeo/udf/run_code.py b/openeo/udf/run_code.py index bbb9304d1..ad267a9a7 100644 --- a/openeo/udf/run_code.py +++ b/openeo/udf/run_code.py @@ -203,13 +203,13 @@ def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.D d = XarrayDataCube(datacube) else: raise ValueError(datacube) + d_array = d.get_array() + expected_order = ("t", "bands", "y", "x") + dims = [d for d in expected_order if d in d_array.dims] + # TODO: skip going through XarrayDataCube above, we only need xarray.DataArray here anyway. # datacube's data is to be float and x,y not provided - d = XarrayDataCube(d.get_array() - .astype(numpy.float64) - .drop(labels='x') - .drop(labels='y') - ) + d = XarrayDataCube(d_array.transpose(*dims).astype(numpy.float64).drop(labels="x").drop(labels="y")) # wrap to udf_data udf_data = UdfData(datacube_list=[d]) diff --git a/tests/udf/test_run_code.py b/tests/udf/test_run_code.py index 5b28afde2..b0ce5a84e 100644 --- a/tests/udf/test_run_code.py +++ b/tests/udf/test_run_code.py @@ -238,18 +238,15 @@ def test_execute_local_udf_basic(): assert isinstance(res, UdfData) result = res.get_datacube_list()[0].get_array() - assert result.shape == (3, 1, 5, 6) + assert result.shape == (3, 1, 6, 5) expected = xarray.DataArray( - [ - [_ndvi(0, 100), _ndvi(1, 101)], - [_ndvi(10, 110), _ndvi(11, 111)] - ], - dims=["x", "y"], + [[_ndvi(0, 100), _ndvi(10, 110)], [_ndvi(1, 101), _ndvi(11, 111)]], + dims=["y", "x"], coords={"t": numpy.datetime64("2020-08-01"), "bands": "ndvi"} ) xarray.testing.assert_equal(result[0, 0, 0:2, 0:2], expected) - assert result[2, 0, 3, 4] == _ndvi(2034, 2134) + assert result[2, 0, 4, 3] == _ndvi(2034, 2134) def test_run_local_udf_from_file_json(tmp_path): @@ -266,18 +263,15 @@ def test_run_local_udf_from_file_json(tmp_path): assert isinstance(res, UdfData) result = res.get_datacube_list()[0].get_array() - assert result.shape == (3, 1, 5, 6) + assert result.shape == (3, 1, 6, 5) expected = xarray.DataArray( - [ - [_ndvi(0, 100), _ndvi(1, 101)], - [_ndvi(10, 110), _ndvi(11, 111)] - ], - dims=["x", "y"], + [[_ndvi(0, 100), _ndvi(10, 110)], [_ndvi(1, 101), _ndvi(11, 111)]], + dims=["y", "x"], coords={"t": numpy.datetime64("2020-08-01"), "bands": "ndvi"} ) xarray.testing.assert_equal(result[0, 0, 0:2, 0:2], expected) - assert result[2, 0, 3, 4] == _ndvi(2034, 2134) + assert result[2, 0, 4, 3] == _ndvi(2034, 2134) def test_run_local_udf_from_file_netcdf(tmp_path): @@ -294,15 +288,12 @@ def test_run_local_udf_from_file_netcdf(tmp_path): assert isinstance(res, UdfData) result = res.get_datacube_list()[0].get_array() - assert result.shape == (3, 1, 5, 6) + assert result.shape == (3, 1, 6, 5) expected = xarray.DataArray( - [ - [_ndvi(0, 100), _ndvi(1, 101)], - [_ndvi(10, 110), _ndvi(11, 111)] - ], - dims=["x", "y"], + [[_ndvi(0, 100), _ndvi(10, 110)], [_ndvi(1, 101), _ndvi(11, 111)]], + dims=["y", "x"], coords={"t": numpy.datetime64("2020-08-01"), "bands": "ndvi"} ) xarray.testing.assert_equal(result[0, 0, 0:2, 0:2], expected) - assert result[2, 0, 3, 4] == _ndvi(2034, 2134) + assert result[2, 0, 4, 3] == _ndvi(2034, 2134)