From 62affca7223c3bcb5e8cc672782ee1edf2ba388c Mon Sep 17 00:00:00 2001 From: Anushan Fernando Date: Wed, 11 Sep 2024 04:04:21 -0700 Subject: [PATCH] Move ToraxSimOutputs and SimError to output.py. This is to avoid a circular import when doing follow up work to add restart functionality. PiperOrigin-RevId: 673329404 --- torax/output.py | 60 ++++++++++++++++++++++++++++++------------- torax/sim.py | 42 +++++------------------------- torax/tests/output.py | 13 +++++----- torax/tests/sim.py | 2 +- 4 files changed, 56 insertions(+), 61 deletions(-) diff --git a/torax/output.py b/torax/output.py index 42408fa3..f6cdba8f 100644 --- a/torax/output.py +++ b/torax/output.py @@ -13,14 +13,15 @@ # limitations under the License. """Module containing functions for saving and loading simulation output.""" -from typing import TypeAlias +from __future__ import annotations + +import enum from absl import logging import chex import jax from jax import numpy as jnp from torax import geometry -from torax import sim as sim_lib from torax import state from torax.config import config_args from torax.sources import source_profiles @@ -29,6 +30,34 @@ import os +@enum.unique +class SimError(enum.Enum): + """Integer enum for sim error handling.""" + + NO_ERROR = 0 + NAN_DETECTED = 1 + + +@chex.dataclass(frozen=True) +class ToraxSimOutputs: + """Output structure returned by `run_simulation()`. + + Contains the error state and the history of the simulation state. + Can be extended in the future to include more metadata about the simulation. + + Attributes: + sim_error: simulation error state: NO_ERROR for no error, NAN_DETECTED for + NaNs found in core profiles. + sim_history: history of the simulation state. + """ + + # Error state + sim_error: SimError + + # Time-dependent TORAX outputs + sim_history: tuple[state.ToraxSimState, ...] + + # Core profiles. TEMP_EL = "temp_el" TEMP_EL_RIGHT_BC = "temp_el_right_bc" @@ -80,31 +109,26 @@ # Simulation error state. SIM_ERROR = "sim_error" -# Tuple of (path_to_xarray_file, time_to_load_from). -FilepathAndTime: TypeAlias = tuple[str, float] - def load_state_file( - filepath_and_time: FilepathAndTime, data_var: str -) -> xr.DataArray: + filepath: str, time: float, +) -> xr.Dataset: """Loads a state file from a filepath.""" - path, t = filepath_and_time - if os.path.exists(path): - with open(path, "rb") as f: - logging.info("Loading %s from state file %s, time %s", data_var, path, t) - da = xr.load_dataset(f).sel(time=slice(t, None)).data_vars[data_var] - if RHO_CELL_NORM in da.coords: - return da.rename({RHO_CELL_NORM: config_args.RHO_NORM}) - else: - return da + if os.path.exists(filepath): + with open(filepath, "rb") as f: + logging.info("Loading state file %s, time %s", filepath, time) + ds = xr.load_dataset(f).sel(time=time, method="nearest").squeeze() + ds = ds.rename({RHO_CELL_NORM: config_args.RHO_NORM}) + ds.close() # Release any resources. + return ds else: - raise ValueError(f"File {path} does not exist.") + raise ValueError(f"File {filepath} does not exist.") class StateHistory: """A history of the state of the simulation and its error state.""" - def __init__(self, sim_outputs: sim_lib.ToraxSimOutputs): + def __init__(self, sim_outputs: ToraxSimOutputs): core_profiles = [ state.core_profiles.history_elem() for state in sim_outputs.sim_history diff --git a/torax/sim.py b/torax/sim.py index b0645477..2e34e107 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -27,7 +27,6 @@ from __future__ import annotations import dataclasses -import enum import time from typing import Any, Optional @@ -41,6 +40,7 @@ from torax import geometry from torax import geometry_provider as geometry_provider_lib from torax import jax_utils +from torax import output from torax import physics from torax import state from torax.config import config_args @@ -56,34 +56,6 @@ from torax.transport_model import transport_model as transport_model_lib -@enum.unique -class SimError(enum.Enum): - """Integer enum for sim error handling.""" - - NO_ERROR = 0 - NAN_DETECTED = 1 - - -@chex.dataclass(frozen=True) -class ToraxSimOutputs: - """Output structure returned by `run_simulation()`. - - Contains the error state and the history of the simulation state. - Can be extended in the future to include more metadata about the simulation. - - Attributes: - sim_error: simulation error state: NO_ERROR for no error, NAN_DETECTED for - NaNs found in core profiles. - sim_history: history of the simulation state. - """ - - # Error state - sim_error: SimError - - # Time-dependent TORAX outputs - sim_history: tuple[state.ToraxSimState, ...] - - def _log_timestep( t: jax.Array, dt: jax.Array, outer_stepper_iterations: int ) -> None: @@ -837,7 +809,7 @@ def run( self, log_timestep_info: bool = False, spectator: spectator_lib.Spectator | None = None, - ) -> ToraxSimOutputs: + ) -> output.ToraxSimOutputs: """Runs the transport simulation over a prescribed time interval. See `run_simulation` for details. @@ -967,7 +939,7 @@ def run_simulation( step_fn: SimulationStepFn, log_timestep_info: bool = False, spectator: spectator_lib.Spectator | None = None, -) -> ToraxSimOutputs: +) -> output.ToraxSimOutputs: """Runs the transport simulation over a prescribed time interval. This is the main entrypoint for running a TORAX simulation. @@ -1063,7 +1035,7 @@ def run_simulation( # Set the sim_error to NO_ERROR. If we encounter an error, we will set it to # the appropriate error code. - sim_error = SimError.NO_ERROR + sim_error = output.SimError.NO_ERROR # Keep advancing the simulation until the time_step_calculator tells us we are # done. while time_step_calculator.not_done( @@ -1148,11 +1120,11 @@ def run_simulation( Possible cause is negative temperatures or densities. Output file contains all profiles up to the last valid step. """) - sim_error = SimError.NAN_DETECTED + sim_error = output.SimError.NAN_DETECTED break # Log final timestep - if log_timestep_info and sim_error == SimError.NO_ERROR: + if log_timestep_info and sim_error == output.SimError.NO_ERROR: # The "sim_state" here has been updated by the loop above. _log_timestep( sim_state.t, @@ -1208,7 +1180,7 @@ def run_simulation( simulation_time, wall_clock_time_elapsed, ) - return ToraxSimOutputs( + return output.ToraxSimOutputs( sim_error=sim_error, sim_history=tuple(sim_history) ) diff --git a/torax/tests/output.py b/torax/tests/output.py index 0db7fc81..aee34025 100644 --- a/torax/tests/output.py +++ b/torax/tests/output.py @@ -25,7 +25,6 @@ from torax import geometry from torax import geometry_provider from torax import output -from torax import sim as sim_lib from torax import state from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params @@ -101,10 +100,10 @@ def test_state_history_init(self): inner_solver_iterations=1, ), ) - sim_error = sim_lib.SimError.NO_ERROR + sim_error = output.SimError.NO_ERROR output.StateHistory( - sim_lib.ToraxSimOutputs(sim_error=sim_error, sim_history=(sim_state,)) + output.ToraxSimOutputs(sim_error=sim_error, sim_history=(sim_state,)) ) def test_state_history_to_xr(self): @@ -124,9 +123,9 @@ def test_state_history_to_xr(self): inner_solver_iterations=1, ), ) - sim_error = sim_lib.SimError.NO_ERROR + sim_error = output.SimError.NO_ERROR history = output.StateHistory( - sim_lib.ToraxSimOutputs(sim_error=sim_error, sim_history=(sim_state,)) + output.ToraxSimOutputs(sim_error=sim_error, sim_history=(sim_state,)) ) history.simulation_output_to_xr(self.geo) @@ -149,9 +148,9 @@ def test_load_core_profiles_from_xr(self): inner_solver_iterations=1, ), ) - sim_error = sim_lib.SimError.NO_ERROR + sim_error = output.SimError.NO_ERROR history = output.StateHistory( - sim_lib.ToraxSimOutputs(sim_error=sim_error, sim_history=(sim_state,)) + output.ToraxSimOutputs(sim_error=sim_error, sim_history=(sim_state,)) ) # Output to an xr.Dataset and save to disk. ds = history.simulation_output_to_xr(self.geo) diff --git a/torax/tests/sim.py b/torax/tests/sim.py index 13b2dea4..2e40e849 100644 --- a/torax/tests/sim.py +++ b/torax/tests/sim.py @@ -638,7 +638,7 @@ def test_nans_trigger_error(self): sim_outputs = sim.run() state_history = output.StateHistory(sim_outputs) - self.assertEqual(state_history.sim_error, sim_lib.SimError.NAN_DETECTED) + self.assertEqual(state_history.sim_error, output.SimError.NAN_DETECTED) assert ( state_history.times[-1] < config_module.CONFIG['runtime_params']['numerics']['t_final']