Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move source profile methods and sum methods to own methods #685

Merged
merged 1 commit into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions torax/fvm/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source as source_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_operations
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
from torax.transport_model import transport_model as transport_model_lib

Expand Down Expand Up @@ -375,7 +377,7 @@ def _calc_coeffs_full(
# This only calculates sources set to implicit in the config. All other
# sources are set to 0 (and should have their profiles already calculated in
# explicit_source_profiles).
implicit_source_profiles = source_models_lib.build_source_profiles(
implicit_source_profiles = source_profile_builders.build_source_profiles(
source_models=source_models,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
Expand Down Expand Up @@ -427,11 +429,11 @@ def _calc_coeffs_full(
source_mat_psi = jnp.zeros_like(geo.rho)

# fill source vector based on both original and updated core profiles
source_psi = source_models_lib.sum_sources_psi(
source_psi = source_operations.sum_sources_psi(
geo,
implicit_source_profiles,
source_models,
) + source_models_lib.sum_sources_psi(
) + source_operations.sum_sources_psi(
geo,
explicit_source_profiles,
source_models,
Expand Down Expand Up @@ -622,11 +624,11 @@ def _calc_coeffs_full(
source_mat_nn = jnp.zeros_like(geo.rho)

# density source vector based both on original and updated core profiles
source_ne = source_models_lib.sum_sources_ne(
source_ne = source_operations.sum_sources_ne(
geo,
explicit_source_profiles,
source_models,
) + source_models_lib.sum_sources_ne(
) + source_operations.sum_sources_ne(
geo,
implicit_source_profiles,
source_models,
Expand Down Expand Up @@ -718,21 +720,21 @@ def _calc_coeffs_full(
source_mat_ii = jnp.zeros_like(geo.rho)
source_mat_ee = jnp.zeros_like(geo.rho)

source_i = source_models_lib.sum_sources_temp_ion(
source_i = source_operations.sum_sources_temp_ion(
geo,
explicit_source_profiles,
source_models,
) + source_models_lib.sum_sources_temp_ion(
) + source_operations.sum_sources_temp_ion(
geo,
implicit_source_profiles,
source_models,
)

source_e = source_models_lib.sum_sources_temp_el(
source_e = source_operations.sum_sources_temp_el(
geo,
explicit_source_profiles,
source_models,
) + source_models_lib.sum_sources_temp_el(
) + source_operations.sum_sources_temp_el(
geo,
implicit_source_profiles,
source_models,
Expand Down
3 changes: 2 additions & 1 deletion torax/fvm/newton_raphson_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torax.geometry import geometry
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles
from torax.stepper import predictor_corrector_method
from torax.transport_model import transport_model as transport_model_lib
Expand Down Expand Up @@ -226,7 +227,7 @@ def newton_raphson_solve_block(
# Initialized here with correct shapes to help with tracing in case
# this is jitted.
(
source_models_lib.build_all_zero_profiles(
source_profile_builders.build_all_zero_profiles(
geo_t,
source_models,
),
Expand Down
3 changes: 2 additions & 1 deletion torax/fvm/optimizer_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torax.geometry import geometry
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles
from torax.stepper import predictor_corrector_method
from torax.transport_model import transport_model as transport_model_lib
Expand Down Expand Up @@ -157,7 +158,7 @@ def optimizer_solve_block(
# Initialized here with correct shapes to help with tracing in case
# this is jitted.
(
source_models_lib.build_all_zero_profiles(
source_profile_builders.build_all_zero_profiles(
geo_t,
source_models,
),
Expand Down
4 changes: 2 additions & 2 deletions torax/fvm/tests/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torax.geometry import circular_geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.stepper import runtime_params as stepper_params_lib
from torax.tests.test_lib import default_sources
from torax.transport_model import constant as constant_transport_model
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_calc_coeffs_smoke_test(
source_models,
)
evolving_names = tuple(['temp_ion'])
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
static_runtime_params_slice=static_runtime_params_slice,
source_models=source_models,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
Expand Down
8 changes: 4 additions & 4 deletions torax/fvm/tests/fvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from torax.geometry import circular_geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.stepper import runtime_params as stepper_runtime_params
from torax.tests.test_lib import default_sources
from torax.tests.test_lib import torax_refs
Expand Down Expand Up @@ -443,7 +443,7 @@ def test_nonlinear_solve_block_loss_minimum(
source_models,
)
evolving_names = tuple(['temp_ion'])
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
source_models=source_models,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
Expand Down Expand Up @@ -587,7 +587,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
geo,
source_models,
)
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand Down Expand Up @@ -740,7 +740,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
geo,
source_models,
)
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice_theta0,
geo=geo,
Expand Down
3 changes: 2 additions & 1 deletion torax/orchestration/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import ohmic_heat_source
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
from torax.stepper import stepper as stepper_lib
from torax.time_step_calculator import time_step_calculator as ts
Expand Down Expand Up @@ -144,7 +145,7 @@ def __call__(
# This only computes sources set to explicit in the
# DynamicSourceConfigSlice. All implicit sources will have their profiles
# set to 0.
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice_t,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo_t,
Expand Down
7 changes: 4 additions & 3 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from torax.orchestration import step_function
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
from torax.stepper import stepper as stepper_lib
from torax.time_step_calculator import chi_time_step_calculator
Expand Down Expand Up @@ -631,7 +632,7 @@ def _run_simulation(
geometry_provider=geometry_provider,
)
)
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand Down Expand Up @@ -700,7 +701,7 @@ def get_initial_source_profiles(
Implicit and explicit SourceProfiles from source models based on the core
profiles from the starting state.
"""
implicit_profiles = source_models_lib.build_source_profiles(
implicit_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand All @@ -709,7 +710,7 @@ def get_initial_source_profiles(
explicit=False,
)
# Also add in the explicit sources to the initial sources.
explicit_source_profiles = source_models_lib.build_source_profiles(
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand Down
15 changes: 9 additions & 6 deletions torax/sources/ohmic_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source as source_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_operations


@functools.partial(
Expand Down Expand Up @@ -70,12 +71,14 @@ def calc_psidot(
"""
consts = constants.CONSTANTS

psi_sources, sigma, sigma_face = source_models_lib.calc_and_sum_sources_psi(
static_runtime_params_slice,
dynamic_runtime_params_slice,
geo,
core_profiles,
source_models,
psi_sources, sigma, sigma_face = (
source_operations.calc_and_sum_sources_psi(
static_runtime_params_slice,
dynamic_runtime_params_slice,
geo,
core_profiles,
source_models,
)
)
# Calculate transient term
toc_psi = (
Expand Down
Loading