diff --git a/autotest/test_starfit_flow_graph.py b/autotest/test_starfit_flow_graph.py index d796dafa..1c9610c5 100644 --- a/autotest/test_starfit_flow_graph.py +++ b/autotest/test_starfit_flow_graph.py @@ -108,6 +108,13 @@ def test_starfit_flow_graph_postprocess( ): input_dir = simulation["output_dir"] + # We'll test adding multiple new nodes in-series into a FlowGraph. Above + # and below the starfit we'll add pass-through nodes and check these + # match. + + # We add a random passthrough node with no upstream node, to test if that + # works. + # in this test we'll cover the case of disconnected nodes and allowing # them or not def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset: @@ -151,17 +158,26 @@ def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset: "node_storages", ] - # Currently this is the same as notebook 06 - - new_nodes_maker_names = ["starfit", "pass_through"] - new_nodes_maker_indices = [0, 0] - new_nodes_maker_ids = [-2, -1] - + new_nodes_maker_names = ["starfit"] + ["pass_through"] * 3 + new_nodes_maker_indices = [0, 0, 1, 2] + new_nodes_maker_ids = [-2, -1, -100, -1000] + # The starfit node flows to the third passthrough node, in index 3. + # The first passthrough node flows to some random nhm_seg, not connected to + # the other new nodes. + # The second passthrough flows to the starfit node in index 0. + # The last passthrough node flows to the seg above which the reservoir + # is placed. + new_nodes_flow_to_nhm_seg = [-3, 44409, 0, 44426] + + # the first in the list is for the disconnected node check_names = ["prms_channel"] + new_nodes_maker_names check_indices = [dis_ds.dims["nsegment"] - 1] + new_nodes_maker_indices - check_ids = [dis_ds.nhm_seg[-1]] + new_nodes_maker_ids + check_ids = [dis_ds.nhm_seg[-1].values.tolist()] + new_nodes_maker_ids - with pytest.warns(UserWarning): + # This warning should say: TODO + with pytest.warns( + UserWarning, match="Disconnected nodes present in FlowGraph." + ): flow_graph = prms_channel_flow_graph_postprocess( control=control, input_dir=input_dir, @@ -179,13 +195,12 @@ def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset: new_nodes_maker_names=new_nodes_maker_names, new_nodes_maker_indices=new_nodes_maker_indices, new_nodes_maker_ids=new_nodes_maker_ids, - new_nodes_flow_to_nhm_seg=[ - 44426, - 44435, - ], - addtl_output_vars=["_lake_spill", "_lake_release"], + new_nodes_flow_to_nhm_seg=new_nodes_flow_to_nhm_seg, + addtl_output_vars=["spill", "release"], + allow_disconnected_nodes=True, ) + # < # get the segments un affected by flow, where the PRMS solutions should # match wh_44426 = np.where(discretization.parameters["nhm_seg"] == 44426)[0][0] @@ -224,6 +239,9 @@ def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset: var.advance() answers_conv_vol = {} + # The -5 is not -4 because of the synthetic disconnected node also + # present on the FlowGraph (which is added via parameters, not + # using prms_channel_flow_graph_postprocess) for key, val in answers.items(): if key in convert_to_vol: current = val.current / (24 * 60 * 60) @@ -241,23 +259,23 @@ def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset: current[(wh_ignore,)] = flow_graph[key][(wh_ignore,)] # Fill in the last two nodes answers_conv_vol[key] = np.concatenate( - [current, flow_graph[key][-3:]] + [current, flow_graph[key][-5:]] ) # << # there are no expected sources or sinks in this test answers_conv_vol["node_sink_source"] = np.concatenate( - [val.current * zero, flow_graph["node_sink_source"][-3:]] + [val.current * zero, flow_graph["node_sink_source"][-5:]] ) answers_conv_vol["node_negative_sink_source"] = np.concatenate( [ val.current * zero, - flow_graph["node_negative_sink_source"][-3:], + flow_graph["node_negative_sink_source"][-5:], ] ) answers_conv_vol["node_storages"] = np.concatenate( - [val.current * nan, flow_graph["node_storages"][-3:]] + [val.current * nan, flow_graph["node_storages"][-5:]] ) compare_in_memory( @@ -270,19 +288,26 @@ def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset: ) # this checks that the budget was actually active for the starfit node - assert flow_graph._nodes[-2].budget is not None + assert flow_graph._nodes[-4].budget is not None flow_graph.finalize() for vv in control.options["netcdf_output_var_names"]: da = xr.load_dataarray(tmp_path / f"{vv}.nc", concat_characters=True) - assert da.node_maker_index[-3:].values.tolist() == check_indices - assert da.node_maker_name[-3:].values.tolist() == check_names - assert da.node_maker_id[-3:].values.tolist() == check_ids + assert da.node_maker_index[-5:].values.tolist() == check_indices + assert da.node_maker_name[-5:].values.tolist() == check_names + assert da.node_maker_id[-5:].values.tolist() == check_ids # test additional output files match and are working da_no = xr.load_dataarray(tmp_path / "node_outflows.nc") - da_lr = xr.load_dataarray(tmp_path / "_lake_release.nc") - da_ls = xr.load_dataarray(tmp_path / "_lake_spill.nc") + da_in = xr.load_dataarray(tmp_path / "node_upstream_inflows.nc") + + # full time check of the passthrough nodes (which is probably gratuitious + # given all the above checks already passed. + assert (abs(da_no[:, -1] - da_no[:, -4]) < 1e-12).all() + assert (abs(da_in[:, -2] - da_in[:, -4]) < 1e-12).all() + + da_lr = xr.load_dataarray(tmp_path / "release.nc") + da_ls = xr.load_dataarray(tmp_path / "spill.nc") da_no = da_no.where(da_no.node_maker_name == "starfit", drop=True) da_lr = da_lr.where(da_lr.node_maker_name == "starfit", drop=True) @@ -371,7 +396,7 @@ def test_starfit_flow_graph_model_dict( new_nodes_maker_ids=new_nodes_maker_ids, new_nodes_flow_to_nhm_seg=new_nodes_flow_to_nhm_seg, graph_budget_type="error", # move to error - addtl_output_vars=["_lake_spill", "_lake_release"], + addtl_output_vars=["spill", "release"], ) model = Model(model_dict) @@ -452,7 +477,7 @@ def test_starfit_flow_graph_model_dict( rtol=rtol, skip_missing_ans=False, fail_after_all_vars=False, - verbose=True, + verbose=False, ) # this checks that the budget was actually active for the starfit node @@ -464,6 +489,5 @@ def test_starfit_flow_graph_model_dict( ds = xr.open_dataset(tmp_path / "FlowGraph.nc") ds_starfit = ds.where(ds.node_maker_name == "starfit", drop=True) assert ( - ds_starfit.node_outflows - == ds_starfit._lake_release + ds_starfit._lake_spill + ds_starfit.node_outflows == ds_starfit.release + ds_starfit.spill ).all() diff --git a/pywatershed/base/flow_graph.py b/pywatershed/base/flow_graph.py index 66dd11bc..05d3b924 100644 --- a/pywatershed/base/flow_graph.py +++ b/pywatershed/base/flow_graph.py @@ -329,7 +329,7 @@ def __init__( addtl_output_vars: list[str] = None, params_not_to_netcdf: list[str] = None, budget_type: Literal["defer", None, "warn", "error"] = "defer", - allow_disconnected_nodes: bool = True, # todo, make False + allow_disconnected_nodes: bool = False, type_check_nodes: bool = False, verbose: bool = None, ): diff --git a/pywatershed/hydrology/prms_channel_flow_graph.py b/pywatershed/hydrology/prms_channel_flow_graph.py index c252b795..81f3421c 100644 --- a/pywatershed/hydrology/prms_channel_flow_graph.py +++ b/pywatershed/hydrology/prms_channel_flow_graph.py @@ -1,4 +1,5 @@ import pathlib as pl +from copy import deepcopy from typing import Literal, Union import numba as nb @@ -623,6 +624,7 @@ def prms_channel_flow_graph_postprocess( new_nodes_maker_ids: list, new_nodes_flow_to_nhm_seg: list, addtl_output_vars: list[str] = None, + allow_disconnected_nodes: bool = False, budget_type: Literal["defer", None, "warn", "error"] = "defer", type_check_nodes: bool = False, ) -> FlowGraph: @@ -657,7 +659,11 @@ def prms_channel_flow_graph_postprocess( NodeMaker. new_nodes_maker_ids: Collated list of ids relative to each NodeMaker. new_nodes_flow_to_nhm_seg: collated list describing the nhm_seg to - which the node will flow. + which the node will flow. Use of non-positive entries specifies + the zero-based index for flowing to nodes specified in these + collated parameters, allowing these new nodes to be added in + groups, in series to the existing NHM FlowGraph. Note that a new + node may not be placed below any outflow point of the domain. budget_type: one of ["defer", None, "warn", "error"] with "defer" being the default and defering to control.options["budget_type"] when available. When control.options["budget_type"] is not avaiable, @@ -726,6 +732,7 @@ def advance(self) -> None: addtl_output_vars=addtl_output_vars, budget_type=budget_type, type_check_nodes=type_check_nodes, + allow_disconnected_nodes=allow_disconnected_nodes, ) return flow_graph @@ -742,6 +749,7 @@ def prms_channel_flow_graph_to_model_dict( new_nodes_flow_to_nhm_seg: list, addtl_output_vars: list[str] = None, graph_budget_type: Literal["defer", None, "warn", "error"] = "defer", + allow_disconnected_nodes: bool = False, type_check_nodes: bool = False, ) -> dict: """Add nodes to a PRMSChannel-based FlowGraph within a Model's model_dict. @@ -775,7 +783,11 @@ def prms_channel_flow_graph_to_model_dict( NodeMaker new_nodes_maker_ids: Collated list of ids relative to each NodeMaker. new_nodes_flow_to_nhm_seg: collated list describing the nhm_seg to - which the node will flow. + which the node will flow. Use of non-positive entries specifies + the zero-based index for flowing to nodes specified in these + collated parameters, allowing these new nodes to be added in + groups, in series to the existing NHM FlowGraph. Note that a new + node may not be placed below any outflow point of the domain. graph_budget_type: one of ["defer", None, "warn", "error"] with "defer" being the default and defering to control.options["budget_type"] when available. When @@ -868,6 +880,7 @@ def exchange_calculation(self) -> None: "dis": None, "budget_type": graph_budget_type, "addtl_output_vars": addtl_output_vars, + "allow_disconnected_nodes": allow_disconnected_nodes, }, } @@ -899,7 +912,7 @@ def _build_flow_graph_inputs( # new_nodes_flow_to_nhm_seg assert len(new_nodes_flow_to_nhm_seg) == len( np.unique(new_nodes_flow_to_nhm_seg) - ), "Cant have more than one new node flowing to an existing node" + ), "Cant have more than one new node flowing to an existing or new node." nseg = prms_channel_params.dims["nsegment"] nnew = len(new_nodes_maker_names) @@ -921,7 +934,17 @@ def _build_flow_graph_inputs( tosegment = dis_params["tosegment"] - 1 # fortan to python indexing to_graph_index[0:nseg] = tosegment + # The new nodes which flow to other new_nodes have to be added after + # the nodes flowing to existing nodes with nhm_seg ids. + to_new_nodes_inds_in_added = {} + added_new_nodes_inds_in_graph = {} for ii, nhm_seg in enumerate(new_nodes_flow_to_nhm_seg): + if nhm_seg < 1: + # negative indes are indices into the collated inputs lists + # for new nodes being in-series. + to_new_nodes_inds_in_added[ii] = -1 * nhm_seg + continue + wh_intervene_above_nhm = np.where(dis_params["nhm_seg"] == nhm_seg) wh_intervene_below_nhm = np.where( tosegment == wh_intervene_above_nhm[0][0] @@ -938,6 +961,29 @@ def _build_flow_graph_inputs( to_graph_index[nseg + ii] = wh_intervene_above_graph[0][0] to_graph_index[wh_intervene_below_graph] = nseg + ii + added_new_nodes_inds_in_graph[ii] = nseg + ii + + # < + to_new_nodes_inds_remaining = deepcopy(to_new_nodes_inds_in_added) + # worst case scenario is that we have to iterate the length of this list + # if the items in the list are in the wrong order + for itry in range(len(to_new_nodes_inds_remaining)): + # for input_ind, to_ind_remain in to_new_nodes_inds_remaining.items(): + for input_ind in list(to_new_nodes_inds_remaining.keys()): + to_ind_remain = to_new_nodes_inds_remaining[input_ind] + if to_ind_remain not in added_new_nodes_inds_in_graph.keys(): + continue + flows_to_ind = added_new_nodes_inds_in_graph[to_ind_remain] + flows_from_inds = np.where(to_graph_index == flows_to_ind) + + to_graph_index[nseg + input_ind] = flows_to_ind + to_graph_index[flows_from_inds] = nseg + input_ind + added_new_nodes_inds_in_graph[input_ind] = nseg + input_ind + del to_new_nodes_inds_remaining[input_ind] + + if len(to_new_nodes_inds_remaining): + msg = "Unable to connect some new nodes in-series." + raise ValueError(msg) # < param_dict = dict( diff --git a/pywatershed/hydrology/starfit.py b/pywatershed/hydrology/starfit.py index 11886fac..c5400132 100644 --- a/pywatershed/hydrology/starfit.py +++ b/pywatershed/hydrology/starfit.py @@ -860,6 +860,16 @@ def storage_change(self) -> np.float64: def storage(self) -> np.float64: return self._lake_storage[0] + @property + def release(self) -> np.float64: + "The release component of the STARFIT outflow." + return self._lake_release[0] + + @property + def spill(self) -> np.float64: + "The spill component of the STARFIT outflow." + return self._lake_spill[0] + @property def sink_source(self) -> np.float64: return zero