diff --git a/docs/internal_api/index.rst b/docs/internal_api/index.rst index 31c4d5172..f6557ff41 100644 --- a/docs/internal_api/index.rst +++ b/docs/internal_api/index.rst @@ -237,6 +237,7 @@ Remapping remap.inverse_distance_weighted._inverse_distance_weighted_remap remap.inverse_distance_weighted._inverse_distance_weighted_remap_uxda remap.inverse_distance_weighted._inverse_distance_weighted_remap_uxds + remap.utils._remap_grid_parse Grid Parsing and Encoding diff --git a/test/test_remap.py b/test/test_remap.py index e3b6aacf3..6c3aa0370 100644 --- a/test/test_remap.py +++ b/test/test_remap.py @@ -3,6 +3,7 @@ from unittest import TestCase from pathlib import Path +import numpy.testing as nt import uxarray as ux @@ -104,16 +105,15 @@ def test_remap_return_types(self): dsfile_v1_geoflow, dsfile_v2_geoflow, dsfile_v3_geoflow ] source_uxds = ux.open_mfdataset(gridfile_geoflow, source_data_paths) - destination_uxds = ux.open_dataset(gridfile_CSne30, - dsfile_vortex_CSne30) + destination_grid = ux.open_grid(gridfile_CSne30) remap_uxda_to_grid = source_uxds['v1'].remap.nearest_neighbor( - destination_uxds.uxgrid) + destination_grid) assert isinstance(remap_uxda_to_grid, UxDataArray) remap_uxds_to_grid = source_uxds.remap.nearest_neighbor( - destination_uxds.uxgrid) + destination_grid) # Dataset with three vars: remapped "v1, v2, v3" assert isinstance(remap_uxds_to_grid, UxDataset) @@ -125,13 +125,17 @@ def test_edge_centers_remapping(self): # Open source and destination datasets to remap to source_grid = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow) - destination_grid = ux.open_dataset(mpasfile_QU, mpasfile_QU) + destination_grid = ux.open_grid(mpasfile_QU) - remap_to_edge_centers = source_grid['v1'].remap.nearest_neighbor(destination_grid=destination_grid.uxgrid, - remap_to="edge centers") + remap_to_edge_centers_spherical = source_grid['v1'].remap.nearest_neighbor(destination_grid=destination_grid, + remap_to="edge centers", coord_type='spherical') + + remap_to_edge_centers_cartesian = source_grid['v1'].remap.nearest_neighbor(destination_grid=destination_grid, + remap_to="edge centers", coord_type='cartesian') # Assert the data variable lies on the "edge centers" - self.assertTrue(remap_to_edge_centers._edge_centered()) + self.assertTrue(remap_to_edge_centers_spherical._edge_centered()) + self.assertTrue(remap_to_edge_centers_cartesian._edge_centered()) def test_overwrite(self): """Tests that the remapping no longer overwrites the dataset.""" @@ -142,11 +146,74 @@ def test_overwrite(self): # Perform remapping remap_to_edge_centers = source_grid['v1'].remap.nearest_neighbor(destination_grid=destination_dataset.uxgrid, - remap_to="nodes") + remap_to="face centers", coord_type='cartesian') # Assert the remapped data is different from the original data assert not np.array_equal(destination_dataset['v1'], remap_to_edge_centers) + def test_source_data_remap(self): + """Test the remapping of all source data positions.""" + + # Open source and destination datasets to remap to + source_uxds = ux.open_dataset(mpasfile_QU, mpasfile_QU) + destination_grid = ux.open_grid(gridfile_geoflow) + + # Remap from `face_centers` + face_centers = source_uxds['latCell'].remap.nearest_neighbor( + destination_grid=destination_grid, + remap_to="nodes" + ) + + # Remap from `nodes` + nodes = source_uxds['latVertex'].remap.nearest_neighbor( + destination_grid=destination_grid, + remap_to="nodes" + ) + + # Remap from `edges` + edges = source_uxds['angleEdge'].remap.nearest_neighbor( + destination_grid=destination_grid, + remap_to="nodes" + ) + + self.assertTrue(len(face_centers.values) != 0) + self.assertTrue(len(nodes.values) != 0) + self.assertTrue(len(edges.values) != 0) + + def test_value_errors(self): + """Tests the raising of value errors and warnings in the function.""" + + # Open source and destination datasets to remap to + source_uxds = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow) + source_uxds_2 = ux.open_dataset(mpasfile_QU, mpasfile_QU) + destination_grid = ux.open_grid(gridfile_geoflow) + + # Raise ValueError when `remap_to` is invalid + with nt.assert_raises(ValueError): + source_uxds['v1'].remap.nearest_neighbor( + destination_grid=destination_grid, + remap_to="test", coord_type='spherical' + ) + with nt.assert_raises(ValueError): + source_uxds['v1'].remap.nearest_neighbor( + destination_grid=destination_grid, + remap_to="test", coord_type="cartesian" + ) + + # Raise ValueError when `coord_type` is invalid + with nt.assert_raises(ValueError): + source_uxds['v1'].remap.nearest_neighbor( + destination_grid=destination_grid, + remap_to="nodes", coord_type="test" + ) + + # Raise ValueError when the source data is invalid + with nt.assert_raises(ValueError): + source_uxds_2['cellsOnCell'].remap.nearest_neighbor( + destination_grid=destination_grid, + remap_to="nodes" + ) + class TestInverseDistanceWeightedRemapping(TestCase): """Testing for inverse distance weighted remapping.""" @@ -156,10 +223,10 @@ def test_remap_center_nodes(self): # datasets to use for remap dataset = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow) - destination_uxds = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow) + destination_grid = ux.open_grid(gridfile_geoflow) data_on_face_centers = dataset['v1'].remap.inverse_distance_weighted( - destination_uxds.uxgrid, remap_to="face centers") + destination_grid, remap_to="face centers", power=6) assert not np.array_equal(dataset['v1'], data_on_face_centers) @@ -168,10 +235,10 @@ def test_remap_nodes(self): # datasets to use for remap dataset = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow) - destination_uxds = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow) + destination_grid = ux.open_grid(gridfile_geoflow) data_on_nodes = dataset['v1'].remap.inverse_distance_weighted( - destination_uxds.uxgrid, remap_to="nodes") + destination_grid, remap_to="nodes") assert not np.array_equal(dataset['v1'], data_on_nodes) @@ -217,17 +284,16 @@ def test_remap_return_types(self): dsfile_v1_geoflow, dsfile_v2_geoflow, dsfile_v3_geoflow ] source_uxds = ux.open_mfdataset(gridfile_geoflow, source_data_paths) - destination_uxds = ux.open_dataset(gridfile_CSne30, - dsfile_vortex_CSne30) + destination_grid = ux.open_grid(gridfile_CSne30) remap_uxda_to_grid = source_uxds['v1'].remap.inverse_distance_weighted( - destination_uxds.uxgrid, power=3, k=10) + destination_grid, power=3, k=10) assert isinstance(remap_uxda_to_grid, UxDataArray) assert len(remap_uxda_to_grid) == 1 remap_uxds_to_grid = source_uxds.remap.inverse_distance_weighted( - destination_uxds.uxgrid) + destination_grid) # Dataset with three vars: remapped "v1, v2, v3" assert isinstance(remap_uxds_to_grid, UxDataset) @@ -239,15 +305,21 @@ def test_edge_remapping(self): # Open source and destination datasets to remap to source_grid = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow) - destination_grid = ux.open_dataset(mpasfile_QU, mpasfile_QU) + destination_grid = ux.open_grid(mpasfile_QU) # Perform remapping to the edge centers of the dataset - remap_to_edge_centers = source_grid['v1'].remap.inverse_distance_weighted(destination_grid=destination_grid.uxgrid, - remap_to="edge centers") + remap_to_edge_centers_spherical = source_grid['v1'].remap.inverse_distance_weighted( + destination_grid=destination_grid, + remap_to="edge centers", coord_type='spherical') + + remap_to_edge_centers_cartesian = source_grid['v1'].remap.inverse_distance_weighted( + destination_grid=destination_grid, + remap_to="edge centers", coord_type='cartesian') # Assert the data variable lies on the "edge centers" - self.assertTrue(remap_to_edge_centers._edge_centered()) + self.assertTrue(remap_to_edge_centers_spherical._edge_centered()) + self.assertTrue(remap_to_edge_centers_cartesian._edge_centered()) def test_overwrite(self): """Tests that the remapping no longer overwrites the dataset.""" @@ -257,8 +329,93 @@ def test_overwrite(self): destination_dataset = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow) # Perform Remapping - remap_to_edge_centers = source_grid['v1'].remap.inverse_distance_weighted(destination_grid=destination_dataset.uxgrid, - remap_to="nodes") + remap_to_edge_centers = source_grid['v1'].remap.inverse_distance_weighted( + destination_grid=destination_dataset.uxgrid, + remap_to="face centers", coord_type='cartesian') # Assert the remapped data is different from the original data assert not np.array_equal(destination_dataset['v1'], remap_to_edge_centers) + + def test_source_data_remap(self): + """Test the remapping of all source data positions.""" + + # Open source and destination datasets to remap to + source_uxds = ux.open_dataset(mpasfile_QU, mpasfile_QU) + destination_grid = ux.open_grid(gridfile_geoflow) + + # Remap from `face_centers` + face_centers = source_uxds['latCell'].remap.inverse_distance_weighted( + destination_grid=destination_grid, + remap_to="nodes" + ) + + # Remap from `nodes` + nodes = source_uxds['latVertex'].remap.inverse_distance_weighted( + destination_grid=destination_grid, + remap_to="nodes" + ) + + # Remap from `edges` + edges = source_uxds['angleEdge'].remap.inverse_distance_weighted( + destination_grid=destination_grid, + remap_to="nodes" + ) + + self.assertTrue(len(face_centers.values) != 0) + self.assertTrue(len(nodes.values) != 0) + self.assertTrue(len(edges.values) != 0) + + def test_value_errors(self): + """Tests the raising of value errors and warnings in the function.""" + + # Open source and destination datasets to remap to + source_uxds = ux.open_dataset(gridfile_geoflow, dsfile_v1_geoflow) + source_uxds_2 = ux.open_dataset(mpasfile_QU, mpasfile_QU) + destination_grid = ux.open_grid(gridfile_geoflow) + + # Raise ValueError when `k` =< 1 + with nt.assert_raises(ValueError): + source_uxds['v1'].remap.inverse_distance_weighted( + destination_grid=destination_grid, + remap_to="nodes", k=1 + ) + + # Raise ValueError when k is larger than `n_node` + with nt.assert_raises(ValueError): + source_uxds['v1'].remap.inverse_distance_weighted( + destination_grid=destination_grid, + remap_to="nodes", k=source_uxds.uxgrid.n_node + 1 + ) + + # Raise ValueError when `remap_to` is invalid + with nt.assert_raises(ValueError): + source_uxds['v1'].remap.inverse_distance_weighted( + destination_grid=destination_grid, + remap_to="test", k=2, coord_type='spherical' + ) + with nt.assert_raises(ValueError): + source_uxds['v1'].remap.inverse_distance_weighted( + destination_grid=destination_grid, + remap_to="test", k=2, coord_type="cartesian" + ) + + # Raise ValueError when `coord_type` is invalid + with nt.assert_raises(ValueError): + source_uxds['v1'].remap.inverse_distance_weighted( + destination_grid=destination_grid, + remap_to="nodes", k=2, coord_type="test" + ) + + # Raise ValueError when the source data is invalid + with nt.assert_raises(ValueError): + source_uxds_2['cellsOnCell'].remap.inverse_distance_weighted( + destination_grid=destination_grid, + remap_to="nodes" + ) + + # Raise UserWarning when `power` > 5 + with nt.assert_warns(UserWarning): + source_uxds['v1'].remap.inverse_distance_weighted( + destination_grid=destination_grid, + remap_to="nodes", power=6 + ) diff --git a/uxarray/remap/inverse_distance_weighted.py b/uxarray/remap/inverse_distance_weighted.py index 92cf1602f..a4436aa64 100644 --- a/uxarray/remap/inverse_distance_weighted.py +++ b/uxarray/remap/inverse_distance_weighted.py @@ -12,16 +12,18 @@ from uxarray.grid import Grid import warnings +from uxarray.remap.utils import _remap_grid_parse + def _inverse_distance_weighted_remap( - source_grid, - destination_grid, - source_data, - remap_to="face centers", - coord_type="spherical", + source_grid: Grid, + destination_grid: Grid, + source_data: np.ndarray, + remap_to: str = "face centers", + coord_type: str = "spherical", power=2, k=8, -): +) -> np.ndarray: """Inverse Distance Weighted Remapping between two grids. Parameters: @@ -49,7 +51,7 @@ def _inverse_distance_weighted_remap( """ if power > 5: - warnings.warn("It is recommended not to exceed a power of 5.0.") + warnings.warn("It is recommended not to exceed a power of 5.0.", UserWarning) if k > source_grid.n_node: raise ValueError( f"Number of nearest neighbors to be used in the calculation is {k}, but should not exceed the " @@ -60,92 +62,18 @@ def _inverse_distance_weighted_remap( f"Number of nearest neighbors to be used in the calculation is {k}, but should be greater than 1" ) + # ensure array is a np.ndarray source_data = np.asarray(source_data) - n_elements = source_data.shape[-1] - - if n_elements == source_grid.n_node: - source_data_mapping = "nodes" - elif n_elements == source_grid.n_face: - source_data_mapping = "face centers" - elif n_elements == source_grid.n_edge: - source_data_mapping = "edge centers" - else: - raise ValueError( - f"Invalid source_data shape. The final dimension should match the number of corner " - f"nodes ({source_grid.n_node}), edge nodes ({source_grid.n_edge}), or face centers ({source_grid.n_face}) " - f"in the source grid, but received: {source_data.shape}" - ) - - if coord_type == "spherical": - if remap_to == "nodes": - lon, lat = ( - destination_grid.node_lon.values, - destination_grid.node_lat.values, - ) - elif remap_to == "face centers": - lon, lat = ( - destination_grid.face_lon.values, - destination_grid.face_lat.values, - ) - elif remap_to == "edge centers": - lon, lat = ( - destination_grid.edge_lon.values, - destination_grid.edge_lat.values, - ) - else: - raise ValueError( - f"Invalid remap_to. Expected 'nodes', 'edge centers', or 'face centers', " - f"but received: {remap_to}" - ) - - _source_tree = source_grid.get_ball_tree(coordinates=source_data_mapping) - - dest_coords = np.vstack([lon, lat]).T - - distances, nearest_neighbor_indices = _source_tree.query(dest_coords, k=k) - elif coord_type == "cartesian": - if remap_to == "nodes": - x, y, z = ( - destination_grid.node_x.values, - destination_grid.node_y.values, - destination_grid.node_z.values, - ) - elif remap_to == "face centers": - x, y, z = ( - destination_grid.face_x.values, - destination_grid.face_y.values, - destination_grid.face_z.values, - ) - elif remap_to == "edge centers": - x, y, z = ( - destination_grid.edge_x.values, - destination_grid.edge_y.values, - destination_grid.edge_z.values, - ) - else: - raise ValueError( - f"Invalid remap_to. Expected 'nodes', 'edge centers', or 'face centers', " - f"but received: {remap_to}" - ) - - _source_tree = source_grid.get_ball_tree( - coordinates=source_data_mapping, - coordinate_system="cartesian", - distance_metric="minkowski", - ) - - dest_coords = np.vstack([x, y, z]).T - - distances, nearest_neighbor_indices = _source_tree.query(dest_coords, k=k) - - else: - raise ValueError( - f"Invalid coord_type. Expected either 'spherical' or 'cartesian', but received {coord_type}" - ) - - if nearest_neighbor_indices.ndim > 1: - nearest_neighbor_indices = nearest_neighbor_indices.squeeze() + _, distances, nearest_neighbor_indices = _remap_grid_parse( + source_data, + source_grid, + destination_grid, + coord_type, + remap_to, + k=k, + query=True, + ) weights = 1 / (distances**power + 1e-6) weights /= np.sum(weights, axis=1, keepdims=True) diff --git a/uxarray/remap/nearest_neighbor.py b/uxarray/remap/nearest_neighbor.py index 1923bd1b7..83a8760c8 100644 --- a/uxarray/remap/nearest_neighbor.py +++ b/uxarray/remap/nearest_neighbor.py @@ -1,6 +1,8 @@ from __future__ import annotations from typing import TYPE_CHECKING +from uxarray.remap.utils import _remap_grid_parse + if TYPE_CHECKING: from uxarray.core.dataset import UxDataset from uxarray.core.dataarray import UxDataArray @@ -45,98 +47,15 @@ def _nearest_neighbor( # ensure array is a np.ndarray source_data = np.asarray(source_data) - n_elements = source_data.shape[-1] - - if n_elements == source_grid.n_node: - source_data_mapping = "nodes" - elif n_elements == source_grid.n_edge: - source_data_mapping = "edge centers" - elif n_elements == source_grid.n_face: - source_data_mapping = "face centers" - else: - raise ValueError( - f"Invalid source_data shape. The final dimension should be either match the number of corner " - f"nodes ({source_grid.n_node}), edge centers ({source_grid.n_edge}), or face centers ({source_grid.n_face}) in the" - f" source grid, but received: {source_data.shape}" - ) - - if coord_type == "spherical": - # get destination coordinate pairs - if remap_to == "nodes": - lon, lat = ( - destination_grid.node_lon.values, - destination_grid.node_lat.values, - ) - elif remap_to == "edge centers": - lon, lat = ( - destination_grid.edge_lon.values, - destination_grid.edge_lat.values, - ) - elif remap_to == "face centers": - lon, lat = ( - destination_grid.face_lon.values, - destination_grid.face_lat.values, - ) - else: - raise ValueError( - f"Invalid remap_to. Expected 'nodes', 'edge centers', or 'face centers', " - f"but received: {remap_to}" - ) - - # specify whether to query on the corner nodes or face centers based on source grid - _source_tree = source_grid.get_ball_tree(coordinates=source_data_mapping) - - # prepare coordinates for query - latlon = np.vstack([lon, lat]).T - - _, nearest_neighbor_indices = _source_tree.query(latlon, k=1) - - elif coord_type == "cartesian": - # get destination coordinates - if remap_to == "nodes": - cart_x, cart_y, cart_z = ( - destination_grid.node_x.values, - destination_grid.node_y.values, - destination_grid.node_z.values, - ) - elif remap_to == "edge centers": - cart_x, cart_y, cart_z = ( - destination_grid.edge_x.values, - destination_grid.edge_y.values, - destination_grid.edge_z.values, - ) - elif remap_to == "face centers": - cart_x, cart_y, cart_z = ( - destination_grid.face_x.values, - destination_grid.face_y.values, - destination_grid.face_z.values, - ) - else: - raise ValueError( - f"Invalid remap_to. Expected 'nodes', 'edge centers', or 'face centers', " - f"but received: {remap_to}" - ) - - # specify whether to query on the corner nodes or face centers based on source grid - _source_tree = source_grid.get_ball_tree( - coordinates=source_data_mapping, - coordinate_system="cartesian", - distance_metric="minkowski", - ) - - # prepare coordinates for query - cartesian = np.vstack([cart_x, cart_y, cart_z]).T - - _, nearest_neighbor_indices = _source_tree.query(cartesian, k=1) - - else: - raise ValueError( - f"Invalid coord_type. Expected either 'spherical' or 'cartesian', but received {coord_type}" - ) - - # data values from source data to destination data using nearest neighbor indices - if nearest_neighbor_indices.ndim > 1: - nearest_neighbor_indices = nearest_neighbor_indices.squeeze() + _, _, nearest_neighbor_indices = _remap_grid_parse( + source_data, + source_grid, + destination_grid, + coord_type, + remap_to, + k=1, + query=True, + ) # support arbitrary dimension data using Ellipsis "..." destination_data = source_data[..., nearest_neighbor_indices] diff --git a/uxarray/remap/utils.py b/uxarray/remap/utils.py new file mode 100644 index 000000000..708b5e1ad --- /dev/null +++ b/uxarray/remap/utils.py @@ -0,0 +1,132 @@ +import numpy as np + + +def _remap_grid_parse( + source_data, source_grid, destination_grid, coord_type, remap_to, k, query +): + """Gets the destination coordinates from the destination grid for + remapping, as well as retrieving the nearest neighbor indices and + distances. + + Parameters: + ----------- + source_data : np.ndarray + Data variable to remap. + source_grid : Grid + Source grid that data is mapped from. + destination_grid : Grid + Destination grid to remap data to. + coord_type: str + Coordinate type to use for nearest neighbor query, either "spherical" or "Cartesian". + remap_to : str + Location of where to map data, either "nodes", "edge centers", or "face centers". + k : int + Number of nearest neighbors to consider in the weighted calculation. + query : bool + Whether to construct and query the tree based on the source grid. + + Returns: + -------- + dest_coords : np.ndarray + Returns the proper destination coordinates based on `remap_to` + distances : np.ndarray + Returns the distances of the query of `k` nearest neighbors. + nearest_neighbor_indices : np.ndarray + Returns the nearest neighbor indices of number `k`. + """ + + n_elements = source_data.shape[-1] + + if n_elements == source_grid.n_node: + source_data_mapping = "nodes" + elif n_elements == source_grid.n_face: + source_data_mapping = "face centers" + elif n_elements == source_grid.n_edge: + source_data_mapping = "edge centers" + else: + raise ValueError( + f"Invalid source_data shape. The final dimension should match the number of corner " + f"nodes ({source_grid.n_node}), edge nodes ({source_grid.n_edge}), or face centers ({source_grid.n_face}) " + f"in the source grid, but received: {source_data.shape}" + ) + + if coord_type == "spherical": + if remap_to == "nodes": + lon, lat = ( + destination_grid.node_lon.values, + destination_grid.node_lat.values, + ) + elif remap_to == "face centers": + lon, lat = ( + destination_grid.face_lon.values, + destination_grid.face_lat.values, + ) + elif remap_to == "edge centers": + lon, lat = ( + destination_grid.edge_lon.values, + destination_grid.edge_lat.values, + ) + else: + raise ValueError( + f"Invalid remap_to. Expected 'nodes', 'edge centers', or 'face centers', " + f"but received: {remap_to}" + ) + + _source_tree = source_grid.get_ball_tree( + coordinates=source_data_mapping, reconstruct=True + ) + + dest_coords = np.vstack([lon, lat]).T + + if query: + distances, nearest_neighbor_indices = _source_tree.query(dest_coords, k=k) + + elif coord_type == "cartesian": + if remap_to == "nodes": + x, y, z = ( + destination_grid.node_x.values, + destination_grid.node_y.values, + destination_grid.node_z.values, + ) + elif remap_to == "face centers": + x, y, z = ( + destination_grid.face_x.values, + destination_grid.face_y.values, + destination_grid.face_z.values, + ) + elif remap_to == "edge centers": + x, y, z = ( + destination_grid.edge_x.values, + destination_grid.edge_y.values, + destination_grid.edge_z.values, + ) + else: + raise ValueError( + f"Invalid remap_to. Expected 'nodes', 'edge centers', or 'face centers', " + f"but received: {remap_to}" + ) + + _source_tree = source_grid.get_ball_tree( + coordinates=source_data_mapping, + coordinate_system="cartesian", + distance_metric="minkowski", + reconstruct=True, + ) + + dest_coords = np.vstack([x, y, z]).T + + if query: + distances, nearest_neighbor_indices = _source_tree.query(dest_coords, k=k) + + else: + raise ValueError( + f"Invalid coord_type. Expected either 'spherical' or 'cartesian', but received {coord_type}" + ) + + if nearest_neighbor_indices.ndim > 1: + nearest_neighbor_indices = nearest_neighbor_indices.squeeze() + + if query: + return dest_coords, distances, nearest_neighbor_indices + else: + return dest_coords