Skip to content

Commit

Permalink
Move source profile methods and sum methods to own methods
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720970981
  • Loading branch information
tamaranorman authored and Torax team committed Jan 29, 2025
1 parent 6be2c3c commit 7260a90
Show file tree
Hide file tree
Showing 19 changed files with 698 additions and 577 deletions.
20 changes: 11 additions & 9 deletions torax/fvm/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source as source_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_operations
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
from torax.transport_model import transport_model as transport_model_lib

Expand Down Expand Up @@ -375,7 +377,7 @@ def _calc_coeffs_full(
# This only calculates sources set to implicit in the config. All other
# sources are set to 0 (and should have their profiles already calculated in
# explicit_source_profiles).
implicit_source_profiles = source_models_lib.build_source_profiles(
implicit_source_profiles = source_profile_builders.build_source_profiles(
source_models=source_models,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
Expand Down Expand Up @@ -427,11 +429,11 @@ def _calc_coeffs_full(
source_mat_psi = jnp.zeros_like(geo.rho)

# fill source vector based on both original and updated core profiles
source_psi = source_models_lib.sum_sources_psi(
source_psi = source_operations.sum_sources_psi(
geo,
implicit_source_profiles,
source_models,
) + source_models_lib.sum_sources_psi(
) + source_operations.sum_sources_psi(
geo,
explicit_source_profiles,
source_models,
Expand Down Expand Up @@ -622,11 +624,11 @@ def _calc_coeffs_full(
source_mat_nn = jnp.zeros_like(geo.rho)

# density source vector based both on original and updated core profiles
source_ne = source_models_lib.sum_sources_ne(
source_ne = source_operations.sum_sources_ne(
geo,
explicit_source_profiles,
source_models,
) + source_models_lib.sum_sources_ne(
) + source_operations.sum_sources_ne(
geo,
implicit_source_profiles,
source_models,
Expand Down Expand Up @@ -718,21 +720,21 @@ def _calc_coeffs_full(
source_mat_ii = jnp.zeros_like(geo.rho)
source_mat_ee = jnp.zeros_like(geo.rho)

source_i = source_models_lib.sum_sources_temp_ion(
source_i = source_operations.sum_sources_temp_ion(
geo,
explicit_source_profiles,
source_models,
) + source_models_lib.sum_sources_temp_ion(
) + source_operations.sum_sources_temp_ion(
geo,
implicit_source_profiles,
source_models,
)

source_e = source_models_lib.sum_sources_temp_el(
source_e = source_operations.sum_sources_temp_el(
geo,
explicit_source_profiles,
source_models,
) + source_models_lib.sum_sources_temp_el(
) + source_operations.sum_sources_temp_el(
geo,
implicit_source_profiles,
source_models,
Expand Down
3 changes: 2 additions & 1 deletion torax/fvm/newton_raphson_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torax.geometry import geometry
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles
from torax.stepper import predictor_corrector_method
from torax.transport_model import transport_model as transport_model_lib
Expand Down Expand Up @@ -226,7 +227,7 @@ def newton_raphson_solve_block(
# Initialized here with correct shapes to help with tracing in case
# this is jitted.
(
source_models_lib.build_all_zero_profiles(
source_profile_builders.build_all_zero_profiles(
geo_t,
source_models,
),
Expand Down
3 changes: 2 additions & 1 deletion torax/fvm/optimizer_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torax.geometry import geometry
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles
from torax.stepper import predictor_corrector_method
from torax.transport_model import transport_model as transport_model_lib
Expand Down Expand Up @@ -157,7 +158,7 @@ def optimizer_solve_block(
# Initialized here with correct shapes to help with tracing in case
# this is jitted.
(
source_models_lib.build_all_zero_profiles(
source_profile_builders.build_all_zero_profiles(
geo_t,
source_models,
),
Expand Down
4 changes: 2 additions & 2 deletions torax/fvm/tests/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torax.geometry import circular_geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.stepper import runtime_params as stepper_params_lib
from torax.tests.test_lib import default_sources
from torax.transport_model import constant as constant_transport_model
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_calc_coeffs_smoke_test(
source_models,
)
evolving_names = tuple(['temp_ion'])
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
static_runtime_params_slice=static_runtime_params_slice,
source_models=source_models,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
Expand Down
8 changes: 4 additions & 4 deletions torax/fvm/tests/fvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from torax.geometry import circular_geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.stepper import runtime_params as stepper_runtime_params
from torax.tests.test_lib import default_sources
from torax.tests.test_lib import torax_refs
Expand Down Expand Up @@ -443,7 +443,7 @@ def test_nonlinear_solve_block_loss_minimum(
source_models,
)
evolving_names = tuple(['temp_ion'])
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
source_models=source_models,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
Expand Down Expand Up @@ -587,7 +587,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
geo,
source_models,
)
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand Down Expand Up @@ -740,7 +740,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
geo,
source_models,
)
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice_theta0,
geo=geo,
Expand Down
3 changes: 2 additions & 1 deletion torax/orchestration/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import ohmic_heat_source
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
from torax.stepper import stepper as stepper_lib
from torax.time_step_calculator import time_step_calculator as ts
Expand Down Expand Up @@ -144,7 +145,7 @@ def __call__(
# This only computes sources set to explicit in the
# DynamicSourceConfigSlice. All implicit sources will have their profiles
# set to 0.
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice_t,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo_t,
Expand Down
7 changes: 4 additions & 3 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from torax.orchestration import step_function
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
from torax.stepper import stepper as stepper_lib
from torax.time_step_calculator import chi_time_step_calculator
Expand Down Expand Up @@ -631,7 +632,7 @@ def _run_simulation(
geometry_provider=geometry_provider,
)
)
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand Down Expand Up @@ -700,7 +701,7 @@ def get_initial_source_profiles(
Implicit and explicit SourceProfiles from source models based on the core
profiles from the starting state.
"""
implicit_profiles = source_models_lib.build_source_profiles(
implicit_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand All @@ -709,7 +710,7 @@ def get_initial_source_profiles(
explicit=False,
)
# Also add in the explicit sources to the initial sources.
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand Down
15 changes: 9 additions & 6 deletions torax/sources/ohmic_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source as source_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_operations


@functools.partial(
Expand Down Expand Up @@ -70,12 +71,14 @@ def calc_psidot(
"""
consts = constants.CONSTANTS

psi_sources, sigma, sigma_face = source_models_lib.calc_and_sum_sources_psi(
static_runtime_params_slice,
dynamic_runtime_params_slice,
geo,
core_profiles,
source_models,
psi_sources, sigma, sigma_face = (
source_operations.calc_and_sum_sources_psi(
static_runtime_params_slice,
dynamic_runtime_params_slice,
geo,
core_profiles,
source_models,
)
)
# Calculate transient term
toc_psi = (
Expand Down
Loading

0 comments on commit 7260a90

Please sign in to comment.