Skip to content

Commit

Permalink
clean up extra_coords/string/chararray coordinate outputs for flow graph
Browse files Browse the repository at this point in the history
  • Loading branch information
jmccreight committed Oct 8, 2024
1 parent 3f6242b commit 5056023
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 39 deletions.
5 changes: 3 additions & 2 deletions autotest/test_pass_through_flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def parameters_flow_graph(parameters_prms, discretization_prms):
nnodes = parameters_prms.dims["nsegment"] + 1
node_maker_name = ["prms_channel"] * nnodes
node_maker_name[-1] = "pass_throughs"
node_maker_name = np.array(node_maker_name, dtype="U")
node_maker_index = np.arange(nnodes)
node_maker_index[-1] = 0
node_maker_id = np.arange(nnodes)
Expand All @@ -78,11 +79,11 @@ def parameters_flow_graph(parameters_prms, discretization_prms):
)
# have to map to the graph from an index found in prms_channel
wh_intervene_above_graph = np.where(
(np.array(node_maker_name) == "prms_channel")
(node_maker_name == "prms_channel")
& (node_maker_index == wh_intervene_above_nhm[0][0])
)
wh_intervene_below_graph = np.where(
(np.array(node_maker_name) == "prms_channel")
(node_maker_name == "prms_channel")
& np.isin(node_maker_index, wh_intervene_below_nhm)
)

