Skip to content

Commit

Permalink
move users of build_sim_object to use create method
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715742162
  • Loading branch information
tamaranorman authored and Torax team committed Jan 15, 2025
1 parent a475745 commit b8c6d3b
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 165 deletions.
12 changes: 2 additions & 10 deletions docs/model_integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,6 @@ the |torax.sim.Sim|_ object.
.. code-block:: python
# in sim.py. Copied here for reference, no need to modify this.
def build_sim_object(
...
transport_model_builder: transport_model_lib.TransportModelBuilder,
...
) -> Sim:
# in your TORAX configuration or run file .py
my_custom_transport_builder = MyCustomTransportModelBuilder()
Expand All @@ -334,7 +326,7 @@ the |torax.sim.Sim|_ object.
my_custom_transport_builder.runtime_params.bar = 4.0
# Build the Sim object.
sim_object = sim_lib.build_sim_object(
sim_object = sim_lib.Sim.create(
...,
transport_model_builder=my_custom_transport_builder,
...
Expand All @@ -348,7 +340,7 @@ As of 7 June 2024, you cannot instantiate and configure a custom transport model
via the config dictionary. You may still configure the other components of your
TORAX simulation via the config dict and use other functions in
|torax.config.build_sim|_ to convert those to the objects you can pass into
``build_sim_object()``. We are working on making this easier, but reach out
``Sim.create()``. We are working on making this easier, but reach out
if this is something you need.
Expand Down
2 changes: 1 addition & 1 deletion torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def build_sim_from_config(
else:
file_restart = None

return sim_lib.build_sim_object(
return sim_lib.Sim.create(
runtime_params=runtime_params,
geometry_provider=geo_provider,
source_models_builder=build_sources_builder_from_config(
Expand Down
297 changes: 147 additions & 150 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,153 @@ def run(
log_timestep_info=log_timestep_info,
)

@classmethod
def create(
cls,
*,
runtime_params: general_runtime_params.GeneralRuntimeParams,
geometry_provider: geometry_provider_lib.GeometryProvider,
stepper_builder: stepper_lib.StepperBuilder,
transport_model_builder: transport_model_lib.TransportModelBuilder,
source_models_builder: source_models_lib.SourceModelsBuilder,
pedestal_model_builder: pedestal_model_lib.PedestalModelBuilder,
time_step_calculator: Optional[ts.TimeStepCalculator] = None,
file_restart: Optional[general_runtime_params.FileRestart] = None,
) -> Sim:
"""Builds a Sim object from the input runtime params and sim components.
Args:
runtime_params: The input runtime params used throughout the simulation
run.
geometry_provider: The geometry used throughout the simulation run.
stepper_builder: A callable to build the stepper. The stepper has already
been factored out of the config.
transport_model_builder: A callable to build the transport model.
source_models_builder: Builds the SourceModels and holds its
runtime_params.
pedestal_model_builder: A callable to build the pedestal model.
time_step_calculator: The time_step_calculator, if built, otherwise a
ChiTimeStepCalculator will be built by default.
file_restart: If provided we will reconstruct the initial state from the
provided file at the given time step. This state from the file will only
be used for constructing the initial state (as well as the config) and
for all subsequent steps, the evolved state and runtime parameters from
config are used.
Returns:
sim: The built Sim instance.
"""

transport_model = transport_model_builder()
pedestal_model = pedestal_model_builder()

# TODO(b/385788907): Document all changes that lead to recompilations.
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geometry_provider.torax_mesh,
stepper=stepper_builder.runtime_params,
)
)
dynamic_runtime_params_slice_provider = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
transport=transport_model_builder.runtime_params,
sources=source_models_builder.runtime_params,
stepper=stepper_builder.runtime_params,
torax_mesh=geometry_provider.torax_mesh,
pedestal=pedestal_model_builder.runtime_params,
)
)
source_models = source_models_builder()
stepper = stepper_builder(transport_model, source_models, pedestal_model)

if time_step_calculator is None:
time_step_calculator = chi_time_step_calculator.ChiTimeStepCalculator()

# Build dynamic_runtime_params_slice at t_initial for initial conditions.
dynamic_runtime_params_slice_for_init, geo_for_init = (
get_consistent_dynamic_runtime_params_slice_and_geometry(
runtime_params.numerics.t_initial,
dynamic_runtime_params_slice_provider,
geometry_provider,
)
)
if file_restart is not None and file_restart.do_restart:
data_tree = output.load_state_file(file_restart.filename)
# Find the closest time in the given dataset.
data_tree = data_tree.sel(time=file_restart.time, method='nearest')
t_restart = data_tree.time.item()
core_profiles_dataset = data_tree.children[output.CORE_PROFILES].dataset
# Remap coordinates in saved file to be consistent with expectations of
# how config_args parses xarrays.
core_profiles_dataset = core_profiles_dataset.rename(
{output.RHO_CELL_NORM: config_args.RHO_NORM}
)
core_profiles_dataset = core_profiles_dataset.squeeze()
if t_restart != runtime_params.numerics.t_initial:
logging.warning(
'Requested restart time %f not exactly available in state file %s.'
' Restarting from closest available time %f instead.',
file_restart.time,
file_restart.filename,
t_restart,
)
# Override some of dynamic runtime params slice from t=t_initial.
dynamic_runtime_params_slice_for_init, geo_for_init = (
_override_initial_runtime_params_from_file(
dynamic_runtime_params_slice_for_init,
geo_for_init,
t_restart,
core_profiles_dataset,
)
)
post_processed_dataset = data_tree.children[
output.POST_PROCESSED_OUTPUTS
].dataset
post_processed_dataset = post_processed_dataset.rename(
{output.RHO_CELL_NORM: config_args.RHO_NORM}
)
post_processed_dataset = post_processed_dataset.squeeze()
post_processed_outputs = (
_override_initial_state_post_processed_outputs_from_file(
geo_for_init,
post_processed_dataset,
)
)

step_fn = SimulationStepFn(
stepper=stepper,
time_step_calculator=time_step_calculator,
transport_model=transport_model,
pedestal_model=pedestal_model,
)

initial_state = get_initial_state(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init,
geo=geo_for_init,
step_fn=step_fn,
)

# If we are restarting from a file, we need to override the initial state
# post processed outputs such that cumulative outputs remain correct.
if file_restart is not None and file_restart.do_restart:
initial_state = dataclasses.replace(
initial_state,
post_processed_outputs=post_processed_outputs, # pylint: disable=undefined-variable
)

return cls(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
geometry_provider=geometry_provider,
initial_state=initial_state,
step_fn=step_fn,
file_restart=file_restart,
)


def _override_initial_runtime_params_from_file(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
Expand Down Expand Up @@ -895,156 +1042,6 @@ def _override_initial_state_post_processed_outputs_from_file(
return post_processed_outputs


def build_sim_object(
runtime_params: general_runtime_params.GeneralRuntimeParams,
geometry_provider: geometry_provider_lib.GeometryProvider,
stepper_builder: stepper_lib.StepperBuilder,
transport_model_builder: transport_model_lib.TransportModelBuilder,
source_models_builder: source_models_lib.SourceModelsBuilder,
pedestal_model_builder: pedestal_model_lib.PedestalModelBuilder,
time_step_calculator: Optional[ts.TimeStepCalculator] = None,
file_restart: Optional[general_runtime_params.FileRestart] = None,
) -> Sim:
"""Builds a Sim object from the input runtime params and sim components.
The Sim object provides a container for all the components that go into a
single TORAX simulation run. It gives a way to reuse components without having
to rebuild or recompile them if JAX shapes or static arguments do not change.
Read more about the Sim object in its class docstring. The use of it is
optional, and users may call `sim.run_simulation()` directly as well.
Args:
runtime_params: The input runtime params used throughout the simulation run.
geometry_provider: The geometry used throughout the simulation run.
stepper_builder: A callable to build the stepper. The stepper has already
been factored out of the config.
transport_model_builder: A callable to build the transport model.
source_models_builder: Builds the SourceModels and holds its runtime_params.
pedestal_model_builder: A callable to build the pedestal model.
time_step_calculator: The time_step_calculator, if built, otherwise a
ChiTimeStepCalculator will be built by default.
file_restart: If provided we will reconstruct the initial state from the
provided file at the given time step. This state from the file will only
be used for constructing the initial state (as well as the config) and for
all subsequent steps, the evolved state and runtime parameters from config
are used.
Returns:
sim: The built Sim instance.
"""

transport_model = transport_model_builder()
pedestal_model = pedestal_model_builder()

# TODO(b/385788907): Clearly document all changes that lead to recompilations.
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geometry_provider.torax_mesh,
stepper=stepper_builder.runtime_params,
)
)
dynamic_runtime_params_slice_provider = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
transport=transport_model_builder.runtime_params,
sources=source_models_builder.runtime_params,
stepper=stepper_builder.runtime_params,
torax_mesh=geometry_provider.torax_mesh,
pedestal=pedestal_model_builder.runtime_params,
)
)
source_models = source_models_builder()
stepper = stepper_builder(transport_model, source_models, pedestal_model)

