Skip to content

Commit

Permalink
FlowGraph capability to cull additional, heterogeneous variables from…
Browse files Browse the repository at this point in the history
… Nodes and output to NetCDF
  • Loading branch information
jmccreight committed Oct 19, 2024
1 parent b29a899 commit 7612552
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 13 deletions.
23 changes: 22 additions & 1 deletion autotest/test_starfit_flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset:
44426,
44435,
],
addtl_output_vars=["_lake_spill", "_lake_release"],
)

# get the segments un affected by flow, where the PRMS solutions should
Expand Down Expand Up @@ -278,6 +279,17 @@ def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset:
assert da.node_maker_name[-3:].values.tolist() == check_names
assert da.node_maker_id[-3:].values.tolist() == check_ids

# test additional output files match and are working
da_no = xr.load_dataarray(tmp_path / "node_outflows.nc")
da_lr = xr.load_dataarray(tmp_path / "_lake_release.nc")
da_ls = xr.load_dataarray(tmp_path / "_lake_spill.nc")

da_no = da_no.where(da_no.node_maker_name == "starfit", drop=True)
da_lr = da_lr.where(da_lr.node_maker_name == "starfit", drop=True)
da_ls = da_ls.where(da_ls.node_maker_name == "starfit", drop=True)

assert (da_no == da_lr + da_ls).all()


