diff --git a/fm2prof/MaskOutputFile.py b/fm2prof/MaskOutputFile.py
deleted file mode 100644
index 0f2c06b5..00000000
--- a/fm2prof/MaskOutputFile.py
+++ /dev/null
@@ -1,93 +0,0 @@
-"""
-Copyright (C) Stichting Deltares 2019. All rights reserved.
-
-This file is part of the Fm2Prof.
-
-The Fm2Prof is free software: you can redistribute it and/or modify
-it under the terms of the GNU Lesser General Public License as published by
-the Free Software Foundation, either version 3 of the License, or
-(at your option) any later version.
-
-This program is distributed in the hope that it will be useful,
-but WITHOUT ANY WARRANTY; without even the implied warranty of
-MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-GNU Affero General Public License for more details.
-
-You should have received a copy of the GNU Affero General Public License
-along with this program. If not, see .
-
-All names, logos, and references to "Deltares" are registered trademarks of
-Stichting Deltares and remain full property of Stichting Deltares at all times.
-All rights reserved.
-"""
-from pathlib import Path
-from typing import Union, Optional
-import geojson
-
-
-
-class MaskOutputFile:
- @staticmethod
- def create_mask_point(coords: geojson.coords, properties: Optional[dict] = None) -> geojson.Feature:
- """Creates a Point based on the properties and coordinates given.
-
- Arguments:
- coords {geojson.coords} --
- Coordinates tuple (x,y) for the mask point.
- properties {dict} -- Dictionary of properties
- """
- if not coords:
- raise ValueError("coords cannot be empty.")
- output_mask = geojson.Feature(geometry=geojson.Point(coords))
-
- if properties:
- output_mask.properties = properties
- return output_mask
-
- @staticmethod
- def validate_extension(file_path: Union[str, Path]) -> None:
- if not isinstance(file_path, (str, Path)):
- err_msg = f"file_path should be string or Path, not {type(file_path)}"
- raise TypeError(err_msg)
- if Path(file_path).suffix not in [".json", ".geojson"]:
- raise IOError(
- "Invalid file path extension, should be .json or .geojson."
- )
-
- @staticmethod
- def read_mask_output_file(file_path: Union[str, Path]) -> dict:
- """Imports a GeoJson from a given json file path.
-
- Arguments:
- file_path {str} -- Location of the json file
- """
- file_path = Path(file_path)
- if not file_path.exists():
- err_msg = f"File path {file_path} not found."
- raise FileNotFoundError(err_msg)
-
- MaskOutputFile.validate_extension(file_path)
- with file_path.open("r") as geojson_file:
- geojson_data = geojson.load(geojson_file)
- if not isinstance(geojson_data, geojson.FeatureCollection):
- raise IOError("File is empty or not a valid geojson file.")
- return geojson_data
-
- @staticmethod
- def write_mask_output_file(file_path: Union[Path, str], mask_points: list) -> None:
- """Writes a .geojson file with a Feature collection containing
- the mask_points list given as input.
-
- Arguments:
- file_path {str} -- file_path where to store the geojson.
- mask_points {list} -- List of features to output.
- """
- if not file_path:
- raise ValueError("file_path is required.")
- file_path = Path(file_path)
- if not mask_points:
- raise ValueError("mask_points cannot be empty.")
- MaskOutputFile.validate_extension(file_path)
- feature_collection = geojson.FeatureCollection(mask_points)
- with file_path.open("w") as f:
- geojson.dump(feature_collection, f, indent=4)
diff --git a/fm2prof/__init__.py b/fm2prof/__init__.py
index 3438156e..f0669b17 100644
--- a/fm2prof/__init__.py
+++ b/fm2prof/__init__.py
@@ -1,3 +1,3 @@
__version__ = "2.3.3"
-from fm2prof.Fm2ProfRunner import Project
+from fm2prof.fm2prof_runner import Project
diff --git a/fm2prof/cli.py b/fm2prof/cli.py
index 6691efc3..bc5d953b 100644
--- a/fm2prof/cli.py
+++ b/fm2prof/cli.py
@@ -1,47 +1,49 @@
+"""CLI for Fm2Prof."""
+
from pathlib import Path
from typing import Optional
import typer
from tqdm import tqdm
+
from fm2prof import Project, __version__
-from fm2prof.IniFile import IniFile
+from fm2prof.ini_file import IniFile
from fm2prof.utils import Compare1D2D, VisualiseOutput
app = typer.Typer()
-def _display_version(value: bool) -> None:
+def _display_version(value: bool) -> None: # noqa:FBT001
if value:
typer.echo(f"Fm2Prof v{__version__}")
- raise typer.Exit()
+ raise typer.Exit
@app.command("create")
-def cli_create_new_project(projectname: str):
- """Creates a new project configuration from scratch, then exit"""
+def cli_create_new_project(projectname: str) -> None:
+ """Create a new project configuration from scratch, then exit."""
inifile = IniFile().print_configuration()
- ini_path = f"{projectname}.ini"
+ ini_path = Path(f"{projectname}.ini")
- with open(ini_path, "w") as f:
+ with ini_path.open("w") as f:
f.write(inifile)
typer.echo(f"{ini_path} written to file")
- raise typer.Exit()
+ raise typer.Exit
@app.command("check")
def cli_check_project(projectname: str) -> None:
- """Load project, check filepaths, print errors then exit"""
+ """Load project, check filepaths, print errors then exit."""
cf = Path(projectname).with_suffix(".ini")
- project = Project(cf)
- raise typer.Exit()
+ Project(cf)
+ raise typer.Exit
@app.command("compare")
def cli_compare_1d2d(
projectname: str, output_1d: str, output_2d: str, routes: str
) -> None:
- """BETA FUNCTIONALITY - compares 1D and 2D results"""
-
+ """BETA FUNCTIONALITY - compares 1D and 2D results."""
cf = Path(projectname).with_suffix(".ini")
project = Project(cf)
@@ -58,17 +60,21 @@ def cli_compare_1d2d(
@app.command("run")
def cli_load_project(
projectname: str,
+ *,
overwrite: bool = typer.Option(
- False, "--overwrite", "-o", help="Overwrite if output already exists"
+ False, # noqa: FBT003
+ "--overwrite",
+ "-o",
+ help="Overwrite if output already exists",
),
pp: bool = typer.Option(
- False,
+ False, # noqa: FBT003
"--post-process",
"-p",
help="Post-process the results, generates figures",
),
) -> None:
- """Loads and runs a project"""
+ """Load and run a project."""
cf = Path(projectname).with_suffix(".ini")
project = Project(cf)
project.run(overwrite=overwrite)
@@ -80,12 +86,12 @@ def cli_load_project(
for css in tqdm(vis.cross_sections):
vis.figure_cross_section(css)
- raise typer.Exit()
+ raise typer.Exit
@app.callback()
def cli(
- version: Optional[bool] = typer.Option(
+ version: Optional[bool] = typer.Option( # noqa: ARG001 FA100
None,
"--version",
"-v",
@@ -94,9 +100,9 @@ def cli(
is_eager=True,
),
) -> None:
+ """Fm2Prof Command-line interface."""
typer.echo("Welcome to Fm2Prof")
- return
-def main():
+def main(): # noqa: ANN201, D103
app()
diff --git a/fm2prof/common.py b/fm2prof/common.py
index 86c2d3bf..4f14d685 100644
--- a/fm2prof/common.py
+++ b/fm2prof/common.py
@@ -1,16 +1,15 @@
-# -*- coding: utf-8 -*-
-"""
-Base classes and data containers
-"""
+"""Base classes and data containers."""
+
+from __future__ import annotations
import logging
# Imports from standard library
-import os
from datetime import datetime
from logging import Logger, LogRecord
+from pathlib import Path
from time import time
-from typing import AnyStr, Mapping
+from typing import TYPE_CHECKING
import colorama
@@ -20,16 +19,19 @@
from colorama import Back, Fore, Style
# Import from package
-# none
-
-IniFile = "fm2prof.IniFile.IniFile"
+if TYPE_CHECKING:
+ from fm2prof.IniFile import IniFile
class TqdmLoggingHandler(logging.StreamHandler):
- def __init__(self):
+ """Logging handler for tqdm package."""
+
+ def __init__(self) -> None:
+ """Instantiate a TqdmLoggingHandler."""
super().__init__()
- def emit(self, record):
+ def emit(self, record: LogRecord) -> None:
+ """Write progressbar to logstream."""
try:
msg = self.format(record)
if self.formatter.pbar:
@@ -39,14 +41,17 @@ def emit(self, record):
stream.write(msg + self.terminator)
self.flush()
- except Exception as e:
+ except Exception:
self.handleError(record)
class ElapsedFormatter:
+ """ElapsedFormatter class."""
+
__new_iteration = 1
- def __init__(self):
+ def __init__(self) -> None:
+ """Instantiate an ElapsedFormatter object."""
self.start_time = time()
self.number_of_iterations: int = 1
self.current_iteration: int = 0
@@ -66,44 +71,43 @@ def __init__(self):
self._loglibrary: dict = {"ERROR": 0, "WARNING": 0}
@property
- def pbar(self):
+ def pbar(self) -> None | tqdm.tqdm | tqdm.std.tqdm:
+ """Progress bar."""
return self._pbar
@pbar.setter
- def pbar(self, pbar):
+ def pbar(self, pbar: tqdm.std.tqdm | tqdm.tqdm | None) -> None:
+ """Set progress bar."""
if isinstance(pbar, (tqdm.std.tqdm, type(None))):
self._pbar = pbar
else:
- raise ValueError
+ raise TypeError
- def format(self, record):
+ def format(self, record: LogRecord) -> str:
+ """Format log record."""
if self._intro:
return self.__format_intro(record)
if self.__new_iteration > 0:
return self.__format_header(record)
if self.__new_iteration == -1:
return self.__format_footer(record)
- else:
- return self.__format_message(record)
+ return self.__format_message(record)
- def __format_intro(self, record: LogRecord):
+ def __format_intro(self, record: LogRecord) -> str:
return f"{record.getMessage()}"
- def __format_header(self, record: LogRecord):
- """Formats the header of a new task"""
-
+ def __format_header(self, record: LogRecord) -> str:
self.__new_iteration -= 1
message = record.getMessage()
current_time = datetime.now().strftime("%Y-%m-%d %H:%M")
return f"╔═════╣ {self._resetStyle}{current_time} {message}{self._resetStyle}"
- def __format_footer(self, record: LogRecord):
+ def __format_footer(self, record: LogRecord) -> str:
self.__new_iteration -= 1
elapsed_seconds = record.created - self.start_time
- message = record.getMessage()
return f"╚═════╣ {self._resetStyle}Task finished in {elapsed_seconds:.2f}sec{self._resetStyle}"
- def __format_message(self, record: LogRecord):
+ def __format_message(self, record: LogRecord) -> str:
elapsed_seconds = record.created - self.start_time
color = self._colors
@@ -114,42 +118,52 @@ def __format_message(self, record: LogRecord):
if level in self._loglibrary:
self._loglibrary[level] += 1
- formatted_string = (
+ return (
f"║ {color[level][0]} {level:>7} "
- + f"{self._resetStyle}{color[level][1]}{self._resetStyle} T+ {elapsed_seconds:.2f}s {message}"
+ f"{self._resetStyle}{color[level][1]}{self._resetStyle} T+ {elapsed_seconds:.2f}s {message}"
)
- return formatted_string
-
- def __reset(self):
+ def __reset(self) -> None:
self.start_time = time()
- def start_new_iteration(self, pbar: tqdm.tqdm = None):
+ def start_new_iteration(self, pbar: tqdm.tqdm | None = None) -> None:
+ """Start a new iteration with a progress bar."""
self.current_iteration += 1
self.new_task()
self.pbar = pbar
- def new_task(self):
+ def new_task(self) -> None:
+ """Reset ElapsedTimeFormatter."""
self.__new_iteration = 1
self.__reset()
- def finish_task(self):
+ def finish_task(self) -> None:
+ """Finish task."""
self.__new_iteration = -1
- def set_number_of_iterations(self, n):
- assert n > 0, "Total number of iterations should be higher than zero"
+ def set_number_of_iterations(self, n: int) -> None:
+ """Set numbber of iterations."""
+ if n < 1:
+ err_msg = "Total number of iterations should be higher than zero"
+ raise ValueError(err_msg)
self.number_of_iterations = n
- def set_intro(self, flag: bool = True):
+ def set_intro(self, flag: bool = True) -> None: # noqa: FBT001, FBT002
+ """Indicate intro section for formatter."""
self._intro = flag
- def get_elapsed_time(self):
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M")
- return current_time - self.start_time
+ def get_elapsed_time(self) -> float:
+ """Get elapsed time in seconds."""
+ current_time = datetime.now()
+ elapsed_time = current_time - self.start_time
+ return elapsed_time.total_seconds()
class ElapsedFileFormatter(ElapsedFormatter):
- def __init__(self):
+ """Elapsed file formatter class."""
+
+ def __init__(self) -> None:
+ """Instantiate an ElapsedFileFormatter object."""
super().__init__()
self._resetStyle = ""
self._colors = {
@@ -160,18 +174,11 @@ def __init__(self):
"RESET": "",
}
- @property
- def pbar(self):
- return self._pbar
-
- @pbar.setter
- def pbar(self, pbar):
- self._pbar = None
-
class FM2ProfBase:
- """
- Base class for FM2PROF types. Implements methods for logging, project specific parameters
+ """Base class for FM2PROF types.
+
+ Implements methods for logging, project specific parameters
"""
__logger = None
@@ -182,64 +189,82 @@ class FM2ProfBase:
__copyright__ = "Copyright 2016-2020, University of Twente & Deltares"
__license__ = "LPGL"
- def __init__(self, logger: Logger = None, inifile: IniFile = None):
+ def __init__(self, logger: Logger | None = None, inifile: IniFile | None = None) -> None:
+ """Instatiate a FM2ProfBase object.
+
+ Args:
+ ----
+ logger (Logger | None, optional): Logger . Defaults to None.
+ inifile (IniFile | None, optional): IniFile instance. Defaults to None.
+
+ """
if logger:
self.set_logger(logger)
if inifile:
self.set_inifile(inifile)
- def _create_logger(self):
+ def _create_logger(self) -> None:
# Create logger
self.__logger = logging.getLogger(__name__)
self.__logger.setLevel(logging.DEBUG)
# create formatter
- self.__logger.__logformatter = ElapsedFormatter()
- self.__logger._Filelogformatter = ElapsedFileFormatter()
+ self.__logger.__logformatter = ElapsedFormatter() # noqa: SLF001
+ self.__logger._Filelogformatter = ElapsedFileFormatter() # noqa: SLF001
# create console handler
if TqdmLoggingHandler not in map(type, self.__logger.handlers):
ch = TqdmLoggingHandler()
ch.setLevel(logging.DEBUG)
- ch.setFormatter(self.__logger.__logformatter)
+ ch.setFormatter(self.__logger.__logformatter) # noqa: SLF001
self.__logger.addHandler(ch)
def get_logger(self) -> Logger:
- """Use this method to return logger object"""
+ """Use this method to return logger object."""
return self.__logger
def set_logger(self, logger: Logger) -> None:
- """
- Use to set logger
+ """Use to set logger.
- Parameters:
+ Args:
+ ----
logger (Logger): Logger instance
+
"""
- assert isinstance(logger, Logger), (
- "" + "logger should be instance of Logger class"
- )
+ if not isinstance(logger, Logger):
+ err_msg = "logger should be instance of Logger class"
+ raise TypeError(err_msg)
self.__logger = logger
def set_logger_message(
- self, err_mssg: str = "", level: str = "info", header: bool = False
+ self,
+ err_mssg: str = "",
+ level: str = "info",
+ *,
+ header: bool = False,
) -> None:
- """Sets message to logger if this is set.
+ """Set message to logger if this is set.
+
+ Args:
+ ----
+ err_mssg (str, optional): Error message to log. Defaults to "".
+ level (str, optional): Log level. Defaults to "info".
+ header (bool, optional): Set error message as header. Defaults to False.
- Arguments:
- err_mssg {str} -- Error message to send to logger.
"""
if not self.__logger:
return
if header:
self.get_logformatter().set_intro(True)
- self.get_logger()._Filelogformatter.set_intro(True)
+ self.get_logger()._Filelogformatter.set_intro(True) # noqa: SLF001
else:
self.get_logformatter().set_intro(False)
- self.get_logger()._Filelogformatter.set_intro(False)
+ self.get_logger()._Filelogformatter.set_intro(False) # noqa: SLF001
if level.lower() not in ["info", "debug", "warning", "error", "critical"]:
- self.__logger.error("{} is not valid logging level.".format(level.lower()))
+ err_msg = f"{level.lower()} is not valid logging level."
+ raise ValueError(err_msg)
if level.lower() == "info":
self.__logger.info(err_mssg)
@@ -253,10 +278,11 @@ def set_logger_message(
self.__logger.critical(err_mssg)
def start_new_log_task(
- self, task_name: str = "NOT DEFINED", pbar: tqdm.tqdm = None
+ self,
+ task_name: str = "NOT DEFINED",
+ pbar: tqdm.tqdm = None,
) -> None:
- """
- Use this method to start a new task. Will reset the internal clock.
+ """Use this method to start a new task. Will reset the internal clock.
:param task_name: task name, will be displayed in log message
"""
@@ -265,8 +291,7 @@ def start_new_log_task(
self.set_logger_message(f"Starting new task: {task_name}")
def finish_log_task(self) -> None:
- """
- Use this method to finish task.
+ """Use this method to finish task.
:param task_name: task name, will be displayed in log message
"""
@@ -275,54 +300,74 @@ def finish_log_task(self) -> None:
self.pbar = None
def get_logformatter(self) -> ElapsedFormatter:
- """Returns formatter"""
- return self.get_logger().__logformatter
+ """Return log formatter."""
+ return self.get_logger().__logformatter # noqa: SLF001
def get_filelogformatter(self) -> ElapsedFormatter:
- """Returns formatter"""
- return self.get_logger()._Filelogformatter
+ """Return file log formatter."""
+ return self.get_logger()._Filelogformatter # noqa: SLF001
+
+ def set_logfile(self, output_dir: str | Path, filename: str = "fm2prof.log") -> None:
+ """Set log file.
- def set_logfile(self, output_dir: str, filename: str = "fm2prof.log") -> None:
+ Args:
+ ----
+ output_dir (str): _description_
+ filename (str, optional): _description_. Defaults to "fm2prof.log".
+
+ """
# create file handler
- fh = logging.FileHandler(os.path.join(output_dir, filename), encoding="utf-8")
+ if not output_dir:
+ err_msg = "output_dir is required."
+ raise ValueError(err_msg)
+ fh = logging.FileHandler(Path(output_dir).joinpath(filename), encoding="utf-8")
fh.setLevel(logging.DEBUG)
- fh.setFormatter(self.get_logger()._Filelogformatter)
+ fh.setFormatter(self.get_logger()._Filelogformatter) # noqa: SLF001
self.__logger.addHandler(fh)
- def set_inifile(self, inifile: IniFile = None):
- """
- Use this method to set configuration file object.
+ def set_inifile(self, inifile: IniFile = None) -> None:
+ """Use this method to set configuration file object.
For loading from file, use ``load_inifile`` instead
- Parameters:
+ Args:
+ ----
inifile (IniFile): inifile object. Obtain using e.g. ``get_inifile``.
+
"""
self.__iniFile = inifile
def get_inifile(self) -> IniFile:
- """ "Use this method to get the inifile object"""
+ """Get the inifile object."""
return self.__iniFile
class FrictionTable:
- """
- Container for friction table
- """
+ """Container for friction table."""
- def __init__(self, level, friction):
+ def __init__(self, level: np.ndarray, friction: np.ndarray) -> None:
+ """Instantiate a FrictionTable object."""
if self._validate_input(level, friction):
self.level = level
self.friction = friction
- def interpolate(self, new_z):
+ def interpolate(self, new_z: np.ndarray) -> None:
+ """Interpolate friction.
+
+ Args:
+ new_z (np.ndarray): _description_
+ """
self.friction = np.interp(new_z, self.level, self.friction)
self.level = new_z
@staticmethod
- def _validate_input(level, friction):
- assert isinstance(level, np.ndarray)
- assert isinstance(friction, np.ndarray)
- assert level.shape == friction.shape
-
+ def _validate_input(level: np.ndarray, friction: np.ndarray) -> bool:
+ if not isinstance(level, np.ndarray):
+ err_msg = f"level argument not of type {np.ndarray}."
+ raise TypeError(err_msg)
+ if not isinstance(friction, np.ndarray):
+ err_msg = f"friction argument not of type {np.ndarray}."
+ if level.shape != friction.shape:
+ err_msg = "level and friction arrays should have the same shape."
+ raise ValueError(err_msg)
return True
diff --git a/fm2prof/CrossSection.py b/fm2prof/cross_section.py
similarity index 70%
rename from fm2prof/CrossSection.py
rename to fm2prof/cross_section.py
index cebf4535..228212d1 100644
--- a/fm2prof/CrossSection.py
+++ b/fm2prof/cross_section.py
@@ -1,9 +1,13 @@
+"""Cross section module."""
+
+from __future__ import annotations
+
import math
import pickle
import traceback
from functools import reduce
-from logging import Logger
-from typing import Dict, List
+from pathlib import Path
+from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
@@ -11,32 +15,42 @@
from scipy.integrate import cumulative_trapezoid
from tqdm import tqdm
-from fm2prof import Functions as FE
+from fm2prof import functions as funcs
from fm2prof.common import FM2ProfBase, FrictionTable
-from fm2prof.IniFile import IniFile
-from fm2prof.MaskOutputFile import MaskOutputFile
+from fm2prof import mask_output_file
+from fm2prof.ini_file import IniFile
+
+from .lib import polysimplify as ps
-from .lib import polysimplify as PS
+if TYPE_CHECKING:
+ from logging import Logger
pd.options.mode.chained_assignment = None # default='warn'
+NODATA = -999 # this should be equal to the missing number value from D-HYDRO
+
class CrossSectionHelpers(FM2ProfBase):
- """
- Collection of function(s) to help with post-processing of cross-sections.
+ """Collection of function(s) to help with post-processing of cross-sections.
+
Wrapped in a class to provide access to the shared logger.
"""
_friction_zstep = 0.1
- def __init__(self, logger=None, inifile=None):
+ def __init__(
+ self,
+ logger: Logger | None = None,
+ inifile: IniFile | None = None,
+ ) -> None:
+ """Initialize a CrossSectionHelpers instance."""
super().__init__(logger=logger, inifile=inifile)
def interpolate_friction_across_cross_sections(
- self, cross_section_list: List["CrossSection"]
+ self,
+ cross_section_list: list[CrossSection],
) -> bool:
- """
- Creates a uniform matrix of z/chezy values for all cross-sections by linear interpolation.
+ """Reates a uniform matrix of z/chezy values for all cross-sections by linear interpolation.
This function loops over a list of cross-sections, determines the minimim and maximum values,
and uses the `CrossSection.friction_tables` DataFrame to interpolate towards those values.
@@ -44,27 +58,35 @@ def interpolate_friction_across_cross_sections(
.. warning::
The function modifies the input!
- outputs TRUE if succesful, FALSE if not. Does not raise exception.
+
+
+ Args:
+ ----
+ cross_section_list (List[CrossSection]): _description_
+
+ Returns:
+ -------
+ bool: TRUE if succesful, FALSE if not. Does not raise exception.
"""
try:
return self._interpolate_friction_across_cross_sections(cross_section_list)
except Exception:
self.set_logger_message(
- "There was an error while making friction tables", "error"
+ "There was an error while making friction tables",
+ "error",
)
for line in traceback.format_exc().splitlines():
self.set_logger_message(line, "debug")
return False
def _interpolate_friction_across_cross_sections(
- self, cross_section_list: List["CrossSection"]
+ self,
+ cross_section_list: list[CrossSection],
) -> bool:
- """Private function"""
-
# Get a list of all sections
- all_sections: List[str] = [
- s for css in cross_section_list for s in css.friction_tables.keys()
+ all_sections: list[str] = [
+ s for css in cross_section_list for s in css.friction_tables
]
sections: np.ndarray[str] = np.unique(all_sections)
@@ -82,10 +104,10 @@ def _interpolate_friction_across_cross_sections(
for css in cross_section_list:
if section in list(css.friction_tables.keys()):
minimal_z = np.min(
- [minimal_z, np.min(css.friction_tables.get(section).level)]
+ [minimal_z, np.min(css.friction_tables.get(section).level)],
)
maximal_z = np.max(
- [maximal_z, np.max(css.friction_tables.get(section).level)]
+ [maximal_z, np.max(css.friction_tables.get(section).level)],
)
# interpolate each cross-section to min-max level
pbar = tqdm(total=len(cross_section_list))
@@ -100,6 +122,8 @@ def _interpolate_friction_across_cross_sections(
class CrossSection(FM2ProfBase):
+ """Cross section class."""
+
__cs_parameter_transitionheight_sd = "SDTransitionHeight"
__cs_parameter_conveyance_detection_method = "ConveyanceDetectionMethod"
__cs_parameter_velocity_threshold = "AbsoluteVelocityThreshold"
@@ -112,24 +136,20 @@ class CrossSection(FM2ProfBase):
def __init__(
self,
- data: Dict,
+ data: dict,
logger: Logger | None = None,
inifile: IniFile | None = None,
- ):
- """
- Use this class to derive cross-sections from fm_data (2D model results).
- See docs how to acquire fm_data and how to prepare a proper 2D model.
+ ) -> None:
+ """Derive cross-sections from fm_data (2D model results).
- Deprecated:
- 1.2: The `foo` attribute is deprecated.
+ See docs how to acquire fm_data and how to prepare a proper 2D model.
- >>> example co
- hello!
- Parameters:
- data: contains stuff
- logger: ala
- inifile: woow
+ Args:
+ ----
+ data (dict): dict with data
+ logger (Logger): logger to log to.
+ inifile (IniFile): IniFile instance.
"""
# If inifile is not given, use default configuration
@@ -137,13 +157,12 @@ def __init__(
inifile = IniFile()
super().__init__(logger=logger, inifile=inifile)
- try:
- assert all(
- key in data
- for key in ["id", "length", "xy", "branchid", "chainage", "fm_data"]
- )
- except AssertionError:
- raise KeyError("Input data does not have all required keys")
+ if not all(
+ key in data
+ for key in ["id", "length", "xy", "branchid", "chainage", "fm_data"]
+ ):
+ err_msg = "Input data does not have all required keys"
+ raise KeyError(err_msg)
# Cross-section meta data
self.name = data.get("id") # cross-section id
@@ -151,14 +170,14 @@ def __init__(
self.location = data.get("xy") # (x,y)
self.branch = data.get("branchid") # name of 1D branch for cross-section
self.chainage = data.get("chainage") # offset from beginning of branch
- self._fm_data: Dict = data.get("fm_data") # dictionary with fmdata
+ self._fm_data: dict = data.get("fm_data") # dictionary with fmdata
# Cross-section geometry
self.z = []
self.total_width = []
self.flow_width = []
self.section_widths = {"main": 0, "floodplain1": 0, "floodplain2": 0}
- self.friction_tables = dict()
+ self.friction_tables = {}
self.roughness_sections = np.array([])
# delta h corrections ("summerdike option")
@@ -171,7 +190,7 @@ def __init__(
self.transition_height = 0.5
self.extra_flow_area = 0.0
self.extra_total_volume = 0.0
- self.extra_area_percentage = list()
+ self.extra_area_percentage = []
self.extra_total_area = 0
self.extra_flow_area = 0
@@ -206,47 +225,49 @@ def __init__(
}
@property
- def alluvial_width(self):
- for key in [1, "1", "main", "Main"]:
- try:
+ def alluvial_width(self) -> int:
+ """Get alluvial width."""
+ for key in ["1", "main", "Main"]:
+ if key in self.section_widths:
return self.section_widths[key]
- except KeyError:
- pass
return 0
@property
- def nonalluvial_width(self):
- for key in [2, "2", "floodplain", "FloodPlain1"]:
- try:
+ def nonalluvial_width(self) -> int:
+ """Get non-alluvial width."""
+ for key in ["2", "floodplain", "FloodPlain1"]:
+ if key in self.section_widths:
return self.section_widths[key]
- except KeyError:
- pass
return 0
@property
- def face_points_list(self):
+ def face_points_list(self) -> list:
+ """Get list of face points."""
return self.__output_face_list
@property
- def edge_points_list(self):
+ def edge_points_list(self) -> list:
+ """Get list of edge points."""
return self.__output_edge_list
- def get_point_list(self, pointtype):
- if pointtype == "face":
+ def get_point_list(self, point_type: str) -> list:
+ """Get list of points based on point type."""
+ if point_type == "face":
return self.face_points_list
- elif pointtype == "edge":
+ if point_type == "edge":
return self.edge_points_list
- else:
- raise ValueError('pointtype must be "face" or "edge"')
+ err_msg = 'pointtype must be "face" or "edge"'
+ raise ValueError(err_msg)
# Public functions
- def build_geometry(self) -> None:
- """
- This methods builds 1D geometrical cross-section from 2D data.
+ def build_geometry(self) -> None: # noqa: PLR0915
+ """Build 1D geometrical cross-section from 2D data.
+
The 2D data is set on initalisation of the `CrossSection` object.
The methods modifies the following attributes
Attributes:
+ ----------
_fm_wet_area
_fm_flow_area
_fm_total_volume
@@ -257,25 +278,19 @@ def build_geometry(self) -> None:
_css_flow_width
"""
- fm_data: Dict = self._fm_data
+ fm_data: dict = self._fm_data
# Unpack FM data
- def get_timeseries(name: str):
- """
- Returns data from fm_data after applying the
- skip_maps and checking for missing numbers
- """
- imiss = (
- -999
- ) # this should be equal to the missing number value from D-HYDRO
-
+ def get_timeseries(name: str) -> np.array:
+ """Return data from fm_data after applying the skip_maps and checking for missing numbers."""
data = fm_data[name].iloc[
- :, self.get_parameter(self.__cs_parameter_skip_maps) :
+ :,
+ self.get_parameter(self.__cs_parameter_skip_maps) :,
]
# map missing value numbers to
- if imiss in data.values:
+ if NODATA in data.to_numpy():
self.set_logger_message(r"Missing data found in {name}", "warning")
- data[data == imiss] = np.nan
+ data[data == NODATA] = np.nan
return data
waterlevel = get_timeseries("waterlevel")
@@ -289,16 +304,19 @@ def get_timeseries(name: str):
# (much more efficient than for-loops)
area_matrix = pd.DataFrame({col: area for col in waterdepth.columns})
area_matrix.index = area.index
-
+
bedlevel_matrix = pd.DataFrame({col: area for col in waterdepth.columns})
bedlevel_matrix.index = bedlevel.index
-
# Retrieve the water-depth
# & water level nearest to the cross-section location
self.set_logger_message("Retrieving centre point values")
- (centre_depth, centre_level) = FE.get_centre_values(
- self.location, fm_data["x"], fm_data["y"], waterdepth, waterlevel
+ (centre_depth, centre_level) = funcs.get_centre_values(
+ self.location,
+ fm_data["x"],
+ fm_data["y"],
+ waterdepth,
+ waterlevel,
)
# Identify river lakes (plassen)
@@ -316,7 +334,9 @@ def get_timeseries(name: str):
if self.get_inifile().get_parameter("ExportCSSData"): # pickle css data
output_dir = self.get_inifile().get_output_directory()
self.set_logger_message(f"pickling to {output_dir}")
- with open(output_dir.joinpath(f"{self.name}_flowmask.pickle"), "wb") as f:
+ with Path(output_dir).joinpath(f"{self.name}_flowmask.pickle").open(
+ "wb",
+ ) as f:
pickle.dump(flow_mask, f)
# Calculate area and volume as function of waterlevel & waterdepth
@@ -339,8 +359,15 @@ def get_timeseries(name: str):
# Compute 2D volume as sum of area times depth
self._fm_total_volume = np.array(
np.nansum(
- area_matrix[wet_not_plas_mask] * waterdepth[wet_not_plas_mask], axis=0
- )
+ area_matrix[wet_not_plas_mask] * waterdepth[wet_not_plas_mask],
+ axis=0,
+ ),
+ )
+ self._fm_flow_volume = np.array(
+ np.nansum(area_matrix[flow_mask] * waterdepth[flow_mask], axis=0),
+ )
+ self._fm_flow_volume = np.array(
+ np.nansum(area_matrix[flow_mask] * waterdepth[flow_mask], axis=0),
)
self._fm_flow_volume = np.array(
np.nansum(area_matrix[flow_mask] * waterdepth[flow_mask], axis=0)
@@ -373,10 +400,12 @@ def get_timeseries(name: str):
# Compute 1D volume as integral of width with respect to z times length
self._css_total_volume = np.append(
- [0], cumulative_trapezoid(self._css_total_width, self._css_z) * self.length
+ [0],
+ cumulative_trapezoid(self._css_total_width, self._css_z) * self.length,
)
self._css_flow_volume = np.append(
- [0], cumulative_trapezoid(self._css_flow_width, self._css_z) * self.length
+ [0],
+ cumulative_trapezoid(self._css_flow_width, self._css_z) * self.length,
)
# If sd correction is run, these attributes will be updated.
@@ -387,11 +416,8 @@ def get_timeseries(name: str):
# (apparently entries can be float32)
self._css_z = np.array(self._css_z, dtype=np.dtype("float64"))
- def check_requirements(self):
- """
- Performs check on cross-section such that it
- hold up to requirements.
- """
+ def check_requirements(self) -> None:
+ """Perform check on cross-section such that it hold up to requirements."""
# Remove multiple zeroes in the bottom of the cross-section
self._css_index_of_first_nonzero = self._check_remove_duplicate_zeroes()
self._css_total_width = self._check_remove_zero_widths(self._css_total_width)
@@ -406,10 +432,10 @@ def check_requirements(self):
self._check_total_width_greater_than_flow_width()
def calculate_correction(self) -> None:
- """
- This method computes a volume correction that cannot be captured within the
- constraints of 1D cross-sectional geometry, which forces an increasing width
- at increasing elevation and cannot deal with varying water levels. In reality,
+ r"""Compute a volume correction.
+
+ A volume correction that cannot be captured within the constraints of 1D cross-sectional geometry, which forces
+ an increasing width at increasing elevation and cannot deal with varying water levels. In reality,
compartimentation of floodplains can cause a sudden
increase of volume while the water level does not increase.
@@ -440,15 +466,14 @@ def calculate_correction(self) -> None:
In SOBEK, the extra volume is released following a polynomial function. In |project|, we approximate
this with the following logistic function:
- :math:`C(h_k)=\Xi(1+e^{\log(\delta)\\tau^{-1}(h_k-(\gamma+\\tau/2))})^{-1}`
+ :math:`C(h_k)=\\Xi(1+e^{\\log(\\delta)\\tau^{-1}(h_k-(\\gamma+\\tau/2))})^{-1}`
- where :math:`\Xi` is the volume correction (m3/s), :math:`\\tau` is the transition height,
- :math:`\delta` is an accuracy parameter, :math:`\gamma` is the crest level, and :math:`C(h_k)`
- is the added volume given water level :math:`h_k`. The default value of :math:`\delta` is
+ where :math:`\\Xi` is the volume correction (m3/s), :math:`\\tau` is the transition height,
+ :math:`\\delta` is an accuracy parameter, :math:`\\gamma` is the crest level, and :math:`C(h_k)`
+ is the added volume given water level :math:`h_k`. The default value of :math:`\\delta` is
0.00001, and not configurable.
"""
-
# Set initial values for optimisation of parameters
initial_total_error = self._css_total_volume - self._fm_total_volume
initial_flow_error = self._css_flow_volume - self._fm_flow_volume
@@ -461,15 +486,19 @@ def calculate_correction(self) -> None:
"Initial crest: {:.4f} m".format(initial_crest), level="debug"
)
self.set_logger_message(
- "Initial extra total area: {:.4f} m2".format(
- initial_total_volume / self.length
- ),
+ f"Initial crest: {initial_crest:.4f} m",
level="debug",
)
self.set_logger_message(
- "Initial extra flow area: {:.4f} m2".format(
- initial_flow_volume / self.length
- ),
+ f"Initial crest: {initial_crest:.4f} m",
+ level="debug",
+ )
+ self.set_logger_message(
+ f"Initial extra total area: {initial_total_volume / self.length:.4f} m2",
+ level="debug",
+ )
+ self.set_logger_message(
+ f"Initial extra flow area: {initial_flow_volume / self.length:.4f} m2",
level="debug",
)
@@ -487,25 +516,30 @@ def calculate_correction(self) -> None:
extra_flow_volume = opt.get("extra_flow_volume")
self.set_logger_message(
- "final costs: {:.2f}".format(opt.get("final_cost")), level="debug"
+ "final costs: {:.2f}".format(opt.get("final_cost")),
+ level="debug",
)
self.set_logger_message(
- "Optimizer msg: {}".format(opt.get("message")), level="debug"
+ "Optimizer msg: {}".format(opt.get("message")),
+ level="debug",
)
self.set_logger_message(
- "Final crest: {:.4f} m".format(crest_level), level="debug"
+ f"Final crest: {crest_level:.4f} m",
+ level="debug",
)
self.set_logger_message(
- "Final total area: {:.4f} m2".format(extra_total_volume / self.length),
+ f"Final total area: {extra_total_volume / self.length:.4f} m2",
level="debug",
)
self.set_logger_message(
- "Final flow area: {:.4f} m2".format(extra_flow_volume / self.length),
+ f"Final flow area: {extra_flow_volume / self.length:.4f} m2",
level="debug",
)
extra_area_percentage = self._get_extra_total_area(
- self._css_z, crest_level, transition_height
+ self._css_z,
+ crest_level,
+ transition_height,
)
# Write to attributes
@@ -525,13 +559,12 @@ def calculate_correction(self) -> None:
self._css_is_corrected = True
def assign_roughness(self) -> None:
- """
- This function builds a table of Chezy values as function of water level
+ """Build a table of Chezy values as function of water level.
+
The roughnes is divides into two sections on the assumption of
an alluvial (smooth) and nonalluvial (rough) part of the total
cross-section. This division is made based on the final timestep.
"""
-
# Compute roughness tabels
self.set_logger_message("Building roughness table", "debug")
self._build_roughness_tables()
@@ -544,15 +577,16 @@ def assign_roughness(self) -> None:
# done
def get_number_of_faces(self) -> int:
- """use this method to return the number of 2D faces within control volume"""
+ """Return the number of 2D faces within control volume."""
return len(self._fm_data.get("x"))
def get_number_of_vertices(self) -> int:
- """Use this method to return the current number of geometry vertices"""
+ """Return the current number of geometry vertices."""
return len(self._css_total_width)
def reduce_points(self, count_after: int = 20) -> None:
- """
+ """Reduce points while preserving the shape of the geometry.
+
The cross-section geometry generated by `fm2prof` contains one point per output
timestep in the 2D map file. This resolution is often too high given the
complexity of the cross-sections, and results in very large input files for the
@@ -564,25 +598,28 @@ def reduce_points(self, count_after: int = 20) -> None:
We use the Visvalingam-Whyatt method of poly-line vertex reduction[^1].
The [total width](glossary.md#total-width) is leading for the simplification of the geometry meaning
that the choice for which points to remove to simplify the geometry is based on
- the total width. Subsequently, the corresponding point are removed from the [flow width](glossary.md#flow-width).
+ the total width. Subsequently, the corresponding point are removed from the
+ [flow width](glossary.md#flow-width).
[^1]:
- Visvalingam, M and Whyatt J D (1993) "Line Generalisation by Repeated Elimination of Points", Cartographic J., 30 (1), 46 - 51 URL: http://web.archive.org/web/20100428020453/http://www2.dcs.hull.ac.uk/CISRG/publications/DPs/DP10/DP10.html
+ Visvalingam, M and Whyatt J D (1993) "Line Generalisation by Repeated Elimination of Points",
+ Cartographic J., 30 (1), 46 - 51
+ URL: http://web.archive.org/web/20100428020453/http://www2.dcs.hull.ac.uk/CISRG/publications/DPs/DP10/DP10.html
Implemented vertex reduction methods:
- Parameters:
- count_after: number of points in cross-section after application of this function
+ Args:
+ ----
+ count_after (int): number of points in cross-section after application of this function
"""
-
n_before_reduction = self.get_number_of_vertices()
points = np.array(
[
[self._css_z[i], self._css_total_width[i]]
for i in range(n_before_reduction)
- ]
+ ],
)
# The number of points is equal to n, it cannot be further reduced
@@ -590,11 +627,11 @@ def reduce_points(self, count_after: int = 20) -> None:
if n_before_reduction > count_after:
try:
- simplifier = PS.VWSimplifier(points)
+ simplifier = ps.VWSimplifier(points)
reduced_index = simplifier.from_number_index(count_after)
except Exception as e:
self.set_logger_message(
- "Exception thrown while using polysimplify: " + "{}".format(str(e)),
+ "Exception thrown while using polysimplify: " + f"{e!s}",
"error",
)
@@ -604,24 +641,13 @@ def reduce_points(self, count_after: int = 20) -> None:
self.flow_width = self._css_flow_width[reduced_index]
self.set_logger_message(
- "Cross-section reduced "
- + "from {} ".format(n_before_reduction)
- + "to {} points".format(len(self.total_width))
+ "Cross-section reduced from {n_before_reduction} to {len(self.total_width)} points",
)
self._css_is_reduced = True
- def set_face_output_list(self):
- """
- Generates a list of output mask points based on
- their values in the mask.
-
- writes to self.__output_mask_list
-
- Paramters:
- fm_data {dict} -- Dictionary containing x,y values.
- mask_array {NP.array} -- Array of values.
- """
+ def set_face_output_list(self) -> None:
+ """Generate a list of output mask points based on their values in the mask."""
fm_data = self._fm_data
# Properties keys
@@ -651,47 +677,35 @@ def set_face_output_list(self):
}
mask_coords = (x_coords[i], y_coords[i])
# Create the actual geojson element.
- output_mask = MaskOutputFile.create_mask_point(
- mask_coords, mask_properties
+ output_mask = mask_output_file.create_mask_point(
+ mask_coords,
+ mask_properties,
)
if output_mask.is_valid:
self.__output_face_list.append(output_mask)
- # self.set_logger_message(
- # 'Added output mask at {} '.format(mask_coords) +
- # 'for Cross Section {}.'.format(self.name),
- # level='debug')
else:
self.set_logger_message(
- "Invalid output mask at {} ".format(mask_coords)
- + "for Cross Section {}, not added. ".format(self.name)
- + "Reason {}".format(output_mask.errors()),
+ f"Invalid output mask at {mask_coords} "
+ f"for Cross Section {self.name}, not added. "
+ f"Reason {output_mask.errors()}",
level="error",
)
except Exception as e_error:
self.set_logger_message(
- "Error setting output masks "
- + "for Cross Section {}. ".format(self.name)
- + "Reason: {}".format(str(e_error)),
+ f"Error setting output masks for Cross Section {self.name}. Reason: {e_error!s}",
level="error",
)
- def set_edge_output_list(self):
- """
- Generates a list of output mask points based on
- their values in the mask.
+ def set_edge_output_list(self) -> None:
+ """Generate a list of output mask points based on their values in the mask.
writes to self.__output_mask_list
-
- Parameters:
- fm_data {dict} -- Dictionary containing x,y values.
- mask_array {NP.array} -- Array of values.
"""
fm_data = self._fm_data
# Properties keys
cross_section_id_key = "cross_section_id"
- cross_section_region_key = "region"
roughness_section_key = "section"
try:
@@ -707,55 +721,46 @@ def set_edge_output_list(self):
}
mask_coords = (x_coords[i], y_coords[i])
# Create the actual geojson element.
- output_mask = MaskOutputFile.create_mask_point(
- mask_coords, mask_properties
+ output_mask = mask_output_file.create_mask_point(
+ mask_coords,
+ mask_properties,
)
if output_mask.is_valid:
self.__output_edge_list.append(output_mask)
else:
self.set_logger_message(
- "Invalid output mask at {} ".format(mask_coords)
- + "for Cross Section {}, not added. ".format(self.name)
- + "Reason {}".format(output_mask.errors()),
+ f"Invalid output mask at {mask_coords} "
+ f"for Cross Section {self.name}, not added. "
+ f"Reason {output_mask.errors()}",
level="error",
)
except Exception as e_error:
self.set_logger_message(
- "Error setting output masks "
- + "for Cross Section {}. ".format(self.name)
- + "Reason: {}".format(str(e_error)),
+ f"Error setting output masks for Cross Section {self.name}. Reason: {e_error!s}",
level="error",
)
- def _check_remove_duplicate_zeroes(self):
- """
- Removes duplicate zeroes in the total width
- """
-
+ def _check_remove_duplicate_zeroes(self) -> int | float:
+ """Remove duplicate zeroes in the total width."""
# Remove multiple 0s in the total width
- index_of_first_nonzero = max(1, np.argwhere(self._css_total_width != 0)[0][0])
-
- return index_of_first_nonzero
+ return max(1, np.argwhere(self._css_total_width != 0)[0][0])
@staticmethod
- def _return_first_item_and_after_index(listin, after_index):
+ def _return_first_item_and_after_index(listin: list, after_index: int) -> list:
return np.append(listin[0], listin[after_index:].tolist())
- def _check_remove_zero_widths(self, width_array):
- """
- A zero width may lead to numerical instability
- """
+ def _check_remove_zero_widths(self, width_array: np.array) -> np.array:
+ """Remove zero width values.
+ A zero width may lead to numerical instability.
+ """
minwidth = self.get_parameter(self.__cs_parameter_minwidth)
width_array[width_array < minwidth] = minwidth
-
return width_array
- def _combined_optimisation_func(self, opt_in):
- """
- Cost function, combines total volume error and flow volume error
- """
+ def _combined_optimisation_func(self, opt_in: tuple[float]) -> np.ndarray:
+ """Cost function, combines total volume error and flow volume error."""
(crest_level, extra_total_volume, extra_flow_volume) = opt_in
transition_height = self.get_parameter(self.__cs_parameter_transitionheight_sd)
@@ -776,9 +781,9 @@ def _combined_optimisation_func(self, opt_in):
self._fm_total_volume + self._fm_flow_volume,
)
- def _optimisation_func(self, opt_in, *args):
- """
- Objective function used in optimising a delta-h correction
+ def _optimisation_func(self, opt_in: tuple[float], *args: tuple) -> np.ndarray:
+ """Objective function used in optimising a delta-h correction.
+
for parameters:
crest_level : level at which the correction begins
transition_height : height over which volume is released
@@ -806,24 +811,26 @@ def _optimisation_func(self, opt_in, *args):
return self._return_volume_error(predicted_volume, self._fm_total_volume)
def _optimize_sd_storage(
- self, initial_crest, initial_total_volume, initial_flow_volume
+ self,
+ initial_crest: float,
+ initial_total_volume: float,
+ initial_flow_volume: float,
) -> dict:
- """
- Optimised the crest level and volumes
+ """Optimised the crest level and volumes.
Returns:
+ -------
Dictionary with optimised values, final cost and optimisation message
+
"""
# Default option
sdoptimisationmethod = self.get_parameter(
- self.__cs_parameter_sdoptimisationmethod
+ self.__cs_parameter_sdoptimisationmethod,
)
if sdoptimisationmethod not in [0, 1, 2]:
# this should be handled in inifile instead
self.set_logger_message(
- "sdoptimisationmethod is {} but should be 0, 1, or 2. Defaulting to 0".format(
- sdoptimisationmethod
- ),
+ f"sdoptimisationmethod is {sdoptimisationmethod} but should be 0, 1, or 2. Defaulting to 0",
level="warning",
)
sdoptimisationmethod = 0
@@ -877,9 +884,10 @@ def _optimize_sd_storage(
)
extra_total_volume = np.max([np.max([opt2["x"][0], 0]), extra_flow_volume])
- elif self.get_parameter(self.__cs_parameter_sdoptimisationmethod) == 2:
+ elif self.get_parameter(self.__cs_parameter_sdoptimisationmethod) == 2: # noqa: PLR2004
self.set_logger_message(
- "Optimising SD on both flow and total volumes", level="debug"
+ "Optimising SD on both flow and total volumes",
+ level="debug",
)
opt = so.minimize(
self._combined_optimisation_func,
@@ -900,17 +908,17 @@ def _optimize_sd_storage(
"message": opt["message"],
}
- def _check_increasing_order(self, list_points):
- """runs"""
+ def _check_increasing_order(self, list_points: list) -> list:
for i in range(1, len(list_points)):
if list_points[i] <= list_points[i - 1]:
list_points[i] = list_points[i - 1] + 0.001
return list_points
- def _build_roughness_tables(self):
+ def _build_roughness_tables(self) -> None:
# Find roughness tables for each section
chezy_fm = self._fm_data.get("chezy").iloc[
- :, self.get_parameter(self.__cs_parameter_skip_maps) :
+ :,
+ self.get_parameter(self.__cs_parameter_skip_maps) :,
]
sections = np.unique(self._fm_data.get("edge_section"))
@@ -922,27 +930,31 @@ def _build_roughness_tables(self):
elif self.get_parameter(self.__cs_parameter_Frictionweighing) == 1:
friction = self._friction_weighing_area(chezy_section, section)
else:
- raise ValueError(
- "unknown option for roughness weighing: {}".format(
- self.get_parameter(self.__cs_parameter_Frictionweighing)
- )
+ err_msg = (
+ "unknown option for roughness weighing: "
+ f"{self.get_parameter(self.__cs_parameter_Frictionweighing)}"
)
+ raise ValueError(err_msg)
self.friction_tables[self._section_map[str(section)]] = FrictionTable(
- level=self._css_z_roughness, friction=friction
+ level=self._css_z_roughness,
+ friction=friction,
)
- def _friction_weighing_simple(self, link_chezy):
- """Simple mean, no weight"""
+ def _friction_weighing_simple(self, link_chezy: np.ndarray) -> np.ndarray:
+ """Calculate mean of chezy values."""
# Remove chezy where zero
- link_chezy = link_chezy.replace(0, np.NaN)
- output = link_chezy.mean(axis=0).replace(np.NaN, 0)
+ link_chezy = link_chezy.replace(0, np.nan)
+ output = link_chezy.mean(axis=0).replace(np.nan, 0)
- return output.values
+ return output.to_numpy()
- def _friction_weighing_area(self, link_chezy, section):
- """
- Compute chezy by weighted average. Weights are determined based on area.
+ def _friction_weighing_area(
+ self,
+ link_chezy: pd.DataFrame,
+ section: str,
+ ) -> np.ndarray:
+ """Compute chezy by weighted average. Weights are determined based on area.
Friction values are known at flow links, while areas are known at flow faces.
@@ -950,34 +962,29 @@ def _friction_weighing_area(self, link_chezy, section):
"""
# Remove chezy where zero
- link_chezy = link_chezy.replace(0, np.NaN)
+ link_chezy = link_chezy.replace(0, np.nan)
# efs are the two faces the edge connects to
efs = self._fm_data["edge_faces"][self._fm_data["edge_section"] == section]
- link_area = []
- for ef in efs:
- # compute the mean area for the two connecting faces
- link_area.append(self._fm_data.get("area_full").reindex(ef).mean())
-
+ # compute the mean area for the two connecting faces
+ link_area = [self._fm_data.get("area_full").reindex(ef).mean() for ef in efs]
# the weight of one link is defined as the sum of the linked areas
link_weight = link_area / np.sum(link_area)
- output = np.sum(link_chezy.values.T * link_weight, axis=1)
+ output = np.sum(link_chezy.to_numpy().T * link_weight, axis=1)
output[np.isnan(output)] = 0
return output
def _compute_section_widths(self) -> None:
- """
- Computes sections widths by dividing the area assigned to a section
- by the length of the cross-section.
+ """Compute sections widths by dividing the area assigned to a section by the length of the cross-section.
If the sum of the section widths is smaller than the flow width, the
width is increase proportionally
"""
-
- unassigned_area = sum(self._fm_data["area"][self._fm_data["section"] == -999])
+ unassigned_area = sum(self._fm_data["area"][self._fm_data["section"] == NODATA])
if unassigned_area > 0:
self.set_logger_message(
- f"{unassigned_area} m2 was not assigned to any section in input files, and is added to the main section",
+ f"{unassigned_area} m2 was not assigned to any section in input files, and"
+ " is added to the main section",
"warning",
)
@@ -999,17 +1006,17 @@ def _compute_section_widths(self) -> None:
self._check_section_widths_greater_than_minimum_width()
def _compute_floodplain_base(self) -> None:
- """
- Sets the self.floodplain_base attribute. The floodplain
- will be set at least 0.5 meter below the crest of the
+ """Set the self.floodplain_base attribute.
+
+ The floodplain will be set at least 0.5 meter below the crest of the
embankment, and otherwise at the average hight of the floodplain
"""
tolerance = self.get_inifile().get_parameter("sdfloodplainbase")
# Mean bed level in section 2 (floodplain)
- floodplain_mask = self._fm_data.get("section") == 2
+ floodplain_mask = self._fm_data.get("section") == 2 # noqa: PLR2004
if floodplain_mask.sum():
mean_floodplain_elevation = np.nanmean(
- self._fm_data["bedlevel"][floodplain_mask]
+ self._fm_data["bedlevel"][floodplain_mask],
)
# Tolerance. Base level must at least be some below the crest to prevent
@@ -1019,8 +1026,8 @@ def _compute_floodplain_base(self) -> None:
self.floodplain_base = self.crest_level - tolerance
self.set_logger_message(
f"Mean floodpl. elev. ({mean_floodplain_elevation:.2f} m)"
- + f"higher than crest level ({self.crest_level:.2f}) + "
- + f"tolerance ({tolerance} m)",
+ f"higher than crest level ({self.crest_level:.2f}) + "
+ f"tolerance ({tolerance} m)",
"warning",
)
else:
@@ -1032,31 +1039,33 @@ def _compute_floodplain_base(self) -> None:
else:
self.floodplain_base = self.crest_level - tolerance
self.set_logger_message(
- f"No Floodplain found, floodplain defaults to {self.crest_level - tolerance}"
+ f"No Floodplain found, floodplain defaults to {self.crest_level - tolerance}",
)
- def _calc_chezy(self, depth, manning):
- return depth ** (1 / float(6)) / manning
-
def _identify_lakes(self, waterdepth: pd.DataFrame) -> np.ndarray:
- """
- This algorithms determines whether a 2D cell should
- be marked as [Lake](glossary.md#Lakes).
+ """Determine whether a 2D cell should be marked as [Lake](glossary.md#Lakes).
Cells are marked as lake if the following conditions are both met:
- the waterdepth on timestep [LakeTimeSteps](configuration.md#exec-1--laketimesteps) is positive
- - the waterdepth on timestep [LakeTimeSteps](configuration.md#exec-1--laketimesteps) is at least 1 cm higher than the waterlevel on timestep 0.
+ - the waterdepth on timestep [LakeTimeSteps](configuration.md#exec-1--laketimesteps) is at least
+ 1 cm higher than the waterlevel on timestep 0.
Next, the following steps are taken
- - It is determined at what timestep the waterlevel in the lake starts rising. From that point on the lake counts as regular geometry and counts toward the total volume. A cell is considered active if its waterlevel has risen by 1 mm.
- - A correction matrix is built that contains the 'lake water level' for each lake cell. This matrix is subtracted from the waterdepth to compute volumes.
+ - It is determined at what timestep the waterlevel in the lake starts rising. From that point on the lake counts
+ as regular geometry and counts toward the total volume. A cell is considered active if its waterlevel has risen
+ by 1 mm.
+ - A correction matrix is built that contains the 'lake water level' for each lake cell. This matrix is
+ subtracted from the waterdepth to compute volumes.
- Parameters:
- waterdepth: a DataFrame containing all waterdepth output in the [control volume](glossary.md#control-volume)
+ Args:
+ ----
+ waterdepth (pd.DataFrame): a DataFrame containing all waterdepth output in the [control volume]
+ (glossary.md#control-volume)
Returns:
+ -------
lake_mask: mask of all cells that are a 'lake'
wet_not_lake_mask: mask of all cells that are wet, but not a lake
lake_depth_correction: the depth of a lake at the start of the 2D computation
@@ -1069,7 +1078,7 @@ def _identify_lakes(self, waterdepth: pd.DataFrame) -> np.ndarray:
waterdepth_diff = np.diff(waterdepth, n=1, axis=-1)
# find all wet cells
- wet_mask = waterdepth > 0
+ wet_mask: np.ndarray = waterdepth > 0
# find all lakes
lake_mask = (
@@ -1080,9 +1089,9 @@ def _identify_lakes(self, waterdepth: pd.DataFrame) -> np.ndarray:
waterdepth.T.iloc[
self.get_parameter(self.__cs_parameter_plassen_timesteps)
]
- - waterdepth.T.iloc[0]
+ - waterdepth.T.iloc[0],
)
- <= 0.01
+ <= 0.01 # noqa: PLR2004 NOTE: What is this value?
)
self.plassen_mask = lake_mask
@@ -1097,25 +1106,28 @@ def _identify_lakes(self, waterdepth: pd.DataFrame) -> np.ndarray:
# when to unmask a plassen cell
for i, diff in enumerate(waterdepth_diff.T):
final_mask = reduce(
- np.logical_and, [(diff <= 0.001), (plassen_mask_time[i] == True)]
+ np.logical_and,
+ [(diff <= 0.001), (plassen_mask_time[i] == True)], # noqa: E712, PLR2004, NOTE what does 0.001 represent?
)
plassen_mask_time[i + 1, :] = final_mask
plassen_mask_time = pd.DataFrame(plassen_mask_time).T
# The depth of a lake is the waterdepth at t=0
- for i, depths in enumerate(waterdepth):
+ for i, _depths in enumerate(waterdepth):
plassen_depth_correction[lake_mask, i] = -waterdepth.T.iloc[0][lake_mask]
# correct wet cells for plassen
wet_not_plas_mask = reduce(
- np.logical_and, [(wet_mask == True), np.asarray(plassen_mask_time == False)]
+ np.logical_and,
+ [(wet_mask == True), np.asarray(plassen_mask_time == False)], # noqa: E712
)
return lake_mask, wet_not_plas_mask, plassen_depth_correction
- def _compute_css_above_z0(self, centre_level) -> None:
- """
+ def _compute_css_above_z0(self, centre_level: np.ndarray) -> None:
+ """Compute total and flow width for each level (z) aboove the water level at the first 2D output(z0).
+
This method computes for each level (z) above the water level at the first 2D output (z0),
the corresponding :term:`Total width` and :term:`Flow width`. This is done in the following way:
@@ -1124,13 +1136,15 @@ def _compute_css_above_z0(self, centre_level) -> None:
3. Correct the flow width such that flow width is always increasing.
Args:
- centre_level: the water level at the cross-section location (x, y), which is typically at the centre of the control volume
+ ----
+ centre_level: the water level at the cross-section location (x, y), which is typically at the centre of the
+ control volume
Return:
-
+ ------
None: this method writes to the cross-section attributes _css_z, _css_total_width and _css_flow_width
- """
+ """
# Set the level
self._css_z = centre_level
@@ -1144,13 +1158,16 @@ def _compute_css_above_z0(self, centre_level) -> None:
# Flow width must increase at each z
for i in range(2, len(self._css_flow_width) + 1):
self._css_flow_width[-i] = np.min(
- [self._css_flow_width[-i], self._css_flow_width[-i + 1]]
+ [self._css_flow_width[-i], self._css_flow_width[-i + 1]],
)
def _distinguish_conveyance_from_storage(
- self, waterdepth: pd.DataFrame, velocity: pd.DataFrame
+ self,
+ waterdepth: pd.DataFrame,
+ velocity: pd.DataFrame,
) -> pd.DataFrame:
- """
+ """Distinguish conveyance from storage.
+
In 1D hydrodynamic models, flow through a cross-section is resolved assuming a
cross-sectionally average velocity. This assumed that the entire cross-section
is available to for conveyance. However in reality some parts of the cross-section
@@ -1161,31 +1178,39 @@ def _distinguish_conveyance_from_storage(
methods to resolve from 2D model output which cells add to the 'flow volume' within a
[control volume](glossary.md#control-volume) and which to the storage volume.
- `fm2prof` implements two methods. The configuration parameter [`ConveyanceDetectionMethod`](configuration.md#exec-1--conveyancedetectionmethod) is used
+ `fm2prof` implements two methods. The configuration parameter [`ConveyanceDetectionMethod`]
+ (configuration.md#exec-1--conveyancedetectionmethod) is used
to determine which method is used.
**`max_method`**
A cell is considered flowing if the velocity magnitude is more than the average
of the three higher flow velocities per outputmap multiplied by the
[`relative velocity threshold`](configuration.md#exec-1--relativevelocitythreshold) OR
- if the flow velocity meets the absolute threshold [`absolute velocity threshold`](configuration.md#exec-1--absolutevelocitythreshold)
+ if the flow velocity meets the absolute threshold [`absolute velocity threshold`]
+ (configuration.md#exec-1--absolutevelocitythreshold)
**`mean_method`**
Not recommended. Legacy method.
- Parameters:
- waterdepth: dataframe of a control volume with waterdepths per cel per output map
- velocity: dataframe of a control volume with velocity magnitude per cel per output map
+ Args:
+ ----
+ waterdepth (pd.DataFrame): dataframe of a control volume with waterdepths per cel per output map
+ velocity (pd.DataFrame): dataframe of a control volume with velocity magnitude per cel per output map
Returns:
- flow_mask: dataframe of a control volume with the flow condition per cel per output map. `True` means flowing, `False` storage.
+ -------
+ flow_mask: dataframe of a control volume with the flow condition per cel per output map. `True` means
+ flowing, `False` storage.
+
"""
@staticmethod
def max_velocity_method(
- waterdepth: pd.DataFrame, velocity: pd.DataFrame
+ waterdepth: pd.DataFrame,
+ velocity: pd.DataFrame,
) -> pd.DataFrame:
- """
+ """Calculate the max velocity.
+
This method was added in version 2.3 because the mean_velocity_method
led to unreasonably high conveyance if the river was connected to
an inland harbour.
@@ -1200,24 +1225,26 @@ def max_velocity_method(
# Relative to max condition
relative_velocity_condition = velocity > maxv * self.get_parameter(
- self.__cs_parameter_relative_threshold
+ self.__cs_parameter_relative_threshold,
)
# Absolute flow condition
absolute_velocity_condition = velocity > self.get_parameter(
- self.__cs_parameter_velocity_threshold
+ self.__cs_parameter_velocity_threshold,
)
# Flow mask determines which cells are conveyance (TRUE)
- flow_mask = waterdepth_condition & (
+ return waterdepth_condition & (
relative_velocity_condition | absolute_velocity_condition
)
- return flow_mask
-
@staticmethod
- def mean_velocity_method(waterdepth, velocity):
- """
+ def mean_velocity_method(
+ waterdepth: np.ndarray,
+ velocity: np.ndarray,
+ ) -> np.ndarray:
+ """Calculate mean velocity.
+
This was the default method < 2.3. This method leads to unreasonably
high conveyance if the river was connected to an inland harbour.
"""
@@ -1225,7 +1252,7 @@ def mean_velocity_method(waterdepth, velocity):
# to smooth out extreme values
velocity = velocity.rolling(window=10, min_periods=1, center=True).mean()
- flow_mask = (
+ return (
(waterdepth > 0)
& (
velocity
@@ -1238,10 +1265,8 @@ def mean_velocity_method(waterdepth, velocity):
)
)
- return flow_mask
-
match self.get_inifile().get_parameter(
- self.__cs_parameter_conveyance_detection_method
+ self.__cs_parameter_conveyance_detection_method,
):
case 0:
return mean_velocity_method(waterdepth, velocity)
@@ -1249,18 +1274,20 @@ def mean_velocity_method(waterdepth, velocity):
return max_velocity_method(waterdepth, velocity)
case _:
self.set_logger_message(
- "Invalid conveyance method. Defaulting to [1]", "warning"
+ "Invalid conveyance method. Defaulting to [1]",
+ "warning",
)
return max_velocity_method(waterdepth, velocity)
def _extend_css_below_z0(
self,
- centre_level,
- centre_depth,
- waterlevel,
- wet_not_plas_mask,
+ centre_level: np.ndarray,
+ centre_depth: np.ndarray,
+ waterlevel: pd.DataFrame,
+ wet_not_plas_mask: np.ndarray,
) -> None:
- """
+ """Compute the total and flow width for the level (z) below the water level.
+
This methods computeS for level (z) below the water level at the first 2D output (z0) the corresponding
:term:`Total width` and :term:`Flow width`. This is done in the following way:
@@ -1268,10 +1295,12 @@ def _extend_css_below_z0(
2. for each step, determine which cells should be counted
- the bed level of the cell should be higher than the water level plus the tolerance (see note below)
- the cell should not be part of a :term:`Lakes`
- 3. Since there is no information on flow velocities, we cannot determine which cells are flowing and width are storage.
- Therefore is is decided like this
- - If :ref:`ExtrapolateStorage ` is True, the flow area is the minimum of the flow area at t0 (from :ref:`wl_dependent_css`)
- - if :ref:`ExtrapolateStorage ` is False, the flow area is equal to the total area
+ 3. Since there is no information on flow velocities, we cannot determine which cells are flowing and width are
+ storage. Therefore is is decided like this
+ - If :ref:`ExtrapolateStorage ` is True, the flow area is the minimum of the
+ flow area at t0 (from :ref:`wl_dependent_css`)
+ - if :ref:`ExtrapolateStorage ` is False, the flow area is equal to the total
+ area
.. note::
- the number of steps between z at t0 to bed level is hard-coded at 10.
@@ -1279,6 +1308,7 @@ def _extend_css_below_z0(
Attributes:
+ ----------
_css_z
_css_total_width
_css_flow_width
@@ -1286,16 +1316,16 @@ def _extend_css_below_z0(
_fm_flow_area
_fm_flow_volume
_fm_total_volume
- """
- bedlevel = self._fm_data.get("bedlevel").values
- cell_area = self._fm_data.get("area").values
+ """
+ bedlevel = self._fm_data.get("bedlevel").to_numpy()
+ cell_area = self._fm_data.get("area").to_numpy()
flow_area_at_z0 = self._fm_flow_area[0]
lowest_level_of_css = (
centre_level[0] - centre_depth[0]
) # this is in fact the bed level at centre point
centre_level_at_t0 = centre_level[0]
- waterlevel_at_t0 = waterlevel.values[:, 0]
+ waterlevel_at_t0 = waterlevel.to_numpy()[:, 0]
waterdepth_at_t0 = waterlevel_at_t0 - bedlevel
waterdepth_at_t0[np.isnan(waterdepth_at_t0)] = 0
tolerance = -1e-3 # at last point, this is still considered wet.
@@ -1304,7 +1334,7 @@ def _extend_css_below_z0(
for dz in np.linspace(0, centre_level_at_t0 - lowest_level_of_css, 10):
centre_level_at_dz = centre_level_at_t0 - dz
total_wet_area = np.nansum(
- cell_area[((waterdepth_at_t0 - dz) > tolerance) & wet_not_plas_mask]
+ cell_area[((waterdepth_at_t0 - dz) > tolerance) & wet_not_plas_mask],
)
# Extension of flow/storage below z0
@@ -1315,22 +1345,28 @@ def _extend_css_below_z0(
self._css_z = self._append_to_start(self._css_z, centre_level_at_dz)
self._css_total_width = self._append_to_start(
- self._css_total_width, total_wet_area / self.length
+ self._css_total_width,
+ total_wet_area / self.length,
)
self._css_flow_width = self._append_to_start(
- self._css_flow_width, total_flow_area / self.length
+ self._css_flow_width,
+ total_flow_area / self.length,
)
self._fm_wet_area = self._append_to_start(self._fm_wet_area, total_wet_area)
self._fm_flow_area = self._append_to_start(
- self._fm_flow_area, total_flow_area
+ self._fm_flow_area,
+ total_flow_area,
)
self._fm_flow_volume = np.insert(self._fm_flow_volume, 0, np.nan)
self._fm_total_volume = np.insert(self._fm_total_volume, 0, np.nan)
- def _get_extra_total_area(self, waterlevel, crest_level, transition_height: float):
- """
- releases extra area dependent on waterlevel using a logistic (sigmoid) function
- """
+ def _get_extra_total_area(
+ self,
+ waterlevel: np.ndarray,
+ crest_level: np.ndarray,
+ transition_height: float,
+ ) -> float:
+ """Releases extra area dependent on waterlevel using a logistic (sigmoid) function."""
delta = 0.00001 # accuracy parameter
return 1 / (
1
@@ -1342,81 +1378,90 @@ def _get_extra_total_area(self, waterlevel, crest_level, transition_height: floa
)
)
- def _append_to_start(self, array, to_add):
- """
- adds ``to add`` to beginning of array
- """
+ def _append_to_start(
+ self,
+ array: np.ndarray,
+ to_add: float | np.ndarray,
+ ) -> np.ndarray:
+ """Add ``to add`` to beginning of array."""
return np.insert(array, 0, to_add)
- def _return_volume_error(self, predicted, measured):
- """
- Returns the squared relative error
- """
+ def _return_volume_error(
+ self,
+ predicted: np.ndarray,
+ measured: np.ndarray,
+ ) -> np.ndarray:
+ """Calculate the squared relative error."""
non_nan_mask = ~np.isnan(predicted) & ~np.isnan(measured)
predicted = predicted[non_nan_mask]
measured = measured[non_nan_mask]
error = np.array(predicted - measured) / np.maximum(
- np.array(measured), np.ones(len(measured))
+ np.array(measured),
+ np.ones(len(measured)),
)
return np.sum(error**2)
@staticmethod
- def _check_monotonicity(arr, method=2):
- """
- for given input array, create mask such that when applied to the array,
- all values are monotonically rising
+ def _check_monotonicity(arr: np.ndarray, method: int = 2) -> np.ndarray:
+ """For given input array, create mask such that when applied to the array.
+
+ All values are monotonically rising.
method 1: remove values were z is falling from array
method 2: sort array such that z is always rising (default)
Arguments:
+ ---------
arr: 1d numpy array
+ method: int
- return:
+ Return:
+ ------
mask such that arr[mask] is monotonically rising
+
"""
if method == 1:
mask = np.array([True])
for i in range(1, len(arr)):
# Last index that had rising value
j = np.argwhere(mask)[-1][0]
- if arr[i] > arr[j]:
- mask = np.append(mask, True)
- else:
- mask = np.append(mask, False)
+ mask = (
+ np.append(mask, True) if arr[i] > arr[j] else np.append(mask, False)
+ )
return mask
- elif method == 2:
+ if method == 2: # noqa: PLR2004
return np.argsort(arr)
+ err_msg = f"method argument, {method}, not understood. Choose between 1 or 2."
+ raise ValueError(err_msg)
- def _check_total_width_greater_than_flow_width(self):
- """
- If total width is smaller than flow width, set flow width to total width
- """
+ def _check_total_width_greater_than_flow_width(self) -> None:
+ """If total width is smaller than flow width, set flow width to total width."""
mask = self._css_flow_width > self._css_total_width
self._css_flow_width[mask] = self._css_total_width[mask]
self.set_logger_message(
- f"Reduces flow widths at {sum(mask)} points to be same as total", "debug"
+ f"Reduces flow widths at {sum(mask)} points to be same as total",
+ "debug",
)
- def _check_section_widths_greater_than_flow_width(self):
+ def _check_section_widths_greater_than_flow_width(self) -> None:
total_section_width = 0
- for key, width in self.section_widths.items():
+ for width in self.section_widths.values():
total_section_width += width
dif = self.flow_width[-1] - total_section_width
if dif > 0:
self.section_widths["main"] += dif
self.set_logger_message(
- f"Increased main section width by {dif:.2f} m", "warning"
+ f"Increased main section width by {dif:.2f} m",
+ "warning",
)
def _check_section_widths_greater_than_minimum_width(self) -> bool:
- """
- Main section width must be greater than minimum profile width, or
- it is ignored by SOBEK 3
- """
+ """Check section widths that are greater than the minimum width.
+ Main section width must be greater than minimum profile width, or it is ignored by SOBEK 3.
+ """
dif = self.section_widths["main"] - self._css_flow_width[0]
# cm accuracy, and at least 10 cm difference
@@ -1430,10 +1475,12 @@ def _check_section_widths_greater_than_minimum_width(self) -> bool:
self.section_widths["main"] -= dif - tol
self.section_widths["floodplain1"] += dif - tol
self.set_logger_message(
- f"Increased main section width by {-1*(dif-tol):.2f}", "warning"
+ f"Increased main section width by {-1*(dif-tol):.2f}",
+ "warning",
)
return True
return False
- def get_parameter(self, key: str):
+ def get_parameter(self, key: str) -> str | bool | int | float | None:
+ """Retrieve parameter from ini file."""
return self.get_inifile().get_parameter(key)
diff --git a/fm2prof/Import.py b/fm2prof/data_import.py
similarity index 62%
rename from fm2prof/Import.py
rename to fm2prof/data_import.py
index 0e3d7588..0696fa4f 100644
--- a/fm2prof/Import.py
+++ b/fm2prof/data_import.py
@@ -1,10 +1,9 @@
-"""
-This module contains code for importing files to fm2prof
-"""
+"""Contains code for importing files to fm2prof."""
+
+from __future__ import annotations
# import from standard library
-import os
-from typing import Dict, Mapping
+from pathlib import Path
# import from dependencies
import numpy as np
@@ -16,35 +15,45 @@
class FMDataImporter(FM2ProfBase):
- dflow2d_face_keys = {
- "x": "mesh2d_face_x",
- "y": "mesh2d_face_y",
- "area": "mesh2d_flowelem_ba",
- "bedlevel": "mesh2d_flowelem_bl",
- }
-
- dflow2d_edge_keys = {
- "x": "mesh2d_edge_x",
- "y": "mesh2d_edge_y",
- "edge_faces": "mesh2d_edge_faces",
- "edge_nodes": "mesh2d_edge_nodes",
- }
-
- dflow2d_result_keys = {
- "waterdepth": "mesh2d_waterdepth",
- "waterlevel": "mesh2d_s1",
- "chezy_mean": "mesh2d_czs", # not used anymore!
- "chezy_edge": "mesh2d_czu",
- "velocity_x": "mesh2d_ucx",
- "velocity_y": "mesh2d_ucy",
- "velocity_edge": "mesh2d_u1",
- }
-
- def import_dflow2d(self, file_path):
- """
- Method to read input from a dflow2d output file
+ """FM Data importer class."""
+
+ @property
+ def dflow2d_face_keys(self) -> dict:
+ """Mapping with dflow2d face keys."""
+ return {
+ "x": "mesh2d_face_x",
+ "y": "mesh2d_face_y",
+ "area": "mesh2d_flowelem_ba",
+ "bedlevel": "mesh2d_flowelem_bl",
+ }
+
+ @property
+ def dflow2d_edge_keys(self) -> dict:
+ """Mapping with dflow2d edge keys."""
+ return {
+ "x": "mesh2d_edge_x",
+ "y": "mesh2d_edge_y",
+ "edge_faces": "mesh2d_edge_faces",
+ "edge_nodes": "mesh2d_edge_nodes",
+ }
+
+ @property
+ def dflow2d_result_keys(self) -> dict:
+ """Mapping with dflow2d_result_keys."""
+ return {
+ "waterdepth": "mesh2d_waterdepth",
+ "waterlevel": "mesh2d_s1",
+ "chezy_mean": "mesh2d_czs", # not used anymore!
+ "chezy_edge": "mesh2d_czu",
+ "velocity_x": "mesh2d_ucx",
+ "velocity_y": "mesh2d_ucy",
+ "velocity_edge": "mesh2d_u1",
+ }
+
+ def import_dflow2d(self, file_path: Path | str) -> tuple[pd.DataFrame | None, dict, pd.DataFrame, dict]:
+ """Read input from a dflow2d output file.
- Arguments:
+ Args:
file_path: path to *_map.nc file
Results:
@@ -52,6 +61,7 @@ def import_dflow2d(self, file_path):
tid_edge - DataFrame with time-independent data on flow links
node_coordinates -
td - DataFrame with time-dependent data (e.g. water levels, ..)
+
"""
self.set_logger_message("hello from dflow2d importer")
@@ -62,31 +72,25 @@ def import_dflow2d(self, file_path):
tid_face = None
for key, nckey in self.dflow2d_face_keys.items():
if tid_face is None:
- tid_face = pd.DataFrame(
- columns=[key], data=np.array(map_file.variables[nckey])
- )
+ tid_face = pd.DataFrame(columns=[key], data=np.array(map_file.variables[nckey]))
else:
tid_face[key] = np.array(map_file.variables[nckey])
# These variables are preallocated with the correct size for later use
- tid_face["region"] = [""] * len(
- tid_face["y"]
- ) # region id (see RegionPolygon). By default, no regions
+ tid_face["region"] = [""] * len(tid_face["y"]) # region id (see RegionPolygon). By default, no regions
tid_face["section"] = ["main"] * len(
- tid_face["y"]
+ tid_face["y"],
) # section id (see SectionPolygon). By default, all sections are 'main'
tid_face["sclass"] = [""] * len(tid_face["y"]) # cross-section id
tid_face["islake"] = [False] * len(
- tid_face["y"]
+ tid_face["y"],
) # whether or not cell is in a lake. By default, nothing is a lake
# Time-invariant variables from FM 2D at edges
# -----------------------------------------------
- internal_edges = (
- map_file.variables["mesh2d_edge_type"][:] == 1
- ) # edgetype 1 = 'internal'
+ internal_edges = map_file.variables["mesh2d_edge_type"][:] == 1 # edgetype 1 = 'internal'
- tid_edge = dict()
+ tid_edge = {}
for key, nckey in self.dflow2d_edge_keys.items():
try:
tid_edge[key] = np.array(map_file.variables[nckey])[internal_edges]
@@ -99,21 +103,17 @@ def import_dflow2d(self, file_path):
)
tid_edge["sclass"] = np.array([""] * np.sum(internal_edges), dtype="U99")
- tid_edge["section"] = np.array(
- ["main"] * np.sum(internal_edges), dtype="U99"
- )
+ tid_edge["section"] = np.array(["main"] * np.sum(internal_edges), dtype="U99")
tid_edge["region"] = np.array([""] * np.sum(internal_edges), dtype="U99")
# node data (- Is this data still used??)
# ----------------------------------------------
- node_coordinates = pd.DataFrame(
- columns=["x"], data=np.array(map_file.variables["mesh2d_node_x"])
- )
+ node_coordinates = pd.DataFrame(columns=["x"], data=np.array(map_file.variables["mesh2d_node_x"]))
node_coordinates["y"] = np.array(map_file.variables["mesh2d_node_y"])
# Time-variant variables
# ----------------------------------------------
- td = dict()
+ td = {}
for key, nckey in self.dflow2d_result_keys.items():
if key == "chezy_edge":
# this one we treat slightly differently:
@@ -127,13 +127,12 @@ def import_dflow2d(self, file_path):
)
except KeyError:
td[key] = pd.DataFrame(
- data=np.array(map_file.variables["mesh2d_cftrt"]).T[
- internal_edges
- ],
+ data=np.array(map_file.variables["mesh2d_cftrt"]).T[internal_edges],
columns=map_file.variables["time"],
)
self.set_logger_message(
- "The Dflow2D output does not have the 'mesh2d_czu' key. Reverting to mesh2d_cftrt. Make sure that the UnifFrictType is set to 0 (Cheyz) in the Dflow2d mdu file.",
+ "The Dflow2D output does not have the 'mesh2d_czu' key. Reverting to mesh2d_cftrt. "
+ "Make sure that the UnifFrictType is set to 0 (Cheyz) in the Dflow2d mdu file.",
"warning",
)
else:
@@ -146,9 +145,7 @@ def import_dflow2d(self, file_path):
class FmModelData:
- """
- Used to read and store data from the 2D model
- """
+ """Used to read and store data from the 2D model."""
time_dependent_data = None
time_independent_data = None
@@ -156,36 +153,41 @@ class FmModelData:
node_coordinates = None
css_data_list = None
- def __init__(self, arg_list: list):
- if not arg_list:
- raise Exception("FM model data was not read correctly.")
- if len(arg_list) != 5:
- raise Exception(
- "Fm model data expects 5 arguments but only "
- + "{} were given".format(len(arg_list))
- )
-
- (
- self.time_dependent_data,
- self.time_independent_data,
- self.edge_data,
- self.node_coordinates,
- css_data_dictionary,
- ) = arg_list
+ def __init__(
+ self,
+ time_dependent_data: pd.DataFrame,
+ time_independent_data: pd.DataFrame,
+ edge_data: dict,
+ node_coordinates: pd.DataFrame,
+ css_data_dictionary: dict,
+ ) -> None:
+ """Instantiate a FmModelData object.
+
+ Args:
+ time_dependent_data (pd.DataFrame): _description_
+ time_independent_data (pd.DataFrame): _description_
+ edge_data (dict): _description_
+ node_coordinates (pd.DataFrame): _description_
+ css_data_dictionary (dict): _description_
+ """
+ self.time_dependent_data = time_dependent_data
+ self.time_independent_data = time_independent_data
+ self.edge_data = edge_data
+ self.node_coordinates = node_coordinates
self.css_data_list = self.get_ordered_css_list(css_data_dictionary)
@staticmethod
- def get_ordered_css_list(css_data_dict: Mapping[str, str]):
- """Returns an ordered list where every element
- represents a Cross Section structure
+ def get_ordered_css_list(css_data_dict: dict[str, str]) -> list[dict[str, str]]:
+ """Return an ordered list where every element represents a Cross Section structure.
- Arguments:
- css_data_dict {Mapping[str,str]} -- Dictionary ordered by the keys
+ Args:
+ css_data_dict (dict[str, str]): Dictionary ordered by the keys
Returns:
- {list} -- List where every element contains a dictionary
+ (list): List where every element contains a dictionary
to create a Cross Section.
+
"""
if not css_data_dict or not isinstance(css_data_dict, dict):
return []
@@ -193,23 +195,22 @@ def get_ordered_css_list(css_data_dict: Mapping[str, str]):
number_of_css = len(css_data_dict[next(iter(css_data_dict))])
css_dict_keys = css_data_dict.keys()
css_dict_values = css_data_dict.values()
- css_data_list = [
+ return [
dict(
zip(
css_dict_keys,
[value[idx] for value in css_dict_values if idx < len(value)],
- )
+ ),
)
for idx in range(number_of_css)
]
- return css_data_list
- def get_selection(self, css_name: str) -> Dict:
- """
- create a dictionary that holds all the 2D data for the cross-section with name 'css_name'
+ def get_selection(self, css_name: str) -> dict:
+ """Create a dictionary that holds all the 2D data for the cross-section with name 'css_name'.
- Parameters
+ Args:
css_name (str): name of the cross-section
+
"""
dti = self.time_independent_data
dtd = self.time_dependent_data
@@ -231,22 +232,17 @@ def get_selection(self, css_name: str) -> Dict:
edge_faces = edge_data["edge_faces"][edge_data["sclass"] == css_name]
except KeyError:
edge_faces = None
- # edge_nodes = edge_data['edge_nodes'][edge_data['sclass'] == css_name]
+
edge_x = edge_data["x"][edge_data["sclass"] == css_name]
edge_y = edge_data["y"][edge_data["sclass"] == css_name]
- edge_section = edge_data["section"][
- edge_data["sclass"] == css_name
- ] # roughness section number
+ edge_section = edge_data["section"][edge_data["sclass"] == css_name] # roughness section number
- # retrieve the full set for face_nodes and area, needed for the roughness calculation
- # face_nodes = edge_data['face_nodes'][dti['sclass'] == css_name]
- # face_nodes_full = edge_data['face_nodes']
bedlevel = dti["bedlevel"][dti["sclass"] == css_name]
velocity = (vx**2 + vy**2) ** 0.5
waterlevel[waterdepth == 0] = np.nan
- return_dict = {
+ return {
"x": x,
"y": y,
"area": area,
@@ -264,31 +260,20 @@ def get_selection(self, css_name: str) -> Dict:
"edge_section": edge_section,
}
- return return_dict
-
class ImportInputFiles(FM2ProfBase):
- """
- This class contains all functions related to the import of files
- """
+ """Contains all functions related to the import of files."""
- def css_file(self, file_path: str, delimiter: str = ",") -> Dict:
- """
- Reads the cross-section location file
- """
- skipLine = False # flag to skip line if file has header
+ def css_file(self, file_path: Path | str, delimiter: str = ",") -> dict:
+ """Read the cross-section location file."""
+ skip_line = False # flag to skip line if file has header
- if not file_path or not os.path.exists(file_path):
- raise IOError(
- "No file path for Cross Section location file was given, or could not be found at {}".format(
- file_path
- )
- )
+ if not file_path or not Path(file_path).exists():
+ err_msg = f"No file path for Cross Section location file was given, or could not be found at {file_path}"
+ raise OSError(err_msg)
- with open(file_path, "r") as fid:
- input_data = dict(
- xy=list(), id=list(), branchid=list(), length=list(), chainage=list()
- )
+ with Path(file_path).open("r") as fid:
+ input_data = {"xy": [], "id": [], "branchid": [], "length": [], "chainage": []}
for lineno, line in enumerate(fid):
try:
(cssid, x, y, length, branchid, chainage) = line.split(delimiter)
@@ -301,14 +286,14 @@ def css_file(self, file_path: str, delimiter: str = ",") -> Dict:
except ValueError:
if lineno == 0:
# file has header. Skip header and try again next
- skipLine = True
- if not skipLine:
+ skip_line = True
+ if not skip_line:
input_data["xy"].append((float(x), float(y)))
input_data["id"].append(cssid)
input_data["length"].append(float(length))
input_data["branchid"].append(branchid.strip())
input_data["chainage"].append(float(chainage))
- skipLine = False
+ skip_line = False
# Convert everything to ndarray
for key in input_data:
diff --git a/fm2prof/Export.py b/fm2prof/export.py
similarity index 54%
rename from fm2prof/Export.py
rename to fm2prof/export.py
index 12973290..158b617e 100644
--- a/fm2prof/Export.py
+++ b/fm2prof/export.py
@@ -1,18 +1,27 @@
-"""Input/output"""
+"""Classes for exporting data."""
+
+from __future__ import annotations
# import from standar library
from dataclasses import astuple, dataclass
-from typing import List
+from pathlib import Path
+from typing import TYPE_CHECKING, Iterator
+
import numpy as np
# import from package
-from fm2prof import Functions as FE
from fm2prof.common import FM2ProfBase
-from fm2prof.CrossSection import CrossSection
+
+if TYPE_CHECKING:
+ from io import TextIOWrapper
+
+ from fm2prof.cross_section import CrossSection
@dataclass
class OutputFiles:
+ """Class for grouping output files."""
+
dimr_css_locations: str = "CrossSectionLocations.ini"
dimr_css_definitions: str = "CrossSectionDefinitions.ini"
dimr_roughness_main: str = "roughness-Main.ini"
@@ -24,23 +33,38 @@ class OutputFiles:
test_roughness: str = "roughness_test.csv"
fm2prof_volume: str = "volumes.csv"
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> str:
+ """Get item by index from OutputFiles dataclass."""
return astuple(self)[index]
- def __iter__(self):
+ def __iter__(self) -> Iterator:
+ """Get iterator of OutputFiles items."""
return iter(astuple(self))
class Export1DModelData(FM2ProfBase):
- """
- This class contains all functions related to exporting to various output export.
+ """Contains all functions related to exporting to various outputs.
+
In the future, split in different classes for SOBEK, D-Hydro etc formats
"""
def export_geometry(
- self, cross_sections: List[CrossSection], file_path, fmt="sobek3"
- ):
- with open(file_path, "w") as f:
+ self,
+ cross_sections: list[CrossSection],
+ file_path: Path | str,
+ fmt: str = "sobek3",
+ ) -> None:
+ """Export cross section geometries in different formats.
+
+ Args:
+ ----
+ cross_sections (list[CrossSection]): List of cross sections.
+ file_path (Path | str): File path to write to.
+ fmt (str): Format to write to. Options are sobek3, dflow1d, and testformat.
+ Defaults to sobek3
+
+ """
+ with Path(file_path).open("w") as f:
if fmt == "sobek3":
"""SOBEK 3 style csv"""
self._write_geometry_sobek3(f, cross_sections)
@@ -52,9 +76,23 @@ def export_geometry(
self._write_geometry_testformat(f, cross_sections)
def export_roughness(
- self, cross_sections, file_path, fmt="sobek3", roughness_section="Main"
- ):
- with open(file_path, "w") as f:
+ self,
+ cross_sections: list[CrossSection],
+ file_path: str | Path,
+ fmt: str = "sobek3",
+ roughness_section: str = "Main",
+ ) -> None:
+ """Export roughnes in different formats.
+
+ Args:
+ ----
+ cross_sections (list[CrossSection]): List of cross sections
+ file_path (str | Path): File path to write to
+ fmt (str, optional): Format to write to. Defaults to "sobek3".
+ roughness_section (str, optional): Name of roughness section. Defaults to "Main".
+
+ """
+ with Path(file_path).open("w") as f:
if fmt == "sobek3":
"""SOBEK 3 style csv"""
self._write_roughness_sobek3(f, cross_sections)
@@ -65,76 +103,101 @@ def export_roughness(
"""test format for system tests, only has geometry (no summerdike)"""
self._write_roughness_testformat(f, cross_sections)
- def export_volumes(self, cross_sections, file_path):
- """Write to file the volume/waterlevel information"""
+ def export_volumes(
+ self,
+ cross_sections: list[CrossSection],
+ file_path: str | Path,
+ ) -> None:
+ """Write to file the volume/waterlevel information.
- with open(file_path, "w") as f:
+ Args:
+ ----
+ cross_sections (list[CrossSection]): List of cross sections.
+ file_path (str | Path): File path to write to.
+
+ """
+ with Path(file_path).open("w") as f:
# Write header
f.write(
- "id,z,2D_total_volume,2D_flow_volume,2D_wet_area,2D_flow_area,1D_total_volume_sd,1D_total_volume,1D_flow_volume_sd,1D_flow_volume,1D_total_width,1D_flow_width\n"
+ "id,z,2D_total_volume,2D_flow_volume,2D_wet_area,2D_flow_area,1D_total_volume_sd,1D_total_volume,1D_flow_volume_sd,1D_flow_volume,1D_total_width,1D_flow_width\n",
)
for css in cross_sections:
for i in range(len(css._css_z)):
- outputdata = dict(
- id=css.name,
- z=css._css_z[i],
- tv2d=css._fm_total_volume[i],
- fv2d=css._fm_flow_volume[i],
- wa2d=css._fm_wet_area[i],
- fa2d=css._fm_flow_area[i],
- tvsd1d=css._css_total_volume_corrected[i],
- tv1d=css._css_total_volume[i],
- fvsd1d=css._css_flow_volume_corrected[i],
- fv1d=css._css_flow_volume[i],
- tw1d=css._css_total_width[i],
- fw1d=css._css_flow_width[i],
- )
+ outputdata = {
+ "id": css.name,
+ "z": css._css_z[i],
+ "tv2d": css._fm_total_volume[i],
+ "fv2d": css._fm_flow_volume[i],
+ "wa2d": css._fm_wet_area[i],
+ "fa2d": css._fm_flow_area[i],
+ "tvsd1d": css._css_total_volume_corrected[i],
+ "tv1d": css._css_total_volume[i],
+ "fvsd1d": css._css_flow_volume_corrected[i],
+ "fv1d": css._css_flow_volume[i],
+ "tw1d": css._css_total_width[i],
+ "fw1d": css._css_flow_width[i],
+ }
f.write(
"{id},{z},{tv2d},{fv2d},{wa2d},{fa2d},{tvsd1d},{tv1d},{fvsd1d},{fv1d},{tw1d},{fw1d}\n".format(
- **outputdata
- )
+ **outputdata,
+ ),
)
- def export_crossSectionLocations(self, cross_sections, file_path):
- """DIMR format"""
- with open(file_path, "w") as fid:
+ def export_cross_section_locations(
+ self,
+ cross_sections: list[CrossSection],
+ file_path: str | Path,
+ ) -> None:
+ """Export cross section locations.
+
+ Args:
+ ----
+ cross_sections (list[CrossSection]): List of cross sections.
+ file_path (str | Path): file path to write to.
+
+ """
+ with Path(file_path).open("w") as fid:
# Write general secton
fid.write(
- "[General]\nmajorVersion\t\t\t= 1\nminorversion\t\t\t= 0\nfileType\t\t\t\t= crossLoc\n\n"
+ "[General]\nmajorVersion\t\t\t= 1\nminorversion\t\t\t= 0\nfileType\t\t\t\t= crossLoc\n\n",
)
for css in cross_sections:
fid.write("[CrossSection]\n")
- fid.write("\tid\t\t\t\t\t= {}\n".format(css.name))
- fid.write("\tbranchid\t\t\t= {}\n".format(css.branch))
- fid.write("\tchainage\t\t\t= {}\n".format(css.chainage))
+ fid.write(f"\tid\t\t\t\t\t= {css.name}\n")
+ fid.write(f"\tbranchid\t\t\t= {css.branch}\n")
+ fid.write(f"\tchainage\t\t\t= {css.chainage}\n")
fid.write("\tshift\t\t\t\t= 0.000\n")
- fid.write("\tdefinition\t\t\t= {}\n\n".format(css.name))
+ fid.write(f"\tdefinition\t\t\t= {css.name}\n\n")
""" test file formats """
- def _write_geometry_testformat(self, fid, cross_sections):
+ def _write_geometry_testformat(
+ self,
+ fid: TextIOWrapper,
+ cross_sections: list[CrossSection],
+ ) -> None:
# write header
fid.write("chainage,z,total width,flow width\n")
- for index, cross_section in enumerate(cross_sections):
+ for _index, cross_section in enumerate(cross_sections):
for i in range(len(cross_section.z)):
fid.write(
- "{}, {}, {}, {}\n".format(
- cross_section.chainage,
- cross_section.z[i],
- cross_section.total_width[i],
- cross_section.flow_width[i],
- )
+ f"{cross_section.chainage}, {cross_section.z[i]}, {cross_section.total_width[i]},"
+ f"{cross_section.flow_width[i]}\n",
)
- def _write_roughness_testformat(self, fid, cross_sections):
+ def _write_roughness_testformat(
+ self,
+ fid: TextIOWrapper,
+ cross_sections: list[CrossSection],
+ ) -> None:
# write header
fid.write("chainage,type,waterlevel,chezy roughness\n")
for roughnesstype in ("alluvial", "nonalluvial"):
- for index, cross_section in enumerate(cross_sections):
+ for cross_section in cross_sections:
waterlevels = cross_section.alluvial_friction_table[0]
if roughnesstype == "alluvial":
@@ -147,73 +210,72 @@ def _write_roughness_testformat(self, fid, cross_sections):
chezy = table[1][index]
except IndexError:
break
- if np.isnan(chezy) == False:
+ if not np.isnan(chezy):
fid.write(
- "{}, {}, {}, {}\n".format(
- cross_section.chainage, roughnesstype, level, chezy
- )
+ f"{cross_section.chainage}, {roughnesstype}, {level}, {chezy}\n",
)
""" FM 1D file formats """
- def _write_geometry_fm1d(self, fid, cross_sections):
- """FM1D uses a configuration 'Delft' file style format"""
-
+ def _write_geometry_fm1d(
+ self,
+ fid: TextIOWrapper,
+ cross_sections: list[CrossSection],
+ ) -> None:
+ """FM1D uses a configuration 'Delft' file style format."""
# Write general secton
fid.write(
- "[General]\nmajorVersion\t\t\t= 1\nminorversion\t\t\t= 0\nfileType\t\t\t\t= crossDef\n\n"
+ "[General]\nmajorVersion\t\t\t= 1\nminorversion\t\t\t= 0\nfileType\t\t\t\t= crossDef\n\n",
)
- for index, css in enumerate(cross_sections):
- z = ["{:.4f}".format(iz) for iz in css.z]
- fw = ["{:.4f}".format(iz) for iz in css.flow_width]
- tw = ["{:.4f}".format(iz) for iz in css.total_width]
-
- # check for nan, because a channel with only one roughness value (ideal case) will not have this value
- if np.isnan(css.floodplain_base) == False:
- floodplain_base = str(css.floodplain_base)
- else:
- floodplain_base = str(css.crest_level)
+ for css in cross_sections:
+ z = [f"{iz:.4f}" for iz in css.z]
+ fw = [f"{iz:.4f}" for iz in css.flow_width]
+ tw = [f"{iz:.4f}" for iz in css.total_width]
fid.write("[Definition]\n")
fid.write(
- "\tid\t\t\t\t\t= {}\n".format(css.name)
- + "\ttype\t\t\t\t= tabulated\n"
- + "\tthalweg\t\t\t\t= 0.000\n"
- + "\tnumLevels\t\t\t= {}\n".format(len(z))
+ f"\tid\t\t\t\t\t= {css.name}\n" # noqa: ISC003
+ "\ttype\t\t\t\t= tabulated\n"
+ "\tthalweg\t\t\t\t= 0.000\n"
+ + f"\tnumLevels\t\t\t= {len(z)}\n"
+ "\tlevels\t\t\t\t= {}\n".format(" ".join(z))
+ "\tflowWidths\t\t\t= {}\n".format(" ".join(fw))
+ "\ttotalWidths\t\t\t= {}\n".format(" ".join(tw))
- + "\tsd_crest\t\t\t= {:.4f}\n".format(css.crest_level)
- + "\tsd_flowArea\t\t\t= {}\n".format(css.extra_flow_area)
- + "\tsd_totalArea\t\t= {:.4f}\n".format(css.extra_total_area)
- + "\tsd_baseLevel\t\t= {}\n".format(css.floodplain_base)
+ + f"\tsd_crest\t\t\t= {css.crest_level:.4f}\n"
+ + f"\tsd_flowArea\t\t\t= {css.extra_flow_area}\n"
+ + f"\tsd_totalArea\t\t= {css.extra_total_area:.4f}\n"
+ + f"\tsd_baseLevel\t\t= {css.floodplain_base}\n"
+ "\tmain\t\t\t\t= {:.4f}\n".format(css.section_widths["main"])
+ "\tfloodPlain1\t\t\t= {:.4f}\n".format(
- css.section_widths["floodplain1"]
+ css.section_widths["floodplain1"],
)
+ "\tfloodPlain2\t\t\t= {:.4f}\n".format(
- css.section_widths["floodplain2"]
+ css.section_widths["floodplain2"],
)
+ "\tgroundlayerUsed\t\t= 0\n"
- + "\tgroundLayer\t\t\t= 0.000\n\n"
+ + "\tgroundLayer\t\t\t= 0.000\n\n",
)
- def _write_roughness_fm1d(self, fid, cross_sections, section):
- """"""
- general_sec = """[General]
- majorVersion = 1
- minorVersion = 0
- fileType = roughness
+ def _write_roughness_fm1d(
+ self,
+ fid: TextIOWrapper,
+ cross_sections: list[CrossSection],
+ section: str,
+ ) -> None:
+ general_sec = f"""[General]
+ majorVersion = 1
+ minorVersion = 0
+ fileType = roughness
[Content]
- sectionId = {}
- flowDirection = False
- interpolate = 1
- globalType = 1
- globalValue = 45
+ sectionId = {section}
+ flowDirection = False
+ interpolate = 1
+ globalType = 1
+ globalValue = 45
- """.format(section)
+ """
branch_sec = self._get_fm1d_branch_sec(cross_sections, section.lower())
definition_sec = self._get_fm1d_definition_sec(cross_sections, section.lower())
@@ -221,15 +283,19 @@ def _write_roughness_fm1d(self, fid, cross_sections, section):
fid.write(branch_sec)
fid.write(definition_sec)
- def _get_fm1d_definition_sec(self, cross_sections, section):
+ def _get_fm1d_definition_sec(
+ self,
+ cross_sections: list[CrossSection],
+ section: str,
+ ) -> str:
def_sec = ""
for css in cross_sections:
- try:
+ if section in css.friction_tables:
table = css.friction_tables[section]
def_sec += """[Definition]
- branchId = {}
- chainage = {}
+ branchId = {}
+ chainage = {}
values = {}
""".format(
@@ -237,55 +303,55 @@ def _get_fm1d_definition_sec(self, cross_sections, section):
css.chainage,
" ".join(map("{:.4}".format, table.friction)),
)
- except KeyError:
- # this section does not exist in this cross-section
- pass
-
return def_sec
- def _get_fm1d_branch_sec(self, cross_sections, section):
+ def _get_fm1d_branch_sec(
+ self,
+ cross_sections: list[CrossSection],
+ section: str,
+ ) -> str:
branch_sec = ""
branch_list = []
for css in cross_sections:
- if section in css.friction_tables:
- try:
- if css.branch not in branch_list:
- branch_list.append(css.branch)
- branch_sec += """[BranchProperties]
- branchId = {}
- roughnessType = 1
- functionType = 2
- numLevels = {}
+ if section in css.friction_tables and css.branch not in branch_list:
+ branch_list.append(css.branch)
+ branch_sec += """[BranchProperties]
+ branchId = {}
+ roughnessType = 1
+ functionType = 2
+ numLevels = {}
levels = {}
-
""".format(
- css.branch,
- len(css.friction_tables[section].level),
- " ".join(
- map("{:.4f}".format, css.friction_tables[section].level)
- ),
- )
- except KeyError:
- # section not in this croess-section
- pass
+ css.branch,
+ len(css.friction_tables[section].level),
+ " ".join(
+ map("{:.4f}".format, css.friction_tables[section].level),
+ ),
+ )
+
return branch_sec
""" SOBEK 3 file formats """
- def _write_geometry_sobek3(self, fid, cross_sections):
+ def _write_geometry_sobek3(
+ self,
+ fid: TextIOWrapper,
+ cross_sections: list[CrossSection],
+ ) -> None:
# write meta
# note, the chainage is currently set to the X-coordinate of the cross-section (straight channel)
# note, the channel naming strategy must be discussed, currently set to 'Channel' for all cross-sections
# write header
fid.write(
- "id,Name,Data_type,level,Total width,Flow width,Profile_type,branch,chainage,width main channel,width floodplain 1,width floodplain 2,width sediment transport,Use Summerdike,Crest level summerdike,Floodplain baselevel behind summerdike,Flow area behind summerdike,Total area behind summerdike,Use groundlayer,Ground layer depth\n"
+ "id,Name,Data_type,level,Total width,Flow width,Profile_type,branch,chainage,width main channel,width"
+ "floodplain 1,width floodplain 2,width sediment transport,Use Summerdike,Crest level summerdike,Floodplain"
+ "baselevel behind summerdike,Flow area behind summerdike,Total area behind summerdike,Use groundlayer,"
+ "Ground layer depth\n",
)
- for index, cross_section in enumerate(cross_sections):
+ for cross_section in cross_sections:
try:
- total_width = cross_section.total_width[-1]
-
b_summerdike = "0"
crest_level = ""
floodplain_base = ""
@@ -296,12 +362,13 @@ def _write_geometry_sobek3(self, fid, cross_sections):
crest_level = str(cross_section.crest_level)
total_area = str(cross_section.extra_total_area)
- # check for nan, because a channel with only one roughness value (ideal case) will not have this value
- if np.isnan(cross_section.floodplain_base) == False:
+ # check for nan, because a channel with only one roughness value (ideal case) will not
+ # have this value
+ if not np.isnan(cross_section.floodplain_base):
floodplain_base = str(cross_section.floodplain_base)
else:
floodplain_base = str(
- cross_section.crest_level
+ cross_section.crest_level,
) # virtual summer dike
fid.write(
@@ -331,10 +398,11 @@ def _write_geometry_sobek3(self, fid, cross_sections):
+ ","
+ total_area
+ ",,,,,,"
- + "\n"
+ + "\n",
)
- # this is to avoid the unique z-value error in sobek, the added 'error' depends on the total_width, this is to make sure the order or points is correct
+ # this is to avoid the unique z-value error in sobek, the added 'error' depends on the total_width,
+ # this is to make sure the order or points is correct
z_format = "{:.8f}"
increment = np.array(range(1, cross_section.z.size + 1)) * 1e-5
z_value = cross_section.z + increment
@@ -353,30 +421,33 @@ def _write_geometry_sobek3(self, fid, cross_sections):
+ ","
+ str(flow_width)
+ ",,,,,,,,,,,,,,"
- + "\n"
+ + "\n",
)
except:
self.set_logger_message(
f"Cross-section {cross_section.name} not written to sobek format cross-section file",
"error",
)
- pass
- def _write_roughness_sobek3(self, fid, cross_sections):
+ def _write_roughness_sobek3(
+ self,
+ fid: TextIOWrapper,
+ cross_sections: list[CrossSection],
+ ) -> None:
# note, the chainage is currently set to the X-coordinate of the cross-section (straight channel)
# note, the channel naming strategy must be discussed, currently set to 'Channel' for all cross-sections
# write header
fid.write(
- "Name,Chainage,RoughnessType,SectionType,Dependance,Interpolation,Pos/neg,R_pos_constant,Q_pos,R_pos_f(Q),H_pos,R_pos__f(h),R_neg_constant,Q_neg,R_neg_f(Q),H_neg,R_neg_f(h)\n"
+ "Name,Chainage,RoughnessType,SectionType,Dependance,Interpolation,Pos/neg,R_pos_constant,Q_pos,R_pos_f(Q),H_pos,R_pos__f(h),R_neg_constant,Q_neg,R_neg_f(Q),H_neg,R_neg_f(h)\n",
)
sections = np.unique(
- [s for css in cross_sections for s in css.friction_tables.keys()]
+ [s for css in cross_sections for s in css.friction_tables],
)
for section in sections:
- for index, cross_section in enumerate(cross_sections):
+ for cross_section in cross_sections:
if section in list(cross_section.friction_tables.keys()):
if section == "main":
table = cross_section.friction_tables[section]
@@ -388,7 +459,8 @@ def _write_roughness_sobek3(self, fid, cross_sections):
table = cross_section.friction_tables[section]
plain = "FloodPlain2"
else:
- raise Exception("Unknown section name: {}".format(section))
+ err_msg = f"Unknown section name: {section}"
+ raise ValueError(err_msg)
for level, friction in zip(table.level, table.friction):
fid.write(
@@ -410,5 +482,5 @@ def _write_roughness_sobek3(self, fid, cross_sections):
+ ","
+ str(friction)
+ ",,,,,"
- + "\n"
+ + "\n",
)
diff --git a/fm2prof/Fm2ProfRunner.py b/fm2prof/fm2prof_runner.py
similarity index 59%
rename from fm2prof/Fm2ProfRunner.py
rename to fm2prof/fm2prof_runner.py
index e2583526..b960cf77 100644
--- a/fm2prof/Fm2ProfRunner.py
+++ b/fm2prof/fm2prof_runner.py
@@ -1,77 +1,77 @@
+"""Module for running Fm2Prof processess."""
+
+from __future__ import annotations
+
import datetime
-import os
import pickle
import traceback
from pathlib import Path
-from typing import Dict, Generator, List, Mapping, Union
+from typing import Generator
import geojson
import numpy as np
+import pandas as pd
import tqdm
from geojson import Feature, FeatureCollection, Polygon
from netCDF4 import Dataset
-import pandas as pd
from scipy.spatial import ConvexHull
-from fm2prof import Functions as FE
from fm2prof import __version__
+from fm2prof import functions as funcs
+from fm2prof import mask_output_file
from fm2prof.common import FM2ProfBase
-from fm2prof.CrossSection import CrossSection, CrossSectionHelpers
-from fm2prof.Export import Export1DModelData, OutputFiles
-from fm2prof.Import import FMDataImporter, FmModelData, ImportInputFiles
-from fm2prof.IniFile import IniFile
-from fm2prof.MaskOutputFile import MaskOutputFile
-from fm2prof.RegionPolygonFile import RegionPolygonFile, SectionPolygonFile
+from fm2prof.cross_section import CrossSection, CrossSectionHelpers
+from fm2prof.data_import import FMDataImporter, FmModelData, ImportInputFiles
+from fm2prof.export import Export1DModelData, OutputFiles
+from fm2prof.ini_file import IniFile
+from fm2prof.region_polygon_file import RegionPolygonFile, SectionPolygonFile
class InitializationError(Exception):
- pass
+ """Exception class for initialization errors."""
class Fm2ProfRunner(FM2ProfBase):
+ """Main class that executes all functionality."""
+
__map_key = "2DMapOutput"
__css_key = "CrossSectionLocationFile"
__key_frictionweighingmethod = "FrictionweighingMethod"
__key_skipmaps = "SkipMaps"
- """
- Main class that executes all functionality.
-
- Arguments:
- iniFilePath (str): path to configuration file
- """
+ def __init__(self, ini_file_path: Path | str = "") -> None:
+ """Initialize the project.
- def __init__(self, iniFilePath: Path | str = ""):
- """
- Initializes the project
+ Args:
+ ----
+ ini_file_path (Path | str): path to configuration file.
- Parameters:
- iniFilePath: path to a configuration file. If not given,
- default values will be used.
"""
self.fm_model_data: FmModelData = None
self._output_files: OutputFiles = OutputFiles()
self._create_logger()
- iniFilePath = Path(iniFilePath)
+ ini_file_path = Path(ini_file_path)
self.start_new_log_task("Loading configuration file")
try:
- self.load_inifile(iniFilePath)
- except (FileNotFoundError, IOError) as e:
+ self.load_inifile(ini_file_path)
+ except (OSError, FileNotFoundError) as e:
self.set_logger_message(f"Exiting {e}", "error")
return
if not self.get_inifile().has_output_directory:
self.set_logger_message(
- "Output directory must be set in configuration file", "error"
+ "Output directory must be set in configuration file",
+ "error",
)
return
# Add a log file
self.set_logfile(
- output_dir=self.get_inifile().get_output_directory(), filename="fm2prof.log"
+ output_dir=self.get_inifile().get_output_directory(),
+ filename="fm2prof.log",
)
self.finish_log_task()
@@ -81,17 +81,18 @@ def __init__(self, iniFilePath: Path | str = ""):
# Print configuration to log
self.set_logger_message(self.get_inifile().print_configuration(), header=True)
- def run(self, overwrite: bool = False) -> None:
- """
- Executes FM2PROF routines.
+ def run(self, *, overwrite: bool = False) -> None:
+ """Execute FM2PROF routines.
- Parameters:
- overwrite: if True, overwrites existing output. If False, exits if output detected
- """
+ Args:
+ ----
+ overwrite (bool): if True, overwrites existing output. If False, exits if output detected.
+ """
if self.get_inifile() is None:
self.set_logger_message(
- "No ini file was specified: the run cannot go further.", "Warning"
+ "No ini file was specified: the run cannot go further.",
+ "Warning",
)
return
@@ -111,17 +112,18 @@ def run(self, overwrite: bool = False) -> None:
else:
self.set_logger_message("Program finished", "info")
- def load_inifile(self, iniFilePath: str):
- """
- use this method to load a configuration file from path.
+ def load_inifile(self, ini_file_path: str) -> None:
+ """Use this method to load a configuration file from path.
+
+ Args:
+ ----
+ ini_file_path (str): path to configuration file
- Parameters:
- iniFilePath (str): path to configuration file
"""
- IniFileObject = IniFile(iniFilePath, logger=self.get_logger())
- self.set_inifile(IniFileObject)
+ ini_file_object = IniFile(ini_file_path, logger=self.get_logger())
+ self.set_inifile(ini_file_object)
- def _print_header(self):
+ def _print_header(self) -> None:
header_text = [
"=" * 80,
f"FM2PROF version {__version__}",
@@ -137,8 +139,7 @@ def _print_header(self):
self.set_logger_message(line, header=True)
def _run_inifile(self) -> bool:
- """
- Executes main program from the configuration file.
+ """Execute main program from the configuration file.
The main steps in the program are:
@@ -146,12 +147,7 @@ def _run_inifile(self) -> bool:
2. Generate cross-sections
3. Finalization
- Arguments:
- iniFile {IniFile}
- -- Object containing all the information
- needed to execute the program
"""
-
# Initialise the project
self.start_new_log_task("Initialising FM2PROF")
try:
@@ -160,7 +156,8 @@ def _run_inifile(self) -> bool:
return False
except:
self.set_logger_message(
- "Unexpected exception during initialisation", "error"
+ "Unexpected exception during initialisation",
+ "error",
)
for line in traceback.format_exc().splitlines():
self.set_logger_message(line, "debug")
@@ -194,27 +191,23 @@ def _run_inifile(self) -> bool:
self._print_log_report()
except:
self.set_logger_message(
- "Unexpected exception during printing of log report", "error"
+ "Unexpected exception during printing of log report",
+ "error",
)
self.finish_log_task()
return True
def _initialise_fm2prof(self) -> None:
- """
-
- Loads data, inifile
- """
-
- iniFile: IniFile = self.get_inifile()
- raiseFileNotFoundError: bool = False
+ """Load data, inifile."""
+ ini_file: IniFile = self.get_inifile()
+ raise_file_not_found: bool = False
# shorter local variables
- map_file = iniFile.get_input_file(self.__map_key)
- css_file = iniFile.get_input_file(self.__css_key)
- region_file = iniFile.get_input_file("RegionPolygonFile")
- section_file = iniFile.get_input_file("SectionPolygonFile")
- output_dir = iniFile.get_output_directory()
+ map_file = ini_file.get_input_file(self.__map_key)
+ css_file = ini_file.get_input_file(self.__css_key)
+ region_file = ini_file.get_input_file("RegionPolygonFile")
+ section_file = ini_file.get_input_file("SectionPolygonFile")
# Read region & section polygon
regions: RegionPolygonFile = None
@@ -229,22 +222,39 @@ def _initialise_fm2prof(self) -> None:
# Check if mandatory input exists
if not Path(map_file).is_file():
self.set_logger_message(
- f"File for {self.__map_key} not found at {map_file}", "error"
+ f"File for {self.__map_key} not found at {map_file}",
+ "error",
)
- raiseFileNotFoundError = True
+ raise_file_not_found = True
if not Path(css_file).is_file():
self.set_logger_message(
- f"File for {self.__css_key} not found at {css_file}", "error"
+ f"File for {self.__css_key} not found at {css_file}",
+ "error",
)
- raiseFileNotFoundError = True
- if raiseFileNotFoundError:
+ raise_file_not_found = True
+ if raise_file_not_found:
raise InitializationError
# Read FM model data
- fm2prof_fm_model_data = self._set_fm_model_data(
- map_file, css_file, regions, sections
+ (
+ time_dependent_data,
+ time_independent_data,
+ edge_data,
+ node_coordinates,
+ css_data_dictionary,
+ ) = self._set_fm_model_data(
+ map_file,
+ css_file,
+ regions,
+ sections,
+ )
+ self.fm_model_data = FmModelData(
+ time_dependent_data=time_dependent_data,
+ time_independent_data=time_independent_data,
+ edge_data=edge_data,
+ node_coordinates=node_coordinates,
+ css_data_dictionary=css_data_dictionary,
)
- self.fm_model_data = FmModelData(fm2prof_fm_model_data)
# Validate config file
success: bool = self._validate_config_after_initalization()
@@ -261,18 +271,18 @@ def _initialise_fm2prof(self) -> None:
nedges: int = self.fm_model_data.edge_data.get("x").shape[0]
self.set_logger_message("finished reading FM and cross-sectional data data")
self.set_logger_message(
- "Number of: timesteps ({}), ".format(ntsteps)
- + "faces ({}), ".format(nfaces)
- + "edges ({})".format(nedges),
+ f"Number of: timesteps ({ntsteps}), "
+ + f"faces ({nfaces}), "
+ + f"edges ({nedges})",
level="debug",
)
return success
def _validate_config_after_initalization(self) -> bool:
- """
- Performs validation checks on config file. Returns
- True if all checks succesfull, False if check fails.
+ """Perform validation checks on config file.
+
+ Returns True if all checks succesfull, False if check fails.
"""
success: bool = True
@@ -298,28 +308,28 @@ def _validate_config_after_initalization(self) -> bool:
)
# Check if edge/face data is available
- if "edge_faces" not in self.fm_model_data.edge_data:
- if self.get_inifile().get_parameter(self.__key_frictionweighingmethod) == 1:
- self.set_logger_message(
- "Friction weighing set to 1 (area-weighted average"
- + "but FM map file does contain the *edge_faces* keyword."
- + "Area weighting is not possible. Defaulting to simple unweighted"
- + "averaging",
- level="warning",
- )
+ if (
+ "edge_faces" not in self.fm_model_data.edge_data
+ and self.get_inifile().get_parameter(self.__key_frictionweighingmethod) == 1
+ ):
+ self.set_logger_message(
+ "Friction weighing set to 1 (area-weighted average"
+ "but FM map file does contain the *edge_faces* keyword."
+ "Area weighting is not possible. Defaulting to simple unweighted"
+ "averaging",
+ level="warning",
+ )
return success
- def _finalise_fm2prof(self, cross_sections: List) -> None:
- """
- Write to output, perform checks
- """
+ def _finalise_fm2prof(self, cross_sections: list[CrossSection]) -> None:
+ """Write to output, perform checks."""
self.set_logger_message("Interpolating roughness")
CrossSectionHelpers().interpolate_friction_across_cross_sections(cross_sections)
# Export cross sections
output_dir = self.get_inifile().get_output_directory()
- self.set_logger_message("Export model input files to {}".format(output_dir))
+ self.set_logger_message(f"Export model input files to {output_dir}")
self._write_output(cross_sections, output_dir)
# Generate output geojson
@@ -330,7 +340,7 @@ def _finalise_fm2prof(self, cross_sections: List) -> None:
# We need a better solution for this (inifile.getparam?.. handle defaults there?)
export_mapfiles = False
if export_mapfiles:
- self.set_logger_message("Export geojson output to {}".format(output_dir))
+ self.set_logger_message(f"Export geojson output to {output_dir}")
self._generate_geojson_output(output_dir, cross_sections)
# Export bounding boxes of cross-section control volumes
@@ -341,18 +351,19 @@ def _finalise_fm2prof(self, cross_sections: List) -> None:
self.set_logger_message("Error while exporting bounding boxes", "error")
self.set_logger_message(e_message, "error")
- def _export_envelope(self, output_dir, cross_sections):
- """
- # Export envelopes around cross-sections
- """
- output = {"type": "FeatureCollection"}
- css_hulls = list()
+ def _export_envelope(
+ self,
+ output_dir: Path | str,
+ cross_sections: list[CrossSection],
+ ) -> None:
+ """Export envelopes around cross-sections."""
+ css_hulls = []
for css in cross_sections:
pointlist = np.array(
[
point["geometry"]["coordinates"]
for point in css.get_point_list("face")
- ]
+ ],
)
# construct envelope
try:
@@ -361,27 +372,34 @@ def _export_envelope(self, output_dir, cross_sections):
Feature(
properties={"name": css.name},
geometry=Polygon([list(map(tuple, pointlist[hull.vertices]))]),
- )
+ ),
)
except IndexError:
self.set_logger_message(f"No Hull Exported For {css.name}")
-
- with open(os.path.join(output_dir, "cross_section_volumes.geojson"), "w") as f:
+ with Path(output_dir).joinpath("cross_section_volumes.geojson").open("w") as f:
geojson.dump(FeatureCollection(css_hulls), f, indent=2)
- def _set_fm_model_data(self, res_file, css_file, regions, sections):
- """
- Reads input files for 'FM2PROF'. See documentation for file format descriptions.
+ def _set_fm_model_data(
+ self,
+ res_file: str | Path,
+ css_file: str | Path,
+ regions: RegionPolygonFile | None,
+ sections: SectionPolygonFile | None,
+ ) -> tuple:
+ """Read input files for 'FM2PROF'.
+
+ See documentation for file format descriptions.
+
+ Args:
+ res_file (str | Path): path to FlowFM map netcfd file (*_map.nc)
+ css_file (str | Path): path to cross-section definition file_
+ regions (RegionPolygonFile | None): RegionPolygonFile object
+ sections (SectionPolygonFile | None): SectionPolygonFile object
- Data is saved in three major structures:
- time_independent_data: holds bathymetry information
- time_dependent_data: waterlevels, roughnesses and velocities
- edge_data: the nodes that relate to edges
+ Returns:
+ tuple: Tuple containing time dependent data, time independent data, edge data, node coordinates,
+ and cross section data.
- :param res_file: str, path to FlowFM map netcfd file (*_map.nc)
- :param css_file: str, path to cross-section definition file
- :param region_file: str, path to region geojson file
- :return:
"""
importer = ImportInputFiles(logger=self.get_logger())
ini_file = self.get_inifile()
@@ -394,7 +412,6 @@ def _set_fm_model_data(self, res_file, css_file, regions, sections):
node_coordinates,
time_dependent_data,
) = FMDataImporter().import_dflow2d(res_file)
- # (ctime_independent_data, cedge_data, cnode_coordinates, ctime_dependent_data) = FE._read_fm_model(res_file)
self.set_logger_message("Closed FM Map file")
# Load locations and names of cross-sections
@@ -405,49 +422,65 @@ def _set_fm_model_data(self, res_file, css_file, regions, sections):
# Classify regions & set cross-sections
if (ini_file.get_parameter("classificationmethod") == 0) or (regions is None):
self.set_logger_message(
- "All 2D points assigned to the same region and classifying points to cross-sections"
+ "All 2D points assigned to the same region and classifying points to cross-sections",
)
- time_independent_data, edge_data = FE.classify_without_regions(
- cssdata, time_independent_data, edge_data
+ time_independent_data, edge_data = funcs.classify_without_regions(
+ cssdata,
+ time_independent_data,
+ edge_data,
)
elif ini_file.get_parameter("classificationmethod") == 1:
self.set_logger_message(
- "Assigning 2D points to regions using DeltaShell and classifying points to cross-sections"
+ "Assigning 2D points to regions using DeltaShell and classifying points to cross-sections",
)
time_independent_data, edge_data = self._classify_with_deltashell(
- time_independent_data, edge_data, cssdata, regions, polytype="region"
+ time_independent_data,
+ edge_data,
+ cssdata,
+ regions,
+ polytype="region",
)
else:
self.set_logger_message(
- "Assigning 2D points to regions using Built-In method and classifying points to cross-sections"
+ "Assigning 2D points to regions using Built-In method and classifying points to cross-sections",
)
time_independent_data, edge_data = self._classify_with_builtin_methods(
- time_independent_data, edge_data, cssdata, regions
+ time_independent_data,
+ edge_data,
+ cssdata,
+ regions,
)
# Classify sections for roughness tables
if (ini_file.get_parameter("classificationmethod") == 0) or (sections is None):
self.set_logger_message("Assigning point to sections without polygons")
edge_data = self._classify_roughness_sections_by_variance(
- edge_data, time_dependent_data["chezy_edge"]
+ edge_data,
+ time_dependent_data["chezy_edge"],
)
time_independent_data = self._classify_roughness_sections_by_variance(
- time_independent_data, time_dependent_data["chezy_mean"]
+ time_independent_data,
+ time_dependent_data["chezy_mean"],
)
elif ini_file.get_parameter("classificationmethod") == 1:
self.set_logger_message("Assigning 2D points to sections using DeltaShell")
time_independent_data, edge_data = self._classify_section_with_deltashell(
- time_independent_data, edge_data
+ time_independent_data,
+ edge_data,
)
else:
self.set_logger_message(
- "Assigning 2D points to sections using Built-In method"
+ "Assigning 2D points to sections using Built-In method",
)
- edge_data = FE.classify_roughness_sections_by_polygon(
- sections, edge_data, self.get_logger()
+ edge_data = funcs.classify_roughness_sections_by_polygon(
+ sections,
+ edge_data,
+ self.get_logger(),
)
- time_independent_data = FE.classify_roughness_sections_by_polygon(
- sections, time_independent_data, self.get_logger()
+ time_independent_data = funcs.classify_roughness_sections_by_polygon(
+ sections,
+ time_independent_data,
+ self.get_logger(),
)
return (
@@ -459,18 +492,27 @@ def _set_fm_model_data(self, res_file, css_file, regions, sections):
)
def _classify_with_builtin_methods(
- self, time_independent_data, edge_data, cssdata, polygons
- ):
+ self,
+ time_independent_data: pd.DataFrame,
+ edge_data: dict,
+ cssdata: dict,
+ regions: RegionPolygonFile,
+ ) -> tuple[pd.DataFrame, dict]:
# Determine in which region each cross-section lies
- css_regions = polygons.classify_points(cssdata["xy"])
+ css_regions = regions.classify_points(cssdata["xy"])
# Determine in which region each 2d point lies
+
+ nr_of_time_independent_data_values = len(time_independent_data.get("x"))
+ x_tid_array = time_independent_data.get("x").to_numpy()
+ y_tid_array = time_independent_data.get("y").to_numpy()
+
xy_tuples_2d = [
(
- time_independent_data.get("x").values[i],
- time_independent_data.get("y").values[i],
+ x_tid_array[i],
+ y_tid_array[i],
)
- for i in range(len(time_independent_data.get("x")))
+ for i in range(nr_of_time_independent_data_values)
]
time_independent_data["region"] = regions.classify_points(xy_tuples_2d)
@@ -480,58 +522,87 @@ def _classify_with_builtin_methods(
for i in range(len(edge_data.get("x")))
]
- edge_data["region"] = polygons.classify_points(xy_tuples_2d)
+ edge_data["region"] = regions.classify_points(xy_tuples_2d)
# Do Nearest neighbour cross-section for each region
- time_independent_data, edge_data = FE.classify_with_regions(
- regions, cssdata, time_independent_data, edge_data, css_regions
+ time_independent_data, edge_data = funcs.classify_with_regions(
+ regions,
+ cssdata,
+ time_independent_data,
+ edge_data,
+ css_regions,
)
return time_independent_data, edge_data
- def _classify_section_with_deltashell(self, time_independent_data, edge_data):
+ def _classify_section_with_deltashell(
+ self,
+ time_independent_data: pd.DataFrame,
+ edge_data: dict,
+ ) -> tuple[pd.DataFrame, dict]:
# Determine in which section each 2D point lies
self.set_logger_message("Assigning faces...")
time_independent_data = self._assign_polygon_using_deltashell(
- time_independent_data, dtype="face", polytype="section"
+ time_independent_data,
+ dtype="face",
+ polytype="section",
)
self.set_logger_message("Assigning edges...")
edge_data = self._assign_polygon_using_deltashell(
- edge_data, dtype="edge", polytype="section"
+ edge_data,
+ dtype="edge",
+ polytype="section",
)
return time_independent_data, edge_data
def _classify_with_deltashell(
- self, time_independent_data, edge_data, cssdata, polygons, polytype="region"
- ):
+ self,
+ time_independent_data: pd.DataFrame,
+ edge_data: dict,
+ cssdata: dict,
+ polygons: RegionPolygonFile,
+ polytype: str = "region",
+ ) -> tuple[pd.DataFrame, dict]:
# Determine in which region each 2D point lies
self.set_logger_message("Assigning faces...")
time_independent_data = self._assign_polygon_using_deltashell(
- time_independent_data, dtype="face", polytype=polytype
+ time_independent_data,
+ dtype="face",
+ polytype=polytype,
)
self.set_logger_message("Assigning edges...")
edge_data = self._assign_polygon_using_deltashell(
- edge_data, dtype="edge", polytype=polytype
+ edge_data,
+ dtype="edge",
+ polytype=polytype,
)
self.set_logger_message(
- "Assigning cross-sections using nearest neighbour within regions..."
+ "Assigning cross-sections using nearest neighbour within regions...",
)
# Determine in which region each cross-section lies
css_regions = polygons.classify_points(cssdata["xy"])
# Do Nearest neighbour cross-section for each region
- time_independent_data, edge_data = FE.classify_with_regions(
- polygons, cssdata, time_independent_data, edge_data, css_regions
+ time_independent_data, edge_data = funcs.classify_with_regions(
+ polygons,
+ cssdata,
+ time_independent_data,
+ edge_data,
+ css_regions,
)
return time_independent_data, edge_data
- def _classify_roughness_sections_by_variance(self, data, variable):
- """
- This method classifies the region into main channel and floodplain based on roughness. It
- is used when the user does not specify a section polygon.
+ def _classify_roughness_sections_by_variance(
+ self,
+ data: pd.DataFrame | dict,
+ variable: pd.DataFrame,
+ ) -> pd.DataFrame | dict:
+ """Classify the region into main channel and floodplain based on roughness.
+
+ It is used when the user does not specify a section polygon.
This method assumes that the main channel is much deeper than the floodplain. Therefore,
the Chézy values will be higher than those in the floodplain. The objective is now to
@@ -544,27 +615,23 @@ def _classify_roughness_sections_by_variance(self, data, variable):
"""
# Get chezy values at last timestep
- end_values = variable.T.iloc[-1].values
+ end_values = variable.T.iloc[-1].to_numpy()
key = "section"
# Find split point (chezy value) by variance minimisation
- variance_list = list()
split_candidates = np.arange(min(end_values), max(end_values), 1)
- if (
- len(split_candidates) < 2
- ): # this means that all end values are very close together, so do not split
+ if len(split_candidates) < 2: # noqa: PLR2004, this means that all end values are very close together, so do not split
data[key][:] = 1
else:
- for split in split_candidates:
- variance_list.append(
- np.max(
- [
- np.var(end_values[end_values > split]),
- np.var(end_values[end_values <= split]),
- ]
- )
+ variance_list = [
+ np.max(
+ [
+ np.var(end_values[end_values > split]),
+ np.var(end_values[end_values <= split]),
+ ],
)
-
+ for split in split_candidates
+ ]
splitpoint = split_candidates[np.nanargmin(variance_list)]
# High chezy values are assigned to section number '1' (Main channel)
@@ -572,23 +639,23 @@ def _classify_roughness_sections_by_variance(self, data, variable):
if isinstance(data, pd.DataFrame):
data.loc[end_values > splitpoint, key] = 1
data.loc[end_values <= splitpoint, key] = 2
- else:
- data[key][end_values > splitpoint] = 1
+ else:
+ data[key][end_values > splitpoint] = 1
data[key][end_values <= splitpoint] = 2
return data
- def _get_region_map_file(self, polytype):
- """Returns the path to a NC file with region ifnormation in the bathymetry data"""
- map_file_path = self.get_inifile().get_input_file("2DMapOutput")
- filepath, ext = os.path.splitext(map_file_path)
- modified_file_path = f"{filepath}_{polytype.upper()}BATHY{ext}"
- return modified_file_path
+ def _get_region_map_file(self, polytype: str) -> str:
+ """Return the path to a NC file with region ifnormation in the bathymetry data."""
+ map_file_path = Path(self.get_inifile().get_input_file("2DMapOutput"))
+ return f"{map_file_path.parent / map_file_path.stem}_{polytype.upper()}BATHY{map_file_path.suffix}"
def _assign_polygon_using_deltashell(
- self, data, dtype: str = "face", polytype: str = "region"
- ):
- """Assign all 2D points using DeltaShell method"""
-
+ self,
+ data: dict | pd.DataFrame,
+ dtype: str = "face",
+ polytype: str = "region",
+ ) -> pd.DataFrame | dict:
+ """Assign all 2D points using DeltaShell method."""
# NOTE
self.set_logger_message(f"Looking for _{polytype.upper()}BATHY.nc", "debug")
@@ -610,17 +677,17 @@ def _assign_polygon_using_deltashell(
return data
- def _generate_geojson_output(self, output_dir: str, cross_sections: list):
- """Generates geojson file based on cross sections.
+ def _generate_geojson_output(self, output_dir: str, cross_sections: list) -> None:
+ """Generate geojson file based on cross sections.
+
+ Args:
+ ----
+ output_dir (str): Output directory path.
+ cross_sections (list): List of Cross Sections.
- Arguments:
- output_dir {str} -- Output directory path.
- cross_sections {list} -- List of Cross Sections.
"""
for pointtype in ["face", "edge"]:
- output_file_path = os.path.join(
- output_dir, "{}_output.geojson".format(pointtype)
- )
+ output_file_path = Path(output_dir) / f"{pointtype}_output.geojson"
try:
node_points = [
node_point
@@ -628,32 +695,28 @@ def _generate_geojson_output(self, output_dir: str, cross_sections: list):
for node_point in cs.get_point_list(pointtype)
]
self.set_logger_message(
- "Collected points, dumping to file", level="debug"
+ "Collected points, dumping to file",
+ level="debug",
)
- MaskOutputFile.write_mask_output_file(output_file_path, node_points)
+ mask_output_file.write_mask_output_file(output_file_path, node_points)
self.set_logger_message("Done", level="debug")
except Exception as e_info:
self.set_logger_message(
"Error while generation .geojson file,"
- + "at {}".format(output_file_path)
- + "Reason: {}".format(str(e_info)),
+ + f"at {output_file_path}"
+ + f"Reason: {e_info!s}",
level="error",
)
- def _generate_cross_section_list(self):
- """Generates cross sections based on the given fm_model_data
-
- Arguments:
- input_param_dict {Mapping[str, list]}
- -- Dictionary of parameters read from IniFile
- fm_model_data {FmModelData}
- -- Class with all necessary data for generating Cross Sections
+ def _generate_cross_section_list(self) -> list[CrossSection]:
+ """Generate cross sections based on the given fm_model_data.
Returns:
- {list} -- List of generated cross sections
- """
+ -------
+ (list): List of generated cross sections
- cross_sections = list()
+ """
+ cross_sections = []
if not self.fm_model_data:
return cross_sections
@@ -669,10 +732,12 @@ def _generate_cross_section_list(self):
pbar = tqdm.tqdm(total=len(selected_list))
for i, css_data in enumerate(selected_list):
self.start_new_log_task(
- f"{css_data.get('id')} ({i}/{len(selected_list)})", pbar=pbar
+ f"{css_data.get('id')} ({i}/{len(selected_list)})",
+ pbar=pbar,
)
generated_cross_section = self._generate_cross_section(
- css_data, self.fm_model_data
+ css_data,
+ self.fm_model_data,
)
if generated_cross_section is not None:
cross_sections.append(generated_cross_section)
@@ -680,60 +745,64 @@ def _generate_cross_section_list(self):
return cross_sections
- def _get_css_range(self, number_of_css: int):
- """parses the CssSelection keyword from the inifile"""
- cssSelection = self.get_inifile().get_parameter("CssSelection")
- if not cssSelection:
- cssSelection = np.arange(0, number_of_css)
- else:
- cssSelection = np.array(cssSelection)
- return cssSelection
+ def _get_css_range(self, number_of_css: int) -> np.array:
+ """Parse the CssSelection keyword from the inifile."""
+ css_selection = self.get_inifile().get_parameter("CssSelection")
+ return (
+ np.arange(0, number_of_css)
+ if not css_selection
+ else np.array(css_selection)
+ )
def _generate_cross_section(
- self, css_data: Dict, fm_model_data: FmModelData
+ self,
+ css_data: dict,
+ fm_model_data: FmModelData,
) -> CrossSection:
- """Generates a cross section and configures its values based
+ """Generate a cross section and configures its values based.
+
on the input parameter dictionary
- Arguments:
- css_data {dict}
- -- Dictionary of data for the current cross section
- input_param_dict {Mapping[str,list]}
- -- Dictionary with input parameters
- fm_model_data {FmModelData}
- -- Data to assign to the new cross section
+ Args:
+ ----
+ css_data (dict): Dictionary of data for the current cross section.
+ fm_model_data (FmModelData): Data to assign to the new cross section
Raises:
+ ------
Exception: If no css_data is given.
Exception: If no input_param_dict is given.
Exception: If no fm_model_data is given.
Returns:
- {CrossSection} -- New Cross Section
+ -------
+ (CrossSection): New Cross Section
+
"""
if css_data is None:
- raise Exception("No data was given to create a Cross Section")
+ err_msg = "No data was given to create a Cross Section"
+ raise ValueError(err_msg)
css_name = css_data.get("id")
if not css_name:
css_name = "new_cross_section"
if fm_model_data is None:
- raise Exception(
- "No FM data given for new cross section {}".format(css_name)
- )
+ err_msg = f"No FM data given for new cross section {css_name}"
+ raise ValueError(err_msg)
# Create cross section
created_css = self._create_new_cross_section(css_data=css_data)
if created_css is None:
self.set_logger_message(
- f"No Cross-section could be generated for {css_name}", "error"
+ f"No Cross-section could be generated for {css_name}",
+ "error",
)
return None
- if created_css.get_number_of_faces() < 10:
+ if created_css.get_number_of_faces() < 10: # noqa: PLR2004
self.set_logger_message(
- f"There are too little 2D points in control volume to construct cross-section",
+ "There are too little 2D points in control volume to construct cross-section",
"error",
)
return None
@@ -751,18 +820,19 @@ def _generate_cross_section(
return created_css
def _build_cross_section_geometry(
- self, cross_section: CrossSection
+ self,
+ cross_section: CrossSection,
) -> CrossSection:
- """
- This method manages the options of building the cross-section geometry
+ """Manage the options of building the cross-section geometry.
- Parameters:
- cross_section {CrossSection}
- -- Given Cross Section.
- """
+ Args:
+ ----
+ cross_section (CrossSection): Given Cross Section.
+ """
if cross_section is None:
- raise Exception
+ err_msg = "Cross section cannot be none."
+ raise ValueError(err_msg)
# Build cross-section
self.set_logger_message("Start building geometry", "debug")
@@ -774,23 +844,21 @@ def _build_cross_section_geometry(
cross_section = self._perform_2D_volume_correction(cross_section)
else:
self.set_logger_message(
- "SD Correction not enable in configuration file, skipping", "info"
+ "SD Correction not enable in configuration file, skipping",
+ "info",
)
# Perform sanity check on cross-section
cross_section.check_requirements()
# Reduce number of points in cross-section
- cross_section = self._reduce_css_points(cross_section)
-
- return cross_section
+ return self._reduce_css_points(cross_section)
def _build_cross_section_roughness(
- self, cross_section: CrossSection
+ self,
+ cross_section: CrossSection,
) -> CrossSection:
- """
- Build the roughness tables
- """
+ """Build the roughness tables."""
# Assign roughness
self.set_logger_message("Starting computing roughness tables", "debug")
cross_section.assign_roughness()
@@ -798,17 +866,17 @@ def _build_cross_section_roughness(
return cross_section
- def _create_new_cross_section(self, css_data: Mapping[str, str]):
- """Creates a cross section with the given input param dictionary.
+ def _create_new_cross_section(self, css_data: dict) -> CrossSection | None:
+ """Create a cross section with the given input param dictionary.
- Arguments:
- css_data {Mapping[str, str]}
- -- FM Model data for cross section.
- input_param_dict {Mapping[str, str]}
- -- Dictionary with parameters for Cross Section.
+ Args:
+ ----
+ css_data (dict): FM Model data for cross section.
Returns:
- {CrossSection} -- New cross section object.
+ -------
+ (CrossSection): New cross section object.
+
"""
# Get id data and id index
if not css_data:
@@ -829,8 +897,8 @@ def _create_new_cross_section(self, css_data: Mapping[str, str]):
css_data["fm_data"] = self.fm_model_data.get_selection(css_data.get("id"))
if self.get_inifile().get_parameter("ExportCSSData"):
- output_dir = self.get_inifile().get_output_directory()
- with open(output_dir.joinpath(f"{css_data.get('id')}.pickle"), "wb") as f:
+ output_dir = Path(self.get_inifile().get_output_directory())
+ with output_dir.joinpath(f"{css_data.get('id')}.pickle").open("wb") as f:
pickle.dump(css_data, f)
try:
css = CrossSection(
@@ -842,36 +910,33 @@ def _create_new_cross_section(self, css_data: Mapping[str, str]):
except Exception as e_info:
self.set_logger_message(
"Exception thrown while creating cross-section "
- + "{}, message: {}".format(css_data_id, str(e_info)),
+ + f"{css_data.get('id')}, message: {e_info!s}",
"error",
)
return None
return css
- def _write_output(self, cross_sections: list, output_dir: Path):
- """Exports all cross sections to the necessary file formats
+ def _write_output(self, cross_sections: list, output_dir: Path) -> None:
+ """Export all cross sections to the necessary file formats.
- Arguments:
- cross_sections {list}
- -- List of created cross sections
- output_dir {str}
- -- target directory where to export all the cross sections
- """
- if not cross_sections:
- return
+ Args:
+ ----
+ cross_sections (list): List of created cross sections
+ output_dir (str): target directory where to export all the cross sections
- if not output_dir or not os.path.exists(output_dir):
+ """
+ if not cross_sections or not output_dir.exists():
return
- OutputExporter = Export1DModelData(logger=self.get_logger())
+ output_exporter = Export1DModelData(logger=self.get_logger())
# File paths
css_location_ini_file = output_dir.joinpath(
- self._output_files.dimr_css_locations
+ self._output_files.dimr_css_locations,
)
css_definitions_ini_file = output_dir.joinpath(
- self._output_files.dimr_css_definitions
+ self._output_files.dimr_css_definitions,
)
# Legacy file formats
@@ -879,27 +944,28 @@ def _write_output(self, cross_sections: list, output_dir: Path):
csv_roughness_file = output_dir.joinpath(self._output_files.sobek3_roughness)
csv_geometry_test_file = output_dir.joinpath(self._output_files.test_geometry)
- csv_roughness_test_file = output_dir.joinpath(self._output_files.test_roughness)
-
csv_volumes_file = output_dir.joinpath(self._output_files.fm2prof_volume)
# export fm1D format
try:
# Export locations
- OutputExporter.export_crossSectionLocations(
- cross_sections, file_path=css_location_ini_file
+ output_exporter.export_cross_section_locations(
+ cross_sections,
+ file_path=css_location_ini_file,
)
# Export definitions
- OutputExporter.export_geometry(
- cross_sections, file_path=css_definitions_ini_file, fmt="dflow1d"
+ output_exporter.export_geometry(
+ cross_sections,
+ file_path=css_definitions_ini_file,
+ fmt="dflow1d",
)
# Export roughness
sections = np.unique(
- [s for css in cross_sections for s in css.friction_tables.keys()]
+ [s for css in cross_sections for s in css.friction_tables],
)
- sectionFileKeyDict = {
+ section_file_key_dict = {
"main": [self._output_files.dimr_roughness_main, "Main"],
"floodplain1": [
self._output_files.dimr_roughness_floodplain1,
@@ -912,71 +978,79 @@ def _write_output(self, cross_sections: list, output_dir: Path):
}
for section in sections:
csv_roughness_ini_file = output_dir.joinpath(
- sectionFileKeyDict[section][0]
+ section_file_key_dict[section][0],
)
- OutputExporter.export_roughness(
+ output_exporter.export_roughness(
cross_sections,
file_path=csv_roughness_ini_file,
fmt="dflow1d",
- roughness_section=sectionFileKeyDict[section][1],
+ roughness_section=section_file_key_dict[section][1],
)
except Exception as e_info:
self.set_logger_message(
"An error was produced while exporting files to DIMR format,"
- + " not all output files might be exported. "
- + "{}".format(str(e_info)),
+ " not all output files might be exported. "
+ f"{e_info!s}",
level="error",
)
# Eport SOBEK 3 format
try:
# Cross-sections
- OutputExporter.export_geometry(
- cross_sections, file_path=csv_geometry_file, fmt="sobek3"
+ output_exporter.export_geometry(
+ cross_sections,
+ file_path=csv_geometry_file,
+ fmt="sobek3",
)
# Roughness
- OutputExporter.export_roughness(
- cross_sections, file_path=csv_roughness_file, fmt="sobek3"
+ output_exporter.export_roughness(
+ cross_sections,
+ file_path=csv_roughness_file,
+ fmt="sobek3",
)
except Exception as e_info:
self.set_logger_message(
"An error was produced while exporting files to SOBEK format,"
- + " not all output files might be exported. "
- + "{}".format(str(e_info)),
+ " not all output files might be exported. "
+ f"{e_info!s}",
level="error",
)
# Other files:
try:
- OutputExporter.export_geometry(
- cross_sections, file_path=csv_geometry_test_file, fmt="testformat"
+ output_exporter.export_geometry(
+ cross_sections,
+ file_path=csv_geometry_test_file,
+ fmt="testformat",
)
- OutputExporter.export_volumes(cross_sections, file_path=csv_volumes_file)
+ output_exporter.export_volumes(cross_sections, file_path=csv_volumes_file)
except Exception as e_info:
self.set_logger_message(
"An error was produced while exporting files,"
- + " not all output files might be exported. "
- + "{}".format(str(e_info)),
+ " not all output files might be exported. "
+ f"{e_info!s}",
level="error",
)
self.set_logger_message("Exported output files, FM2PROF finished")
- def _reduce_css_points(self, cross_section: CrossSection):
- """Returns a valid value for the number of css points read from ini file.
+ def _reduce_css_points(self, cross_section: CrossSection) -> CrossSection:
+ """Return a valid value for the number of css points read from ini file.
- Parameters:
+ Parameters
+ ----------
cross_section (CrossSection)
Returns:
+ -------
cross_section (CrossSection): modified
- """
+ """
maximum_number_of_css_points = self.get_inifile().get_parameter(
- "MaximumPointsInProfile"
+ "MaximumPointsInProfile",
)
try:
@@ -985,27 +1059,31 @@ def _reduce_css_points(self, cross_section: CrossSection):
e_message = str(e_error)
self.set_logger_message(
"Exception thrown while trying to reduce the css points. "
- + "{}".format(e_message),
+ + f"{e_message}",
"error",
)
return cross_section
- def _get_time_stamp_seconds(self, start_time: datetime):
- """Returns a time stamp with the time difference
+ def _get_time_stamp_seconds(self, start_time: datetime) -> float:
+ """Return a time stamp with the time difference.
- Arguments:
- start_time {datetime} -- Initial date time
+ Args:
+ ----
+ start_time (datetime): Initial date time
Returns:
- {float} -- difference of time between start and now in seconds
+ -------
+ (float): difference of time between start and now in seconds
+
"""
time_now = datetime.datetime.now()
time_difference = time_now - start_time
return time_difference.total_seconds()
- def _perform_2D_volume_correction(self, css: CrossSection) -> CrossSection:
- """
+ def _perform_2D_volume_correction(self, css: CrossSection) -> CrossSection: # noqa: N802
+ """Calculate a logistic correction term which may be applied in 1D models.
+
In 2D, the volume available in a profile can rise rapidly
while the water level changes little due to compartimentalisation
of the floodplain. This methods calculates a logistic correction
@@ -1015,7 +1093,6 @@ def _perform_2D_volume_correction(self, css: CrossSection) -> CrossSection:
Calculates the Cross Section correction if needed.
"""
-
try:
css.calculate_correction()
self.set_logger_message("correction finished")
@@ -1024,7 +1101,7 @@ def _perform_2D_volume_correction(self, css: CrossSection) -> CrossSection:
self.set_logger_message(
"Exception thrown "
+ "while trying to calculate the correction. "
- + "{}".format(e_message),
+ + f"{e_message}",
"error",
)
return css
@@ -1035,9 +1112,7 @@ def _print_log_report(self) -> None:
self.set_logger_message(f"Errors: {ll.get('ERROR')}")
def _output_exists(self) -> bool:
- """
- Checks whether output exists
- """
+ """Check whether output exists."""
for output_file in self._output_files:
if (
self.get_inifile()
@@ -1046,13 +1121,11 @@ def _output_exists(self) -> bool:
.is_file()
):
return True
- else:
- return False
+ return False
class Project(Fm2ProfRunner):
- """
- Provides the python API for running FM2PROF.
+ """Provides the python API for running FM2PROF.
Instantiate by providing the path to a configuration file
@@ -1060,83 +1133,96 @@ class Project(Fm2ProfRunner):
"""
- def set_parameter(self, name: str, value: Union[str, float, int]):
- """
- Use this method to set the value of a parameter
+ def set_parameter(self, name: str, value: str | float) -> None:
+ """Use this method to set the value of a parameter.
+
+ Args:
+ ----
+ name (str): name of the parameter (case insensitive).
- Arguments:
- name: name of the parameter (case insensitive).
+ value (str | float): value of the parameter.
+ An error will be given if the value has the wrong type (e.g. string if int was expected).
- value: value of the parameter. An error will be given if the value has the wrong type (e.g. string if int was expected).
"""
self.get_inifile().set_parameter(name, value)
- def get_parameter(self, name: str) -> Union[str, float, int]:
- """
- Use this method to get the value of a parameter
+ def get_parameter(self, name: str) -> str | float:
+ """Use this method to get the value of a parameter.
- Arguments:
- name: name of the parameter (case insensitive)
+ Args:
+ ----
+ name (str): name of the parameter (case insensitive)
Returns:
- The current value of the parameter
+ -------
+ (str | float): The current value of the parameter
+
"""
return self.get_inifile().get_parameter(name)
- def set_input_file(self, name: str, value: Union[str, float, int]) -> None:
- """
- Use this method to set the path to an input file
+ def set_input_file(self, name: str, value: str | float) -> None:
+ """Use this method to set the path to an input file.
- Arguments:
+ Args:
+ ----
name: name of the input file in the configuration (case insensitive).
value: path to the inputfile
+
"""
return self.get_inifile().set_input_file(name, value)
- def get_input_file(self, name: str) -> Union[str,]:
- """
- Use this method to retrieve the path to an input file
+ def get_input_file(self, name: str) -> str:
+ """Use this method to retrieve the path to an input file.
- Arguments:
+ Args:
+ ----
name (str): case-insensitive key of the input file (e.g.'2dmapoutput')
+
"""
return self.get_inifile().get_input_file(name)
- def set_output_directory(self, path) -> None:
- """
- Use this method to set the output directory.
+ def set_output_directory(self, path: str | Path) -> None:
+ """Use this method to set the output directory.
.. warning::
calling this function will also create the output directory,
if it does not already exists!
- Arguments:
- path: path to the output path
+ Args:
+ ----
+ path (path | str): path to the output path
+
"""
self.get_inifile().set_output_directory(path)
def get_output_directory(self) -> str:
- """
- Returns the current output directory
- """
+ """Return the current output directory."""
return self.get_inifile().get_output_directory()
def print_configuration(self) -> str:
- """
- Use this method to obtain string representation of the
- configuration. Use this string to write to file, e.g.:
+ """Use this method to obtain string representation of the configuration.
+
+ Use this string to write to file, e.g.:
>> with open('EmptyProject.ini', 'w') as f:
>> f.write(project.print_configuration())
Returns:
- string
+ -------
+ (str): string representation of the configuration
"""
return self.get_inifile().print_configuration()
@property
def output_files(self) -> Generator[Path, None, None]:
+ """Get a generator object with the output files.
+
+ Yields:
+ ------
+ Generator[Path, None, None]: generator of output files.
+
+ """
for of in self._output_files:
yield self.get_output_directory().joinpath(of)
diff --git a/fm2prof/Functions.py b/fm2prof/functions.py
similarity index 63%
rename from fm2prof/Functions.py
rename to fm2prof/functions.py
index b022d739..672e09e2 100644
--- a/fm2prof/Functions.py
+++ b/fm2prof/functions.py
@@ -1,7 +1,5 @@
#! /usr/bin/env python
-"""
-This module contains functions used for the emulation/reduction of 2D models to 1D models for Delft3D FM (D-Hydro).
-
+"""Contains functions used for the emulation/reduction of 2D models to 1D models for Delft3D FM (D-Hydro).
Dependencies
------------------
@@ -35,10 +33,21 @@
Stichting Deltares and remain full property of Stichting Deltares at all times.
All rights reserved.
"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
import numpy as np
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors
+if TYPE_CHECKING:
+ from logging import Logger
+
+ from fm2prof.CrossSection import CrossSection
+ from fm2prof.RegionPolygonFile import SectionPolygonFile
+
__author__ = "Koen Berends"
__copyright__ = "Copyright 2016, University of Twente & Deltares"
__credits__ = ["Koen Berends"]
@@ -52,8 +61,12 @@
# region // public functions
-def classify_roughness_sections_by_polygon(sections, data, logger):
- """assigns edges to a roughness section based on polygon data"""
+def classify_roughness_sections_by_polygon(
+ sections: SectionPolygonFile,
+ data: dict | pd.DataFrame,
+ logger: Logger,
+) -> pd.DataFrame | dict:
+ """Assign edges to a roughness section based on polygon data."""
logger.debug("....gathering points")
points = [(data["x"][i], data["y"][i]) for i in range(len(data["x"]))]
logger.debug("....classifying points")
@@ -62,20 +75,21 @@ def classify_roughness_sections_by_polygon(sections, data, logger):
def extract_point_from_np(data: dict, pos: int) -> list:
+ """Extract points."""
return (data["x"][pos], data["y"][pos])
def classify_with_regions(
- regions, cssdata, time_independent_data, edge_data, css_regions
-):
- """
- Assigns cross-section id's based on region polygons.
+ cssdata: dict,
+ time_independent_data: pd.DataFrame,
+ edge_data: dict,
+ css_regions: list,
+) -> tuple[pd.DataFrame, dict]:
+ """Assign cross-section id's based on region polygons.
+
Within a region, assignment will be done by k nearest neighbour
"""
-
time_independent_data["sclass"] = time_independent_data["region"].astype(str)
- # edge_data['sclass'] = edge_data['region']
-
# Nearest Neighbour within regions
for region in np.unique(css_regions):
# Select cross-sections within this region
@@ -97,22 +111,25 @@ def classify_with_regions(
css_2d_edges = neigh.predict(np.array([x_2d_edge, y_2d_edge]).T)
# Update data in main structures
- time_independent_data.loc[node_mask, "sclass"] = (
- css_2d_nodes # sclass = cross-section id
- )
+ time_independent_data.loc[node_mask, "sclass"] = css_2d_nodes # sclass = cross-section id
edge_data["sclass"][edge_mask] = css_2d_edges
return time_independent_data, edge_data
-def classify_without_regions(cssdata, time_independent_data, edge_data):
+def classify_without_regions(
+ cssdata: dict,
+ time_independent_data: pd.DataFrame,
+ edge_data: dict,
+) -> tuple[pd.DataFrame, dict]:
+ """Classify without regions."""
# Create a class identifier to map points to cross-sections
neigh = _get_class_tree(cssdata["xy"], cssdata["id"])
# Expand time-independent dataset with cross-section names
time_independent_data["sclass"] = neigh.predict(
- np.array([time_independent_data["x"], time_independent_data["y"]]).T
+ np.array([time_independent_data["x"], time_independent_data["y"]]).T,
)
# Assign cross-section names to edge coordinates as well
@@ -121,25 +138,21 @@ def classify_without_regions(cssdata, time_independent_data, edge_data):
return time_independent_data, edge_data
-def mirror(array, reverse_sign=False):
- """
- Mirrors array
-
- :param array:
- :param reverse_sign:
- :return:
- """
+def mirror(array: np.array, *, reverse_sign: bool = False) -> np.array:
+ """Mirrors array."""
if reverse_sign:
return np.append(np.flipud(array) * -1, array)
- else:
- return np.append(np.flipud(array), array)
-
+ return np.append(np.flipud(array), array)
-def get_centre_values(location, x, y, waterdepth, waterlevel):
- """
- Find output point closest to x,y location, output depth and water level as nd arrays
- """
+def get_centre_values(
+ location: np.array,
+ x: float,
+ y: float,
+ waterdepth: pd.DataFrame,
+ waterlevel: pd.DataFrame,
+) -> tuple[np.ndarray, np.ndarray]:
+ """Find output point closest to x,y location, output depth and water level as nd arrays."""
nn = NearestNeighbors(n_neighbors=1, algorithm="ball_tree").fit(np.array([x, y]).T)
# conversion to 2d array, as 1d arrays are deprecated for kneighbors
@@ -153,46 +166,51 @@ def get_centre_values(location, x, y, waterdepth, waterlevel):
# When starting from a dry bed, the centre_level may have nan values
#
bed_level = np.nanmin(centre_level - centre_depth)
- # centre_depth[np.isnan(centre_depth)] = np.nanmin(centre_depth)
centre_level[np.isnan(centre_level)] = bed_level
return centre_depth[0], centre_level[0]
-def empirical_ppf(qs, p, val=None, single_value=False):
- """
- Constructs empirical cdf, then draws quantile by linear interpolation
- qs : array of quantiles (e.g. [2.5, 50, 97.5])
- p : array of random inputs
+def empirical_ppf(
+ qs: np.array,
+ p: np.array,
+ val: list | np.ndarray | None = None,
+ *,
+ single_value: bool = False,
+) -> list | np.ndarray:
+ """Construct empirical cdf, then draws quantile by linear interpolation.
+
+ Args:
+ ----
+ qs (np.array): array of quantiles
+ p (np.array): array of random inputs
+ val (np.ndarray | None, optional): array or list of values. Defaults to None.
+ single_value (bool, optional): boolean for indicating single value. Defaults to False.
+
+ Returns:
+ -------
+ list | np.ndarray
- return
"""
if val is None:
p, val = get_empirical_cdf(p)
- if not single_value:
- output = list()
- for q in qs:
- output.append(np.interp(q / 100.0, p, val))
- else:
- output = np.interp(qs / 100.0, p, val)
- return output
-
+ return [np.interp(q / 100.0, p, val) for q in qs] if not single_value else np.interp(qs / 100.0, p, val)
-def get_empirical_cdf(sample, method=1, ignore_nan=True):
- """
- Returns an experimental/empirical cdf from data.
- Arguments:
+def get_empirical_cdf(sample: list, *, ignore_nan: bool = True) -> tuple[np.array, np.array]:
+ """Return an experimental/empirical cdf from data.
- p : list
+ Args:
+ ----
+ sample (list): list of sample values
+ ignore_nan (bool, optional): Defaults to True.
Returns:
-
- (x, y) : lists of values (x) and cumulative probability (y)
+ -------
+ tuple[np.array, np.array]: tuple containg arrays of values (x) and cumulative probability (y)
"""
-
sample = np.array(sample)
if ignore_nan:
sample = sample[~np.isnan(sample)]
@@ -209,15 +227,19 @@ def get_empirical_cdf(sample, method=1, ignore_nan=True):
# region // protected functions
-def _get_class_tree(xy, c):
- X = xy
+def _get_class_tree(xy: np.ndarray, c: np.ndarray) -> KNeighborsClassifier:
+ x = xy
y = c
neigh = KNeighborsClassifier(n_neighbors=1)
- neigh.fit(X, y)
+ neigh.fit(x, y)
return neigh
-def _interpolate_roughness_css(cross_section, alluvial_range, nonalluvial_range):
+def _interpolate_roughness_css(
+ cross_section: CrossSection,
+ alluvial_range: np.ndarray,
+ nonalluvial_range: np.ndarray,
+) -> None:
# change nan's to zeros
chezy_alluvial = np.nan_to_num(cross_section.alluvial_friction_table[1])
chezy_nonalluvial = np.nan_to_num(cross_section.nonalluvial_friction_table[1])
@@ -234,12 +256,8 @@ def _interpolate_roughness_css(cross_section, alluvial_range, nonalluvial_range)
# only interpolate and assign if nonzero elements exist in the chezy table
if np.sum(alluvial_nonzero_mask) > 0:
- waterlevel_alluvial_trimmed = waterlevel_alluvial[
- alluvial_nonzero_mask[0] : alluvial_nonzero_mask[-1] + 1
- ]
- alluvial_interp = np.interp(
- alluvial_range, waterlevel_alluvial_trimmed, chezy_alluvial_trimmed
- )
+ waterlevel_alluvial_trimmed = waterlevel_alluvial[alluvial_nonzero_mask[0] : alluvial_nonzero_mask[-1] + 1]
+ alluvial_interp = np.interp(alluvial_range, waterlevel_alluvial_trimmed, chezy_alluvial_trimmed)
# assign
cross_section.alluvial_friction_table[0] = alluvial_range
@@ -249,9 +267,7 @@ def _interpolate_roughness_css(cross_section, alluvial_range, nonalluvial_range)
waterlevel_nonalluvial_trimmed = waterlevel_nonalluvial[
nonalluvial_nonzero_mask[0] : nonalluvial_nonzero_mask[-1] + 1
]
- nonalluvial_interp = np.interp(
- nonalluvial_range, waterlevel_nonalluvial_trimmed, chezy_nonalluvial_trimmed
- )
+ nonalluvial_interp = np.interp(nonalluvial_range, waterlevel_nonalluvial_trimmed, chezy_nonalluvial_trimmed)
# assign
cross_section.nonalluvial_friction_table[0] = nonalluvial_range
diff --git a/fm2prof/IniFile.py b/fm2prof/ini_file.py
similarity index 51%
rename from fm2prof/IniFile.py
rename to fm2prof/ini_file.py
index 5462f244..01236d4c 100644
--- a/fm2prof/IniFile.py
+++ b/fm2prof/ini_file.py
@@ -1,54 +1,49 @@
-"""
-This module contains the IniFile class, which handles the main configuration file.
-"""
+"""Module contains the IniFile class, which handles the main configuration file."""
+
+from __future__ import annotations
import configparser
import inspect
import io
import json
import os
-from logging import Logger
from pathlib import Path
from pydoc import locate
-from typing import AnyStr, Dict, List, Mapping, Type, Union
+from typing import TYPE_CHECKING, Any, Generator, Mapping
from fm2prof.common import FM2ProfBase
+if TYPE_CHECKING:
+ from logging import Logger
-class InvalidConfigurationFileError(Exception):
- """Raised when config file is not up to snot"""
- pass
+class InvalidConfigurationFileError(Exception):
+ """Raised when config file is not up to snot."""
class ImportBoolType:
- """Custom type to parse booleans"""
+ """Custom type to parse booleans."""
- def __new__(cls, value):
- if value.lower().strip() == "true":
- return True
- else:
- return False
+ def __new__(cls, value: str) -> bool: # noqa: D102
+ return value.lower().strip() == "true"
class ImportListType:
- """Custom type to parse list of comma separated ints"""
+ """Custom type to parse list of comma separated ints."""
- def __new__(cls, value):
+ def __new__(cls, value: str) -> list: # noqa: D102
return list(map(int, value.strip("[]").split(",")))
class IniFile(FM2ProfBase):
- """
- This class provides all functionality to interact with the configuration file, e.g.:
+ """Class for utilizing .ini files.
+
+ Provides all functionality to interact with the configuration file, e.g.:
- reading
- validating
- retrieving parameters
- retrieving files
- Parameters:
- file_path (str): path to filestring
-
"""
_file: Path = None
@@ -57,28 +52,29 @@ class IniFile(FM2ProfBase):
__input_debug_key = "debug"
__output_key = "output"
__output_directory_key = "OutputDirectory"
- __ini_keys = dict(
- map_file="2dmapoutput",
- css_file="crosssectionlocationfile",
- region_file="regionpolygonfile",
- section_file="sectionpolygonfile",
- export_mapfiles="exportmapfiles",
- css_selection="cssselection",
- classificationmethod="classificationmethod",
- sdfloodplainbase="sdfloodplainbase",
- sdstorage="sdstorage",
- transitionheight_sd="transitionheight_sd",
- number_of_css_points="number_of_css_points",
- minimum_width="minimum_width",
- )
+ __ini_keys = { # noqa:RUF012
+ "map_file": "2dmapoutput",
+ "css_file": "crosssectionlocationfile",
+ "region_file": "regionpolygonfile",
+ "section_file": "sectionpolygonfile",
+ "export_mapfiles": "exportmapfiles",
+ "css_selection": "cssselection",
+ "classificationmethod": "classificationmethod",
+ "sdfloodplainbase": "sdfloodplainbase",
+ "sdstorage": "sdstorage",
+ "transitionheight_sd": "transitionheight_sd",
+ "number_of_css_points": "number_of_css_points",
+ "minimum_width": "minimum_width",
+ }
_output_dir = None
_input_file_paths = None
_input_parameters = None
- def __init__(self, file_path: Union[Path, str] = ".", logger: Logger = None):
- """
- Initializes the object Ini File which contains the path locations of all
+ def __init__(self, file_path: Path | str = ".", logger: Logger | None = None) -> None:
+ """Initialize the object Ini File.
+
+ File should contain the path locations of all
parameters needed by the Fm2ProfRunner.
The configuration file consists of three main sections: input, parameters and output.
@@ -91,8 +87,10 @@ def __init__(self, file_path: Union[Path, str] = ".", logger: Logger = None):
- **output** specifies the output directory. A check will be performed:
- whether the directory can be created
- Arguments:
- file_path {str} -- File path where the IniFile is located
+ Args:
+ file_path (str): File path where the IniFile is located.
+ logger (Logger): logger object to log messages to.
+
"""
super().__init__(logger=logger)
@@ -101,11 +99,11 @@ def __init__(self, file_path: Union[Path, str] = ".", logger: Logger = None):
if file_path is None:
self.set_logger_message(
- "No ini file given, using default options", "warning"
+ "No ini file given, using default options",
+ "warning",
)
return
- else:
- file_path = Path(file_path)
+ file_path = Path(file_path)
if isinstance(file_path, Path):
self._file = file_path
if file_path.is_file():
@@ -113,12 +111,13 @@ def __init__(self, file_path: Union[Path, str] = ".", logger: Logger = None):
self._read_inifile(file_path)
elif file_path.is_dir():
self.set_logger_message(
- "No ini file given, using default options", "warning"
+ "No ini file given, using default options",
+ "warning",
)
- pass
else:
# User has supplied a file, but the file does not exist. Raise error.
- raise IOError(f"The given file path {file_path} could not be found")
+ err_msg = f"The given file path {file_path} could not be found"
+ raise OSError(err_msg)
@property
def _file_dir(self) -> Path:
@@ -128,25 +127,23 @@ def _file_dir(self) -> Path:
@property
def has_output_directory(self) -> bool:
- """
- Verifies if the output directory has been set and exists or not.
- Arguments:
- iniFile {IniFile} -- [description]
+ """Verifies if the output directory has been set and exists or not.
+
Returns:
True - the output_dir is set and exists.
False - the output_dir is not set or does not exist.
+
"""
if self.get_output_directory() is None:
self.set_logger_message("No Output Set", "warning")
return False
- if not os.path.exists(self.get_output_directory()):
+ if not self.get_output_directory().exists():
try:
- os.makedirs(self.get_output_directory())
+ self.get_output_directory().mkdir(parents=True)
except OSError:
self.set_logger_message(
- "The output directory {}, ".format(self.get_output_directory())
- + "could not be found neither created.",
+ f"The output directory {self.get_output_directory()}, could not be found neither created.",
"warning",
)
return False
@@ -154,54 +151,72 @@ def has_output_directory(self) -> bool:
return True
def get_output_directory(self) -> Path:
- """
- Use this method to return the output directory
+ """Use this method to return the output directory.
Returns:
output directory (Path)
"""
- op = self._configuration["sections"]["output"][self.__output_directory_key][
- "value"
- ]
+ op = self._configuration["sections"]["output"][self.__output_directory_key]["value"]
return self.get_relative_path(op)
- def set_output_directory(self, name: Union[Path, str]) -> None:
- """
- Use this method to set the output directory
+ def set_output_directory(self, name: Path | str) -> None:
+ """Use this method to set the output directory.
- Parameters
- name: name of the output-directory
- """
+ Args:
+ name (Path | str): name of the output directory
- self._configuration["sections"]["output"][self.__output_directory_key][
- "value"
- ] = self._get_valid_output_dir(name)
+ """
+ self._configuration["sections"]["output"][self.__output_directory_key]["value"] = self._get_valid_output_dir(
+ name,
+ )
return self.get_output_directory()
def get_ini_root(self, dirtype: str = "relative") -> Path:
- if dirtype == "relative":
- return self._file_dir
- elif dirtype == "absolute":
+ """Get root directory of ini file.
+
+ Args:
+ dirtype (str, optional): Abosulte or relative path. Defaults to "relative".
+
+ Returns:
+ Path: _description_
+ """
+ if dirtype == "absolute":
return Path.cwd().joinpath(self._file_dir)
+ return self._file_dir
def get_relative_path(self, path: str) -> Path:
- return self.get_ini_root().joinpath(path)
+ """Get relative path of the ini root.
- def get_input_file(self, key: str) -> AnyStr:
- """
- Use this method to retrieve the path to an input file
+ Args:
+ path (str): file path.
- Parameters:
- key (str): name of the input file
+ Returns:
+ Path: relative path to the ini root directory.
"""
- return self._get_from_configuration("input", key)
+ return self.get_ini_root().joinpath(path)
+
+ def get_input_file(self, file_name: str) -> str:
+ """Use this method to retrieve the path to an input file.
- def get_parameter(self, key: str) -> Union[str, bool, int, float]:
+ Args:
+ file_name (str): name of input file.
+
+ Returns:
+ str: string of path to input file
"""
- Use this method to return a parameter value
+ return self._get_from_configuration("input", file_name)
+
+ def get_parameter(self, key: str) -> str | bool | int | float | None:
+ """Use this method to return a parameter value.
+
+ Args:
+ key (str): name of parameter.
+
+ Returns:
+ Union (str | bool | int | float | None): parameter value.
"""
try:
return self._get_from_configuration("parameters", key)
@@ -212,120 +227,122 @@ def get_parameter(self, key: str) -> Union[str, bool, int, float]:
except KeyError:
pass
self.set_logger_message(f"unknown key {key}", "error")
+ return None
- def set_input_file(self, key: str, value=None) -> None:
- """Use this method to set a input files the configuration"""
+ def set_input_file(self, key: str, value: str | Path | None = None) -> None:
+ """Use this method to set a input files the configuration."""
self._set_config_value("input", key, value)
- def set_parameter(self, key: str, value, section="parameters") -> None:
- """Use this method to set a key/value pair to the configuration"""
+ def set_parameter(self, key: str, value: str | float, section: str = "parameters") -> None:
+ """Use this method to set a key/value pair to the configuration."""
self._set_config_value(section, key, value)
- def _set_config_value(self, section, key, value) -> None:
- """Use this method to set a input files the configuration"""
+ def _set_config_value(self, section: str, key: str, value: Any) -> None: # noqa: ANN401
+ """Use this method to set a input files the configuration.
+
+ Args:
+ section (str): name of section.
+ key (str): name of key.
+ value (Any): config value to set.
+
+ """
ckey = self._get_key_from_case_insensitive_input(section=section, key_in=key)
if ckey:
self._configuration["sections"][section][ckey]["value"] = value
return True
- else:
- keys = self._configuration["sections"][section].keys()
- self.set_logger_message(f"Unknown {section}. Available keys: {list(keys)}")
- return False
+ keys = self._configuration["sections"][section].keys()
+ self.set_logger_message(f"Unknown {section}. Available keys: {list(keys)}")
+ return False
- def _set_output_directory_no_validation(self, value: str) -> None:
- """
- Use this method to set the output directory within testing framework only
+ def _set_output_directory_no_validation(self, value: str | Path) -> None:
+ """Set the output directory within testing framework only.
+
+ Args:
+ value (Union[str, Path]): output directory
"""
- self._configuration["sections"]["output"][self.__output_directory_key][
- "value"
- ] = value
+ self._configuration["sections"]["output"][self.__output_directory_key]["value"] = value
def print_configuration(self) -> str:
- """Use this method to print a string of the configuration used"""
- return self._print_configuration(self._configuration)
-
- def iter_parameters(self):
- """Use this method to iterate through the names and values of all parameters"""
- for parameter, content in (
- self._configuration["sections"].get("parameters").items()
- ):
- yield (
- parameter,
- content.get("type"),
- content.get("hint"),
- content.get("value"),
- )
-
- @staticmethod
- def _print_configuration(inputdict) -> str:
+ """Print the configuration as a string."""
f = io.StringIO()
- for sectionname, section in inputdict.get("sections").items():
+ for sectionname, section in self._configuration.get("sections").items():
f.write(f"[{sectionname}]\n")
for key, contents in section.items():
f.write(
- f"{key:<30}= {str(contents.get('value')):<10}# {contents.get('hint')}\n"
+ f"{key:<30}= {contents.get('value')!s:<10}# {contents.get('hint')}\n",
)
f.write("\n")
return f.getvalue()
- def _get_key_from_case_insensitive_input(self, section, key_in):
+ def iter_parameters(self) -> Generator[tuple[str], None, None]:
+ """Iterate through the names and values of all parameters."""
+ for parameter, content in self._configuration["sections"].get("parameters").items():
+ yield (
+ parameter,
+ content.get("type"),
+ content.get("hint"),
+ content.get("value"),
+ )
+
+ def _get_key_from_case_insensitive_input(self, section: str, key_in: str) -> str:
section = self._configuration["sections"][section]
- for key in section.keys():
+ for key in section:
if key.lower() == key_in.lower():
return key
return ""
def _get_from_configuration(
- self, section: str, key: str
- ) -> Union[str, bool, int, float]:
+ self,
+ section: str,
+ key: str,
+ ) -> str | bool | int | float:
for parameter, content in self._configuration["sections"].get(section).items():
if parameter.lower() == key.lower():
return content.get("value")
- raise KeyError(f"key {key} not found")
+ err_msg = f"key {key} not found"
+ raise KeyError(err_msg)
- def _get_template_ini(self) -> Dict:
+ def _get_template_ini(self) -> dict:
# Open config file
path, _ = os.path.split(inspect.getabsfile(IniFile))
- with open(os.path.join(path, "configurationfile_template.json"), "r") as f:
- default_ini = json.load(f)
-
- # parse all types
- return default_ini
+ with Path(path).joinpath("configurationfile_template.json").open("r") as f:
+ return json.load(f)
- def _read_inifile(self, file_path: str):
- """
- Reads the inifile and extract all its parameters for later usage
+ def _read_inifile(self, file_path: str) -> None:
+ """Reads the inifile and extract all its parameters for later usage.
- Parameters:
+ Parameters
+ ----------
file_path {str} -- File path where the IniFile is located
+
"""
- if file_path is None or not file_path:
+ if not file_path:
msg = "No ini file was specified and no data could be read."
self.set_logger_message(msg, "error")
- raise IOError(msg)
+ raise OSError(msg)
try:
- if not os.path.exists(file_path):
+ if not Path(file_path).exists():
msg = f"The given file path {file_path} could not be found."
self.set_logger_message(msg, "error")
- raise IOError(msg)
- except TypeError:
+ raise OSError(msg)
+ except TypeError as err:
if not isinstance(file_path, io.StringIO):
- raise IOError("Unknown file format entered")
+ err_msg = "Unknown file format entered"
+ raise TypeError(err_msg) from err
try:
supplied_ini = self._get_inifile_params(file_path)
except Exception as e_info:
raise Exception(
- "It was not possible to extract ini parameters from the file {}. Exception thrown: {}".format(
- file_path, str(e_info)
- )
+ f"It was not possible to extract ini parameters from the file {file_path}. Exception thrown: {e_info!s}",
)
# Compare supplied with default/expected inifile
try:
self._extract_input_files(supplied_ini)
- except Exception as e_info:
+ except Exception:
self.set_logger_message(
- "Unexpected error reading input files. Check config file", "error"
+ "Unexpected error reading input files. Check config file",
+ "error",
)
try:
self._extract_parameters(supplied_ini, self.__input_parameters_key)
@@ -339,74 +356,71 @@ def _read_inifile(self, file_path: str):
self._extract_output_dir(supplied_ini)
except Exception:
self.set_logger_message(
- "Unexpected error output parameters. Check config file", "error"
+ "Unexpected error output parameters. Check config file",
+ "error",
)
- def _extract_parameters(self, supplied_ini: Mapping[str, list], section: str):
- """Extract InputParameters and convert values either integer or float from string
-
- Arguments:
- inifile_parameters {Mapping[str, list} -- Collection of parameters as read in the original file
+ def _extract_parameters(self, supplied_ini: Mapping[str, list], section: str) -> None:
+ """Extract InputParameters and convert values either integer or float from string.
- Returns:
- {Mapping[str, number]} -- Dictionary of mapped parameters to a either integer or float
+ Args:
+ supplied_ini (Mapping[str, list]): Mapping of ini config parameters
+ section (str): name of section
"""
try:
inputsection = supplied_ini.get(section)
- except KeyError:
- raise InvalidConfigurationFileError
+ except KeyError as err:
+ raise InvalidConfigurationFileError from err
for key, value in inputsection.items():
key_default, key_type = self._get_key_from_template(section, key)
try:
parsed_value = key_type(value)
- self.set_parameter(key_default, parsed_value, section)
+ self._set_config_value("parameters", key_default, parsed_value)
except ValueError:
self.set_logger_message(
- f"{key} could not be cast as {key_type}", "debug"
+ f"{key} could not be cast as {key_type}",
+ "debug",
)
except KeyError:
pass
def _extract_input_files(self, supplied_ini: Mapping[str, list]) -> None:
- """
- Extract and validates input files
+ """Extract and validate input files.
- Parameters:
- inifile_parameters {Mapping[str, list]} -- Collection of parameters as read in the original file
-
- Returns:
- {Mapping[str, str]} -- new dict containing a normalized key (file name parameter) and its value.
+ Args:
+ supplied_ini (Mapping[str, list]): Mapping of ini config parameters
"""
try:
inputsection = supplied_ini.get(self.__input_files_key)
- except KeyError:
- raise InvalidConfigurationFileError
+ except KeyError as err:
+ raise InvalidConfigurationFileError from err
for key, input_file in inputsection.items():
key_default, _ = self._get_key_from_template("input", key)
if key_default is None:
continue
- input_file = self._file_dir.joinpath(input_file)
+ input_file_path = self._file_dir.joinpath(input_file)
- if input_file.is_file():
- self.set_input_file(key_default, input_file)
+ if input_file_path.is_file():
+ self._set_config_value("input", key_default, input_file_path)
continue
if key_default in ("2DMapOutput", "CrossSectionLocationFile"):
+ err_msg = f"Could not find input file: {key_default}"
self.set_logger_message(
- f"Could not find input file: {key_default}", "error"
- )
- raise FileNotFoundError
- else:
- self.set_logger_message(
- f"Could not find optional input file for {key_default}, skipping",
- "warning",
+ err_msg,
+ "error",
)
+ raise FileNotFoundError(err_msg)
+ self.set_logger_message(
+ f"Could not find optional input file for {key_default}, skipping",
+ "warning",
+ )
- def _get_key_from_template(self, section, key) -> List[Union[str, Type]]:
- """return list of lower case keys from default configuration files"""
+ def _get_key_from_template(self, section: str, key: str) -> list[str, type]:
+ """Return list of lower case keys from default configuration files."""
sectiondict = self._ini_template.get("sections").get(section)
for entry in sectiondict:
if key.lower() == entry.lower():
@@ -415,63 +429,58 @@ def _get_key_from_template(self, section, key) -> List[Union[str, Type]]:
self.set_logger_message(f"{key} is not a known key", "warning")
return [None, KeyError]
- def _extract_output_dir(self, supplied_ini: Mapping[str, list]):
- """
- Extract and validates output directory
+ def _extract_output_dir(self, supplied_ini: Mapping[str, list]) -> None:
+ """Extract and validates output directory.
- Parameters:
- supplied_ini {Mapping[str, list]} -- Collection of parameters as read in the original file
+ Args:
+ supplied_ini (Mapping[str, list]): Mapping of ini config parameters
- Returns:
- {str} -- Normalized output dir path
"""
try:
outputsection = supplied_ini.get(self.__output_key)
- except KeyError:
- raise InvalidConfigurationFileError
+ except KeyError as err:
+ raise InvalidConfigurationFileError from err
for key, value in outputsection.items():
if key.lower() == self.__output_directory_key.lower():
self.set_output_directory(value)
else:
self.set_logger_message(
- f"Unknown key {key} found in configuration file", "warning"
+ f"Unknown key {key} found in configuration file",
+ "warning",
)
- def _get_valid_output_dir(self, output_dir: str):
- """
- Gets a normalized output directory path. Creates it if not yet exists
+ def _get_valid_output_dir(self, output_dir: str) -> Path:
+ """Get a normalized output directory path. Creates it if not yet exists.
- Arguments:
- output_dir {str} -- Relative path to the configuration file.
+ Args:
+ output_dir (str): Relative path to the configuration file.
Returns:
- {str} -- Valid output directory path.
+ _Path: Valid output directory path.
"""
- try:
- os.makedirs(output_dir)
- except FileExistsError:
- pass
+ output_dir = Path(output_dir)
+ if output_dir.exists():
+ return output_dir
+ output_dir.mkdir()
return output_dir
@property
- def _output_dir(self):
+ def _output_dir(self) -> str | None:
try:
- return self._configuration["sections"]["output"][
- self.__output_directory_key
- ]["value"]
+ return self._configuration["sections"]["output"][self.__output_directory_key]["value"]
except KeyError:
return None
@staticmethod
- def _get_inifile_params(file_path: str) -> Dict:
+ def _get_inifile_params(file_path: str) -> dict:
"""Extracts the parameters from an ini file.
- Arguments:
- file_path {str} -- Ini file location
+ Args:
+ file_path (str): Ini file location.
Returns:
- {array} -- List of sections containing list of options
+ Dict: config file parameters
"""
ini_file_params = {}
comment_delimter = "#"
@@ -481,7 +490,7 @@ def _get_inifile_params(file_path: str) -> Dict:
if isinstance(file_path, io.StringIO):
config.read_file(file_path)
else:
- with open(file_path, "r") as f:
+ with Path(file_path).open("r") as f:
config.read_file(f)
# Extract all sections and options
diff --git a/fm2prof/main.py b/fm2prof/main.py
index cbdd8f30..91764506 100644
--- a/fm2prof/main.py
+++ b/fm2prof/main.py
@@ -9,7 +9,7 @@
import sys
# Import from package
-from fm2prof import Fm2ProfRunner
+from fm2prof import fm2prof_runner
# Import from dependencies
# None
@@ -59,16 +59,12 @@ def main(argv):
# Check if input parameters are in expected order
if not __is_input(opts[0]):
- err_mssg = (
- ""
- + "The first argument should be an input file.\n"
- + "Given: {}\n".format(opts[0])
- )
+ err_mssg = "" + "The first argument should be an input file.\n" + "Given: {}\n".format(opts[0])
__report_expected_arguments(err_mssg)
# Run Fm2Prof with given arguments
ini_file_path = opts[0][1]
- runner = Fm2ProfRunner.Fm2ProfRunner(ini_file_path)
+ runner = fm2prof_runner.Fm2ProfRunner(ini_file_path)
runner.run()
diff --git a/fm2prof/mask_output_file.py b/fm2prof/mask_output_file.py
new file mode 100644
index 00000000..fd993115
--- /dev/null
+++ b/fm2prof/mask_output_file.py
@@ -0,0 +1,104 @@
+"""Module for handling mask output file.
+
+Copyright (C) Stichting Deltares 2019. All rights reserved.
+
+This file is part of the Fm2Prof.
+
+The Fm2Prof is free software: you can redistribute it and/or modify
+it under the terms of the GNU Lesser General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public License
+along with this program. If not, see .
+
+All names, logos, and references to "Deltares" are registered trademarks of
+Stichting Deltares and remain full property of Stichting Deltares at all times.
+All rights reserved.
+"""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+import geojson
+
+
+def create_mask_point(coords: geojson.coords, properties: dict | None = None) -> geojson.Feature:
+ """Create a Point based on the properties and coordinates given.
+
+ Args:
+ coords (geojson.coords): Coordinates tuple (x,y) for the mask point.
+ properties (dict): Dictionary of properties
+
+ """
+ if not coords:
+ err_msg = "coords cannot be empty."
+ raise ValueError(err_msg)
+ output_mask = geojson.Feature(geometry=geojson.Point(coords))
+
+ if properties:
+ output_mask.properties = properties
+ return output_mask
+
+
+def validate_extension(file_path: str | Path) -> None:
+ """Validate extension of mask output file.
+
+ Args:
+ file_path (Union[str, Path]): path to output file.
+
+ """
+ if not isinstance(file_path, (str, Path)):
+ err_msg = f"file_path should be string or Path, not {type(file_path)}"
+ raise TypeError(err_msg)
+ if Path(file_path).suffix not in [".json", ".geojson"]:
+ err_msg = "Invalid file path extension, should be .json or .geojson."
+ raise OSError(err_msg)
+
+
+def read_mask_output_file(file_path: str | Path) -> dict:
+ """Import a GeoJson from a given json file path.
+
+ Args:
+ file_path (str | Path): Location of the json file
+
+ """
+ file_path = Path(file_path)
+ if not file_path.exists():
+ err_msg = f"File path {file_path} not found."
+ raise FileNotFoundError(err_msg)
+
+ validate_extension(file_path)
+ with file_path.open("r") as geojson_file:
+ geojson_data = geojson.load(geojson_file)
+ if not isinstance(geojson_data, geojson.FeatureCollection):
+ err_msg = "File is empty or not a valid geojson file."
+ raise OSError(err_msg)
+ return geojson_data
+
+
+def write_mask_output_file(file_path: Path | str, mask_points: list) -> None:
+ """Write a geojson file with a Feature collection containing the mask_points list given as input.
+
+ Arguments:
+ file_path (str): file_path where to store the geojson.
+ mask_points (list): List of features to output.
+
+ """
+ if not file_path:
+ err_msg = "file_path is required."
+ raise ValueError(err_msg)
+ file_path = Path(file_path)
+ if not mask_points:
+ err_msg = "mask_points cannot be empty."
+ raise ValueError(err_msg)
+ validate_extension(file_path)
+ feature_collection = geojson.FeatureCollection(mask_points)
+ with file_path.open("w") as f:
+ geojson.dump(feature_collection, f, indent=4)
diff --git a/fm2prof/RegionPolygonFile.py b/fm2prof/region_polygon_file.py
similarity index 60%
rename from fm2prof/RegionPolygonFile.py
rename to fm2prof/region_polygon_file.py
index 19abcc4c..85602527 100644
--- a/fm2prof/RegionPolygonFile.py
+++ b/fm2prof/region_polygon_file.py
@@ -1,4 +1,5 @@
-"""
+"""Module handles region polygon files.
+
Copyright (C) Stichting Deltares 2019. All rights reserved.
This file is part of the Fm2Prof.
@@ -21,34 +22,42 @@
All rights reserved.
"""
-import logging
+from __future__ import annotations
+
import json
-from collections import namedtuple
from pathlib import Path
-from typing import Iterable, Union, List
-import rtree
+from typing import TYPE_CHECKING, Iterable, NamedTuple
import numpy as np
+import rtree
from shapely.geometry import Point, shape
from fm2prof.common import FM2ProfBase
-Polygon = namedtuple("Polygon", ["geometry", "properties"])
+if TYPE_CHECKING:
+ from logging import Logger
+
+
+class Polygon(NamedTuple):
+ """Polygon datastructure."""
+
+ geometry: shape
+ properties: dict
class PolygonFile(FM2ProfBase):
+ """Polygon file class."""
+
__logger = None
- def __init__(self, logger):
+ def __init__(self, logger: Logger) -> None:
+ """Instantiae a PolygonFile object."""
self.set_logger(logger)
- self._polygons = list()
+ self._polygons = []
self.undefined = -999
- def classify_points_with_property(
- self, points: Iterable[Point], property_name: str = "name"
- ) -> np.array:
- """
- Classifies points as belonging to which region
+ def classify_points_with_property(self, points: Iterable[Point], property_name: str = "name") -> np.array:
+ """Classify points as belonging to which region.
Points = list of tuples [(x,y), (x,y)]
"""
@@ -66,10 +75,11 @@ def classify_points_with_property(
return np.array(points_regions)
def classify_points_with_property_shapely_prep(
- self, points: Iterable[Point], property_name: str = "name"
- ):
- """
- Classifies points as belonging to which region
+ self,
+ points: Iterable[Point],
+ property_name: str = "name",
+ ) -> np.array:
+ """Classify points as belonging to which region.
Points = list of tuples [(x,y), (x,y)]
"""
@@ -85,27 +95,25 @@ def classify_points_with_property_shapely_prep(
for i, point in enumerate(points):
for p_id, polygon in enumerate(prep_polygons):
if polygon.intersects(point):
- points_regions[i] = self.polygons[p_id].properties.get(
- property_name
- )
+ points_regions[i] = self.polygons[p_id].properties.get(property_name)
break
return np.array(points_regions)
def classify_points_with_property_rtree_by_polygons(
- self, points: Iterable[Point], property_name: str = "name"
+ self,
+ points: Iterable[Point],
+ property_name: str = "name",
) -> list:
- """Applies RTree index to quickly classify points in polygons.
-
- Arguments:
- iterable_points {Iterable[list]} -- List of unformatted points.
+ """Apply RTree index to quickly classify points in polygons.
- Keyword Arguments:
- property_name {str}
- -- Property to retrieve from the polygons (default: {'name'})
+ Args:
+ points (Iterable[list]): List of unformatted points.
+ property_name (str): Property to retrieve from the polygons (default: {'name'})
Returns:
- list -- List of mapped points to polygon properties.
+ (list): List of mapped points to polygon properties.
+
"""
idx = rtree.index.Index()
for p_id, polygon in enumerate(self.polygons):
@@ -126,29 +134,11 @@ def classify_points_with_property_rtree_by_polygons(
del idx
return np.array(point_properties_list)
- def __get_polygon_property(
- self, grouped_values: list, property_name: str
- ) -> str: # TODO: Can this be removed?
- """Retrieves the polygon property from the internal list of polygons.
-
- Arguments:
- grouped_values {int}
- -- Grouped values containing point and polygon id.
- property_name {str} -- Property to search.
-
- Returns:
- str -- Property value.
- """
- polygon_id = list(grouped_values[1])[0][0]
- if polygon_id >= len(self.polygons) or polygon_id < 0:
- return self.undefined
- return self.polygons[polygon_id].properties.get(property_name)
-
- def parse_geojson_file(self, file_path: Union[Path, str]) -> None:
- """Read data from geojson file"""
+ def parse_geojson_file(self, file_path: Path | str) -> None:
+ """Read data from geojson file."""
PolygonFile._validate_extension(file_path)
- with open(file_path) as geojson_file:
+ with Path(file_path).open("r") as geojson_file:
geojson_data = json.load(geojson_file).get("features")
for feature in geojson_data:
@@ -157,65 +147,74 @@ def parse_geojson_file(self, file_path: Union[Path, str]) -> None:
Polygon(
geometry=shape(feature["geometry"]).buffer(0),
properties=feature_props,
- )
+ ),
)
- # polygons = GeometryCollection([shape(feature["geometry"]).buffer(0) for feature in geojson_data])
- # polygon_names = [feature.get('properties').get('Name') for feature in geojson_data]
-
- # for polygon, polygon_name in zip(polygons, polygon_names):
- # self.polygons[polygon_name] = polygon
@staticmethod
- def _validate_extension(file_path: Union[Path, str]) -> None:
+ def _validate_extension(file_path: Path | str) -> None:
if isinstance(file_path, str):
file_path = Path(file_path)
if file_path.suffix not in (".json", ".geojson"):
- raise IOError("Invalid file path extension, should be .json or .geojson.")
+ err_msg = "Invalid file path extension, should be .json or .geojson."
+ raise OSError(err_msg)
- def _check_overlap(self):
+ def _check_overlap(self) -> None:
for polygon in self.polygons:
for testpoly in self.polygons:
if polygon.properties.get("name") == testpoly.properties.get("name"):
# polygon will obviously overlap with itself
continue
- else:
- if polygon.geometry.intersects(testpoly.geometry):
- self.set_logger_message(
- "{} overlaps {}.".format(
- polygon.properties.get("name"),
- testpoly.properties.get("name"),
- ),
- level="warning",
- )
+ if polygon.geometry.intersects(testpoly.geometry):
+ self.set_logger_message(
+ "{} overlaps {}.".format(
+ polygon.properties.get("name"),
+ testpoly.properties.get("name"),
+ ),
+ level="warning",
+ )
@property
def polygons(self) -> list[Polygon]:
+ """Polygons."""
return self._polygons
@polygons.setter
- def polygons(self, polygons_list: List[Polygon]) -> None:
- if not all([isinstance(polygon, Polygon) for polygon in polygons_list]):
- raise ValueError("Polygons must be of type Polygon")
+ def polygons(self, polygons_list: list[Polygon]) -> None:
+ if not all([isinstance(polygon, Polygon) for polygon in polygons_list]): # noqa: C419
+ err_msg = "Polygons must be of type Polygon"
+ raise ValueError(err_msg)
# Check if properties contain the required 'name' property
names = [polygon.properties.get("name") for polygon in polygons_list]
if not all(names):
- raise ValueError("Polygon properties must contain key-word 'name'")
+ err_msg = "Polygon properties must contain key-word 'name'"
+ raise ValueError(err_msg)
# Check if 'name' property is unique, otherwise _check_overlap will produce bugs
if len(names) != len(set(names)):
- raise ValueError("Property 'name' must be unique")
+ err_msg = "Property 'name' must be unique"
+ raise ValueError(err_msg)
self._polygons = polygons_list
class RegionPolygonFile(PolygonFile):
- def __init__(self, region_file_path, logger):
+ """RegionPolygonFile class."""
+
+ def __init__(self, region_file_path: str | Path, logger: Logger) -> None:
+ """Instantiate a RegionPolygonFile object."""
super().__init__(logger)
self.read_region_file(region_file_path)
@property
def regions(self) -> list[Polygon]:
+ """Region polygons."""
return self.polygons
- def read_region_file(self, file_path) -> None:
+ def read_region_file(self, file_path: Path | str) -> None:
+ """Read region file.
+
+ Args:
+ file_path (Path | str): region file path
+
+ """
self.parse_geojson_file(file_path)
self._validate_regions()
@@ -224,37 +223,68 @@ def _validate_regions(self) -> None:
number_of_regions = len(self.regions)
- self.set_logger_message(
- "{} regions found".format(number_of_regions), level="info"
- )
+ self.set_logger_message(f"{number_of_regions} regions found", level="info")
# Test if polygons overlap
self._check_overlap()
- def classify_points(
- self, points: Iterable[Point], property_name: str = "id"
- ) -> list:
+ def classify_points(self, points: Iterable[Point], property_name: str = "id") -> list:
+ """Classify region points with a property.
+
+ Args:
+ points (Iterable[Point]): Points to classify
+ property_name (str, optional): Property. Defaults to "id".
+
+ Returns:
+ list: _description_
+
+ """
return self.classify_points_with_property(points, property_name=property_name)
class SectionPolygonFile(PolygonFile):
- def __init__(self, section_file_path, logger: logging.Logger):
+ """SectionPolygonFile class."""
+
+ def __init__(self, section_file_path: str | Path, logger: Logger) -> None:
+ """Instantiate a SectionPolygonFile object.
+
+ Args:
+ section_file_path (str | Path): path to section polygon file.
+ logger (Logger): logger
+
+ """
super().__init__(logger)
self.read_section_file(section_file_path)
self.undefined = 1 # 1 is main
@property
- def sections(self):
+ def sections(self) -> list[Polygon]:
+ """Section polygons."""
return self.polygons
- def read_section_file(self, file_path):
+ def read_section_file(self, file_path: str | Path) -> None:
+ """Read section polygon file.
+
+ Args:
+ file_path (str | Path): path to section polygon file.
+
+ """
self.parse_geojson_file(file_path)
self._validate_sections()
- def classify_points(self, points: Iterable[Point]):
+ def classify_points(self, points: Iterable[Point]) -> np.array:
+ """Classify points with a section property name.
+
+ Args:
+ points (Iterable[Point]): List of points to classify.
+
+ Returns:
+ np.array: array of points
+
+ """
return self.classify_points_with_property(points, property_name="section")
- def _validate_sections(self):
+ def _validate_sections(self) -> None:
self.set_logger_message("Validating Section file")
raise_exception = False
@@ -270,27 +300,21 @@ def _validate_sections(self):
if "section" not in section.properties:
raise_exception = True
self.set_logger_message(
- 'Polygon {} has no property "section"'.format(
- section.properties.get("name")
- ),
+ 'Polygon {} has no property "section"'.format(section.properties.get("name")),
level="error",
)
- elif (
- str(section.properties.get("section")).lower() not in valid_section_keys
- ):
+ elif str(section.properties.get("section")).lower() not in valid_section_keys:
section_key = str(section.properties.get("section")).lower()
if section_key not in list(map_section_keys.keys()):
raise_exception = True
self.set_logger_message(
- "{} is not a recognized section".format(section_key),
+ f"{section_key} is not a recognized section",
level="error",
)
else:
self.set_logger_message(
- "remapped section {} to {}".format(
- section_key, map_section_keys[section_key]
- ),
+ f"remapped section {section_key} to {map_section_keys[section_key]}",
level="warning",
)
section.properties["section"] = map_section_keys.get(section_key)
@@ -298,5 +322,6 @@ def _validate_sections(self):
self._check_overlap()
if raise_exception:
- raise AssertionError("Section file is not valid")
+ err_msg = "Section file is not valid"
+ raise OSError(err_msg)
self.set_logger_message("Section file succesfully validated")
diff --git a/fm2prof/utils.py b/fm2prof/utils.py
index c7d0cd90..ff46af5e 100644
--- a/fm2prof/utils.py
+++ b/fm2prof/utils.py
@@ -1,24 +1,14 @@
+"""Utility module."""
+
+from __future__ import annotations
+
import ast
import locale
-import os
-import re
import warnings
from collections import namedtuple
-from dataclasses import dataclass
from datetime import datetime, timedelta
-from logging import Logger
from pathlib import Path
-from typing import (
- Any,
- Callable,
- Dict,
- Generator,
- Iterable,
- List,
- Optional,
- Tuple,
- Union,
-)
+from typing import TYPE_CHECKING, Any, Callable, Generator
import matplotlib as mpl
import matplotlib.dates as mdates
@@ -27,22 +17,26 @@
import numpy as np
import pandas as pd
import tqdm
-from matplotlib.figure import Figure
-from matplotlib.legend import Legend
from matplotlib.ticker import MultipleLocator
from netCDF4 import Dataset
from pandas.plotting import register_matplotlib_converters
-from scipy.interpolate import interp1d
-from fm2prof import Project
from fm2prof.common import FM2ProfBase
+if TYPE_CHECKING:
+ from io import TextIOWrapper
+ from logging import Logger
+
+ from matplotlib.axes import Axes
+ from matplotlib.figure import Figure
+ from matplotlib.legend import Legend
+
+ from fm2prof import Project
+
register_matplotlib_converters()
-FigureOutput = namedtuple("FigureOutput", ["fig", "axes", "legend"])
-StyleGuide = namedtuple(
- "StyleGuide", ["font", "major_grid", "minor_grid", "spine_width"]
-)
+FigureOutput = namedtuple("FigureOutput", ["fig", "axes", "legend"]) # noqa: PYI024
+StyleGuide = namedtuple("StyleGuide", ["font", "major_grid", "minor_grid", "spine_width"]) # noqa: PYI024
COLORSCHEMES = {
"Deltares": ["#000000", "#00cc96", "#0d38e0"],
@@ -69,8 +63,7 @@
class GenerateCrossSectionLocationFile(FM2ProfBase):
- """
- Builds a cross-section input file for FM2PROF from a SOBEK 3 DIMR network definition file.
+ """Build a cross-section input file for FM2PROF from a SOBEK 3 DIMR network definition file.
The distance between cross-section is computed from the differences between the offsets/chainages.
The beginning and end point of each branch are treated as half-distance control volumes.
@@ -82,7 +75,7 @@ class GenerateCrossSectionLocationFile(FM2ProfBase):
>>> GenerateCrossSectionLocationFile(**input)
Parameters
-
+ ----------
networkdefinitionfile: path to NetworkDefinitionFile.ini
crossectionlocationfile: path to the desired output file
@@ -136,35 +129,41 @@ class GenerateCrossSectionLocationFile(FM2ProfBase):
def __init__(
self,
- networkdefinitionfile: Union[str, Path],
- crossectionlocationfile: Union[str, Path],
- branchrulefile: Optional[Union[str, Path]] = "",
- ):
+ network_definition_file: str | Path,
+ crossection_location_file: str | Path,
+ branchrule_file: str | Path = "",
+ ) -> None:
+ """Generate cross section location file object.
+
+ Args:
+ network_definition_file (str | Path): network definition file
+ crossection_location_file (str | Path): crosssection location file
+ branchrule_file (str | Path, optional): . Defaults to "".
+
+ """
super().__init__()
- networkdefinitionfile, crossectionlocationfile, branchrulefile = map(
- Path, [networkdefinitionfile, crossectionlocationfile, branchrulefile]
+ network_definition_file, crossection_location_file, branchrule_file = map(
+ Path,
+ [network_definition_file, crossection_location_file, branchrule_file],
)
- required_files = (networkdefinitionfile.is_file(),)
- if not all(required_files):
- raise FileNotFoundError
+ if not network_definition_file.exists():
+ err_msg = "Network difinition file not found"
+ raise FileNotFoundError(err_msg)
- self._networkdeffile_to_input(
- networkdefinitionfile, crossectionlocationfile, branchrulefile
- )
+ self._network_definition_file_to_input(network_definition_file, crossection_location_file, branchrule_file)
+
+ def _parse_network_definition_file(self, network_definition_file: Path, branchrules: dict | None = None) -> dict:
+ """Parse network definition file.
- def _parse_NetworkDefinitionFile(
- self, networkdefinitionfile: Path, branchrules: Optional[Dict] = None
- ) -> Dict:
- """
Output:
- x,y : coordinates of cross-section
- cid : name of the cross-section
- cdis: half-way distance between cross-section points on either side
- bid : name of the branch
- coff: chainage of cross-section on branch
+ x,y : coordinates of cross-section
+ cid : name of the cross-section
+ cdis: half-way distance between cross-section points on either side
+ bid : name of the branch
+ coff: chainage of cross-section on branch
"""
if not branchrules:
@@ -178,12 +177,11 @@ def _parse_NetworkDefinitionFile(
coff = [] # offset of cross-section on 1D branch ('chainage')
cdis = [] # distance of 1D branch influenced by crosss-section ('vaklengte')
- with open(networkdefinitionfile, "r") as f:
+ with network_definition_file.open("r") as f:
for line in f:
if line.strip().lower() == "[branch]":
branchid = f.readline().split("=")[1].strip()
- xlength = 0
- for i in range(10):
+ for _ in range(10):
bline = f.readline().strip().lower().split("=")
if bline[0].strip() == "gridpointx":
xtmp = list(map(float, bline[1].split()))
@@ -195,9 +193,7 @@ def _parse_NetworkDefinitionFile(
cofftmp = list(map(float, bline[1].split()))
# compute distance between control volumes
- cdistmp = np.append(np.diff(cofftmp) / 2, [0]) + np.append(
- [0], np.diff(cofftmp) / 2
- )
+ cdistmp = np.append(np.diff(cofftmp) / 2, [0]) + np.append([0], np.diff(cofftmp) / 2)
cdistmp = list(cdistmp)
# Append branchids
@@ -227,9 +223,7 @@ def _parse_NetworkDefinitionFile(
cdistmp,
bidtmp,
cofftmp,
- ) = self._applyBranchRules(
- rule, xtmp, ytmp, cidtmp, cdistmp, bidtmp, cofftmp
- )
+ ) = self._apply_branch_rules(rule, xtmp, ytmp, cidtmp, cdistmp, bidtmp, cofftmp)
if exceptions:
(
xtmp,
@@ -238,9 +232,7 @@ def _parse_NetworkDefinitionFile(
cdistmp,
bidtmp,
cofftmp,
- ) = self._applyBranchExceptions(
- exceptions, xtmp, ytmp, cidtmp, cdistmp, bidtmp, cofftmp
- )
+ ) = self._apply_branch_exceptions(exceptions, xtmp, ytmp, cidtmp, cdistmp, bidtmp, cofftmp)
c = len(xtmp)
for ic in xtmp, ytmp, cidtmp, cdistmp, bidtmp, cofftmp:
if len(ic) != c:
@@ -253,27 +245,33 @@ def _parse_NetworkDefinitionFile(
cdis.extend(cdistmp)
bid.extend(bidtmp)
coff.extend(cofftmp)
+ return {"x": x, "y": y, "css_id": cid, "css_len": cdis, "branch_id": bid, "css_offset": coff}
- return dict(x=x, y=y, css_id=cid, css_len=cdis, branch_id=bid, css_offset=coff)
-
- def _networkdeffile_to_input(
+ def _network_definition_file_to_input(
self,
- networkdefinitionfile: Path,
- crossectionlocationfile: Path,
- branchrulefile: Path,
- ):
- branchrules: Dict = {}
+ network_definition_file: Path,
+ crossection_location_file: Path,
+ branchrule_file: Path,
+ ) -> None:
+ branchrules: dict = {}
- if branchrulefile.is_file():
- branchrules = self._parseBranchRuleFile(branchrulefile)
+ if branchrule_file.is_file():
+ branchrules = self._parse_branch_rule_file(branchrule_file)
- network_dict = self._parse_NetworkDefinitionFile(
- networkdefinitionfile, branchrules
- )
+ network_dict = self._parse_network_definition_file(network_definition_file, branchrules)
- self._writeCrossSectionLocationFile(crossectionlocationfile, network_dict)
+ self._write_cross_section_location_file(crossection_location_file, network_dict)
- def _applyBranchExceptions(self, exceptions, x, y, cid, cdis, bid, coff):
+ def _apply_branch_exceptions( # noqa: PLR0913
+ self,
+ exceptions: list[str],
+ x: list[float],
+ y: list[float],
+ cid: list[str],
+ cdis: list[float],
+ bid: list[str],
+ coff: list[float],
+ ) -> tuple[list[float], list[float], list[str], list[float], list[str], list[float]]:
for exc in exceptions:
if exc not in cid:
self.set_logger_message(f"{exc} not found in branch", "error")
@@ -290,7 +288,7 @@ def _applyBranchExceptions(self, exceptions, x, y, cid, cdis, bid, coff):
cdis,
bid,
coff,
- ) = self._applyBranchRules("ignorefirst", x, y, cid, cdis, bid, coff)
+ ) = self._apply_branch_rules("ignorefirst", x, y, cid, cdis, bid, coff)
elif pop_index == len(x) - 1:
(
x,
@@ -299,29 +297,35 @@ def _applyBranchExceptions(self, exceptions, x, y, cid, cdis, bid, coff):
cdis,
bid,
coff,
- ) = self._applyBranchRules("ignorelast", x, y, cid, cdis, bid, coff)
+ ) = self._apply_branch_rules("ignorelast", x, y, cid, cdis, bid, coff)
else:
- # the distance of the popped value is divided over the two on aither side.
+ # the distance of the popped value is divided over the two on aither side.
cdis[pop_index - 1] += cdis[pop_index] / 2
cdis[pop_index + 1] += cdis[pop_index] / 2
-
+
# then, pop the value
- for l in [x, y, cid, cdis, bid, coff]:
- l.pop(pop_index)
-
+ for v in [x, y, cid, cdis, bid, coff]:
+ v.pop(pop_index)
return x, y, cid, cdis, bid, coff
- def _applyBranchRules(
+ def _apply_branch_rules( # noqa: PLR0913
self,
rule: str,
x: float,
y: float,
- cid: float,
+ cid: str,
cdis: float,
bid: str,
coff: float,
- ):
+ ) -> tuple[
+ list[float],
+ list[float],
+ list[str],
+ list[float],
+ list[str],
+ list[float],
+ ]:
# bfunc: what points to pop (remove from list)
bfunc = {
"norule": lambda x: x,
@@ -329,13 +333,9 @@ def _applyBranchRules(
x[0],
x[-1],
], # only keep the 2 cross-section on either end of the branch
- "ignoreedges": lambda x: x[
- 1:-1
- ], # keep everything except 2 css on either end of the branch
+ "ignoreedges": lambda x: x[1:-1], # keep everything except 2 css on either end of the branch
"ignorelast": lambda x: x[:-1], # keep everything except last css on branch
- "ignorefirst": lambda x: x[
- 1:
- ], # keep everything except first css on branch
+ "ignorefirst": lambda x: x[1:], # keep everything except first css on branch
"onlyfirst": lambda x: [x[0]], # keep only the first css on branch
"onlylast": lambda x: [x[-1]], # keep only the last css on branch
}
@@ -352,39 +352,34 @@ def _applyBranchRules(
try:
bf = bfunc[rule.lower().strip()]
- df = disfunc[rule.lower().strip()]
- return bf(x), bf(y), bf(cid), df(cdis), bf(bid), bf(coff)
+ disf = disfunc[rule.lower().strip()]
+ return bf(x), bf(y), bf(cid), disf(cdis), bf(bid), bf(coff)
except KeyError:
self.set_logger_message(
f"'{rule}' is not a known branchrules. Known rules are: {list(bfunc.keys())}",
"error",
)
- def _parseBranchRuleFile(
- self, branchrulefile: Path, delimiter: str = ","
- ) -> Dict[str, Dict]:
- """
- Parses the branchrule file which is a delimited file (comma by default)
- """
+ def _parse_branch_rule_file(self, branchrulefile: Path, delimiter: str = ",") -> dict[str, dict]:
+ """Parse the branchrule file which is a delimited file (comma by default)."""
branchrules: dict = {}
- with open(branchrulefile, "r") as f:
+ with branchrulefile.open("r") as f:
lines = [line.strip().split(delimiter) for line in f if len(line) > 1]
for line in lines:
branch: str = line[0].strip()
rule: str = line[1].strip()
- exceptions: List = []
- if len(line) > 2:
+ exceptions: list = []
+ if len(line) > 2: # noqa: PLR2004
exceptions = [e.strip() for e in line[2:]]
- branchrules[branch] = dict(rule=rule, exceptions=exceptions)
+ branchrules[branch] = {"rule": rule, "exceptions": exceptions}
return branchrules
- def _writeCrossSectionLocationFile(
- self, crossectionlocationfile: Path, network_dict: Dict
- ):
- """
+ def _write_cross_section_location_file(self, crossectionlocationfile: Path, network_dict: dict) -> None:
+ """Write cross section location file.
+
List inputs:
x,y : coordinates of cross-section
@@ -400,15 +395,15 @@ def _writeCrossSectionLocationFile(
bid = network_dict.get("branch_id")
coff = network_dict.get("css_offset")
- with open(crossectionlocationfile, "w") as f:
+ with crossectionlocationfile.open("w") as f:
f.write("name,x,y,length,branch,offset\n")
for i in range(len(x)):
- f.write(
- f"{cid[i]}, {x[i]:.4f}, {y[i]:.4f}, {cdis[i]:.2f}, {bid[i]}, {coff[i]:.2f}\n"
- )
+ f.write(f"{cid[i]}, {x[i]:.4f}, {y[i]:.4f}, {cdis[i]:.2f}, {bid[i]}, {coff[i]:.2f}\n")
class VisualiseOutput(FM2ProfBase):
+ """Visaulise output class."""
+
__cssdeffile = "CrossSectionDefinitions.ini"
__volumefile = "volumes.csv"
__rmainfile = "roughness-Main.ini"
@@ -417,10 +412,9 @@ class VisualiseOutput(FM2ProfBase):
def __init__(
self,
output_directory: str,
- figure_type: str = "png",
- overwrite: bool = True,
- logger: Logger = None,
- ):
+ logger: Logger | None = None,
+ ) -> None:
+ """Instantiate a VisualiseOutput object."""
super().__init__(logger=logger)
if not logger:
@@ -438,47 +432,38 @@ def __init__(
PlotStyles.apply()
@property
- def branches(self) -> Generator[List[str], None, None]:
- def split_css(name) -> Tuple[str, float, str]:
- chainage = float(name.split("_")[-1])
- branch = "_".join(name.split("_")[:-1])
- return (name, chainage, branch)
-
- def find_branches(css_list) -> List[str]:
- branches = np.unique([i[2] for i in css_names])
- contiguous_branches = np.unique([b.split("_")[0] for b in branches])
- return branches, contiguous_branches
-
- css_names = [split_css(css.get("id")) for css in self.cross_sections]
- branches, contiguous_branches = find_branches(css_names)
+ def branches(self) -> tuple[np.ndarray, np.ndarray]:
+ """Get branches."""
+ css_names = [self.split_css(css.get("id")) for css in self.cross_sections]
+ branches = np.unique([i[2] for i in css_names])
+ contiguous_branches = np.unique([b.split("_")[0] for b in branches])
return branches, contiguous_branches
@property
def number_of_cross_sections(self) -> int:
+ """Get number of cross sections."""
return len(list(self.cross_sections))
@property
- def cross_sections(self) -> Generator[Dict, None, None]:
- """
- Generator to loop through all cross-sections in definition file.
+ def cross_sections(self) -> Generator[dict, None, None]:
+ """Generator to loop through all cross-sections in definition file.
Example use:
>>> for css in visualiser.cross_sections:
>>> visualiser.make_figure(css)
"""
- csslist = self._readCSSDefFile()
- for css in csslist:
- yield css
+ csslist = self._read_css_def_file()
+ yield from csslist
+
+ def figure_roughness_longitudinal(self, branch: str) -> None:
+ """Get figure of longitudinal roughness.
- def figure_roughness_longitudinal(self, branch: str):
- """
Assumes the following naming convention:
[branch]_[optional:branch_order]_[chainage]
"""
output_dir = self.fig_dir.joinpath("roughness")
- if not output_dir.is_dir():
- os.mkdir(output_dir)
+ output_dir.mkdir(exist_ok=True)
fig, ax = plt.subplots(1, figsize=(12, 5))
@@ -488,7 +473,7 @@ def figure_roughness_longitudinal(self, branch: str):
minmax = []
for cross_section in css:
chainage.append(cross_section[1])
- roughness = self.getRoughnessInfoForCss(cross_section[0])[1]
+ roughness = self.get_roughness_info_for_css(cross_section[0])[1]
minmax.append([min(roughness), max(roughness)])
chainage = np.array(chainage) * 1e-3
@@ -498,61 +483,73 @@ def figure_roughness_longitudinal(self, branch: str):
ax.set_ylabel("Ruwheid (Chezy)")
ax.set_xlabel("Afstand [km]")
ax.set_title(branch)
- fig, lgd = self._SetPlotStyle(fig, use_legend=True)
+ fig, lgd = self._set_plot_style(fig, use_legend=True)
plt.savefig(
output_dir.joinpath(f"roughness_longitudinal_{branch}.png"),
bbox_extra_artists=[lgd],
bbox_inches="tight",
)
- def get_cross_sections_for_branch(self, branch: str) -> Tuple[str, float, str]:
- if branch not in self.branches:
- raise KeyError(f"Branch {branch} not in known branches: {self.branches}")
+ def get_cross_sections_for_branch(self, branch: str) -> tuple[str, float, str]:
+ """Get cross sections for branch name.
- def split_css(name) -> Tuple[str, float, str]:
- chainage = float(name.split("_")[-1])
- branch = "_".join(name.split("_")[:-1])
- return (name, chainage, branch)
+ Args:
+ branch (str): branch name.
+
+
+ Returns:
+ tuple[str, float, str]:
- def get_css_for_branch(css_list, branchname: str):
- return [c for c in css_list if c[2].startswith(branchname)]
+ """
+ if branch not in self.branches:
+ err_msg = f"Branch {branch} not in known branches: {self.branches}"
+ raise KeyError(err_msg)
- css_list = [split_css(css.get("id")) for css in self.cross_sections]
+ css_list = [self.split_css(css.get("id")) for css in self.cross_sections]
branches, contiguous_branches = self.branches
branch_list = []
sub_branches = np.unique([b for b in branches if b.startswith(branch)])
running_chainage = 0
for i, sub_branch in enumerate(sub_branches):
- sublist = get_css_for_branch(css_list, sub_branch)
+ sublist = self.get_css_for_branch(css_list, sub_branch)
if i > 0:
- running_chainage += get_css_for_branch(css_list, sub_branches[i - 1])[
- -1
- ][1]
+ running_chainage += self.get_css_for_branch(css_list, sub_branches[i - 1])[-1][1]
branch_list.extend([(s[0], s[1] + running_chainage, s[2]) for s in sublist])
return branch_list
- def getRoughnessInfoForCss(self, cssname, rtype: str = "roughnessMain"):
- """
- Opens roughness file and reads information for a given cross-section
- name
- """
+ @staticmethod
+ def split_css(name: str) -> tuple[str, float, str]:
+ """Split cross section name."""
+ chainage = float(name.split("_")[-1])
+ branch = "_".join(name.split("_")[:-1])
+ return (name, chainage, branch)
+
+ @staticmethod
+ def get_css_for_branch(css_list: list[tuple[str, float, str]], branchname: str) -> list[tuple[str, float, str]]:
+ """Get cross section for given branch name."""
+ return [c for c in css_list if c[2].startswith(branchname)]
+
+ def get_roughness_info_for_css(self, cssname: str, rtype: str = "roughnessMain") -> tuple[list, list]:
+ """Open roughness file and reads information for a given cross-section name."""
levels = None
values = None
- with open(self.files[rtype], "r") as f:
+ with self.files[rtype].open("r") as f:
cssbranch, csschainage = self._parse_cssname(cssname)
for line in f:
- if line.strip().lower() == "[branchproperties]":
- if self._getValueFromLine(f).lower() == cssbranch:
- [f.readline() for i in range(3)]
- levels = list(map(float, self._getValueFromLine(f).split()))
- if line.strip().lower() == "[definition]":
- if self._getValueFromLine(f).lower() == cssbranch:
- if float(self._getValueFromLine(f).lower()) == csschainage:
- values = list(map(float, self._getValueFromLine(f).split()))
+ if line.strip().lower() == "[branchproperties]" and self._get_value_from_line(f).lower() == cssbranch:
+ [f.readline() for i in range(3)]
+ levels = list(map(float, self._get_value_from_line(f).split()))
+ if (
+ line.strip().lower() == "[definition]"
+ and self._get_value_from_line(f).lower() == cssbranch
+ and float(self._get_value_from_line(f).lower()) == csschainage
+ ):
+ values = list(map(float, self._get_value_from_line(f).split()))
return levels, values
- def getVolumeInfoForCss(self, cssname):
+ def get_volume_info_for_css(self, cssname: str) -> dict:
+ """Get volume info for cross section."""
column_names = [
"z",
"2D_total_volume",
@@ -568,9 +565,9 @@ def getVolumeInfoForCss(self, cssname):
]
cssdata = {}
for column in column_names:
- cssdata[column] = list()
+ cssdata[column] = []
- with open(self.files["volumes"], "r") as f:
+ with self.files["volumes"].open("r") as f:
for line in f:
values = line.strip().split(",")
if values[0] == cssname:
@@ -579,51 +576,48 @@ def getVolumeInfoForCss(self, cssname):
return cssdata
- def get_cross_section_by_id(self, id: str) -> dict:
- """
- Get cross-section information given an id.
+ def get_cross_section_by_id(self, css_id: str) -> dict | None:
+ """Get cross-section information given an id.
+
+ Args:
+ css_id (str): cross-section name
- Arguments:
- id (str): cross-section name
"""
- csslist = self._readCSSDefFile()
+ csslist = self._read_css_def_file()
for css in csslist:
- if css.get("id") == id:
+ if css.get("id") == css_id:
return css
+ return None
def figure_cross_section(
self,
- css,
+ css: dict,
reference_geometry: tuple = (),
reference_roughness: tuple = (),
+ *,
save_to_file: bool = True,
overwrite: bool = False,
- pbar: tqdm.std.tqdm = None,
- ) -> None:
- """
- Creates a figure
-
- Arguments
-
- css: dictionary containing cross-section information. Obtain with `VisualiseOutput.cross_sections`
- generator or `VisualiseOutput.get_cross_section_by_id` method.
+ ) -> Figure:
+ """Get a figure of the cross section.
- reference_geometry (tuple): tuple(list(y), list(z))
+ Args:
+ css (dict): cross section dict
+ reference_geometry (tuple, optional): tuple of reference . Defaults to ().
+ reference_roughness (tuple, optional): _description_. Defaults to ().
+ save_to_file (bool, optional): Save the figure to file. Defaults to True.
+ overwrite (bool, optional): Overwrite the figure. Defaults to False.
- reference_roughness (tuple): tuple(list(z), list(n))
-
- save_to_file (bool): if true, save figure to VisualiseOutput.fig_dir
- if false, returns pyplot figure object
+ Returns:
+ Figure: _description_
"""
output_dir = self.fig_dir.joinpath("cross_sections")
- if not output_dir.is_dir():
- os.mkdir(output_dir)
+ output_dir.mkdir(exist_ok=True)
output_file = output_dir.joinpath(f"{css['id']}.png")
if output_file.is_file() and not overwrite:
self.set_logger_message("file already exists", "debug")
- return
+ return None
try:
fig = plt.figure(figsize=(8, 12))
gs = fig.add_gridspec(2, 2)
@@ -637,7 +631,7 @@ def figure_cross_section(
self._plot_volume(css, axs[1])
self._plot_roughness(css, axs[2], reference_roughness)
- fig, lgd = self._SetPlotStyle(fig)
+ fig, lgd = self._set_plot_style(fig)
if save_to_file:
plt.savefig(
@@ -649,15 +643,17 @@ def figure_cross_section(
return fig
except Exception as e:
- self.set_logger_message(f"error processing: {css['id']} {str(e)}", "error")
+ self.set_logger_message(f"error processing: {css['id']} {e!s}", "error")
return None
finally:
plt.close()
- def plot_cross_sections(self):
- """Makes figures for all cross-sections in project,
- output to output directory of project"""
+ def plot_cross_sections(self) -> None:
+ """Plot figures for all cross-sections in project.
+
+ Outputs to output directory of project.
+ """
pbar = tqdm.tqdm(total=self.number_of_cross_sections)
self.start_new_log_task("Plotting cross-secton figures", pbar=pbar)
@@ -667,50 +663,48 @@ def plot_cross_sections(self):
self.finish_log_task()
- def _generate_output_dir(self, figure_type: str = "png", overwrite: bool = True):
- """
- Creates a new directory in the output map to store figures for each cross-section
+ def _generate_output_dir(self) -> Path:
+ """Create a new directory in the output map to store figures for each cross-section.
Arguments:
output_map - path to fm2prof output directory
Returns:
png images saved to file
- """
+ """
figdir = self.output_dir.joinpath("figures")
- if not figdir.is_dir():
- figdir.mkdir(parents=True)
+ figdir.mkdir(parents=True, exist_ok=True)
return figdir
- def _set_files(self):
+ def _set_files(self) -> None:
self.files = {
- "css_def": os.path.join(self.output_dir, self.__cssdeffile),
- "volumes": os.path.join(self.output_dir, self.__volumefile),
- "roughnessMain": os.path.join(self.output_dir, self.__rmainfile),
- "roughnessFP1": os.path.join(self.output_dir, self.__rfp1file),
+ "css_def": self.output_dir / self.__cssdeffile,
+ "volumes": self.output_dir / self.__volumefile,
+ "roughnessMain": self.output_dir / self.__rmainfile,
+ "roughnessFP1": self.output_dir / self.__rfp1file,
}
- def _getValueFromLine(self, f):
+ def _get_value_from_line(self, f: TextIOWrapper) -> str:
return f.readline().strip().split("=")[1].strip()
- def _readCSSDefFile(self) -> List[Dict]:
- csslist = list()
+ def _read_css_def_file(self) -> list[dict]:
+ csslist = []
- with open(self.files.get("css_def"), "r") as f:
+ with self.files.get("css_def").open("r") as f:
for line in f:
if line.lower().strip() == "[definition]":
css_id = f.readline().strip().split("=")[1]
[f.readline() for i in range(3)]
- css_levels = list(map(float, self._getValueFromLine(f).split()))
- css_fwidth = list(map(float, self._getValueFromLine(f).split()))
- css_twidth = list(map(float, self._getValueFromLine(f).split()))
- css_sdcrest = float(self._getValueFromLine(f))
- css_sdflow = float(self._getValueFromLine(f))
- css_sdtotal = float(self._getValueFromLine(f))
- css_sdbaselevel = float(self._getValueFromLine(f))
- css_mainsectionwidth = float(self._getValueFromLine(f))
- css_fp1sectionwidth = float(self._getValueFromLine(f))
+ css_levels = list(map(float, self._get_value_from_line(f).split()))
+ css_fwidth = list(map(float, self._get_value_from_line(f).split()))
+ css_twidth = list(map(float, self._get_value_from_line(f).split()))
+ css_sdcrest = float(self._get_value_from_line(f))
+ css_sdflow = float(self._get_value_from_line(f))
+ css_sdtotal = float(self._get_value_from_line(f))
+ css_sdbaselevel = float(self._get_value_from_line(f))
+ css_mainsectionwidth = float(self._get_value_from_line(f))
+ css_fp1sectionwidth = float(self._get_value_from_line(f))
css = {
"id": css_id.strip(),
@@ -728,36 +722,36 @@ def _readCSSDefFile(self) -> List[Dict]:
return csslist
- def _SetPlotStyle(self, *args, **kwargs):
- """todo: add preference to switch styles or
+ def _set_plot_style(self, *args: tuple, **kwargs: dict) -> tuple[Figure, Legend]:
+ """Set plot style.
+
+ TODO: add preference to switch styles or
inject own style
"""
return PlotStyles.apply(*args, **kwargs)
- def _plot_geometry(self, css, ax, reference_geometry=None):
+ def _plot_geometry(self, css: dict, ax: Axes, reference_geometry: list | None = None) -> None:
# Get data
tw = np.append([0], np.array(css["total_width"]))
fw = np.append([0], np.array(css["flow_width"]))
- l = np.append(css["levels"][0], np.array(css["levels"]))
+ levels = np.append(css["levels"][0], np.array(css["levels"]))
mainsectionwidth = css["mainsectionwidth"]
fp1sectionwidth = css["fp1sectionwidth"]
# Get the water level where water level independent computation takes over
# this is the lowest level where there is 2D information on volumes
- z_waterlevel_independent = self._get_lowest_water_level_in_2D(css)
+ z_waterlevel_independent = self._get_lowest_water_level_in_2d(css)
# Plot cross-section geometry
for side in [-1, 1]:
- h = ax.fill_betweenx(
- l, side * fw / 2, side * tw / 2, color="#44B1D5AA", hatch="////"
- )
- ax.plot(side * tw / 2, l, "-k")
- ax.plot(side * fw / 2, l, "--k")
+ h = ax.fill_betweenx(levels, side * fw / 2, side * tw / 2, color="#44B1D5AA", hatch="////")
+ ax.plot(side * tw / 2, levels, "-k")
+ ax.plot(side * fw / 2, levels, "--k")
# Plot roughness section width
ax.plot(
[-0.5 * mainsectionwidth, 0.5 * mainsectionwidth],
- [min(l) - 0.25] * 2,
+ [min(levels) - 0.25] * 2,
"-",
linewidth=2,
color="red",
@@ -768,17 +762,16 @@ def _plot_geometry(self, css, ax, reference_geometry=None):
-0.5 * (mainsectionwidth + fp1sectionwidth),
0.5 * (mainsectionwidth + fp1sectionwidth),
],
- [min(l) - 0.25] * 2,
+ [min(levels) - 0.25] * 2,
"--",
color="red",
label="Floodplain section",
)
- # ax.plot(tw-0.5*max(fp1sectionwidth),[min(l)]*len(l), '--', color='cyan', label='Floodplain section')
# Plot water level indepentent line
ax.plot(
tw - 0.5 * max(tw),
- [z_waterlevel_independent] * len(l),
+ [z_waterlevel_independent] * len(levels),
linestyle="--",
color="m",
label="Lowest water level in 2D",
@@ -810,10 +803,10 @@ def _plot_geometry(self, css, ax, reference_geometry=None):
label=sd_info.get("label"),
)
- def _plot_volume(self, css, ax):
+ def _plot_volume(self, css: dict, ax: Axes) -> None:
# Get data
- vd = self.getVolumeInfoForCss(css["id"])
- z_waterlevel_independent = self._get_lowest_water_level_in_2D(css)
+ vd = self.get_volume_info_for_css(css["id"])
+ z_waterlevel_independent = self._get_lowest_water_level_in_2d(css)
# Plot 1D volumes
ax.fill_between(
@@ -875,17 +868,15 @@ def _plot_volume(self, css, ax):
ax.set_xlabel("Water level [m]")
ax.set_ylabel("Volume [m$^3$]")
- def _plot_roughness(self, css, ax, reference_roughness):
- levels, values = self.getRoughnessInfoForCss(css["id"], rtype="roughnessMain")
+ def _plot_roughness(self, css: dict, ax: Axes, reference_roughness: tuple) -> None:
+ levels, values = self.get_roughness_info_for_css(css["id"], rtype="roughnessMain")
try:
ax.plot(levels, values, label="Main channel")
except:
pass
try:
- levels, values = self.getRoughnessInfoForCss(
- css["id"], rtype="roughnessFP1"
- )
+ levels, values = self.get_roughness_info_for_css(css["id"], rtype="roughnessFP1")
if levels is not None and values is not None:
ax.plot(levels, values, label="Floodplain1")
except FileNotFoundError:
@@ -906,60 +897,57 @@ def _plot_roughness(self, css, ax, reference_roughness):
ax.set_ylabel("Chezy coefficient [m$^{1/2}$/s]")
@staticmethod
- def _get_sd_plot_info(css):
- l = np.append(css["levels"][0], np.array(css["levels"]))
+ def _get_sd_plot_info(css: dict) -> dict:
+ v = np.append(css["levels"][0], np.array(css["levels"]))
z_crest_level = css["SD_crest"]
- if z_crest_level <= max(l):
- if z_crest_level >= min(l):
+ if z_crest_level <= max(v):
+ if z_crest_level >= min(v):
sd_linestyle = "--"
sd_label = "SD Crest Level"
else:
- z_crest_level = min(l)
+ z_crest_level = min(v)
sd_linestyle = "-"
sd_label = "SD Crest Level (cropped)"
else:
- z_crest_level = max(l)
+ z_crest_level = max(v)
sd_linestyle = "-"
sd_label = "SD Crest Level (cropped)"
return {"linestyle": sd_linestyle, "label": sd_label, "crest": z_crest_level}
- def _get_lowest_water_level_in_2D(self, css):
- vd = self.getVolumeInfoForCss(css["id"])
+ def _get_lowest_water_level_in_2d(self, css: dict) -> float:
+ vd = self.get_volume_info_for_css(css["id"])
index_waterlevel_independent = np.argmax(~np.isnan(vd.get("2D_total_volume")))
- z_waterlevel_independent = vd.get("z")[index_waterlevel_independent]
- return z_waterlevel_independent
+ return vd.get("z")[index_waterlevel_independent]
- def _parse_cssname(self, cssname):
- """
- returns name of branch and chainage
- """
- branch, chainage = cssname.rsplit(
- "_", 1
- ) # rsplit prevents error if branchname contains _
+ def _parse_cssname(self, cssname: str) -> tuple[str, float]:
+ """Return name of branch and chainage."""
+ branch, chainage = cssname.rsplit("_", 1) # rsplit prevents error if branchname contains _
chainage = round(float(chainage), 2)
return branch, chainage
class PlotStyles:
- myFmt = mdates.DateFormatter("%d-%b")
+ """Class for handling and applying plot styles."""
+
+ my_fmt = mdates.DateFormatter("%d-%b")
monthlocator = mdates.MonthLocator(bymonthday=(1, 10, 20))
daylocator = mdates.DayLocator(interval=5)
colorscheme = COLORSCHEMES["Koeln"]
@staticmethod
- def set_locale(localeString: str):
+ def set_locale(locale_string: str) -> None:
+ """Set locale."""
try:
- locale.setlocale(locale.LC_TIME, localeString)
+ locale.setlocale(locale.LC_TIME, locale_string)
except locale.Error:
# known error on linux fix:
# export LC_ALL="en_US.UTF-8" & export LC_CTYPE="en_US.UTF-8" & sudo dpkg-reconfigure locales
- print(f"could not set locale to {localeString}")
- pass
+ print(f"could not set locale to {locale_string}")
@staticmethod
- def _is_timeaxis(axis) -> bool:
+ def _is_timeaxis(axis: Axes) -> bool:
try:
label_string = axis.get_ticklabels()[0].get_text().replace("−", "-")
# if label_string is empty (e.g. because of twin_axis, return false)
@@ -976,12 +964,15 @@ def _is_timeaxis(axis) -> bool:
def van_veen(
cls,
fig: Figure | None = None,
+ *,
use_legend: bool = True,
- extra_labels: List | None = None,
+ extra_labels: list | None = None,
ax_align_legend: plt.Axes | None = None,
- ):
- warnings.warn(
- "This function is deprecated and will be removed on future versions. Use PlotStyle.apply(fig, style='van_veen') instead",
+ ) -> None:
+ """Apply van veen plotstyle."""
+ warnings.warn( # noqa: B028
+ "This function is deprecated and will be removed on future versions."
+ "Use PlotStyle.apply(fig, style='van_veen') instead",
category=DeprecationWarning,
)
cls.apply(
@@ -997,138 +988,143 @@ def apply(
cls,
fig: Figure | None = None,
style: str = "sito",
- use_legend: bool = True,
- extra_labels: List | None = None,
+ extra_labels: list | None = None,
ax_align_legend: plt.Axes | None = None,
- ) -> Tuple[Figure, Legend]:
- styles: Dict[str, StyleGuide] = dict(
- sito=StyleGuide(
+ *,
+ use_legend: bool = True,
+ ) -> tuple[Figure, Legend]:
+ """Apply style to figure."""
+ styles: dict[str, StyleGuide] = {
+ "sito": StyleGuide(
font={"family": "Franca, Arial", "weight": "normal", "size": 16},
- major_grid=dict(
- visible=True,
- which="major",
- linestyle="--",
- linewidth=1.0,
- color="#BBBBBB",
- ),
- minor_grid=dict(
- visible=True,
- which="minor",
- linestyle="--",
- linewidth=1.0,
- color="#BBBBBB",
- ),
+ major_grid={
+ "visible": True,
+ "which": "major",
+ "linestyle": "--",
+ "linewidth": 1.0,
+ "color": "#BBBBBB",
+ },
+ minor_grid={
+ "visible": True,
+ "which": "minor",
+ "linestyle": "--",
+ "linewidth": 1.0,
+ "color": "#BBBBBB",
+ },
spine_width=1,
),
- van_veen=StyleGuide(
+ "van_veen": StyleGuide(
font={"family": "Bahnschrift", "weight": "normal", "size": 18},
- major_grid=dict(
- visible=True, which="major", linestyle="-", linewidth=1, color="k"
- ),
- minor_grid=dict(
- visible=True, which="minor", linestyle="-", linewidth=0.5, color="k"
- ),
+ major_grid={"visible": True, "which": "major", "linestyle": "-", "linewidth": 1, "color": "k"},
+ minor_grid={"visible": True, "which": "minor", "linestyle": "-", "linewidth": 0.5, "color": "k"},
spine_width=2,
),
- )
+ }
if style not in styles:
- raise KeyError(f"unknown style {style}. Options are {list(styles.keys())}")
+ err_msg = f"unknown style {style}. Options are {list(styles.keys())}"
+ raise KeyError(err_msg)
style_guide: StyleGuide = styles.get(style)
- def initiate() -> None:
- # Set default locale to NL
- # TODO: add localization options (#85)
- PlotStyles.set_locale("nl_NL.UTF-8")
-
- # Color style
- mpl.rcParams["axes.prop_cycle"] = mpl.cycler(
- color=cls.colorscheme * 3,
- linestyle=["-"] * len(cls.colorscheme)
- + ["--"] * len(cls.colorscheme)
- + ["-."] * len(cls.colorscheme),
- )
+ if not fig:
+ return cls._initiate(style_guide)
+ cls._initiate(style_guide)
- # Font style
- font = style_guide.font
-
- # not all fonts support the unicode minus, so disable this option
- mpl.rc("font", **font)
- mpl.rcParams["axes.unicode_minus"] = False
-
- def style_figure(
- fig, use_legend, extra_labels, ax_align_legend
- ) -> Tuple[Figure, Legend] | Tuple[Figure, List] | None:
- if ax_align_legend is None:
- ax_align_legend = fig.axes[0]
-
- # this forces labels to be generated. Necessary to detect datetimes
- fig.canvas.draw()
-
- # Set styles for each axis
- legend_title = r"Toelichting"
- handles = list()
- labels = list()
-
- for ax in fig.axes:
- # Enable grid grid
- ax.grid(**style_guide.major_grid)
- ax.grid(**style_guide.minor_grid)
-
- for _, spine in ax.spines.items():
- spine.set_linewidth(style_guide.spine_width)
-
- if cls._is_timeaxis(ax.xaxis):
- ax.xaxis.set_major_formatter(cls.myFmt)
- ax.xaxis.set_major_locator(cls.monthlocator)
- if cls._is_timeaxis(ax.yaxis):
- ax.yaxis.set_major_formatter(cls.myFmt)
- ax.yaxis.set_major_locator(cls.monthlocator)
-
- ax.patch.set_visible(False)
- h, l = ax.get_legend_handles_labels()
- handles.extend(h)
- labels.extend(l)
-
- if extra_labels:
- handles.extend(extra_labels[0])
- labels.extend(extra_labels[1])
- fig.tight_layout()
- if use_legend:
- lgd = fig.legend(
- handles,
- labels,
- loc="upper left",
- bbox_to_anchor=(1.0, ax_align_legend.get_position().y1),
- bbox_transform=fig.transFigure,
- edgecolor="k",
- facecolor="white",
- framealpha=1,
- borderaxespad=0,
- title=legend_title.upper(),
- )
+ return cls._style_figure(
+ style_guide=style_guide,
+ fig=fig,
+ use_legend=use_legend,
+ extra_labels=extra_labels,
+ ax_align_legend=ax_align_legend,
+ )
- return fig, lgd
- else:
- return fig, handles, labels
+ @classmethod
+ def _initiate(cls, style_guide: StyleGuide) -> None:
+ # Set default locale to NL
+ # TODO: add localization options (#85)
+ PlotStyles.set_locale("nl_NL.UTF-8")
+
+ # Color style
+ mpl.rcParams["axes.prop_cycle"] = mpl.cycler(
+ color=cls.colorscheme * 3,
+ linestyle=["-"] * len(cls.colorscheme) + ["--"] * len(cls.colorscheme) + ["-."] * len(cls.colorscheme),
+ )
- if not fig:
- return initiate()
- else:
- initiate()
- return style_figure(
- fig=fig,
- use_legend=use_legend,
- extra_labels=extra_labels,
- ax_align_legend=ax_align_legend,
+ # Font style
+ font = style_guide.font
+
+ # not all fonts support the unicode minus, so disable this option
+ mpl.rc("font", **font)
+ mpl.rcParams["axes.unicode_minus"] = False
+
+ @classmethod
+ def _style_figure(
+ cls,
+ style_guide: StyleGuide,
+ fig: Figure | None,
+ extra_labels: list | None,
+ ax_align_legend: Axes | None,
+ *,
+ use_legend: bool,
+ ) -> tuple[Figure, Legend] | tuple[Figure, list] | None:
+ if ax_align_legend is None:
+ ax_align_legend = fig.axes[0]
+
+ # this forces labels to be generated. Necessary to detect datetimes
+ fig.canvas.draw()
+
+ # Set styles for each axis
+ legend_title = r"Toelichting"
+ handles = []
+ labels = []
+
+ for ax in fig.axes:
+ # Enable grid grid
+ ax.grid(**style_guide.major_grid)
+ ax.grid(**style_guide.minor_grid)
+
+ for spine in ax.spines.values():
+ spine.set_linewidth(style_guide.spine_width)
+
+ if cls._is_timeaxis(ax.xaxis):
+ ax.xaxis.set_major_formatter(cls.my_fmt)
+ ax.xaxis.set_major_locator(cls.monthlocator)
+ if cls._is_timeaxis(ax.yaxis):
+ ax.yaxis.set_major_formatter(cls.my_fmt)
+ ax.yaxis.set_major_locator(cls.monthlocator)
+
+ ax.patch.set_visible(False)
+ h, lab = ax.get_legend_handles_labels()
+ handles.extend(h)
+ labels.extend(lab)
+
+ if extra_labels:
+ handles.extend(extra_labels[0])
+ labels.extend(extra_labels[1])
+ fig.tight_layout()
+ if use_legend:
+ lgd = fig.legend(
+ handles,
+ labels,
+ loc="upper left",
+ bbox_to_anchor=(1.0, ax_align_legend.get_position().y1),
+ bbox_transform=fig.transFigure,
+ edgecolor="k",
+ facecolor="white",
+ framealpha=1,
+ borderaxespad=0,
+ title=legend_title.upper(),
)
+ return fig, lgd
+ return fig, handles, labels
+
class ModelOutputReader(FM2ProfBase):
- """
- This class provides methods to post-process 1D and 2D data,
- by writing csv files of output locations (observation stations)
+ """Provide methods to post-process 1D and 2D data.
+
+ The data is prost-processed by writing csv files of output locations (observation stations)
that are both in 1D and 2D. It produces two csv files that
are input for :meth:`fm2prof.utils.ModelOutputPlotter`
@@ -1151,98 +1147,111 @@ class ModelOutputReader(FM2ProfBase):
__fileOutName_F2D_Q = "2D_Q.csv"
__fileOutName_F2D_H = "2D_H.csv"
- _key_1D_Q_name = "observation_id"
- _key_1D_Q = "water_discharge"
- _key_1D_time = "time"
- _key_1D_H_name = "observation_id"
- _key_1D_H = "water_level"
-
- _key_2D_Q_name = "cross_section_name"
- _key_2D_Q = "cross_section_discharge"
- _key_2D_time = "time"
- _key_2D_H_name = "station_name"
- _key_2D_H = "waterlevel"
+ _key_1d_q_name = "observation_id"
+ _key_1d_q = "water_discharge"
+ _key_1d_time = "time"
+ _key_1d_h_name = "observation_id"
+ _key_1d_h = "water_level"
+
+ _key_2d_q_name = "cross_section_name"
+ _key_2d_q = "cross_section_discharge"
+ _key_2d_time = "time"
+ _key_2d_h_name = "station_name"
+ _key_2d_h = "waterlevel"
__fileOutName_1D2DMap = "map_1d_2d.csv"
_time_fmt = "%Y-%m-%d %H:%M:%S"
def __init__(
self,
- logger=None,
+ logger: Logger | None = None,
start_time: datetime | None = None,
stop_time: datetime | None = None,
- ):
+ ) -> None:
+ """Instantiate a ModelOutputReader object.
+
+ Args:
+ logger (Logger | None, optional): logger. Defaults to None.
+ start_time (datetime | None, optional): start time. Defaults to None.
+ stop_time (datetime | None, optional): stop time. Defaults to None.
+
+ """
super().__init__(logger=logger)
- self._path_out: Path = Path(".")
- self._path_flow1d: Path = Path(".")
- self._path_flow2d: Path = Path(".")
+ self._path_out: Path = Path()
+ self._path_flow1d: Path = Path()
+ self._path_flow2d: Path = Path()
- self._data_1D_Q: pd.DataFrame = None
+ self._data_1d_q: pd.DataFrame = None
self._time_offset_1d: int = 0
self._start_time: datetime | None = start_time
self._stop_time: datetime | None = stop_time
@property
- def start_time(self) -> Union[datetime, None]:
- """if defined, used to mask data"""
+ def start_time(self) -> datetime | None:
+ """If defined, used to mask data."""
return self._start_time
@start_time.setter
- def start_time(self, input_time: datetime):
+ def start_time(self, input_time: datetime) -> None:
if isinstance(input_time, datetime):
self._start_time = input_time
@property
- def stop_time(self) -> Union[datetime, None]:
- """if defined, used to mask data"""
+ def stop_time(self) -> datetime | None:
+ """If defined, used to mask data."""
return self._stop_time
@stop_time.setter
- def stop_time(self, input_time: datetime):
+ def stop_time(self, input_time: datetime) -> None:
if isinstance(input_time, datetime):
self._stop_time = input_time
@property
- def path_flow1d(self):
+ def path_flow1d(self) -> Path:
+ """Return path to flow 1D file."""
return self._path_flow1d
@path_flow1d.setter
- def path_flow1d(self, path: Union[Path, str]):
+ def path_flow1d(self, path: Path | str) -> None:
# Verify path is dir
- assert Path(path).is_file()
+ if not Path(path).is_file():
+ err_msg = f"Given path, {path}, is not a file."
+ raise ValueError(err_msg)
# set attribute
self._path_flow1d = Path(path)
@property
- def path_flow2d(self):
+ def path_flow2d(self) -> Path:
+ """Return path to flow 2D file."""
return self._path_flow2d
@path_flow2d.setter
- def path_flow2d(self, path: Union[Path, str]):
+ def path_flow2d(self, path: Path | str) -> None:
# Verify path is file
- assert Path(path).is_file()
+ if not Path(path).is_file():
+ err_msg = f"Given path, {path}, is not a file."
+ raise ValueError(err_msg)
# set attribute
self._path_flow2d = Path(path)
- def load_flow1d_data(self):
- """
- Loads 'observations.nc' and outputs to csv file
+ def load_flow1d_data(self) -> None:
+ """Load 'observations.nc' and outputs to csv file.
.. note::
Path to the 1D model must first be set by using
>>> ModelOutputReader.path_flow1d = path_to_dir_that_contains_dimr_xml
"""
- if self.file_1D_Q.is_file() & self.file_1D_H.is_file():
- self._data_1D_Q = pd.read_csv(
- self.file_1D_Q,
+ if self.file_1d_q.is_file() & self.file_1d_h.is_file():
+ self._data_1d_q = pd.read_csv(
+ self.file_1d_q,
index_col=0,
parse_dates=True,
date_format=self._time_fmt,
)
- self._data_1D_H = pd.read_csv(
- self.file_1D_H,
+ self._data_1d_h = pd.read_csv(
+ self.file_1d_h,
index_col=0,
parse_dates=True,
date_format=self._time_fmt,
@@ -1250,30 +1259,31 @@ def load_flow1d_data(self):
self.set_logger_message("Using existing flow1d csv files")
else:
self.set_logger_message("Importing from NetCDF")
- self._data_1D_H, self._data_1D_Q = self._import_1Dobservations()
+ self._data_1d_h, self._data_1d_q = self._import_1d_observations()
self.set_logger_message("Writing to CSV (waterlevels)")
- self._data_1D_H.to_csv(self.file_1D_H)
+ self._data_1d_h.to_csv(self.file_1d_h)
self.set_logger_message("Writing to CSV (discharge)")
- self._data_1D_Q.to_csv(self.file_1D_Q)
+ self._data_1d_q.to_csv(self.file_1d_q)
- def load_flow2d_data(self):
- """
- Loads 2D output file (netCDF, must contain observation point results),
+ def load_flow2d_data(self) -> None:
+ """Load 2D output file.
+
+ netCDF, must contain observation point results,
matches to 1D result, output to csv
.. note::
Path to the 2D model output
>>> ModelOutputReader.path_flow2d = path_to_netcdf_file
"""
- if self.file_2D_Q.is_file() & self.file_2D_H.is_file():
- self._data_2D_Q = pd.read_csv(
- self.file_2D_Q,
+ if self.file_2d_q.is_file() & self.file_2d_h.is_file():
+ self._data_2d_q = pd.read_csv(
+ self.file_2d_q,
index_col=0,
parse_dates=True,
date_format=self._time_fmt,
)
- self._data_2D_H = pd.read_csv(
- self.file_2D_H,
+ self._data_2d_h = pd.read_csv(
+ self.file_2d_h,
index_col=0,
parse_dates=True,
date_format=self._time_fmt,
@@ -1281,199 +1291,205 @@ def load_flow2d_data(self):
self.set_logger_message("Using existing flow2d csv files")
else:
# write to file
- self._import_2Dobservations()
+ self._import_2d_observations()
# then load
- self._data_2D_Q = pd.read_csv(
- self.file_2D_Q,
+ self._data_2d_q = pd.read_csv(
+ self.file_2d_q,
index_col=0,
parse_dates=True,
date_format=self._time_fmt,
)
- self._data_2D_H = pd.read_csv(
- self.file_2D_H,
+ self._data_2d_h = pd.read_csv(
+ self.file_2d_h,
index_col=0,
parse_dates=True,
date_format=self._time_fmt,
)
- def get_1d2d_map(self):
- """Writes a map between stations in 1D and stations in 2D. Matches based on identical characters in first nine slots"""
- if self.file_1D2D_map.is_file():
+ def get_1d2d_map(self) -> None:
+ """Write a map between stations in 1D and stations in 2D.
+
+ Matches based on identical characters in first nine slots
+ """
+ if self.file_1d2d_map.is_file():
self.set_logger_message("using existing 1d-2d map")
return
- else:
- self._get_1d2d_map()
+ self._get_1d2d_map()
def read_all_data(self) -> None:
- """ """
+ """Read all data."""
self.load_flow1d_data()
self.get_1d2d_map()
self.load_flow2d_data()
- def _dateparser(self, t):
+ def _dateparser(self, t: str) -> datetime:
# DEPRECATED
return datetime.strptime(t, self._time_fmt)
@property
def output_path(self) -> Path:
+ """Return output path."""
return self._path_out
@output_path.setter
- def output_path(self, new_path: Union[Path, str]):
+ def output_path(self, new_path: Path | str) -> None:
newpath = Path(new_path)
if newpath.is_dir():
self._path_out = newpath
else:
- raise ValueError(f"{new_path} is not a directory")
+ err_msg = f"{new_path} is not a directory"
+ raise ValueError(err_msg)
@property
- def file_1D_Q(self):
+ def file_1d_q(self) -> Path:
+ """Return path to 1D water discharge file."""
return self.output_path.joinpath(self.__fileOutName_F1D_Q)
@property
- def file_1D_H(self):
+ def file_1d_h(self) -> Path:
+ """Return path to 1D water level file."""
return self.output_path.joinpath(self.__fileOutName_F1D_H)
@property
- def file_2D_Q(self):
+ def file_2d_q(self) -> Path:
+ """Return path to 2D discharge file."""
return self.output_path.joinpath(self.__fileOutName_F2D_Q)
@property
- def file_2D_H(self):
+ def file_2d_h(self) -> Path:
+ """Return path to 2D water level."""
return self.output_path.joinpath(self.__fileOutName_F2D_H)
@property
- def file_1D2D_map(self):
+ def file_1d2d_map(self) -> Path:
+ """Return path to 1D2D map file."""
return self.output_path.joinpath(self.__fileOutName_1D2DMap)
@property
- def data_1D_H(self):
- return self._apply_startstop_time(self._data_1D_H)
+ def data_1d_h(self) -> pd.DataFrame:
+ """Apply start stop time to 1D water level data."""
+ return self._apply_startstop_time(self._data_1d_h)
@property
- def data_2D_H(self):
- return self._apply_startstop_time(self._data_2D_H)
+ def data_2d_h(self) -> pd.DataFrame:
+ """Apply start stop time to 2D water level data."""
+ return self._apply_startstop_time(self._data_2d_h)
@property
- def data_1D_Q(self):
- return self._apply_startstop_time(self._data_1D_Q)
+ def data_1d_q(self) -> pd.DataFrame:
+ """Apply start stop time to 1D discharge data."""
+ return self._apply_startstop_time(self._data_1d_q)
@property
- def data_2D_Q(self):
- return self._apply_startstop_time(self._data_2D_Q)
+ def data_2d_q(self) -> pd.DataFrame:
+ """Apply start stop time to 2D discharge data."""
+ return self._apply_startstop_time(self._data_2d_q)
@property
- def time_offset_1d(self):
+ def time_offset_1d(self) -> int:
+ """Return time offset for 1D data."""
return self._time_offset_1d
@time_offset_1d.setter
- def time_offset_1d(self, seconds: int = 0):
+ def time_offset_1d(self, seconds: int = 0) -> None:
self._time_offset_1d = seconds
def _apply_startstop_time(self, data: pd.DataFrame) -> pd.DataFrame:
- """
- Applies stop/start time to data
- """
+ """Apply stop/start time to data."""
if self.stop_time is None:
self.stop_time = data.index[-1]
if self.start_time is None:
self.start_time = data.index[0]
if self.start_time >= self.stop_time:
+ err_msg = "Stop time ({self.stop_time}) should be later than start time ({self.start_time})"
self.set_logger_message(
- "Stop time ({self.stop_time}) should be later than start time ({self.start_time})",
+ err_msg,
"error",
)
- raise ValueError
+ raise ValueError(err_msg)
if bool(self.start_time) and (self.start_time >= data.index[-1]):
+ err_msg = f"Provided start time {self.start_time} is later than last record in data ({data.index[-1]})"
self.set_logger_message(
- f"Provided start time {self.start_time} is later than last record in data ({data.index[-1]})",
+ err_msg,
"error",
)
- raise ValueError
+ raise ValueError(err_msg)
if bool(self.stop_time) and (self.stop_time <= data.index[0]):
+ err_msg = f"Provided stop time {self.stop_time} is earlier than first record in data ({data.index[0]})"
self.set_logger_message(
- f"Provided stop time {self.stop_time} is earlier than first record in data ({data.index[0]})",
+ err_msg,
"error",
)
- raise ValueError
+ raise ValueError(err_msg)
if bool(self.start_time) and bool(self.stop_time):
- return data[
- (data.index >= self.start_time) & (data.index <= self.stop_time)
- ]
- elif bool(self.start_time) and not bool(self.stop_time):
+ return data[(data.index >= self.start_time) & (data.index <= self.stop_time)]
+ if bool(self.start_time) and not bool(self.stop_time):
return data[(data.index >= self.start_time)]
- elif not bool(self.start_time) and bool(self.stop_time):
+ if not bool(self.start_time) and bool(self.stop_time):
return data[data.index <= self.stop_time]
- else:
- return data
+ return data
@staticmethod
- def _parse_names(nclist, encoding="utf-8"):
- """Parses the bytestring list of names in netcdf"""
- return [
- "".join([bstr.decode(encoding) for bstr in ncrow]).strip()
- for ncrow in nclist
- ]
+ def _parse_names(nclist: list[str], encoding: str = "utf-8") -> list[str]:
+ """Parse the bytestring list of names in netcdf."""
+ return ["".join([bstr.decode(encoding) for bstr in ncrow]).strip() for ncrow in nclist]
- def _import_2Dobservations(self) -> None:
- print("Reading 2D data")
+ def _import_2d_observations(self) -> None:
+ self.set_logger_message("Reading 2D data")
for nkey, dkey, map_key, fname in zip(
- [self._key_2D_Q_name, self._key_2D_H_name],
- [self._key_2D_Q, self._key_2D_H],
+ [self._key_2d_q_name, self._key_2d_h_name],
+ [self._key_2d_q, self._key_2d_h],
["2D_Q", "2D_H"],
- [self.file_2D_Q, self.file_2D_H],
+ [self.file_2d_q, self.file_2d_h],
):
with Dataset(self._path_flow2d) as f:
self.set_logger_message(f"loading 2D data for {map_key}")
- station_map = pd.read_csv(self.file_1D2D_map, index_col=0)
+ station_map = pd.read_csv(self.file_1d2d_map, index_col=0)
qnames = self._parse_names(f.variables[nkey][:])
qdata = f.variables[dkey][:]
time = self._parse_time(f.variables["time"])
- df = pd.DataFrame(columns=station_map.index, index=time)
+ station_map_df = pd.DataFrame(columns=station_map.index, index=time)
self.set_logger_message("Matching 1D and 2D data")
- for index, station in tqdm.tqdm(
- station_map.iterrows(), total=len(station_map.index)
- ):
+ for _, station in tqdm.tqdm(station_map.iterrows(), total=len(station_map.index)):
# Get index of the current station, or skip if ValueError
try:
si = qnames.index(station[map_key])
except ValueError:
continue
- df[station.name] = qdata[:, si]
+ station_map_df[station.name] = qdata[:, si]
- df.to_csv(f"{fname}")
+ station_map_df.to_csv(f"{fname}")
- def _import_1Dobservations(self) -> pd.DataFrame:
- """
- time_offset: offset in seconds
+ def _import_1d_observations(self) -> tuple[pd.DataFrame, pd.DataFrame]:
+ """Import 1D observations.
+
+ time_offset: offset in seconds.
"""
_file_his = self.path_flow1d
with Dataset(_file_his) as f:
- names = self._parse_names(
- f.variables[self._key_1D_H_name]
- ) # names are the same for Q in 1D
+ names = self._parse_names(f.variables[self._key_1d_h_name]) # names are the same for Q in 1D
- time = self._parse_time(f.variables[self._key_1D_time])
- data = f.variables[self._key_1D_H][:]
- dfH = pd.DataFrame(columns=names, index=time, data=data)
+ time = self._parse_time(f.variables[self._key_1d_time])
+ data = f.variables[self._key_1d_h][:]
+ df_h = pd.DataFrame(columns=names, index=time, data=data)
- data = f.variables[self._key_1D_Q][:]
- dfQ = pd.DataFrame(columns=names, index=time, data=data)
+ data = f.variables[self._key_1d_q][:]
+ df_q = pd.DataFrame(columns=names, index=time, data=data)
# apply index shift
- dfH.index = dfH.index + timedelta(seconds=self.time_offset_1d)
- dfQ.index = dfQ.index + timedelta(seconds=self.time_offset_1d)
+ df_h.index = df_h.index + timedelta(seconds=self.time_offset_1d)
+ df_q.index = df_q.index + timedelta(seconds=self.time_offset_1d)
- return dfH, dfQ
+ return df_h, df_q
- def _parse_time(self, timevector: pd.DataFrame):
- """seconds"""
+ def _parse_time(self, timevector: pd.DataFrame) -> list[datetime]:
+ """Parse time from seconds."""
unit = timevector.units.replace("seconds since ", "").strip()
try:
@@ -1485,21 +1501,21 @@ def _parse_time(self, timevector: pd.DataFrame):
return [start_time + timedelta(seconds=i) for i in timevector[:]]
- def _parse_1D_stations(self) -> Generator[str, None, None]:
- """Reads the names of observations stations from 1D model"""
- return list(self._data_1D_H.columns)
+ def _parse_1d_stations(self) -> list[str]:
+ """Read the names of observations stations from 1D model."""
+ return list(self._data_1d_h.columns)
- def _get_1d2d_map(self):
+ def _get_1d2d_map(self) -> None:
_file_his = self.path_flow2d
with Dataset(_file_his) as f:
- qnames = self._parse_names(f.variables[self._key_2D_Q_name][:])
- hnames = self._parse_names(f.variables[self._key_2D_H_name][:])
+ qnames = self._parse_names(f.variables[self._key_2d_q_name][:])
+ hnames = self._parse_names(f.variables[self._key_2d_h_name][:])
# get matching names based on first nine characters
- with open(self.file_1D2D_map, "w") as fw:
+ with self.file_1d2d_map.open("w") as fw:
fw.write("1D,2D_H,2D_Q\n")
- for n in tqdm.tqdm(list(self._parse_1D_stations())):
+ for n in tqdm.tqdm(list(self._parse_1d_stations())):
try:
qn = next(x for x in qnames if x.startswith(n[:9]))
except StopIteration:
@@ -1512,9 +1528,7 @@ def _get_1d2d_map(self):
class Compare1D2D(ModelOutputReader):
- """
- Utility to compare the results of a 1D and 2D model through
- visualisation and statistical post-processing.
+ """Utility to compare the results of a 1D and 2D model through visualisation and statistical post-processing.
Note:
If 2D and 1D netCDF input files are provided, they will first be
@@ -1534,32 +1548,34 @@ class Compare1D2D(ModelOutputReader):
```
- Parameters:
+ Parameters
+ ----------
project: `fm2prof.Project` object.
path_1d: path to SOBEK dimr directory
path_2d: path to his nc file
routes: list of branch abbreviations, e.g. ['NR', 'LK']
- start_time: start time for plotting and analytics. Use this to crop the time to prevent initalisation from affecting statistics.
+ start_time: start time for plotting and analytics. Use this to crop the time to prevent initalisation from
+ affecting statistics.
stop_time: stop time for plotting and analytics.
style: `PlotStyles` style
+
"""
- _routes: List[List[str]] = None
+ _routes: list[list[str]] = None
- def __init__(
+ def __init__( # noqa: PLR0913
self,
project: Project,
path_1d: Path | str | None = None,
path_2d: Path | str | None = None,
- routes: List[List[str]] | None = None,
+ routes: list[list[str]] | None = None,
start_time: None | datetime = None,
stop_time: None | datetime = None,
style: str = "sito",
- ):
+ ) -> None:
+ """Instantiate a Compare1D2D object."""
if project:
- super().__init__(
- logger=project.get_logger(), start_time=start_time, stop_time=stop_time
- )
+ super().__init__(logger=project.get_logger(), start_time=start_time, stop_time=stop_time)
self.output_path = project.get_output_directory()
else:
super().__init__()
@@ -1582,10 +1598,10 @@ def __init__(
# Defaults
self.routes = routes
self.statistics = None
- self._data_1D_H: pd.DataFrame = None
- self._data_2D_H: pd.DataFrame = None
- self._data_1D_H_digitized: pd.DataFrame = None
- self._data_2D_H_digitized: pd.DataFrame = None
+ self._data_1d_h: pd.DataFrame = None
+ self._data_2d_h: pd.DataFrame = None
+ self._data_1d_h_digitized: pd.DataFrame = None
+ self._data_2d_h_digitized: pd.DataFrame = None
self._qsteps = np.arange(0, 100 * np.ceil(18000 / 100), 200)
# initiate plotstyle
@@ -1610,15 +1626,10 @@ def __init__(
"figures/stations",
]
for od in output_dirs:
- try:
- os.makedirs(self.output_path.joinpath(od))
- except FileExistsError:
- pass
+ self.output_path.joinpath(od).mkdir(parents=True, exist_ok=True)
- def eval(self):
- """
- does a bunch
- """
+ def eval(self) -> None:
+ """Create multiple figures."""
for route in tqdm.tqdm(self.routes):
self.set_logger_message(f"Making figures for route {route}")
self.figure_longitudinal_rating_curve(route)
@@ -1626,131 +1637,112 @@ def eval(self):
self.heatmap_rating_curve(route)
self.heatmap_time(route)
- self.set_logger_message(f"Making figures for stations")
- for station in tqdm.tqdm(self.stations(), total=self._data_1D_H.shape[1]):
+ self.set_logger_message("Making figures for stations")
+ for station in tqdm.tqdm(self.stations(), total=self._data_1d_h.shape[1]):
self.figure_at_station(station)
@property
- def routes(self):
+ def routes(self) -> list[list[str]]:
+ """Return routes."""
return self._routes
@routes.setter
- def routes(self, routes):
+ def routes(self, routes: list[list[str]] | str) -> None:
if isinstance(routes, list):
self._routes = routes
if isinstance(routes, str):
self._routes = ast.literal_eval(routes)
@property
- def file_1D_H_digitized(self):
- return self.file_1D_H.parent.joinpath(f"{self.file_1D_H.stem}_digitized.csv")
+ def file_1d_h_digitized(self) -> Path:
+ """Return 1D water level digitized file path."""
+ return self.file_1d_h.parent.joinpath(f"{self.file_1d_h.stem}_digitized.csv")
@property
- def file_2D_H_digitized(self):
- return self.file_2D_H.parent.joinpath(f"{self.file_2D_H.stem}_digitized.csv")
+ def file_2d_h_digitized(self) -> Path:
+ """Return 2D water level digitized file path."""
+ return self.file_2d_h.parent.joinpath(f"{self.file_2d_h.stem}_digitized.csv")
@property
- def colorscheme(self):
+ def colorscheme(self) -> str:
+ """Color scheme."""
return self._colorscheme
- def digitize_data(self):
- if self.file_1D_H_digitized.is_file():
+ def digitize_data(self) -> None:
+ """Compute the average for a given bin for 1D and 2D water level data.
+
+ Use to make Q-H graphs instead of T-H graph
+ """
+ if self.file_1d_h_digitized.is_file():
self.set_logger_message("Using existing digitized file for 1d")
- self._data_1D_H_digitized = pd.read_csv(
- self.file_1D_H_digitized, index_col=0
- )
+ self._data_1d_h_digitized = pd.read_csv(self.file_1d_h_digitized, index_col=0)
else:
- self._data_1D_H_digitized = self._digitize_data(
- self._data_1D_H, self._data_1D_Q, self._qsteps
- )
- self._data_1D_H_digitized.to_csv(self.file_1D_H_digitized)
- if self.file_2D_H_digitized.is_file():
+ self._data_1d_h_digitized = self._digitize_data(self._data_1d_h, self._data_1d_q, self._qsteps)
+ self._data_1d_h_digitized.to_csv(self.file_1d_h_digitized)
+ if self.file_2d_h_digitized.is_file():
self.set_logger_message("Using existing digitized file for 2d")
- self._data_2D_H_digitized = pd.read_csv(
- self.file_2D_H_digitized, index_col=0
- )
+ self._data_2d_h_digitized = pd.read_csv(self.file_2d_h_digitized, index_col=0)
else:
- self._data_2D_H_digitized = self._digitize_data(
- self._data_2D_H, self._data_2D_Q, self._qsteps
- )
- self._data_2D_H_digitized.to_csv(self.file_2D_H_digitized)
+ self._data_2d_h_digitized = self._digitize_data(self._data_2d_h, self._data_2d_q, self._qsteps)
+ self._data_2d_h_digitized.to_csv(self.file_2d_h_digitized)
def stations(self) -> Generator[str, None, None]:
- for station in self._data_1D_H.columns:
- yield station
+ """Yield station names."""
+ yield from self._data_1d_h.columns
@staticmethod
- def _digitize_data(hdata, qdata, bins) -> pd.DataFrame:
- """Computes the average for a given bin. Use to make Q-H graphs instead of T-H graph"""
+ def _digitize_data(hdata: pd.DateFrame, qdata: pd.DataFrame, bins: np.ndarray) -> pd.DataFrame:
+ """Compute the average for a given bin.
+ Use to make Q-H graphs instead of T-H graph
+ """
stations = hdata.columns
- C = list()
- rkms = list()
+ c = []
for i, station in enumerate(stations):
- # rkms.append(float(station.split('_')[1]))
d = np.digitize(qdata[station], bins)
- C.append([np.nanmean(hdata[station][d == i]) for i, _ in enumerate(bins)])
+ c.append([np.nanmean(hdata[station][d == i]) for i, _ in enumerate(bins)])
+ c = np.array(c) # [sort]
+ return pd.DataFrame(columns=stations, index=bins, data=c.T)
- # sort = np.argsort(rkms)
- C = np.array(C) # [sort]
- return pd.DataFrame(columns=stations, index=bins, data=C.T)
+ def _names_to_rkms(self, station_names: list[str]) -> list[float]:
+ return [self._catch_e(lambda i=i: float(i.split("_")[1]), (IndexError, ValueError)) for i in station_names]
- def _names_to_rkms(self, station_names: List[str]) -> List[float]:
- return [
- self._catch_e(lambda: float(i.split("_")[1]), (IndexError, ValueError))
- for i in station_names
- ]
-
- def _names_to_branches(self, station_names: List[str]) -> List[str]:
- return [
- self._catch_e(lambda: i.split("_")[0], IndexError) for i in station_names
- ]
+ def _names_to_branches(self, station_names: list[str]) -> list[str]:
+ return [self._catch_e(lambda i=i: i.split("_")[0], IndexError) for i in station_names]
- def get_route(
- self, route: List[str]
- ) -> Tuple[List[str], List[float], List[Tuple[str, float]]]:
- """returns a sorted list of stations along a route, with rkms"""
- station_names = self._data_2D_H.columns
+ def get_route(self, route: list[str]) -> tuple[list[str], list[float], list[tuple[str, float]]]:
+ """Return a sorted list of stations along a route, with rkms."""
+ station_names = self._data_2d_h.columns
# Parse names
rkms = self._names_to_rkms(station_names)
branches = self._names_to_branches(station_names)
# select along route
- routekms = list()
- stations = list()
- lmw_stations = list()
+ routekms = []
+ stations = []
+ lmw_stations = []
for stop in route:
indices = [i for i, b in enumerate(branches) if b == stop]
routekms.extend([rkms[i] for i in indices])
stations.extend([station_names[i] for i in indices])
- lmw_stations.extend(
- [
- (station_names[i], rkms[i])
- for i in indices
- if "LMW" in station_names[i]
- ]
- )
+ lmw_stations.extend([(station_names[i], rkms[i]) for i in indices if "LMW" in station_names[i]])
# sort data
sorted_indices = np.argsort(routekms)
- sorted_stations = [
- stations[i] for i in sorted_indices if routekms[i] is not np.nan
- ]
+ sorted_stations = [stations[i] for i in sorted_indices if routekms[i] is not np.nan]
sorted_rkms = [routekms[i] for i in sorted_indices if routekms[i] is not np.nan]
# sort lmw stations
- lmw_stations = [
- lmw_stations[j] for j in np.argsort([i[1] for i in lmw_stations])
- ]
+ lmw_stations = [lmw_stations[j] for j in np.argsort([i[1] for i in lmw_stations])]
return sorted_stations, sorted_rkms, lmw_stations
def statistics_to_file(self, file_path: str = "error_statistics") -> None:
- """
- Creates and output a file `error_statistics.csv', which is a
- comma-seperated file with the following columns:
+ """Calculate statistics and write them to file.
+
+ The output file is a comma-seperated file with the following columns:
,bias,rkm,branch,is_lmw,std,mae,max13,last25
@@ -1764,7 +1756,6 @@ def statistics_to_file(self, file_path: str = "error_statistics") -> None:
- mae = mean absolute error of the error
"""
-
self.statistics = self._compute_statistics()
statfile = self.output_path.joinpath(file_path).with_suffix(".csv")
@@ -1776,35 +1767,26 @@ def statistics_to_file(self, file_path: str = "error_statistics") -> None:
# summary of statistics
s = self.statistics
- with open(sumfile, "w") as f:
+ with sumfile.open("w") as f:
for branch in s.branch.unique():
bbias = s.bias[s.branch == branch].mean()
bstd = s["std"][s.branch == branch].mean()
lmw_bias = s.bias[(s.branch == branch) & s.is_lmw].mean()
lmw_std = s["std"][(s.branch == branch) & s.is_lmw].mean()
- f.write(
- f"{branch},{bbias:.2f}±({bstd:.2f}), {lmw_bias:.2f}±({lmw_std:.2f})\n"
- )
-
- def figure_at_station(
- self, station: str, func: str = "time", savefig: bool = True
- ) -> FigureOutput:
- """
- Creates a figure with the timeseries at a single observation station.
-
- ``` py
-
- ```
+ f.write(f"{branch},{bbias:.2f}±({bstd:.2f}), {lmw_bias:.2f}±({lmw_std:.2f})\n")
- Parameters:
- station: name of station. use `stations` method to list all station names
- func: use `time` for a timeseries and `qh` for rating curve
- savefig: if True, saves to png. If False, returned FigureOutput
+ def figure_at_station(self, station: str, func: str = "time", *, savefig: bool = True) -> FigureOutput | None:
+ """Create a figure with the timeseries at a single observation station.
+ Args:
+ station (str): name of station. use `stations` method to list all station names
+ func (str, optional): use `time` for a timeseries and `qh` for rating curve
+ savefig (bool, optional):if True, saves to png. If False, returned FigureOutput. Defaults to True.
+ Returns:
+ FigureOutput | None: FigureOutput object or None if savefig is set to True.
"""
-
fig, ax = plt.subplots(1, figsize=(12, 5))
error_ax = ax.twinx()
@@ -1813,14 +1795,14 @@ def figure_at_station(
case "qh":
ax.plot(
self._qsteps,
- self._data_2D_H_digitized[station],
+ self._data_2d_h_digitized[station],
"--",
linewidth=2,
label="2D",
)
ax.plot(
self._qsteps,
- self._data_1D_H_digitized[station],
+ self._data_1d_h_digitized[station],
"-",
linewidth=2,
label="1D",
@@ -1831,20 +1813,19 @@ def figure_at_station(
ax.set_ylabel("Waterstand [m+NAP]")
error_ax.plot(
self._qsteps,
- self._data_1D_H_digitized[station]
- - self._data_2D_H_digitized[station],
+ self._data_1d_h_digitized[station] - self._data_2d_h_digitized[station],
".",
color=self._color_error,
)
case "time":
- ax.plot(self.data_2D_H[station], "--", linewidth=2, label="2D")
- ax.plot(self.data_1D_H[station], "-", linewidth=2, label="1D")
+ ax.plot(self.data_2d_h[station], "--", linewidth=2, label="2D")
+ ax.plot(self.data_1d_h[station], "-", linewidth=2, label="1D")
ax.set_ylabel("Waterstand [m+NAP]")
ax.set_title(f"{station}\nTijdreeks")
error_ax.plot(
- self.data_1D_H[station] - self.data_2D_H[station],
+ self.data_1d_h[station] - self.data_2d_h[station],
".",
label="1D-2D",
color=self._color_error,
@@ -1874,36 +1855,31 @@ def figure_at_station(
if savefig:
fig.savefig(
- self.output_path.joinpath("figures/stations").joinpath(
- f"{station}.png"
- ),
+ self.output_path.joinpath("figures/stations").joinpath(f"{station}.png"),
bbox_extra_artists=[lgd],
bbox_inches="tight",
)
plt.close()
- else:
- return FigureOutput(fig=fig, axes=ax, legend=lgd)
+ return None
- def _style_error_axes(
- self, ax, ylim: List[float] = [-0.5, 0.5], ylabel: str = "1D-2D [m]"
- ):
+ return FigureOutput(fig=fig, axes=ax, legend=lgd)
+
+ def _style_error_axes(self, ax: Axes, ylim: list[float] = (-0.5, 0.5), ylabel: str = "1D-2D [m]") -> None:
ax.set_ylim(ylim)
ax.set_ylabel(ylabel)
ax.spines["right"].set_edgecolor(self._color_error)
ax.tick_params(axis="y", colors=self._color_error)
- ax.grid(False)
+ ax.grid(visible=False)
def _compute_statistics(self) -> pd.DataFrame:
- """
- Computes statistics for the difference between 1D and 2D water levels
+ """Compute statistics for the difference between 1D and 2D water levels.
Returns DataFrame with
columns: rkm, branch, is_lmw, bias, std, mae
rows: observation stations
"""
-
- diff = self.data_1D_H - self.data_2D_H
+ diff = self.data_1d_h - self.data_2d_h
station_names = diff.columns
rkms = self._names_to_rkms(station_names)
branches = self._names_to_branches(station_names)
@@ -1911,24 +1887,22 @@ def _compute_statistics(self) -> pd.DataFrame:
stats = pd.DataFrame(data=diff.mean(), columns=["bias"])
stats["rkm"] = rkms
stats["branch"] = branches
- stats["is_lmw"] = [
- True if "lmw" in name.lower() else False for name in station_names
- ]
+ stats["is_lmw"] = ["lmw" in name.lower() for name in station_names]
# stats
stats["bias"] = diff.mean()
stats["std"] = diff.std()
stats["mae"] = diff.abs().mean()
- stats["1D_last3"] = self._apply_stat(self.data_1D_H, stat="last3")
- stats["1D_last25"] = self._apply_stat(self.data_1D_H, stat="last25")
- stats["1D_max3"] = self._apply_stat(self.data_1D_H, stat="max3")
- stats["1D_max13"] = self._apply_stat(self.data_1D_H, stat="max13")
+ stats["1D_last3"] = self._apply_stat(self.data_1d_h, stat="last3")
+ stats["1D_last25"] = self._apply_stat(self.data_1d_h, stat="last25")
+ stats["1D_max3"] = self._apply_stat(self.data_1d_h, stat="max3")
+ stats["1D_max13"] = self._apply_stat(self.data_1d_h, stat="max13")
- stats["2D_last3"] = self._apply_stat(self.data_2D_H, stat="last3")
- stats["2D_last25"] = self._apply_stat(self.data_2D_H, stat="last25")
- stats["2D_max3"] = self._apply_stat(self.data_2D_H, stat="max3")
- stats["2D_max13"] = self._apply_stat(self.data_2D_H, stat="max13")
+ stats["2D_last3"] = self._apply_stat(self.data_2d_h, stat="last3")
+ stats["2D_last25"] = self._apply_stat(self.data_2d_h, stat="last25")
+ stats["2D_max3"] = self._apply_stat(self.data_2d_h, stat="max3")
+ stats["2D_max13"] = self._apply_stat(self.data_2d_h, stat="max13")
stats["diff_last3"] = self._apply_stat(diff, stat="last3")
stats["diff_last25"] = self._apply_stat(diff, stat="last25")
@@ -1937,17 +1911,21 @@ def _compute_statistics(self) -> pd.DataFrame:
return stats
- def _get_statistics(self, station):
+ def _get_statistics(self, station: str) -> pd.Series | pd.DataFrame:
if self.statistics is None:
self.statistics = self._compute_statistics()
return self.statistics.loc[station]
def figure_compare_discharge_at_stations(
- self, stations: List[str], title: str = "no_title", savefig: bool = True
+ self,
+ stations: list[str],
+ title: str = "no_title",
+ *,
+ savefig: bool = True,
) -> FigureOutput | None:
- """
- Like `Compare1D2D.figure_at_station`, but compares discharge
- distribution over two stations.
+ """Comparea discharge distribution over two stations.
+
+ Like `Compare1D2D.figure_at_station`.
Example usage:
``` py
@@ -1965,12 +1943,11 @@ def figure_compare_discharge_at_stations(
fig, axs = plt.subplots(2, 1, figsize=(12, 10))
ax_error = axs[0].twinx()
- ax_error.set_zorder(
- axs[0].get_zorder() - 1
- ) # default zorder is 0 for ax1 and ax2
+ ax_error.set_zorder(axs[0].get_zorder() - 1) # default zorder is 0 for ax1 and ax2
- if len(stations) != 2:
- print("error: must define 2 stations")
+ if len(stations) != 2: # noqa: PLR2004
+ err_msg = "Must define 2 stations"
+ raise ValueError(err_msg)
linestyles_2d = ["-", "--"]
for j, station in enumerate(stations):
@@ -1979,51 +1956,51 @@ def figure_compare_discharge_at_stations(
# tijdserie
axs[0].plot(
- self.data_2D_Q[station],
+ self.data_2d_q[station],
label=f"2D, {station.split('_')[0]}",
linewidth=2,
linestyle=linestyles_2d[j],
)
axs[0].plot(
- self.data_1D_Q[station],
+ self.data_1d_q[station],
label=f"1D, {station.split('_')[0]}",
linewidth=2,
linestyle="-",
)
ax_error.plot(
- self._data_1D_Q[station] - self._data_2D_Q[station],
+ self._data_1d_q[station] - self._data_2d_q[station],
".",
color="r",
markersize=5,
- label=f"1D-2D",
+ label="1D-2D",
)
# discharge distribution
- Q2D = self.data_2D_Q[stations]
- Q1D = self.data_1D_Q[stations]
+ q_2d = self.data_2d_q[stations]
+ q_1d = self.data_1d_q[stations]
axs[1].plot(
- Q2D.sum(axis=1),
- (Q2D.iloc[:, 0] / Q2D.sum(axis=1)) * 100,
+ q_2d.sum(axis=1),
+ (q_2d.iloc[:, 0] / q_2d.sum(axis=1)) * 100,
linewidth=2,
linestyle="--",
)
axs[1].plot(
- Q1D.sum(axis=1),
- (Q1D.iloc[:, 0] / Q1D.sum(axis=1)) * 100,
+ q_1d.sum(axis=1),
+ (q_1d.iloc[:, 0] / q_1d.sum(axis=1)) * 100,
linewidth=2,
linestyle="-",
)
axs[1].plot(
- Q2D.sum(axis=1),
- (Q2D.iloc[:, 1] / Q2D.sum(axis=1)) * 100,
+ q_2d.sum(axis=1),
+ (q_2d.iloc[:, 1] / q_2d.sum(axis=1)) * 100,
linewidth=2,
linestyle="--",
)
axs[1].plot(
- Q1D.sum(axis=1),
- (Q1D.iloc[:, 1] / Q1D.sum(axis=1)) * 100,
+ q_1d.sum(axis=1),
+ (q_1d.iloc[:, 1] / q_1d.sum(axis=1)) * 100,
linewidth=2,
linestyle="-",
)
@@ -2050,69 +2027,79 @@ def figure_compare_discharge_at_stations(
bbox_inches="tight",
)
plt.close()
- else:
- return FigureOutput(fig=fig, axes=axs, legend=lgd)
+ return FigureOutput(fig=fig, axes=axs, legend=lgd)
- def get_data_along_route_for_time(
- self, data: pd.DataFrame, route: List[str], time_index: int
- ) -> pd.Series:
- stations, rkms, _ = self.get_route(route)
+ def get_data_along_route_for_time(self, data: pd.DataFrame, route: list[str], time_index: int) -> pd.Series:
+ """Get data along route for a given time index.
+
+ Args:
+ data (pd.DataFrame): Dataframe with data
+ route (list[str]): list of route
+ time_index (int): time index
+
+ Returns:
+ pd.Series: Series containing route data
- tmp_data = list()
- for station in stations:
- tmp_data.append(data[station].iloc[time_index])
+ """
+ stations, rkms, _ = self.get_route(route)
+ tmp_data = []
+ tmp_data = [data[station].iloc[time_index] for station in stations]
return pd.Series(index=rkms, data=tmp_data)
- def get_data_along_route(
- self, data: pd.DataFrame, route: List[str]
- ) -> pd.DataFrame:
+ def get_data_along_route(self, data: pd.DataFrame, route: list[str]) -> pd.DataFrame:
+ """Get data along route.
+
+ Args:
+ data (pd.DataFrame): DataFrame with data
+ route (list[str]): list with route data
+
+ Returns:
+ pd.DataFrame: data
+
+ """
stations, rkms, _ = self.get_route(route)
- tmp_data = list()
- for station in stations:
- tmp_data.append(data[station])
+ tmp_data = []
+ tmp_data = [data[station] for station in stations]
- df = pd.DataFrame(index=rkms, data=tmp_data)
+ route_data_df = pd.DataFrame(index=rkms, data=tmp_data)
# drop duplicates
- df = df.drop_duplicates()
- return df
+ return route_data_df.drop_duplicates()
@staticmethod
- def _sec_to_days(seconds):
+ def _sec_to_days(seconds: float) -> float:
return seconds / (3600 * 24)
@staticmethod
- def _get_nearest_time(data: pd.DataFrame, date: datetime = None) -> int:
+ def _get_nearest_time(data: pd.DataFrame, date: datetime | None = None) -> int:
try:
return list(data.index < date).index(False)
except ValueError:
# False is not list, return last index
return len(data.index) - 1
- def _time_func(self, route) -> Dict[str, pd.Series | str]:
- first_day = self.data_1D_H.index[0] # + timedelta(days=delta_days) * 2
- last_day = self.data_1D_H.index[-1]
+ def _time_func(self, route: list[str]) -> dict[str, pd.Series | str]:
+ first_day = self.data_1d_h.index[0] # + timedelta(days=delta_days) * 2
+ last_day = self.data_1d_h.index[-1]
number_of_days = (last_day - first_day).days
delta_days = int(number_of_days / 6)
- moments = [
- first_day + timedelta(days=i) for i in range(0, number_of_days, delta_days)
- ]
+ moments = [first_day + timedelta(days=i) for i in range(0, number_of_days, delta_days)]
lines = []
for day in moments:
h1d = self.get_data_along_route_for_time(
- data=self.data_1D_H,
+ data=self.data_1d_h,
route=route,
- time_index=self._get_nearest_time(data=self.data_1D_H, date=day),
+ time_index=self._get_nearest_time(data=self.data_1d_h, date=day),
)
h2d = self.get_data_along_route_for_time(
- data=self.data_2D_H,
+ data=self.data_2d_h,
route=route,
- time_index=self._get_nearest_time(data=self.data_2D_H, date=day),
+ time_index=self._get_nearest_time(data=self.data_2d_h, date=day),
)
lines.append({"1D": h1d, "2D": h2d, "label": f"{day:%b-%d}"})
@@ -2120,10 +2107,8 @@ def _time_func(self, route) -> Dict[str, pd.Series | str]:
return lines
@staticmethod
- def _apply_stat(df, stat: str = "max13"):
- """
- Applies column-wise "last25" or "max13" on 1D and 2D data
- """
+ def _apply_stat(df: pd.DataFrame, stat: str = "max13") -> pd.Series:
+ """Apply column-wise "last25" or "max13" on 1D and 2D data."""
columns = df.columns
values = []
for column in columns:
@@ -2142,29 +2127,21 @@ def _apply_stat(df, stat: str = "max13"):
values.append(af[-25:].mean())
return pd.Series(index=columns, data=values)
- def _stat_func(
- self, route: List[str], stat: str = "max13"
- ) -> List[Dict[str, pd.Series | str]]:
- """
- Applies column-wise "last25" or "max13" on 1D and 2D data
- """
- max13_1d = self._apply_stat(
- self.get_data_along_route(self.data_1D_H, route=route).T, stat=stat
- )
- max13_2d = self._apply_stat(
- self.get_data_along_route(self.data_2D_H, route=route).T, stat=stat
- )
+ def _stat_func(self, route: list[str], stat: str = "max13") -> list[dict[str, pd.Series | str]]:
+ """Apply column-wise "last25" or "max13" on 1D and 2D data."""
+ max13_1d = self._apply_stat(self.get_data_along_route(self.data_1d_h, route=route).T, stat=stat)
+ max13_2d = self._apply_stat(self.get_data_along_route(self.data_2d_h, route=route).T, stat=stat)
return [{"1D": max13_1d, "2D": max13_2d, "label": stat}]
- def _lmw_func(self, station_names, station_locs):
+ def _lmw_func(self, station_names: list[str], station_locs: list[int]) -> tuple[list, list]:
st_names = []
st_locs = []
prev_loc = -9999
for name, loc in zip(station_names, station_locs):
if "lmw" not in name.lower():
continue
- if abs(prev_loc - loc) < 5:
+ if abs(prev_loc - loc) < 5: # noqa: PLR2004
self.set_logger_message(
f"skipped labelling {name} because too close to previous station",
"warning",
@@ -2176,9 +2153,11 @@ def _lmw_func(self, station_names, station_locs):
return st_names, st_locs
- def figure_longitudinal_time(self, route: List[str]) -> None:
- warnings.warn(
- 'Method figure_longitudinal_time will be removed in the future. Use figure_longitudinal(route, stat="time") instead ',
+ def figure_longitudinal_time(self, route: list[str]) -> None:
+ """Create a figure along a `route`."""
+ warnings.warn( # noqa: B028
+ 'Method figure_longitudinal_time will be removed in the future. Use figure_longitudinal(route, stat="time")'
+ "instead",
category=DeprecationWarning,
)
@@ -2186,29 +2165,37 @@ def figure_longitudinal_time(self, route: List[str]) -> None:
def figure_longitudinal(
self,
- route: List[str],
+ route: list[str],
stat: str = "time",
- savefig: bool = True,
label: str = "",
add_to_fig: FigureOutput | None = None,
+ *,
+ savefig: bool = True,
) -> FigureOutput | None:
- """
- Creates a figure along a `route`. Content of figure depends
- on `stat`. Figures are saved to `[Compare1D2D.output_path]/figures/longitudinal`
+ """Create a figure along a `route`.
+
+ Content of figure depends on `stat`. Figures are saved to `[Compare1D2D.output_path]/figures/longitudinal`
Example output:

- Parameters:
- route: List of branches (e.g. ['NK', 'LK'])
- stat: what type of longitudinal plot to make. Options are:
- - time
- - last25
- - max13
- savefig: if true, figure is saved to png file. If false, `FigureOutput`
- returned, which is input for `add_to_fig`
- add_to_fig: if `FigureOutput` is provided, adds content to figure.
+
+ Args:
+ route (list[str]): List of branches (e.g. ['NK', 'LK'])
+ stat (str, optional): What type of longitudinal plot to make (options: "time", "last3", "last25", "max3",
+ "max13"). Defaults to "time".
+ label (str, optional): Label of figure. Defaults to "".
+ add_to_fig (FigureOutput | None, optional):if `FigureOutput` is provided, adds content to figure.
+ Defaults to None.
+ savefig (bool, optional): if true, figure is saved to png file. If false, `FigureOutput`
+ returned, which is input for `add_to_fig`. Defaults to True.
+
+
+
+ Returns:
+ FigureOutput | None: FigureOutput object or None
+
"""
# Get route and stations along route
routename = "-".join(route)
@@ -2223,7 +2210,8 @@ def figure_longitudinal(
case y if y in ["last3", "last25", "max3", "max13"]:
lines = self._stat_func(stat=y, route=route)
case _:
- raise KeyError(f"{stat} is unknown statistics")
+ err_msg = f"{stat} is unknown statistics"
+ raise KeyError(err_msg)
# Get figure object
if add_to_fig is None:
@@ -2236,7 +2224,7 @@ def figure_longitudinal(
if add_to_fig is None:
station_names, station_locs, _ = self.get_route(route)
st_names, st_locs = labelfunc(station_names, station_locs)
- h1d = self.get_data_along_route(data=self.data_1D_H, route=route)
+ h1d = self.get_data_along_route(data=self.data_1d_h, route=route)
for st_name, st_loc in zip(st_names, st_locs):
for ax in axs:
ax.axvline(x=st_loc, linestyle="--")
@@ -2279,12 +2267,11 @@ def figure_longitudinal(
)
plt.close()
- else:
- return FigureOutput(fig=fig, axes=axs, legend=lgd)
+ return FigureOutput(fig=fig, axes=axs, legend=lgd)
+
+ def figure_longitudinal_rating_curve(self, route: list[str]) -> None:
+ """Create a figure along a route with lines at various dicharges.
- def figure_longitudinal_rating_curve(self, route: List[str]) -> None:
- """
- Create a figure along a route with lines at various dicharges.
To to this, rating curves are generated at each point by digitizing
the model output.
@@ -2300,14 +2287,12 @@ def figure_longitudinal_rating_curve(self, route: List[str]) -> None:
routename = "-".join(route)
_, _, lmw_stations = self.get_route(route)
- h1d = self.get_data_along_route(data=self._data_1D_H_digitized, route=route)
- h2d = self.get_data_along_route(data=self._data_2D_H_digitized, route=route)
+ h1d = self.get_data_along_route(data=self._data_1d_h_digitized, route=route)
+ h2d = self.get_data_along_route(data=self._data_2d_h_digitized, route=route)
discharge_steps = list(self._iter_discharge_steps(h1d.T, n=8))
if len(discharge_steps) < 1:
- self.set_logger_message(
- "There is too little data to plot a QH relationship", "error"
- )
+ self.set_logger_message("There is too little data to plot a QH relationship", "error")
return
# Plot LMW station locations
@@ -2331,7 +2316,6 @@ def figure_longitudinal_rating_curve(self, route: List[str]) -> None:
)
# Plot betrekkingslijnen
- texty_previous = -999
for discharge in discharge_steps:
axs[0].plot(h1d[discharge], label=f"{discharge:.0f} m$^3$/s")
axs[0].set_ylabel("waterstand [m+nap]")
@@ -2360,10 +2344,8 @@ def figure_longitudinal_rating_curve(self, route: List[str]) -> None:
)
plt.close()
- def _iter_discharge_steps(self, data: pd.DataFrame, n: int = 5) -> List[float]:
- """
- Choose discharge steps based on increase in water level downstream
- """
+ def _iter_discharge_steps(self, data: pd.DataFrame, n: int = 5) -> Generator[float, None, None]:
+ """Choose discharge steps based on increase in water level downstream."""
station = data[data.columns[-1]]
wl_range = station.max() - station[station.index > 0].min()
@@ -2378,10 +2360,10 @@ def _iter_discharge_steps(self, data: pd.DataFrame, n: int = 5) -> List[float]:
q_at_t_previous = q_at_t
yield q_at_t
- def heatmap_time(self, route: List[str]) -> None:
- """
- Create a 2D heatmap along a route. The horizontal axis uses
- timemarks to match the 1D and 2D models
+ def heatmap_time(self, route: list[str]) -> None:
+ """Create a 2D heatmap along a route.
+
+ The horizontal axis uses timemarks to match the 1D and 2D models
Figures are saved to `[Compare1D2D.output_path]/figures/heatmap`
@@ -2392,10 +2374,9 @@ def heatmap_time(self, route: List[str]) -> None:
example output figure
"""
-
routename = "-".join(route)
_, _, lmw_stations = self.get_route(route)
- data = self._data_1D_H - self._data_2D_H
+ data = self._data_1d_h - self._data_2d_h
routedata = self.get_data_along_route(data.dropna(how="all"), route)
fig, ax = plt.subplots(1, figsize=(12, 7))
@@ -2410,9 +2391,7 @@ def heatmap_time(self, route: List[str]) -> None:
for lmw in lmw_stations:
if lmw is None:
continue
- ax.plot(
- routedata.columns, [lmw[1]] * len(routedata.columns), "--k", linewidth=1
- )
+ ax.plot(routedata.columns, [lmw[1]] * len(routedata.columns), "--k", linewidth=1)
ax.text(routedata.columns[0], lmw[1], lmw[0], fontsize=12)
ax.set_ylabel("rivierkilometer")
@@ -2422,15 +2401,13 @@ def heatmap_time(self, route: List[str]) -> None:
cb.set_label("waterstandsverschil [m+nap]".upper(), rotation=270, labelpad=15)
PlotStyles.apply(fig, style=self._plotstyle, use_legend=False)
fig.tight_layout()
- fig.savefig(
- self.output_path.joinpath(f"figures/heatmaps/{routename}_timeseries.png")
- )
+ fig.savefig(self.output_path.joinpath(f"figures/heatmaps/{routename}_timeseries.png"))
plt.close()
- def heatmap_rating_curve(self, route: List[str]) -> None:
- """
- Create a 2D heatmap along a route. The horizontal axis uses
- the digitized rating curves to match the two models
+ def heatmap_rating_curve(self, route: list[str]) -> None:
+ """Create a 2D heatmap along a route.
+
+ The horizontal axis uses the digitized rating curves to match the two models
Figures are saved to `[Compare1D2D.output_path]/figures/heatmap`
@@ -2441,10 +2418,9 @@ def heatmap_rating_curve(self, route: List[str]) -> None:
example output figure
"""
-
routename = "-".join(route)
_, _, lmw_stations = self.get_route(route)
- data = self._data_1D_H_digitized - self._data_2D_H_digitized
+ data = self._data_1d_h_digitized - self._data_2d_h_digitized
routedata = self.get_data_along_route(data.dropna(how="all"), route)
@@ -2460,9 +2436,7 @@ def heatmap_rating_curve(self, route: List[str]) -> None:
for lmw in lmw_stations:
if lmw is None:
continue
- ax.plot(
- routedata.columns, [lmw[1]] * len(routedata.columns), "--k", linewidth=1
- )
+ ax.plot(routedata.columns, [lmw[1]] * len(routedata.columns), "--k", linewidth=1)
ax.text(routedata.columns[0], lmw[1], lmw[0], fontsize=12)
ax.set_ylabel("rivierkilometer")
@@ -2472,18 +2446,13 @@ def heatmap_rating_curve(self, route: List[str]) -> None:
cb.set_label("waterstandsverschil [m+nap]".upper(), rotation=270, labelpad=15)
PlotStyles.apply(fig, style=self._plotstyle, use_legend=False)
fig.tight_layout()
- fig.savefig(
- self.output_path.joinpath(f"figures/heatmaps/{routename}_rating_curve.png")
- )
+ fig.savefig(self.output_path.joinpath(f"figures/heatmaps/{routename}_rating_curve.png"))
plt.close()
@staticmethod
- def _catch_e(
- func: Callable, exception: Union[Exception, Tuple[Exception]], *args, **kwargs
- ):
- """catch exception in function call. useful for list comprehensions"""
+ def _catch_e(func: Callable, exception: Exception | tuple[Exception], *args: tuple, **kwargs: dict) -> Any | float: # noqa: ANN401
+ """Catch exception in function call. useful for list comprehensions."""
try:
return func(*args, **kwargs)
- except exception as e:
+ except exception:
return np.nan
-
diff --git a/pyproject.toml b/pyproject.toml
index 1ae8b94e..eb7fb8d0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -62,3 +62,18 @@ docs = [
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
+
+
+[tool.ruff.lint]
+select = ["ALL"]
+
+[tool.ruff.lint.pydocstyle]
+convention = "google"
+
+[tool.ruff]
+line-length = 120
+ignore = ["DTZ005", "DTZ001"]
+exclude = ["scripts"]
+
+[tool.ruff.per-file-ignores]
+"tests/**" = ["D100", "D101", "D102", "D103", "D104", "PT001", "ANN201", "S101", "PLR2004", "ANN001"]
diff --git a/tests/test_acceptance.py b/tests/test_acceptance.py
index 519c5959..ecd5cf40 100644
--- a/tests/test_acceptance.py
+++ b/tests/test_acceptance.py
@@ -6,12 +6,11 @@
import pytest
-from fm2prof.Fm2ProfRunner import Fm2ProfRunner
-from fm2prof.IniFile import IniFile
+from fm2prof.fm2prof_runner import Fm2ProfRunner
+from fm2prof.ini_file import IniFile
from fm2prof.utils import VisualiseOutput
-
from tests.CompareIdealizedModel import CompareHelper, CompareIdealizedModel
-from tests.CompareWaalModel import CompareWaalModel as CompareWaalModel
+from tests.CompareWaalModel import CompareWaalModel
from tests.TestUtils import TestUtils, skipwhenexternalsmissing
_root_output_dir = None
@@ -143,9 +142,7 @@ def get_valid_inifile_input_parameters():
def _get_base_output_dir() -> Path:
- """
- Sets up the necessary data for MainMethodTest
- """
+ """Sets up the necessary data for MainMethodTest"""
output_dir = _create_artifact_dir(_run_with_files_dir_name)
# Create it if it does not exist
if not output_dir.is_dir():
@@ -154,8 +151,7 @@ def _get_base_output_dir() -> Path:
def _create_artifact_dir(dirName: Optional[str] = None) -> Path:
- """
- Create test output directory
+ """Create test output directory
so it's easier to collect output afterwards.
"""
artifacts_dir: Path = TestUtils.get_artifacts_test_data_dir()
@@ -171,11 +167,8 @@ def _create_artifact_dir(dirName: Optional[str] = None) -> Path:
return subOutputDir
-def _check_and_create_test_case_output_dir(
- base_output_dir: Path, caseName: str
-) -> Path:
- """
- Helper to split to set up an output directory
+def _check_and_create_test_case_output_dir(base_output_dir: Path, caseName: str) -> Path:
+ """Helper to split to set up an output directory
for the generated data of each test case.
"""
output_directory = base_output_dir / caseName
@@ -205,7 +198,12 @@ class Test_Run_Testcases:
)
@skipwhenexternalsmissing
def test_when_given_input_data_then_output_is_generated(
- self, case_name, map_file, css_file, region_file, section_file
+ self,
+ case_name,
+ map_file,
+ css_file,
+ region_file,
+ section_file,
):
# 1. Set up test data.
iniFilePath = None
@@ -216,7 +214,7 @@ def test_when_given_input_data_then_output_is_generated(
base_output_dir = _get_base_output_dir()
iniFile._set_output_directory_no_validation(
- str(_check_and_create_test_case_output_dir(base_output_dir, case_name))
+ str(_check_and_create_test_case_output_dir(base_output_dir, case_name)),
)
if region_file:
@@ -232,9 +230,7 @@ def test_when_given_input_data_then_output_is_generated(
# iniFile.set_parameter('ExportMapFiles', True)
iniFile.set_parameter("skipmaps", 6)
iniFile.set_input_file("2dmapoutput", str(test_data_dir / map_file))
- iniFile.set_input_file(
- "crosssectionlocationfile", str(test_data_dir / css_file)
- )
+ iniFile.set_input_file("crosssectionlocationfile", str(test_data_dir / css_file))
iniFile.set_input_file("regionpolygonfile", region_file_path)
iniFile.set_input_file("sectionpolygonfile", section_file_path)
@@ -244,23 +240,20 @@ def test_when_given_input_data_then_output_is_generated(
# 2. Verify precondition (no output generated)
assert (
- os.path.exists(iniFile.get_output_directory())
- and not len(os.listdir(iniFile.get_output_directory())) > 1
+ os.path.exists(iniFile.get_output_directory()) and not len(os.listdir(iniFile.get_output_directory())) > 1
)
# 3. Run file:
runner.run()
# 4. Verify there is output generated:
- assert os.listdir(
- iniFile.get_output_directory()
- ), "" + "There is no output generated for {0}".format(case_name)
+ assert os.listdir(iniFile.get_output_directory()), "" + f"There is no output generated for {case_name}"
class ARCHIVED_Test_Main_Run_IniFile:
def __run_main_with_arguments(self, ini_file):
- pythonCall = "fm2prof\\main.py -i {0}".format(ini_file)
- os.system("python {0}".format(pythonCall))
+ pythonCall = f"fm2prof\\main.py -i {ini_file}"
+ os.system(f"python {pythonCall}")
def __create_test_ini_file(self, root_dir, case_name, map_file, css_file):
output_dir = os.path.join(root_dir, "OutputFiles")
@@ -293,29 +286,27 @@ def __create_test_ini_file(self, root_dir, case_name, map_file, css_file):
}
# write file
- file_path = os.path.join(root_dir, "{}_ini_file.ini".format(case_name))
+ file_path = os.path.join(root_dir, f"{case_name}_ini_file.ini")
f = open(file_path, "w+")
- f.writelines("[{}]\r\n".format(input_files_key))
+ f.writelines(f"[{input_files_key}]\r\n")
for key, value in input_file_paths.items():
- f.writelines("{} = {}\r\n".format(key, value))
+ f.writelines(f"{key} = {value}\r\n")
f.writelines("\r\n")
- f.writelines("[{}]\r\n".format(input_parameters_key))
+ f.writelines(f"[{input_parameters_key}]\r\n")
for key, value in input_parameters.items():
- f.writelines("{} = {}\r\n".format(key, value))
+ f.writelines(f"{key} = {value}\r\n")
f.writelines("\r\n")
- f.writelines("[{}]\r\n".format(output_directory_key))
- f.writelines("OutputDir = {}\r\n".format(output_dir))
- f.writelines("CaseName = {}\r\n".format(case_name))
+ f.writelines(f"[{output_directory_key}]\r\n")
+ f.writelines(f"OutputDir = {output_dir}\r\n")
+ f.writelines(f"CaseName = {case_name}\r\n")
f.close()
return (file_path, output_dir)
def _get_custom_dir(self) -> Path:
- """
- Sets up the necessary data for MainMethodTest
- """
+ """Sets up the necessary data for MainMethodTest"""
return _create_artifact_dir("RunWithCustom_IniFile")
def test_when_given_inifile_then_output_is_generated(self):
@@ -324,9 +315,7 @@ def test_when_given_inifile_then_output_is_generated(self):
map_file = "fm_map.nc"
css_file = "fm_css.xyz"
root_output_dir = self._get_custom_dir()
- (ini_file_path, output_dir) = self.__create_test_ini_file(
- root_output_dir, case_name, map_file, css_file
- )
+ (ini_file_path, output_dir) = self.__create_test_ini_file(root_output_dir, case_name, map_file, css_file)
# 2. Verify precondition (no output generated)
assert os.path.exists(ini_file_path)
@@ -344,16 +333,14 @@ def test_when_given_inifile_then_output_is_generated(self):
except Exception as e_error:
if os.path.exists(root_output_dir):
shutil.rmtree(root_output_dir)
- pytest.fail("No exception expected but was thrown {}.".format(str(e_error)))
+ pytest.fail(f"No exception expected but was thrown {e_error!s}.")
# 4. Verify there is output generated:
- output_files = os.path.join(output_dir, "{}01".format(case_name))
+ output_files = os.path.join(output_dir, f"{case_name}01")
generated_files = os.listdir(output_files)
if os.path.exists(root_output_dir):
shutil.rmtree(root_output_dir)
- assert generated_files, "" + "There is no output generated for {0}".format(
- case_name
- )
+ assert generated_files, "" + f"There is no output generated for {case_name}"
for expected_file in expected_files:
assert expected_file in generated_files
@@ -371,9 +358,8 @@ def test_when_fm2prof_output_then_use_it_for_sobek_model_input(self):
fm_dir = str(waal_test_folder / "Model_FM")
fm2prof_dir = _get_test_case_output_dir(_waal_case)
-
# 2. Try to compare.
-
+
waal_comparer = CompareWaalModel()
output_1d, _ = waal_comparer._run_waal_1d_model(
case_name=_waal_case,
@@ -381,13 +367,10 @@ def test_when_fm2prof_output_then_use_it_for_sobek_model_input(self):
sobek_dir=sobek_dir,
fm_dir=fm_dir,
)
-
# 3. Verify final expectations
assert output_1d
- assert os.path.exists(output_1d), "" + "No output found at {}.".format(
- output_1d
- )
+ assert os.path.exists(output_1d), "" + f"No output found at {output_1d}."
def test_when_sobek_output_exist_then_create_figures(self):
# 1. Set up test data
@@ -399,7 +382,7 @@ def test_when_sobek_output_exist_then_create_figures(self):
result_figures = []
# 2. Try to compare.
-
+
waal_comparer = CompareWaalModel()
result_figures = waal_comparer._compare_waal(
case_name=_waal_case,
@@ -411,15 +394,11 @@ def test_when_sobek_output_exist_then_create_figures(self):
# 3. Verify final expectations
assert result_figures
for fig_path in result_figures:
- assert os.path.exists(fig_path), "" + "Figure not found at path {}.".format(
- fig_path
- )
+ assert os.path.exists(fig_path), "" + f"Figure not found at path {fig_path}."
@pytest.mark.acceptance
@pytest.mark.requires_output
- @pytest.mark.parametrize(
- ("case_name"), _test_scenarios_ids, ids=_test_scenarios_ids
- )
+ @pytest.mark.parametrize(("case_name"), _test_scenarios_ids, ids=_test_scenarios_ids)
def test_when_output_exists_then_compare_waal_model_volume(self, case_name: str):
if case_name != _waal_case:
# print('This case is tested on another fixture.')
@@ -434,21 +413,16 @@ def test_when_output_exists_then_compare_waal_model_volume(self, case_name: str)
input_volume_file = os.path.join(fm2prof_dir, volume_file_name)
# 2. Verify / create necessary folders and directories
- assert os.path.exists(
- input_volume_file
- ), "" + "Input file {} could not be found".format(input_volume_file)
+ assert os.path.exists(input_volume_file), "" + f"Input file {input_volume_file} could not be found"
if not os.path.exists(fm2prof_fig_dir):
os.makedirs(fm2prof_fig_dir)
- # 3. Run
+ # 3. Run
waal_comparer = CompareWaalModel()
waal_comparer._compare_volume(case_name, input_volume_file, fm2prof_fig_dir)
-
# 4. Final expectation
- assert os.listdir(
- fm2prof_fig_dir
- ), "" + "There is no volume output generated for {0}".format(case_name)
+ assert os.listdir(fm2prof_fig_dir), "" + f"There is no volume output generated for {case_name}"
class Test_Compare_Idealized_Model:
@@ -523,9 +497,7 @@ class Test_Compare_Idealized_Model:
# region for tests
@pytest.mark.acceptance
@pytest.mark.requires_output
- @pytest.mark.parametrize(
- ("case_name"), _test_scenarios_ids, ids=_test_scenarios_ids
- )
+ @pytest.mark.parametrize(("case_name"), _test_scenarios_ids, ids=_test_scenarios_ids)
def ARCHIVED_test_compare_generic_model_geometry(self, case_name: str):
if case_name == _waal_case:
# print('This case is tested on another fixture.')
@@ -541,9 +513,7 @@ def ARCHIVED_test_compare_generic_model_geometry(self, case_name: str):
input_geometry_file = os.path.join(fm2prof_dir, geometry_file_name)
# 2. Verify / create necessary folders and directories
- assert os.path.exists(
- input_geometry_file
- ), "" + "Input file {} could not be found".format(input_geometry_file)
+ assert os.path.exists(input_geometry_file), "" + f"Input file {input_geometry_file} could not be found"
if os.path.exists(fm2prof_fig_dir):
shutil.rmtree(fm2prof_fig_dir)
@@ -553,28 +523,19 @@ def ARCHIVED_test_compare_generic_model_geometry(self, case_name: str):
# 3. Run
tzw_values = self.__case_tzw_dict.get(case_name)
if not tzw_values or tzw_values is None:
- pytest.fail("Test failed, no values retrieved for {}".format(case_name))
+ pytest.fail(f"Test failed, no values retrieved for {case_name}")
-
generic_comparer = CompareIdealizedModel()
- generic_comparer._compare_css(
- case_name, tzw_values, input_geometry_file, fm2prof_fig_dir
- )
-
+ generic_comparer._compare_css(case_name, tzw_values, input_geometry_file, fm2prof_fig_dir)
+
# 4. Final expectation
- assert os.listdir(
- fm2prof_fig_dir
- ), "" + "There is no geometry output generated for {0}".format(case_name)
+ assert os.listdir(fm2prof_fig_dir), "" + f"There is no geometry output generated for {case_name}"
# region for tests
@pytest.mark.acceptance
@pytest.mark.requires_output
- @pytest.mark.parametrize(
- ("case_name"), _test_scenarios_ids, ids=_test_scenarios_ids
- )
- def ARCHIVED_test_when_output_exists_then_compare_generic_model_roughness(
- self, case_name: str
- ):
+ @pytest.mark.parametrize(("case_name"), _test_scenarios_ids, ids=_test_scenarios_ids)
+ def ARCHIVED_test_when_output_exists_then_compare_generic_model_roughness(self, case_name: str):
if case_name == _waal_case:
# print('This case is tested on another fixture.')
return
@@ -586,35 +547,24 @@ def ARCHIVED_test_when_output_exists_then_compare_generic_model_roughness(
input_roughness_file = os.path.join(fm2prof_dir, roughness_file_name)
# 2. Verify / create necessary folders and directories
- assert os.path.exists(
- input_roughness_file
- ), "" + "Input file {} could not be found".format(input_roughness_file)
+ assert os.path.exists(input_roughness_file), "" + f"Input file {input_roughness_file} could not be found"
if not os.path.exists(fm2prof_fig_dir):
os.makedirs(fm2prof_fig_dir)
# 3. Run
tzw_values = self.__case_tzw_dict.get(case_name)
if not tzw_values or tzw_values is None:
- pytest.fail("Test failed, no values retrieved for {}".format(case_name))
+ pytest.fail(f"Test failed, no values retrieved for {case_name}")
generic_comparer = CompareIdealizedModel()
- generic_comparer._compare_roughness(
- case_name, tzw_values, input_roughness_file, fm2prof_fig_dir
- )
-
+ generic_comparer._compare_roughness(case_name, tzw_values, input_roughness_file, fm2prof_fig_dir)
- assert os.listdir(
- fm2prof_fig_dir
- ), "" + "There is no roughness output generated for {0}".format(case_name)
+ assert os.listdir(fm2prof_fig_dir), "" + f"There is no roughness output generated for {case_name}"
@pytest.mark.acceptance
@pytest.mark.requires_output
- @pytest.mark.parametrize(
- ("case_name"), _test_scenarios_ids, ids=_test_scenarios_ids
- )
- def ARCHIVED_test_when_output_exists_then_compare_generic_model_volume(
- self, case_name: str
- ):
+ @pytest.mark.parametrize(("case_name"), _test_scenarios_ids, ids=_test_scenarios_ids)
+ def ARCHIVED_test_when_output_exists_then_compare_generic_model_volume(self, case_name: str):
if case_name == _waal_case:
# print('This case is tested on another fixture.')
return
@@ -628,34 +578,24 @@ def ARCHIVED_test_when_output_exists_then_compare_generic_model_volume(
input_volume_file = os.path.join(fm2prof_dir, volume_file_name)
# 2. Verify / create necessary folders and directories
- assert os.path.exists(
- input_volume_file
- ), "" + "Input file {} could not be found".format(input_volume_file)
+ assert os.path.exists(input_volume_file), "" + f"Input file {input_volume_file} could not be found"
if not os.path.exists(fm2prof_fig_dir):
os.makedirs(fm2prof_fig_dir)
# 3. Run
-
+
generic_comparer = CompareIdealizedModel()
- generic_comparer._compare_volume(
- case_name, input_volume_file, fm2prof_fig_dir
- )
-
+ generic_comparer._compare_volume(case_name, input_volume_file, fm2prof_fig_dir)
+
# 4. Final expectation
- assert os.listdir(
- fm2prof_fig_dir
- ), "" + "There is no volume output generated for {0}".format(case_name)
+ assert os.listdir(fm2prof_fig_dir), "" + f"There is no volume output generated for {case_name}"
@pytest.mark.acceptance
@pytest.mark.requireoutput
- @pytest.mark.parametrize(
- ("case_name"), _test_scenarios_ids[:-3], ids=_test_scenarios_ids[:-3]
- )
+ @pytest.mark.parametrize(("case_name"), _test_scenarios_ids[:-3], ids=_test_scenarios_ids[:-3])
@skipwhenexternalsmissing
def test_when_output_exists_then_compare_with_reference(self, case_name: str):
- """
- This test is supposed to supercede the others
- """
+ """This test is supposed to supercede the others"""
# 1. Get all necessary output / input directories
reference_geometry = self.__case_tzw_dict.get(case_name)
fm2prof_dir = _get_test_case_output_dir(case_name)
@@ -670,9 +610,7 @@ def test_when_output_exists_then_compare_with_reference(self, case_name: str):
ref = CompareHelper.convert_ZW_to_symmetric_css(ref)
ref_friction = frictionhelper(css)
- visualiser.make_figure(
- css, reference_geometry=ref, reference_roughness=ref_friction
- )
+ visualiser.make_figure(css, reference_geometry=ref, reference_roughness=ref_friction)
class ARCHIVED_Test_WaalPerformance:
@@ -692,9 +630,7 @@ def test_when_waal_case_then_performance_is_slow(self):
ini_file_path = None
test_ini_file = IniFile(ini_file_path)
base_output_dir = _get_base_output_dir()
- test_ini_file._output_dir = str(
- _check_and_create_test_case_output_dir(base_output_dir, case_name)
- )
+ test_ini_file._output_dir = str(_check_and_create_test_case_output_dir(base_output_dir, case_name))
test_ini_file._input_file_paths = {
"fm_netcdfile": map_file,
@@ -725,13 +661,11 @@ def test_when_waal_case_then_performance_is_slow(self):
assert os.path.exists(css_file), "" + "CrossSection (test) file was not found"
# 3. Run test.
-
+
runner = Fm2ProfRunner(iniFilePath=None)
runner.run_inifile(iniFile=test_ini_file)
-
# 4. Verify final expectations.
- pass
def test_dummy_timing(self):
import timeit
@@ -749,7 +683,7 @@ def fs():
assert f_res < fs_res
def merge_names(self, a, b):
- val = "{} & {}".format(a, b)
+ val = f"{a} & {b}"
return val
def test_dummy_mp(self):
diff --git a/tests/test_crosssection.py b/tests/test_crosssection.py
index ee9e9693..add5d1da 100644
--- a/tests/test_crosssection.py
+++ b/tests/test_crosssection.py
@@ -2,53 +2,123 @@
import pytest
-from fm2prof.CrossSection import CrossSection
+from fm2prof.cross_section import CrossSection
import pickle
from tests.TestUtils import TestUtils
css_test_dir = "cross_sections"
-test_cases = [dict(
- name="waal_1_40147.826",
- css_z = np.array([-0.9141 , -0.2808732 , 0.3523536 , 0.9855804 , 1.61880719,
- 2.25203399, 2.88526079, 3.51848759, 4.15171439, 4.78494119,
- 4.78494119, 4.83562453, 5.0825522 , 5.35312778, 5.639405 ,
- 5.94391206, 6.23998542, 6.54804721, 6.84868257, 7.13847582,
- 7.38282862, 7.63388787, 7.89941093, 8.18003839, 8.4205877 ,
- 8.68096833, 8.90767707, 9.14537943, 9.39478408, 9.62215027,
- 9.81724704, 10.01426908, 10.2297726 , 10.46081785, 10.71071257,
- 10.99499479, 11.29960472, 11.62194869, 11.95376401, 12.20920782,
- 12.26486054, 12.27733635]),
- css_total_volume = np.array([ 0. , 202411.92430511, 440863.87373976,
- 714537.17597224, 1023088.46888726, 1370281.34820867,
- 1746175.98609244, 2142631.45969038, 2559469.10395432,
- 2991123.22671975, 2991123.22671975, 3028190.8376371 ,
- 3218864.47242984, 3430109.5158954 , 3655876.28183273,
- 3898019.94718616, 4136564.29328984, 4391422.83609613,
- 4652111.1476516 , 4914018.51771363, 5139429.93097691,
- 5372719.6548311 , 5620726.47479325, 5884181.24648082,
- 6110863.38850093, 6359954.26342078, 6580688.49745023,
- 6813452.44399827, 7062291.71909494, 7293350.75974847,
- 7491616.26087858, 7691838.29814542, 7910841.99672851,
- 8145639.84965338, 8399593.31489642, 8693502.87558943,
- 9016339.67243312, 9360662.51019234, 9715102.4974322 ,
- 9987963.61958959, 10047410.99351798, 10060737.46221183]),
- crest_level = 4.573300167187546,
- extra_total_volume = 689298.2236775636)]
-
+test_cases = [
+ dict(
+ name="waal_1_40147.826",
+ css_z=np.array(
+ [
+ -0.9141,
+ -0.2808732,
+ 0.3523536,
+ 0.9855804,
+ 1.61880719,
+ 2.25203399,
+ 2.88526079,
+ 3.51848759,
+ 4.15171439,
+ 4.78494119,
+ 4.78494119,
+ 4.83562453,
+ 5.0825522,
+ 5.35312778,
+ 5.639405,
+ 5.94391206,
+ 6.23998542,
+ 6.54804721,
+ 6.84868257,
+ 7.13847582,
+ 7.38282862,
+ 7.63388787,
+ 7.89941093,
+ 8.18003839,
+ 8.4205877,
+ 8.68096833,
+ 8.90767707,
+ 9.14537943,
+ 9.39478408,
+ 9.62215027,
+ 9.81724704,
+ 10.01426908,
+ 10.2297726,
+ 10.46081785,
+ 10.71071257,
+ 10.99499479,
+ 11.29960472,
+ 11.62194869,
+ 11.95376401,
+ 12.20920782,
+ 12.26486054,
+ 12.27733635,
+ ]
+ ),
+ css_total_volume=np.array(
+ [
+ 0.0,
+ 202411.92430511,
+ 440863.87373976,
+ 714537.17597224,
+ 1023088.46888726,
+ 1370281.34820867,
+ 1746175.98609244,
+ 2142631.45969038,
+ 2559469.10395432,
+ 2991123.22671975,
+ 2991123.22671975,
+ 3028190.8376371,
+ 3218864.47242984,
+ 3430109.5158954,
+ 3655876.28183273,
+ 3898019.94718616,
+ 4136564.29328984,
+ 4391422.83609613,
+ 4652111.1476516,
+ 4914018.51771363,
+ 5139429.93097691,
+ 5372719.6548311,
+ 5620726.47479325,
+ 5884181.24648082,
+ 6110863.38850093,
+ 6359954.26342078,
+ 6580688.49745023,
+ 6813452.44399827,
+ 7062291.71909494,
+ 7293350.75974847,
+ 7491616.26087858,
+ 7691838.29814542,
+ 7910841.99672851,
+ 8145639.84965338,
+ 8399593.31489642,
+ 8693502.87558943,
+ 9016339.67243312,
+ 9360662.51019234,
+ 9715102.4974322,
+ 9987963.61958959,
+ 10047410.99351798,
+ 10060737.46221183,
+ ]
+ ),
+ crest_level=4.573300167187546,
+ extra_total_volume=689298.2236775636,
+ )
+]
class Test_generate_cross_section_instance:
-
def test_when_wrong_input_dict_is_given_then_expected_exception_risen(self):
# 1. Set up test data
test_css_name = "dummy_css"
css_data = {"id": test_css_name}
# 2. Set expectations
- expected_error = ("'Input data does not have all required keys'")
-
+ expected_error = "'Input data does not have all required keys'"
+
# 3. Run test
with pytest.raises(KeyError) as e_info:
CrossSection(data=css_data)
@@ -56,39 +126,38 @@ def test_when_wrong_input_dict_is_given_then_expected_exception_risen(self):
# 4. Verify final expectations
error_message = str(e_info.value)
assert error_message == expected_error, (
- ""
- + "Expected exception message {},".format(expected_error)
- + " retrieved {}".format(error_message)
+ "" + "Expected exception message {},".format(expected_error) + " retrieved {}".format(error_message)
)
-
+
def test_when_correct_input_dict_is_given_CrossSection_initialises(self):
# 1. Set up test data
tdir = TestUtils.get_local_test_data_dir(css_test_dir)
- with open(tdir.joinpath(f"{test_cases[0].get('name')}.pickle"), 'rb') as f:
+ with open(tdir.joinpath(f"{test_cases[0].get('name')}.pickle"), "rb") as f:
css_data = pickle.load(f)
- # 2. Set expectations
- # 3. Run test
+ # 2. Set expectations
+ # 3. Run test
css = CrossSection(data=css_data)
-
+
# 4. Verify final expectations
- assert css.length == css_data.get('length')
+ assert css.length == css_data.get("length")
+
class Test_cross_section_construction:
def test_build_geometry(self):
# 1. Set up test data
test_case: dict = test_cases[0]
tdir = TestUtils.get_local_test_data_dir(css_test_dir)
- with open(tdir.joinpath(f"{test_case.get('name')}.pickle"), 'rb') as f:
+ with open(tdir.joinpath(f"{test_case.get('name')}.pickle"), "rb") as f:
css_data = pickle.load(f)
tol = 1e-6
- # 2. Set expectations for
- css_z: np.array = test_case.get('css_z')
- css_total_volume: np.array = test_case.get('css_total_volume')
+ # 2. Set expectations for
+ css_z: np.array = test_case.get("css_z")
+ css_total_volume: np.array = test_case.get("css_total_volume")
- # 3. Run test
+ # 3. Run test
css = CrossSection(data=css_data)
css.build_geometry()
@@ -98,25 +167,23 @@ def test_build_geometry(self):
assert all([abs(css_z[i] - css._css_z[i]) < tol for i in range(len(css_z))])
assert all([abs(css_total_volume[i] - css._css_total_volume[i]) < tol for i in range(len(css_total_volume))])
-
def test_calculate_correction(self):
# 1. Set up test data
test_case: dict = test_cases[0]
tdir = TestUtils.get_local_test_data_dir(css_test_dir)
- with open(tdir.joinpath(f"{test_case.get('name')}.pickle"), 'rb') as f:
+ with open(tdir.joinpath(f"{test_case.get('name')}.pickle"), "rb") as f:
css_data = pickle.load(f)
tol = 1e-6
- # 2. Set expectations for
- crest_level:float = test_case.get('crest_level') # type: ignore
- extra_total_volume: np.ndarray = test_case.get('extra_total_volume') # type: ignore
+ # 2. Set expectations for
+ crest_level: float = test_case.get("crest_level") # type: ignore
+ extra_total_volume: np.ndarray = test_case.get("extra_total_volume") # type: ignore
- # 3. Run test
+ # 3. Run test
css = CrossSection(data=css_data)
css.build_geometry()
css.calculate_correction()
-
# 4. Verify final expectations
assert abs(crest_level - css.crest_level) < tol
@@ -126,14 +193,14 @@ def test_reduce_points(self):
# 1. Set up test data
test_case: dict = test_cases[0]
tdir = TestUtils.get_local_test_data_dir(css_test_dir)
- with open(tdir.joinpath(f"{test_case.get('name')}.pickle"), 'rb') as f:
+ with open(tdir.joinpath(f"{test_case.get('name')}.pickle"), "rb") as f:
css_data = pickle.load(f)
- # 2. Run test
+ # 2. Run test
css = CrossSection(data=css_data)
css.build_geometry()
css.calculate_correction()
css.reduce_points(count_after=20)
-
+
# 3. Verify final expectations
- assert len(css.z) == 20
\ No newline at end of file
+ assert len(css.z) == 20
diff --git a/tests/test_Import.py b/tests/test_data_import.py
similarity index 53%
rename from tests/test_Import.py
rename to tests/test_data_import.py
index a046cf1e..8331df3f 100644
--- a/tests/test_Import.py
+++ b/tests/test_data_import.py
@@ -2,7 +2,7 @@
import pytest
-from fm2prof.Import import FMDataImporter, FmModelData
+from fm2prof.data_import import FMDataImporter, FmModelData
from tests.TestUtils import TestUtils, skipwhenexternalsmissing
@@ -10,9 +10,7 @@ class Test_FMDataImporter:
@skipwhenexternalsmissing
def test_when_map_file_without_czu_no_exception(self):
# 1. Set up test data
- test_map = Path(TestUtils.get_local_test_data_dir("main_test_data")).joinpath(
- "fm_map.nc"
- )
+ test_map = Path(TestUtils.get_local_test_data_dir("main_test_data")).joinpath("fm_map.nc")
assert test_map.is_file()
# 2. Set initial expectations
@@ -23,52 +21,6 @@ def test_when_map_file_without_czu_no_exception(self):
class Test_FmModelData:
- @pytest.mark.parametrize("arg_list", [(""), (None)])
- def test_when_no_arguments_exception_is_risen(self, arg_list):
- # 1. Set up test data
- arg_list = arg_list
-
- # 2. Set initial expectations
- expected_error_message = "FM model data was not read correctly."
-
- # 3. Run test
- with pytest.raises(Exception) as pytest_wrapped_e:
- FmModelData(arg_list)
-
- # 4. Verify final expectations
- recieved_error_message = str(pytest_wrapped_e.value)
- assert expected_error_message == recieved_error_message, (
- ""
- + "Expected error message {},".format(expected_error_message)
- + " does not match generated {}".format(recieved_error_message)
- )
-
- @pytest.mark.parametrize("arg_list", [(""), (None)])
- def test_when_argument_length_not_as_expected_then_exception_is_risen(
- self, arg_list
- ):
- # 1. Set up test data
- arg_list = ["arg1", "arg2"]
-
- # 2. Set initial expectations
- expected_error_message = (
- ""
- + "Fm model data expects 5 arguments but only"
- + " {} were given".format(len(arg_list))
- )
-
- # 3. Run test
- with pytest.raises(Exception) as pytest_wrapped_e:
- FmModelData(arg_list)
-
- # 4. Verify final expectations
- recieved_error_message = str(pytest_wrapped_e.value)
- assert expected_error_message == recieved_error_message, (
- ""
- + "Expected error message {},".format(expected_error_message)
- + " does not match generated {}".format(recieved_error_message)
- )
-
def test_when_given_expected_arguments_then_object_is_created(self):
# 1. Set up test data
time_dependent_data = "arg1"
@@ -76,17 +28,16 @@ def test_when_given_expected_arguments_then_object_is_created(self):
edge_data = "arg3"
node_coordinates = "arg4"
css_data = "arg5"
- arg_list = [
- time_dependent_data,
- time_independent_data,
- edge_data,
- node_coordinates,
- css_data,
- ]
return_fm_model_data = None
- # 2. Run test
- return_fm_model_data = FmModelData(arg_list)
+ # 2. Run test
+ return_fm_model_data = FmModelData(
+ time_dependent_data=time_dependent_data,
+ time_independent_data=time_independent_data,
+ edge_data=edge_data,
+ node_coordinates=node_coordinates,
+ css_data_dictionary=css_data,
+ )
# 4. Verify final expectations
assert return_fm_model_data is not None
@@ -108,21 +59,20 @@ def test_when_given_data_dictionary_then_css_data_list_is_set(self):
dummy_key: dummy_values,
}
- arg_list = [
- time_dependent_data,
- time_independent_data,
- edge_data,
- node_coordinates,
- css_data_dict,
- ]
return_fm_model_data = None
# 2. Set expectations
expected_css_data_list = [{dummy_key: 0}, {dummy_key: 1}]
- # 3. Run test
- return_fm_model_data = FmModelData(arg_list)
-
+ # 3. Run test
+ return_fm_model_data = FmModelData(
+ time_dependent_data=time_dependent_data,
+ time_independent_data=time_independent_data,
+ edge_data=edge_data,
+ node_coordinates=node_coordinates,
+ css_data_dictionary=css_data_dict,
+ )
+
# 4. Verify final expectations
assert return_fm_model_data is not None
assert return_fm_model_data.css_data_list != css_data_dict
@@ -144,13 +94,11 @@ def test_when_given_dictionary_then_returns_list(self):
# 3. Run test
return_list = FmModelData.get_ordered_css_list(test_dict)
-
+
# 4. Verify final expectations
assert return_list is not None
assert return_list == expected_list, (
- ""
- + "Expected return value {},".format(expected_list)
- + " but return {} instead.".format(return_list)
+ "" + "Expected return value {},".format(expected_list) + " but return {} instead.".format(return_list)
)
@pytest.mark.parametrize("test_dict", [(""), (None), ({})])
@@ -161,13 +109,11 @@ def test_when_given_unexpected_value_then_returns_empty_list(self, test_dict):
# 2. Run test
expected_list = []
- # 3. Run test
+ # 3. Run test
return_list = FmModelData.get_ordered_css_list(test_dict)
-
+
# 4. Verify final expectations
assert return_list is not None
assert return_list == expected_list, (
- ""
- + "Expected return value {},".format(expected_list)
- + " but return {} instead.".format(return_list)
+ "" + "Expected return value {},".format(expected_list) + " but return {} instead.".format(return_list)
)
diff --git a/tests/test_fm2profrunner.py b/tests/test_fm2profrunner.py
index 81570f64..ebc98cb4 100644
--- a/tests/test_fm2profrunner.py
+++ b/tests/test_fm2profrunner.py
@@ -1,7 +1,7 @@
import os
from fm2prof import Project
-from fm2prof.Fm2ProfRunner import Fm2ProfRunner
+from fm2prof.fm2prof_runner import Fm2ProfRunner
from tests.TestUtils import TestUtils
@@ -20,9 +20,9 @@ def test_if_get_existing_parameter_then_returned(self):
project = None
value = None
- # 2. Run test
+ # 2. Run test
project = Project()
- value = project.get_parameter("LakeTimeSteps")
+ value = project.get_parameter("LakeTimeSteps")
# 3. Verify final expectations
assert project is not None
@@ -32,10 +32,10 @@ def test_if_get_nonexisting_parameter_then_no_exception(self):
# 1. Set up initial test dat
project = None
value = None
- # 2. Run test
+ # 2. Run test
project = Project()
value = project.get_parameter("IDoNoTExist")
-
+
# 3. Verify final expectations
assert project is not None
assert value is None
@@ -45,9 +45,9 @@ def test_if_get_existing_inputfile_then_returned(self):
project = None
value = None
- # 2. Run test
+ # 2. Run test
project = Project()
- value = project.get_input_file("CrossSectionLocationFile")
+ value = project.get_input_file("CrossSectionLocationFile")
# 3. Verify final expectations
assert project is not None
@@ -58,9 +58,9 @@ def test_if_get_output_directory_then_returned(self):
project = None
value = None
- # 2. Run test
+ # 2. Run test
project = Project()
- value = project.get_output_directory()
+ value = project.get_output_directory()
# 3. Verify final expectations
assert project is not None
@@ -71,9 +71,9 @@ def test_set_parameter(self):
project = None
value = 150
- # 2. Run test
+ # 2. Run test
project = Project()
- project.set_parameter("LakeTimeSteps", value)
+ project.set_parameter("LakeTimeSteps", value)
# 3. Verify final expectations
assert project.get_parameter("LakeTimeSteps") == value
@@ -82,32 +82,32 @@ def test_set_input_file(self):
# 1. Set up initial test dat
project = None
value = "RandomString"
- # 2. Run test
+ # 2. Run test
project = Project()
- project.set_input_file("CrossSectionLocationFile", value)
+ project.set_input_file("CrossSectionLocationFile", value)
# 3. Verify final expectations
assert project.get_input_file("CrossSectionLocationFile") == value
- def test_set_output_directory(self):
+ def test_set_output_directory(self, tmp_path):
# 1. Set up initial test dat
project = None
- value = "test/subdir"
- # 2. Run test
+ # 2. Run test
project = Project()
- project.set_output_directory(value)
-
+ project.set_output_directory(tmp_path)
+
def test_print_configuration(self):
# 1. Set up initial test dat
project = None
value = None
- # 2. Run test
+ # 2. Run test
project = Project()
value = project.print_configuration()
-
+
# 3. Verify final expectations
assert value is not None
+
class Test_Fm2ProfRunner:
def test_when_no_file_path_then_no_exception_is_risen(self):
# 1. Set up initial test dat
@@ -115,7 +115,6 @@ def test_when_no_file_path_then_no_exception_is_risen(self):
# 2. Run test
runner = Fm2ProfRunner()
-
# 3. Verify final expectations
assert runner is not None
@@ -129,21 +128,17 @@ def test_given_inifile_then_no_exception_is_risen(self):
runner = None
# 2. Verify the initial expectations
- assert os.path.exists(ini_file_path), "" "Test File {} was not found".format(
- ini_file_path
- )
+ assert os.path.exists(ini_file_path), f"Test File {ini_file_path} was not found"
- # 3. Run test
- runner = Fm2ProfRunner(ini_file_path)
+ # 3. Run test
+ runner = Fm2ProfRunner(ini_file_path)
# 4. Verify final expectations
assert runner is not None
def test_run_with_inifile(self):
# 1. Set up test data
- inifile = TestUtils.get_local_test_file(
- "cases/case_02_compound/fm2prof_config.ini"
- )
+ inifile = TestUtils.get_local_test_file("cases/case_02_compound/fm2prof_config.ini")
# 3. run test
project = Project(inifile)
@@ -154,9 +149,7 @@ def test_run_with_inifile(self):
def test_run_with_overwrite_false_output_unchanged(self):
# 1. Set up test data
- inifile = TestUtils.get_local_test_file(
- "cases/case_02_compound/fm2prof_config.ini"
- )
+ inifile = TestUtils.get_local_test_file("cases/case_02_compound/fm2prof_config.ini")
# 2. set expections
project = Project(inifile)
@@ -171,9 +164,7 @@ def test_run_with_overwrite_false_output_unchanged(self):
def test_run_with_overwrite_true_output_has_changed(self):
# 1. Set up test data
- inifile = TestUtils.get_local_test_file(
- "cases/case_02_compound/fm2prof_config.ini"
- )
+ inifile = TestUtils.get_local_test_file("cases/case_02_compound/fm2prof_config.ini")
# 2. set expections
project = Project(inifile)
diff --git a/tests/test_functions.py b/tests/test_functions.py
index f072591c..adf568c4 100644
--- a/tests/test_functions.py
+++ b/tests/test_functions.py
@@ -3,12 +3,11 @@
import numpy as np
import pytest
-import fm2prof.Functions as Func
+import fm2prof.functions as Func
from tests.TestUtils import TestUtils
class ARCHIVED_Test_read_css_xyz:
-
_test_scenarios_invalid_file_paths = [(None), (""), ("dummyFilePath")]
@pytest.mark.parametrize("file_path", _test_scenarios_invalid_file_paths)
@@ -38,14 +37,11 @@ def test_read_css_xyz_valid_file_path_returns_expected_input_data(self):
}
# 2. Verify the initial expectation
- assert os.path.exists(file_path), "" + "Test File {} could not be found".format(
- file_path
- )
+ assert os.path.exists(file_path), "" + "Test File {} could not be found".format(file_path)
# 3. Run test
-
+
result_input_data = Func._read_css_xyz(file_path)
-
# 4. Verify final expectations
for expected_input_key in expected_input_data:
diff --git a/tests/test_inifile.py b/tests/test_inifile.py
index eacda7ea..300c0b0d 100644
--- a/tests/test_inifile.py
+++ b/tests/test_inifile.py
@@ -2,7 +2,7 @@
import pytest
-from fm2prof.IniFile import IniFile
+from fm2prof.ini_file import IniFile
_root_output_dir = None
@@ -20,16 +20,13 @@ def test_when_no_file_path_then_no_exception_is_risen(self):
# 2. Run test
IniFile(iniFilePath)
-
def test_when_non_existent_file_path_then_io_exception_is_risen(self):
# 1. Set up initial test data
ini_file_path = "nonexistent_ini_file.ini"
# 2. Set expectations
- expected_error = "" + "The given file path {} could not be found".format(
- ini_file_path
- )
+ expected_error = "" + "The given file path {} could not be found".format(ini_file_path)
# 3. Run test
with pytest.raises(IOError) as e_info:
@@ -38,9 +35,7 @@ def test_when_non_existent_file_path_then_io_exception_is_risen(self):
# 4. Verify final expectations
error_message = str(e_info.value)
assert error_message == expected_error, (
- ""
- + "Expected exception message {},".format(expected_error)
- + "retrieved {}".format(error_message)
+ "" + "Expected exception message {},".format(expected_error) + "retrieved {}".format(error_message)
)
@pytest.mark.parametrize("output_dir, expected_value", _test_scenarios_output_dirs)
@@ -50,8 +45,8 @@ def test_set_output_dir_with_valid_input(self, output_dir, expected_value):
iniFile = IniFile(ini_file_path)
new_output_dir = None
- # 2. Run test
- new_output_dir = iniFile.set_output_directory(output_dir)
+ # 2. Run test
+ new_output_dir = iniFile.set_output_directory(output_dir)
# 3. Verify final expectations
assert Path(expected_value) == new_output_dir.relative_to(Path().cwd())
diff --git a/tests/test_maskoutputfile.py b/tests/test_maskoutputfile.py
index 633e5bbf..35c22b17 100644
--- a/tests/test_maskoutputfile.py
+++ b/tests/test_maskoutputfile.py
@@ -3,25 +3,27 @@
import geojson
import pytest
-from fm2prof.MaskOutputFile import MaskOutputFile
+from fm2prof.mask_output_file import (
+ create_mask_point,
+ write_mask_output_file,
+ read_mask_output_file,
+ validate_extension,
+)
from tests.TestUtils import TestUtils
-
class Test_create_mask_point:
def test_when_no_coords_does_raise(self):
# 1. Set up test model
coords = None
properties = None
-
+
# 2. Set up initial expectations
expected_error = "coords cannot be empty."
# 3. Do test
- with pytest.raises(ValueError, match=expected_error):
- MaskOutputFile.create_mask_point(coords, properties)
-
-
+ with pytest.raises(ValueError, match=expected_error):
+ create_mask_point(coords, properties)
@pytest.mark.parametrize("coords_values", [(4.2, 2.4), (4.2, 2.4, 42)])
def test_when_valid_coords_does_not_raise(self, coords_values: tuple):
@@ -29,81 +31,69 @@ def test_when_valid_coords_does_not_raise(self, coords_values: tuple):
coords = coords_values
properties = None
- # 2. Do test
- mask_point = MaskOutputFile.create_mask_point(coords, properties)
-
+ # 2. Do test
+ mask_point = create_mask_point(coords, properties)
+
# 3. Verify final expectations
assert mask_point is not None, "No mask_point generated"
-
-class Test_validate_extension:
+class Test_validate_extension:
def test_when_none_file_path_does_raise(self):
with pytest.raises(TypeError, match=f"file_path should be string or Path, not {type(None)}"):
- MaskOutputFile.validate_extension(None)
-
+ validate_extension(None)
def test_when_invalid_extension_raises_expected_exception(self):
# 1. Set up test data
file_path = Path("test.sjon")
-
- # 2. Set expectations
+
+ # 2. Set expectations
expected_IOError = "Invalid file path extension, should be .json or .geojson."
expectedTypeError = f"file_path should be string or Path, not {int}"
-
+
# 3. Run test
with pytest.raises(IOError, match=expected_IOError):
- MaskOutputFile.validate_extension(str(file_path))
-
- with pytest.raises(TypeError, match=expectedTypeError):
- MaskOutputFile.validate_extension(1)
-
+ validate_extension(str(file_path))
+ with pytest.raises(TypeError, match=expectedTypeError):
+ validate_extension(1)
@pytest.mark.parametrize("file_name", [("n.json"), ("n.geojson")])
- def test_when_valid_extension_does_not_raise(self, file_name):
- MaskOutputFile.validate_extension(file_name)
-
+ def test_when_valid_extension_does_not_raise(self, file_name):
+ validate_extension(file_name)
+
class Test_read_mask_output_file:
def test_when_invalid_file_path_given(self, tmp_path):
# 1. Set up test data
- not_existing_file = tmp_path / "test.geojson"
+ not_existing_file = tmp_path / "test.geojson"
- # 2. Set expectation
+ # 2. Set expectation
expected_error = f"File path {not_existing_file} not found"
# 3. Run test
with pytest.raises(FileNotFoundError, match=expected_error):
- MaskOutputFile.read_mask_output_file(not_existing_file)
+ read_mask_output_file(not_existing_file)
-
- def test_when_valid_file_with_no_content_then_returns_expected_geojson(
- self, tmp_path
- ):
+ def test_when_valid_file_with_no_content_then_returns_expected_geojson(self, tmp_path):
# 1. Set up test data
file_path: Path = tmp_path / "empty.geojson"
with file_path.open("w") as f:
- geojson.dump({},f)
-
+ geojson.dump({}, f)
+
# 2. Set up expectations
expected_error = "File is empty or not a valid geojson file."
# Run test
with pytest.raises(IOError, match=expected_error):
- MaskOutputFile.read_mask_output_file(file_path)
-
-
+ read_mask_output_file(file_path)
def test_when_valid_file_with_content_then_returns_expected_geojson(self):
# 1. Set up test data
- file_path = (
- TestUtils.get_local_test_data_dir("maskoutputfile_test_data")
- / "mask_points.geojson"
- )
+ file_path = TestUtils.get_local_test_data_dir("maskoutputfile_test_data") / "mask_points.geojson"
# 2. Read data
- read_geojson = MaskOutputFile.read_mask_output_file(file_path)
+ read_geojson = read_mask_output_file(file_path)
# 3. Verify final expectations
assert read_geojson, "No geojson data was generated."
@@ -112,24 +102,19 @@ def test_when_valid_file_with_content_then_returns_expected_geojson(self):
class Test_write_mask_output_file:
- def test_when_valid_file_path_and_no_mask_point_then_writes_expectations(
- self, tmp_path: Path
- ):
- # 1. Set up test data
+ def test_when_valid_file_path_and_no_mask_point_then_writes_expectations(self, tmp_path: Path):
+ # 1. Set up test data
file_path = tmp_path / "no_mask_points.geojson"
-
+
# 2. Set up expectations
error_msg = "mask_points cannot be empty"
# 3. Run test
with pytest.raises(ValueError, match=error_msg):
- MaskOutputFile.write_mask_output_file(file_path=file_path, mask_points=None)
-
+ write_mask_output_file(file_path=file_path, mask_points=None)
- def test_when_valid_file_path_and_mask_points_then_writes_expectations(
- self, tmp_path
- ):
- # 1. Set up test data
+ def test_when_valid_file_path_and_mask_points_then_writes_expectations(self, tmp_path):
+ # 1. Set up test data
file_path = tmp_path / "mask_points.geojson"
mask_points = [
geojson.Feature(geometry=geojson.utils.generate_random("Point")),
@@ -139,8 +124,8 @@ def test_when_valid_file_path_and_mask_points_then_writes_expectations(
# 2. Set expectations
expected_mask_points = geojson.FeatureCollection(mask_points)
- # 3. Run test
- MaskOutputFile.write_mask_output_file(file_path, mask_points)
+ # 3. Run test
+ write_mask_output_file(file_path, mask_points)
# 4. Verify final expectations
assert file_path.exists(), f"File {file_path} not found."
diff --git a/tests/test_regionpolygonfile.py b/tests/test_regionpolygonfile.py
index b4db15f6..a30e35a7 100644
--- a/tests/test_regionpolygonfile.py
+++ b/tests/test_regionpolygonfile.py
@@ -1,8 +1,8 @@
+import json
import logging
import os
import timeit
from random import randint
-import json
import matplotlib.pyplot as plt
import numpy as np
@@ -10,9 +10,9 @@
from pytest import fixture
from shapely.geometry import Polygon
-import fm2prof.Functions as FE
-from fm2prof.RegionPolygonFile import Polygon as p_tuple
-from fm2prof.RegionPolygonFile import PolygonFile, SectionPolygonFile, RegionPolygonFile
+import fm2prof.functions as funcs
+from fm2prof.region_polygon_file import Polygon as p_tuple
+from fm2prof.region_polygon_file import PolygonFile, RegionPolygonFile, SectionPolygonFile
from tests.TestUtils import TestUtils
@@ -26,14 +26,10 @@ def test_action_regular(self):
return self.polygon_file.classify_points_with_property(self.xy_list)
def test_action_regular_prep(self):
- return self.polygon_file.classify_points_with_property_shapely_prep(
- iter(self.xy_list)
- )
+ return self.polygon_file.classify_points_with_property_shapely_prep(iter(self.xy_list))
def test_action_polygons(self):
- return self.polygon_file.classify_points_with_property_rtree_by_polygons(
- self.xy_list
- )
+ return self.polygon_file.classify_points_with_property_rtree_by_polygons(self.xy_list)
@staticmethod
def test_action_regular_static(polygon_file, xy_list):
@@ -41,9 +37,7 @@ def test_action_regular_static(polygon_file, xy_list):
@staticmethod
def test_action_prep_static(polygon_file, xy_list):
- return polygon_file.classify_points_with_property_shapely_prep(
- iter(xy_list)
- )
+ return polygon_file.classify_points_with_property_shapely_prep(iter(xy_list))
@staticmethod
def test_action_polygons_static(polygon_file, xy_list):
@@ -88,9 +82,7 @@ def __get_random_point(self, max_x, max_y):
(ClassifierApproaches.test_action_polygons_static),
],
)
- def test_given_list_of_geometries_then_classifies_correctly(
- self, classifier_function
- ):
+ def test_given_list_of_geometries_then_classifies_correctly(self, classifier_function):
# 1. Defining test input data
left_classifier = "left_classifier"
right_classifier = "right_classifier"
@@ -148,10 +140,7 @@ def test_overall_performance(self, number_of_points: int):
classifiers_names = [left_classifier, right_classifier, undefined_classifier]
polygon_list = self.__get_basic_polygon_list(left_classifier, right_classifier)
map_boundary = (10, 10)
- xy_list = [
- self.__get_random_point(map_boundary[0], map_boundary[1])
- for _ in range(number_of_points)
- ]
+ xy_list = [self.__get_random_point(map_boundary[0], map_boundary[1]) for _ in range(number_of_points)]
polygon_file = PolygonFile(logging.getLogger(__name__))
polygon_file.polygons = polygon_list
@@ -171,24 +160,18 @@ def test_classify_points_for_waal(self):
assert polygon is not None
# Read NC File
- waal_data_dir = (
- TestUtils.get_external_test_data_subdir("case_08_waal") / "Data" / "FM"
- )
+ waal_data_dir = TestUtils.get_external_test_data_subdir("case_08_waal") / "Data" / "FM"
waal_nc_file = waal_data_dir / "FlowFM_fm2prof_map.nc"
assert waal_nc_file.is_file()
- _, edge_data, _, _ = FE._read_fm_model(str(waal_nc_file))
- points = [
- (edge_data["x"][i], edge_data["y"][i]) for i in range(len(edge_data["x"]))
- ]
+ _, edge_data, _, _ = funcs._read_fm_model(str(waal_nc_file))
+ points = [(edge_data["x"][i], edge_data["y"][i]) for i in range(len(edge_data["x"]))]
assert points is not None
# 2. Run test
self.__run_performance_test(polygon, points, None)
- def __run_performance_test(
- self, polygon_file: PolygonFile, xy_list: list, classifiers_names: list
- ):
+ def __run_performance_test(self, polygon_file: PolygonFile, xy_list: list, classifiers_names: list):
number_of_points = len(xy_list)
t_repetitions = 10
@@ -221,30 +204,23 @@ def time_function(function_name) -> list:
for name, result in t_results.items():
plt.plot(range(t_repetitions), result, label=name)
plt.legend()
- plt.savefig(
- output_dir + "\\time_performance_points_{}.png".format(number_of_points)
- )
+ plt.savefig(output_dir + f"\\time_performance_points_{number_of_points}.png")
plt.close()
plt.figure()
for name, result in c_results.items():
values = [classifiers_names.index(val) for val in list(result)]
- plt.scatter(
- range(len(list(result))), values, marker=next(markers), label=name
- )
+ plt.scatter(range(len(list(result))), values, marker=next(markers), label=name)
plt.yticks(
- [c_id for c_id in range(len(classifiers_names))],
+ list(range(len(classifiers_names))),
classifiers_names,
rotation=45,
)
plt.legend()
- plt.savefig(
- output_dir + "\\classifier_results_points_{}.png".format(number_of_points)
- )
+ plt.savefig(output_dir + f"\\classifier_results_points_{number_of_points}.png")
plt.close()
-
@fixture
def polygon_list():
return [
@@ -302,9 +278,7 @@ def test_PolygonFile_classify_points_with_property_shapely_prep(polygon_list):
polygon_file = PolygonFile(logging.getLogger())
polygon_file.polygons = polygon_list
xy_list = [(4, 2), (8, 6), (8, 8)]
- classified_points = polygon_file.classify_points_with_property_shapely_prep(
- points=xy_list, property_name="name"
- )
+ classified_points = polygon_file.classify_points_with_property_shapely_prep(points=xy_list, property_name="name")
assert np.array_equal(classified_points, ["poly1", "poly2", -999])
@@ -313,7 +287,8 @@ def test_PolygonFile_classify_points_with_property_rtree_by_polygons(polygon_lis
polygon_file.polygons = polygon_list
xy_list = [(4, 2), (8, 6), (8, 8)]
classified_points = polygon_file.classify_points_with_property_rtree_by_polygons(
- points=xy_list, property_name="name"
+ points=xy_list,
+ property_name="name",
)
assert np.array_equal(classified_points, ["poly1", "poly2", -999])
@@ -334,9 +309,7 @@ def test_PolygonFile_validate_extension():
polygon_file = PolygonFile(logging.getLogger())
test_fp = "test.sjon"
- with pytest.raises(
- IOError, match="Invalid file path extension, should be .json or .geojson."
- ):
+ with pytest.raises(IOError, match="Invalid file path extension, should be .json or .geojson."):
polygon_file._validate_extension(file_path=test_fp)
test_fp = "test.json"
polygon_file._validate_extension(test_fp)
@@ -359,11 +332,9 @@ def test_PolygonFile_polygons_property(polygon_list):
p_tuple(
geometry=Polygon([[1, 1], [5, 1], [5, 4], [1, 4], [1, 1]]),
properties={"type": "poly1"},
- )
+ ),
)
- with pytest.raises(
- ValueError, match="Polygon properties must contain key-word 'name'"
- ):
+ with pytest.raises(ValueError, match="Polygon properties must contain key-word 'name'"):
polygon_file.polygons = polygon_list
polygon_list[2].properties.pop("type")
polygon_list[2].properties["name"] = "poly1"
@@ -379,16 +350,12 @@ def test_RegionPolygonFile(mocker, test_geojson, tmp_path):
file_path = tmp_path / "test.geojson"
_geojson_file_writer(test_geojson, file_path)
mock_logger = mocker.patch.object(RegionPolygonFile, "set_logger_message")
- region_polygon_file = RegionPolygonFile(
- region_file_path=file_path, logger=logging.getLogger(__name__)
- )
+ region_polygon_file = RegionPolygonFile(region_file_path=file_path, logger=logging.getLogger(__name__))
assert mock_logger.call_args_list[0][0][0] == "Validating region file"
assert mock_logger.call_args_list[1][0][0] == "2 regions found"
xy_list = [(4, 2), (8, 6), (8, 8)]
- classified_points = region_polygon_file.classify_points(
- xy_list, property_name="name"
- )
+ classified_points = region_polygon_file.classify_points(xy_list, property_name="name")
assert np.array_equal(classified_points, ["poly1", "poly2", -999])
@@ -396,25 +363,17 @@ def test_SectionPolygonFile(mocker, test_geojson, tmp_path):
file_path = tmp_path / "test_geojson.geojson"
_geojson_file_writer(test_geojson, file_path)
mock_logger = mocker.patch.object(SectionPolygonFile, "set_logger_message")
- with pytest.raises(AssertionError, match="Section file is not valid"):
+ with pytest.raises(OSError, match="Section file is not valid"):
SectionPolygonFile(file_path, logger=logging.getLogger())
- assert (
- mock_logger.call_args_list[1][0][0]
- == 'Polygon poly1 has no property "section"'
- )
- assert (
- mock_logger.call_args_list[2][0][0]
- == 'Polygon poly2 has no property "section"'
- )
+ assert mock_logger.call_args_list[1][0][0] == 'Polygon poly1 has no property "section"'
+ assert mock_logger.call_args_list[2][0][0] == 'Polygon poly2 has no property "section"'
test_geojson["features"][0]["properties"]["section"] = "fake section"
_geojson_file_writer(test_geojson, file_path)
- with pytest.raises(AssertionError, match="Section file is not valid"):
+ with pytest.raises(OSError, match="Section file is not valid"):
SectionPolygonFile(file_path, logger=logging.getLogger())
- assert "fake section is not a recognized section" in [
- log_cal[0][0] for log_cal in mock_logger.call_args_list
- ]
+ assert "fake section is not a recognized section" in [log_cal[0][0] for log_cal in mock_logger.call_args_list]
test_geojson["features"][0]["properties"]["section"] = "1"
test_geojson["features"][1]["properties"]["section"] = "2"
_geojson_file_writer(test_geojson, file_path)
@@ -424,7 +383,5 @@ def test_SectionPolygonFile(mocker, test_geojson, tmp_path):
assert mock_logger.call_args_list[-1][0][0] == "Section file succesfully validated"
- classified_points = section_polygonfile.classify_points(
- points=[(4, 2), (8, 6), (8, 8)]
- )
+ classified_points = section_polygonfile.classify_points(points=[(4, 2), (8, 6), (8, 8)])
assert np.array_equal(classified_points, ["main", "floodplain1", 1])
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 624f3c97..6a482d50 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,45 +1,38 @@
import os
import shutil
-import pytest
+from datetime import datetime
from pathlib import Path
+
+import pytest
+
from fm2prof import Project
-from fm2prof.utils import GenerateCrossSectionLocationFile, Compare1D2D, VisualiseOutput
+from fm2prof.utils import Compare1D2D, GenerateCrossSectionLocationFile, VisualiseOutput
from tests.TestUtils import TestUtils
-from datetime import datetime
-
_root_output_dir = None
class Test_GenerateCrossSectionLocationFile:
def test_given_networkdefinitionfile_cssloc_file_is_generated(self, tmp_path: Path):
# 1. Set up initial test data
- path_1d = TestUtils.get_local_test_file(
- "cases/case_02_compound/Model_SOBEK/dimr/dflow1d/NetworkDefinition.ini"
- )
+ path_1d = TestUtils.get_local_test_file("cases/case_02_compound/Model_SOBEK/dimr/dflow1d/NetworkDefinition.ini")
output_file = tmp_path / "cross_section_locations.xyz"
# 2. Set Expectations
# 3. Run test
- GenerateCrossSectionLocationFile(
- networkdefinitionfile=path_1d, crossectionlocationfile=output_file
- )
+ GenerateCrossSectionLocationFile(network_definition_file=path_1d, crossection_location_file=output_file)
# 4. verify
assert output_file.is_file()
def test_given_branchrulefile_output_is_generated(self, tmp_path: Path):
# 1. Set up initial test data
- path_1d = TestUtils.get_local_test_file(
- "cases/case_02_compound/Model_SOBEK/dimr/dflow1d/NetworkDefinition.ini"
- )
+ path_1d = TestUtils.get_local_test_file("cases/case_02_compound/Model_SOBEK/dimr/dflow1d/NetworkDefinition.ini")
output_file = tmp_path / "cross_section_locations_new.xyz"
- branch_rule_file = TestUtils.get_local_test_file(
- "cases/case_02_compound/Data/branchrules_onlyfirst.ini"
- )
+ branch_rule_file = TestUtils.get_local_test_file("cases/case_02_compound/Data/branchrules_onlyfirst.ini")
# 2. Set Expectations
# 3. Run test
@@ -54,14 +47,10 @@ def test_given_branchrulefile_output_is_generated(self, tmp_path: Path):
def test_given_branchrule_exceptions_output_is_generated(self, tmp_path: Path):
# 1. Set up initial test data
- path_1d = TestUtils.get_local_test_file(
- "cases/case_02_compound/Model_SOBEK/dimr/dflow1d/NetworkDefinition.ini"
- )
+ path_1d = TestUtils.get_local_test_file("cases/case_02_compound/Model_SOBEK/dimr/dflow1d/NetworkDefinition.ini")
output_file = tmp_path / "cross_section_locations_new.xyz"
- branch_rule_file = TestUtils.get_local_test_file(
- "cases/case_02_compound/Data/branchrules_exceptions.ini"
- )
+ branch_rule_file = TestUtils.get_local_test_file("cases/case_02_compound/Data/branchrules_exceptions.ini")
# 2. Set Expectations
# 3. Run test
@@ -78,14 +67,10 @@ def test_given_branchrule_exceptions_output_is_generated(self, tmp_path: Path):
class Test_VisualiseOutput:
def test_when_branch_not_in_branches_raise_exception(self):
# 1. Set up initial test data
- project_config = TestUtils.get_local_test_file(
- "cases/case_02_compound/fm2prof_config.ini"
- )
+ project_config = TestUtils.get_local_test_file("cases/case_02_compound/fm2prof_config.ini")
project = Project(project_config)
- vis = VisualiseOutput(
- project.get_output_directory(), logger=project.get_logger()
- )
+ vis = VisualiseOutput(project.get_output_directory(), logger=project.get_logger())
# 2. Set expectations
error_snippet = "not in known branches:"
@@ -100,26 +85,20 @@ def test_when_branch_not_in_branches_raise_exception(self):
def test_when_branch_in_branches_produce_figure(self):
# 1. Set up initial test data
# 1. Set up initial test data
- project_config = TestUtils.get_local_test_file(
- "cases/case_02_compound/fm2prof_config.ini"
- )
+ project_config = TestUtils.get_local_test_file("cases/case_02_compound/fm2prof_config.ini")
project = Project(project_config)
- vis = VisualiseOutput(
- project.get_output_directory(), logger=project.get_logger()
- )
+ vis = VisualiseOutput(project.get_output_directory(), logger=project.get_logger())
# 2. Set expectations
- output_dir = TestUtils.get_local_test_file(
- "cases/case_02_compound/output/figures/roughness"
- )
+ output_dir = TestUtils.get_local_test_file("cases/case_02_compound/output/figures/roughness")
if output_dir.is_dir():
shutil.rmtree(output_dir)
os.mkdir(output_dir)
output_file = TestUtils.get_local_test_file(
- "cases/case_02_compound/output/figures/roughness/roughness_longitudinal_channel1.png"
+ "cases/case_02_compound/output/figures/roughness/roughness_longitudinal_channel1.png",
)
# 3. Run test
@@ -130,23 +109,17 @@ def test_when_branch_in_branches_produce_figure(self):
def test_when_given_output_css_figure_produced(self):
# 1. Set up initial test data
- project_config = TestUtils.get_local_test_file(
- "cases/case_02_compound/fm2prof_config.ini"
- )
+ project_config = TestUtils.get_local_test_file("cases/case_02_compound/fm2prof_config.ini")
project = Project(project_config)
- vis = VisualiseOutput(
- project.get_output_directory(), logger=project.get_logger()
- )
+ vis = VisualiseOutput(project.get_output_directory(), logger=project.get_logger())
# 2. Set expectations
- output_dir = TestUtils.get_local_test_file(
- "cases/case_02_compound/output/figures/cross_sections"
- )
+ output_dir = TestUtils.get_local_test_file("cases/case_02_compound/output/figures/cross_sections")
if output_dir.is_dir():
shutil.rmtree(output_dir)
os.mkdir(output_dir)
output_file = TestUtils.get_local_test_file(
- "cases/case_02_compound/output/figures/cross_sections/channel1_125.000.png"
+ "cases/case_02_compound/output/figures/cross_sections/channel1_125.000.png",
)
# 3. Run test
@@ -160,9 +133,7 @@ def test_when_given_output_css_figure_produced(self):
class Test_Compare1D2D:
def test_when_no_netcdf_but_csv_present_class_initialises(self):
# 1. Set up initial test data
- project_config = TestUtils.get_local_test_file(
- "compare1d2d/rijn-j22_6-v1a2/sobek-rijn-j22.ini"
- )
+ project_config = TestUtils.get_local_test_file("compare1d2d/rijn-j22_6-v1a2/sobek-rijn-j22.ini")
project = Project(project_config)
# 2. Run test
@@ -179,9 +150,7 @@ def test_when_no_netcdf_but_csv_present_class_initialises(self):
def test_statistics_to_file(self):
# 1. Set up initial test data
- project_config = TestUtils.get_local_test_file(
- "compare1d2d/rijn-j22_6-v1a2/sobek-rijn-j22.ini"
- )
+ project_config = TestUtils.get_local_test_file("compare1d2d/rijn-j22_6-v1a2/sobek-rijn-j22.ini")
project = Project(project_config)
plotter = Compare1D2D(
project=project,
@@ -193,9 +162,7 @@ def test_statistics_to_file(self):
# 2. Set expectations
# this file should exist
- output_file = TestUtils.get_local_test_file(
- "compare1d2d/rijn-j22_6-v1a2/output/error_statistics.csv"
- )
+ output_file = TestUtils.get_local_test_file("compare1d2d/rijn-j22_6-v1a2/output/error_statistics.csv")
# 3. Run test
plotter.statistics_to_file()
@@ -205,18 +172,14 @@ def test_statistics_to_file(self):
def test_figure_longitudinal(self):
# 1. Set up initial test data
- project_config = TestUtils.get_local_test_file(
- "compare1d2d/rijn-j22_6-v1a2/sobek-rijn-j22.ini"
- )
+ project_config = TestUtils.get_local_test_file("compare1d2d/rijn-j22_6-v1a2/sobek-rijn-j22.ini")
project = Project(project_config)
- plotter = Compare1D2D(
- project=project, start_time=datetime(year=2000, month=1, day=5)
- )
+ plotter = Compare1D2D(project=project, start_time=datetime(year=2000, month=1, day=5))
# 2. Set expectations
# this file should exist
output_file = TestUtils.get_local_test_file(
- "compare1d2d/rijn-j22_6-v1a2/output/figures/longitudinal/BR-PK-IJ.png"
+ "compare1d2d/rijn-j22_6-v1a2/output/figures/longitudinal/BR-PK-IJ.png",
)
# 3. Run test
@@ -227,43 +190,31 @@ def test_figure_longitudinal(self):
def test_figure_discharge(self):
# 1. Set up initial test data
- project_config = TestUtils.get_local_test_file(
- "compare1d2d/rijn-j22_6-v1a2/sobek-rijn-j22.ini"
- )
+ project_config = TestUtils.get_local_test_file("compare1d2d/rijn-j22_6-v1a2/sobek-rijn-j22.ini")
project = Project(project_config)
- plotter = Compare1D2D(
- project=project, start_time=datetime(year=2000, month=1, day=5)
- )
+ plotter = Compare1D2D(project=project, start_time=datetime(year=2000, month=1, day=5))
# 2. Set expectations
# this file should exist
output_file = TestUtils.get_local_test_file(
- "compare1d2d/rijn-j22_6-v1a2/output/figures/discharge/Pannerdensche Kop.png"
+ "compare1d2d/rijn-j22_6-v1a2/output/figures/discharge/Pannerdensche Kop.png",
)
# 3. Run test
- plotter.figure_compare_discharge_at_stations(
- stations=["WL_869.00", "PK_869.00"], title="Pannerdensche Kop"
- )
+ plotter.figure_compare_discharge_at_stations(stations=["WL_869.00", "PK_869.00"], title="Pannerdensche Kop")
# 4. Verify expectations
assert output_file.is_file()
def test_figure_at_station(self):
# 1. Set up initial test data
- project_config = TestUtils.get_local_test_file(
- "compare1d2d/cases/case1/fm2prof.ini"
- )
+ project_config = TestUtils.get_local_test_file("compare1d2d/cases/case1/fm2prof.ini")
project = Project(project_config)
- plotter = Compare1D2D(
- project=project, start_time=datetime(year=2000, month=1, day=5)
- )
+ plotter = Compare1D2D(project=project, start_time=datetime(year=2000, month=1, day=5))
# 2. Set expectations
# this file should exist
- output_file = TestUtils.get_local_test_file(
- "compare1d2d/cases/case1/output/figures/stations/NR_919.00.png"
- )
+ output_file = TestUtils.get_local_test_file("compare1d2d/cases/case1/output/figures/stations/NR_919.00.png")
# 3. Run test
plotter.figure_at_station("NR_919.00")
@@ -275,15 +226,13 @@ def test_if_style_is_given_figure_produced(self):
# 1. Set up initial test data
styles = ["van_veen", "sito"]
- project_config = TestUtils.get_local_test_file(
- "compare1d2d/rijn-j22_6-v1a2/sobek-rijn-j22.ini"
- )
+ project_config = TestUtils.get_local_test_file("compare1d2d/rijn-j22_6-v1a2/sobek-rijn-j22.ini")
project = Project(project_config)
# 2. Set expectations
# this file should exist
output_file = TestUtils.get_local_test_file(
- "compare1d2d/rijn-j22_6-v1a2/output/figures/longitudinal/BR-PK-IJ.png"
+ "compare1d2d/rijn-j22_6-v1a2/output/figures/longitudinal/BR-PK-IJ.png",
)
# 3. Run test