Skip to content

Commit

Permalink
Overhaul of Python architecture (#731)
Browse files Browse the repository at this point in the history
Fixes #512
Groundwork for #630 & #252

The code has become much simpler in some places (e.g. write_toml).

```python
from ribasim import Model

m = Model(filepath="generated_testmodels/basic/ribasim.toml")
m.database.node.df # Node table
m.basin.static.df # BasinStatc table
m.write("test")
```

### Some notes:
- The config.py file cannot be autogenerated anymore. The schemas still
can, but I disabled it for now to be sure (some imports error).

### Changes:
I created new (parent) classes:
- BaseModel, from Pydantic, with our own config
- FileModel, like hydrolib-core (but now Pydantic v2), which can take a
single filepath for initilization. Models who inherit require defining
_load and _save, dealing with that filepath.
- NodeModel, a class for nodes (where `add` will be).
- TableModel, a class to read/write individual tables from/to
gpkg/arrow.
- SpatialTableModel, inherits TableModel, but reads/writes spatial
stuff.

I changed:
- the Model class to be a carbon copy of Config (which has been
deleted), so it mirrors the toml.
- in turn this created a `database` NodeModel class (reflecting the
field in the toml), with only Node and Edge underneath.
- the NodeModel classes Basin from their node_type version to the one in
Config, and set the type of the underlying table with a TypeVar like so:
```python
class Terminal(NodeModel):
    static: TableModel[TerminalStaticSchema]
```

### Yet to do:
- [x] Update tests
- [x] Fix sort! rules
- [x] Delete node_types folder
- [x] Link schemas to their Pydantic class
(TableModel[TerminalStaticSchema] => TerminalStatic)

---------

Co-authored-by: Martijn Visser <[email protected]>
Co-authored-by: Hofer-Julian <[email protected]>
Co-authored-by: Hofer-Julian <[email protected]>
  • Loading branch information
4 people committed Nov 14, 2023
1 parent 95e2c6f commit a157c82
Show file tree
Hide file tree
Showing 44 changed files with 2,742 additions and 2,951 deletions.
4 changes: 2 additions & 2 deletions docs/contribute/addnode.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ from typing import Optional
import pandera as pa
from pandera.engines.pandas_engine import PydanticModel
from pandera.typing import DataFrame
from pydantic import ConfigDict

from ribasim import models
from ribasim.input_base import TableModel
Expand Down Expand Up @@ -145,8 +146,7 @@ class NewNodeType(TableModel):
static: Optional[DataFrame[StaticSchema]] = None
# possible other schemas

class Config:
validate_assignment = True
model_config = ConfigDict(validate_assignment=True)

def sort(self):
self.static.sort_values("node_id", ignore_index=True, inplace=True)
Expand Down
46 changes: 26 additions & 20 deletions docs/python/examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@
")\n",
"node_xy = gpd.points_from_xy(x=xy[:, 0], y=xy[:, 1])\n",
"\n",
"node_id, node_type = ribasim.Node.get_node_ids_and_types(\n",
"node_id, node_type = ribasim.Node.node_ids_and_types(\n",
" basin,\n",
" manning_resistance,\n",
" rating_curve,\n",
Expand All @@ -333,7 +333,7 @@
"\n",
"# Make sure the feature id starts at 1: explicitly give an index.\n",
"node = ribasim.Node(\n",
" static=gpd.GeoDataFrame(\n",
" df=gpd.GeoDataFrame(\n",
" data={\"type\": node_type},\n",
" index=pd.Index(node_id, name=\"fid\"),\n",
" geometry=node_xy,\n",
Expand Down Expand Up @@ -363,7 +363,7 @@
")\n",
"lines = ribasim.utils.geometry_from_connectivity(node, from_id, to_id)\n",
"edge = ribasim.Edge(\n",
" static=gpd.GeoDataFrame(\n",
" df=gpd.GeoDataFrame(\n",
" data={\n",
" \"from_node_id\": from_id,\n",
" \"to_node_id\": to_id,\n",
Expand All @@ -390,8 +390,10 @@
"outputs": [],
"source": [
"model = ribasim.Model(\n",
" node=node,\n",
" edge=edge,\n",
" database=ribasim.Database(\n",
" node=node,\n",
" edge=edge,\n",
" ),\n",
" basin=basin,\n",
" level_boundary=level_boundary,\n",
" flow_boundary=flow_boundary,\n",
Expand Down Expand Up @@ -520,7 +522,7 @@
" .to_xarray()\n",
")\n",
"\n",
"basin_ids = model.basin.static[\"node_id\"].to_numpy()\n",
"basin_ids = model.basin.static.df[\"node_id\"].to_numpy()\n",
"basin_nodes = xr.DataArray(\n",
" np.ones(len(basin_ids)), coords={\"node_id\": basin_ids}, dims=[\"node_id\"]\n",
")\n",
Expand Down Expand Up @@ -548,8 +550,8 @@
"metadata": {},
"outputs": [],
"source": [
"model.basin.time = forcing\n",
"model.basin.state = state"
"model.basin.time.df = forcing\n",
"model.basin.state.df = state"
]
},
{
Expand Down Expand Up @@ -675,7 +677,7 @@
"\n",
"# Make sure the feature id starts at 1: explicitly give an index.\n",
"node = ribasim.Node(\n",
" static=gpd.GeoDataFrame(\n",
" df=gpd.GeoDataFrame(\n",
" data={\"type\": node_type},\n",
" index=pd.Index(np.arange(len(xy)) + 1, name=\"fid\"),\n",
" geometry=node_xy,\n",
Expand Down Expand Up @@ -705,7 +707,7 @@
"\n",
"lines = ribasim.utils.geometry_from_connectivity(node, from_id, to_id)\n",
"edge = ribasim.Edge(\n",
" static=gpd.GeoDataFrame(\n",
" df=gpd.GeoDataFrame(\n",
" data={\"from_node_id\": from_id, \"to_node_id\": to_id, \"edge_type\": edge_type},\n",
" geometry=lines,\n",
" crs=\"EPSG:28992\",\n",
Expand Down Expand Up @@ -903,8 +905,10 @@
"outputs": [],
"source": [
"model = ribasim.Model(\n",
" node=node,\n",
" edge=edge,\n",
" database=ribasim.Database(\n",
" node=node,\n",
" edge=edge,\n",
" ),\n",
" basin=basin,\n",
" pump=pump,\n",
" level_boundary=level_boundary,\n",
Expand Down Expand Up @@ -994,7 +998,7 @@
"\n",
"ax = df_basin_wide[\"level\"].plot()\n",
"\n",
"greater_than = model.discrete_control.condition.greater_than\n",
"greater_than = model.discrete_control.condition.df.greater_than\n",
"\n",
"ax.hlines(\n",
" greater_than,\n",
Expand Down Expand Up @@ -1103,7 +1107,7 @@
"\n",
"# Make sure the feature id starts at 1: explicitly give an index.\n",
"node = ribasim.Node(\n",
" static=gpd.GeoDataFrame(\n",
" df=gpd.GeoDataFrame(\n",
" data={\"type\": node_type},\n",
" index=pd.Index(np.arange(len(xy)) + 1, name=\"fid\"),\n",
" geometry=node_xy,\n",
Expand All @@ -1130,7 +1134,7 @@
"\n",
"lines = ribasim.utils.geometry_from_connectivity(node, from_id, to_id)\n",
"edge = ribasim.Edge(\n",
" static=gpd.GeoDataFrame(\n",
" df=gpd.GeoDataFrame(\n",
" data={\n",
" \"from_node_id\": from_id,\n",
" \"to_node_id\": to_id,\n",
Expand Down Expand Up @@ -1325,8 +1329,10 @@
"outputs": [],
"source": [
"model = ribasim.Model(\n",
" node=node,\n",
" edge=edge,\n",
" database=ribasim.Database(\n",
" node=node,\n",
" edge=edge,\n",
" ),\n",
" basin=basin,\n",
" flow_boundary=flow_boundary,\n",
" level_boundary=level_boundary,\n",
Expand Down Expand Up @@ -1415,8 +1421,8 @@
"ax.set_ylabel(\"level [m]\")\n",
"\n",
"# Plot target level\n",
"target_levels = model.pid_control.time.target.to_numpy()[::2]\n",
"times = date2num(model.pid_control.time.time)[::2]\n",
"target_levels = model.pid_control.time.df.target.to_numpy()[::2]\n",
"times = date2num(model.pid_control.time.df.time)[::2]\n",
"ax.plot(times, target_levels, color=\"k\", ls=\":\", label=\"target level\")\n",
"pass"
]
Expand Down Expand Up @@ -1445,7 +1451,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit a157c82

Please sign in to comment.