Skip to content

Commit

Permalink
fix: lint
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsail committed Sep 14, 2024
1 parent 2645e7f commit 8e85173
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 47 deletions.
32 changes: 16 additions & 16 deletions pyposeidon/telemac.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,19 @@ def write_netcdf(ds, outpath):


def extract_t_elev_2D(
ds: xr.Dataset, x: float, y: float, var: str = "elev", xstr: str = "longitude", ystr: str = "latitude", max_dist: float = 1000,
ds: xr.Dataset,
x: float,
y: float,
var: str = "elev",
xstr: str = "longitude",
ystr: str = "latitude",
max_dist: float = 1000,
):
lons, lats = ds[xstr].values, ds[ystr].values
mesh = pd.DataFrame(np.vstack([x, y]).T, columns = ["lon", "lat"])
points = pd.DataFrame(np.vstack([lons, lats]).T, columns = ["lon", "lat"])
mesh = pd.DataFrame(np.vstack([x, y]).T, columns=["lon", "lat"])
points = pd.DataFrame(np.vstack([lons, lats]).T, columns=["lon", "lat"])
df = find_nearest_nodes(mesh, points, 1)
df = df[df.distance < max_dist]
df = df[df.distance < max_dist]
indx = df["mesh_index"]
ds_ = ds.isel(node=indx.values[0])
out_ = ds_[var].values
Expand Down Expand Up @@ -1102,7 +1108,7 @@ def run(self, api=True, **kwargs):
return

if self.fortran:
user_fortran = 'user_fortran'
user_fortran = "user_fortran"
else:
user_fortran = None

Expand Down Expand Up @@ -1430,16 +1436,10 @@ def set_obs(self, **kwargs):
return

mesh = pd.DataFrame(
np.array(
[
self.mesh.Dataset.SCHISM_hgrid_node_x.values,
self.mesh.Dataset.SCHISM_hgrid_node_y.values
]
).T,
columns = ["lon", "lat"])
points = pd.DataFrame(
np.array([tgn.longitude.values, tgn.latitude.values]).T,
columns = ["lon", "lat"])
np.array([self.mesh.Dataset.SCHISM_hgrid_node_x.values, self.mesh.Dataset.SCHISM_hgrid_node_y.values]).T,
columns=["lon", "lat"],
)
points = pd.DataFrame(np.array([tgn.longitude.values, tgn.latitude.values]).T, columns=["lon", "lat"])
df = find_nearest_nodes(mesh, points, 1)
df = df[df.distance < max_dist]

Expand All @@ -1453,7 +1453,7 @@ def set_obs(self, **kwargs):

# convert to MERCATOR coordinates
# dirty fix (this needs to be fixed in TELEMAC directly)
x, y = longlat2spherical(df["lon"], df["lat"],0,0)
x, y = longlat2spherical(df["lon"], df["lat"], 0, 0)
df["x"] = x
df["y"] = y