def test_starfit_flow_graph_model_dict(
simulation,
Expand Down Expand Up @@ -359,6 +371,7 @@ def test_starfit_flow_graph_model_dict(
new_nodes_maker_ids=new_nodes_maker_ids,
new_nodes_flow_to_nhm_seg=new_nodes_flow_to_nhm_seg,
graph_budget_type="error", # move to error
addtl_output_vars=["_lake_spill", "_lake_release"],
)
model = Model(model_dict)

Expand All @@ -375,7 +388,7 @@ def test_starfit_flow_graph_model_dict(

if do_compare_output_files:
# not really feasible as noted by NB section at top
model.initialize_netcdf(tmp_path)
model.initialize_netcdf(tmp_path, separate_files=False)

if do_compare_in_memory:
answers = {}
Expand Down Expand Up @@ -446,3 +459,11 @@ def test_starfit_flow_graph_model_dict(
assert flow_graph._nodes[-2].budget is not None

flow_graph.finalize()

# test single file output has extra coords and additional vars
ds = xr.open_dataset(tmp_path / "FlowGraph.nc")
ds_starfit = ds.where(ds.node_maker_name == "starfit", drop=True)
assert (
ds_starfit.node_outflows
== ds_starfit._lake_release + ds_starfit._lake_spill
).all()
2 changes: 2 additions & 0 deletions pywatershed/base/conservative_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def initialize_netcdf(
budget_args: dict = None,
output_vars: list = None,
extra_coords: dict = None,
addtl_output_vars: list = None,
) -> None:
if self._netcdf_initialized:
msg = (
Expand All @@ -230,6 +231,7 @@ def initialize_netcdf(
separate_files=separate_files,
output_vars=output_vars,
extra_coords=extra_coords,
addtl_output_vars=addtl_output_vars,
)

if self.budget is not None:
Expand Down
50 changes: 50 additions & 0 deletions pywatershed/base/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def __init__(
parameters: Parameters,
inflows: adaptable,
node_maker_dict: dict,
addtl_output_vars: list[str] = None,
params_not_to_netcdf: list[str] = None,
budget_type: Literal["defer", None, "warn", "error"] = "defer",
allow_disconnected_nodes: bool = True, # todo, make False
Expand All @@ -345,6 +346,13 @@ def __init__(
node_maker_dict: A dictionary of FlowNodeMaker instances with
keys/names supplied in the parameters, e.g.
{key1: flow_node_maker_instance, ...}.
params_not_to_netcdf: A list of string names for parameter to NOT
write to NetCDF output files. By default all parameters are
included in each file written.
addtl_output_vars: A list of string names for variables to collect
for NetCDF output from FlowNodes. These variables do not have to
be available in all FlowNodes but must be present in at least
one.
budget_type: one of ["defer", None, "warn", "error"] with "defer"
being the default and defering to
control.options["budget_type"] when
Expand Down Expand Up @@ -524,6 +532,40 @@ def _init_graph(self) -> None:
)
]

# <
# Deal with additional output variables requested
if self._addtl_output_vars is None:
self._addtl_output_vars = []

unique_makers = np.unique(params["node_maker_name"])

self._addtl_output_vars_wh_collect = {}
for vv in self._addtl_output_vars:
inds_to_collect = []
for uu in unique_makers:
wh_uu = np.where(params["node_maker_name"] == uu)
if hasattr(self._nodes[wh_uu[0][0]], vv):
# do we need to get/set the type here? Would have to
# check the type over all nodes/node makers
# for now I'll just throw and error if it is not float64
msg = "Only currently handling float64, new code required"
assert self._nodes[wh_uu[0][0]][vv].dtype == "float64", msg
inds_to_collect += wh_uu[0].tolist()
# <<
if len(inds_to_collect):
self._addtl_output_vars_wh_collect[vv] = inds_to_collect

msg = "Variable already set on FlowGraph."
for kk in self._addtl_output_vars_wh_collect.keys():
assert not hasattr(self, kk), msg
self[kk] = np.full([self.nnodes], np.nan)
# TODO: find some other way of getting metadata here.
# it could come from the node itself, i suppose, or
# could come from arguments or static source. Node seems most
# elegant. I suppose there could be conflicts if multiple
# nodes have the same variable and different metadata.
self.meta[kk] = {"dims": ("nnodes",), "type": "float64"}

def initialize_netcdf(
self,
output_dir: [str, pl.Path] = None,
Expand Down Expand Up @@ -559,6 +601,7 @@ def initialize_netcdf(
separate_files=separate_files,
output_vars=output_vars,
extra_coords=extra_coords,
addtl_output_vars=list(self._addtl_output_vars_wh_collect.keys()),
)

return
Expand Down Expand Up @@ -627,6 +670,13 @@ def calculate(self, time_length: float, n_substeps: int = 24) -> None:
self.node_storages[ii] = self._nodes[ii].storage
self.node_sink_source[ii] = self._nodes[ii].sink_source

for (
add_var_name,
add_var_inds,
) in self._addtl_output_vars_wh_collect.items():
for ii in add_var_inds:
self[add_var_name][ii] = self._nodes[ii][add_var_name]

self.node_negative_sink_source[:] = -1 * self.node_sink_source

# global mass balance term
Expand Down
20 changes: 8 additions & 12 deletions pywatershed/base/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def initialize_netcdf(
separate_files: bool = None,
output_vars: list = None,
extra_coords: dict = None,
addtl_output_vars: list = None,
) -> None:
"""Initialize NetCDF output.
Expand Down Expand Up @@ -526,16 +527,14 @@ def initialize_netcdf(
self._netcdf_initialized = False
return

if addtl_output_vars is not None:
self._netcdf_output_vars += addtl_output_vars

self._netcdf = {}

if self._netcdf_separate:
# make working directory
self._netcdf_output_dir.mkdir(parents=True, exist_ok=True)
for variable_name in self.variables:
if (self._netcdf_output_vars is not None) and (
variable_name not in self._netcdf_output_vars
):
continue
for variable_name in self._netcdf_output_vars:
nc_path = self._netcdf_output_dir / f"{variable_name}.nc"
self._netcdf[variable_name] = NetCdfWrite(
name=nc_path,
Expand All @@ -559,6 +558,7 @@ def initialize_netcdf(
coordinates=self._params.coords,
variables=self._netcdf_output_vars,
var_meta=self.meta,
extra_coords=extra_coords,
global_attrs={"process class": self.name},
)
for variable in the_out_vars[1:]:
Expand All @@ -575,11 +575,7 @@ def _output_netcdf(self) -> None:
"""
if self._netcdf_initialized:
time_added = False
for variable in self.variables:
if (self._netcdf_output_vars is not None) and (
variable not in self._netcdf_output_vars
):
continue
for variable in self._netcdf_output_vars:
if not time_added or self._netcdf_separate:
time_added = True
self._netcdf[variable].add_simulation_time(
Expand All @@ -600,7 +596,7 @@ def _finalize_netcdf(self) -> None:
None
"""
if self._netcdf_initialized:
for idx, variable in enumerate(self.variables):
for idx, variable in enumerate(self._netcdf_output_vars):
if (self._netcdf_output_vars is not None) and (
variable not in self._netcdf_output_vars
):
Expand Down
4 changes: 4 additions & 0 deletions pywatershed/hydrology/prms_channel_flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ def prms_channel_flow_graph_postprocess(
new_nodes_maker_indices: list,
new_nodes_maker_ids: list,
new_nodes_flow_to_nhm_seg: list,
addtl_output_vars: list[str] = None,
budget_type: Literal["defer", None, "warn", "error"] = "defer",
type_check_nodes: bool = False,
) -> FlowGraph:
Expand Down Expand Up @@ -722,6 +723,7 @@ def advance(self) -> None:
parameters=params_flow_graph,
inflows=inflows_graph,
node_maker_dict=node_maker_dict,
addtl_output_vars=addtl_output_vars,
budget_type=budget_type,
type_check_nodes=type_check_nodes,
)
Expand All @@ -738,6 +740,7 @@ def prms_channel_flow_graph_to_model_dict(
new_nodes_maker_indices: list,
new_nodes_maker_ids: list,
new_nodes_flow_to_nhm_seg: list,
addtl_output_vars: list[str] = None,
graph_budget_type: Literal["defer", None, "warn", "error"] = "defer",
type_check_nodes: bool = False,
) -> dict:
Expand Down Expand Up @@ -864,6 +867,7 @@ def exchange_calculation(self) -> None:
"parameters": params_flow_graph,
"dis": None,
"budget_type": graph_budget_type,
"addtl_output_vars": addtl_output_vars,
},
}

Expand Down

0 comments on commit 7612552

Please sign in to comment.