Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/module action #29

Merged
merged 15 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions .github/workflows/cd.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
name: cd-module

on:
workflow_dispatch:
inputs:
module-name:
description: 'Module name'
required: true
type: string
push:
tags:
- '[0-9]+.[0-9]+.[0-9]+'
jobs:
install_module:
runs-on: [hpc]
steps:
- uses: ecmwf-actions/reusable-workflows/ci-hpc-generic@v2
with:
troika_user: ${{ secrets.HPC_CI_SSH_USER }}
template: |
# Declare an array to store the module versions
python3_versions=("3.10.10-01" "3.11.8-01")


HAT_VERSION=${{ github.event_name == 'workflow_dispatch' && inputs.module-name || github.ref_name }}

PREFIX=/usr/local/apps/hat/${HAT_VERSION}
rm -rf $PREFIX
mkdir -p $PREFIX

# Loop through the module versions
for version in "${python3_versions[@]}"
do
# Load the module for the current version
module load python3/$version

PYTHONUSERBASE=$PREFIX pip3 install --user git+https://github.com/ecmwf/hat.git@${HAT_VERSION}

module unload python3
done

cat > $PREFIX/README.txt << EOF
-- [hat] ($HAT_VERSION) [$HAT_VERSION]
EOF

software-sync -s local -p hat
module load modulemgr
modulemgr -f -m all sync hat

sbatch_options: |
#SBATCH --job-name=cd_hat
#SBATCH --time=00:10:00
#SBATCH --qos=deploy
2 changes: 1 addition & 1 deletion .github/workflows/on-push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
branches:
- main
tags:
- "*"
- "[0-9]+.[0-9]+.[0-9]+"
pull_request:
branches:
- main
Expand Down
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ dependencies:
- tqdm
- typer
- humanize
- typer
- ipyleaflet
- ipywidgets
- pip
Expand Down
8 changes: 6 additions & 2 deletions hat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def load_package_config(fname):
resource_package = "hat"
resource_path = os.path.join("config_json", fname)

config_string = pkg_resources.resource_string(resource_package, resource_path)
config_string = pkg_resources.resource_string(
resource_package, resource_path
)
config = json.loads(config_string.decode())

return config
Expand Down Expand Up @@ -50,7 +52,9 @@ def booleanify(config, key):
return

if str(config[key]).lower() not in ["true", "false"]:
raise ValueError(f'"{key}" configuration variable must be "True" or "False"')
raise ValueError(
f'"{key}" configuration variable must be "True" or "False"'
)

if config[key].lower() == "true":
config[key] = True
Expand Down
28 changes: 21 additions & 7 deletions hat/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def find_files(simulation_files):
fpaths = glob.glob(simulation_files)

if not fpaths:
raise Exception(f"Could not find any file from regex {simulation_files}")
raise Exception(
f"Could not find any file from regex {simulation_files}"
)
else:
print("Found following simulation files:")
print(*fpaths, sep="\n")
Expand Down Expand Up @@ -113,7 +115,9 @@ def dirsize(simulation_datadir, bytesize=False):
"""given a root directory return total size of all files
in a directory in human readable format"""

if not os.path.exists(simulation_datadir) or not os.path.isdir(simulation_datadir):
if not os.path.exists(simulation_datadir) or not os.path.isdir(
simulation_datadir
):
print("Not a directory", simulation_datadir)
return

Expand Down Expand Up @@ -178,7 +182,9 @@ def raster_loader(fpath: str, epsg: int = 4326, engine=None):
xarr.attrs["fname"] = os.path.basename(fpath)
xarr.attrs["experiment_name"] = pathlib.Path(fpath).stem
xarr.attrs["file_extension"] = pathlib.Path(fpath).suffix
xarr.attrs["river_network"] = river_network_from_filename(xarr.attrs["fname"])
xarr.attrs["river_network"] = river_network_from_filename(
xarr.attrs["fname"]
)

return xarr

Expand Down Expand Up @@ -261,7 +267,9 @@ def save_dataset_to_netcdf(ds: xr.Dataset, fpath: str):


def find_main_var(ds, min_dim=3):
variable_names = [k for k in ds.variables if len(ds.variables[k].dims) >= min_dim]
variable_names = [
k for k in ds.variables if len(ds.variables[k].dims) >= min_dim
]
if len(variable_names) > 1:
raise Exception("More than one variable in dataset")
elif len(variable_names) == 0:
Expand Down Expand Up @@ -291,10 +299,16 @@ def read_simulation_as_xarray(options):
fs = earthkit.data.from_source(src_type, *args)

