Skip to content

Commit

Permalink
Assortment of sources changes and simplifications
Browse files Browse the repository at this point in the history
* Change calc qei to be part of build_source_profiles in the implicit case, these were being calculated later but on the same values and just in the implicit case
* Use a static arg for explicit in build_source_profiles which simplifies code and has no impact on performance
* Remove ProfileType and replace with simple functions for face/cell shape

PiperOrigin-RevId: 720968920
  • Loading branch information
tamaranorman authored and Torax team committed Jan 29, 2025
1 parent 9591640 commit 6be2c3c
Show file tree
Hide file tree
Showing 15 changed files with 92 additions and 203 deletions.
101 changes: 27 additions & 74 deletions torax/fvm/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,67 +388,38 @@ def _calc_coeffs_full(
# recalculate here to avoid issues with JAX branching in the logic.
# Decide which values to use depending on whether the source is explicit or
# implicit.
sigma = jax_utils.select(
static_runtime_params_slice.sources[
source_models.j_bootstrap_name
].is_explicit,
explicit_source_profiles.j_bootstrap.sigma,
implicit_source_profiles.j_bootstrap.sigma,
)
sigma_face = jax_utils.select(
static_runtime_params_slice.sources[
source_models.j_bootstrap_name
].is_explicit,
explicit_source_profiles.j_bootstrap.sigma_face,
implicit_source_profiles.j_bootstrap.sigma_face,
)
j_bootstrap = jax_utils.select(
static_runtime_params_slice.sources[
source_models.j_bootstrap_name
].is_explicit,
explicit_source_profiles.j_bootstrap.j_bootstrap,
implicit_source_profiles.j_bootstrap.j_bootstrap,
)
j_bootstrap_face = jax_utils.select(
static_runtime_params_slice.sources[
source_models.j_bootstrap_name
].is_explicit,
explicit_source_profiles.j_bootstrap.j_bootstrap_face,
implicit_source_profiles.j_bootstrap.j_bootstrap_face,
)
I_bootstrap = jax_utils.select( # pylint: disable=invalid-name
static_runtime_params_slice.sources[
source_models.j_bootstrap_name
].is_explicit,
explicit_source_profiles.j_bootstrap.I_bootstrap,
implicit_source_profiles.j_bootstrap.I_bootstrap,
)
if static_runtime_params_slice.sources[
source_models.j_bootstrap_name
].is_explicit:
j_bootstrap = explicit_source_profiles.j_bootstrap
else:
j_bootstrap = implicit_source_profiles.j_bootstrap

external_current = jnp.zeros_like(geo.rho)
# Sum over all psi sources (except the bootstrap current).
for source_name, source in source_models.psi_sources.items():
external_current += jax_utils.select(
static_runtime_params_slice.sources[source_name].is_explicit,
source.get_source_profile_for_affected_core_profile(
profile=explicit_source_profiles.profiles[source_name],
affected_core_profile=source_lib.AffectedCoreProfile.PSI.value,
geo=geo,
),
source.get_source_profile_for_affected_core_profile(
profile=implicit_source_profiles.profiles[source_name],
affected_core_profile=source_lib.AffectedCoreProfile.PSI.value,
geo=geo,
),
if static_runtime_params_slice.sources[source_name].is_explicit:
profiles = explicit_source_profiles.profiles
else:
profiles = implicit_source_profiles.profiles
external_current += source.get_source_profile_for_affected_core_profile(
profile=profiles[source_name],
affected_core_profile=source_lib.AffectedCoreProfile.PSI.value,
geo=geo,
)

currents = dataclasses.replace(
core_profiles.currents,
j_bootstrap=j_bootstrap,
j_bootstrap_face=j_bootstrap_face,
j_bootstrap=j_bootstrap.j_bootstrap,
j_bootstrap_face=j_bootstrap.j_bootstrap_face,
external_current_source=external_current,
johm=(core_profiles.currents.jtot - j_bootstrap - external_current),
I_bootstrap=I_bootstrap,
sigma=sigma,
johm=(
core_profiles.currents.jtot
- j_bootstrap.j_bootstrap
- external_current
),
I_bootstrap=j_bootstrap.I_bootstrap,
sigma=j_bootstrap.sigma,
)
core_profiles = dataclasses.replace(core_profiles, currents=currents)

Expand Down Expand Up @@ -495,7 +466,7 @@ def _calc_coeffs_full(
1.0
/ dynamic_runtime_params_slice.numerics.resistivity_mult
* geo.rho_norm
* sigma
* j_bootstrap.sigma
* consts.mu0
* 16
* jnp.pi**2
Expand Down Expand Up @@ -737,30 +708,11 @@ def _calc_coeffs_full(
* consts.mu0
* geo.Phibdot
* geo.Phib
* sigma_face
* j_bootstrap.sigma_face
* geo.rho_face_norm**2
/ geo.F_face**2
)

# Ion and electron heat sources.
qei = source_models.qei_source.get_qei(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
# For Qei, always use the current set of core profiles.
# In the linear solver, core_profiles is the set of profiles at time t (at
# the start of the time step) or the updated core_profiles in
# predictor-corrector, and in the nonlinear solver, calc_coeffs is called
# at least twice, once with the core_profiles at time t, and again
# (iteratively) with core_profiles at t+dt.
core_profiles=core_profiles,
)
# Update the implicit profiles with the qei info.
implicit_source_profiles = dataclasses.replace(
implicit_source_profiles,
qei=qei,
)

# Fill heat transport equation sources. Initialize source matrices to zero

source_mat_ii = jnp.zeros_like(geo.rho)
Expand All @@ -787,6 +739,7 @@ def _calc_coeffs_full(
)

# Add the Qei effects.
qei = implicit_source_profiles.qei
source_mat_ii += qei.implicit_ii * geo.vpr
source_i += qei.explicit_i * geo.vpr
source_mat_ee += qei.implicit_ee * geo.vpr
Expand Down Expand Up @@ -853,7 +806,7 @@ def _calc_coeffs_full(
# Add effective phibdot poloidal flux source term

ddrnorm_sigma_rnorm2_over_f2 = jnp.gradient(
sigma * geo.rho_norm**2 / geo.F**2, geo.rho_norm
j_bootstrap.sigma * geo.rho_norm**2 / geo.F**2, geo.rho_norm
)

source_psi += (
Expand Down
7 changes: 0 additions & 7 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,13 +708,6 @@ def get_initial_source_profiles(
source_models=source_models,
explicit=False,
)
qei = source_models.qei_source.get_qei(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
)
implicit_profiles = dataclasses.replace(implicit_profiles, qei=qei)
# Also add in the explicit sources to the initial sources.
explicit_source_profiles = source_models_lib.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
Expand Down
6 changes: 3 additions & 3 deletions torax/sources/bootstrap_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):

def _default_output_shapes(geo) -> tuple[int, int, int, int]:
return (
source.ProfileType.CELL.get_profile_shape(geo) # sigmaneo
+ source.ProfileType.CELL.get_profile_shape(geo) # bootstrap
+ source.ProfileType.FACE.get_profile_shape(geo) # bootstrap face
source.get_cell_profile_shape(geo) # sigmaneo
+ source.get_cell_profile_shape(geo) # bootstrap
+ source.get_face_profile_shape(geo) # bootstrap face
+ (1,) # I_bootstrap
)

Expand Down
2 changes: 1 addition & 1 deletion torax/sources/electron_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def calc_heating_and_current(


def _get_ec_output_shape(geo: geometry.Geometry) -> tuple[int, ...]:
return (2,) + source.ProfileType.CELL.get_profile_shape(geo)
return (2,) + source.get_cell_profile_shape(geo)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
Expand Down
2 changes: 1 addition & 1 deletion torax/sources/generic_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,4 @@ def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]:

@property
def output_shape_getter(self) -> source.SourceOutputShapeFunction:
return source.ProfileType.CELL.get_profile_shape
return source.get_cell_profile_shape
37 changes: 12 additions & 25 deletions torax/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,6 @@ def __call__(
]


def get_cell_profile_shape(
geo: geometry.Geometry,
):
"""Returns the shape of a source profile on the cell grid."""
return ProfileType.CELL.get_profile_shape(geo)


@enum.unique
class AffectedCoreProfile(enum.IntEnum):
"""Defines which part of the core profiles the source helps evolve.
Expand Down Expand Up @@ -178,7 +171,7 @@ def get_value(
]
output_shape = self.output_shape_getter(geo)

return get_source_profiles(
return _get_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand Down Expand Up @@ -247,27 +240,19 @@ def get_source_profile_for_affected_core_profile(
)


class ProfileType(enum.Enum):
"""Describes what kind of profile is expected from a source."""

# Source should return a profile on the cell grid.
CELL = enum.auto()
def get_cell_profile_shape(geo: geometry.Geometry) -> tuple[int, ...]:
"""Returns the shape of a source profile on the cell grid."""
return geo.torax_mesh.cell_centers.shape

# Source should return a profile on the face grid.
FACE = enum.auto()

def get_profile_shape(self, geo: geometry.Geometry) -> tuple[int, ...]:
"""Returns the expected length of the source profile."""
profile_type_to_len = {
ProfileType.CELL: geo.rho.shape,
ProfileType.FACE: geo.rho_face.shape,
}
return profile_type_to_len[self]
def get_face_profile_shape(geo: geometry.Geometry) -> tuple[int, ...]:
"""Returns the shape of a source profile on the face grid."""
return geo.torax_mesh.face_centers.shape


# pytype bug: 'source_models.SourceModels' not treated as a forward ref
# pytype: disable=name-error
def get_source_profiles(
def _get_source_profiles(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down Expand Up @@ -320,12 +305,14 @@ def get_source_profiles(
)
case runtime_params_lib.Mode.PRESCRIBED.value:
return prescribed_values
case _:
case runtime_params_lib.Mode.ZERO.value:
return jnp.zeros(output_shape)
case _:
raise ValueError(f'Unknown mode: {mode}')


def get_ion_el_output_shape(geo):
return (2,) + ProfileType.CELL.get_profile_shape(geo)
return (2,) + get_cell_profile_shape(geo)


@dataclasses.dataclass(frozen=False, kw_only=True)
Expand Down
Loading

0 comments on commit 6be2c3c

Please sign in to comment.