diff --git a/torax/fvm/calc_coeffs.py b/torax/fvm/calc_coeffs.py index 63e84b5e..c5df5707 100644 --- a/torax/fvm/calc_coeffs.py +++ b/torax/fvm/calc_coeffs.py @@ -34,6 +34,8 @@ from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib +from torax.sources import source_operations +from torax.sources import source_profile_builders from torax.sources import source_profiles as source_profiles_lib from torax.transport_model import transport_model as transport_model_lib @@ -375,7 +377,7 @@ def _calc_coeffs_full( # This only calculates sources set to implicit in the config. All other # sources are set to 0 (and should have their profiles already calculated in # explicit_source_profiles). - implicit_source_profiles = source_models_lib.build_source_profiles( + implicit_source_profiles = source_profile_builders.build_source_profiles( source_models=source_models, dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, @@ -427,11 +429,11 @@ def _calc_coeffs_full( source_mat_psi = jnp.zeros_like(geo.rho) # fill source vector based on both original and updated core profiles - source_psi = source_models_lib.sum_sources_psi( + source_psi = source_operations.sum_sources_psi( geo, implicit_source_profiles, source_models, - ) + source_models_lib.sum_sources_psi( + ) + source_operations.sum_sources_psi( geo, explicit_source_profiles, source_models, @@ -622,11 +624,11 @@ def _calc_coeffs_full( source_mat_nn = jnp.zeros_like(geo.rho) # density source vector based both on original and updated core profiles - source_ne = source_models_lib.sum_sources_ne( + source_ne = source_operations.sum_sources_ne( geo, explicit_source_profiles, source_models, - ) + source_models_lib.sum_sources_ne( + ) + source_operations.sum_sources_ne( geo, implicit_source_profiles, source_models, @@ -718,21 +720,21 @@ def _calc_coeffs_full( source_mat_ii = jnp.zeros_like(geo.rho) source_mat_ee = jnp.zeros_like(geo.rho) - source_i = source_models_lib.sum_sources_temp_ion( + source_i = source_operations.sum_sources_temp_ion( geo, explicit_source_profiles, source_models, - ) + source_models_lib.sum_sources_temp_ion( + ) + source_operations.sum_sources_temp_ion( geo, implicit_source_profiles, source_models, ) - source_e = source_models_lib.sum_sources_temp_el( + source_e = source_operations.sum_sources_temp_el( geo, explicit_source_profiles, source_models, - ) + source_models_lib.sum_sources_temp_el( + ) + source_operations.sum_sources_temp_el( geo, implicit_source_profiles, source_models, diff --git a/torax/fvm/newton_raphson_solve_block.py b/torax/fvm/newton_raphson_solve_block.py index 7c4c6862..fde19bb0 100644 --- a/torax/fvm/newton_raphson_solve_block.py +++ b/torax/fvm/newton_raphson_solve_block.py @@ -35,6 +35,7 @@ from torax.geometry import geometry from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.sources import source_models as source_models_lib +from torax.sources import source_profile_builders from torax.sources import source_profiles from torax.stepper import predictor_corrector_method from torax.transport_model import transport_model as transport_model_lib @@ -226,7 +227,7 @@ def newton_raphson_solve_block( # Initialized here with correct shapes to help with tracing in case # this is jitted. ( - source_models_lib.build_all_zero_profiles( + source_profile_builders.build_all_zero_profiles( geo_t, source_models, ), diff --git a/torax/fvm/optimizer_solve_block.py b/torax/fvm/optimizer_solve_block.py index 8c9ae79f..5c30ab4b 100644 --- a/torax/fvm/optimizer_solve_block.py +++ b/torax/fvm/optimizer_solve_block.py @@ -29,6 +29,7 @@ from torax.geometry import geometry from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.sources import source_models as source_models_lib +from torax.sources import source_profile_builders from torax.sources import source_profiles from torax.stepper import predictor_corrector_method from torax.transport_model import transport_model as transport_model_lib @@ -157,7 +158,7 @@ def optimizer_solve_block( # Initialized here with correct shapes to help with tracing in case # this is jitted. ( - source_models_lib.build_all_zero_profiles( + source_profile_builders.build_all_zero_profiles( geo_t, source_models, ), diff --git a/torax/fvm/tests/calc_coeffs.py b/torax/fvm/tests/calc_coeffs.py index 8914fda3..af8518ed 100644 --- a/torax/fvm/tests/calc_coeffs.py +++ b/torax/fvm/tests/calc_coeffs.py @@ -24,7 +24,7 @@ from torax.geometry import circular_geometry from torax.pedestal_model import set_tped_nped from torax.sources import runtime_params as source_runtime_params -from torax.sources import source_models as source_models_lib +from torax.sources import source_profile_builders from torax.stepper import runtime_params as stepper_params_lib from torax.tests.test_lib import default_sources from torax.transport_model import constant as constant_transport_model @@ -108,7 +108,7 @@ def test_calc_coeffs_smoke_test( source_models, ) evolving_names = tuple(['temp_ion']) - explicit_source_profiles = source_models_lib.build_source_profiles( + explicit_source_profiles = source_profile_builders.build_source_profiles( static_runtime_params_slice=static_runtime_params_slice, source_models=source_models, dynamic_runtime_params_slice=dynamic_runtime_params_slice, diff --git a/torax/fvm/tests/fvm.py b/torax/fvm/tests/fvm.py index 4d1663a4..a4ea0292 100644 --- a/torax/fvm/tests/fvm.py +++ b/torax/fvm/tests/fvm.py @@ -36,7 +36,7 @@ from torax.geometry import circular_geometry from torax.pedestal_model import set_tped_nped from torax.sources import runtime_params as source_runtime_params -from torax.sources import source_models as source_models_lib +from torax.sources import source_profile_builders from torax.stepper import runtime_params as stepper_runtime_params from torax.tests.test_lib import default_sources from torax.tests.test_lib import torax_refs @@ -443,7 +443,7 @@ def test_nonlinear_solve_block_loss_minimum( source_models, ) evolving_names = tuple(['temp_ion']) - explicit_source_profiles = source_models_lib.build_source_profiles( + explicit_source_profiles = source_profile_builders.build_source_profiles( source_models=source_models, dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, @@ -587,7 +587,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): geo, source_models, ) - explicit_source_profiles = source_models_lib.build_source_profiles( + explicit_source_profiles = source_profile_builders.build_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, geo=geo, @@ -740,7 +740,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self): geo, source_models, ) - explicit_source_profiles = source_models_lib.build_source_profiles( + explicit_source_profiles = source_profile_builders.build_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice_theta0, geo=geo, diff --git a/torax/orchestration/step_function.py b/torax/orchestration/step_function.py index 69def509..15eeec62 100644 --- a/torax/orchestration/step_function.py +++ b/torax/orchestration/step_function.py @@ -31,6 +31,7 @@ from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.sources import ohmic_heat_source from torax.sources import source_models as source_models_lib +from torax.sources import source_profile_builders from torax.sources import source_profiles as source_profiles_lib from torax.stepper import stepper as stepper_lib from torax.time_step_calculator import time_step_calculator as ts @@ -144,7 +145,7 @@ def __call__( # This only computes sources set to explicit in the # DynamicSourceConfigSlice. All implicit sources will have their profiles # set to 0. - explicit_source_profiles = source_models_lib.build_source_profiles( + explicit_source_profiles = source_profile_builders.build_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice_t, static_runtime_params_slice=static_runtime_params_slice, geo=geo_t, diff --git a/torax/sim.py b/torax/sim.py index 867e03a5..44d6f9ea 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -46,6 +46,7 @@ from torax.orchestration import step_function from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.sources import source_models as source_models_lib +from torax.sources import source_profile_builders from torax.sources import source_profiles as source_profiles_lib from torax.stepper import stepper as stepper_lib from torax.time_step_calculator import chi_time_step_calculator @@ -631,7 +632,7 @@ def _run_simulation( geometry_provider=geometry_provider, ) ) - explicit_source_profiles = source_models_lib.build_source_profiles( + explicit_source_profiles = source_profile_builders.build_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, geo=geo, @@ -700,7 +701,7 @@ def get_initial_source_profiles( Implicit and explicit SourceProfiles from source models based on the core profiles from the starting state. """ - implicit_profiles = source_models_lib.build_source_profiles( + implicit_profiles = source_profile_builders.build_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, geo=geo, @@ -709,7 +710,7 @@ def get_initial_source_profiles( explicit=False, ) # Also add in the explicit sources to the initial sources. - explicit_source_profiles = source_models_lib.build_source_profiles( + explicit_source_profiles = source_profile_builders.build_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, geo=geo, diff --git a/torax/sources/ohmic_heat_source.py b/torax/sources/ohmic_heat_source.py index a364c0f3..1547980d 100644 --- a/torax/sources/ohmic_heat_source.py +++ b/torax/sources/ohmic_heat_source.py @@ -32,6 +32,7 @@ from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib +from torax.sources import source_operations @functools.partial( @@ -70,12 +71,14 @@ def calc_psidot( """ consts = constants.CONSTANTS - psi_sources, sigma, sigma_face = source_models_lib.calc_and_sum_sources_psi( - static_runtime_params_slice, - dynamic_runtime_params_slice, - geo, - core_profiles, - source_models, + psi_sources, sigma, sigma_face = ( + source_operations.calc_and_sum_sources_psi( + static_runtime_params_slice, + dynamic_runtime_params_slice, + geo, + core_profiles, + source_models, + ) ) # Calculate transient term toc_psi = ( diff --git a/torax/sources/source_models.py b/torax/sources/source_models.py index c434b871..0734557f 100644 --- a/torax/sources/source_models.py +++ b/torax/sources/source_models.py @@ -17,14 +17,9 @@ from __future__ import annotations from collections.abc import Mapping -import functools -import chex -import jax import jax.numpy as jnp from torax import array_typing -from torax import constants -from torax import jax_utils from torax import state from torax.config import runtime_params_slice from torax.geometry import geometry @@ -33,309 +28,6 @@ from torax.sources import qei_source as qei_source_lib from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib -from torax.sources import source_profiles - - -@functools.partial( - jax_utils.jit, - static_argnames=[ - 'source_models', - 'static_runtime_params_slice', - 'explicit', - ], -) -def build_source_profiles( - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, - source_models: SourceModels, - explicit: bool, -) -> source_profiles.SourceProfiles: - """Builds explicit or implicit source profiles. - - Args: - static_runtime_params_slice: Input config. Cannot change from time step to - time step. - dynamic_runtime_params_slice: Input config for this time step. Can change - from time step to time step. - geo: Geometry of the torus. - core_profiles: Core plasma profiles, either at the start of the time step - (if explicit) or the live profiles being evolved during the time step (if - implicit). - source_models: Functions computing profiles for all TORAX sources/sinks. - explicit: If True, this function should return profiles for all explicit - sources. All implicit sources should be set to 0. And same vice versa. - - Returns: - SourceProfiles for either explicit or implicit sources (and all others set - to zero). - """ - # Bootstrap current is a special-case source with multiple outputs, so handle - # it here. - static_bootstrap_runtime_params = static_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ] - bootstrap_profiles = _build_bootstrap_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_bootstrap_runtime_params, - geo=geo, - core_profiles=core_profiles, - j_bootstrap_source=source_models.j_bootstrap, - explicit=explicit, - ) - other_profiles = _build_standard_source_profiles( - static_runtime_params_slice, - dynamic_runtime_params_slice, - geo, - core_profiles, - source_models, - explicit, - ) - if not explicit: - qei = source_models.qei_source.get_qei( - static_runtime_params_slice=static_runtime_params_slice, - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - geo=geo, - core_profiles=core_profiles, - ) - else: - qei = source_profiles.QeiInfo.zeros(geo) - return source_profiles.SourceProfiles( - profiles=other_profiles, - j_bootstrap=bootstrap_profiles, - qei=qei, - ) - - -def _build_bootstrap_profiles( - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, - j_bootstrap_source: bootstrap_current_source.BootstrapCurrentSource, - explicit: bool = True, - calculate_anyway: bool = False, -) -> source_profiles.BootstrapCurrentProfile: - """Computes the bootstrap current profile. - - Args: - static_runtime_params_slice: Input config. Cannot change from time step to - time step. - static_source_runtime_params: Input runtime parameters specific to the - bootstrap current source that do not change from time step to time step. - dynamic_runtime_params_slice: Input config for this time step. Can change - from time step to time step. - geo: Geometry of the torus. - core_profiles: Core plasma profiles, either at the start of the time step - (if explicit) or the live profiles being evolved during the time step (if - implicit). - j_bootstrap_source: Bootstrap current source used to compute the profile. - explicit: If True, this function should return the profile for an explicit - source. If explicit is True and the bootstrap current source is not - explicit, then this should return all zeros. And same with implicit (if - explicit=False and the source is set to be explicit, then this will return - all zeros). - calculate_anyway: If True, returns values regardless of explicit - - Returns: - Bootstrap current profile. - """ - bootstrap_profile = j_bootstrap_source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo, - core_profiles=core_profiles, - ) - if explicit == static_source_runtime_params.is_explicit | calculate_anyway: - return bootstrap_profile - else: - return source_profiles.BootstrapCurrentProfile.zero_profile(geo) - - -def _build_standard_source_profiles( - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, - source_models: SourceModels, - explicit: bool = True, - calculate_anyway: bool = False, - affected_core_profiles: tuple[source_lib.AffectedCoreProfile, ...] = ( - source_lib.AffectedCoreProfile.PSI, - source_lib.AffectedCoreProfile.NE, - source_lib.AffectedCoreProfile.TEMP_ION, - source_lib.AffectedCoreProfile.TEMP_EL, - ), -) -> dict[str, chex.ArrayTree]: - """Computes sources and builds a kwargs dict for SourceProfiles. - - Args: - static_runtime_params_slice: Input config. Cannot change from time step to - time step. - dynamic_runtime_params_slice: Input config for this time step. Can change - from time step to time step. - geo: Geometry of the torus. - core_profiles: Core plasma profiles, either at the start of the time step - (if explicit) or the live profiles being evolved during the time step (if - implicit). - source_models: Collection of all TORAX sources. - explicit: If True, this function should return the profile for an explicit - source. If explicit is True and a given source is not explicit, then this - function will return zeros for that source. And same with implicit (if - explicit=False and the source is set to be explicit, then this will return - all zeros). - calculate_anyway: If True, returns values regardless of explicit - affected_core_profiles: Populate the output for sources that affect these - core profiles. - - Returns: - dict of source profiles excluding the two special-case sources (bootstrap - and qei). - """ - computed_source_profiles = {} - affected_core_profiles_set = set(affected_core_profiles) - for source_name, source in source_models.standard_sources.items(): - if affected_core_profiles_set.intersection(source.affected_core_profiles): - static_source_runtime_params = static_runtime_params_slice.sources[ - source_name - ] - if ( - explicit - == static_source_runtime_params.is_explicit | calculate_anyway - ): - computed_source_profiles[source_name] = source.get_value( - static_runtime_params_slice, - dynamic_runtime_params_slice, - geo, - core_profiles, - ) - else: - computed_source_profiles[source_name] = jnp.zeros( - source.output_shape_getter(geo)) - return computed_source_profiles - - -def sum_sources_psi( - geo: geometry.Geometry, - source_profile: source_profiles.SourceProfiles, - source_models: SourceModels, -) -> jax.Array: - """Computes psi source values for sim.calc_coeffs.""" - total = source_profile.j_bootstrap.j_bootstrap - for source_name, source in source_models.psi_sources.items(): - total += source.get_source_profile_for_affected_core_profile( - profile=source_profile.profiles[source_name], - affected_core_profile=source_lib.AffectedCoreProfile.PSI.value, - geo=geo, - ) - mu0 = constants.CONSTANTS.mu0 - prefactor = 8 * geo.vpr * jnp.pi**2 * geo.B0 * mu0 * geo.Phib / geo.F**2 - scale_source = lambda src: -src * prefactor - return scale_source(total) - - -def sum_sources_ne( - geo: geometry.Geometry, - source_profile: source_profiles.SourceProfiles, - source_models: SourceModels, -) -> jax.Array: - """Computes ne source values for sim.calc_coeffs.""" - total = jnp.zeros_like(geo.rho) - for source_name, source in source_models.ne_sources.items(): - total += source.get_source_profile_for_affected_core_profile( - profile=source_profile.profiles[source_name], - affected_core_profile=source_lib.AffectedCoreProfile.NE.value, - geo=geo, - ) - return total * geo.vpr - - -def sum_sources_temp_ion( - geo: geometry.Geometry, - source_profile: source_profiles.SourceProfiles, - source_models: SourceModels, -) -> jax.Array: - """Computes temp_ion source values for sim.calc_coeffs.""" - total = jnp.zeros_like(geo.rho) - for source_name, source in source_models.temp_ion_sources.items(): - total += source.get_source_profile_for_affected_core_profile( - profile=source_profile.profiles[source_name], - affected_core_profile=(source_lib.AffectedCoreProfile.TEMP_ION.value), - geo=geo, - ) - return total * geo.vpr - - -def sum_sources_temp_el( - geo: geometry.Geometry, - source_profile: source_profiles.SourceProfiles, - source_models: SourceModels, -) -> jax.Array: - """Computes temp_el source values for sim.calc_coeffs.""" - total = jnp.zeros_like(geo.rho) - for source_name, source in source_models.temp_el_sources.items(): - total += source.get_source_profile_for_affected_core_profile( - profile=source_profile.profiles[source_name], - affected_core_profile=(source_lib.AffectedCoreProfile.TEMP_EL.value), - geo=geo, - ) - return total * geo.vpr - - -def calc_and_sum_sources_psi( - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, - source_models: SourceModels, -) -> tuple[jax.Array, jax.Array, jax.Array]: - """Computes sum of psi sources for psi_dot calculation.""" - - # TODO(b/335597108): Revisit how to calculate this once we enable more - # expensive source functions that might not jittable (like file-based or - # RPC-based sources). - psi_profiles = _build_standard_source_profiles( - static_runtime_params_slice, - dynamic_runtime_params_slice, - geo, - core_profiles, - source_models, - calculate_anyway=True, - affected_core_profiles=(source_lib.AffectedCoreProfile.PSI,), - ) - total = 0 - for source_name, source in source_models.psi_sources.items(): - total += source.get_source_profile_for_affected_core_profile( - profile=psi_profiles[source_name], - affected_core_profile=source_lib.AffectedCoreProfile.PSI.value, - geo=geo, - ) - static_bootstrap_runtime_params = static_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ] - j_bootstrap_profiles = _build_bootstrap_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_bootstrap_runtime_params, - geo=geo, - core_profiles=core_profiles, - j_bootstrap_source=source_models.j_bootstrap, - calculate_anyway=True, - ) - total += j_bootstrap_profiles.j_bootstrap - - mu0 = constants.CONSTANTS.mu0 - prefactor = 8 * geo.vpr * jnp.pi**2 * geo.B0 * mu0 * geo.Phib / geo.F**2 - scale_source = lambda src: -src * prefactor - - return ( - scale_source(total), - j_bootstrap_profiles.sigma, - j_bootstrap_profiles.sigma_face, - ) class SourceModels: @@ -705,19 +397,3 @@ def runtime_params(self) -> dict[str, runtime_params_lib.RuntimeParams]: source_name: builder.runtime_params for source_name, builder in self.source_builders.items() } - - -def build_all_zero_profiles( - geo: geometry.Geometry, - source_models: SourceModels, -) -> source_profiles.SourceProfiles: - """Returns a SourceProfiles object with all zero profiles.""" - profiles = { - source_name: jnp.zeros(source_model.output_shape_getter(geo)) - for source_name, source_model in source_models.standard_sources.items() - } - return source_profiles.SourceProfiles( - profiles=profiles, - j_bootstrap=source_profiles.BootstrapCurrentProfile.zero_profile(geo), - qei=source_profiles.QeiInfo.zeros(geo), - ) diff --git a/torax/sources/source_operations.py b/torax/sources/source_operations.py new file mode 100644 index 00000000..9917a609 --- /dev/null +++ b/torax/sources/source_operations.py @@ -0,0 +1,148 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for building source profiles in TORAX.""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp +from torax import constants +from torax import state +from torax.config import runtime_params_slice +from torax.geometry import geometry +from torax.sources import source as source_lib +from torax.sources import source_models as source_models_lib +from torax.sources import source_profile_builders +from torax.sources import source_profiles + + +def sum_sources_psi( + geo: geometry.Geometry, + source_profile: source_profiles.SourceProfiles, + source_models: source_models_lib.SourceModels, +) -> jax.Array: + """Computes psi source values for sim.calc_coeffs.""" + total = source_profile.j_bootstrap.j_bootstrap + for source_name, source in source_models.psi_sources.items(): + total += source.get_source_profile_for_affected_core_profile( + profile=source_profile.profiles[source_name], + affected_core_profile=source_lib.AffectedCoreProfile.PSI.value, + geo=geo, + ) + mu0 = constants.CONSTANTS.mu0 + prefactor = 8 * geo.vpr * jnp.pi**2 * geo.B0 * mu0 * geo.Phib / geo.F**2 + scale_source = lambda src: -src * prefactor + return scale_source(total) + + +def sum_sources_ne( + geo: geometry.Geometry, + source_profile: source_profiles.SourceProfiles, + source_models: source_models_lib.SourceModels, +) -> jax.Array: + """Computes ne source values for sim.calc_coeffs.""" + total = jnp.zeros_like(geo.rho) + for source_name, source in source_models.ne_sources.items(): + total += source.get_source_profile_for_affected_core_profile( + profile=source_profile.profiles[source_name], + affected_core_profile=source_lib.AffectedCoreProfile.NE.value, + geo=geo, + ) + return total * geo.vpr + + +def sum_sources_temp_ion( + geo: geometry.Geometry, + source_profile: source_profiles.SourceProfiles, + source_models: source_models_lib.SourceModels, +) -> jax.Array: + """Computes temp_ion source values for sim.calc_coeffs.""" + total = jnp.zeros_like(geo.rho) + for source_name, source in source_models.temp_ion_sources.items(): + total += source.get_source_profile_for_affected_core_profile( + profile=source_profile.profiles[source_name], + affected_core_profile=(source_lib.AffectedCoreProfile.TEMP_ION.value), + geo=geo, + ) + return total * geo.vpr + + +def sum_sources_temp_el( + geo: geometry.Geometry, + source_profile: source_profiles.SourceProfiles, + source_models: source_models_lib.SourceModels, +) -> jax.Array: + """Computes temp_el source values for sim.calc_coeffs.""" + total = jnp.zeros_like(geo.rho) + for source_name, source in source_models.temp_el_sources.items(): + total += source.get_source_profile_for_affected_core_profile( + profile=source_profile.profiles[source_name], + affected_core_profile=(source_lib.AffectedCoreProfile.TEMP_EL.value), + geo=geo, + ) + return total * geo.vpr + + +def calc_and_sum_sources_psi( + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + source_models: source_models_lib.SourceModels, +) -> tuple[jax.Array, jax.Array, jax.Array]: + """Computes sum of psi sources for psi_dot calculation.""" + + # TODO(b/335597108): Revisit how to calculate this once we enable more + # expensive source functions that might not jittable (like file-based or + # RPC-based sources). + psi_profiles = source_profile_builders.build_standard_source_profiles( + static_runtime_params_slice, + dynamic_runtime_params_slice, + geo, + core_profiles, + source_models, + calculate_anyway=True, + affected_core_profiles=(source_lib.AffectedCoreProfile.PSI,), + ) + total = 0 + for source_name, source in source_models.psi_sources.items(): + total += source.get_source_profile_for_affected_core_profile( + profile=psi_profiles[source_name], + affected_core_profile=source_lib.AffectedCoreProfile.PSI.value, + geo=geo, + ) + static_bootstrap_runtime_params = static_runtime_params_slice.sources[ + source_models.j_bootstrap_name + ] + j_bootstrap_profiles = source_profile_builders.build_bootstrap_profiles( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_runtime_params_slice, + static_source_runtime_params=static_bootstrap_runtime_params, + geo=geo, + core_profiles=core_profiles, + j_bootstrap_source=source_models.j_bootstrap, + calculate_anyway=True, + ) + total += j_bootstrap_profiles.j_bootstrap + + mu0 = constants.CONSTANTS.mu0 + prefactor = 8 * geo.vpr * jnp.pi**2 * geo.B0 * mu0 * geo.Phib / geo.F**2 + scale_source = lambda src: -src * prefactor + + return ( + scale_source(total), + j_bootstrap_profiles.sigma, + j_bootstrap_profiles.sigma_face, + ) diff --git a/torax/sources/source_profile_builders.py b/torax/sources/source_profile_builders.py new file mode 100644 index 00000000..1359a2df --- /dev/null +++ b/torax/sources/source_profile_builders.py @@ -0,0 +1,229 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for building source profiles in TORAX.""" + +from __future__ import annotations + +import functools + +import chex +import jax.numpy as jnp +from torax import jax_utils +from torax import state +from torax.config import runtime_params_slice +from torax.geometry import geometry +from torax.sources import bootstrap_current_source +from torax.sources import runtime_params as runtime_params_lib +from torax.sources import source as source_lib +from torax.sources import source_models as source_models_lib +from torax.sources import source_profiles + + +@functools.partial( + jax_utils.jit, + static_argnames=[ + 'source_models', + 'static_runtime_params_slice', + 'explicit', + ], +) +def build_source_profiles( + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + source_models: source_models_lib.SourceModels, + explicit: bool, +) -> source_profiles.SourceProfiles: + """Builds explicit or implicit source profiles. + + Args: + static_runtime_params_slice: Input config. Cannot change from time step to + time step. + dynamic_runtime_params_slice: Input config for this time step. Can change + from time step to time step. + geo: Geometry of the torus. + core_profiles: Core plasma profiles, either at the start of the time step + (if explicit) or the live profiles being evolved during the time step (if + implicit). + source_models: Functions computing profiles for all TORAX sources/sinks. + explicit: If True, this function should return profiles for all explicit + sources. All implicit sources should be set to 0. And same vice versa. + + Returns: + SourceProfiles for either explicit or implicit sources (and all others set + to zero). + """ + # Bootstrap current is a special-case source with multiple outputs, so handle + # it here. + static_bootstrap_runtime_params = static_runtime_params_slice.sources[ + source_models.j_bootstrap_name + ] + bootstrap_profiles = build_bootstrap_profiles( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_runtime_params_slice, + static_source_runtime_params=static_bootstrap_runtime_params, + geo=geo, + core_profiles=core_profiles, + j_bootstrap_source=source_models.j_bootstrap, + explicit=explicit, + ) + other_profiles = build_standard_source_profiles( + static_runtime_params_slice, + dynamic_runtime_params_slice, + geo, + core_profiles, + source_models, + explicit, + ) + if not explicit: + qei = source_models.qei_source.get_qei( + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + geo=geo, + core_profiles=core_profiles, + ) + else: + qei = source_profiles.QeiInfo.zeros(geo) + return source_profiles.SourceProfiles( + profiles=other_profiles, + j_bootstrap=bootstrap_profiles, + qei=qei, + ) + + +def build_bootstrap_profiles( + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + j_bootstrap_source: bootstrap_current_source.BootstrapCurrentSource, + explicit: bool = True, + calculate_anyway: bool = False, +) -> source_profiles.BootstrapCurrentProfile: + """Computes the bootstrap current profile. + + Args: + static_runtime_params_slice: Input config. Cannot change from time step to + time step. + static_source_runtime_params: Input runtime parameters specific to the + bootstrap current source that do not change from time step to time step. + dynamic_runtime_params_slice: Input config for this time step. Can change + from time step to time step. + geo: Geometry of the torus. + core_profiles: Core plasma profiles, either at the start of the time step + (if explicit) or the live profiles being evolved during the time step (if + implicit). + j_bootstrap_source: Bootstrap current source used to compute the profile. + explicit: If True, this function should return the profile for an explicit + source. If explicit is True and the bootstrap current source is not + explicit, then this should return all zeros. And same with implicit (if + explicit=False and the source is set to be explicit, then this will return + all zeros). + calculate_anyway: If True, returns values regardless of explicit + + Returns: + Bootstrap current profile. + """ + bootstrap_profile = j_bootstrap_source.get_value( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_runtime_params_slice, + geo=geo, + core_profiles=core_profiles, + ) + if explicit == static_source_runtime_params.is_explicit | calculate_anyway: + return bootstrap_profile + else: + return source_profiles.BootstrapCurrentProfile.zero_profile(geo) + + +def build_standard_source_profiles( + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + source_models: source_models_lib.SourceModels, + explicit: bool = True, + calculate_anyway: bool = False, + affected_core_profiles: tuple[source_lib.AffectedCoreProfile, ...] = ( + source_lib.AffectedCoreProfile.PSI, + source_lib.AffectedCoreProfile.NE, + source_lib.AffectedCoreProfile.TEMP_ION, + source_lib.AffectedCoreProfile.TEMP_EL, + ), +) -> dict[str, chex.ArrayTree]: + """Computes sources and builds a kwargs dict for SourceProfiles. + + Args: + static_runtime_params_slice: Input config. Cannot change from time step to + time step. + dynamic_runtime_params_slice: Input config for this time step. Can change + from time step to time step. + geo: Geometry of the torus. + core_profiles: Core plasma profiles, either at the start of the time step + (if explicit) or the live profiles being evolved during the time step (if + implicit). + source_models: Collection of all TORAX sources. + explicit: If True, this function should return the profile for an explicit + source. If explicit is True and a given source is not explicit, then this + function will return zeros for that source. And same with implicit (if + explicit=False and the source is set to be explicit, then this will return + all zeros). + calculate_anyway: If True, returns values regardless of explicit + affected_core_profiles: Populate the output for sources that affect these + core profiles. + + Returns: + dict of source profiles excluding the two special-case sources (bootstrap + and qei). + """ + computed_source_profiles = {} + affected_core_profiles_set = set(affected_core_profiles) + for source_name, source in source_models.standard_sources.items(): + if affected_core_profiles_set.intersection(source.affected_core_profiles): + static_source_runtime_params = static_runtime_params_slice.sources[ + source_name + ] + if ( + explicit + == static_source_runtime_params.is_explicit | calculate_anyway + ): + computed_source_profiles[source_name] = source.get_value( + static_runtime_params_slice, + dynamic_runtime_params_slice, + geo, + core_profiles, + ) + else: + computed_source_profiles[source_name] = jnp.zeros( + source.output_shape_getter(geo)) + return computed_source_profiles + + +def build_all_zero_profiles( + geo: geometry.Geometry, + source_models: source_models_lib.SourceModels, +) -> source_profiles.SourceProfiles: + """Returns a SourceProfiles object with all zero profiles.""" + profiles = { + source_name: jnp.zeros(source_model.output_shape_getter(geo)) + for source_name, source_model in source_models.standard_sources.items() + } + return source_profiles.SourceProfiles( + profiles=profiles, + j_bootstrap=source_profiles.BootstrapCurrentProfile.zero_profile(geo), + qei=source_profiles.QeiInfo.zeros(geo), + ) diff --git a/torax/sources/tests/source_models_test.py b/torax/sources/tests/source_models_test.py index b4847e51..2e8dd9a4 100644 --- a/torax/sources/tests/source_models_test.py +++ b/torax/sources/tests/source_models_test.py @@ -15,236 +15,18 @@ from absl.testing import absltest from absl.testing import parameterized -import jax -import jax.numpy as jnp import numpy as np from torax import core_profile_setters from torax.config import runtime_params as runtime_params_lib from torax.config import runtime_params_slice from torax.geometry import circular_geometry -from torax.geometry import geometry from torax.sources import generic_current_source from torax.sources import runtime_params as source_runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib -from torax.sources import source_profiles as source_profiles_lib -from torax.stepper import runtime_params as stepper_runtime_params_lib -from torax.tests.test_lib import default_sources -@dataclasses.dataclass(frozen=True) -class FooSource(source_lib.Source): - """A test source.""" - - @property - def source_name(self) -> str: - return 'foo' - - @property - def affected_core_profiles( - self, - ) -> tuple[source_lib.AffectedCoreProfile, ...]: - return ( - source_lib.AffectedCoreProfile.TEMP_EL, - source_lib.AffectedCoreProfile.NE, - ) - - @property - def output_shape_getter(self) -> source_lib.SourceOutputShapeFunction: - return source_lib.get_ion_el_output_shape - - -_FooSourceBuilder = source_lib.make_source_builder( - FooSource, -) - - -class SourceProfilesTest(parameterized.TestCase): - """Tests for computing source profiles.""" - - def test_computing_source_profiles_works_with_all_defaults(self): - """Tests that you can compute source profiles with all defaults.""" - runtime_params = runtime_params_lib.GeneralRuntimeParams() - geo = circular_geometry.build_circular_geometry() - source_models_builder = source_models_lib.SourceModelsBuilder() - source_models = source_models_builder() - dynamic_runtime_params_slice = ( - runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params, - sources=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - )( - t=runtime_params.numerics.t_initial, - ) - ) - static_slice = runtime_params_slice.build_static_runtime_params_slice( - runtime_params=runtime_params, - source_runtime_params=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - ) - core_profiles = core_profile_setters.initial_core_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_slice, - geo=geo, - source_models=source_models, - ) - stepper_params = stepper_runtime_params_lib.RuntimeParams() - static_runtime_params_slice = ( - runtime_params_slice.build_static_runtime_params_slice( - runtime_params=runtime_params, - source_runtime_params=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - stepper=stepper_params, - ) - ) - _ = source_models_lib.build_source_profiles( - static_runtime_params_slice, - dynamic_runtime_params_slice, - geo, - core_profiles, - source_models, - explicit=True, - ) - _ = source_models_lib.build_source_profiles( - static_runtime_params_slice, - dynamic_runtime_params_slice, - geo, - core_profiles, - source_models, - explicit=False, - ) - - def test_summed_temp_ion_profiles_dont_change_when_jitting(self): - """Test that sum_sources_temp_{ion|el} works with jitting.""" - geo = circular_geometry.build_circular_geometry() - - # Use the default sources where the generic_ion_el_heat_source, - # fusion_heat_source, and ohmic_heat_source are included and produce - # profiles for ion and electron heat. - # temperature. - source_models_builder = default_sources.get_default_sources_builder() - source_models = source_models_builder() - # Make some dummy source profiles that could have come from these sources. - ones = jnp.ones(source_lib.get_cell_profile_shape(geo)) - profiles = source_profiles_lib.SourceProfiles( - j_bootstrap=source_profiles_lib.BootstrapCurrentProfile.zero_profile( - geo - ), - qei=source_profiles_lib.QeiInfo.zeros(geo), - profiles={ - 'generic_ion_el_heat_source': jnp.stack([ones, ones * 2]), - 'fusion_heat_source': jnp.stack([ones * 3, ones * 4]), - 'bremsstrahlung_heat_sink': -ones, - 'ohmic_heat_source': ones * 5, # only used for electron temp. - }, - ) - with self.subTest('without_jit'): - summed_temp_ion = source_models_lib.sum_sources_temp_ion( - geo, profiles, source_models - ) - np.testing.assert_allclose(summed_temp_ion, ones * 4 * geo.vpr) - summed_temp_el = source_models_lib.sum_sources_temp_el( - geo, profiles, source_models - ) - np.testing.assert_allclose(summed_temp_el, ones * 10 * geo.vpr) - - with self.subTest('with_jit'): - sum_temp_ion = jax.jit( - source_models_lib.sum_sources_temp_ion, - static_argnames=['source_models'], - ) - jitted_temp_ion = sum_temp_ion(geo, profiles, source_models) - np.testing.assert_allclose(jitted_temp_ion, ones * 4 * geo.vpr) - sum_temp_el = jax.jit( - source_models_lib.sum_sources_temp_el, - static_argnames=['source_models'], - ) - jitted_temp_el = sum_temp_el(geo, profiles, source_models) - np.testing.assert_allclose(jitted_temp_el, ones * 10 * geo.vpr) - - def test_custom_source_profiles_dont_change_when_jitted(self): - """Test that custom source profiles don't change profiles when jitted.""" - source_name = 'foo' - - def foo_formula( - unused_dcs, - unused_static_runtime_params_slice, - geo: geometry.Geometry, - unused_source_name: str, - unused_state, - unused_source_models, - ): - return jnp.stack([ - jnp.zeros(source_lib.get_cell_profile_shape(geo)), - jnp.ones(source_lib.get_cell_profile_shape(geo)), - ]) - - foo_source_builder = source_lib.make_source_builder( - FooSource, model_func=foo_formula - )() - # Set the source mode to MODEL_BASED. - foo_source_builder.runtime_params.mode = ( - source_runtime_params_lib.Mode.MODEL_BASED - ) - source_models_builder = source_models_lib.SourceModelsBuilder( - {source_name: foo_source_builder}, - ) - source_models = source_models_builder() - runtime_params = runtime_params_lib.GeneralRuntimeParams() - geo = circular_geometry.build_circular_geometry() - dynamic_runtime_params_slice = ( - runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params, - sources=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - )( - t=runtime_params.numerics.t_initial, - ) - ) - static_slice = runtime_params_slice.build_static_runtime_params_slice( - runtime_params=runtime_params, - source_runtime_params=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - ) - core_profiles = core_profile_setters.initial_core_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_slice, - geo=geo, - source_models=source_models, - ) - - def compute_and_sum_profiles(): - profiles = source_models_lib.build_source_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_slice, - geo=geo, - core_profiles=core_profiles, - source_models=source_models, - # Configs set sources to implicit by default, so set this to False to - # calculate the custom source's profile. - explicit=False, - ) - ne = source_models_lib.sum_sources_ne(geo, profiles, source_models) - temp_el = source_models_lib.sum_sources_temp_el( - geo, profiles, source_models - ) - return (ne, temp_el) - - expected_ne = ( - jnp.ones(source_lib.get_cell_profile_shape(geo)) * geo.vpr - ) - expected_temp_el = jnp.zeros( - source_lib.get_cell_profile_shape(geo) - ) - with self.subTest('without_jit'): - (ne, temp_el) = compute_and_sum_profiles() - np.testing.assert_allclose(ne, expected_ne) - np.testing.assert_allclose(temp_el, expected_temp_el) - with self.subTest('with_jit'): - jitted_compute_and_sum = jax.jit(compute_and_sum_profiles) - (ne, temp_el) = jitted_compute_and_sum() - np.testing.assert_allclose(ne, expected_ne) - np.testing.assert_allclose(temp_el, expected_temp_el) +class SourceModelsTest(parameterized.TestCase): @parameterized.parameters( (source_runtime_params_lib.Mode.ZERO, 0), diff --git a/torax/sources/tests/source_operations_test.py b/torax/sources/tests/source_operations_test.py new file mode 100644 index 00000000..c5b6fd13 --- /dev/null +++ b/torax/sources/tests/source_operations_test.py @@ -0,0 +1,191 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import numpy as np +from torax import core_profile_setters +from torax.config import runtime_params as runtime_params_lib +from torax.config import runtime_params_slice +from torax.geometry import circular_geometry +from torax.geometry import geometry +from torax.sources import runtime_params as source_runtime_params_lib +from torax.sources import source as source_lib +from torax.sources import source_models as source_models_lib +from torax.sources import source_operations +from torax.sources import source_profile_builders +from torax.sources import source_profiles as source_profiles_lib +from torax.tests.test_lib import default_sources + + +@dataclasses.dataclass(frozen=True) +class FooSource(source_lib.Source): + """A test source.""" + + @property + def source_name(self) -> str: + return 'foo' + + @property + def affected_core_profiles( + self, + ) -> tuple[source_lib.AffectedCoreProfile, ...]: + return ( + source_lib.AffectedCoreProfile.TEMP_EL, + source_lib.AffectedCoreProfile.NE, + ) + + @property + def output_shape_getter(self) -> source_lib.SourceOutputShapeFunction: + return source_lib.get_ion_el_output_shape + + +class SourceOperationsTest(parameterized.TestCase): + + def test_summed_temp_ion_profiles_dont_change_when_jitting(self): + geo = circular_geometry.build_circular_geometry() + + # Use the default sources where the generic_ion_el_heat_source, + # fusion_heat_source, and ohmic_heat_source are included and produce + # profiles for ion and electron heat. + # temperature. + source_models_builder = default_sources.get_default_sources_builder() + source_models = source_models_builder() + # Make some dummy source profiles that could have come from these sources. + ones = jnp.ones(source_lib.get_cell_profile_shape(geo)) + profiles = source_profiles_lib.SourceProfiles( + j_bootstrap=source_profiles_lib.BootstrapCurrentProfile.zero_profile( + geo + ), + qei=source_profiles_lib.QeiInfo.zeros(geo), + profiles={ + 'generic_ion_el_heat_source': jnp.stack([ones, ones * 2]), + 'fusion_heat_source': jnp.stack([ones * 3, ones * 4]), + 'bremsstrahlung_heat_sink': -ones, + 'ohmic_heat_source': ones * 5, # only used for electron temp. + }, + ) + with self.subTest('without_jit'): + summed_temp_ion = source_operations.sum_sources_temp_ion( + geo, profiles, source_models + ) + np.testing.assert_allclose(summed_temp_ion, ones * 4 * geo.vpr) + summed_temp_el = source_operations.sum_sources_temp_el( + geo, profiles, source_models + ) + np.testing.assert_allclose(summed_temp_el, ones * 10 * geo.vpr) + + with self.subTest('with_jit'): + sum_temp_ion = jax.jit( + source_operations.sum_sources_temp_ion, + static_argnames=['source_models'], + ) + jitted_temp_ion = sum_temp_ion(geo, profiles, source_models) + np.testing.assert_allclose(jitted_temp_ion, ones * 4 * geo.vpr) + sum_temp_el = jax.jit( + source_operations.sum_sources_temp_el, + static_argnames=['source_models'], + ) + jitted_temp_el = sum_temp_el(geo, profiles, source_models) + np.testing.assert_allclose(jitted_temp_el, ones * 10 * geo.vpr) + + def test_custom_source_profiles_dont_change_when_jitted(self): + source_name = 'foo' + + def foo_formula( + unused_dcs, + unused_static_runtime_params_slice, + geo: geometry.Geometry, + unused_source_name: str, + unused_state, + unused_source_models, + ): + return jnp.stack([ + jnp.zeros(source_lib.get_cell_profile_shape(geo)), + jnp.ones(source_lib.get_cell_profile_shape(geo)), + ]) + + foo_source_builder = source_lib.make_source_builder( + FooSource, model_func=foo_formula + )() + # Set the source mode to MODEL_BASED. + foo_source_builder.runtime_params.mode = ( + source_runtime_params_lib.Mode.MODEL_BASED + ) + source_models_builder = source_models_lib.SourceModelsBuilder( + {source_name: foo_source_builder}, + ) + source_models = source_models_builder() + runtime_params = runtime_params_lib.GeneralRuntimeParams() + geo = circular_geometry.build_circular_geometry() + dynamic_runtime_params_slice = ( + runtime_params_slice.DynamicRuntimeParamsSliceProvider( + runtime_params, + sources=source_models_builder.runtime_params, + torax_mesh=geo.torax_mesh, + )( + t=runtime_params.numerics.t_initial, + ) + ) + static_slice = runtime_params_slice.build_static_runtime_params_slice( + runtime_params=runtime_params, + source_runtime_params=source_models_builder.runtime_params, + torax_mesh=geo.torax_mesh, + ) + core_profiles = core_profile_setters.initial_core_profiles( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_slice, + geo=geo, + source_models=source_models, + ) + + def compute_and_sum_profiles(): + profiles = source_profile_builders.build_source_profiles( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_slice, + geo=geo, + core_profiles=core_profiles, + source_models=source_models, + # Configs set sources to implicit by default, so set this to False to + # calculate the custom source's profile. + explicit=False, + ) + ne = source_operations.sum_sources_ne(geo, profiles, source_models) + temp_el = source_operations.sum_sources_temp_el( + geo, profiles, source_models + ) + return (ne, temp_el) + + expected_ne = ( + jnp.ones(source_lib.get_cell_profile_shape(geo)) * geo.vpr + ) + expected_temp_el = jnp.zeros( + source_lib.get_cell_profile_shape(geo) + ) + with self.subTest('without_jit'): + (ne, temp_el) = compute_and_sum_profiles() + np.testing.assert_allclose(ne, expected_ne) + np.testing.assert_allclose(temp_el, expected_temp_el) + with self.subTest('with_jit'): + jitted_compute_and_sum = jax.jit(compute_and_sum_profiles) + (ne, temp_el) = jitted_compute_and_sum() + np.testing.assert_allclose(ne, expected_ne) + np.testing.assert_allclose(temp_el, expected_temp_el) + + +if __name__ == '__main__': + absltest.main() diff --git a/torax/sources/tests/source_profile_builders_test.py b/torax/sources/tests/source_profile_builders_test.py new file mode 100644 index 00000000..d1c1a9c5 --- /dev/null +++ b/torax/sources/tests/source_profile_builders_test.py @@ -0,0 +1,81 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl.testing import absltest +from absl.testing import parameterized +from torax import core_profile_setters +from torax.config import runtime_params as runtime_params_lib +from torax.config import runtime_params_slice +from torax.geometry import circular_geometry +from torax.sources import source_models as source_models_lib +from torax.sources import source_profile_builders +from torax.stepper import runtime_params as stepper_runtime_params_lib + + +class SourceModelsTest(parameterized.TestCase): + + def test_computing_source_profiles_works_with_all_defaults(self): + """Tests that you can compute source profiles with all defaults.""" + runtime_params = runtime_params_lib.GeneralRuntimeParams() + geo = circular_geometry.build_circular_geometry() + source_models_builder = source_models_lib.SourceModelsBuilder() + source_models = source_models_builder() + dynamic_runtime_params_slice = ( + runtime_params_slice.DynamicRuntimeParamsSliceProvider( + runtime_params, + sources=source_models_builder.runtime_params, + torax_mesh=geo.torax_mesh, + )( + t=runtime_params.numerics.t_initial, + ) + ) + static_slice = runtime_params_slice.build_static_runtime_params_slice( + runtime_params=runtime_params, + source_runtime_params=source_models_builder.runtime_params, + torax_mesh=geo.torax_mesh, + ) + core_profiles = core_profile_setters.initial_core_profiles( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_slice, + geo=geo, + source_models=source_models, + ) + stepper_params = stepper_runtime_params_lib.RuntimeParams() + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice( + runtime_params=runtime_params, + source_runtime_params=source_models_builder.runtime_params, + torax_mesh=geo.torax_mesh, + stepper=stepper_params, + ) + ) + _ = source_profile_builders.build_source_profiles( + static_runtime_params_slice, + dynamic_runtime_params_slice, + geo, + core_profiles, + source_models, + explicit=True, + ) + _ = source_profile_builders.build_source_profiles( + static_runtime_params_slice, + dynamic_runtime_params_slice, + geo, + core_profiles, + source_models, + explicit=False, + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/torax/stepper/linear_theta_method.py b/torax/stepper/linear_theta_method.py index 32e43015..fca2f7d1 100644 --- a/torax/stepper/linear_theta_method.py +++ b/torax/stepper/linear_theta_method.py @@ -26,6 +26,7 @@ from torax.geometry import geometry from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.sources import source_models as source_models_lib +from torax.sources import source_profile_builders from torax.sources import source_profiles from torax.stepper import predictor_corrector_method from torax.stepper import runtime_params as runtime_params_lib @@ -105,7 +106,7 @@ def _x_new( init_val = ( x_new_init, ( - source_models_lib.build_all_zero_profiles( + source_profile_builders.build_all_zero_profiles( geo_t, self.source_models, ), diff --git a/torax/stepper/stepper.py b/torax/stepper/stepper.py index 36b02cbd..461eba2e 100644 --- a/torax/stepper/stepper.py +++ b/torax/stepper/stepper.py @@ -28,6 +28,7 @@ from torax.geometry import geometry from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.sources import source_models as source_models_lib +from torax.sources import source_profile_builders from torax.sources import source_profiles from torax.stepper import runtime_params as runtime_params_lib from torax.transport_model import transport_model as transport_model_lib @@ -143,7 +144,7 @@ def __call__( ) else: x_new = tuple() - core_sources = source_models_lib.build_all_zero_profiles( + core_sources = source_profile_builders.build_all_zero_profiles( source_models=self.source_models, geo=geo_t, ) diff --git a/torax/tests/sim_output_source_profiles.py b/torax/tests/sim_output_source_profiles.py index 7d3c4370..c9907cc0 100644 --- a/torax/tests/sim_output_source_profiles.py +++ b/torax/tests/sim_output_source_profiles.py @@ -36,6 +36,7 @@ from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib +from torax.sources import source_profile_builders from torax.sources import source_profiles as source_profiles_lib from torax.sources.tests import test_lib from torax.tests.test_lib import default_sources @@ -171,7 +172,7 @@ def mock_step_fn( dt=dt, time_step_calculator_state=(), # The returned source profiles include only the implicit sources. - core_sources=source_models_lib.build_source_profiles( + core_sources=source_profile_builders.build_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice_provider( t=new_t, ), diff --git a/torax/tests/sim_time_dependence.py b/torax/tests/sim_time_dependence.py index eff7ab79..3a32167e 100644 --- a/torax/tests/sim_time_dependence.py +++ b/torax/tests/sim_time_dependence.py @@ -35,6 +35,7 @@ from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.pedestal_model import set_tped_nped from torax.sources import source_models as source_models_lib +from torax.sources import source_profile_builders from torax.sources import source_profiles from torax.stepper import stepper as stepper_lib from torax.time_step_calculator import fixed_time_step_calculator @@ -212,7 +213,7 @@ def __call__( pedestal_model_output, ) # Use Qei as a hacky way to extract what the combined value was. - core_sources = source_models_lib.build_all_zero_profiles( + core_sources = source_profile_builders.build_all_zero_profiles( geo=geo_t, source_models=self.source_models, ) diff --git a/torax/tests/test_lib/explicit_stepper.py b/torax/tests/test_lib/explicit_stepper.py index 166bcff9..3f30bbf7 100644 --- a/torax/tests/test_lib/explicit_stepper.py +++ b/torax/tests/test_lib/explicit_stepper.py @@ -30,7 +30,8 @@ from torax.config import runtime_params_slice from torax.fvm import diffusion_terms from torax.geometry import geometry -from torax.sources import source_models +from torax.sources import source_operations +from torax.sources import source_profile_builders from torax.sources import source_profiles from torax.stepper import stepper as stepper_lib from torax.transport_model import constant as constant_transport_model @@ -106,7 +107,7 @@ def __call__( ) # Source term - c += source_models.sum_sources_temp_ion( + c += source_operations.sum_sources_temp_ion( geo_t, explicit_source_profiles, self.source_models, @@ -153,7 +154,7 @@ def __call__( q_face=q_face, s_face=s_face, ), - source_models.build_all_zero_profiles( + source_profile_builders.build_all_zero_profiles( geo=geo_t, source_models=self.source_models, ),