Skip to content

Commit

Permalink
prms_channel_flow_graph _postprocess and _model_dict both handle new …
Browse files Browse the repository at this point in the history
…nodes grouped in-series; various FlowGraph details and testing
  • Loading branch information
jmccreight committed Oct 22, 2024
1 parent 7612552 commit e3d4001
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 31 deletions.
78 changes: 51 additions & 27 deletions autotest/test_starfit_flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def test_starfit_flow_graph_postprocess(
):
input_dir = simulation["output_dir"]

# We'll test adding multiple new nodes in-series into a FlowGraph. Above
# and below the starfit we'll add pass-through nodes and check these
# match.

# We add a random passthrough node with no upstream node, to test if that
# works.

# in this test we'll cover the case of disconnected nodes and allowing
# them or not
def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset:
Expand Down Expand Up @@ -151,17 +158,26 @@ def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset:
"node_storages",
]

# Currently this is the same as notebook 06

new_nodes_maker_names = ["starfit", "pass_through"]
new_nodes_maker_indices = [0, 0]
new_nodes_maker_ids = [-2, -1]

new_nodes_maker_names = ["starfit"] + ["pass_through"] * 3
new_nodes_maker_indices = [0, 0, 1, 2]
new_nodes_maker_ids = [-2, -1, -100, -1000]
# The starfit node flows to the third passthrough node, in index 3.
# The first passthrough node flows to some random nhm_seg, not connected to
# the other new nodes.
# The second passthrough flows to the starfit node in index 0.
# The last passthrough node flows to the seg above which the reservoir
# is placed.
new_nodes_flow_to_nhm_seg = [-3, 44409, 0, 44426]

# the first in the list is for the disconnected node
check_names = ["prms_channel"] + new_nodes_maker_names
check_indices = [dis_ds.dims["nsegment"] - 1] + new_nodes_maker_indices
check_ids = [dis_ds.nhm_seg[-1]] + new_nodes_maker_ids
check_ids = [dis_ds.nhm_seg[-1].values.tolist()] + new_nodes_maker_ids

with pytest.warns(UserWarning):
# This warning should say: TODO
with pytest.warns(
UserWarning, match="Disconnected nodes present in FlowGraph."
):
flow_graph = prms_channel_flow_graph_postprocess(
control=control,
input_dir=input_dir,
Expand All @@ -179,13 +195,12 @@ def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset:
new_nodes_maker_names=new_nodes_maker_names,
new_nodes_maker_indices=new_nodes_maker_indices,
new_nodes_maker_ids=new_nodes_maker_ids,
new_nodes_flow_to_nhm_seg=[
44426,
44435,
],
addtl_output_vars=["_lake_spill", "_lake_release"],
new_nodes_flow_to_nhm_seg=new_nodes_flow_to_nhm_seg,
addtl_output_vars=["spill", "release"],
allow_disconnected_nodes=True,
)

# <
# get the segments un affected by flow, where the PRMS solutions should
# match
wh_44426 = np.where(discretization.parameters["nhm_seg"] == 44426)[0][0]
Expand Down Expand Up @@ -224,6 +239,9 @@ def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset:
var.advance()

answers_conv_vol = {}
# The -5 is not -4 because of the synthetic disconnected node also
# present on the FlowGraph (which is added via parameters, not
# using prms_channel_flow_graph_postprocess)
for key, val in answers.items():
if key in convert_to_vol:
current = val.current / (24 * 60 * 60)
Expand All @@ -241,23 +259,23 @@ def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset:
current[(wh_ignore,)] = flow_graph[key][(wh_ignore,)]
# Fill in the last two nodes
answers_conv_vol[key] = np.concatenate(
[current, flow_graph[key][-3:]]
[current, flow_graph[key][-5:]]
)

# <<
# there are no expected sources or sinks in this test
answers_conv_vol["node_sink_source"] = np.concatenate(
[val.current * zero, flow_graph["node_sink_source"][-3:]]
[val.current * zero, flow_graph["node_sink_source"][-5:]]
)
answers_conv_vol["node_negative_sink_source"] = np.concatenate(
[
val.current * zero,
flow_graph["node_negative_sink_source"][-3:],
flow_graph["node_negative_sink_source"][-5:],
]
)

answers_conv_vol["node_storages"] = np.concatenate(
[val.current * nan, flow_graph["node_storages"][-3:]]
[val.current * nan, flow_graph["node_storages"][-5:]]
)

compare_in_memory(
Expand All @@ -270,19 +288,26 @@ def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset:
)

# this checks that the budget was actually active for the starfit node
assert flow_graph._nodes[-2].budget is not None
assert flow_graph._nodes[-4].budget is not None

