Skip to content

Commit

Permalink
Move ToraxSimOutputs and SimError to output.py.
Browse files Browse the repository at this point in the history
This is to avoid a circular import when doing follow up work to add restart functionality.

PiperOrigin-RevId: 673329404
  • Loading branch information
Nush395 authored and Torax team committed Sep 11, 2024
1 parent 6d40708 commit 62affca
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 61 deletions.
60 changes: 42 additions & 18 deletions torax/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.

"""Module containing functions for saving and loading simulation output."""
from typing import TypeAlias
from __future__ import annotations

import enum

from absl import logging
import chex
import jax
from jax import numpy as jnp
from torax import geometry
from torax import sim as sim_lib
from torax import state
from torax.config import config_args
from torax.sources import source_profiles
Expand All @@ -29,6 +30,34 @@
import os


@enum.unique
class SimError(enum.Enum):
"""Integer enum for sim error handling."""

NO_ERROR = 0
NAN_DETECTED = 1


@chex.dataclass(frozen=True)
class ToraxSimOutputs:
"""Output structure returned by `run_simulation()`.
Contains the error state and the history of the simulation state.
Can be extended in the future to include more metadata about the simulation.
Attributes:
sim_error: simulation error state: NO_ERROR for no error, NAN_DETECTED for
NaNs found in core profiles.
sim_history: history of the simulation state.
"""

# Error state
sim_error: SimError

# Time-dependent TORAX outputs
sim_history: tuple[state.ToraxSimState, ...]


# Core profiles.
TEMP_EL = "temp_el"
TEMP_EL_RIGHT_BC = "temp_el_right_bc"
Expand Down Expand Up @@ -80,31 +109,26 @@
# Simulation error state.
SIM_ERROR = "sim_error"

# Tuple of (path_to_xarray_file, time_to_load_from).
FilepathAndTime: TypeAlias = tuple[str, float]


def load_state_file(
filepath_and_time: FilepathAndTime, data_var: str
) -> xr.DataArray:
filepath: str, time: float,
) -> xr.Dataset:
"""Loads a state file from a filepath."""
path, t = filepath_and_time
if os.path.exists(path):
with open(path, "rb") as f:
logging.info("Loading %s from state file %s, time %s", data_var, path, t)
da = xr.load_dataset(f).sel(time=slice(t, None)).data_vars[data_var]
if RHO_CELL_NORM in da.coords:
return da.rename({RHO_CELL_NORM: config_args.RHO_NORM})
else:
return da
if os.path.exists(filepath):
with open(filepath, "rb") as f:
logging.info("Loading state file %s, time %s", filepath, time)
ds = xr.load_dataset(f).sel(time=time, method="nearest").squeeze()
ds = ds.rename({RHO_CELL_NORM: config_args.RHO_NORM})
ds.close() # Release any resources.
return ds
else:
raise ValueError(f"File {path} does not exist.")
raise ValueError(f"File {filepath} does not exist.")


class StateHistory:
"""A history of the state of the simulation and its error state."""