if time_step_calculator is None:
time_step_calculator = chi_time_step_calculator.ChiTimeStepCalculator()

# Build dynamic_runtime_params_slice at t_initial for initial conditions.
dynamic_runtime_params_slice_for_init, geo_for_init = (
get_consistent_dynamic_runtime_params_slice_and_geometry(
runtime_params.numerics.t_initial,
dynamic_runtime_params_slice_provider,
geometry_provider,
)
)
if file_restart is not None and file_restart.do_restart:
data_tree = output.load_state_file(file_restart.filename)
# Find the closest time in the given dataset.
data_tree = data_tree.sel(time=file_restart.time, method='nearest')
t_restart = data_tree.time.item()
core_profiles_dataset = data_tree.children[output.CORE_PROFILES].dataset
# Remap coordinates in saved file to be consistent with expectations of
# how config_args parses xarrays.
core_profiles_dataset = core_profiles_dataset.rename(
{output.RHO_CELL_NORM: config_args.RHO_NORM}
)
core_profiles_dataset = core_profiles_dataset.squeeze()
if t_restart != runtime_params.numerics.t_initial:
logging.warning(
'Requested restart time %f not exactly available in state file %s.'
' Restarting from closest available time %f instead.',
file_restart.time,
file_restart.filename,
t_restart,
)
# Override some of dynamic runtime params slice from t=t_initial.
dynamic_runtime_params_slice_for_init, geo_for_init = (
_override_initial_runtime_params_from_file(
dynamic_runtime_params_slice_for_init,
geo_for_init,
t_restart,
core_profiles_dataset,
)
)
post_processed_dataset = data_tree.children[
output.POST_PROCESSED_OUTPUTS
].dataset
post_processed_dataset = post_processed_dataset.rename(
{output.RHO_CELL_NORM: config_args.RHO_NORM}
)
post_processed_dataset = post_processed_dataset.squeeze()
post_processed_outputs = (
_override_initial_state_post_processed_outputs_from_file(
geo_for_init,
post_processed_dataset,
)
)

