diff --git a/xpublish_edr/geometry/common.py b/xpublish_edr/geometry/common.py index 7eac8c2..ee5844a 100644 --- a/xpublish_edr/geometry/common.py +++ b/xpublish_edr/geometry/common.py @@ -96,15 +96,23 @@ def project_dataset(ds: xr.Dataset, query_crs: str | pyproj.CRS) -> xr.Dataset: always_xy=True, ) - # TODO: Handle rotated pole - cf_coords = target_crs.coordinate_system.to_cf() - - # Get the new X and Y coordinates - target_y_coord = next(coord for coord in cf_coords if coord["axis"] == "Y") - target_x_coord = next(coord for coord in cf_coords if coord["axis"] == "X") - - X = ds.cf["X"] - Y = ds.cf["Y"] + # Unpack the coordinates + try: + X = ds.cf["X"] + Y = ds.cf["Y"] + except KeyError: + # If the dataset has multiple X axes, we can try to find the right one + source_cf_coords = data_crs.coordinate_system.to_cf() + + source_x_coord = next( + coord["standard_name"] for coord in source_cf_coords if coord["axis"] == "X" + ) + source_y_coord = next( + coord["standard_name"] for coord in source_cf_coords if coord["axis"] == "Y" + ) + + X = ds.cf[source_x_coord] + Y = ds.cf[source_y_coord] # Transform the coordinates # If the data is vectorized, we just transform the points in full @@ -124,6 +132,13 @@ def project_dataset(ds: xr.Dataset, query_crs: str | pyproj.CRS) -> xr.Dataset: c for c in ds.coords if x_dim in ds[c].dims or y_dim in ds[c].dims ] + # TODO: Handle rotated pole + target_cf_coords = target_crs.coordinate_system.to_cf() + + # Get the new X and Y coordinates + target_x_coord = next(coord for coord in target_cf_coords if coord["axis"] == "X") + target_y_coord = next(coord for coord in target_cf_coords if coord["axis"] == "Y") + target_x_coord_name = target_x_coord["standard_name"] target_y_coord_name = target_y_coord["standard_name"]