flow_graph.finalize()
for vv in control.options["netcdf_output_var_names"]:
da = xr.load_dataarray(tmp_path / f"{vv}.nc", concat_characters=True)
assert da.node_maker_index[-3:].values.tolist() == check_indices
assert da.node_maker_name[-3:].values.tolist() == check_names
assert da.node_maker_id[-3:].values.tolist() == check_ids
assert da.node_maker_index[-5:].values.tolist() == check_indices
assert da.node_maker_name[-5:].values.tolist() == check_names
assert da.node_maker_id[-5:].values.tolist() == check_ids

# test additional output files match and are working
da_no = xr.load_dataarray(tmp_path / "node_outflows.nc")
da_lr = xr.load_dataarray(tmp_path / "_lake_release.nc")
da_ls = xr.load_dataarray(tmp_path / "_lake_spill.nc")
da_in = xr.load_dataarray(tmp_path / "node_upstream_inflows.nc")

# full time check of the passthrough nodes (which is probably gratuitious
# given all the above checks already passed.
assert (abs(da_no[:, -1] - da_no[:, -4]) < 1e-12).all()
assert (abs(da_in[:, -2] - da_in[:, -4]) < 1e-12).all()

da_lr = xr.load_dataarray(tmp_path / "release.nc")
da_ls = xr.load_dataarray(tmp_path / "spill.nc")

da_no = da_no.where(da_no.node_maker_name == "starfit", drop=True)
da_lr = da_lr.where(da_lr.node_maker_name == "starfit", drop=True)
Expand Down Expand Up @@ -371,7 +396,7 @@ def test_starfit_flow_graph_model_dict(
new_nodes_maker_ids=new_nodes_maker_ids,
new_nodes_flow_to_nhm_seg=new_nodes_flow_to_nhm_seg,
graph_budget_type="error", # move to error
addtl_output_vars=["_lake_spill", "_lake_release"],
addtl_output_vars=["spill", "release"],
)
model = Model(model_dict)

Expand Down Expand Up @@ -452,7 +477,7 @@ def test_starfit_flow_graph_model_dict(
rtol=rtol,
skip_missing_ans=False,
fail_after_all_vars=False,
verbose=True,
verbose=False,
)

