Skip to content

Commit

Permalink
Move source profile methods and sum methods to own methods
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 719152153
  • Loading branch information
tamaranorman authored and Torax team committed Jan 29, 2025
1 parent 9591640 commit 2d7c66d
Show file tree
Hide file tree
Showing 30 changed files with 755 additions and 745 deletions.
121 changes: 38 additions & 83 deletions torax/fvm/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source as source_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_operations
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
from torax.transport_model import transport_model as transport_model_lib

Expand Down Expand Up @@ -375,7 +377,7 @@ def _calc_coeffs_full(
# This only calculates sources set to implicit in the config. All other
# sources are set to 0 (and should have their profiles already calculated in
# explicit_source_profiles).
implicit_source_profiles = source_models_lib.build_source_profiles(
implicit_source_profiles = source_profile_builders.build_source_profiles(
source_models=source_models,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
Expand All @@ -388,79 +390,50 @@ def _calc_coeffs_full(
# recalculate here to avoid issues with JAX branching in the logic.
# Decide which values to use depending on whether the source is explicit or
# implicit.
sigma = jax_utils.select(
static_runtime_params_slice.sources[
source_models.j_bootstrap_name
].is_explicit,
explicit_source_profiles.j_bootstrap.sigma,
implicit_source_profiles.j_bootstrap.sigma,
)
sigma_face = jax_utils.select(
static_runtime_params_slice.sources[
source_models.j_bootstrap_name
].is_explicit,
explicit_source_profiles.j_bootstrap.sigma_face,
implicit_source_profiles.j_bootstrap.sigma_face,
)
j_bootstrap = jax_utils.select(
static_runtime_params_slice.sources[
source_models.j_bootstrap_name
].is_explicit,
explicit_source_profiles.j_bootstrap.j_bootstrap,
implicit_source_profiles.j_bootstrap.j_bootstrap,
)
j_bootstrap_face = jax_utils.select(
static_runtime_params_slice.sources[
source_models.j_bootstrap_name
].is_explicit,
explicit_source_profiles.j_bootstrap.j_bootstrap_face,
implicit_source_profiles.j_bootstrap.j_bootstrap_face,
)
I_bootstrap = jax_utils.select( # pylint: disable=invalid-name
static_runtime_params_slice.sources[
source_models.j_bootstrap_name
].is_explicit,
explicit_source_profiles.j_bootstrap.I_bootstrap,
implicit_source_profiles.j_bootstrap.I_bootstrap,
)
if static_runtime_params_slice.sources[
source_models.j_bootstrap_name
].is_explicit:
j_bootstrap = explicit_source_profiles.j_bootstrap
else:
j_bootstrap = implicit_source_profiles.j_bootstrap

external_current = jnp.zeros_like(geo.rho)
# Sum over all psi sources (except the bootstrap current).
for source_name, source in source_models.psi_sources.items():
external_current += jax_utils.select(
static_runtime_params_slice.sources[source_name].is_explicit,
source.get_source_profile_for_affected_core_profile(
profile=explicit_source_profiles.profiles[source_name],
affected_core_profile=source_lib.AffectedCoreProfile.PSI.value,
geo=geo,
),
source.get_source_profile_for_affected_core_profile(
profile=implicit_source_profiles.profiles[source_name],
affected_core_profile=source_lib.AffectedCoreProfile.PSI.value,
geo=geo,
),
if static_runtime_params_slice.sources[source_name].is_explicit:
profiles = explicit_source_profiles.profiles
else:
profiles = implicit_source_profiles.profiles
external_current += source.get_source_profile_for_affected_core_profile(
profile=profiles[source_name],
affected_core_profile=source_lib.AffectedCoreProfile.PSI.value,
geo=geo,
)

currents = dataclasses.replace(
core_profiles.currents,
j_bootstrap=j_bootstrap,
j_bootstrap_face=j_bootstrap_face,
j_bootstrap=j_bootstrap.j_bootstrap,
j_bootstrap_face=j_bootstrap.j_bootstrap_face,
external_current_source=external_current,
johm=(core_profiles.currents.jtot - j_bootstrap - external_current),
I_bootstrap=I_bootstrap,
sigma=sigma,
johm=(
core_profiles.currents.jtot
- j_bootstrap.j_bootstrap
- external_current
),
I_bootstrap=j_bootstrap.I_bootstrap,
sigma=j_bootstrap.sigma,
)
core_profiles = dataclasses.replace(core_profiles, currents=currents)

# psi source terms. Source matrix is zero for all psi sources
source_mat_psi = jnp.zeros_like(geo.rho)

# fill source vector based on both original and updated core profiles
source_psi = source_models_lib.sum_sources_psi(
source_psi = source_operations.sum_sources_psi(
geo,
implicit_source_profiles,
source_models,
) + source_models_lib.sum_sources_psi(
) + source_operations.sum_sources_psi(
geo,
explicit_source_profiles,
source_models,
Expand Down Expand Up @@ -495,7 +468,7 @@ def _calc_coeffs_full(
1.0
/ dynamic_runtime_params_slice.numerics.resistivity_mult
* geo.rho_norm
* sigma
* j_bootstrap.sigma
* consts.mu0
* 16
* jnp.pi**2
Expand Down Expand Up @@ -651,11 +624,11 @@ def _calc_coeffs_full(
source_mat_nn = jnp.zeros_like(geo.rho)

# density source vector based both on original and updated core profiles
source_ne = source_models_lib.sum_sources_ne(
source_ne = source_operations.sum_sources_ne(
geo,
explicit_source_profiles,
source_models,
) + source_models_lib.sum_sources_ne(
) + source_operations.sum_sources_ne(
geo,
implicit_source_profiles,
source_models,
Expand Down Expand Up @@ -737,56 +710,38 @@ def _calc_coeffs_full(
* consts.mu0
* geo.Phibdot
* geo.Phib
* sigma_face
* j_bootstrap.sigma_face
* geo.rho_face_norm**2
/ geo.F_face**2
)

# Ion and electron heat sources.
qei = source_models.qei_source.get_qei(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
# For Qei, always use the current set of core profiles.
# In the linear solver, core_profiles is the set of profiles at time t (at
# the start of the time step) or the updated core_profiles in
# predictor-corrector, and in the nonlinear solver, calc_coeffs is called
# at least twice, once with the core_profiles at time t, and again
# (iteratively) with core_profiles at t+dt.
core_profiles=core_profiles,
)
# Update the implicit profiles with the qei info.
implicit_source_profiles = dataclasses.replace(
implicit_source_profiles,
qei=qei,
)

# Fill heat transport equation sources. Initialize source matrices to zero

source_mat_ii = jnp.zeros_like(geo.rho)
source_mat_ee = jnp.zeros_like(geo.rho)

source_i = source_models_lib.sum_sources_temp_ion(
source_i = source_operations.sum_sources_temp_ion(
geo,
explicit_source_profiles,
source_models,
) + source_models_lib.sum_sources_temp_ion(
) + source_operations.sum_sources_temp_ion(
geo,
implicit_source_profiles,
source_models,
)

source_e = source_models_lib.sum_sources_temp_el(
source_e = source_operations.sum_sources_temp_el(
geo,
explicit_source_profiles,
source_models,
) + source_models_lib.sum_sources_temp_el(
) + source_operations.sum_sources_temp_el(
geo,
implicit_source_profiles,
source_models,
)

# Add the Qei effects.
qei = implicit_source_profiles.qei
source_mat_ii += qei.implicit_ii * geo.vpr
source_i += qei.explicit_i * geo.vpr
source_mat_ee += qei.implicit_ee * geo.vpr
Expand Down Expand Up @@ -853,7 +808,7 @@ def _calc_coeffs_full(
# Add effective phibdot poloidal flux source term

ddrnorm_sigma_rnorm2_over_f2 = jnp.gradient(
sigma * geo.rho_norm**2 / geo.F**2, geo.rho_norm
j_bootstrap.sigma * geo.rho_norm**2 / geo.F**2, geo.rho_norm
)

source_psi += (
Expand Down
3 changes: 2 additions & 1 deletion torax/fvm/newton_raphson_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torax.geometry import geometry
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles
from torax.stepper import predictor_corrector_method
from torax.transport_model import transport_model as transport_model_lib
Expand Down Expand Up @@ -226,7 +227,7 @@ def newton_raphson_solve_block(
# Initialized here with correct shapes to help with tracing in case
# this is jitted.
(
source_models_lib.build_all_zero_profiles(
source_profile_builders.build_all_zero_profiles(
geo_t,
source_models,
),
Expand Down
3 changes: 2 additions & 1 deletion torax/fvm/optimizer_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torax.geometry import geometry
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles
from torax.stepper import predictor_corrector_method
from torax.transport_model import transport_model as transport_model_lib
Expand Down Expand Up @@ -157,7 +158,7 @@ def optimizer_solve_block(
# Initialized here with correct shapes to help with tracing in case
# this is jitted.
(
source_models_lib.build_all_zero_profiles(
source_profile_builders.build_all_zero_profiles(
geo_t,
source_models,
),
Expand Down
4 changes: 2 additions & 2 deletions torax/fvm/tests/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torax.geometry import circular_geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.stepper import runtime_params as stepper_params_lib
from torax.tests.test_lib import default_sources
from torax.transport_model import constant as constant_transport_model
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_calc_coeffs_smoke_test(
source_models,
)
evolving_names = tuple(['temp_ion'])
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
static_runtime_params_slice=static_runtime_params_slice,
source_models=source_models,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
Expand Down
8 changes: 4 additions & 4 deletions torax/fvm/tests/fvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from torax.geometry import circular_geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.stepper import runtime_params as stepper_runtime_params
from torax.tests.test_lib import default_sources
from torax.tests.test_lib import torax_refs
Expand Down Expand Up @@ -443,7 +443,7 @@ def test_nonlinear_solve_block_loss_minimum(
source_models,
)
evolving_names = tuple(['temp_ion'])
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
source_models=source_models,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
Expand Down Expand Up @@ -587,7 +587,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
geo,
source_models,
)
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand Down Expand Up @@ -740,7 +740,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
geo,
source_models,
)
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice_theta0,
geo=geo,
Expand Down
3 changes: 2 additions & 1 deletion torax/orchestration/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import ohmic_heat_source
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
from torax.stepper import stepper as stepper_lib
from torax.time_step_calculator import time_step_calculator as ts
Expand Down Expand Up @@ -144,7 +145,7 @@ def __call__(
# This only computes sources set to explicit in the
# DynamicSourceConfigSlice. All implicit sources will have their profiles
# set to 0.
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice_t,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo_t,
Expand Down
14 changes: 4 additions & 10 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from torax.orchestration import step_function
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
from torax.stepper import stepper as stepper_lib
from torax.time_step_calculator import chi_time_step_calculator
Expand Down Expand Up @@ -631,7 +632,7 @@ def _run_simulation(
geometry_provider=geometry_provider,
)
)
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand Down Expand Up @@ -700,23 +701,16 @@ def get_initial_source_profiles(
Implicit and explicit SourceProfiles from source models based on the core
profiles from the starting state.
"""
implicit_profiles = source_models_lib.build_source_profiles(
implicit_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
source_models=source_models,
explicit=False,
)
qei = source_models.qei_source.get_qei(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
)
implicit_profiles = dataclasses.replace(implicit_profiles, qei=qei)
# Also add in the explicit sources to the initial sources.
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand Down
6 changes: 3 additions & 3 deletions torax/sources/bootstrap_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):

def _default_output_shapes(geo) -> tuple[int, int, int, int]:
return (
source.ProfileType.CELL.get_profile_shape(geo) # sigmaneo
+ source.ProfileType.CELL.get_profile_shape(geo) # bootstrap
+ source.ProfileType.FACE.get_profile_shape(geo) # bootstrap face
source.get_cell_profile_shape(geo) # sigmaneo
+ source.get_cell_profile_shape(geo) # bootstrap
+ source.get_face_profile_shape(geo) # bootstrap face
+ (1,) # I_bootstrap
)

Expand Down
2 changes: 1 addition & 1 deletion torax/sources/electron_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def calc_heating_and_current(


def _get_ec_output_shape(geo: geometry.Geometry) -> tuple[int, ...]:
return (2,) + source.ProfileType.CELL.get_profile_shape(geo)
return (2,) + source.get_cell_profile_shape(geo)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
Expand Down
2 changes: 1 addition & 1 deletion torax/sources/generic_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,4 @@ def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]:

@property
def output_shape_getter(self) -> source.SourceOutputShapeFunction:
return source.ProfileType.CELL.get_profile_shape
return source.get_cell_profile_shape
Loading

0 comments on commit 2d7c66d

Please sign in to comment.