From 6be2c3cea961efbc87081a6f596e6fcdb069355a Mon Sep 17 00:00:00 2001 From: Tamara Norman Date: Wed, 29 Jan 2025 07:32:11 -0800 Subject: [PATCH] Assortment of sources changes and simplifications * Change calc qei to be part of build_source_profiles in the implicit case, these were being calculated later but on the same values and just in the implicit case * Use a static arg for explicit in build_source_profiles which simplifies code and has no impact on performance * Remove ProfileType and replace with simple functions for face/cell shape PiperOrigin-RevId: 720968920 --- torax/fvm/calc_coeffs.py | 101 +++++------------- torax/sim.py | 7 -- torax/sources/bootstrap_current_source.py | 6 +- torax/sources/electron_cyclotron_source.py | 2 +- torax/sources/generic_current_source.py | 2 +- torax/sources/source.py | 37 +++---- torax/sources/source_models.py | 94 ++++++---------- .../tests/bootstrap_current_source_test.py | 4 +- .../tests/electron_cyclotron_source_test.py | 2 +- .../tests/generic_current_source_test.py | 4 +- .../impurity_radiation_heat_sink_test.py | 2 +- torax/sources/tests/source_models_test.py | 10 +- torax/sources/tests/source_test.py | 20 +--- torax/sources/tests/test_lib.py | 2 +- torax/tests/output.py | 2 +- 15 files changed, 92 insertions(+), 203 deletions(-) diff --git a/torax/fvm/calc_coeffs.py b/torax/fvm/calc_coeffs.py index 3a0f4690..63e84b5e 100644 --- a/torax/fvm/calc_coeffs.py +++ b/torax/fvm/calc_coeffs.py @@ -388,67 +388,38 @@ def _calc_coeffs_full( # recalculate here to avoid issues with JAX branching in the logic. # Decide which values to use depending on whether the source is explicit or # implicit. - sigma = jax_utils.select( - static_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ].is_explicit, - explicit_source_profiles.j_bootstrap.sigma, - implicit_source_profiles.j_bootstrap.sigma, - ) - sigma_face = jax_utils.select( - static_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ].is_explicit, - explicit_source_profiles.j_bootstrap.sigma_face, - implicit_source_profiles.j_bootstrap.sigma_face, - ) - j_bootstrap = jax_utils.select( - static_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ].is_explicit, - explicit_source_profiles.j_bootstrap.j_bootstrap, - implicit_source_profiles.j_bootstrap.j_bootstrap, - ) - j_bootstrap_face = jax_utils.select( - static_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ].is_explicit, - explicit_source_profiles.j_bootstrap.j_bootstrap_face, - implicit_source_profiles.j_bootstrap.j_bootstrap_face, - ) - I_bootstrap = jax_utils.select( # pylint: disable=invalid-name - static_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ].is_explicit, - explicit_source_profiles.j_bootstrap.I_bootstrap, - implicit_source_profiles.j_bootstrap.I_bootstrap, - ) + if static_runtime_params_slice.sources[ + source_models.j_bootstrap_name + ].is_explicit: + j_bootstrap = explicit_source_profiles.j_bootstrap + else: + j_bootstrap = implicit_source_profiles.j_bootstrap external_current = jnp.zeros_like(geo.rho) # Sum over all psi sources (except the bootstrap current). for source_name, source in source_models.psi_sources.items(): - external_current += jax_utils.select( - static_runtime_params_slice.sources[source_name].is_explicit, - source.get_source_profile_for_affected_core_profile( - profile=explicit_source_profiles.profiles[source_name], - affected_core_profile=source_lib.AffectedCoreProfile.PSI.value, - geo=geo, - ), - source.get_source_profile_for_affected_core_profile( - profile=implicit_source_profiles.profiles[source_name], - affected_core_profile=source_lib.AffectedCoreProfile.PSI.value, - geo=geo, - ), + if static_runtime_params_slice.sources[source_name].is_explicit: + profiles = explicit_source_profiles.profiles + else: + profiles = implicit_source_profiles.profiles + external_current += source.get_source_profile_for_affected_core_profile( + profile=profiles[source_name], + affected_core_profile=source_lib.AffectedCoreProfile.PSI.value, + geo=geo, ) currents = dataclasses.replace( core_profiles.currents, - j_bootstrap=j_bootstrap, - j_bootstrap_face=j_bootstrap_face, + j_bootstrap=j_bootstrap.j_bootstrap, + j_bootstrap_face=j_bootstrap.j_bootstrap_face, external_current_source=external_current, - johm=(core_profiles.currents.jtot - j_bootstrap - external_current), - I_bootstrap=I_bootstrap, - sigma=sigma, + johm=( + core_profiles.currents.jtot + - j_bootstrap.j_bootstrap + - external_current + ), + I_bootstrap=j_bootstrap.I_bootstrap, + sigma=j_bootstrap.sigma, ) core_profiles = dataclasses.replace(core_profiles, currents=currents) @@ -495,7 +466,7 @@ def _calc_coeffs_full( 1.0 / dynamic_runtime_params_slice.numerics.resistivity_mult * geo.rho_norm - * sigma + * j_bootstrap.sigma * consts.mu0 * 16 * jnp.pi**2 @@ -737,30 +708,11 @@ def _calc_coeffs_full( * consts.mu0 * geo.Phibdot * geo.Phib - * sigma_face + * j_bootstrap.sigma_face * geo.rho_face_norm**2 / geo.F_face**2 ) - # Ion and electron heat sources. - qei = source_models.qei_source.get_qei( - static_runtime_params_slice=static_runtime_params_slice, - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - geo=geo, - # For Qei, always use the current set of core profiles. - # In the linear solver, core_profiles is the set of profiles at time t (at - # the start of the time step) or the updated core_profiles in - # predictor-corrector, and in the nonlinear solver, calc_coeffs is called - # at least twice, once with the core_profiles at time t, and again - # (iteratively) with core_profiles at t+dt. - core_profiles=core_profiles, - ) - # Update the implicit profiles with the qei info. - implicit_source_profiles = dataclasses.replace( - implicit_source_profiles, - qei=qei, - ) - # Fill heat transport equation sources. Initialize source matrices to zero source_mat_ii = jnp.zeros_like(geo.rho) @@ -787,6 +739,7 @@ def _calc_coeffs_full( ) # Add the Qei effects. + qei = implicit_source_profiles.qei source_mat_ii += qei.implicit_ii * geo.vpr source_i += qei.explicit_i * geo.vpr source_mat_ee += qei.implicit_ee * geo.vpr @@ -853,7 +806,7 @@ def _calc_coeffs_full( # Add effective phibdot poloidal flux source term ddrnorm_sigma_rnorm2_over_f2 = jnp.gradient( - sigma * geo.rho_norm**2 / geo.F**2, geo.rho_norm + j_bootstrap.sigma * geo.rho_norm**2 / geo.F**2, geo.rho_norm ) source_psi += ( diff --git a/torax/sim.py b/torax/sim.py index 784ccd96..867e03a5 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -708,13 +708,6 @@ def get_initial_source_profiles( source_models=source_models, explicit=False, ) - qei = source_models.qei_source.get_qei( - static_runtime_params_slice=static_runtime_params_slice, - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - geo=geo, - core_profiles=core_profiles, - ) - implicit_profiles = dataclasses.replace(implicit_profiles, qei=qei) # Also add in the explicit sources to the initial sources. explicit_source_profiles = source_models_lib.build_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, diff --git a/torax/sources/bootstrap_current_source.py b/torax/sources/bootstrap_current_source.py index 4ebe0e95..b3ab9f18 100644 --- a/torax/sources/bootstrap_current_source.py +++ b/torax/sources/bootstrap_current_source.py @@ -69,9 +69,9 @@ class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): def _default_output_shapes(geo) -> tuple[int, int, int, int]: return ( - source.ProfileType.CELL.get_profile_shape(geo) # sigmaneo - + source.ProfileType.CELL.get_profile_shape(geo) # bootstrap - + source.ProfileType.FACE.get_profile_shape(geo) # bootstrap face + source.get_cell_profile_shape(geo) # sigmaneo + + source.get_cell_profile_shape(geo) # bootstrap + + source.get_face_profile_shape(geo) # bootstrap face + (1,) # I_bootstrap ) diff --git a/torax/sources/electron_cyclotron_source.py b/torax/sources/electron_cyclotron_source.py index 4319486d..890ec8ac 100644 --- a/torax/sources/electron_cyclotron_source.py +++ b/torax/sources/electron_cyclotron_source.py @@ -179,7 +179,7 @@ def calc_heating_and_current( def _get_ec_output_shape(geo: geometry.Geometry) -> tuple[int, ...]: - return (2,) + source.ProfileType.CELL.get_profile_shape(geo) + return (2,) + source.get_cell_profile_shape(geo) @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) diff --git a/torax/sources/generic_current_source.py b/torax/sources/generic_current_source.py index f09530ab..74b87edc 100644 --- a/torax/sources/generic_current_source.py +++ b/torax/sources/generic_current_source.py @@ -188,4 +188,4 @@ def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: @property def output_shape_getter(self) -> source.SourceOutputShapeFunction: - return source.ProfileType.CELL.get_profile_shape + return source.get_cell_profile_shape diff --git a/torax/sources/source.py b/torax/sources/source.py index aebdaf59..d02f10f8 100644 --- a/torax/sources/source.py +++ b/torax/sources/source.py @@ -74,13 +74,6 @@ def __call__( ] -def get_cell_profile_shape( - geo: geometry.Geometry, -): - """Returns the shape of a source profile on the cell grid.""" - return ProfileType.CELL.get_profile_shape(geo) - - @enum.unique class AffectedCoreProfile(enum.IntEnum): """Defines which part of the core profiles the source helps evolve. @@ -178,7 +171,7 @@ def get_value( ] output_shape = self.output_shape_getter(geo) - return get_source_profiles( + return _get_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, geo=geo, @@ -247,27 +240,19 @@ def get_source_profile_for_affected_core_profile( ) -class ProfileType(enum.Enum): - """Describes what kind of profile is expected from a source.""" - - # Source should return a profile on the cell grid. - CELL = enum.auto() +def get_cell_profile_shape(geo: geometry.Geometry) -> tuple[int, ...]: + """Returns the shape of a source profile on the cell grid.""" + return geo.torax_mesh.cell_centers.shape - # Source should return a profile on the face grid. - FACE = enum.auto() - def get_profile_shape(self, geo: geometry.Geometry) -> tuple[int, ...]: - """Returns the expected length of the source profile.""" - profile_type_to_len = { - ProfileType.CELL: geo.rho.shape, - ProfileType.FACE: geo.rho_face.shape, - } - return profile_type_to_len[self] +def get_face_profile_shape(geo: geometry.Geometry) -> tuple[int, ...]: + """Returns the shape of a source profile on the face grid.""" + return geo.torax_mesh.face_centers.shape # pytype bug: 'source_models.SourceModels' not treated as a forward ref # pytype: disable=name-error -def get_source_profiles( +def _get_source_profiles( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, @@ -320,12 +305,14 @@ def get_source_profiles( ) case runtime_params_lib.Mode.PRESCRIBED.value: return prescribed_values - case _: + case runtime_params_lib.Mode.ZERO.value: return jnp.zeros(output_shape) + case _: + raise ValueError(f'Unknown mode: {mode}') def get_ion_el_output_shape(geo): - return (2,) + ProfileType.CELL.get_profile_shape(geo) + return (2,) + get_cell_profile_shape(geo) @dataclasses.dataclass(frozen=False, kw_only=True) diff --git a/torax/sources/source_models.py b/torax/sources/source_models.py index 3f826cb6..c434b871 100644 --- a/torax/sources/source_models.py +++ b/torax/sources/source_models.py @@ -19,6 +19,7 @@ from collections.abc import Mapping import functools +import chex import jax import jax.numpy as jnp from torax import array_typing @@ -40,6 +41,7 @@ static_argnames=[ 'source_models', 'static_runtime_params_slice', + 'explicit', ], ) def build_source_profiles( @@ -91,12 +93,19 @@ def build_source_profiles( source_models, explicit, ) + if not explicit: + qei = source_models.qei_source.get_qei( + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + geo=geo, + core_profiles=core_profiles, + ) + else: + qei = source_profiles.QeiInfo.zeros(geo) return source_profiles.SourceProfiles( profiles=other_profiles, j_bootstrap=bootstrap_profiles, - # Qei is computed within calc_coeffs and will replace this value. This is - # here as a placeholder with correct shapes. - qei=source_profiles.QeiInfo.zeros(geo), + qei=qei, ) @@ -140,53 +149,10 @@ def _build_bootstrap_profiles( geo=geo, core_profiles=core_profiles, ) - sigma = jax_utils.select( - jnp.logical_or( - explicit == static_source_runtime_params.is_explicit, - calculate_anyway, - ), - bootstrap_profile.sigma, - jnp.zeros_like(bootstrap_profile.sigma), - ) - sigma_face = jax_utils.select( - jnp.logical_or( - explicit == static_source_runtime_params.is_explicit, - calculate_anyway, - ), - bootstrap_profile.sigma_face, - jnp.zeros_like(bootstrap_profile.sigma_face), - ) - j_bootstrap = jax_utils.select( - jnp.logical_or( - explicit == static_source_runtime_params.is_explicit, - calculate_anyway, - ), - bootstrap_profile.j_bootstrap, - jnp.zeros_like(bootstrap_profile.j_bootstrap), - ) - j_bootstrap_face = jax_utils.select( - jnp.logical_or( - explicit == static_source_runtime_params.is_explicit, - calculate_anyway, - ), - bootstrap_profile.j_bootstrap_face, - jnp.zeros_like(bootstrap_profile.j_bootstrap_face), - ) - I_bootstrap = jax_utils.select( # pylint: disable=invalid-name - jnp.logical_or( - explicit == static_source_runtime_params.is_explicit, - calculate_anyway, - ), - bootstrap_profile.I_bootstrap, - jnp.zeros_like(bootstrap_profile.I_bootstrap), - ) - return source_profiles.BootstrapCurrentProfile( - sigma=sigma, - sigma_face=sigma_face, - j_bootstrap=j_bootstrap, - j_bootstrap_face=j_bootstrap_face, - I_bootstrap=I_bootstrap, - ) + if explicit == static_source_runtime_params.is_explicit | calculate_anyway: + return bootstrap_profile + else: + return source_profiles.BootstrapCurrentProfile.zero_profile(geo) def _build_standard_source_profiles( @@ -203,7 +169,7 @@ def _build_standard_source_profiles( source_lib.AffectedCoreProfile.TEMP_ION, source_lib.AffectedCoreProfile.TEMP_EL, ), -) -> dict[str, jax.Array]: +) -> dict[str, chex.ArrayTree]: """Computes sources and builds a kwargs dict for SourceProfiles. Args: @@ -236,19 +202,19 @@ def _build_standard_source_profiles( static_source_runtime_params = static_runtime_params_slice.sources[ source_name ] - computed_source_profiles[source_name] = jax_utils.select( - jnp.logical_or( - explicit == static_source_runtime_params.is_explicit, - calculate_anyway, - ), - source.get_value( - static_runtime_params_slice, - dynamic_runtime_params_slice, - geo, - core_profiles, - ), - jnp.zeros(source.output_shape_getter(geo)), - ) + if ( + explicit + == static_source_runtime_params.is_explicit | calculate_anyway + ): + computed_source_profiles[source_name] = source.get_value( + static_runtime_params_slice, + dynamic_runtime_params_slice, + geo, + core_profiles, + ) + else: + computed_source_profiles[source_name] = jnp.zeros( + source.output_shape_getter(geo)) return computed_source_profiles diff --git a/torax/sources/tests/bootstrap_current_source_test.py b/torax/sources/tests/bootstrap_current_source_test.py index 7c3f2865..d693532b 100644 --- a/torax/sources/tests/bootstrap_current_source_test.py +++ b/torax/sources/tests/bootstrap_current_source_test.py @@ -37,8 +37,8 @@ def test_extraction_of_relevant_profile_from_output(self): """Tests that the relevant profile is extracted from the output.""" source = bootstrap_current_source.BootstrapCurrentSource() geo = circular_geometry.build_circular_geometry() - cell = source_lib.ProfileType.CELL.get_profile_shape(geo) - face = source_lib.ProfileType.FACE.get_profile_shape(geo) + cell = source_lib.get_cell_profile_shape(geo) + face = source_lib.get_face_profile_shape(geo) fake_profile = source_profiles.BootstrapCurrentProfile( sigma=jnp.zeros(cell), sigma_face=jnp.zeros(face), diff --git a/torax/sources/tests/electron_cyclotron_source_test.py b/torax/sources/tests/electron_cyclotron_source_test.py index 001ac05c..d14ec5c1 100644 --- a/torax/sources/tests/electron_cyclotron_source_test.py +++ b/torax/sources/tests/electron_cyclotron_source_test.py @@ -92,7 +92,7 @@ def test_extraction_of_relevant_profile_from_output(self): """Tests that the relevant profile is extracted from the output.""" geo = circular_geometry.build_circular_geometry() source = self._source_class() - cell = source_lib.ProfileType.CELL.get_profile_shape(geo) + cell = source_lib.get_cell_profile_shape(geo) fake_profile = jnp.stack((jnp.ones(cell), 2 * jnp.ones(cell))) # Check TEMP_EL and PSI are modified np.testing.assert_allclose( diff --git a/torax/sources/tests/generic_current_source_test.py b/torax/sources/tests/generic_current_source_test.py index 0ad3471d..0fd96119 100644 --- a/torax/sources/tests/generic_current_source_test.py +++ b/torax/sources/tests/generic_current_source_test.py @@ -42,7 +42,7 @@ def test_profile_is_on_cell_grid(self): source = source_builder() self.assertEqual( source.output_shape_getter(geo), - source_lib.ProfileType.CELL.get_profile_shape(geo), + source_lib.get_cell_profile_shape(geo), ) runtime_params = general_runtime_params.GeneralRuntimeParams() dynamic_runtime_params_slice = runtime_params_slice.DynamicRuntimeParamsSliceProvider( @@ -72,7 +72,7 @@ def test_profile_is_on_cell_grid(self): geo, core_profiles=None, ).shape, - source_lib.ProfileType.CELL.get_profile_shape(geo), + source_lib.get_cell_profile_shape(geo), ) @parameterized.named_parameters( diff --git a/torax/sources/tests/impurity_radiation_heat_sink_test.py b/torax/sources/tests/impurity_radiation_heat_sink_test.py index 94ba4614..3d3b43d2 100644 --- a/torax/sources/tests/impurity_radiation_heat_sink_test.py +++ b/torax/sources/tests/impurity_radiation_heat_sink_test.py @@ -158,7 +158,7 @@ def test_extraction_of_relevant_profile_from_output(self): source_models = source_models_builder() source = source_models.sources[self._source_name] self.assertIsInstance(source, source_lib.Source) - cell = source_lib.ProfileType.CELL.get_profile_shape(geo) + cell = source_lib.get_cell_profile_shape(geo) fake_profile = jnp.ones(cell) # Check TEMP_EL is modified np.testing.assert_allclose( diff --git a/torax/sources/tests/source_models_test.py b/torax/sources/tests/source_models_test.py index ec392d24..b4847e51 100644 --- a/torax/sources/tests/source_models_test.py +++ b/torax/sources/tests/source_models_test.py @@ -125,7 +125,7 @@ def test_summed_temp_ion_profiles_dont_change_when_jitting(self): source_models_builder = default_sources.get_default_sources_builder() source_models = source_models_builder() # Make some dummy source profiles that could have come from these sources. - ones = jnp.ones(source_lib.ProfileType.CELL.get_profile_shape(geo)) + ones = jnp.ones(source_lib.get_cell_profile_shape(geo)) profiles = source_profiles_lib.SourceProfiles( j_bootstrap=source_profiles_lib.BootstrapCurrentProfile.zero_profile( geo @@ -175,8 +175,8 @@ def foo_formula( unused_source_models, ): return jnp.stack([ - jnp.zeros(source_lib.ProfileType.CELL.get_profile_shape(geo)), - jnp.ones(source_lib.ProfileType.CELL.get_profile_shape(geo)), + jnp.zeros(source_lib.get_cell_profile_shape(geo)), + jnp.ones(source_lib.get_cell_profile_shape(geo)), ]) foo_source_builder = source_lib.make_source_builder( @@ -231,10 +231,10 @@ def compute_and_sum_profiles(): return (ne, temp_el) expected_ne = ( - jnp.ones(source_lib.ProfileType.CELL.get_profile_shape(geo)) * geo.vpr + jnp.ones(source_lib.get_cell_profile_shape(geo)) * geo.vpr ) expected_temp_el = jnp.zeros( - source_lib.ProfileType.CELL.get_profile_shape(geo) + source_lib.get_cell_profile_shape(geo) ) with self.subTest('without_jit'): (ne, temp_el) = compute_and_sum_profiles() diff --git a/torax/sources/tests/source_test.py b/torax/sources/tests/source_test.py index 82d3b579..ea9110e7 100644 --- a/torax/sources/tests/source_test.py +++ b/torax/sources/tests/source_test.py @@ -14,28 +14,18 @@ import dataclasses from absl.testing import absltest from absl.testing import parameterized -import jax from jax import numpy as jnp import numpy as np from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice from torax.geometry import circular_geometry -from torax.geometry import geometry from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib from torax.sources.tests import test_lib -def get_zero_profile( - profile_type: source_lib.ProfileType, - geo: geometry.Geometry, -) -> jax.Array: - """Returns a source profile with all zeros.""" - return jnp.zeros(profile_type.get_profile_shape(geo)) - - @dataclasses.dataclass(frozen=True, eq=True) class PsiTestSource(source_lib.Source): @@ -201,7 +191,7 @@ def test_zero_profile_works_by_default(self): ) np.testing.assert_allclose( profile, - get_zero_profile(source_lib.ProfileType.CELL, geo), + np.zeros_like(geo.torax_mesh.cell_centers) ) @parameterized.parameters( @@ -214,7 +204,7 @@ def test_correct_mode_called(self, mode, expected_profile): source_builder = source_lib.make_source_builder( test_lib.TestSource, model_func=lambda _0, _1, _2, _3, _4, _5: jnp.ones( - source_lib.ProfileType.CELL.get_profile_shape(geo) + source_lib.get_cell_profile_shape(geo) ) * 2, )() source_models_builder = source_models_lib.SourceModelsBuilder( @@ -328,13 +318,13 @@ def test_defaults_output_zeros(self): ) np.testing.assert_allclose( profile, - get_zero_profile(source_lib.ProfileType.CELL, geo), + np.zeros_like(geo.torax_mesh.cell_centers), ) def test_overriding_model(self): """The user-specified model should override the default model.""" geo = circular_geometry.build_circular_geometry() - output_shape = source_lib.ProfileType.CELL.get_profile_shape(geo) + output_shape = source_lib.get_cell_profile_shape(geo) expected_output = jnp.ones(output_shape) source_builder = source_lib.make_source_builder( IonElTestSource, @@ -378,7 +368,7 @@ def test_overriding_model(self): def test_overriding_prescribed_values(self): """Providing prescribed values results in the correct profile.""" geo = circular_geometry.build_circular_geometry() - output_shape = source_lib.ProfileType.CELL.get_profile_shape(geo) + output_shape = source_lib.get_cell_profile_shape(geo) # Define the expected output expected_output = jnp.ones(output_shape) # Create the source diff --git a/torax/sources/tests/test_lib.py b/torax/sources/tests/test_lib.py index 8c253f28..217af6bf 100644 --- a/torax/sources/tests/test_lib.py +++ b/torax/sources/tests/test_lib.py @@ -220,7 +220,7 @@ def test_extraction_of_relevant_profile_from_output(self): # pylint: disable=missing-kwoa source = self._source_class() # pytype: disable=missing-parameter # pylint: enable=missing-kwoa - cell = source_lib.ProfileType.CELL.get_profile_shape(geo) + cell = source_lib.get_cell_profile_shape(geo) fake_profile = jnp.stack((jnp.ones(cell), 2 * jnp.ones(cell))) np.testing.assert_allclose( source.get_source_profile_for_affected_core_profile( diff --git a/torax/tests/output.py b/torax/tests/output.py index 6ac85788..b01e989b 100644 --- a/torax/tests/output.py +++ b/torax/tests/output.py @@ -58,7 +58,7 @@ def setUp(self): source_models = source_models_builder() # Make some dummy source profiles that could have come from these sources. self.geo = circular_geometry.build_circular_geometry() - ones = jnp.ones(source_lib.ProfileType.CELL.get_profile_shape(self.geo)) + ones = jnp.ones(source_lib.get_cell_profile_shape(self.geo)) geo_provider = geometry_provider.ConstantGeometryProvider(self.geo) dynamic_runtime_params_slice, geo = ( torax_refs.build_consistent_dynamic_runtime_params_slice_and_geometry(