Expand Down
4 changes: 1 addition & 3 deletions pyposeidon/utils/cfl.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ def parse_hgrid(path: os.PathLike[str] | str) -> dict[str, T.Any]:
no_closed_boundaries = int(fd.readline().split(b"=")[0].strip())
total_closed_boundary_nodes = int(fd.readline().split(b"=")[0].strip())
for i in range(no_closed_boundaries):
no_nodes_in_boundary, boundary_type = map(
int, (fd.readline().split(b"=")[0].strip().split(b" "))
)
no_nodes_in_boundary, boundary_type = map(int, (fd.readline().split(b"=")[0].strip().split(b" ")))
boundary_nodes = np.fromiter(
fd,
count=no_nodes_in_boundary,
Expand Down
4 changes: 1 addition & 3 deletions pyposeidon/utils/cpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,4 @@ def find_nearest_nodes(
.assign(distance=(distances.flatten() * earth_radius))
.reset_index(names=["mesh_index"])
)
return pd.concat(
(points.loc[points.index.repeat(k)].reset_index(drop=True), closest_nodes), axis="columns"
)
return pd.concat((points.loc[points.index.repeat(k)].reset_index(drop=True), closest_nodes), axis="columns")
20 changes: 17 additions & 3 deletions pyposeidon/utils/obs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Observational Data retrieval """

from __future__ import annotations

import itertools
Expand Down Expand Up @@ -152,14 +153,27 @@ def serialize_stations(
msg = f"stations must have these columns too: {mandatory_cols.difference(df_cols)}"
raise ValueError(msg)
#
basic_cols = ["mesh_lon", "mesh_lat", "z", "separator", "unique_id", "mesh_index", "lon", "lat", "depth", "distance"]
basic_cols = [
"mesh_lon",
"mesh_lat",
"z",
"separator",
"unique_id",
"mesh_index",
"lon",
"lat",
"depth",
"distance",
]
station_in = stations.assign(
z=0,
separator="\t!\t",
)
station_in = station_in.set_index(station_in.index +1)
station_in = station_in.set_index(station_in.index + 1)
station_in = station_in[basic_cols]
with open(f"{path}", "w") as fd:
fd.write(f"{schism_station_flag.strip()}\t ! https://schism-dev.github.io/schism/master/input-output/optional-inputs.html#stationin-bp-format\n")
fd.write(
f"{schism_station_flag.strip()}\t ! https://schism-dev.github.io/schism/master/input-output/optional-inputs.html#stationin-bp-format\n"
)
fd.write(f"{len(station_in)}\t ! number of stations\n")
station_in.to_csv(fd, header=None, sep=" ", float_format="%.10f")
1 change: 1 addition & 0 deletions pyposeidon/utils/pplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import sys
import os

# from pyposeidon.tools import to_geodataframe

ffmpeg = sys.exec_prefix + "/bin/ffmpeg"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_telemac.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
},
}

case4 = { # test does not work with telemac3d: mesh quality is too bad
case4 = { # test does not work with telemac3d: mesh quality is too bad
"solver_name": "telemac",
"mesh_file": MESH_FILE,
"module": "telemac3d",
Expand Down
24 changes: 14 additions & 10 deletions tests/utils/test_cpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,23 @@

@pytest.fixture(scope="session")
def mesh_nodes():
return pd.DataFrame({
"lon": [0, 10, 20],
"lat": [0, 5, 0],
})
return pd.DataFrame(
{
"lon": [0, 10, 20],
"lat": [0, 5, 0],
}
)


@pytest.fixture(scope="session")
def points():
return pd.DataFrame({
"lon": [1, 11, 21, 2],
"lat": [1, 4, 1, 2],
"id": ["a", "b", "c", "d"],
})
return pd.DataFrame(
{
"lon": [1, 11, 21, 2],
"lat": [1, 4, 1, 2],
"id": ["a", "b", "c", "d"],
}
)


@pytest.fixture(scope="session")
Expand All @@ -41,7 +45,7 @@ def test_find_nearest_nodes(mesh_nodes, points):
assert nearest_nodes.distance.max() < 320_000


@pytest.mark.parametrize("k", [pytest.param(2, id='2 points'), pytest.param(3, id='3 points')])
@pytest.mark.parametrize("k", [pytest.param(2, id="2 points"), pytest.param(3, id="3 points")])
def test_find_nearest_nodes_multiple_points_and_pass_tree_as_argument(mesh_nodes, points, k, ball_tree):
nearest_nodes = find_nearest_nodes(mesh_nodes, points, k=k, tree=ball_tree)
assert isinstance(nearest_nodes, pd.DataFrame)
Expand Down
24 changes: 13 additions & 11 deletions tests/utils/test_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@ def test_serialize_stations(tmp_path):
3 20.0000000000 0.0000000000 0 ! c 2 21.0000000000 1.0000000000 1 157249.3812719441
"""
)
stations = pd.DataFrame({
'lon': [1., 11., 21.],
'lat': [1., 4., 1.],
'unique_id': ["a", "b", "c"],
'extra_col': ["AA", "BB", "CC"],
'mesh_index': [0, 1, 2],
'mesh_lon': [0., 10., 20.],
'mesh_lat': [0., 5., 0.],
'distance': [157249.38127194397, 157010.16264060183, 157249.38127194406],
'depth': [3, 5, 1],
})
stations = pd.DataFrame(
{
"lon": [1.0, 11.0, 21.0],
"lat": [1.0, 4.0, 1.0],
"unique_id": ["a", "b", "c"],
"extra_col": ["AA", "BB", "CC"],
"mesh_index": [0, 1, 2],
"mesh_lon": [0.0, 10.0, 20.0],
"mesh_lat": [0.0, 5.0, 0.0],
"distance": [157249.38127194397, 157010.16264060183, 157249.38127194406],
"depth": [3, 5, 1],
}
)
path = tmp_path / "station.in"
serialize_stations(stations, path)
contents = path.read_text()
Expand Down

0 comments on commit 8e85173

Please sign in to comment.