Skip to content

Commit

Permalink
Merge pull request #664 from DHI/pfs_hints
Browse files Browse the repository at this point in the history
Pfs type hints
  • Loading branch information
ecomodeller authored Mar 11, 2024
2 parents 521193d + 5a729c7 commit 83a9359
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 51 deletions.
2 changes: 1 addition & 1 deletion mikeio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .dfs import Dfs0, Dfs1, Dfs2, Dfs3
from .dfsu import Dfsu, Mesh
from .eum import EUMType, EUMUnit, ItemInfo
from .pfs import Pfs, PfsDocument, PfsSection, read_pfs
from .pfs import PfsDocument, PfsSection, read_pfs

# Grid geometries are imported into the main module, since they are used to create dfs files
# Other geometries are available in the spatial module
Expand Down
18 changes: 13 additions & 5 deletions mikeio/pfs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from ._pfsdocument import Pfs, PfsDocument
from __future__ import annotations
from pathlib import Path
from typing import Dict, TextIO
from ._pfsdocument import PfsDocument
from ._pfssection import PfsNonUniqueList, PfsSection


def read_pfs(filename, encoding="cp1252", unique_keywords=False):
def read_pfs(
filename: str | Path | TextIO | Dict | PfsSection,
encoding: str = "cp1252",
unique_keywords: bool = False,
) -> PfsDocument:
"""Read a pfs file to a Pfs object for further analysis/manipulation
Parameters
Expand All @@ -20,15 +27,16 @@ def read_pfs(filename, encoding="cp1252", unique_keywords=False):
Returns
-------
mikeio.Pfs
Pfs object which can be used for inspection, manipulation and writing
PfsDocument
A PfsDocument object
"""
return PfsDocument(filename, encoding=encoding, unique_keywords=unique_keywords)


__all__ = [
"Pfs",
"PfsDocument",
"PfsNonUniqueList",
"PfsSection",
"read_pfs",
]
]
52 changes: 22 additions & 30 deletions mikeio/pfs/_pfsdocument.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import Counter
from datetime import datetime
from pathlib import Path
from typing import Callable, Dict, List, Mapping, TextIO, Tuple
from typing import Callable, Dict, List, Mapping, Sequence, TextIO, Tuple

import yaml

Expand Down Expand Up @@ -82,12 +82,12 @@ class PfsDocument(PfsSection):

def __init__(
self,
data: TextIO | PfsSection | Dict,
data: TextIO | PfsSection | Dict | str | Path,
*,
encoding="cp1252",
names=None,
unique_keywords=False,
):
encoding: str = "cp1252",
names: Sequence[str] | None = None,
unique_keywords: bool = False,
) -> None:

if isinstance(data, (str, Path)) or hasattr(data, "read"):
if names is not None:
Expand Down Expand Up @@ -168,7 +168,9 @@ def _read_pfs_file(self, filename, encoding, unique_keywords=False):
try:
yml = self._pfs2yaml(filename, encoding)
target_list = parse_yaml_preserving_duplicates(yml, unique_keywords)
except AttributeError: # This is the error raised if parsing fails, try again with the normal loader
except (
AttributeError
): # This is the error raised if parsing fails, try again with the normal loader
target_list = yaml.load(yml, Loader=yaml.CFullLoader)
except FileNotFoundError as e:
raise FileNotFoundError(str(e))
Expand All @@ -179,7 +181,10 @@ def _read_pfs_file(self, filename, encoding, unique_keywords=False):
return names, sections

@staticmethod
def _parse_non_file_input(input, names=None):
def _parse_non_file_input(
input: Dict | PfsSection | Sequence[PfsSection] | Sequence[Dict],
names: Sequence[str] | None = None,
) -> Tuple[Sequence[str], List[PfsSection]]:
"""dict/PfsSection or lists of these can be parsed"""
if names is None:
assert isinstance(input, Mapping), "input must be a mapping"
Expand All @@ -189,11 +194,6 @@ def _parse_non_file_input(input, names=None):
sec, Mapping
), "all targets must be PfsSections/dict (no key-value pairs allowed in the root)"
return names, sections
# else:
# warnings.warn(
# "Creating a PfsDocument with names argument is deprecated, provide instead the names as keys in a dictionary",
# FutureWarning,
# )

if isinstance(names, str):
names = [names]
Expand All @@ -202,9 +202,9 @@ def _parse_non_file_input(input, names=None):
sections = [input]
elif isinstance(input, dict):
sections = [PfsSection(input)]
elif isinstance(input, (List, Tuple)):
elif isinstance(input, Sequence):
if isinstance(input[0], PfsSection):
sections = input
sections = input # type: ignore
elif isinstance(input[0], dict):
sections = [PfsSection(d) for d in input]
else:
Expand Down Expand Up @@ -245,12 +245,14 @@ def _add_FM_alias(self, alias: str, module: str) -> None:
setattr(self, alias, self.targets[0][module])
self._ALIAS_LIST.append(alias)

def _pfs2yaml(self, filename, encoding=None) -> str:
def _pfs2yaml(
self, filename: str | Path | TextIO, encoding: str | None = None
) -> str:

if hasattr(filename, "read"): # To read in memory strings StringIO
pfsstring = filename.read()
else:
with (open(filename, encoding=encoding)) as f:
with open(filename, encoding=encoding) as f:
pfsstring = f.read()

lines = pfsstring.split("\n")
Expand Down Expand Up @@ -331,16 +333,10 @@ def _parse_param(self, value: str) -> str:

_COMMA_MATCHER = re.compile(r",(?=(?:[^\"']*[\"'][^\"']*[\"'])*[^\"']*$)")