# this checks that the budget was actually active for the starfit node
Expand All @@ -464,6 +489,5 @@ def test_starfit_flow_graph_model_dict(
ds = xr.open_dataset(tmp_path / "FlowGraph.nc")
ds_starfit = ds.where(ds.node_maker_name == "starfit", drop=True)
assert (
ds_starfit.node_outflows
== ds_starfit._lake_release + ds_starfit._lake_spill
ds_starfit.node_outflows == ds_starfit.release + ds_starfit.spill
).all()
2 changes: 1 addition & 1 deletion pywatershed/base/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def __init__(
addtl_output_vars: list[str] = None,
params_not_to_netcdf: list[str] = None,
budget_type: Literal["defer", None, "warn", "error"] = "defer",
allow_disconnected_nodes: bool = True, # todo, make False
allow_disconnected_nodes: bool = False,
type_check_nodes: bool = False,
verbose: bool = None,
):
Expand Down
52 changes: 49 additions & 3 deletions pywatershed/hydrology/prms_channel_flow_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pathlib as pl
from copy import deepcopy
from typing import Literal, Union

import numba as nb
Expand Down Expand Up @@ -623,6 +624,7 @@ def prms_channel_flow_graph_postprocess(
new_nodes_maker_ids: list,
new_nodes_flow_to_nhm_seg: list,
addtl_output_vars: list[str] = None,
allow_disconnected_nodes: bool = False,
budget_type: Literal["defer", None, "warn", "error"] = "defer",
type_check_nodes: bool = False,
) -> FlowGraph:
Expand Down Expand Up @@ -657,7 +659,11 @@ def prms_channel_flow_graph_postprocess(
NodeMaker.
new_nodes_maker_ids: Collated list of ids relative to each NodeMaker.
new_nodes_flow_to_nhm_seg: collated list describing the nhm_seg to
which the node will flow.
which the node will flow. Use of non-positive entries specifies
the zero-based index for flowing to nodes specified in these
collated parameters, allowing these new nodes to be added in
groups, in series to the existing NHM FlowGraph. Note that a new
node may not be placed below any outflow point of the domain.
budget_type: one of ["defer", None, "warn", "error"] with "defer" being
the default and defering to control.options["budget_type"] when
available. When control.options["budget_type"] is not avaiable,
Expand Down Expand Up @@ -726,6 +732,7 @@ def advance(self) -> None:
addtl_output_vars=addtl_output_vars,
budget_type=budget_type,
type_check_nodes=type_check_nodes,
allow_disconnected_nodes=allow_disconnected_nodes,
)
return flow_graph

Expand All @@ -742,6 +749,7 @@ def prms_channel_flow_graph_to_model_dict(
new_nodes_flow_to_nhm_seg: list,
addtl_output_vars: list[str] = None,
graph_budget_type: Literal["defer", None, "warn", "error"] = "defer",
allow_disconnected_nodes: bool = False,
type_check_nodes: bool = False,
) -> dict:
"""Add nodes to a PRMSChannel-based FlowGraph within a Model's model_dict.
Expand Down Expand Up @@ -775,7 +783,11 @@ def prms_channel_flow_graph_to_model_dict(
NodeMaker
new_nodes_maker_ids: Collated list of ids relative to each NodeMaker.
new_nodes_flow_to_nhm_seg: collated list describing the nhm_seg to
which the node will flow.
which the node will flow. Use of non-positive entries specifies
the zero-based index for flowing to nodes specified in these
collated parameters, allowing these new nodes to be added in
groups, in series to the existing NHM FlowGraph. Note that a new
node may not be placed below any outflow point of the domain.
graph_budget_type: one of ["defer", None, "warn", "error"] with
"defer" being the default and defering to
control.options["budget_type"] when available. When
Expand Down Expand Up @@ -868,6 +880,7 @@ def exchange_calculation(self) -> None:
"dis": None,
"budget_type": graph_budget_type,
"addtl_output_vars": addtl_output_vars,
"allow_disconnected_nodes": allow_disconnected_nodes,
},
}

Expand Down Expand Up @@ -899,7 +912,7 @@ def _build_flow_graph_inputs(
# new_nodes_flow_to_nhm_seg
assert len(new_nodes_flow_to_nhm_seg) == len(
np.unique(new_nodes_flow_to_nhm_seg)
), "Cant have more than one new node flowing to an existing node"
), "Cant have more than one new node flowing to an existing or new node."

nseg = prms_channel_params.dims["nsegment"]
nnew = len(new_nodes_maker_names)
Expand All @@ -921,7 +934,17 @@ def _build_flow_graph_inputs(
tosegment = dis_params["tosegment"] - 1 # fortan to python indexing
to_graph_index[0:nseg] = tosegment

# The new nodes which flow to other new_nodes have to be added after
# the nodes flowing to existing nodes with nhm_seg ids.
to_new_nodes_inds_in_added = {}
added_new_nodes_inds_in_graph = {}
for ii, nhm_seg in enumerate(new_nodes_flow_to_nhm_seg):
if nhm_seg < 1:
# negative indes are indices into the collated inputs lists
# for new nodes being in-series.
to_new_nodes_inds_in_added[ii] = -1 * nhm_seg
continue

wh_intervene_above_nhm = np.where(dis_params["nhm_seg"] == nhm_seg)
wh_intervene_below_nhm = np.where(
tosegment == wh_intervene_above_nhm[0][0]
Expand All @@ -938,6 +961,29 @@ def _build_flow_graph_inputs(

to_graph_index[nseg + ii] = wh_intervene_above_graph[0][0]
to_graph_index[wh_intervene_below_graph] = nseg + ii
added_new_nodes_inds_in_graph[ii] = nseg + ii

# <
to_new_nodes_inds_remaining = deepcopy(to_new_nodes_inds_in_added)
# worst case scenario is that we have to iterate the length of this list
# if the items in the list are in the wrong order
for itry in range(len(to_new_nodes_inds_remaining)):
# for input_ind, to_ind_remain in to_new_nodes_inds_remaining.items():
for input_ind in list(to_new_nodes_inds_remaining.keys()):
to_ind_remain = to_new_nodes_inds_remaining[input_ind]
if to_ind_remain not in added_new_nodes_inds_in_graph.keys():
continue
flows_to_ind = added_new_nodes_inds_in_graph[to_ind_remain]
flows_from_inds = np.where(to_graph_index == flows_to_ind)

to_graph_index[nseg + input_ind] = flows_to_ind
to_graph_index[flows_from_inds] = nseg + input_ind
added_new_nodes_inds_in_graph[input_ind] = nseg + input_ind
del to_new_nodes_inds_remaining[input_ind]

if len(to_new_nodes_inds_remaining):
msg = "Unable to connect some new nodes in-series."
raise ValueError(msg)

# <
param_dict = dict(
Expand Down
10 changes: 10 additions & 0 deletions pywatershed/hydrology/starfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,16 @@ def storage_change(self) -> np.float64:
def storage(self) -> np.float64:
return self._lake_storage[0]

@property
def release(self) -> np.float64:
"The release component of the STARFIT outflow."
return self._lake_release[0]

@property
def spill(self) -> np.float64:
"The spill component of the STARFIT outflow."
return self._lake_spill[0]

@property
def sink_source(self) -> np.float64:
return zero
Expand Down

0 comments on commit e3d4001

Please sign in to comment.