step_fn = SimulationStepFn(
stepper=stepper,
time_step_calculator=time_step_calculator,
transport_model=transport_model,
pedestal_model=pedestal_model,
)

initial_state = get_initial_state(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init,
geo=geo_for_init,
step_fn=step_fn,
)

# If we are restarting from a file, we need to override the initial state
# post processed outputs such that cumulative outputs remain correct.
if file_restart is not None and file_restart.do_restart:
initial_state = dataclasses.replace(
initial_state,
post_processed_outputs=post_processed_outputs, # pylint: disable=undefined-variable
)

return Sim(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
geometry_provider=geometry_provider,
initial_state=initial_state,
step_fn=step_fn,
file_restart=file_restart,
)


def _run_simulation(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider,
Expand Down
2 changes: 1 addition & 1 deletion torax/tests/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def test_no_op(self):
geometry.build_circular_geometry()
)

sim = sim_lib.build_sim_object(
sim = sim_lib.Sim.create(
runtime_params=runtime_params,
geometry_provider=geo_provider,
stepper_builder=linear_theta_method.LinearThetaMethodBuilder(),
Expand Down
2 changes: 1 addition & 1 deletion torax/tests/sim_custom_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def custom_source_formula(
geo_provider = geometry_provider.ConstantGeometryProvider(
geometry.build_circular_geometry()
)
sim = sim_lib.build_sim_object(
sim = sim_lib.Sim.create(
runtime_params=self.test_particle_sources_constant_runtime_params,
geometry_provider=geo_provider,
stepper_builder=self.stepper_builder,
Expand Down
2 changes: 1 addition & 1 deletion torax/tests/sim_time_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_time_dependent_params_update_in_adaptive_dt(
# max combined value of Ti_bound_right should be 2.5. Higher will make the
# error state from the stepper be 1.
time_calculator = fixed_time_step_calculator.FixedTimeStepCalculator()
sim = sim_lib.build_sim_object(
sim = sim_lib.Sim.create(
runtime_params=runtime_params,
geometry_provider=geometry_provider,
stepper_builder=FakeStepperBuilder(
Expand Down
2 changes: 1 addition & 1 deletion torax/tests/test_data/test_explicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_sim() -> sim_lib.Sim:
# config taking place via constructor args in this function.
runtime_params = get_runtime_params()
geo_provider = get_geometry_provider()
return sim_lib.build_sim_object(
return sim_lib.Sim.create(
runtime_params=runtime_params,
geometry_provider=geo_provider,
source_models_builder=get_sources_builder(),
Expand Down

0 comments on commit b8c6d3b

Please sign in to comment.