From 01ff89e4d182db87958e720b90e58c7a3efb651a Mon Sep 17 00:00:00 2001 From: Tamara Norman Date: Thu, 16 Jan 2025 21:25:51 -0800 Subject: [PATCH] Change core_profile_setters_test to have test prefix and simplify tests PiperOrigin-RevId: 716503857 --- torax/core_profile_setters.py | 122 ++++----- torax/fvm/cell_variable.py | 34 --- torax/state.py | 23 -- torax/tests/arg_order.py | 3 + ...etters.py => core_profile_setters_test.py} | 246 ++++++------------ torax/tests/state.py | 31 --- 6 files changed, 139 insertions(+), 320 deletions(-) rename torax/tests/{core_profile_setters.py => core_profile_setters_test.py} (76%) diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index b8a350b3..afe49d66 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -29,6 +29,8 @@ from torax import math_utils from torax import physics from torax import state +from torax.config import numerics +from torax.config import profile_conditions from torax.config import runtime_params_slice from torax.fvm import cell_variable from torax.geometry import geometry @@ -38,83 +40,76 @@ _trapz = jax.scipy.integrate.trapezoid +# Using capitalized variables for physics notational conventions rather than +# Python style. +# pylint: disable=invalid-name -def updated_ion_temperature( - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + +def _updated_ion_temperature( + dynamic_profile_conditions: profile_conditions.DynamicProfileConditions, geo: geometry.Geometry, ) -> cell_variable.CellVariable: """Updated ion temp. Used upon initialization and if temp_ion=False.""" - # pylint: disable=invalid-name - Ti_bound_right = ( - dynamic_runtime_params_slice.profile_conditions.Ti_bound_right - ) - Ti_bound_right = jax_utils.error_if_not_positive( - Ti_bound_right, + dynamic_profile_conditions.Ti_bound_right, 'Ti_bound_right', ) temp_ion = cell_variable.CellVariable( - value=dynamic_runtime_params_slice.profile_conditions.Ti, + value=dynamic_profile_conditions.Ti, left_face_grad_constraint=jnp.zeros(()), right_face_grad_constraint=None, right_face_constraint=Ti_bound_right, dr=geo.drho_norm, ) - # pylint: enable=invalid-name + return temp_ion -def updated_electron_temperature( - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, +def _updated_electron_temperature( + dynamic_profile_conditions: profile_conditions.DynamicProfileConditions, geo: geometry.Geometry, ) -> cell_variable.CellVariable: """Updated electron temp. Used upon initialization and if temp_el=False.""" - # pylint: disable=invalid-name - Te_bound_right = ( - dynamic_runtime_params_slice.profile_conditions.Te_bound_right - ) - Te_bound_right = jax_utils.error_if_not_positive( - Te_bound_right, + dynamic_profile_conditions.Te_bound_right, 'Te_bound_right', ) temp_el = cell_variable.CellVariable( - value=dynamic_runtime_params_slice.profile_conditions.Te, + value=dynamic_profile_conditions.Te, left_face_grad_constraint=jnp.zeros(()), right_face_grad_constraint=None, right_face_constraint=Te_bound_right, dr=geo.drho_norm, ) - # pylint: enable=invalid-name return temp_el -# pylint: disable=invalid-name def _get_ne( - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_numerics: numerics.DynamicNumerics, + dynamic_profile_conditions: profile_conditions.DynamicProfileConditions, geo: geometry.Geometry, ) -> cell_variable.CellVariable: """Gets initial or prescribed electron density profile at current timestep.""" - # pylint: disable=invalid-name + nGW = ( - dynamic_runtime_params_slice.profile_conditions.Ip_tot + dynamic_profile_conditions.Ip_tot / (jnp.pi * geo.Rmin**2) * 1e20 - / dynamic_runtime_params_slice.numerics.nref + / dynamic_numerics.nref ) ne_value = jnp.where( - dynamic_runtime_params_slice.profile_conditions.ne_is_fGW, - dynamic_runtime_params_slice.profile_conditions.ne * nGW, - dynamic_runtime_params_slice.profile_conditions.ne, + dynamic_profile_conditions.ne_is_fGW, + dynamic_profile_conditions.ne * nGW, + dynamic_profile_conditions.ne, ) # Calculate ne_bound_right. ne_bound_right = jnp.where( - dynamic_runtime_params_slice.profile_conditions.ne_bound_right_is_fGW, - dynamic_runtime_params_slice.profile_conditions.ne_bound_right * nGW, - dynamic_runtime_params_slice.profile_conditions.ne_bound_right, + dynamic_profile_conditions.ne_bound_right_is_fGW, + dynamic_profile_conditions.ne_bound_right * nGW, + dynamic_profile_conditions.ne_bound_right, ) - if dynamic_runtime_params_slice.profile_conditions.normalize_to_nbar: + if dynamic_profile_conditions.normalize_to_nbar: face_left = ne_value[0] # Zero gradient boundary condition at left face. face_right = ne_bound_right face_inner = (ne_value[..., :-1] + ne_value[..., 1:]) / 2.0 @@ -132,16 +127,15 @@ def _get_ne( Rmin_out = geo.Rout_face[-1] - geo.Rout_face[0] # find target nbar in absolute units target_nbar = jnp.where( - dynamic_runtime_params_slice.profile_conditions.ne_is_fGW, - dynamic_runtime_params_slice.profile_conditions.nbar * nGW, - dynamic_runtime_params_slice.profile_conditions.nbar, + dynamic_profile_conditions.ne_is_fGW, + dynamic_profile_conditions.nbar * nGW, + dynamic_profile_conditions.nbar, ) if ( - not dynamic_runtime_params_slice.profile_conditions.ne_bound_right_is_absolute + not dynamic_profile_conditions.ne_bound_right_is_absolute ): # In this case, ne_bound_right is taken from ne and we also normalize it. C = target_nbar / (_trapz(ne_face, geo.Rout_face) / Rmin_out) - # pylint: enable=invalid-name ne_bound_right = C * ne_bound_right else: # If ne_bound_right is absolute, subtract off contribution from outer @@ -180,7 +174,6 @@ def _get_charge_states( array_typing.ArrayFloat, ]: """Updated charge states based on IonMixtures and electron temperature.""" - # pylint: disable=invalid-name Zi = charge_states.get_average_charge_state( ion_symbols=static_runtime_params_slice.main_ion_names, ion_mixture=dynamic_runtime_params_slice.plasma_composition.main_ion, @@ -308,7 +301,6 @@ def _prescribe_currents_no_bootstrap( """ # Many variables throughout this function are capitalized based on physics # notational conventions rather than on Google Python style - # pylint: disable=invalid-name # Calculate splitting of currents depending on input runtime params. Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot @@ -397,7 +389,6 @@ def _prescribe_currents_with_bootstrap( # Many variables throughout this function are capitalized based on physics # notational conventions rather than on Google Python style - # pylint: disable=invalid-name Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot bootstrap_profile = source_models.j_bootstrap.get_value( @@ -483,7 +474,6 @@ def _calculate_currents_from_psi( # Many variables throughout this function are capitalized based on physics # notational conventions rather than on Google Python style - # pylint: disable=invalid-name jtot, jtot_face, Ip_profile_face = physics.calc_jtot_from_psi( geo, core_profiles.psi, @@ -574,7 +564,6 @@ def _update_psi_from_j( return psi -# pylint: enable=invalid-name def _calculate_psi_grad_constraint( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, @@ -689,12 +678,10 @@ def _init_psi_and_current( geo, currents.jtot_hires, ) - # pylint: disable=invalid-name _, _, Ip_profile_face = physics.calc_jtot_from_psi( geo, psi, ) - # pylint: enable=invalid-name currents = dataclasses.replace(currents, Ip_profile_face=Ip_profile_face) else: raise ValueError('Cannot compute psi for given config.') @@ -721,14 +708,19 @@ def initial_core_profiles( Returns: Initial core profiles. """ - # pylint: disable=invalid-name # To set initial values and compute the boundary conditions, we need to handle # potentially time-varying inputs from the users. # The default time in build_dynamic_runtime_params_slice is t_initial - temp_ion = updated_ion_temperature(dynamic_runtime_params_slice, geo) - temp_el = updated_electron_temperature(dynamic_runtime_params_slice, geo) - ne = _get_ne(dynamic_runtime_params_slice, geo) + temp_ion = _updated_ion_temperature( + dynamic_runtime_params_slice.profile_conditions, geo + ) + temp_el = _updated_electron_temperature( + dynamic_runtime_params_slice.profile_conditions, geo + ) + ne = _get_ne(dynamic_runtime_params_slice.numerics, + dynamic_runtime_params_slice.profile_conditions, + geo) ni, nimp, Zi, Zi_face, Zimp, Zimp_face = get_ion_density_and_charge_states( static_runtime_params_slice, @@ -795,15 +787,12 @@ def initial_core_profiles( core_profiles = dataclasses.replace(core_profiles, psidot=psidot) # Set psi as source of truth and recalculate jtot, q, s - core_profiles = physics.update_jtot_q_face_s_face( + return physics.update_jtot_q_face_s_face( geo=geo, core_profiles=core_profiles, q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor, ) - # pylint: enable=invalid-name - return core_profiles - def get_prescribed_core_profile_values( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, @@ -824,23 +813,22 @@ def get_prescribed_core_profile_values( Returns: Updated core profiles values on the cell grid. """ - # pylint: disable=invalid-name - # If profiles are not evolved, they can still potential be time-evolving, # depending on the runtime params. If so, they are updated below. if ( not static_runtime_params_slice.ion_heat_eq and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution ): - temp_ion = updated_ion_temperature(dynamic_runtime_params_slice, geo).value + temp_ion = _updated_ion_temperature( + dynamic_runtime_params_slice.profile_conditions, geo).value else: temp_ion = core_profiles.temp_ion.value if ( not static_runtime_params_slice.el_heat_eq and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution ): - temp_el_cell_variable = updated_electron_temperature( - dynamic_runtime_params_slice, geo + temp_el_cell_variable = _updated_electron_temperature( + dynamic_runtime_params_slice.profile_conditions, geo ) temp_el = temp_el_cell_variable.value else: @@ -850,7 +838,10 @@ def get_prescribed_core_profile_values( not static_runtime_params_slice.dens_eq and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution ): - ne_cell_variable = _get_ne(dynamic_runtime_params_slice, geo) + ne_cell_variable = _get_ne( + dynamic_runtime_params_slice.numerics, + dynamic_runtime_params_slice.profile_conditions, + geo) else: ne_cell_variable = core_profiles.ne ni, nimp, Zi, Zi_face, Zimp, Zimp_face = get_ion_density_and_charge_states( @@ -908,7 +899,6 @@ def get_update(x_new, var): psi = get_update(x_new, 'psi') ne = get_update(x_new, 'ne') - # pylint: disable=invalid-name ni, nimp, Zi, Zi_face, Zimp, Zimp_face = get_ion_density_and_charge_states( static_runtime_params_slice, dynamic_runtime_params_slice, @@ -916,7 +906,6 @@ def get_update(x_new, var): ne, temp_el, ) - # pylint: enable=invalid-name return dataclasses.replace( core_profiles, @@ -950,24 +939,24 @@ def compute_boundary_conditions( each CellVariable in the state. This dict can in theory recursively replace values in a State object. """ - Ti_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name + Ti_bound_right = jax_utils.error_if_not_positive( dynamic_runtime_params_slice.profile_conditions.Ti_bound_right, 'Ti_bound_right', ) - Te_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name + Te_bound_right = jax_utils.error_if_not_positive( dynamic_runtime_params_slice.profile_conditions.Te_bound_right, 'Te_bound_right', ) # TODO(b/390143606): Separate out the boundary condition calculation from the # core profile calculation. ne = _get_ne( - dynamic_runtime_params_slice, + dynamic_runtime_params_slice.numerics, + dynamic_runtime_params_slice.profile_conditions, geo, ) ne_bound_right = ne.right_face_constraint - # pylint: disable=invalid-name Zi_edge = charge_states.get_average_charge_state( static_runtime_params_slice.main_ion_names, ion_mixture=dynamic_runtime_params_slice.plasma_composition.main_ion, @@ -978,7 +967,6 @@ def compute_boundary_conditions( ion_mixture=dynamic_runtime_params_slice.plasma_composition.impurity, Te=Te_bound_right, ) - # pylint: disable=invalid-name dilution_factor_edge = physics.get_main_ion_dilution_factor( Zi_edge, @@ -1026,7 +1014,6 @@ def compute_boundary_conditions( } -# pylint: disable=invalid-name def _get_jtot_hires( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, @@ -1064,6 +1051,3 @@ def _get_jtot_hires( johm_hires = jformula_hires * Cohm_hires jtot_hires = johm_hires + external_current_hires + j_bootstrap_hires return jtot_hires - - -# pylint: enable=invalid-name diff --git a/torax/fvm/cell_variable.py b/torax/fvm/cell_variable.py index ed508406..050beb50 100644 --- a/torax/fvm/cell_variable.py +++ b/torax/fvm/cell_variable.py @@ -76,28 +76,6 @@ class CellVariable: # Can't make the above default values be jax zeros because that would be a # call to jax before absl.app.run - def project(self, weights): - assert self.history is not None - - def project(x): - return jnp.dot(weights, x) - - def opt_project(x): - if x is None: - return None - return project(x) - - return dataclasses.replace( - self, - value=project(self.value), - dr=self.dr[0], - left_face_constraint=opt_project(self.left_face_constraint), - left_face_grad_constraint=opt_project(self.left_face_grad_constraint), - right_face_constraint=opt_project(self.right_face_constraint), - right_face_grad_constraint=opt_project(self.right_face_grad_constraint), - history=None, - ) - def __post_init__(self): self.sanity_check() @@ -266,18 +244,6 @@ def assert_not_history(self): 'by `jax.lax.scan`. Most methods of a CellVariable ' 'do not work in history mode.' ) - if hasattr(self.history, 'ndim'): - if self.history.ndim == 0 or ( - self.history.ndim == 1 and self.history.shape[0] == 1 - ): - msg += ( - f' self.history={self.history} which probably indicates' - ' (due to its scalar shape)' - ' that an indexing or projection operation failed to' - ' turn off history mode. self.history should be None for' - ' non-history or a a vector of shape (history_length) for' - ' history.' - ) raise AssertionError(msg) def __hash__(self): diff --git a/torax/state.py b/torax/state.py index 1f7bc6a2..d432d11a 100644 --- a/torax/state.py +++ b/torax/state.py @@ -194,29 +194,6 @@ def sanity_check(self): if hasattr(value, "sanity_check"): value.sanity_check() - def project(self, weights): - project = lambda x: jnp.dot(weights, x) - proj_currents = jax.tree_util.tree_map(project, self.currents) - return dataclasses.replace( - self, - temp_ion=self.temp_ion.project(weights), - temp_el=self.temp_el.project(weights), - psi=self.psi.project(weights), - psidot=self.psidot.project(weights), - ne=self.ne.project(weights), - ni=self.ni.project(weights), - currents=proj_currents, - q_face=project(self.q_face), - s_face=project(self.s_face), - nref=project(self.nref), - Zi=project(self.Zi), - Zi_face=project(self.Zi_face), - Ai=project(self.Ai), - Zimp=project(self.Zimp), - Zimp_face=project(self.Zimp_face), - Aimp=project(self.Aimp), - ) - def __hash__(self): """Make CoreProfiles hashable. diff --git a/torax/tests/arg_order.py b/torax/tests/arg_order.py index 763b0897..ef67c30a 100644 --- a/torax/tests/arg_order.py +++ b/torax/tests/arg_order.py @@ -82,6 +82,9 @@ def test_arg_order(self, module): fields = inspect.getmembers(module) print(module.__name__) for name, obj in fields: + if name.startswith("_"): + # Ignore private fields and methods. + continue if inspect.isfunction(obj): print("\t", name) params = inspect.signature(obj).parameters.keys() diff --git a/torax/tests/core_profile_setters.py b/torax/tests/core_profile_setters_test.py similarity index 76% rename from torax/tests/core_profile_setters.py rename to torax/tests/core_profile_setters_test.py index 136d7542..b1d755b7 100644 --- a/torax/tests/core_profile_setters.py +++ b/torax/tests/core_profile_setters_test.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Tests for module torax.boundary_conditions.""" +from unittest import mock from absl.testing import absltest from absl.testing import parameterized from jax import numpy as jnp import numpy as np from torax import core_profile_setters +from torax import jax_utils from torax import physics from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params @@ -39,165 +39,70 @@ class CoreProfileSettersTest(parameterized.TestCase): def setUp(self): super().setUp() + jax_utils.enable_errors(True) self.geo = geometry.build_circular_geometry(n_rho=4) - @parameterized.parameters( - (0.0, np.array([10.5, 7.5, 4.5, 1.5])), - (80.0, np.array([1.0, 1.0, 1.0, 1.0])), - ( - 40.0, - np.array([ - (1.0 + 10.5) / 2, - (1.0 + 7.5) / 2, - (1.0 + 4.5) / 2, - (1.0 + 1.5) / 2, - ]), - ), - ) - def test_temperature_rho_and_time_interpolation( - self, - t: float, - expected_temperature: np.ndarray, - ): - """Tests that the temperature rho and time interpolation works.""" - runtime_params = general_runtime_params.GeneralRuntimeParams( - profile_conditions=profile_conditions_lib.ProfileConditions( - Ti={0.0: {0.0: 12.0, 1.0: SMALL_VALUE}, 80.0: {0.0: 1.0}}, - Ti_bound_right=SMALL_VALUE, - Te={0.0: {0.0: 12.0, 1.0: SMALL_VALUE}, 80.0: {0.0: 1.0}}, - Te_bound_right=SMALL_VALUE, - ), - ) - geo = geometry.build_circular_geometry(n_rho=4) - dynamic_slice = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( - runtime_params, - torax_mesh=geo.torax_mesh, - )(t=t) - Ti = core_profile_setters.updated_ion_temperature(dynamic_slice, geo) - Te = core_profile_setters.updated_electron_temperature(dynamic_slice, geo) - np.testing.assert_allclose( - Ti.value, - expected_temperature, - rtol=1e-6, - atol=1e-6, - ) - np.testing.assert_allclose( - Te.value, - expected_temperature, - rtol=1e-6, - atol=1e-6, - ) - - @parameterized.parameters( - (None, None, 2.0, 2.0), - (1.0, None, 1.0, 2.0), - (None, 1.0, 2.0, 1.0), - (None, None, 2.0, 2.0), - ) - def test_temperature_boundary_condition_override( - self, - Ti_bound_right: float | None, - Te_bound_right: float | None, - expected_Ti_bound_right: float, - expected_Te_bound_right: float, - ): - """Tests that the temperature boundary condition override works.""" - runtime_params = general_runtime_params.GeneralRuntimeParams( - profile_conditions=profile_conditions_lib.ProfileConditions( - Ti={ - 0.0: {0.0: 12.0, 1.0: 2.0}, - }, - Te={ - 0.0: {0.0: 12.0, 1.0: 2.0}, - }, - Ti_bound_right=Ti_bound_right, - Te_bound_right=Te_bound_right, - ), - ) - t = 0.0 - geo = geometry.build_circular_geometry(n_rho=4) - dynamic_slice = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( - runtime_params, - torax_mesh=geo.torax_mesh, - )( - t=t, - ) - Ti_bound_right = core_profile_setters.updated_ion_temperature( - dynamic_slice, geo - ).right_face_constraint - Te_bound_right = core_profile_setters.updated_electron_temperature( - dynamic_slice, geo - ).right_face_constraint - self.assertEqual( - Ti_bound_right, - expected_Ti_bound_right, - ) - self.assertEqual( - Te_bound_right, - expected_Te_bound_right, - ) - - def test_time_dependent_provider_with_temperature_is_time_dependent(self): - """Tests that the runtime_params slice provider is time dependent for T.""" - runtime_params = general_runtime_params.GeneralRuntimeParams( - profile_conditions=profile_conditions_lib.ProfileConditions( - Ti={0.0: {0.0: 12.0, 1.0: SMALL_VALUE}, 3.0: {0.0: SMALL_VALUE}}, - Ti_bound_right=SMALL_VALUE, - Te={0.0: {0.0: 12.0, 1.0: SMALL_VALUE}, 3.0: {0.0: SMALL_VALUE}}, - Te_bound_right=SMALL_VALUE, - ), - ) - geo = geometry.build_circular_geometry(n_rho=4) - provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( - runtime_params=runtime_params, - transport=transport_params_lib.RuntimeParams(), - sources={}, - stepper=stepper_params_lib.RuntimeParams(), - torax_mesh=geo.torax_mesh, - ) - - dynamic_runtime_params_slice = provider(t=1.0) - Ti = core_profile_setters.updated_ion_temperature( - dynamic_runtime_params_slice, geo - ) - Te = core_profile_setters.updated_electron_temperature( - dynamic_runtime_params_slice, geo - ) - - np.testing.assert_allclose( - Ti.value, - np.array([7.0, 5.0, 3.0, 1.0]), - atol=1e-6, - rtol=1e-6, - ) - np.testing.assert_allclose( - Te.value, - np.array([7.0, 5.0, 3.0, 1.0]), - atol=1e-6, - rtol=1e-6, + def test_updated_ion_temperature(self): + bound = np.array(42.) + value = np.array([12.0, 10.0, 8.0, 6.0]) + profile_conditions = mock.create_autospec( + profile_conditions_lib.DynamicProfileConditions, + instance=True, + Ti_bound_right=bound, + Ti=value, + ) + result = core_profile_setters._updated_ion_temperature( + profile_conditions, + self.geo, ) + np.testing.assert_allclose(result.value, value) + np.testing.assert_equal(result.right_face_constraint, bound) + + @parameterized.parameters(0, -1) + def test_updated_ion_temperature_negative_Ti_bound_right( + self, Ti_bound_right: float): + profile_conditions = mock.create_autospec( + profile_conditions_lib.DynamicProfileConditions, + instance=True, + Ti_bound_right=np.array(Ti_bound_right), + Ti=np.array([12.0, 10.0, 8.0, 6.0]), + ) + with self.assertRaisesRegex(RuntimeError, 'Ti_bound_right'): + core_profile_setters._updated_ion_temperature( + profile_conditions, + self.geo, + ) - dynamic_runtime_params_slice = provider( - t=2.0, - ) - Ti = core_profile_setters.updated_ion_temperature( - dynamic_runtime_params_slice, geo - ) - Te = core_profile_setters.updated_electron_temperature( - dynamic_runtime_params_slice, geo - ) - np.testing.assert_allclose( - Ti.value, - np.array([3.5, 2.5, 1.5, 0.5]), - atol=1e-6, - rtol=1e-6, - ) - np.testing.assert_allclose( - Te.value, - np.array([3.5, 2.5, 1.5, 0.5]), - atol=1e-6, - rtol=1e-6, - ) + def test_updated_electron_temperature(self): + bound = np.array(42.) + value = np.array([12.0, 10.0, 8.0, 6.0]) + profile_conditions = mock.create_autospec( + profile_conditions_lib.DynamicProfileConditions, + instance=True, + Te_bound_right=bound, + Te=value, + ) + result = core_profile_setters._updated_electron_temperature( + profile_conditions, + self.geo + ) + np.testing.assert_allclose(result.value, value) + np.testing.assert_equal(result.right_face_constraint, bound) + + @parameterized.parameters(0, -1) + def test_updated_electron_temperature_negative_Te_bound_right( + self, Te_bound_right: float): + profile_conditions = mock.create_autospec( + profile_conditions_lib.DynamicProfileConditions, + instance=True, + Te_bound_right=np.array(Te_bound_right), + Te=np.array([12.0, 10.0, 8.0, 6.0]), + ) + with self.assertRaisesRegex(RuntimeError, 'Te_bound_right'): + core_profile_setters._updated_electron_temperature( + profile_conditions, + self.geo, + ) def test_ne_core_profile_setter(self): """Tests that setting ne works.""" @@ -225,6 +130,7 @@ def test_ne_core_profile_setter(self): torax_mesh=self.geo.torax_mesh, ) dynamic_runtime_params_slice = provider(t=1.0) + temp_el = cell_variable.CellVariable( value=jnp.ones_like(self.geo.rho_norm) * 100.0, # ensure full ionization @@ -234,7 +140,9 @@ def test_ne_core_profile_setter(self): dr=self.geo.drho_norm, ) ne = core_profile_setters._get_ne( - dynamic_runtime_params_slice, self.geo + dynamic_runtime_params_slice.numerics, + dynamic_runtime_params_slice.profile_conditions, + self.geo ) ni, nimp, Zi, _, Zimp, _ = ( core_profile_setters.get_ion_density_and_charge_states( @@ -306,7 +214,11 @@ def test_density_boundary_condition_override( dynamic_runtime_params_slice = provider( t=1.0, ) - ne = core_profile_setters._get_ne(dynamic_runtime_params_slice, self.geo) + ne = core_profile_setters._get_ne( + dynamic_runtime_params_slice.numerics, + dynamic_runtime_params_slice.profile_conditions, + self.geo, + ) np.testing.assert_allclose( ne.right_face_constraint, expected_value, @@ -340,7 +252,9 @@ def test_ne_core_profile_setter_with_normalization( ) ne_normalized = core_profile_setters._get_ne( - dynamic_runtime_params_slice_normalized, self.geo + dynamic_runtime_params_slice_normalized.numerics, + dynamic_runtime_params_slice_normalized.profile_conditions, + self.geo ) np.testing.assert_allclose(np.mean(ne_normalized.value), nbar, rtol=1e-1) @@ -351,7 +265,9 @@ def test_ne_core_profile_setter_with_normalization( ) ne_unnormalized = core_profile_setters._get_ne( - dynamic_runtime_params_slice_unnormalized, self.geo + dynamic_runtime_params_slice_unnormalized.numerics, + dynamic_runtime_params_slice_unnormalized.profile_conditions, + self.geo ) ratio = ne_unnormalized.value / ne_normalized.value @@ -387,7 +303,9 @@ def test_ne_core_profile_setter_with_fGW( t=1.0, ) ne_fGW = core_profile_setters._get_ne( - dynamic_runtime_params_slice_fGW, self.geo + dynamic_runtime_params_slice_fGW.numerics, + dynamic_runtime_params_slice_fGW.profile_conditions, + self.geo ) runtime_params.profile_conditions.ne_is_fGW = False @@ -396,7 +314,9 @@ def test_ne_core_profile_setter_with_fGW( ) ne = core_profile_setters._get_ne( - dynamic_runtime_params_slice, self.geo + dynamic_runtime_params_slice.numerics, + dynamic_runtime_params_slice.profile_conditions, + self.geo ) ratio = ne.value / ne_fGW.value diff --git a/torax/tests/state.py b/torax/tests/state.py index a9755703..f20918ca 100644 --- a/torax/tests/state.py +++ b/torax/tests/state.py @@ -153,37 +153,6 @@ def test_index( for i in range(self.history_length): self.assertEqual(i, history.index(i).temp_ion.value[0]) - @parameterized.parameters([ - dict(references_getter=torax_refs.circular_references), - dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict( - references_getter=torax_refs.chease_references_Ip_from_runtime_params - ), - ]) - def test_project( - self, - references_getter: Callable[[], torax_refs.References], - ): - """Test State.project.""" - references = references_getter() - history = self._make_history( - references.runtime_params, references.geometry_provider - ) - - seed = 20230421 - rng_state = jax.random.PRNGKey(seed) - del seed # Make sure seed isn't accidentally re-used - weights = jax.random.normal(rng_state, (self.history_length,)) - del rng_state # Make sure rng_state isn't accidentally re-used - - expected = jnp.dot(weights, jnp.arange(self.history_length)) - - projected = history.project(weights) - - actual = projected.temp_ion.value[0] - - np.testing.assert_allclose(expected, actual) - class InitialStatesTest(parameterized.TestCase): """Unit tests for the `torax.core_profile_setters` module."""