Skip to content

Commit

Permalink
refactor: TrouteOutput to handle validation
Browse files Browse the repository at this point in the history
  • Loading branch information
aaraney committed Sep 20, 2024
1 parent 10277ec commit 06fa1b8
Showing 1 changed file with 68 additions and 17 deletions.
85 changes: 68 additions & 17 deletions python/ngen_cal/src/ngen/cal/ngen_hooks/ngen_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,30 @@

import pandas as pd
from ngen.cal import hookimpl
from pydantic import BaseModel

if TYPE_CHECKING:
from ngen.cal.meta import JobMeta
from ngen.cal.model import ModelExec
from ngen.cal.model import ModelExec, ValidationOptions, EvaluationOptions
from ngen.config.realization import NgenRealization


class _NgenCalModelOutputFn(typing.Protocol):
def __call__(self, id: str) -> pd.Series: ...

class TrouteOutputSettings(BaseModel):
validation_routing_output: Path


@typing.final
class TrouteOutput:
def __init__(self, filepath: Path) -> None:
self._output_file = filepath
self._settings: TrouteOutputSettings | None = None

self._ngen_realization: NgenRealization | None = None
self._validation_options: ValidationOptions | None = None
self._eval_options: EvaluationOptions | None = None

@hookimpl
def ngen_cal_model_configure(self, config: ModelExec) -> None:
Expand All @@ -33,36 +41,79 @@ def ngen_cal_model_configure(self, config: ModelExec) -> None:
assert config.ngen_realization is not None
self._ngen_realization = config.ngen_realization

# Try external provided output hooks, if those fail, try this one
# this will only execute if all other hooks return None (or they don't exist)
@hookimpl(specname="ngen_cal_model_output", trylast=True)
def get_output(self, id: str) -> pd.Series | None:
if (eval_options := config.eval_params) is not None:
self._eval_options = eval_options

if (validation_config := config.val_params) is not None:
self._validation_options = validation_config

# maybe pull in plugin settings
if (plugin_settings := config.plugin_settings.get("ngen_cal_troute_output")) is not None:
self._settings = TrouteOutputSettings.parse_obj(plugin_settings)

def _sim_eval_interval(self) -> tuple[datetime.datetime, datetime.datetime]:
assert (
self._ngen_realization is not None
), "ngen realization required; ensure `ngen_cal_model_configure` was called and the plugin was properly configured"

if not self._output_file.exists():
print(
f"{self._output_file} not found. Current working directory is {Path.cwd()!s}"
)
print("Setting output to None")
return None
if self._eval_options is not None and self._eval_options.evaluation_start is not None:
assert self._eval_options.evaluation_stop is not None
return self._eval_options.evaluation_start, self._eval_options.evaluation_stop

return self._ngen_realization.time.start_time, self._ngen_realization.time.end_time

filetype = self._output_file.suffix.lower()
def _validation_eval_interval(self) -> tuple[datetime.datetime, datetime.datetime]:
if self._validation_options is None:
print("validation options not provided, using sim evaluation interval")
return self._sim_eval_interval()
return self._validation_options.evaluation_interval()

def _output_handler_factory(self, output_file: Path) -> _NgenCalModelOutputFn:
filetype = output_file.suffix.lower()
if filetype == ".csv":
fn = self._factory_handler_csv(self._output_file)
fn = self._factory_handler_csv(output_file)
# TODO: fix. dont know if this format still works
# elif filetype == ".hdf5":
# fn = _model_output_legacy_hdf5(self._output_file)
elif filetype == ".nc":
fn = _stream_output_netcdf_v1(self._output_file)
fn = _stream_output_netcdf_v1(output_file)
elif filetype == ".parquet":
fn = _stream_output_parquet_v1(self._output_file)
fn = _stream_output_parquet_v1(output_file)
else:
raise RuntimeError(
f"unsupported t-route output filetype: {self._output_file.suffix}"
f"unsupported t-route output filetype: {output_file.suffix}"
)
return fn

# Try external provided output hooks, if those fail, try this one
# this will only execute if all other hooks return None (or they don't exist)
@hookimpl(specname="ngen_cal_model_output", trylast=True)
def get_output(self, id: str) -> pd.Series | None:
assert (
self._ngen_realization is not None
), "ngen realization required; ensure `ngen_cal_model_configure` was called and the plugin was properly configured"

if self._settings is not None and self._settings.validation_routing_output.exists():
output_file = self._settings.validation_routing_output
print(f"retrieving simulation data from validation output file: {output_file!s}")

start, end = self._validation_eval_interval()
print(f"validation: {start=} {end=}")
elif self._output_file.exists():
output_file = self._output_file
print(f"retrieving simulation data from output file: {output_file!s}")

start, end = self._sim_eval_interval()
print(f"{start=} {end=}")
else:
print(
f"{self._output_file} not found. Current working directory is {Path.cwd()!s}"
)
print("Setting output to None")
return None

# TODO: I dont think all output handlers can handle validation (csv comes to mind). circle back to this
fn = self._output_handler_factory(output_file)
ds = fn(id)
ds.name = "sim_flow"

Expand All @@ -74,7 +125,7 @@ def get_output(self, id: str) -> pd.Series | None:
seconds=self._ngen_realization.time.output_interval
)
start = self._ngen_realization.time.start_time
ds = ds.loc[start + ngen_dt :]
ds = ds.loc[start + ngen_dt :end]
ds = ds.resample("1h").first()
return ds

Expand Down

0 comments on commit 06fa1b8

Please sign in to comment.