Skip to content

Commit

Permalink
test standardization for PRMSSolarGeom, PRMSAtmosphere
Browse files Browse the repository at this point in the history
  • Loading branch information
jmccreight committed Oct 19, 2023
1 parent 02346c4 commit ff9435b
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 120 deletions.
133 changes: 41 additions & 92 deletions autotest/test_prms_atmosphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,15 @@
from pywatershed.base.control import Control
from pywatershed.base.parameters import Parameters
from pywatershed.parameters import PrmsParameters
from utils_compare import compare_in_memory, compare_netcdfs

params = ["params_sep", "params_one"]


# @pytest.fixture(scope="function", params=params)
# def control(domain, request):
# if request.param == "params_one":
# params = PrmsParameters.load(domain["param_file"])
# dis = None

# else:
# # channel needs both hru and seg dis files
# dis_hru_file = domain["dir"] / "parameters_dis_hru.nc"
# dis_data = Parameters.merge(
# Parameters.from_netcdf(dis_hru_file, encoding=False),
# )
# dis = {"dis_hru": dis_data}
# compare in memory (faster) or full output files? or both!
do_compare_output_files = True
do_compare_in_memory = True
rtol = 1.0e-5
atol = 1.0e-5 # why is this relatively low accuracy?

# param_file = domain["dir"] / "parameters_PRMSAtmosphere.nc"
# params = {"PRMSAtmosphere": PrmsParameters.from_netcdf(param_file)}

# return Control.load(domain["control_file"], params=params, dis=dis)
params = ["params_sep", "params_one"]


@pytest.fixture(scope="function")
Expand All @@ -55,31 +42,11 @@ def parameters(domain, request):


def test_compare_prms(domain, control, discretization, parameters, tmp_path):
comparison_var_names = PRMSAtmosphere.get_variables()

output_dir = domain["prms_output_dir"]
cbh_dir = domain["cbh_inputs"]["prcp"].parent.resolve()

# get the answer data
comparison_var_names = [
"tmaxf",
"tminf",
"hru_ppt",
"hru_rain",
"hru_snow",
"swrad",
"potet",
"transp_on",
"tmaxc",
"tavgc",
"tminc",
"prmx",
"pptmix",
"orad_hru",
]
ans = {}
for key in comparison_var_names:
nc_pth = output_dir / f"{key}.nc"
ans[key] = adapter_factory(nc_pth, variable_name=key, control=control)

input_variables = {}
for key in PRMSAtmosphere.get_inputs():
dir = ""
Expand All @@ -96,53 +63,35 @@ def test_compare_prms(domain, control, discretization, parameters, tmp_path):
netcdf_output_dir=tmp_path,
)

all_success = True
for istep in range(control.n_times):
control.advance()
atm.advance()
atm.calculate(1.0)

# compare along the way
for key, val in ans.items():
val.advance()

for key in ans.keys():
a1 = ans[key].current
a2 = atm[key].current

tol = 1e-5
if key == "swrad":
tol = 5e-4
warn(f"using tol = {tol} for variable {key}")
if key == "tavgc":
tol = 1e-5
warn(f"using tol = {tol} for variable {key}")

success_a = np.allclose(a2, a1, atol=tol, rtol=0.00)
success_r = np.allclose(a2, a1, atol=0.00, rtol=tol)
success = False
if (not success_a) and (not success_r):
diff = a2 - a1
diffratio = abs(diff / a2)
if (diffratio < 1e-6).all():
success = True
continue
all_success = False
diffmin = diff.min()
diffmax = diff.max()
abs_diff = abs(diff)
absdiffmax = abs_diff.max()
wh_absdiffmax = np.where(abs_diff)[0]
print(f"time step {istep}")
print(f"output variable {key}")
print(f"prms {a1.min()} {a1.max()}")
print(f"pywatershed {a2.min()} {a2.max()}")
print(f"diff {diffmin} {diffmax}")
print(f"absdiffmax {absdiffmax}")
print(f"wh_absdiffmax {wh_absdiffmax}")
assert success

atm.finalize()

if not all_success:
raise Exception("pywatershed results do not match prms results")
if do_compare_in_memory:
answers = {}
for var in comparison_var_names:
var_pth = output_dir / f"{var}.nc"
answers[var] = adapter_factory(
var_pth, variable_name=var, control=control
)

# check the advance/calculate the state
tmaxf_id = id(atm.tmaxf)

for ii in range(control.n_times):
control.advance()
atm.advance()
if ii == 0:
atm.output()
atm.calculate(1.0)

compare_in_memory(atm, answers, atol=atol, rtol=rtol)
assert id(atm.tmaxf) == tmaxf_id