def __init__(self, sim_outputs: sim_lib.ToraxSimOutputs):
def __init__(self, sim_outputs: ToraxSimOutputs):
core_profiles = [
state.core_profiles.history_elem()
for state in sim_outputs.sim_history
Expand Down
42 changes: 7 additions & 35 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from __future__ import annotations

import dataclasses
import enum
import time
from typing import Any, Optional

Expand All @@ -41,6 +40,7 @@
from torax import geometry
from torax import geometry_provider as geometry_provider_lib
from torax import jax_utils
from torax import output
from torax import physics
from torax import state
from torax.config import config_args
Expand All @@ -56,34 +56,6 @@
from torax.transport_model import transport_model as transport_model_lib


@enum.unique
class SimError(enum.Enum):
"""Integer enum for sim error handling."""

NO_ERROR = 0
NAN_DETECTED = 1


@chex.dataclass(frozen=True)
class ToraxSimOutputs:
"""Output structure returned by `run_simulation()`.
Contains the error state and the history of the simulation state.
Can be extended in the future to include more metadata about the simulation.
Attributes:
sim_error: simulation error state: NO_ERROR for no error, NAN_DETECTED for
NaNs found in core profiles.
sim_history: history of the simulation state.
"""

# Error state
sim_error: SimError

# Time-dependent TORAX outputs
sim_history: tuple[state.ToraxSimState, ...]


def _log_timestep(
t: jax.Array, dt: jax.Array, outer_stepper_iterations: int
) -> None:
Expand Down Expand Up @@ -837,7 +809,7 @@ def run(
self,
log_timestep_info: bool = False,
spectator: spectator_lib.Spectator | None = None,
) -> ToraxSimOutputs:
) -> output.ToraxSimOutputs:
"""Runs the transport simulation over a prescribed time interval.
See `run_simulation` for details.
Expand Down Expand Up @@ -967,7 +939,7 @@ def run_simulation(
step_fn: SimulationStepFn,
log_timestep_info: bool = False,
spectator: spectator_lib.Spectator | None = None,
) -> ToraxSimOutputs:
) -> output.ToraxSimOutputs:
"""Runs the transport simulation over a prescribed time interval.
This is the main entrypoint for running a TORAX simulation.
Expand Down Expand Up @@ -1063,7 +1035,7 @@ def run_simulation(

# Set the sim_error to NO_ERROR. If we encounter an error, we will set it to
# the appropriate error code.
sim_error = SimError.NO_ERROR
sim_error = output.SimError.NO_ERROR
# Keep advancing the simulation until the time_step_calculator tells us we are
# done.
while time_step_calculator.not_done(
Expand Down Expand Up @@ -1148,11 +1120,11 @@ def run_simulation(
Possible cause is negative temperatures or densities.
Output file contains all profiles up to the last valid step.
""")
sim_error = SimError.NAN_DETECTED
sim_error = output.SimError.NAN_DETECTED
break

# Log final timestep
if log_timestep_info and sim_error == SimError.NO_ERROR:
if log_timestep_info and sim_error == output.SimError.NO_ERROR:
# The "sim_state" here has been updated by the loop above.
_log_timestep(
sim_state.t,
Expand Down Expand Up @@ -1208,7 +1180,7 @@ def run_simulation(
simulation_time,
wall_clock_time_elapsed,
)
return ToraxSimOutputs(
return output.ToraxSimOutputs(
sim_error=sim_error, sim_history=tuple(sim_history)
)

Expand Down
13 changes: 6 additions & 7 deletions torax/tests/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from torax import geometry
from torax import geometry_provider
from torax import output
from torax import sim as sim_lib
from torax import state
from torax.config import profile_conditions as profile_conditions_lib
from torax.config import runtime_params as general_runtime_params
Expand Down Expand Up @@ -101,10 +100,10 @@ def test_state_history_init(self):
inner_solver_iterations=1,
),
)
sim_error = sim_lib.SimError.NO_ERROR
sim_error = output.SimError.NO_ERROR

output.StateHistory(
sim_lib.ToraxSimOutputs(sim_error=sim_error, sim_history=(sim_state,))
output.ToraxSimOutputs(sim_error=sim_error, sim_history=(sim_state,))
)

def test_state_history_to_xr(self):
Expand All @@ -124,9 +123,9 @@ def test_state_history_to_xr(self):
inner_solver_iterations=1,
),
)
sim_error = sim_lib.SimError.NO_ERROR
sim_error = output.SimError.NO_ERROR
history = output.StateHistory(
sim_lib.ToraxSimOutputs(sim_error=sim_error, sim_history=(sim_state,))
output.ToraxSimOutputs(sim_error=sim_error, sim_history=(sim_state,))
)

history.simulation_output_to_xr(self.geo)
Expand All @@ -149,9 +148,9 @@ def test_load_core_profiles_from_xr(self):
inner_solver_iterations=1,
),
)
sim_error = sim_lib.SimError.NO_ERROR
sim_error = output.SimError.NO_ERROR
history = output.StateHistory(
sim_lib.ToraxSimOutputs(sim_error=sim_error, sim_history=(sim_state,))
output.ToraxSimOutputs(sim_error=sim_error, sim_history=(sim_state,))
)
# Output to an xr.Dataset and save to disk.
ds = history.simulation_output_to_xr(self.geo)
Expand Down
2 changes: 1 addition & 1 deletion torax/tests/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def test_nans_trigger_error(self):
sim_outputs = sim.run()

state_history = output.StateHistory(sim_outputs)
self.assertEqual(state_history.sim_error, sim_lib.SimError.NAN_DETECTED)
self.assertEqual(state_history.sim_error, output.SimError.NAN_DETECTED)
assert (
state_history.times[-1]
< config_module.CONFIG['runtime_params']['numerics']['t_final']
Expand Down

0 comments on commit 62affca

Please sign in to comment.