xarray_kwargs = {}
if isinstance(fs, earthkit.data.readers.netcdf.fieldlist.NetCDFMultiFieldList):
xarray_kwargs["xarray_open_mfdataset_kwargs"] = {"chunks": {"time": "auto"}}
if isinstance(
fs, earthkit.data.readers.netcdf.fieldlist.NetCDFMultiFieldList
):
xarray_kwargs["xarray_open_mfdataset_kwargs"] = {
"chunks": {"time": "auto"}
}
else:
xarray_kwargs["xarray_open_dataset_kwargs"] = {"chunks": {"time": "auto"}}
xarray_kwargs["xarray_open_dataset_kwargs"] = {
"chunks": {"time": "auto"}
}

# xarray dataset
ds = fs.to_xarray(**xarray_kwargs)
Expand Down
6 changes: 5 additions & 1 deletion hat/extract_simulation_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def extract_timeseries(

# all timeseries (i.e. handle proximal and duplicate stations)
da_stations = assign_stations(
stations, station_mask, da_points, coords, config["station_id_column_name"]
stations,
station_mask,
da_points,
coords,
config["station_id_column_name"],
)

return da_stations
25 changes: 19 additions & 6 deletions hat/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

# @st.cache_data
def temporal_filter(
_metadata, _observations: pd.DataFrame, timeperiod, station_id_name="station_id"
_metadata,
_observations: pd.DataFrame,
timeperiod,
station_id_name="station_id",
):
"""
filter station metadata and timeseries by timeperiod
Expand Down Expand Up @@ -89,7 +92,9 @@ def stations_with_discharge(obs, timeperiod, metadata):
return metadata, obsdis


def apply_filter(df: pd.DataFrame, key: str, operator: str, value: str) -> pd.DataFrame:
def apply_filter(
df: pd.DataFrame, key: str, operator: str, value: str
) -> pd.DataFrame:
"""
Apply the filter on the DataFrame based on the provided
key, operator, and value.
Expand All @@ -104,7 +109,9 @@ def apply_filter(df: pd.DataFrame, key: str, operator: str, value: str) -> pd.Da
}

if key not in df.columns:
raise ValueError(f"Key '{key}' does not exist as column name in dataframe")
raise ValueError(
f"Key '{key}' does not exist as column name in dataframe"
)

if operator not in operators:
raise ValueError(f"Operator '{operator}' is not supported")
Expand Down Expand Up @@ -134,19 +141,25 @@ def filter_dataframe(df, filters: str):
continue
parts = filter_str.split()
if len(parts) != 3:
raise ValueError("Invalid filter format. Expected 'key operator value'.")
raise ValueError(
"Invalid filter format. Expected 'key operator value'."
)

key, operator, value = parts

df = apply_filter(df, key, operator, value)

if len(df) == 0:
raise ValueError("There are no remaining rows (try different filters?)")
raise ValueError(
"There are no remaining rows (try different filters?)"
)

return df


