diff --git a/test/grid_geoflow.exo b/test/grid_geoflow.exo deleted file mode 100644 index 8abc65564..000000000 Binary files a/test/grid_geoflow.exo and /dev/null differ diff --git a/test/test_cross_sections.py b/test/test_cross_sections.py index 26bf8777c..eea11fe8f 100644 --- a/test/test_cross_sections.py +++ b/test/test_cross_sections.py @@ -33,43 +33,47 @@ class TestQuadHex: """ def test_constant_lat_cross_section_grid(self): - uxgrid = ux.open_grid(quad_hex_grid_path) - grid_top_two = uxgrid.cross_section.constant_latitude(lat=0.1) + for method in ["bounding_box_intersection", "edge_intersection"]: - assert grid_top_two.n_face == 2 + uxgrid = ux.open_grid(quad_hex_grid_path) - grid_bottom_two = uxgrid.cross_section.constant_latitude(lat=-0.1) + grid_top_two = uxgrid.cross_section.constant_latitude(lat=0.1, method=method) - assert grid_bottom_two.n_face == 2 + assert grid_top_two.n_face == 2 - grid_all_four = uxgrid.cross_section.constant_latitude(lat=0.0) + grid_bottom_two = uxgrid.cross_section.constant_latitude(lat=-0.1, method=method) - assert grid_all_four.n_face == 4 + assert grid_bottom_two.n_face == 2 - with pytest.raises(ValueError): - # no intersections found at this line - uxgrid.cross_section.constant_latitude(lat=10.0) + grid_all_four = uxgrid.cross_section.constant_latitude(lat=0.0, method=method) + + assert grid_all_four.n_face == 4 + + with pytest.raises(ValueError): + # no intersections found at this line + uxgrid.cross_section.constant_latitude(lat=10.0, method=method) def test_constant_lat_cross_section_uxds(self): - uxds = ux.open_dataset(quad_hex_grid_path, quad_hex_data_path) + for method in ["bounding_box_intersection", "edge_intersection"]: + uxds = ux.open_dataset(quad_hex_grid_path, quad_hex_data_path) - da_top_two = uxds['t2m'].cross_section.constant_latitude(lat=0.1) + da_top_two = uxds['t2m'].cross_section.constant_latitude(lat=0.1, method=method) - nt.assert_array_equal(da_top_two.data, uxds['t2m'].isel(n_face=[1, 2]).data) + nt.assert_array_equal(da_top_two.data, uxds['t2m'].isel(n_face=[1, 2]).data) - da_bottom_two = uxds['t2m'].cross_section.constant_latitude(lat=-0.1) + da_bottom_two = uxds['t2m'].cross_section.constant_latitude(lat=-0.1, method=method) - nt.assert_array_equal(da_bottom_two.data, uxds['t2m'].isel(n_face=[0, 3]).data) + nt.assert_array_equal(da_bottom_two.data, uxds['t2m'].isel(n_face=[0, 3]).data) - da_all_four = uxds['t2m'].cross_section.constant_latitude(lat=0.0) + da_all_four = uxds['t2m'].cross_section.constant_latitude(lat=0.0, method=method) - nt.assert_array_equal(da_all_four.data , uxds['t2m'].data) + nt.assert_array_equal(da_all_four.data , uxds['t2m'].data) - with pytest.raises(ValueError): - # no intersections found at this line - uxds['t2m'].cross_section.constant_latitude(lat=10.0) + with pytest.raises(ValueError): + # no intersections found at this line + uxds['t2m'].cross_section.constant_latitude(lat=10.0, method=method) class TestGeosCubeSphere: diff --git a/uxarray/cross_sections/dataarray_accessor.py b/uxarray/cross_sections/dataarray_accessor.py index d840892b6..378bb1d5b 100644 --- a/uxarray/cross_sections/dataarray_accessor.py +++ b/uxarray/cross_sections/dataarray_accessor.py @@ -21,7 +21,7 @@ def __repr__(self): return prefix + methods_heading - def constant_latitude(self, lat: float, method="fast"): + def constant_latitude(self, lat: float, method="edge_intersection"): """Extracts a cross-section of the data array at a specified constant latitude. @@ -32,9 +32,11 @@ def constant_latitude(self, lat: float, method="fast"): method : str, optional The internal method to use when identifying faces at the constant latitude. Options are: - - 'fast': Uses a faster but potentially less accurate method for face identification. - - 'accurate': Uses a slower but more accurate method. - Default is 'fast'. + - 'edge_intersection': The intersection of each edge with a line of constant latitude is calculated, with + faces that contain that edges included in the result. + - 'bounding_box_intersection': The minimum and maximum latitude of each face is used to determine if + the line of constant latitude intersects it. + Default is 'edge_intersection'. Raises ------ @@ -44,11 +46,6 @@ def constant_latitude(self, lat: float, method="fast"): Examples -------- >>> uxda.constant_latitude_cross_section(lat=-15.5) - - Notes - ----- - The accuracy and performance of the function can be controlled using the `method` parameter. - For higher precision requreiments, consider using method='acurate'. """ faces = self.uxda.uxgrid.get_faces_at_constant_latitude(lat, method) diff --git a/uxarray/cross_sections/grid_accessor.py b/uxarray/cross_sections/grid_accessor.py index 067e8f5fb..4d1458809 100644 --- a/uxarray/cross_sections/grid_accessor.py +++ b/uxarray/cross_sections/grid_accessor.py @@ -20,15 +20,12 @@ def __repr__(self): methods_heading += " * constant_latitude(lat, )\n" return prefix + methods_heading - def constant_latitude(self, lat: float, return_face_indices=False, method="fast"): + def constant_latitude( + self, lat: float, return_face_indices=False, method="edge_intersection" + ): """Extracts a cross-section of the grid at a specified constant latitude. - This method identifies and returns all faces (or grid elements) that intersect - with a given latitude. The returned cross-section can include either just the grid - or both the grid elements and the corresponding face indices, depending - on the `return_face_indices` parameter. - Parameters ---------- lat : float @@ -40,9 +37,11 @@ def constant_latitude(self, lat: float, return_face_indices=False, method="fast" method : str, optional The internal method to use when identifying faces at the constant latitude. Options are: - - 'fast': Uses a faster but potentially less accurate method for face identification. - - 'accurate': Uses a slower but more accurate method. - Default is 'fast'. + - 'edge_intersection': The intersection of each edge with a line of constant latitude is calculated, with + faces that contain that edges included in the result. + - 'bounding_box_intersection': The minimum and maximum latitude of each face is used to determine if + the line of constant latitude intersects it. + Default is 'edge_intersection'. Returns ------- diff --git a/uxarray/grid/grid.py b/uxarray/grid/grid.py index c90b97786..604084b77 100644 --- a/uxarray/grid/grid.py +++ b/uxarray/grid/grid.py @@ -67,7 +67,8 @@ ) from uxarray.grid.intersections import ( - fast_constant_lat_intersections, + constant_lat_intersections_edges, + constant_lat_intersections_face_bounds, ) from spatialpandas import GeoDataFrame @@ -1967,58 +1968,52 @@ def isel(self, **dim_kwargs): "Indexing must be along a grid dimension: ('n_node', 'n_edge', 'n_face')" ) - def get_edges_at_constant_latitude(self, lat, method="fast"): + def get_edges_at_constant_latitude(self, lat, method="edge_intersection"): """Identifies the edges of the grid that intersect with a specified constant latitude. - This method computes the intersection of grid edges with a given latitude and - returns a collection of edges that cross or are aligned with that latitude. - The method used for identifying these edges can be controlled by the `method` - parameter. - Parameters ---------- lat : float The latitude at which to identify intersecting edges, in degrees. method : str, optional - The computational method used to determine edge intersections. Options are: - - 'fast': Uses a faster but potentially less accurate method for determining intersections. - - 'accurate': Uses a slower but more precise method. - Default is 'fast'. + Method used to determine edge intersections. Options are: + - 'edge_intersection': The intersection of each edge is explicit computed Returns ------- edges : array A squeezed array of edges that intersect the specified constant latitude. """ - if method == "fast": - edges = fast_constant_lat_intersections( + + if lat > 90.0 or lat < -90.0: + raise ValueError( + f"Latitude must be between -90 and 90 degrees. Received {lat}" + ) + + if method == "edge_intersection": + edges = constant_lat_intersections_edges( lat, self.edge_node_z.values, self.n_edge ) - elif method == "accurate": - raise NotImplementedError("Accurate method not yet implemented.") else: raise ValueError(f"Invalid method: {method}.") return edges.squeeze() - def get_faces_at_constant_latitude(self, lat, method="fast"): + def get_faces_at_constant_latitude(self, lat, method="edge_intersection"): """Identifies the faces of the grid that intersect with a specified constant latitude. - This method finds the faces (or cells) of the grid that intersect a given latitude - by first identifying the intersecting edges and then determining the faces connected - to these edges. The method used for identifying edges can be adjusted with the `method` - parameter. - Parameters ---------- lat : float The latitude at which to identify intersecting faces, in degrees. method : str, optional - The computational method used to determine intersecting edges. Options are: - - 'fast': Uses a faster but potentially less accurate method for determining intersections. - - 'accurate': Uses a slower but more precise method. - Default is 'fast'. + The computational method used to determine intersecting faces. Options are: + - 'edge_intersection': The intersection of each edge with a line of constant latitude is calculated, with + faces that contain that edges included in the result. + - 'bounding_box_intersection': The minimum and maximum latitude of each face is used to determine if + the line of constant latitude intersects it. + Default is 'edge_intersection' Returns ------- @@ -2027,7 +2022,21 @@ def get_faces_at_constant_latitude(self, lat, method="fast"): Faces that are invalid or missing (e.g., with a fill value) are excluded from the result. """ - edges = self.get_edges_at_constant_latitude(lat, method) - faces = np.unique(self.edge_face_connectivity[edges].data.ravel()) - return faces[faces != INT_FILL_VALUE] + if lat > 90.0 or lat < -90.0: + raise ValueError( + f"Latitude must be between -90 and 90 degrees. Received {lat}" + ) + + if method == "bounding_box_intersection": + faces = constant_lat_intersections_face_bounds( + lat=lat, + face_min_lat_rad=self.bounds.values[:, 0, 0], + face_max_lat_rad=self.bounds.values[:, 0, 1], + ) + return faces + elif method == "edge_intersection": + edges = self.get_edges_at_constant_latitude(lat, method) + faces = np.unique(self.edge_face_connectivity[edges].data.ravel()) + + return faces[faces != INT_FILL_VALUE] diff --git a/uxarray/grid/intersections.py b/uxarray/grid/intersections.py index 7e74622d4..c8aaa62f1 100644 --- a/uxarray/grid/intersections.py +++ b/uxarray/grid/intersections.py @@ -11,7 +11,7 @@ @njit(parallel=True, nogil=True, cache=True) -def fast_constant_lat_intersections(lat, edge_node_z, n_edge): +def constant_lat_intersections_edges(lat, edge_node_z, n_edge): """Determine which edges intersect a constant line of latitude on a sphere. Parameters @@ -49,6 +49,36 @@ def fast_constant_lat_intersections(lat, edge_node_z, n_edge): return np.unique(intersecting_edges) +@njit +def constant_lat_intersections_face_bounds(lat, face_min_lat_rad, face_max_lat_rad): + """Identifies the candidate faces on a grid that intersect with a given + constant latitude. + + This function checks whether the specified latitude, `lat`, in degrees lies within + the latitude bounds of grid faces, defined by `face_min_lat_rad` and `face_max_lat_rad`, + which are given in radians. The function returns the indices of the faces where the + latitude is within these bounds. + + Parameters + ---------- + lat : float + The latitude in degrees for which to find intersecting faces. + face_min_lat_rad : numpy.ndarray + A 1D array containing the minimum latitude bounds (in radians) of each face. + face_max_lat_rad : numpy.ndarray + A 1D array containing the maximum latitude bounds (in radians) of each face. + + Returns + ------- + candidate_faces : numpy.ndarray + A 1D array containing the indices of the faces that intersect with the given latitude. + """ + lat = np.deg2rad(lat) + within_bounds = (face_min_lat_rad <= lat) & (face_max_lat_rad >= lat) + candidate_faces = np.where(within_bounds)[0] + return candidate_faces + + def gca_gca_intersection(gca1_cart, gca2_cart, fma_disabled=True): """Calculate the intersection point(s) of two Great Circle Arcs (GCAs) in a Cartesian coordinate system.