if do_compare_output_files:
compare_netcdfs(
comparison_var_names,
tmp_path,
output_dir,
atol=atol,
rtol=rtol,
print_var_max_errs=False,
)

return
2 changes: 1 addition & 1 deletion autotest/test_prms_soilzone.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# compare in memory (faster) or full output files? or both!
do_compare_output_files = False
do_compare_in_memory = True
rtol = atol = 1.0e-8
rtol = atol = 1.0e-7

calc_methods = ("numpy", "numba")
params = ("params_sep", "params_one")
Expand Down
63 changes: 36 additions & 27 deletions autotest/test_prms_solar_geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from pywatershed.parameters import PrmsParameters
from utils_compare import compare_in_memory, compare_netcdfs

# in this case we'll compare netcdf files and in memory
# compare in memory (faster) or full output files? or both!
do_compare_output_files = True
do_compare_in_memory = True
rtol = atol = 1.0e-10

atol = rtol = np.finfo(np.float32).resolution

params = ("params_sep", "params_one")
Expand Down Expand Up @@ -43,6 +47,7 @@ def test_compare_prms(
domain, control, discretization, parameters, tmp_path, from_prms_file
):
output_dir = domain["prms_output_dir"]

prms_soltab_file = domain["prms_run_dir"] / "soltab_debug"
if from_prms_file:
from_prms_file = prms_soltab_file
Expand All @@ -56,33 +61,37 @@ def test_compare_prms(
from_prms_file=from_prms_file,
netcdf_output_dir=tmp_path,
)
solar_geom.output()
solar_geom.finalize()

compare_netcdfs(
PRMSSolarGeometry.get_variables(),
tmp_path,
output_dir,
atol=atol,
rtol=rtol,
)

answers = {}
for var in PRMSSolarGeometry.get_variables():
var_pth = output_dir / f"{var}.nc"
answers[var] = adapter_factory(
var_pth, variable_name=var, control=control
if do_compare_in_memory:
answers = {}
for var in PRMSSolarGeometry.get_variables():
var_pth = output_dir / f"{var}.nc"
answers[var] = adapter_factory(
var_pth, variable_name=var, control=control
)

sunhrs_id = id(solar_geom.soltab_sunhrs)

# Though the data is all calculate on the initial advance,
# we step through it using the timeseries array.
# we only need to output at time 0
for ii in range(control.n_times):
control.advance()
solar_geom.advance()
if ii == 0:
solar_geom.output()
solar_geom.calculate(1.0)

compare_in_memory(solar_geom, answers, atol=atol, rtol=rtol)
assert id(solar_geom.soltab_sunhrs) == sunhrs_id

if do_compare_output_files:
compare_netcdfs(
PRMSSolarGeometry.get_variables(),
tmp_path,
output_dir,
atol=atol,
rtol=rtol,
)

# check the advance/calculate the state
sunhrs_id = id(solar_geom.soltab_sunhrs)

for ii in range(control.n_times):
control.advance()
solar_geom.advance()
solar_geom.calculate(1.0)

compare_in_memory(solar_geom, answers, atol=atol, rtol=rtol)
assert id(solar_geom.soltab_sunhrs) == sunhrs_id

return
11 changes: 11 additions & 0 deletions autotest/utils_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def assert_allclose(
strict: bool = False,
also_check_w_np: bool = True,
error_message: str = "Comparison unsuccessful (default message)",
print_max_errs: bool = False,
var_name: str = "",
):
"""Reinvent np.testing.assert_allclose to get useful diagnostincs in debug
Expand All @@ -29,6 +31,7 @@ def assert_allclose(
handling of scalars mentioned in the Notes section is disabled.
also_check_w_np: first check using np.testing.assert_allclose using
the same options.
print_max_errs: bool=False. Print max abs and rel err for each var.
"""

if also_check_w_np:
Expand Down Expand Up @@ -61,6 +64,11 @@ def assert_allclose(

close = abs_close | rel_close

if print_max_errs:
sp = "" if len(var_name) == 0 else " "
print(f"{var_name}{sp}max abs err: {abs_diff.max()}")
print(f"{var_name}{sp}max rel err: {rel_abs_diff.max()}")

assert close.all()


Expand Down Expand Up @@ -114,6 +122,7 @@ def compare_netcdfs(
strict: bool = False,
also_check_w_np: bool = True,
error_message: str = None,
print_var_max_errs: bool = False,
):
# TODO: docstring
# TODO: improve error message
Expand All @@ -138,4 +147,6 @@ def compare_netcdfs(
strict=strict,
also_check_w_np=also_check_w_np,
error_message=error_message,
print_max_errs=print_var_max_errs,
var_name=var,
)

0 comments on commit ff9435b

Please sign in to comment.