Expand Down
29 changes: 21 additions & 8 deletions autotest/test_prms_channel_flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_prms_channel_flow_graph_compare_prms(
"node_coord": np.arange(nnodes),
},
data_vars={
"node_maker_name": ["prms_channel"] * nnodes,
"node_maker_name": np.array(["prms_channel"] * nnodes, dtype="U"),
"node_maker_index": np.arange(nnodes),
"node_maker_id": np.arange(nnodes),
"to_graph_index": discretization.parameters["tosegment"] - 1,
Expand Down Expand Up @@ -398,6 +398,8 @@ def calculation(self) -> None:
def test_prms_channel_flow_graph_to_model_dict(
simulation, control, discretization, parameters, tmp_path
):
# This also tests the netcdf output contains the correct coordinates, at
# the end.
domain_dir = simulation["dir"]
input_dir = simulation["output_dir"]
run_dir = tmp_path
Expand Down Expand Up @@ -445,15 +447,19 @@ def test_prms_channel_flow_graph_to_model_dict(
random_seg_ids = discretization.parameters["nhm_seg"][rando]
n_new_nodes = len(random_seg_ids)

check_names = ["pass"] * n_new_nodes
check_indices = list(range(n_new_nodes))
check_ids = list(range(n_new_nodes))

model_dict = prms_channel_flow_graph_to_model_dict(
model_dict=model_dict,
prms_channel_dis=discretization,
prms_channel_dis_name="dis_both",
prms_channel_params=parameters,
new_nodes_maker_dict={"pass": PassThroughNodeMaker()},
new_nodes_maker_names=["pass"] * n_new_nodes,
new_nodes_maker_indices=list(range(n_new_nodes)),
new_nodes_maker_ids=list(range(n_new_nodes)),
new_nodes_maker_names=check_names,
new_nodes_maker_indices=check_indices,
new_nodes_maker_ids=check_ids,
new_nodes_flow_to_nhm_seg=random_seg_ids,
graph_budget_type="warn", # move to error
)
Expand Down Expand Up @@ -484,13 +490,20 @@ def test_prms_channel_flow_graph_to_model_dict(
model.finalize()

ans_dir = simulation["output_dir"]
outflow_ans = xr.open_dataarray(ans_dir / "seg_outflow.nc")
outflow_act = xr.open_dataarray(run_dir / "node_outflows.nc")[:, 0:(nsegs)]

outflow_ans = xr.load_dataarray(ans_dir / "seg_outflow.nc")
outflow_act = xr.load_dataarray(run_dir / "node_outflows.nc")
outflow_act_compare = outflow_act[:, 0:(nsegs)]
for tt in range(control.n_times):
np.testing.assert_allclose(
outflow_ans.values[tt, :],
outflow_act.values[tt, :],
outflow_act_compare.values[tt, :],
rtol=rtol,
atol=atol,
)

# check that the coordinates match what was provided
for vv in control.options["netcdf_output_var_names"]:
da = xr.load_dataarray(tmp_path / f"{vv}.nc")
assert da.node_maker_index[-3:].values.tolist() == check_indices
assert da.node_maker_name[-3:].values.tolist() == check_names
assert da.node_maker_id[-3:].values.tolist() == check_ids
15 changes: 1 addition & 14 deletions pywatershed/base/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,6 @@ def initialize_netcdf(
output_vars: list = None,
extra_coords: dict = None,
) -> None:
from netCDF4 import stringtochar

if self._netcdf_initialized:
msg = (

Check warning on line 536 in pywatershed/base/flow_graph.py

View check run for this annotation

Codecov / codecov/patch

pywatershed/base/flow_graph.py#L536

Added line #L536 was not covered by tests
f"{self.name} class previously initialized netcdf output "
Expand All @@ -553,18 +551,7 @@ def initialize_netcdf(
if param_name in skip_params + ["node_coord"]:
continue

if param_name == "node_maker_name":
# https://unidata.github.io/netcdf4-python/#dealing-with-strings # noqa: E501
mknames = params["node_maker_name"]
maxlen = np.array([len(nn) for nn in mknames]).max()
param_value = stringtochar(
np.array(mknames, dtype=f"S{maxlen}")
)

else:
param_value = params[param_name]

extra_coords["node_coord"][param_name] = param_value
extra_coords["node_coord"][param_name] = params[param_name]

# this gets the budget initialization too
super().initialize_netcdf(
Expand Down
1 change: 0 additions & 1 deletion pywatershed/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
np.dtype("uint16"): "u2",
np.dtype("uint8"): "u1",
np.dtype("bool"): None,
np.dtype("|S1"): "S#",
}

inch2cm = 2.54
Expand Down
7 changes: 5 additions & 2 deletions pywatershed/hydrology/prms_channel_flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,9 @@ def _build_flow_graph_inputs(
nnodes = nseg + nnew

node_maker_name = ["prms_channel"] * nseg + new_nodes_maker_names
maxlen = np.array([len(nn) for nn in node_maker_name]).max()
# need this to be unicode U# for keys and searching below
node_maker_name = np.array(node_maker_name, dtype=f"|U{maxlen}")
node_maker_index = np.array(
np.arange(nseg).tolist() + new_nodes_maker_indices
)
Expand All @@ -921,11 +924,11 @@ def _build_flow_graph_inputs(
)
# have to map to the graph from an index found in prms_channel
wh_intervene_above_graph = np.where(
(np.array(node_maker_name) == "prms_channel")
(node_maker_name == "prms_channel")
& (node_maker_index == wh_intervene_above_nhm[0][0])
)
wh_intervene_below_graph = np.where(
(np.array(node_maker_name) == "prms_channel")
(node_maker_name == "prms_channel")
& np.isin(node_maker_index, wh_intervene_below_nhm)
)

Expand Down
43 changes: 31 additions & 12 deletions pywatershed/utils/netcdf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ def __init__(
complevel: int = 4,
chunk_sizes: dict = {"time": 1, "hruid": 0},
):
from netCDF4 import stringtochar

"""Output the csv output data to a netcdf file
Args:
Expand Down Expand Up @@ -522,32 +524,49 @@ def __init__(
)
self.node_coord[:] = coordinates["node_coord"]

dims_created = []
char_dims_created = []
for x_dim, x_data_dict in extra_coords.items():
for x_var_name, x_data in x_data_dict.items():
nc_type = np_type_to_netcdf_type_dict[x_data.dtype]
# https://unidata.github.io/netcdf4-python/#dealing-with-strings # noqa: E501
type = x_data.dtype
type_str = str(type)

dim = (x_dim,)
if nc_type == "S#":
sdimlen = len(x_data[0])
if "U" in type_str or "S" in type_str:
# https://unidata.github.io/netcdf4-python/#dealing-with-strings # noqa: E501
# S1 gives "char" type in the file whereas another
# number gives "string" type. The former is properly
# handled by xarray
nc_type = "S1"
if nc_type in dims_created:

# if it is a string array, convert it to a character array
# I dont understand the particulars here, may need more
# work
if "U" in type_str:
char_array = stringtochar(x_data.astype("S"))
else:
char_array = stringtochar(x_data)

Check warning on line 547 in pywatershed/utils/netcdf_utils.py

View check run for this annotation

Codecov / codecov/patch

pywatershed/utils/netcdf_utils.py#L547

Added line #L547 was not covered by tests

char_dim_len = char_array.shape[1]
dim_name = f"char{char_dim_len}"
if dim_name in char_dims_created:
continue

Check warning on line 552 in pywatershed/utils/netcdf_utils.py

View check run for this annotation

Codecov / codecov/patch

pywatershed/utils/netcdf_utils.py#L552

Added line #L552 was not covered by tests

dim = (x_dim, nc_type)
_ = self.dataset.createDimension(nc_type, sdimlen)
dim = (x_dim, dim_name)
_ = self.dataset.createDimension(
dimname=dim_name, size=char_dim_len
)

char_dims_created += [dim_name]

dims_created += [sdimlen]
else:
nc_type = np_type_to_netcdf_type_dict[type]

# <
self[x_var_name] = self.dataset.createVariable(
x_var_name, nc_type, dim
varname=x_var_name, datatype=nc_type, dimensions=dim
)
if "S" == nc_type[0]:
self[x_var_name][:, :] = x_data
if "S1" == nc_type:
self[x_var_name][:, :] = char_array
self[x_var_name]._Encoding = "utf-8"

else:
Expand Down

0 comments on commit 5056023

Please sign in to comment.