Skip to content

Commit

Permalink
feat: Let find_nearest_nodes() locate more than 1 nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
pmav99 committed Sep 10, 2024
1 parent 9018101 commit c8d9766
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 23 deletions.
68 changes: 55 additions & 13 deletions pyposeidon/utils/cpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import typing as T

import numpy as np
import pandas as pd
import sklearn.neighbors
Expand All @@ -10,26 +12,52 @@ def closest_node(node, nodes):
return nodes[cdist([node], nodes).argmin()]


def get_ball_tree(
mesh_nodes: pd.DataFrame,
metric: str = "haversine",
) -> sklearn.neighbors.BallTree:
"""
Return a `BallTree` constructed from the provided `mesh_nodes`.
`mesh_nodes` must be a `pandas.DataFrames` with columns named
`lon` and `lat` and the coords must be in EPSG:4326.
"""
tree = sklearn.neighbors.BallTree(
np.radians(mesh_nodes[["lat", "lon"]]),
metric=metric,
)
return tree


def find_nearest_nodes(
mesh_nodes: pd.DataFrame,
points: pd.DataFrame,
k: int = 1,
metric: str = "haversine",
earth_radius = 6371000,
):
earth_radius=6371000,
tree: sklearn.neighbors.BallTree | None = None,
**kwargs: T.Any,
):
"""
Calculate the mesh nodes that are nearest to the specified `points`.
Calculate the `k` mesh nodes that are nearest to the specified `points`.
Both `mesh_nodes` and `points` must be `pandas.DataFrames` that have
columns named `lon` and `lat` and the coords must be in EPSG:4326.
As a speed optimization, the `tree` can be pre-constructed with ``get_ball_tree()``
(and/or serialized to disk using [skops](skops.io) or `pickle`)
and passed using the `tree` argument.
Returns the `points` DataFrame after adding these extra columns:
- `mesh_index` which is the index of the node in the `hgrid.gr3` file
- `mesh_index` which is the index of the mesh node
- `mesh_lon` which is the longitude of the nearest mesh node
- `mesh_lat` which is the latitude of the nearest mesh node
- `distance` which is the distance in meters between the point and the nearest mesh node
- `distance` which is the distance in meters between the point in question
and the nearest mesh node
Examples:
``` python
>>> mesh_nodes = pd.DataFrame({
... "lon": [0, 10, 20],
... "lat": [0, 5, 0],
Expand All @@ -45,18 +73,32 @@ def find_nearest_nodes(
0 1 1 a 0 0 0 157249.381272
1 11 4 b 1 10 5 157010.162641
2 21 1 c 2 20 0 157249.381272
>>> nearest_nodes = find_nearest_nodes(mesh_nodes, points, k=2)
>>> nearest_nodes
lon lat id mesh_index mesh_lon mesh_lat distance
0 1 1 a 0 0 0 1.572494e+05
1 1 1 a 1 10 5 1.093700e+06
2 11 4 b 1 10 5 1.570102e+05
3 11 4 b 2 20 0 1.094398e+06
4 21 1 c 2 20 0 1.572494e+05
5 21 1 c 1 10 5 1.299688e+06
```
"""
# The only requirement is that both `mesh_nodes and `points` have `lon/lat` columns
tree = sklearn.neighbors.BallTree(
np.radians(mesh_nodes[["lat", "lon"]]),
metric=metric,
# Resolve tree
if tree is None:
tree = get_ball_tree(mesh_nodes=mesh_nodes, metric=metric)
distances, indices = tree.query(
X=np.radians(points[["lat", "lon"]].values),
k=k,
return_distance=True,
**kwargs,
)
distances, indices = tree.query(np.radians(points[["lat", "lon"]].values))
closest_nodes = (
mesh_nodes
.rename(columns={"lon": "mesh_lon", "lat": "mesh_lat"})
mesh_nodes.rename(columns={"lon": "mesh_lon", "lat": "mesh_lat"})
.iloc[indices.flatten()]
.assign(distance=(distances.flatten() * earth_radius))
.reset_index(names=["mesh_index"])
)
return pd.concat((points, closest_nodes), axis="columns")
return pd.concat(
(points.loc[points.index.repeat(k)].reset_index(drop=True), closest_nodes), axis="columns"
)
46 changes: 36 additions & 10 deletions tests/utils/test_cpoint.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,49 @@
from __future__ import annotations

import pandas as pd
import pytest

from pyposeidon.utils.cpoint import find_nearest_nodes
from pyposeidon.utils.cpoint import get_ball_tree

EXPECTED_COLUMNS = ["lon", "lat", "id", "mesh_index", "mesh_lon", "mesh_lat", "distance"]


def test_find_nearest_nodes():
mesh_nodes = pd.DataFrame({
@pytest.fixture(scope="session")
def mesh_nodes():
return pd.DataFrame({
"lon": [0, 10, 20],
"lat": [0, 5, 0],
})
points = pd.DataFrame({
"lon": [1, 11, 21],
"lat": [1, 4, 1],
"id": ["a", "b", "c"],


@pytest.fixture(scope="session")
def points():
return pd.DataFrame({
"lon": [1, 11, 21, 2],
"lat": [1, 4, 1, 2],
"id": ["a", "b", "c", "d"],
})


@pytest.fixture(scope="session")
def ball_tree(mesh_nodes):
return get_ball_tree(mesh_nodes)


def test_find_nearest_nodes(mesh_nodes, points):
nearest_nodes = find_nearest_nodes(mesh_nodes, points)
assert isinstance(nearest_nodes, pd.DataFrame)
assert len(nearest_nodes) == 3
assert nearest_nodes.columns.tolist() == ["lon", "lat", "id", "mesh_index", "mesh_lon", "mesh_lat", "distance"]
assert nearest_nodes.mesh_index.tolist() == [0, 1, 2]
assert len(nearest_nodes) == len(points)
assert nearest_nodes.columns.tolist() == EXPECTED_COLUMNS
assert nearest_nodes.mesh_index.tolist() == [0, 1, 2, 0]
assert nearest_nodes.distance.min() > 150_000
assert nearest_nodes.distance.max() < 160_000
assert nearest_nodes.distance.max() < 320_000


@pytest.mark.parametrize("k", [pytest.param(2, id='2 points'), pytest.param(3, id='3 points')])
def test_find_nearest_nodes_multiple_points_and_pass_tree_as_argument(mesh_nodes, points, k, ball_tree):
nearest_nodes = find_nearest_nodes(mesh_nodes, points, k=k, tree=ball_tree)
assert isinstance(nearest_nodes, pd.DataFrame)
assert len(nearest_nodes) == len(points) * k
assert nearest_nodes.columns.tolist() == EXPECTED_COLUMNS

0 comments on commit c8d9766

Please sign in to comment.