From b8c6d3b5ec4c32247f6b799cf432c79de4cb45b6 Mon Sep 17 00:00:00 2001 From: Tamara Norman Date: Wed, 15 Jan 2025 04:18:32 -0800 Subject: [PATCH] move users of build_sim_object to use create method PiperOrigin-RevId: 715742162 --- docs/model_integration.rst | 12 +- torax/config/build_sim.py | 2 +- torax/sim.py | 297 ++++++++++++------------- torax/tests/sim.py | 2 +- torax/tests/sim_custom_sources.py | 2 +- torax/tests/sim_time_dependence.py | 2 +- torax/tests/test_data/test_explicit.py | 2 +- 7 files changed, 154 insertions(+), 165 deletions(-) diff --git a/docs/model_integration.rst b/docs/model_integration.rst index 768e59f3..e19e8b91 100644 --- a/docs/model_integration.rst +++ b/docs/model_integration.rst @@ -312,14 +312,6 @@ the |torax.sim.Sim|_ object. .. code-block:: python - # in sim.py. Copied here for reference, no need to modify this. - - def build_sim_object( - ... - transport_model_builder: transport_model_lib.TransportModelBuilder, - ... - ) -> Sim: - # in your TORAX configuration or run file .py my_custom_transport_builder = MyCustomTransportModelBuilder() @@ -334,7 +326,7 @@ the |torax.sim.Sim|_ object. my_custom_transport_builder.runtime_params.bar = 4.0 # Build the Sim object. - sim_object = sim_lib.build_sim_object( + sim_object = sim_lib.Sim.create( ..., transport_model_builder=my_custom_transport_builder, ... @@ -348,7 +340,7 @@ As of 7 June 2024, you cannot instantiate and configure a custom transport model via the config dictionary. You may still configure the other components of your TORAX simulation via the config dict and use other functions in |torax.config.build_sim|_ to convert those to the objects you can pass into -``build_sim_object()``. We are working on making this easier, but reach out +``Sim.create()``. We are working on making this easier, but reach out if this is something you need. diff --git a/torax/config/build_sim.py b/torax/config/build_sim.py index f2eb53af..7c9e3902 100644 --- a/torax/config/build_sim.py +++ b/torax/config/build_sim.py @@ -270,7 +270,7 @@ def build_sim_from_config( else: file_restart = None - return sim_lib.build_sim_object( + return sim_lib.Sim.create( runtime_params=runtime_params, geometry_provider=geo_provider, source_models_builder=build_sources_builder_from_config( diff --git a/torax/sim.py b/torax/sim.py index 0a6a75bc..d02dddc4 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -826,6 +826,153 @@ def run( log_timestep_info=log_timestep_info, ) + @classmethod + def create( + cls, + *, + runtime_params: general_runtime_params.GeneralRuntimeParams, + geometry_provider: geometry_provider_lib.GeometryProvider, + stepper_builder: stepper_lib.StepperBuilder, + transport_model_builder: transport_model_lib.TransportModelBuilder, + source_models_builder: source_models_lib.SourceModelsBuilder, + pedestal_model_builder: pedestal_model_lib.PedestalModelBuilder, + time_step_calculator: Optional[ts.TimeStepCalculator] = None, + file_restart: Optional[general_runtime_params.FileRestart] = None, + ) -> Sim: + """Builds a Sim object from the input runtime params and sim components. + + Args: + runtime_params: The input runtime params used throughout the simulation + run. + geometry_provider: The geometry used throughout the simulation run. + stepper_builder: A callable to build the stepper. The stepper has already + been factored out of the config. + transport_model_builder: A callable to build the transport model. + source_models_builder: Builds the SourceModels and holds its + runtime_params. + pedestal_model_builder: A callable to build the pedestal model. + time_step_calculator: The time_step_calculator, if built, otherwise a + ChiTimeStepCalculator will be built by default. + file_restart: If provided we will reconstruct the initial state from the + provided file at the given time step. This state from the file will only + be used for constructing the initial state (as well as the config) and + for all subsequent steps, the evolved state and runtime parameters from + config are used. + + Returns: + sim: The built Sim instance. + """ + + transport_model = transport_model_builder() + pedestal_model = pedestal_model_builder() + + # TODO(b/385788907): Document all changes that lead to recompilations. + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice( + runtime_params=runtime_params, + source_runtime_params=source_models_builder.runtime_params, + torax_mesh=geometry_provider.torax_mesh, + stepper=stepper_builder.runtime_params, + ) + ) + dynamic_runtime_params_slice_provider = ( + runtime_params_slice.DynamicRuntimeParamsSliceProvider( + runtime_params=runtime_params, + transport=transport_model_builder.runtime_params, + sources=source_models_builder.runtime_params, + stepper=stepper_builder.runtime_params, + torax_mesh=geometry_provider.torax_mesh, + pedestal=pedestal_model_builder.runtime_params, + ) + ) + source_models = source_models_builder() + stepper = stepper_builder(transport_model, source_models, pedestal_model) + + if time_step_calculator is None: + time_step_calculator = chi_time_step_calculator.ChiTimeStepCalculator() + + # Build dynamic_runtime_params_slice at t_initial for initial conditions. + dynamic_runtime_params_slice_for_init, geo_for_init = ( + get_consistent_dynamic_runtime_params_slice_and_geometry( + runtime_params.numerics.t_initial, + dynamic_runtime_params_slice_provider, + geometry_provider, + ) + ) + if file_restart is not None and file_restart.do_restart: + data_tree = output.load_state_file(file_restart.filename) + # Find the closest time in the given dataset. + data_tree = data_tree.sel(time=file_restart.time, method='nearest') + t_restart = data_tree.time.item() + core_profiles_dataset = data_tree.children[output.CORE_PROFILES].dataset + # Remap coordinates in saved file to be consistent with expectations of + # how config_args parses xarrays. + core_profiles_dataset = core_profiles_dataset.rename( + {output.RHO_CELL_NORM: config_args.RHO_NORM} + ) + core_profiles_dataset = core_profiles_dataset.squeeze() + if t_restart != runtime_params.numerics.t_initial: + logging.warning( + 'Requested restart time %f not exactly available in state file %s.' + ' Restarting from closest available time %f instead.', + file_restart.time, + file_restart.filename, + t_restart, + ) + # Override some of dynamic runtime params slice from t=t_initial. + dynamic_runtime_params_slice_for_init, geo_for_init = ( + _override_initial_runtime_params_from_file( + dynamic_runtime_params_slice_for_init, + geo_for_init, + t_restart, + core_profiles_dataset, + ) + ) + post_processed_dataset = data_tree.children[ + output.POST_PROCESSED_OUTPUTS + ].dataset + post_processed_dataset = post_processed_dataset.rename( + {output.RHO_CELL_NORM: config_args.RHO_NORM} + ) + post_processed_dataset = post_processed_dataset.squeeze() + post_processed_outputs = ( + _override_initial_state_post_processed_outputs_from_file( + geo_for_init, + post_processed_dataset, + ) + ) + + step_fn = SimulationStepFn( + stepper=stepper, + time_step_calculator=time_step_calculator, + transport_model=transport_model, + pedestal_model=pedestal_model, + ) + + initial_state = get_initial_state( + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init, + geo=geo_for_init, + step_fn=step_fn, + ) + + # If we are restarting from a file, we need to override the initial state + # post processed outputs such that cumulative outputs remain correct. + if file_restart is not None and file_restart.do_restart: + initial_state = dataclasses.replace( + initial_state, + post_processed_outputs=post_processed_outputs, # pylint: disable=undefined-variable + ) + + return cls( + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider, + geometry_provider=geometry_provider, + initial_state=initial_state, + step_fn=step_fn, + file_restart=file_restart, + ) + def _override_initial_runtime_params_from_file( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, @@ -895,156 +1042,6 @@ def _override_initial_state_post_processed_outputs_from_file( return post_processed_outputs -def build_sim_object( - runtime_params: general_runtime_params.GeneralRuntimeParams, - geometry_provider: geometry_provider_lib.GeometryProvider, - stepper_builder: stepper_lib.StepperBuilder, - transport_model_builder: transport_model_lib.TransportModelBuilder, - source_models_builder: source_models_lib.SourceModelsBuilder, - pedestal_model_builder: pedestal_model_lib.PedestalModelBuilder, - time_step_calculator: Optional[ts.TimeStepCalculator] = None, - file_restart: Optional[general_runtime_params.FileRestart] = None, -) -> Sim: - """Builds a Sim object from the input runtime params and sim components. - - The Sim object provides a container for all the components that go into a - single TORAX simulation run. It gives a way to reuse components without having - to rebuild or recompile them if JAX shapes or static arguments do not change. - - Read more about the Sim object in its class docstring. The use of it is - optional, and users may call `sim.run_simulation()` directly as well. - - Args: - runtime_params: The input runtime params used throughout the simulation run. - geometry_provider: The geometry used throughout the simulation run. - stepper_builder: A callable to build the stepper. The stepper has already - been factored out of the config. - transport_model_builder: A callable to build the transport model. - source_models_builder: Builds the SourceModels and holds its runtime_params. - pedestal_model_builder: A callable to build the pedestal model. - time_step_calculator: The time_step_calculator, if built, otherwise a - ChiTimeStepCalculator will be built by default. - file_restart: If provided we will reconstruct the initial state from the - provided file at the given time step. This state from the file will only - be used for constructing the initial state (as well as the config) and for - all subsequent steps, the evolved state and runtime parameters from config - are used. - - Returns: - sim: The built Sim instance. - """ - - transport_model = transport_model_builder() - pedestal_model = pedestal_model_builder() - - # TODO(b/385788907): Clearly document all changes that lead to recompilations. - static_runtime_params_slice = ( - runtime_params_slice.build_static_runtime_params_slice( - runtime_params=runtime_params, - source_runtime_params=source_models_builder.runtime_params, - torax_mesh=geometry_provider.torax_mesh, - stepper=stepper_builder.runtime_params, - ) - ) - dynamic_runtime_params_slice_provider = ( - runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params=runtime_params, - transport=transport_model_builder.runtime_params, - sources=source_models_builder.runtime_params, - stepper=stepper_builder.runtime_params, - torax_mesh=geometry_provider.torax_mesh, - pedestal=pedestal_model_builder.runtime_params, - ) - ) - source_models = source_models_builder() - stepper = stepper_builder(transport_model, source_models, pedestal_model) - - if time_step_calculator is None: - time_step_calculator = chi_time_step_calculator.ChiTimeStepCalculator() - - # Build dynamic_runtime_params_slice at t_initial for initial conditions. - dynamic_runtime_params_slice_for_init, geo_for_init = ( - get_consistent_dynamic_runtime_params_slice_and_geometry( - runtime_params.numerics.t_initial, - dynamic_runtime_params_slice_provider, - geometry_provider, - ) - ) - if file_restart is not None and file_restart.do_restart: - data_tree = output.load_state_file(file_restart.filename) - # Find the closest time in the given dataset. - data_tree = data_tree.sel(time=file_restart.time, method='nearest') - t_restart = data_tree.time.item() - core_profiles_dataset = data_tree.children[output.CORE_PROFILES].dataset - # Remap coordinates in saved file to be consistent with expectations of - # how config_args parses xarrays. - core_profiles_dataset = core_profiles_dataset.rename( - {output.RHO_CELL_NORM: config_args.RHO_NORM} - ) - core_profiles_dataset = core_profiles_dataset.squeeze() - if t_restart != runtime_params.numerics.t_initial: - logging.warning( - 'Requested restart time %f not exactly available in state file %s.' - ' Restarting from closest available time %f instead.', - file_restart.time, - file_restart.filename, - t_restart, - ) - # Override some of dynamic runtime params slice from t=t_initial. - dynamic_runtime_params_slice_for_init, geo_for_init = ( - _override_initial_runtime_params_from_file( - dynamic_runtime_params_slice_for_init, - geo_for_init, - t_restart, - core_profiles_dataset, - ) - ) - post_processed_dataset = data_tree.children[ - output.POST_PROCESSED_OUTPUTS - ].dataset - post_processed_dataset = post_processed_dataset.rename( - {output.RHO_CELL_NORM: config_args.RHO_NORM} - ) - post_processed_dataset = post_processed_dataset.squeeze() - post_processed_outputs = ( - _override_initial_state_post_processed_outputs_from_file( - geo_for_init, - post_processed_dataset, - ) - ) - - step_fn = SimulationStepFn( - stepper=stepper, - time_step_calculator=time_step_calculator, - transport_model=transport_model, - pedestal_model=pedestal_model, - ) - - initial_state = get_initial_state( - static_runtime_params_slice=static_runtime_params_slice, - dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init, - geo=geo_for_init, - step_fn=step_fn, - ) - - # If we are restarting from a file, we need to override the initial state - # post processed outputs such that cumulative outputs remain correct. - if file_restart is not None and file_restart.do_restart: - initial_state = dataclasses.replace( - initial_state, - post_processed_outputs=post_processed_outputs, # pylint: disable=undefined-variable - ) - - return Sim( - static_runtime_params_slice=static_runtime_params_slice, - dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider, - geometry_provider=geometry_provider, - initial_state=initial_state, - step_fn=step_fn, - file_restart=file_restart, - ) - - def _run_simulation( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, diff --git a/torax/tests/sim.py b/torax/tests/sim.py index 9eaba011..6cf3cb45 100644 --- a/torax/tests/sim.py +++ b/torax/tests/sim.py @@ -481,7 +481,7 @@ def test_no_op(self): geometry.build_circular_geometry() ) - sim = sim_lib.build_sim_object( + sim = sim_lib.Sim.create( runtime_params=runtime_params, geometry_provider=geo_provider, stepper_builder=linear_theta_method.LinearThetaMethodBuilder(), diff --git a/torax/tests/sim_custom_sources.py b/torax/tests/sim_custom_sources.py index 654b9434..ffdbdb4e 100644 --- a/torax/tests/sim_custom_sources.py +++ b/torax/tests/sim_custom_sources.py @@ -191,7 +191,7 @@ def custom_source_formula( geo_provider = geometry_provider.ConstantGeometryProvider( geometry.build_circular_geometry() ) - sim = sim_lib.build_sim_object( + sim = sim_lib.Sim.create( runtime_params=self.test_particle_sources_constant_runtime_params, geometry_provider=geo_provider, stepper_builder=self.stepper_builder, diff --git a/torax/tests/sim_time_dependence.py b/torax/tests/sim_time_dependence.py index 68722112..e6375960 100644 --- a/torax/tests/sim_time_dependence.py +++ b/torax/tests/sim_time_dependence.py @@ -77,7 +77,7 @@ def test_time_dependent_params_update_in_adaptive_dt( # max combined value of Ti_bound_right should be 2.5. Higher will make the # error state from the stepper be 1. time_calculator = fixed_time_step_calculator.FixedTimeStepCalculator() - sim = sim_lib.build_sim_object( + sim = sim_lib.Sim.create( runtime_params=runtime_params, geometry_provider=geometry_provider, stepper_builder=FakeStepperBuilder( diff --git a/torax/tests/test_data/test_explicit.py b/torax/tests/test_data/test_explicit.py index 7c805d56..6cd2c320 100644 --- a/torax/tests/test_data/test_explicit.py +++ b/torax/tests/test_data/test_explicit.py @@ -106,7 +106,7 @@ def get_sim() -> sim_lib.Sim: # config taking place via constructor args in this function. runtime_params = get_runtime_params() geo_provider = get_geometry_provider() - return sim_lib.build_sim_object( + return sim_lib.Sim.create( runtime_params=runtime_params, geometry_provider=geo_provider, source_models_builder=get_sources_builder(),