def filter_timeseries(sims_ds: xr.DataArray, obs_ds: xr.DataArray, threshold=80):
def filter_timeseries(
sims_ds: xr.DataArray, obs_ds: xr.DataArray, threshold=80
):
"""Clean the simulation and observation timeseries

Only keep..
Expand Down
24 changes: 18 additions & 6 deletions hat/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def river_network_geometry(metadata, river_network, version=1):
if river_network == "EPSG:4326":
return metadata

lon_name, lat_name = name_of_adjusted_coords(river_network, version=version)
lon_name, lat_name = name_of_adjusted_coords(
river_network, version=version
)

gdf = metadata.copy(deep=True)
gdf["geometry"] = gpd.points_from_xy(gdf[lon_name], gdf[lat_name])
Expand Down Expand Up @@ -157,7 +159,9 @@ def river_network_to_coord_names(river_network: str = "") -> dict:
x, y = river_network_to_coords[river_network]
return {"x": x, "y": y}
else:
print(f"River network '{river_network}' not in: {valid_river_networks}")
print(
f"River network '{river_network}' not in: {valid_river_networks}"
)


def geojson_schema():
Expand Down Expand Up @@ -230,7 +234,9 @@ def geojson_schema():
{"$ref": "#/definitions/pointCoordinates"},
{"$ref": "#/definitions/multiPointCoordinates"},
{"$ref": "#/definitions/lineStringCoordinates"},
{"$ref": "#/definitions/multiLineStringCoordinates"},
{
"$ref": "#/definitions/multiLineStringCoordinates"
},
{"$ref": "#/definitions/polygonCoordinates"},
{"$ref": "#/definitions/multiPolygonCoordinates"},
]
Expand Down Expand Up @@ -307,8 +313,12 @@ def geopoints_to_array(gdf, array_coords) -> np.ndarray:
point_ys = np.array([point.y for point in points])

# nearest neighbour indices of the points in the array
x_indices = [(np.abs(array_coords["x"] - point_x)).argmin() for point_x in point_xs]
y_indices = [(np.abs(array_coords["y"] - point_y)).argmin() for point_y in point_ys]
x_indices = [
(np.abs(array_coords["x"] - point_x)).argmin() for point_x in point_xs
]
y_indices = [
(np.abs(array_coords["y"] - point_y)).argmin() for point_y in point_ys
]

# create an empty boolean array
shape = (len(array_coords["y"]), len(array_coords["x"]))
Expand All @@ -320,7 +330,9 @@ def geopoints_to_array(gdf, array_coords) -> np.ndarray:
return arr


def geopoints_from_csv(fpath: str, lat_name: str, lon_name: str) -> gpd.GeoDataFrame:
def geopoints_from_csv(
fpath: str, lat_name: str, lon_name: str
) -> gpd.GeoDataFrame:
"""Load georeferenced points from file.
Requires name of latitude and longitude columns"""

Expand Down
19 changes: 15 additions & 4 deletions hat/interactive/explorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def prepare_observations_data(observations, sim_ds, obs_var_name):
return obs_ds


def find_common_stations(station_index, stations_metadata, obs_ds, sim_ds, statistics):
def find_common_stations(
station_index, stations_metadata, obs_ds, sim_ds, statistics
):
"""
Find common stations between observations, simulations and station
metadata.
Expand Down Expand Up @@ -260,7 +262,11 @@ def __init__(self, config):
self.title_label = ipywidgets.Label(
"Interactive Map Visualisation for Hydrological Model Performance",
layout=ipywidgets.Layout(justify_content="center"),
style={"font_weight": "bold", "font_size": "24px", "font_family": "Arial"},
style={
"font_weight": "bold",
"font_size": "24px",
"font_family": "Arial",
},
)

# Create the interactive widgets
Expand All @@ -269,7 +275,9 @@ def __init__(self, config):
widgets = {}
widgets["plot"] = PlotlyWidget(datasets)
widgets["stats"] = StatisticsWidget(self.statistics)
widgets["meta"] = MetaDataWidget(self.stations_metadata, self.station_index)
widgets["meta"] = MetaDataWidget(
self.stations_metadata, self.station_index
)
self.widgets = WidgetsManager(
widgets, config["station_id_column_name"], self.loading_widget
)
Expand Down Expand Up @@ -302,7 +310,10 @@ def create_frame(self):
width="40%",
)
right_layout = ipywidgets.Layout(
justify_content="center", align_items="center", spacing="2px", width="60%"
justify_content="center",
align_items="center",
spacing="2px",
width="60%",
)

# Frames
Expand Down
11 changes: 8 additions & 3 deletions hat/interactive/leaflet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(
basemap=ipyleaflet.basemaps.OpenStreetMap.Mapnik,
):
self.map = ipyleaflet.Map(
basemap=basemap, layout=ipywidgets.Layout(width="100%", height="600px")
basemap=basemap,
layout=ipywidgets.Layout(width="100%", height="600px"),
)
self.legend_widget = ipywidgets.Output()

Expand All @@ -56,7 +57,10 @@ def _set_boundaries(self, stations_metadata, coord_names):
min_lat, max_lat = min(lats), max(lats)
min_lon, max_lon = min(lons), max(lons)

bounds = [(float(min_lat), float(min_lon)), (float(max_lat), float(max_lon))]
bounds = [
(float(min_lat), float(min_lon)),
(float(max_lat), float(max_lon)),
]
self.map.fit_bounds(bounds)

def add_geolayer(self, geodata, colormap, widgets, coord_names=None):
Expand Down Expand Up @@ -221,7 +225,8 @@ def legend(self):
"""
# Convert the colormap to a list of RGB values
rgb_values = [
mpl.colors.rgb2hex(self.colormap(i)) for i in np.linspace(0, 1, 256)
mpl.colors.rgb2hex(self.colormap(i))
for i in np.linspace(0, 1, 256)
]

# Create a gradient style using the RGB values
Expand Down
Loading
Loading