def _split_line_by_comma(self, s: str):
def _split_line_by_comma(self, s: str) -> List[str]:
return self._COMMA_MATCHER.split(s)
# import shlex
# lexer = shlex.shlex(s)
# lexer.whitespace += ","
# lexer.quotes += "|"
# lexer.wordchars += ",.-"
# return list(lexer)

def _parse_token(self, token: str, context="") -> str:

def _parse_token(self, token: str, context: str = "") -> str:
s = token.strip()

# Example of complicated string:
Expand Down Expand Up @@ -384,7 +380,3 @@ def write(self, filename=None):
f.write("\n\n")

self._write_with_func(f.write, level=0)


# TODO remove this alias
Pfs = PfsDocument
28 changes: 17 additions & 11 deletions mikeio/pfs/_pfssection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations
from datetime import datetime
from types import SimpleNamespace
from typing import Any, Callable, List, Mapping, MutableMapping, Sequence
from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Sequence

import pandas as pd


def _merge_dict(a: MutableMapping[str, Any], b: Mapping[str, Any]) -> Mapping[str, Any]:
def _merge_dict(a: Dict[str, Any], b: Mapping[str, Any]) -> Dict[str, Any]:
"""merges dict b into dict a; handling non-unique keys"""
for key in b:
if key in a:
Expand Down Expand Up @@ -180,9 +180,9 @@ def search(
*,
key: str | None = None,
section: str | None = None,
param=None,
param: str | bool | int | float | None = None,
case: bool = False,
):
) -> PfsSection:
"""Find recursively all keys, sections or parameters
matching a pattern
Expand Down Expand Up @@ -225,7 +225,11 @@ def search(
keypat=key, parampat=param, secpat=section, case=case
):
results.append(item)
return self.__class__._merge_PfsSections(results) if len(results) > 0 else None
return (
self.__class__._merge_PfsSections(results)
if len(results) > 0
else PfsSection({})
)

def _find_patterns_generator(
self, keypat=None, parampat=None, secpat=None, keylist=[], case=False
Expand Down Expand Up @@ -256,7 +260,7 @@ def _yield_deep_dict(keys, val):
yield d

@staticmethod
def _param_match(parampat, v, case):
def _param_match(parampat: Any, v: Any, case: bool) -> bool:
if parampat is None:
return False
if type(v) != type(parampat):
Expand All @@ -267,7 +271,7 @@ def _param_match(parampat, v, case):
else:
return parampat == v

def find_replace(self, old_value, new_value):
def find_replace(self, old_value: Any, new_value: Any) -> None:
"""Update recursively all old_value with new_value"""
for k, v in self.items():
if isinstance(v, PfsSection):
Expand All @@ -284,12 +288,14 @@ def copy(self) -> "PfsSection":
d[key] = value.to_dict().copy()
return self.__class__(d)

def _to_txt_lines(self):
lines = []
def _to_txt_lines(self) -> List[str]:
lines: List[str] = []
self._write_with_func(lines.append, newline="")
return lines

def _write_with_func(self, func: Callable, level: int = 0, newline: str = "\n"):
def _write_with_func(
self, func: Callable, level: int = 0, newline: str = "\n"
) -> None:
"""Write pfs nested objects
Parameters
Expand Down Expand Up @@ -432,7 +438,7 @@ def to_dataframe(self, prefix: str | None = None) -> pd.DataFrame:
return pd.DataFrame(res, index=range(1, n_sections + 1))

@classmethod
def _merge_PfsSections(cls, sections: Sequence) -> "PfsSection":
def _merge_PfsSections(cls, sections: Sequence[Dict]) -> "PfsSection":
"""Merge a list of PfsSections/dict"""
assert len(sections) > 0
a = sections[0]
Expand Down
10 changes: 6 additions & 4 deletions tests/test_pfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def test_read_write_she2(tmp_path):
pfs2 = mikeio.PfsDocument(outfilename)
assert pfs1.MIKESHE_FLOWMODEL == pfs2.MIKESHE_FLOWMODEL


def test_read_write_filenames(tmp_path):
infilename = "tests/testdata/pfs/filenames.pfs"
pfs1 = mikeio.PfsDocument(infilename)
Expand Down Expand Up @@ -1095,7 +1096,7 @@ def test_search_keyword(pfs_ABC_text):
assert "A2" in pfs.ROOT

r0 = pfs.search(key="not_there")
assert r0 is None
assert len(r0) == 0

r1 = pfs.search(key="float")
assert r1.ROOT.A1.B.float_1 == 4.5
Expand Down Expand Up @@ -1124,7 +1125,7 @@ def test_search_param(pfs_ABC_text):
pfs = mikeio.PfsDocument(StringIO(pfs_ABC_text))

r0 = pfs.search(param="not_there")
assert r0 is None
assert len(r0) == 0

r1 = pfs.search(param=0)
assert len(r1.ROOT) == 2
Expand All @@ -1143,7 +1144,7 @@ def test_search_section(pfs_ABC_text):
pfs = mikeio.PfsDocument(StringIO(pfs_ABC_text))

r0 = pfs.search(section="not_there")
assert r0 is None
assert len(r0) == 0

r1 = pfs.search(section="A")
assert len(r1.ROOT) == 2
Expand Down Expand Up @@ -1208,6 +1209,7 @@ def test_clob_can_contain_pipe_characters():
== '<CLOB:22,1,1,false,1,0,"",0,"",0,"",0,"",0,"",0,"",0,"",0,"",||,false>'
)


def test_write_read_clob(tmp_path):
clob_text = """
[WQRiverPfs_0]
Expand All @@ -1229,4 +1231,4 @@ def test_write_read_clob(tmp_path):
assert (
sct.Clob
== '<CLOB:22,1,1,false,1,0,"",0,"",0,"",0,"",0,"",0,"",0,"",0,"",||,false>'
)
)

0 comments on commit 83a9359

Please sign in to comment.