diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml
new file mode 100644
index 0000000..2145d43
--- /dev/null
+++ b/.github/workflows/cd.yml
@@ -0,0 +1,53 @@
+name: cd-module
+
+on:
+ workflow_dispatch:
+ inputs:
+ module-name:
+ description: 'Module name'
+ required: true
+ type: string
+ push:
+ tags:
+ - '[0-9]+.[0-9]+.[0-9]+'
+jobs:
+ install_module:
+ runs-on: [hpc]
+ steps:
+ - uses: ecmwf-actions/reusable-workflows/ci-hpc-generic@v2
+ with:
+ troika_user: ${{ secrets.HPC_CI_SSH_USER }}
+ template: |
+ # Declare an array to store the module versions
+ python3_versions=("3.10.10-01" "3.11.8-01")
+
+
+ HAT_VERSION=${{ github.event_name == 'workflow_dispatch' && inputs.module-name || github.ref_name }}
+
+ PREFIX=/usr/local/apps/hat/${HAT_VERSION}
+ rm -rf $PREFIX
+ mkdir -p $PREFIX
+
+ # Loop through the module versions
+ for version in "${python3_versions[@]}"
+ do
+ # Load the module for the current version
+ module load python3/$version
+
+ PYTHONUSERBASE=$PREFIX pip3 install --user git+https://github.com/ecmwf/hat.git@${HAT_VERSION}
+
+ module unload python3
+ done
+
+ cat > $PREFIX/README.txt << EOF
+ -- [hat] ($HAT_VERSION) [$HAT_VERSION]
+ EOF
+
+ software-sync -s local -p hat
+ module load modulemgr
+ modulemgr -f -m all sync hat
+
+ sbatch_options: |
+ #SBATCH --job-name=cd_hat
+ #SBATCH --time=00:10:00
+ #SBATCH --qos=deploy
\ No newline at end of file
diff --git a/.github/workflows/on-push.yml b/.github/workflows/on-push.yml
index 968e5fa..e4bb583 100644
--- a/.github/workflows/on-push.yml
+++ b/.github/workflows/on-push.yml
@@ -5,7 +5,7 @@ on:
branches:
- main
tags:
- - "*"
+ - "[0-9]+.[0-9]+.[0-9]+"
pull_request:
branches:
- main
diff --git a/environment.yml b/environment.yml
index 62feffc..0ff4ee5 100644
--- a/environment.yml
+++ b/environment.yml
@@ -14,7 +14,6 @@ dependencies:
- tqdm
- typer
- humanize
- - typer
- ipyleaflet
- ipywidgets
- pip
diff --git a/hat/config.py b/hat/config.py
index 49318da..f4e1037 100644
--- a/hat/config.py
+++ b/hat/config.py
@@ -13,7 +13,9 @@ def load_package_config(fname):
resource_package = "hat"
resource_path = os.path.join("config_json", fname)
- config_string = pkg_resources.resource_string(resource_package, resource_path)
+ config_string = pkg_resources.resource_string(
+ resource_package, resource_path
+ )
config = json.loads(config_string.decode())
return config
@@ -50,7 +52,9 @@ def booleanify(config, key):
return
if str(config[key]).lower() not in ["true", "false"]:
- raise ValueError(f'"{key}" configuration variable must be "True" or "False"')
+ raise ValueError(
+ f'"{key}" configuration variable must be "True" or "False"'
+ )
if config[key].lower() == "true":
config[key] = True
diff --git a/hat/data.py b/hat/data.py
index 5b6c81a..c9d9ee5 100644
--- a/hat/data.py
+++ b/hat/data.py
@@ -77,7 +77,9 @@ def find_files(simulation_files):
fpaths = glob.glob(simulation_files)
if not fpaths:
- raise Exception(f"Could not find any file from regex {simulation_files}")
+ raise Exception(
+ f"Could not find any file from regex {simulation_files}"
+ )
else:
print("Found following simulation files:")
print(*fpaths, sep="\n")
@@ -113,7 +115,9 @@ def dirsize(simulation_datadir, bytesize=False):
"""given a root directory return total size of all files
in a directory in human readable format"""
- if not os.path.exists(simulation_datadir) or not os.path.isdir(simulation_datadir):
+ if not os.path.exists(simulation_datadir) or not os.path.isdir(
+ simulation_datadir
+ ):
print("Not a directory", simulation_datadir)
return
@@ -178,7 +182,9 @@ def raster_loader(fpath: str, epsg: int = 4326, engine=None):
xarr.attrs["fname"] = os.path.basename(fpath)
xarr.attrs["experiment_name"] = pathlib.Path(fpath).stem
xarr.attrs["file_extension"] = pathlib.Path(fpath).suffix
- xarr.attrs["river_network"] = river_network_from_filename(xarr.attrs["fname"])
+ xarr.attrs["river_network"] = river_network_from_filename(
+ xarr.attrs["fname"]
+ )
return xarr
@@ -261,7 +267,9 @@ def save_dataset_to_netcdf(ds: xr.Dataset, fpath: str):
def find_main_var(ds, min_dim=3):
- variable_names = [k for k in ds.variables if len(ds.variables[k].dims) >= min_dim]
+ variable_names = [
+ k for k in ds.variables if len(ds.variables[k].dims) >= min_dim
+ ]
if len(variable_names) > 1:
raise Exception("More than one variable in dataset")
elif len(variable_names) == 0:
@@ -291,10 +299,16 @@ def read_simulation_as_xarray(options):
fs = earthkit.data.from_source(src_type, *args)
xarray_kwargs = {}
- if isinstance(fs, earthkit.data.readers.netcdf.fieldlist.NetCDFMultiFieldList):
- xarray_kwargs["xarray_open_mfdataset_kwargs"] = {"chunks": {"time": "auto"}}
+ if isinstance(
+ fs, earthkit.data.readers.netcdf.fieldlist.NetCDFMultiFieldList
+ ):
+ xarray_kwargs["xarray_open_mfdataset_kwargs"] = {
+ "chunks": {"time": "auto"}
+ }
else:
- xarray_kwargs["xarray_open_dataset_kwargs"] = {"chunks": {"time": "auto"}}
+ xarray_kwargs["xarray_open_dataset_kwargs"] = {
+ "chunks": {"time": "auto"}
+ }
# xarray dataset
ds = fs.to_xarray(**xarray_kwargs)
diff --git a/hat/extract_simulation_timeseries.py b/hat/extract_simulation_timeseries.py
index acb9d03..12f97ae 100644
--- a/hat/extract_simulation_timeseries.py
+++ b/hat/extract_simulation_timeseries.py
@@ -36,7 +36,11 @@ def extract_timeseries(
# all timeseries (i.e. handle proximal and duplicate stations)
da_stations = assign_stations(
- stations, station_mask, da_points, coords, config["station_id_column_name"]
+ stations,
+ station_mask,
+ da_points,
+ coords,
+ config["station_id_column_name"],
)
return da_stations
diff --git a/hat/filters.py b/hat/filters.py
index bb27f93..5514b60 100644
--- a/hat/filters.py
+++ b/hat/filters.py
@@ -4,7 +4,10 @@
# @st.cache_data
def temporal_filter(
- _metadata, _observations: pd.DataFrame, timeperiod, station_id_name="station_id"
+ _metadata,
+ _observations: pd.DataFrame,
+ timeperiod,
+ station_id_name="station_id",
):
"""
filter station metadata and timeseries by timeperiod
@@ -89,7 +92,9 @@ def stations_with_discharge(obs, timeperiod, metadata):
return metadata, obsdis
-def apply_filter(df: pd.DataFrame, key: str, operator: str, value: str) -> pd.DataFrame:
+def apply_filter(
+ df: pd.DataFrame, key: str, operator: str, value: str
+) -> pd.DataFrame:
"""
Apply the filter on the DataFrame based on the provided
key, operator, and value.
@@ -104,7 +109,9 @@ def apply_filter(df: pd.DataFrame, key: str, operator: str, value: str) -> pd.Da
}
if key not in df.columns:
- raise ValueError(f"Key '{key}' does not exist as column name in dataframe")
+ raise ValueError(
+ f"Key '{key}' does not exist as column name in dataframe"
+ )
if operator not in operators:
raise ValueError(f"Operator '{operator}' is not supported")
@@ -134,19 +141,25 @@ def filter_dataframe(df, filters: str):
continue
parts = filter_str.split()
if len(parts) != 3:
- raise ValueError("Invalid filter format. Expected 'key operator value'.")
+ raise ValueError(
+ "Invalid filter format. Expected 'key operator value'."
+ )
key, operator, value = parts
df = apply_filter(df, key, operator, value)
if len(df) == 0:
- raise ValueError("There are no remaining rows (try different filters?)")
+ raise ValueError(
+ "There are no remaining rows (try different filters?)"
+ )
return df
-def filter_timeseries(sims_ds: xr.DataArray, obs_ds: xr.DataArray, threshold=80):
+def filter_timeseries(
+ sims_ds: xr.DataArray, obs_ds: xr.DataArray, threshold=80
+):
"""Clean the simulation and observation timeseries
Only keep..
diff --git a/hat/geo.py b/hat/geo.py
index 6c0408b..38ac2b2 100644
--- a/hat/geo.py
+++ b/hat/geo.py
@@ -116,7 +116,9 @@ def river_network_geometry(metadata, river_network, version=1):
if river_network == "EPSG:4326":
return metadata
- lon_name, lat_name = name_of_adjusted_coords(river_network, version=version)
+ lon_name, lat_name = name_of_adjusted_coords(
+ river_network, version=version
+ )
gdf = metadata.copy(deep=True)
gdf["geometry"] = gpd.points_from_xy(gdf[lon_name], gdf[lat_name])
@@ -157,7 +159,9 @@ def river_network_to_coord_names(river_network: str = "") -> dict:
x, y = river_network_to_coords[river_network]
return {"x": x, "y": y}
else:
- print(f"River network '{river_network}' not in: {valid_river_networks}")
+ print(
+ f"River network '{river_network}' not in: {valid_river_networks}"
+ )
def geojson_schema():
@@ -230,7 +234,9 @@ def geojson_schema():
{"$ref": "#/definitions/pointCoordinates"},
{"$ref": "#/definitions/multiPointCoordinates"},
{"$ref": "#/definitions/lineStringCoordinates"},
- {"$ref": "#/definitions/multiLineStringCoordinates"},
+ {
+ "$ref": "#/definitions/multiLineStringCoordinates"
+ },
{"$ref": "#/definitions/polygonCoordinates"},
{"$ref": "#/definitions/multiPolygonCoordinates"},
]
@@ -307,8 +313,12 @@ def geopoints_to_array(gdf, array_coords) -> np.ndarray:
point_ys = np.array([point.y for point in points])
# nearest neighbour indices of the points in the array
- x_indices = [(np.abs(array_coords["x"] - point_x)).argmin() for point_x in point_xs]
- y_indices = [(np.abs(array_coords["y"] - point_y)).argmin() for point_y in point_ys]
+ x_indices = [
+ (np.abs(array_coords["x"] - point_x)).argmin() for point_x in point_xs
+ ]
+ y_indices = [
+ (np.abs(array_coords["y"] - point_y)).argmin() for point_y in point_ys
+ ]
# create an empty boolean array
shape = (len(array_coords["y"]), len(array_coords["x"]))
@@ -320,7 +330,9 @@ def geopoints_to_array(gdf, array_coords) -> np.ndarray:
return arr
-def geopoints_from_csv(fpath: str, lat_name: str, lon_name: str) -> gpd.GeoDataFrame:
+def geopoints_from_csv(
+ fpath: str, lat_name: str, lon_name: str
+) -> gpd.GeoDataFrame:
"""Load georeferenced points from file.
Requires name of latitude and longitude columns"""
diff --git a/hat/interactive/explorers.py b/hat/interactive/explorers.py
index 0c88d4c..0955da5 100644
--- a/hat/interactive/explorers.py
+++ b/hat/interactive/explorers.py
@@ -106,7 +106,9 @@ def prepare_observations_data(observations, sim_ds, obs_var_name):
return obs_ds
-def find_common_stations(station_index, stations_metadata, obs_ds, sim_ds, statistics):
+def find_common_stations(
+ station_index, stations_metadata, obs_ds, sim_ds, statistics
+):
"""
Find common stations between observations, simulations and station
metadata.
@@ -260,7 +262,11 @@ def __init__(self, config):
self.title_label = ipywidgets.Label(
"Interactive Map Visualisation for Hydrological Model Performance",
layout=ipywidgets.Layout(justify_content="center"),
- style={"font_weight": "bold", "font_size": "24px", "font_family": "Arial"},
+ style={
+ "font_weight": "bold",
+ "font_size": "24px",
+ "font_family": "Arial",
+ },
)
# Create the interactive widgets
@@ -269,7 +275,9 @@ def __init__(self, config):
widgets = {}
widgets["plot"] = PlotlyWidget(datasets)
widgets["stats"] = StatisticsWidget(self.statistics)
- widgets["meta"] = MetaDataWidget(self.stations_metadata, self.station_index)
+ widgets["meta"] = MetaDataWidget(
+ self.stations_metadata, self.station_index
+ )
self.widgets = WidgetsManager(
widgets, config["station_id_column_name"], self.loading_widget
)
@@ -302,7 +310,10 @@ def create_frame(self):
width="40%",
)
right_layout = ipywidgets.Layout(
- justify_content="center", align_items="center", spacing="2px", width="60%"
+ justify_content="center",
+ align_items="center",
+ spacing="2px",
+ width="60%",
)
# Frames
diff --git a/hat/interactive/leaflet.py b/hat/interactive/leaflet.py
index 10ff223..b2de215 100644
--- a/hat/interactive/leaflet.py
+++ b/hat/interactive/leaflet.py
@@ -39,7 +39,8 @@ def __init__(
basemap=ipyleaflet.basemaps.OpenStreetMap.Mapnik,
):
self.map = ipyleaflet.Map(
- basemap=basemap, layout=ipywidgets.Layout(width="100%", height="600px")
+ basemap=basemap,
+ layout=ipywidgets.Layout(width="100%", height="600px"),
)
self.legend_widget = ipywidgets.Output()
@@ -56,7 +57,10 @@ def _set_boundaries(self, stations_metadata, coord_names):
min_lat, max_lat = min(lats), max(lats)
min_lon, max_lon = min(lons), max(lons)
- bounds = [(float(min_lat), float(min_lon)), (float(max_lat), float(max_lon))]
+ bounds = [
+ (float(min_lat), float(min_lon)),
+ (float(max_lat), float(max_lon)),
+ ]
self.map.fit_bounds(bounds)
def add_geolayer(self, geodata, colormap, widgets, coord_names=None):
@@ -221,7 +225,8 @@ def legend(self):
"""
# Convert the colormap to a list of RGB values
rgb_values = [
- mpl.colors.rgb2hex(self.colormap(i)) for i in np.linspace(0, 1, 256)
+ mpl.colors.rgb2hex(self.colormap(i))
+ for i in np.linspace(0, 1, 256)
]
# Create a gradient style using the RGB values
diff --git a/hat/interactive/widgets.py b/hat/interactive/widgets.py
index ba7c5d0..715100e 100644
--- a/hat/interactive/widgets.py
+++ b/hat/interactive/widgets.py
@@ -156,7 +156,9 @@ def _filter_nan_values(dates, data_values):
assert len(dates) == len(
data_values
), "Dates and data values must be the same length."
- valid_dates = [date for date, val in zip(dates, data_values) if not np.isnan(val)]
+ valid_dates = [
+ date for date, val in zip(dates, data_values) if not np.isnan(val)
+ ]
valid_data = [val for val in data_values if not np.isnan(val)]
return valid_dates, valid_data
@@ -192,7 +194,11 @@ def __init__(self, datasets):
height=350,
margin=dict(l=120),
legend=dict(
- orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1
+ orientation="h",
+ yanchor="bottom",
+ y=1.02,
+ xanchor="right",
+ x=1,
),
xaxis_title="Date",
xaxis_tickformat="%d-%m-%Y",
@@ -214,7 +220,9 @@ def __init__(self, datasets):
date_picker_box = HBox([self.start_date_picker, self.end_date_picker])
layout = Layout(justify_content="center", align_items="center")
- output = VBox([self.figure, date_label, date_picker_box], layout=layout)
+ output = VBox(
+ [self.figure, date_label, date_picker_box], layout=layout
+ )
super().__init__(output)
def _update_plot_dates(self):
@@ -262,9 +270,7 @@ def _update_title(self, metadata):
"""
station_id = metadata["station_id"]
station_name = metadata["StationName"]
- updated_title = (
- f"Selected station:
ID: {station_id}, name: {station_name} "
- )
+ updated_title = f"Selected station:
ID: {station_id}, name: {station_name} " # noqa: E501
self.figure.update_layout(
title={
"text": updated_title,
@@ -435,7 +441,9 @@ def __init__(self, dataframe, station_index):
def _extract_dataframe(self, station_id):
stations_df = self.dataframe
- selected_station_df = stations_df[stations_df[self.station_index] == station_id]
+ selected_station_df = stations_df[
+ stations_df[self.station_index] == station_id
+ ]
return selected_station_df
@@ -489,7 +497,11 @@ def _extract_dataframe(self, station_id):
statistics_df = pd.DataFrame(data, columns=columns)
# Round the numerical columns to 2 decimal places
- numerical_columns = [col for col in statistics_df.columns if col != "Exp. name"]
- statistics_df[numerical_columns] = statistics_df[numerical_columns].round(2)
+ numerical_columns = [
+ col for col in statistics_df.columns if col != "Exp. name"
+ ]
+ statistics_df[numerical_columns] = statistics_df[
+ numerical_columns
+ ].round(2)
return statistics_df
diff --git a/hat/mapping/evaluation.py b/hat/mapping/evaluation.py
index 61c08a4..e4bdb8f 100644
--- a/hat/mapping/evaluation.py
+++ b/hat/mapping/evaluation.py
@@ -43,7 +43,9 @@ def calculate_rmse(df, column_reference, column_evaluated):
rmse = round(
np.sqrt(((df[column_reference] - df[column_evaluated]) ** 2).mean()), 2
)
- print(f"RMSE between {column_reference} and {column_evaluated}: {rmse} km2")
+ print(
+ f"RMSE between {column_reference} and {column_evaluated}: {rmse} km2"
+ )
return rmse
@@ -101,7 +103,9 @@ def count_and_analyze_area_distance(
for index, row in df.iterrows():
area_diff = abs(
- calculate_area_diff_percentage(row[eval_area_col], row[ref_area_col])
+ calculate_area_diff_percentage(
+ row[eval_area_col], row[ref_area_col]
+ )
)
if area_diff <= area_diff_limit:
@@ -123,7 +127,9 @@ def count_and_analyze_area_distance(
row[eval_lon_col],
)
)
- distance_freq[grid_distance] = distance_freq.get(grid_distance, 0) + 1
+ distance_freq[grid_distance] = (
+ distance_freq.get(grid_distance, 0) + 1
+ )
count_inside_area_limit += 1
if grid_distance <= distance_limit:
diff --git a/hat/mapping/station_mapping.py b/hat/mapping/station_mapping.py
index a9f0425..4f2c6f0 100644
--- a/hat/mapping/station_mapping.py
+++ b/hat/mapping/station_mapping.py
@@ -97,7 +97,9 @@ def find_best_matching_grid(
def create_grid_polygon(lat, lon, cell_size):
"""Create a rectangular polygon around the given lat/lon based on cell size."""
half_cell = cell_size / 2
- return box(lon - half_cell, lat - half_cell, lon + half_cell, lat + half_cell)
+ return box(
+ lon - half_cell, lat - half_cell, lon + half_cell, lat + half_cell
+ )
def process_station_data(
@@ -152,7 +154,9 @@ def process_station_data(
nearest grid cell, and best matching grid cell (if applicable).
"""
lat, lon = float(station[lat_col]), float(station[lon_col])
- station_area = float(station[csv_variable]) if station[csv_variable] else np.nan
+ station_area = (
+ float(station[csv_variable]) if station[csv_variable] else np.nan
+ )
# manually mapped variable
if manual_area is not None:
@@ -171,7 +175,9 @@ def process_station_data(
manual_lat_idx, manual_lon_idx = get_grid_index(
manual_lat, manual_lon, latitudes, longitudes
)
- manual_area = float(station[manual_area]) if station[manual_area] else np.nan
+ manual_area = (
+ float(station[manual_area]) if station[manual_area] else np.nan
+ )
else:
manual_lat = np.nan
@@ -182,8 +188,12 @@ def process_station_data(
# Nearest grid cell
lat_idx, lon_idx = get_grid_index(lat, lon, latitudes, longitudes)
near_grid_area = nc_data[lat_idx, lon_idx]
- near_grid_area = float(near_grid_area) if not is_masked(near_grid_area) else np.nan
- near_area_diff = calculate_area_diff_percentage(near_grid_area, station_area)
+ near_grid_area = (
+ float(near_grid_area) if not is_masked(near_grid_area) else np.nan
+ )
+ near_area_diff = calculate_area_diff_percentage(
+ near_grid_area, station_area
+ )
near_distance_km = calculate_distance_km(
lat, lon, latitudes[lat_idx], longitudes[lon_idx]
)
@@ -207,7 +217,9 @@ def process_station_data(
)
optimum_grid_area = nc_data[optimum_lat_idx, optimum_lon_idx]
optimum_grid_area = (
- float(optimum_grid_area) if not is_masked(optimum_grid_area) else np.nan
+ float(optimum_grid_area)
+ if not is_masked(optimum_grid_area)
+ else np.nan
)
optimum_area_diff = calculate_area_diff_percentage(
optimum_grid_area, station_area
@@ -368,8 +380,12 @@ def save_geo_dataframes(df, out_dir, cell_size):
)
# Convert any additional geometry columns to WKT for serialization
- df["near_grid_polygon_wkt"] = df["near_grid_polygon"].apply(lambda x: x.wkt)
- df["optimum_grid_polygon_wkt"] = df["optimum_grid_polygon"].apply(lambda x: x.wkt)
+ df["near_grid_polygon_wkt"] = df["near_grid_polygon"].apply(
+ lambda x: x.wkt
+ )
+ df["optimum_grid_polygon_wkt"] = df["optimum_grid_polygon"].apply(
+ lambda x: x.wkt
+ )
# Drop the Shapely object columns that were replaced by wkts
df = df.drop(columns=["near_grid_polygon", "optimum_grid_polygon"])
@@ -387,7 +403,10 @@ def save_geo_dataframes(df, out_dir, cell_size):
# Create GeoDataFrames
gdf_station_point = gpd.GeoDataFrame(
- df, geometry=[Point(xy) for xy in zip(df["station_lon"], df["station_lat"])]
+ df,
+ geometry=[
+ Point(xy) for xy in zip(df["station_lon"], df["station_lat"])
+ ],
)
gdf_near_grid_polygon = gpd.GeoDataFrame(
df, geometry=df["near_grid_polygon_wkt"].apply(loads)
@@ -408,7 +427,8 @@ def save_geo_dataframes(df, out_dir, cell_size):
os.path.join(out_dir, "optimum_grid.geojson"), driver="GeoJSON"
)
gdf_line_optimum.to_file(
- os.path.join(out_dir, "stations2grid_optimum_line.geojson"), driver="GeoJSON"
+ os.path.join(out_dir, "stations2grid_optimum_line.geojson"),
+ driver="GeoJSON",
)
gdf_station_point.to_csv(os.path.join(out_dir, "stations.csv"))
diff --git a/hat/mapping/visualisation.py b/hat/mapping/visualisation.py
index c6399df..f7f3a25 100644
--- a/hat/mapping/visualisation.py
+++ b/hat/mapping/visualisation.py
@@ -16,7 +16,9 @@
class GeoJSONLayerManager:
- def __init__(self, path, style_callback=None, point_style=None, name="Layer"):
+ def __init__(
+ self, path, style_callback=None, point_style=None, name="Layer"
+ ):
self.gdf = gpd.read_file(path)
self.style_callback = style_callback
self.point_style = point_style
@@ -37,7 +39,9 @@ def add_to_map(self, map_object):
name=self.name,
)
else:
- self.layer = GeoJSON(data=self.gdf.__geo_interface__, name=self.name)
+ self.layer = GeoJSON(
+ data=self.gdf.__geo_interface__, name=self.name
+ )
map_object.add_layer(self.layer)
@@ -122,12 +126,18 @@ def line_click_handler(feature, **kwargs):
station_area = feature["properties"].get(station_area_attr, "N/A")
near_area = feature["properties"].get(near_area_attr, "N/A")
optimum_area = feature["properties"].get(optimum_area_attr, "N/A")
- optimum_distance_cells = feature["properties"].get(optimum_dist_attr, "N/A")
+ optimum_distance_cells = feature["properties"].get(
+ optimum_dist_attr, "N/A"
+ )
# Format numbers with comma separators
- station_area = f"{station_area:,.1f}" if station_area != "N/A" else station_area
+ station_area = (
+ f"{station_area:,.1f}" if station_area != "N/A" else station_area
+ )
near_area = f"{near_area:,.1f}" if near_area != "N/A" else near_area
- optimum_area = f"{optimum_area:,.1f}" if optimum_area != "N/A" else optimum_area
+ optimum_area = (
+ f"{optimum_area:,.1f}" if optimum_area != "N/A" else optimum_area
+ )
optimum_distance_cells = (
f"{optimum_distance_cells:,.1f}"
if optimum_distance_cells != "N/A"
@@ -203,7 +213,9 @@ def vector_style(feature, color, opacity=0.5, weight=1):
}
-def attribute_based_style(row, attribute_name, threshold, color_above, color_below):
+def attribute_based_style(
+ row, attribute_name, threshold, color_above, color_below
+):
"""Style function for GeoJSON features based on an attribute value."""
attribute_value = row.get(attribute_name)
if attribute_value is not None:
diff --git a/hat/observations.py b/hat/observations.py
index 467eb70..45e2b72 100644
--- a/hat/observations.py
+++ b/hat/observations.py
@@ -22,7 +22,10 @@ def add_geometry_column(gdf: gpd.GeoDataFrame, coord_names):
# Filter rows that do not plot on Earth (e.g. -9999)
gdf = gdf[
- (gdf["x"] >= -180) & (gdf["x"] <= 180) & (gdf["y"] >= -90) & (gdf["y"] <= 90)
+ (gdf["x"] >= -180)
+ & (gdf["x"] <= 180)
+ & (gdf["y"] >= -90)
+ & (gdf["y"] <= 90)
]
# Create a geometry column
diff --git a/hat/timeseries.py b/hat/timeseries.py
index 3cb6a4e..c528a5b 100644
--- a/hat/timeseries.py
+++ b/hat/timeseries.py
@@ -114,7 +114,9 @@ def assign_stations(
station_id = station[station_dim]
# timeseries index for a given station
- timeseries_index = station_timeseries_index(station, lon_in_mask, lat_in_mask)
+ timeseries_index = station_timeseries_index(
+ station, lon_in_mask, lat_in_mask
+ )
# add to the list
stations_id += [station_id]
diff --git a/hat/tools/extract_simulation_timeseries_cli.py b/hat/tools/extract_simulation_timeseries_cli.py
index e63c833..445b069 100644
--- a/hat/tools/extract_simulation_timeseries_cli.py
+++ b/hat/tools/extract_simulation_timeseries_cli.py
@@ -19,11 +19,16 @@
# hat modules
from hat.data import read_simulation_as_xarray, save_dataset_to_netcdf
-from hat.extract_simulation_timeseries import DEFAULT_CONFIG, extract_timeseries
+from hat.extract_simulation_timeseries import (
+ DEFAULT_CONFIG,
+ extract_timeseries,
+)
from hat.observations import read_station_metadata_file
-def print_overview(config: dict, station_metadata: gpd.GeoDataFrame, simulation):
+def print_overview(
+ config: dict, station_metadata: gpd.GeoDataFrame, simulation
+):
"""Print overview of relevant information for user"""
title("Configuration", color="cyan")
diff --git a/hat/tools/hydrostats_cli.py b/hat/tools/hydrostats_cli.py
index 204f92a..fa05935 100644
--- a/hat/tools/hydrostats_cli.py
+++ b/hat/tools/hydrostats_cli.py
@@ -29,10 +29,14 @@ def check_inputs(functions, sims, obs):
)
if not sims.endswith(".nc"):
- raise UserError(f"Simulation filepath must end with .nc was given: {sims}")
+ raise UserError(
+ f"Simulation filepath must end with .nc was given: {sims}"
+ )
if not obs.endswith(".nc"):
- raise UserError(f"Observation filepath must end with .nc was given: {obs}")
+ raise UserError(
+ f"Observation filepath must end with .nc was given: {obs}"
+ )
return True
@@ -91,7 +95,9 @@ def hydrostats_cli(
obs_da = obs_ds[var]
# clean timeseries
- sims_da, obs_da = filter_timeseries(sims_da, obs_da, threshold=obs_threshold)
+ sims_da, obs_da = filter_timeseries(
+ sims_da, obs_da, threshold=obs_threshold
+ )
# calculate statistics
statistics_ds = run_analysis(functions, sims_da, obs_da)
diff --git a/hat/version.py b/hat/version.py
deleted file mode 100644
index 49e0fc1..0000000
--- a/hat/version.py
+++ /dev/null
@@ -1 +0,0 @@
-__version__ = "0.7.0"
diff --git a/pyproject.toml b/pyproject.toml
index bde8513..15b40ce 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,3 +1,105 @@
-# required for isort and black to play nicely
+
+[build-system]
+requires = ["setuptools>=65", "setuptools_scm[toml]>=6.2"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "hydro_analysis_toolkit"
+requires-python = ">=3.8"
+authors = [
+ {name = "European Centre for Medium-Range Weather Forecasts (ECMWF)", email = "software.support@ecmwf.int"},
+]
+maintainers = [
+ {name = "Corentin Carton de Wiart", email = "corentin.carton@ecmwf.int"},
+]
+description = "ECMWF Hydrological Analysis Tools"
+license = {file = "LICENSE"}
+classifiers = [
+ "Development Status :: 4 - Beta",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: Apache Software License",
+ "Natural Language :: English",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3 :: Only",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.10",
+ "Topic :: Scientific/Engineering",
+]
+dynamic = ["version", "readme"]
+
+dependencies = [
+ "numpy",
+ "pandas",
+ "xarray",
+ "matplotlib",
+ "geopandas",
+ "typer",
+ "humanize",
+ "tqdm",
+ "ipyleaflet",
+ "ipywidgets",
+ "earthkit-data",
+ "cfgrib"
+]
+
+[project.urls]
+ repository = "https://github.com/ecmwf/hat"
+ documentation = "https://hydro-analysis-toolkit.readthedocs.io"
+ issues = "https://github.com/ecmwf/hat/issues"
+
+[project.optional-dependencies]
+ test = [
+ "pytest",
+ ]
+ docs = [
+ "mkdocs",
+ "mkdocs-material",
+ "mkdocstrings-python",
+ "mkdocs-autorefs",
+ "mkdocs-gen-files",
+ "pymdown-extensions",
+ "markdown-exec[ansi]",
+ ]
+
+[project.scripts]
+ hat-extract-timeseries = "hat.tools.extract_simulation_timeseries_cli:main"
+ hat-hydrostats = "hat.tools.hydrostats_cli:main"
+ hat-stations-mapping = "hat.mapping.station_mapping:main"
+
+# Code inspection
+[tool.black]
+line-length = 79
+
[tool.isort]
-profile = "black"
\ No newline at end of file
+profile="black"
+
+# Testing
+[tool.pytest]
+addopts = "--pdbcls=IPython.terminal.debugger:Pdb"
+testpaths = [
+ "tests",
+]
+
+# Packaging/setuptools options
+[tool.setuptools]
+include-package-data = true
+
+[tool.setuptools.dynamic]
+readme = {file = ["README.md"], content-type = "text/markdown"}
+
+# [tool.setuptools.package_data]
+# hat = ["config_json/*.json"]
+
+[tool.setuptools.packages.find]
+where = ["."]
+exclude = ["tests"]
+
+[tool.setuptools_scm]
+write_to = "hat/_version.py"
+write_to_template = '''
+# Do not change! Do not track in version control!
+__version__ = "{version}"
+'''
+parentdir_prefix_version='hat-' # get version from GitHub-like tarballs
+fallback_version='0.7.1'
\ No newline at end of file
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index d436303..0000000
--- a/setup.cfg
+++ /dev/null
@@ -1,27 +0,0 @@
-[metadata]
-name = hydro_analysis_toolkit
-version = attr: hat.version.__version__
-author = European Centre for Medium-Range Weather Forecasts (ECMWF)
-author_email = software.support@ecmwf.int
-license = Apache 2.0
-license_files = LICENSE
-description = ECMWF Hydrological Analysis Tools
-long_description = file: README.md
-long_description_content_type=text/markdown
-url = https://hydro-analysis-toolkit.readthedocs.io
-
-[options]
-packages = find:
-include_package_data = True
-
-[options.package_data]
-hat = config_json/*.json
-
-[options.packages.find]
-include = hat*
-
-[options.entry_points]
-console_scripts =
- hat-extract-timeseries = hat.tools.extract_simulation_timeseries_cli:main
- hat-hydrostats = hat.tools.hydrostats_cli:main
- hat-stations-mapping = hat.mapping.station_mapping:main
\ No newline at end of file
diff --git a/tests/test_config.py b/tests/test_config.py
index 941a014..47258bf 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -8,7 +8,12 @@
# See https://setuptools.pypa.io/en/latest/pkg_resources.html
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
- from hat.config import DEFAULT_CONFIG, booleanify, read_config, valid_custom_config
+ from hat.config import (
+ DEFAULT_CONFIG,
+ booleanify,
+ read_config,
+ valid_custom_config,
+ )
def test_DEFAULT_CONFIG():
@@ -39,7 +44,9 @@ def test_valid_custom_config():
assert valid_custom_config(empty_dict) == DEFAULT_CONFIG
assert valid_custom_config(invalid_keys) == DEFAULT_CONFIG
- assert valid_custom_config(partially_complete).keys() == DEFAULT_CONFIG.keys()
+ assert (
+ valid_custom_config(partially_complete).keys() == DEFAULT_CONFIG.keys()
+ )
def test_read_config():
diff --git a/tests/test_interactive.py b/tests/test_interactive.py
index de24d32..d7d710f 100644
--- a/tests/test_interactive.py
+++ b/tests/test_interactive.py
@@ -32,7 +32,9 @@ def update(self, index, metadata, **kwargs):
class TestWidgetsManager:
def test_update(self):
dummy = DummyWidget()
- widgets = wd.WidgetsManager(widgets={"dummy": dummy}, index_column="station")
+ widgets = wd.WidgetsManager(
+ widgets={"dummy": dummy}, index_column="station"
+ )
feature = {
"properties": {
"station": "A",
@@ -92,7 +94,11 @@ def test_update(self):
class TestMetaDataWidget:
def test_update(self):
df = pd.DataFrame(
- {"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [0.1, 0.2, 0.3]}
+ {
+ "col1": [1, 2, 3],
+ "col2": ["a", "b", "c"],
+ "col3": [0.1, 0.2, 0.3],
+ }
)
widget = wd.MetaDataWidget(df, "col2")
assert widget.update("a") is True
@@ -183,7 +189,9 @@ def test_stats_style_fail(self):
"station": 4,
}
}
- colormap = lf.PyleafletColormap(self.config, self.stats, empty_color="black")
+ colormap = lf.PyleafletColormap(
+ self.config, self.stats, empty_color="black"
+ )
style_fct = colormap.style_callback()
style = style_fct(feature)
assert style["fillColor"] == "black"
diff --git a/tests/test_mapping.py b/tests/test_mapping.py
index 7468c13..0c310b4 100644
--- a/tests/test_mapping.py
+++ b/tests/test_mapping.py
@@ -74,7 +74,12 @@ def test_create_grid_polygon():
lat, lon = 2.5, 2.5
cell_size = 1
polygon = create_grid_polygon(lat, lon, cell_size)
- assert polygon.bounds == (2, 2, 3, 3) # Check if the polygon bounds are as expected
+ assert polygon.bounds == (
+ 2,
+ 2,
+ 3,
+ 3,
+ ) # Check if the polygon bounds are as expected
@pytest.fixture
@@ -92,7 +97,9 @@ def mock_station():
# Mock function to simulate netCDF data access
def mock_latitudes_longitudes():
"""Provides mock latitudes and longitudes arrays"""
- return np.array([0.0, 1.0, 2.0, 3.0, 4.0]), np.array([0.0, 1.0, 2.0, 3.0, 4.0])
+ return np.array([0.0, 1.0, 2.0, 3.0, 4.0]), np.array(
+ [0.0, 1.0, 2.0, 3.0, 4.0]
+ )
@pytest.fixture
@@ -105,7 +112,9 @@ def mock_nc_data():
return mock_nc
-def test_process_station_data(mock_station, mock_latitudes_longitudes, mock_nc_data):
+def test_process_station_data(
+ mock_station, mock_latitudes_longitudes, mock_nc_data
+):
max_neighboring_cells = 1
min_area_diff = 10
max_area_diff = 20
@@ -168,27 +177,37 @@ def test_process_station_data(mock_station, mock_latitudes_longitudes, mock_nc_d
processed_data["near_grid_area"], float
), "Near grid area should be a float"
- assert "near_area_diff" in processed_data, "Near area difference is missing"
+ assert (
+ "near_area_diff" in processed_data
+ ), "Near area difference is missing"
assert isinstance(
processed_data["near_area_diff"], float
), "Near area difference should be a float"
- assert "optimum_grid_lat" in processed_data, "Optimum grid latitude is missing"
+ assert (
+ "optimum_grid_lat" in processed_data
+ ), "Optimum grid latitude is missing"
assert isinstance(
processed_data["optimum_grid_lat"], float
), "Optimum grid latitude should be a float"
- assert "optimum_grid_lon" in processed_data, "Optimum grid longitude is missing"
+ assert (
+ "optimum_grid_lon" in processed_data
+ ), "Optimum grid longitude is missing"
assert isinstance(
processed_data["optimum_grid_lon"], float
), "Optimum grid longitude should be a float"
- assert "optimum_grid_area" in processed_data, "Optimum grid area is missing"
+ assert (
+ "optimum_grid_area" in processed_data
+ ), "Optimum grid area is missing"
assert isinstance(
processed_data["optimum_grid_area"], float
), "Optimum grid area should be a float"
- assert "optimum_area_diff" in processed_data, "Optimum area difference is missing"
+ assert (
+ "optimum_area_diff" in processed_data
+ ), "Optimum area difference is missing"
assert isinstance(
processed_data["optimum_area_diff"], float
), "Optimum area difference should be a float"
@@ -200,7 +219,9 @@ def test_process_station_data(mock_station, mock_latitudes_longitudes, mock_nc_d
processed_data["optimum_distance_cells"], int
), "Optimum distance should be an integer"
- assert "optimum_distance_km" in processed_data, "Optimium distance in km is missing"
+ assert (
+ "optimum_distance_km" in processed_data
+ ), "Optimium distance in km is missing"
assert isinstance(
processed_data["optimum_distance_km"], float
), "Nearest distance in km should be a float"
diff --git a/tests/test_mapping_evaluation.py b/tests/test_mapping_evaluation.py
index c284620..c332945 100644
--- a/tests/test_mapping_evaluation.py
+++ b/tests/test_mapping_evaluation.py
@@ -13,7 +13,9 @@ def test_calculate_mae():
data = {"reference": [100, 200], "evaluated": [90, 195]}
df = pd.DataFrame(data)
expected_mae = 7.5 # Calculated manually
- assert calculate_mae(df, "reference", "evaluated") == pytest.approx(expected_mae)
+ assert calculate_mae(df, "reference", "evaluated") == pytest.approx(
+ expected_mae
+ )
def test_calculate_rmse():
@@ -52,4 +54,6 @@ def test_count_and_analyze_area_distance(sample_dataframe):
"log",
)
- assert isinstance(fig, Figure), "The function should return a matplotlib figure."
+ assert isinstance(
+ fig, Figure
+ ), "The function should return a matplotlib figure."