From e99b0b3cc78071c07954d0cb7baf87e99ce48961 Mon Sep 17 00:00:00 2001 From: Tamara Norman Date: Wed, 22 Jan 2025 15:28:53 -0800 Subject: [PATCH] Move circular geometry to its own file PiperOrigin-RevId: 718551839 --- torax/config/build_sim.py | 9 +- torax/config/tests/build_sim.py | 7 +- torax/config/tests/numerics.py | 6 +- torax/config/tests/plasma_composition.py | 10 +- torax/config/tests/profile_conditions.py | 14 +- torax/config/tests/runtime_params.py | 4 +- torax/config/tests/runtime_params_slice.py | 14 +- torax/core_profile_setters.py | 3 +- torax/fvm/tests/calc_coeffs.py | 4 +- torax/fvm/tests/fvm.py | 10 +- torax/geometry/circular_geometry.py | 270 ++++++++++++++++++ torax/geometry/geometry.py | 245 ---------------- torax/geometry/geometry_provider.py | 2 +- .../geometry/tests/circular_geometry_test.py | 63 ++++ torax/geometry/tests/geometry_test.py | 55 ---- .../tests/set_pped_tpedratio_nped.py | 8 +- torax/pedestal_model/tests/set_tped_nped.py | 8 +- .../sources/tests/bootstrap_current_source.py | 4 +- .../tests/electron_cyclotron_source.py | 6 +- torax/sources/tests/generic_current_source.py | 6 +- .../tests/impurity_radiation_heat_sink.py | 6 +- torax/sources/tests/ion_cyclotron_source.py | 4 +- torax/sources/tests/qei_source.py | 4 +- torax/sources/tests/source.py | 15 +- torax/sources/tests/source_models.py | 9 +- torax/sources/tests/source_runtime_params.py | 4 +- torax/sources/tests/test_lib.py | 10 +- torax/tests/boundary_conditions.py | 4 +- torax/tests/core_profile_setters_test.py | 4 +- torax/tests/math_utils.py | 7 +- torax/tests/output.py | 4 +- torax/tests/physics.py | 7 +- torax/tests/post_processing.py | 6 +- torax/tests/sim.py | 6 +- torax/tests/sim_custom_sources.py | 3 +- torax/tests/sim_output_source_profiles.py | 5 +- torax/tests/sim_time_dependence.py | 3 +- torax/tests/state.py | 13 +- torax/tests/test_data/test_explicit.py | 4 +- torax/tests/test_lib/torax_refs.py | 3 +- torax/transport_model/tests/bohm_gyrobohm.py | 4 +- torax/transport_model/tests/constant.py | 4 +- .../tests/critical_gradient.py | 4 +- .../tests/qlknn_transport_model.py | 6 +- .../tests/qualikiz_based_transport_model.py | 3 +- .../tests/qualikiz_transport_model.py | 4 +- .../tests/quasilinear_transport_model.py | 5 +- .../transport_model/tests/transport_model.py | 5 +- .../tests/transport_model_runtime_params.py | 4 +- 49 files changed, 478 insertions(+), 430 deletions(-) create mode 100644 torax/geometry/circular_geometry.py create mode 100644 torax/geometry/tests/circular_geometry_test.py diff --git a/torax/config/build_sim.py b/torax/config/build_sim.py index 7c9e3902..07f35478 100644 --- a/torax/config/build_sim.py +++ b/torax/config/build_sim.py @@ -21,6 +21,7 @@ from torax import sim as sim_lib from torax.config import config_args from torax.config import runtime_params as runtime_params_lib +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.geometry import geometry_provider from torax.pedestal_model import pedestal_model as pedestal_model_lib @@ -126,14 +127,14 @@ def _build_circular_geometry_provider( raise ValueError('n_rho must be set in the input config.') geometries = {} for time, c in kwargs['geometry_configs'].items(): - geometries[time] = geometry.build_circular_geometry( + geometries[time] = circular_geometry.build_circular_geometry( n_rho=kwargs['n_rho'], **c ) - return geometry.CircularAnalyticalGeometryProvider.create_provider( + return circular_geometry.CircularAnalyticalGeometryProvider.create_provider( geometries ) return geometry_provider.ConstantGeometryProvider( - geometry.build_circular_geometry(**kwargs) + circular_geometry.build_circular_geometry(**kwargs) ) @@ -153,7 +154,7 @@ def build_geometry_provider_from_config( expected in the rest of the config. See the following functions to get a full list of the arguments exposed: - - `geometry.build_circular_geometry()` + - `circular_geometry.build_circular_geometry()` - `geometry.StandardGeometryIntermediates.from_chease()` - `geometry.StandardGeometryIntermediates.from_fbt()` diff --git a/torax/config/tests/build_sim.py b/torax/config/tests/build_sim.py index 9f5669cb..199148d3 100644 --- a/torax/config/tests/build_sim.py +++ b/torax/config/tests/build_sim.py @@ -20,6 +20,7 @@ from torax.config import build_sim from torax.config import runtime_params as runtime_params_lib from torax.config import runtime_params_slice +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.geometry import geometry_provider from torax.pedestal_model import set_tped_nped @@ -105,7 +106,7 @@ def test_build_sim_with_full_config(self): ) with self.subTest('geometry'): geo = sim.geometry_provider(sim.initial_state.t) - self.assertIsInstance(geo, geometry.CircularAnalyticalGeometry) + self.assertIsInstance(geo, circular_geometry.CircularAnalyticalGeometry) self.assertEqual(geo.torax_mesh.nx, 5) with self.subTest('sources'): self.assertEqual( @@ -185,7 +186,7 @@ def test_general_runtime_params_with_time_dependent_args(self): self.assertEqual(runtime_params.profile_conditions.ne_is_fGW, False) self.assertEqual(runtime_params.numerics.q_correction_factor, 0.2) self.assertEqual(runtime_params.output_dir, '/tmp/this/is/a/test') - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, @@ -218,7 +219,7 @@ def test_build_circular_geometry(self): ) geo = geo_provider(t=0) np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 5) - self.assertIsInstance(geo, geometry.CircularAnalyticalGeometry) + self.assertIsInstance(geo, circular_geometry.CircularAnalyticalGeometry) np.testing.assert_array_equal(geo.B0, 5.3) # test a default. def test_build_geometry_from_chease(self): diff --git a/torax/config/tests/numerics.py b/torax/config/tests/numerics.py index 5fc1b394..4f5f7b72 100644 --- a/torax/config/tests/numerics.py +++ b/torax/config/tests/numerics.py @@ -18,7 +18,7 @@ from absl.testing import parameterized from torax import interpolated_param from torax.config import numerics -from torax.geometry import geometry +from torax.geometry import circular_geometry class NumericsTest(parameterized.TestCase): @@ -26,7 +26,7 @@ class NumericsTest(parameterized.TestCase): def test_numerics_make_provider(self): nums = numerics.Numerics() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = nums.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) @@ -35,7 +35,7 @@ def test_interpolated_vars_are_only_constructed_once( ): """Tests that interpolated vars are only constructed once.""" nums = numerics.Numerics() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = nums.make_provider(geo.torax_mesh) interpolated_params = {} for field in provider: diff --git a/torax/config/tests/plasma_composition.py b/torax/config/tests/plasma_composition.py index c6ac4968..20a81264 100644 --- a/torax/config/tests/plasma_composition.py +++ b/torax/config/tests/plasma_composition.py @@ -20,7 +20,7 @@ from torax import charge_states from torax import interpolated_param from torax.config import plasma_composition -from torax.geometry import geometry +from torax.geometry import circular_geometry class PlasmaCompositionTest(parameterized.TestCase): @@ -29,7 +29,7 @@ class PlasmaCompositionTest(parameterized.TestCase): def test_plasma_composition_make_provider(self): """Checks provider construction with no issues.""" pc = plasma_composition.PlasmaComposition() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = pc.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) @@ -40,7 +40,7 @@ def test_plasma_composition_make_provider(self): ) def test_zeff_accepts_float_inputs(self, zeff: float): """Tests that zeff accepts a single float input.""" - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() pc = plasma_composition.PlasmaComposition(Zeff=zeff) provider = pc.make_provider(geo.torax_mesh) dynamic_pc = provider.build_dynamic_params(t=0.0) @@ -63,7 +63,7 @@ def test_zeff_and_zeff_face_match_expected(self): 1.0: {0.0: 1.8, 0.5: 2.1, 1.0: 2.4}, } - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() pc = plasma_composition.PlasmaComposition(Zeff=zeff_profile) provider = pc.make_provider(geo.torax_mesh) @@ -102,7 +102,7 @@ def test_interpolated_vars_are_only_constructed_once( ): """Tests that interpolated vars are only constructed once.""" pc = plasma_composition.PlasmaComposition() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = pc.make_provider(geo.torax_mesh) interpolated_params = {} for field in provider: diff --git a/torax/config/tests/profile_conditions.py b/torax/config/tests/profile_conditions.py index 10c20156..cb2d9998 100644 --- a/torax/config/tests/profile_conditions.py +++ b/torax/config/tests/profile_conditions.py @@ -20,7 +20,7 @@ from torax import interpolated_param from torax.config import config_args from torax.config import profile_conditions -from torax.geometry import geometry +from torax.geometry import circular_geometry import xarray as xr @@ -30,7 +30,7 @@ class ProfileConditionsTest(parameterized.TestCase): def test_profile_conditions_make_provider(self): pc = profile_conditions.ProfileConditions() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = pc.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) @@ -46,7 +46,7 @@ def test_profile_conditions_sets_Te_bound_right_correctly( Te={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}}, Te_bound_right=Te_bound_right, ) - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = pc.make_provider(geo.torax_mesh) dcs = provider.build_dynamic_params(t=0.0) self.assertEqual(dcs.Te_bound_right, expected_initial_value) @@ -65,7 +65,7 @@ def test_profile_conditions_sets_Ti_bound_right_correctly( Ti={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}}, Ti_bound_right=Ti_bound_right, ) - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = pc.make_provider(geo.torax_mesh) dcs = provider.build_dynamic_params(t=0.0) self.assertEqual(dcs.Ti_bound_right, expected_initial_value) @@ -84,7 +84,7 @@ def test_profile_conditions_sets_ne_bound_right_correctly( ne={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}}, ne_bound_right=ne_bound_right, ) - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = pc.make_provider(geo.torax_mesh) dcs = provider.build_dynamic_params(t=0.0) self.assertEqual(dcs.ne_bound_right, expected_initial_value) @@ -126,7 +126,7 @@ def test_profile_conditions_sets_psi_correctly( self, psi, expected_initial_value, expected_second_value ): """Tests that psi is set correctly.""" - geo = geometry.build_circular_geometry(n_rho=4) + geo = circular_geometry.build_circular_geometry(n_rho=4) pc = profile_conditions.ProfileConditions( psi=psi, ) @@ -147,7 +147,7 @@ def test_interpolated_vars_are_only_constructed_once( ): """Tests that interpolated vars are only constructed once.""" pc = profile_conditions.ProfileConditions() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = pc.make_provider(geo.torax_mesh) interpolated_params = {} for field in provider: diff --git a/torax/config/tests/runtime_params.py b/torax/config/tests/runtime_params.py index d13aa6c9..f5cfd34c 100644 --- a/torax/config/tests/runtime_params.py +++ b/torax/config/tests/runtime_params.py @@ -21,7 +21,7 @@ from torax.config import config_args from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params -from torax.geometry import geometry +from torax.geometry import circular_geometry # pylint: disable=invalid-name @@ -137,7 +137,7 @@ def test_runtime_params_make_provider(self): runtime_params = general_runtime_params.GeneralRuntimeParams( profile_conditions=profile_conditions_lib.ProfileConditions() ) - torax_mesh = geometry.build_circular_geometry().torax_mesh + torax_mesh = circular_geometry.build_circular_geometry().torax_mesh runtime_params_provider = runtime_params.make_provider(torax_mesh) runtime_params_provider.build_dynamic_params(0.0) diff --git a/torax/config/tests/runtime_params_slice.py b/torax/config/tests/runtime_params_slice.py index 53f4a8f4..8e9ce64d 100644 --- a/torax/config/tests/runtime_params_slice.py +++ b/torax/config/tests/runtime_params_slice.py @@ -23,7 +23,7 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice as runtime_params_slice_lib -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.pedestal_model import set_tped_nped from torax.sources import electron_density_sources from torax.sources import generic_current_source @@ -37,7 +37,7 @@ class RuntimeParamsSliceTest(parameterized.TestCase): def setUp(self): super().setUp() - self._geo = geometry.build_circular_geometry() + self._geo = circular_geometry.build_circular_geometry() def test_dynamic_slice_can_be_input_to_jitted_function(self): """Tests that the slice can be input to a jitted function.""" @@ -351,7 +351,7 @@ def test_profile_conditions_set_electron_temperature_and_boundary_condition( runtime_params = general_runtime_params.GeneralRuntimeParams( profile_conditions=profile_conditions, ) - geo = geometry.build_circular_geometry(n_rho=4) + geo = circular_geometry.build_circular_geometry(n_rho=4) dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, torax_mesh=geo.torax_mesh, @@ -394,7 +394,7 @@ def test_profile_conditions_set_electron_density_and_boundary_condition( ne_is_fGW=ne_is_fGW, ), ) - geo = geometry.build_circular_geometry(n_rho=4) + geo = circular_geometry.build_circular_geometry(n_rho=4) dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, @@ -427,7 +427,7 @@ def test_update_dynamic_slice_provider_updates_runtime_params( Ti_bound_right={0.0: 1.0, 1.0: 2.0}, ), ) - geo = geometry.build_circular_geometry(n_rho=4) + geo = circular_geometry.build_circular_geometry(n_rho=4) provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, torax_mesh=geo.torax_mesh, @@ -463,7 +463,7 @@ def test_update_dynamic_slice_provider_updates_sources( source_models_builder.runtime_params[ generic_current_source.GenericCurrentSource.SOURCE_NAME ].Iext = 1.0 - geo = geometry.build_circular_geometry(n_rho=4) + geo = circular_geometry.build_circular_geometry(n_rho=4) provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, sources=source_models_builder.runtime_params, @@ -519,7 +519,7 @@ def test_update_dynamic_slice_provider_updates_transport( """Tests that the dynamic slice provider can be updated.""" runtime_params = general_runtime_params.GeneralRuntimeParams() transport = transport_params_lib.RuntimeParams(De_inner=1.0) - geo = geometry.build_circular_geometry(n_rho=4) + geo = circular_geometry.build_circular_geometry(n_rho=4) provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, torax_mesh=geo.torax_mesh, diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index afe49d66..d2da7cfd 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -33,6 +33,7 @@ from torax.config import profile_conditions from torax.config import runtime_params_slice from torax.fvm import cell_variable +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.sources import ohmic_heat_source from torax.sources import source_models as source_models_lib @@ -648,7 +649,7 @@ def _init_psi_and_current( ) # Calculating j according to nu formula and psi from j. elif ( - isinstance(geo, geometry.CircularAnalyticalGeometry) + isinstance(geo, circular_geometry.CircularAnalyticalGeometry) or dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j ): currents = _prescribe_currents_no_bootstrap( diff --git a/torax/fvm/tests/calc_coeffs.py b/torax/fvm/tests/calc_coeffs.py index 5f40e6a4..8914fda3 100644 --- a/torax/fvm/tests/calc_coeffs.py +++ b/torax/fvm/tests/calc_coeffs.py @@ -21,7 +21,7 @@ from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice as runtime_params_slice_lib from torax.fvm import calc_coeffs -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.pedestal_model import set_tped_nped from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -55,7 +55,7 @@ def test_calc_coeffs_smoke_test( predictor_corrector=False, theta_imp=theta_imp, ) - geo = geometry.build_circular_geometry(n_rho=num_cells) + geo = circular_geometry.build_circular_geometry(n_rho=num_cells) transport_model_builder = ( constant_transport_model.ConstantTransportModelBuilder( runtime_params=constant_transport_model.RuntimeParams( diff --git a/torax/fvm/tests/fvm.py b/torax/fvm/tests/fvm.py index e78befef..4d1663a4 100644 --- a/torax/fvm/tests/fvm.py +++ b/torax/fvm/tests/fvm.py @@ -33,7 +33,7 @@ from torax.fvm import cell_variable from torax.fvm import implicit_solve_block from torax.fvm import residual_and_loss -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.pedestal_model import set_tped_nped from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -390,7 +390,7 @@ def test_nonlinear_solve_block_loss_minimum( predictor_corrector=False, theta_imp=theta_imp, ) - geo = geometry.build_circular_geometry(n_rho=num_cells) + geo = circular_geometry.build_circular_geometry(n_rho=num_cells) transport_model_builder = ( constant_transport_model.ConstantTransportModelBuilder( runtime_params=constant_transport_model.RuntimeParams( @@ -558,7 +558,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): pedestal_model_builder = ( set_tped_nped.SetTemperatureDensityPedestalModelBuilder() ) - geo = geometry.build_circular_geometry(n_rho=num_cells) + geo = circular_geometry.build_circular_geometry(n_rho=num_cells) dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, @@ -579,7 +579,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): stepper=stepper_params, ) ) - geo = geometry.build_circular_geometry(n_rho=num_cells) + geo = circular_geometry.build_circular_geometry(n_rho=num_cells) source_models = source_models_builder() initial_core_profiles = core_profile_setters.initial_core_profiles( static_runtime_params_slice, @@ -681,7 +681,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self): predictor_corrector=False, theta_imp=0.0, ) - geo = geometry.build_circular_geometry(n_rho=num_cells) + geo = circular_geometry.build_circular_geometry(n_rho=num_cells) transport_model_builder = ( constant_transport_model.ConstantTransportModelBuilder( runtime_params=constant_transport_model.RuntimeParams( diff --git a/torax/geometry/circular_geometry.py b/torax/geometry/circular_geometry.py new file mode 100644 index 00000000..00fc7a9e --- /dev/null +++ b/torax/geometry/circular_geometry.py @@ -0,0 +1,270 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing a circular geometry.""" + + +from __future__ import annotations + +import chex +import numpy as np +from torax import interpolated_param +from torax.geometry import geometry + + +# Using invalid-name because we are using the same naming convention as the +# external physics implementations +# pylint: disable=invalid-name + + +@chex.dataclass(frozen=True) +class CircularAnalyticalGeometry(geometry.Geometry): + """Circular geometry type used for testing only. + + Most users should default to using the Geometry class. + """ + + elongation_hires: chex.Array + + +@chex.dataclass(frozen=True) +class CircularAnalyticalGeometryProvider(geometry.GeometryProvider): + """Circular geometry type used for testing only. + + Most users should default to using the GeometryProvider class. + """ + + elongation_hires: interpolated_param.InterpolatedVarSingleAxis + + def __call__(self, t: chex.Numeric) -> geometry.Geometry: + """Returns a Geometry instance at the given time.""" + return self._get_geometry_base(t, CircularAnalyticalGeometry) + + +def build_circular_geometry( + n_rho: int = 25, + elongation_LCFS: float = 1.72, + Rmaj: float = 6.2, + Rmin: float = 2.0, + B0: float = 5.3, + hires_fac: int = 4, +) -> CircularAnalyticalGeometry: + """Constructs a CircularAnalyticalGeometry. + + This is the standard entrypoint for building a circular geometry, not + CircularAnalyticalGeometry.__init__(). chex.dataclasses do not allow + overriding __init__ functions with different parameters than the attributes of + the dataclass, so this builder function lives outside the class. + + Args: + n_rho: Radial grid points (num cells) + elongation_LCFS: Elongation at last closed flux surface. Defaults to 1.72 + for the ITER elongation, to approximately correct volume and area integral + Jacobians. + Rmaj: major radius (R) in meters + Rmin: minor radius (a) in meters + B0: Toroidal magnetic field on axis [T] + hires_fac: Grid refinement factor for poloidal flux <--> plasma current + calculations. + + Returns: + A CircularAnalyticalGeometry instance. + """ + # circular geometry assumption of r/Rmin = rho_norm, the normalized + # toroidal flux coordinate. + drho_norm = 1.0 / n_rho + # Define mesh (Slab Uniform 1D with Jacobian = 1) + mesh = geometry.Grid1D.construct(nx=n_rho, dx=drho_norm) + # toroidal flux coordinate (rho) at boundary (last closed flux surface) + rho_b = np.asarray(Rmin) + + # normalized and unnormalized toroidal flux coordinate (rho) + # on face and cell grids. See fvm documentation and paper for details on + # face and cell grids. + rho_face_norm = mesh.face_centers + rho_norm = mesh.cell_centers + rho_face = rho_face_norm * rho_b + rho = rho_norm * rho_b + + Rmaj = np.array(Rmaj) + B0 = np.array(B0) + + # Define toroidal flux + Phi = np.pi * B0 * rho**2 + Phi_face = np.pi * B0 * rho_face**2 + + # Elongation profile. + # Set to be a linearly increasing function from 1 to elongation_LCFS, which + # is the elongation value at the last closed flux surface, set in config. + elongation = 1 + rho_norm * (elongation_LCFS - 1) + elongation_face = 1 + rho_face_norm * (elongation_LCFS - 1) + + # Volume in elongated circular geometry is given by: + # V = 2*pi^2*R*rho^2*elongation + # S = pi*rho^2*elongation + + volume = 2 * np.pi**2 * Rmaj * rho**2 * elongation + volume_face = 2 * np.pi**2 * Rmaj * rho_face**2 * elongation_face + area = np.pi * rho**2 * elongation + area_face = np.pi * rho_face**2 * elongation_face + + # V' = dV/drnorm for volume integrations + # \nabla V = 4*pi^2*R*rho*elongation + # + V * (elongation_param - 1) / elongation / rho_b + # vpr = \nabla V * rho_b + vpr = 4 * np.pi**2 * Rmaj * rho * elongation * rho_b + volume / elongation * ( + elongation_LCFS - 1 + ) + vpr_face = ( + 4 * np.pi**2 * Rmaj * rho_face * elongation_face * rho_b + + volume_face / elongation_face * (elongation_LCFS - 1) + ) + # pylint: disable=invalid-name + # S' = dS/drnorm for area integrals on cell grid + spr_cell = 2 * np.pi * rho * elongation * rho_b + area / elongation * ( + elongation_LCFS - 1 + ) + spr_face = ( + 2 * np.pi * rho_face * elongation_face * rho_b + + area_face / elongation_face * (elongation_LCFS - 1) + ) + + delta_face = np.zeros(len(rho_face)) + + # Geometry variables for general geometry form of transport equations. + # With circular geometry approximation. + + # g0: <\nabla V> + g0 = vpr / rho_b + g0_face = vpr_face / rho_b + + # g1: <(\nabla V)^2> + g1 = vpr**2 / rho_b**2 + g1_face = vpr_face**2 / rho_b**2 + + # g2: <(\nabla V)^2 / R^2> + g2 = g1 / Rmaj**2 + g2_face = g1_face / Rmaj**2 + + # g3: <1/R^2> (done without a elongation correction) + # <1/R^2> = + # 1/2pi*int_0^2pi (1/(Rmaj+r*cosx)^2)dx = + # 1/( Rmaj^2 * (1 - (r/Rmaj)^2)^3/2 ) + g3 = 1 / (Rmaj**2 * (1 - (rho / Rmaj) ** 2) ** (3.0 / 2.0)) + g3_face = 1 / (Rmaj**2 * (1 - (rho_face / Rmaj) ** 2) ** (3.0 / 2.0)) + + # simplifying assumption for now, for J=R*B/(R0*B0) + J = np.ones(len(rho)) + J_face = np.ones(len(rho_face)) + # simplified (constant) version of the F=B*R function + F = np.ones(len(rho)) * Rmaj * B0 + F_face = np.ones(len(rho_face)) * Rmaj * B0 + + # Using an approximation where: + # g2g3_over_rhon = 16 * pi**4 * G2 / (J * R) where: + # G2 = vpr / (4 * pi**2) * <1/R^2> + # This is done due to our ad-hoc elongation assumption, which leads to more + # reasonable values for g2g3_over_rhon through the G2 definition. + # In the future, a more rigorous analytical geometry will be developed and + # the direct definition of g2g3_over_rhon will be used. + + g2g3_over_rhon = 4 * np.pi**2 * vpr * g3 / (J * Rmaj) + g2g3_over_rhon_face = 4 * np.pi**2 * vpr_face * g3_face / (J_face * Rmaj) + + # High resolution versions for j (plasma current) and psi (poloidal flux) + # manipulations. Needed if psi is initialized from plasma current, which is + # the only option for ad-hoc circular geometry. + rho_hires_norm = np.linspace(0, 1, n_rho * hires_fac) + rho_hires = rho_hires_norm * rho_b + + Rout = Rmaj + rho + Rout_face = Rmaj + rho_face + + Rin = Rmaj - rho + Rin_face = Rmaj - rho_face + + # assumed elongation profile on hires grid + elongation_hires = 1 + rho_hires_norm * (elongation_LCFS - 1) + + volume_hires = 2 * np.pi**2 * Rmaj * rho_hires**2 * elongation_hires + area_hires = np.pi * rho_hires**2 * elongation_hires + + # V' = dV/drnorm for volume integrations on hires grid + vpr_hires = ( + 4 * np.pi**2 * Rmaj * rho_hires * elongation_hires * rho_b + + volume_hires / elongation_hires * (elongation_LCFS - 1) + ) + # S' = dS/drnorm for area integrals on hires grid + spr_hires = ( + 2 * np.pi * rho_hires * elongation_hires * rho_b + + area_hires / elongation_hires * (elongation_LCFS - 1) + ) + + g3_hires = 1 / (Rmaj**2 * (1 - (rho_hires / Rmaj) ** 2) ** (3.0 / 2.0)) + F_hires = np.ones(len(rho_hires)) * B0 * Rmaj + g2g3_over_rhon_hires = 4 * np.pi**2 * vpr_hires * g3_hires * B0 / F_hires + + return CircularAnalyticalGeometry( + # Set the standard geometry params. + geometry_type=geometry.GeometryType.CIRCULAR.value, + drho_norm=np.asarray(drho_norm), + torax_mesh=mesh, + Phi=Phi, + Phi_face=Phi_face, + Rmaj=Rmaj, + Rmin=rho_b, + B0=B0, + volume=volume, + volume_face=volume_face, + area=area, + area_face=area_face, + vpr=vpr, + vpr_face=vpr_face, + spr_cell=spr_cell, + spr_face=spr_face, + delta_face=delta_face, + g0=g0, + g0_face=g0_face, + g1=g1, + g1_face=g1_face, + g2=g2, + g2_face=g2_face, + g3=g3, + g3_face=g3_face, + g2g3_over_rhon=g2g3_over_rhon, + g2g3_over_rhon_face=g2g3_over_rhon_face, + g2g3_over_rhon_hires=g2g3_over_rhon_hires, + F=F, + F_face=F_face, + F_hires=F_hires, + Rin=Rin, + Rin_face=Rin_face, + Rout=Rout, + Rout_face=Rout_face, + # Set the circular geometry-specific params. + elongation=elongation, + elongation_face=elongation_face, + volume_hires=volume_hires, + area_hires=area_hires, + spr_hires=spr_hires, + rho_hires_norm=rho_hires_norm, + rho_hires=rho_hires, + elongation_hires=elongation_hires, + vpr_hires=vpr_hires, + # always initialize Phibdot as zero. It will be replaced once both geo_t + # and geo_t_plus_dt are provided, and set to be the same for geo_t and + # geo_t_plus_dt for each given time interval. + Phibdot=np.asarray(0.0), + _z_magnetic_axis=np.asarray(0.0), + ) diff --git a/torax/geometry/geometry.py b/torax/geometry/geometry.py index c6bdf2b1..14deadca 100644 --- a/torax/geometry/geometry.py +++ b/torax/geometry/geometry.py @@ -379,30 +379,6 @@ def __call__(self, t: chex.Numeric) -> Geometry: return self._get_geometry_base(t, Geometry) -@chex.dataclass(frozen=True) -class CircularAnalyticalGeometry(Geometry): - """Circular geometry type used for testing only. - - Most users should default to using the Geometry class. - """ - - elongation_hires: chex.Array - - -@chex.dataclass(frozen=True) -class CircularAnalyticalGeometryProvider(GeometryProvider): - """Circular geometry type used for testing only. - - Most users should default to using the GeometryProvider class. - """ - - elongation_hires: interpolated_param.InterpolatedVarSingleAxis - - def __call__(self, t: chex.Numeric) -> Geometry: - """Returns a Geometry instance at the given time.""" - return self._get_geometry_base(t, CircularAnalyticalGeometry) - - @chex.dataclass(frozen=True) class StandardGeometry(Geometry): """Standard geometry object including additional useful attributes, like psi. @@ -441,227 +417,6 @@ def __call__(self, t: chex.Numeric) -> Geometry: return self._get_geometry_base(t, StandardGeometry) -def build_circular_geometry( - n_rho: int = 25, - elongation_LCFS: float = 1.72, - Rmaj: float = 6.2, - Rmin: float = 2.0, - B0: float = 5.3, - hires_fac: int = 4, -) -> CircularAnalyticalGeometry: - """Constructs a CircularAnalyticalGeometry. - - This is the standard entrypoint for building a circular geometry, not - CircularAnalyticalGeometry.__init__(). chex.dataclasses do not allow - overriding __init__ functions with different parameters than the attributes of - the dataclass, so this builder function lives outside the class. - - Args: - n_rho: Radial grid points (num cells) - elongation_LCFS: Elongation at last closed flux surface. Defaults to 1.72 - for the ITER elongation, to approximately correct volume and area integral - Jacobians. - Rmaj: major radius (R) in meters - Rmin: minor radius (a) in meters - B0: Toroidal magnetic field on axis [T] - hires_fac: Grid refinement factor for poloidal flux <--> plasma current - calculations. - - Returns: - A CircularAnalyticalGeometry instance. - """ - # circular geometry assumption of r/Rmin = rho_norm, the normalized - # toroidal flux coordinate. - drho_norm = 1.0 / n_rho - # Define mesh (Slab Uniform 1D with Jacobian = 1) - mesh = Grid1D.construct(nx=n_rho, dx=drho_norm) - # toroidal flux coordinate (rho) at boundary (last closed flux surface) - rho_b = np.asarray(Rmin) - - # normalized and unnormalized toroidal flux coordinate (rho) - # on face and cell grids. See fvm documentation and paper for details on - # face and cell grids. - rho_face_norm = mesh.face_centers - rho_norm = mesh.cell_centers - rho_face = rho_face_norm * rho_b - rho = rho_norm * rho_b - - Rmaj = np.array(Rmaj) - B0 = np.array(B0) - - # Define toroidal flux - Phi = np.pi * B0 * rho**2 - Phi_face = np.pi * B0 * rho_face**2 - - # Elongation profile. - # Set to be a linearly increasing function from 1 to elongation_LCFS, which - # is the elongation value at the last closed flux surface, set in config. - elongation = 1 + rho_norm * (elongation_LCFS - 1) - elongation_face = 1 + rho_face_norm * (elongation_LCFS - 1) - - # Volume in elongated circular geometry is given by: - # V = 2*pi^2*R*rho^2*elongation - # S = pi*rho^2*elongation - - volume = 2 * np.pi**2 * Rmaj * rho**2 * elongation - volume_face = 2 * np.pi**2 * Rmaj * rho_face**2 * elongation_face - area = np.pi * rho**2 * elongation - area_face = np.pi * rho_face**2 * elongation_face - - # V' = dV/drnorm for volume integrations - # \nabla V = 4*pi^2*R*rho*elongation - # + V * (elongation_param - 1) / elongation / rho_b - # vpr = \nabla V * rho_b - vpr = 4 * np.pi**2 * Rmaj * rho * elongation * rho_b + volume / elongation * ( - elongation_LCFS - 1 - ) - vpr_face = ( - 4 * np.pi**2 * Rmaj * rho_face * elongation_face * rho_b - + volume_face / elongation_face * (elongation_LCFS - 1) - ) - # pylint: disable=invalid-name - # S' = dS/drnorm for area integrals on cell grid - spr_cell = 2 * np.pi * rho * elongation * rho_b + area / elongation * ( - elongation_LCFS - 1 - ) - spr_face = ( - 2 * np.pi * rho_face * elongation_face * rho_b - + area_face / elongation_face * (elongation_LCFS - 1) - ) - - delta_face = np.zeros(len(rho_face)) - - # Geometry variables for general geometry form of transport equations. - # With circular geometry approximation. - - # g0: <\nabla V> - g0 = vpr / rho_b - g0_face = vpr_face / rho_b - - # g1: <(\nabla V)^2> - g1 = vpr**2 / rho_b**2 - g1_face = vpr_face**2 / rho_b**2 - - # g2: <(\nabla V)^2 / R^2> - g2 = g1 / Rmaj**2 - g2_face = g1_face / Rmaj**2 - - # g3: <1/R^2> (done without a elongation correction) - # <1/R^2> = - # 1/2pi*int_0^2pi (1/(Rmaj+r*cosx)^2)dx = - # 1/( Rmaj^2 * (1 - (r/Rmaj)^2)^3/2 ) - g3 = 1 / (Rmaj**2 * (1 - (rho / Rmaj) ** 2) ** (3.0 / 2.0)) - g3_face = 1 / (Rmaj**2 * (1 - (rho_face / Rmaj) ** 2) ** (3.0 / 2.0)) - - # simplifying assumption for now, for J=R*B/(R0*B0) - J = np.ones(len(rho)) - J_face = np.ones(len(rho_face)) - # simplified (constant) version of the F=B*R function - F = np.ones(len(rho)) * Rmaj * B0 - F_face = np.ones(len(rho_face)) * Rmaj * B0 - - # Using an approximation where: - # g2g3_over_rhon = 16 * pi**4 * G2 / (J * R) where: - # G2 = vpr / (4 * pi**2) * <1/R^2> - # This is done due to our ad-hoc elongation assumption, which leads to more - # reasonable values for g2g3_over_rhon through the G2 definition. - # In the future, a more rigorous analytical geometry will be developed and - # the direct definition of g2g3_over_rhon will be used. - - g2g3_over_rhon = 4 * np.pi**2 * vpr * g3 / (J * Rmaj) - g2g3_over_rhon_face = 4 * np.pi**2 * vpr_face * g3_face / (J_face * Rmaj) - - # High resolution versions for j (plasma current) and psi (poloidal flux) - # manipulations. Needed if psi is initialized from plasma current, which is - # the only option for ad-hoc circular geometry. - rho_hires_norm = np.linspace(0, 1, n_rho * hires_fac) - rho_hires = rho_hires_norm * rho_b - - Rout = Rmaj + rho - Rout_face = Rmaj + rho_face - - Rin = Rmaj - rho - Rin_face = Rmaj - rho_face - - # assumed elongation profile on hires grid - elongation_hires = 1 + rho_hires_norm * (elongation_LCFS - 1) - - volume_hires = 2 * np.pi**2 * Rmaj * rho_hires**2 * elongation_hires - area_hires = np.pi * rho_hires**2 * elongation_hires - - # V' = dV/drnorm for volume integrations on hires grid - vpr_hires = ( - 4 * np.pi**2 * Rmaj * rho_hires * elongation_hires * rho_b - + volume_hires / elongation_hires * (elongation_LCFS - 1) - ) - # S' = dS/drnorm for area integrals on hires grid - spr_hires = ( - 2 * np.pi * rho_hires * elongation_hires * rho_b - + area_hires / elongation_hires * (elongation_LCFS - 1) - ) - - g3_hires = 1 / (Rmaj**2 * (1 - (rho_hires / Rmaj) ** 2) ** (3.0 / 2.0)) - F_hires = np.ones(len(rho_hires)) * B0 * Rmaj - g2g3_over_rhon_hires = 4 * np.pi**2 * vpr_hires * g3_hires * B0 / F_hires - - return CircularAnalyticalGeometry( - # Set the standard geometry params. - geometry_type=GeometryType.CIRCULAR.value, - drho_norm=np.asarray(drho_norm), - torax_mesh=mesh, - Phi=Phi, - Phi_face=Phi_face, - Rmaj=Rmaj, - Rmin=rho_b, - B0=B0, - volume=volume, - volume_face=volume_face, - area=area, - area_face=area_face, - vpr=vpr, - vpr_face=vpr_face, - spr_cell=spr_cell, - spr_face=spr_face, - delta_face=delta_face, - g0=g0, - g0_face=g0_face, - g1=g1, - g1_face=g1_face, - g2=g2, - g2_face=g2_face, - g3=g3, - g3_face=g3_face, - g2g3_over_rhon=g2g3_over_rhon, - g2g3_over_rhon_face=g2g3_over_rhon_face, - g2g3_over_rhon_hires=g2g3_over_rhon_hires, - F=F, - F_face=F_face, - F_hires=F_hires, - Rin=Rin, - Rin_face=Rin_face, - Rout=Rout, - Rout_face=Rout_face, - # Set the circular geometry-specific params. - elongation=elongation, - elongation_face=elongation_face, - volume_hires=volume_hires, - area_hires=area_hires, - spr_hires=spr_hires, - rho_hires_norm=rho_hires_norm, - rho_hires=rho_hires, - elongation_hires=elongation_hires, - vpr_hires=vpr_hires, - # always initialize Phibdot as zero. It will be replaced once both geo_t - # and geo_t_plus_dt are provided, and set to be the same for geo_t and - # geo_t_plus_dt for each given time interval. - Phibdot=np.asarray(0.0), - _z_magnetic_axis=np.asarray(0.0), - ) - - -# pylint: disable=invalid-name - - @dataclasses.dataclass(frozen=True) class StandardGeometryIntermediates: """Holds the intermediate values used to build a StandardGeometry. diff --git a/torax/geometry/geometry_provider.py b/torax/geometry/geometry_provider.py index 9e7734b8..93761f7b 100644 --- a/torax/geometry/geometry_provider.py +++ b/torax/geometry/geometry_provider.py @@ -42,7 +42,7 @@ class GeometryProvider(Protocol): .. code-block:: python - geo = geometry.build_circular_geometry(...) + geo = circular_geometry.build_circular_geometry(...) constant_geo_provider = lamdba t: geo def func_expecting_geo_provider(gp: GeometryProvider): diff --git a/torax/geometry/tests/circular_geometry_test.py b/torax/geometry/tests/circular_geometry_test.py new file mode 100644 index 00000000..863d0efd --- /dev/null +++ b/torax/geometry/tests/circular_geometry_test.py @@ -0,0 +1,63 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import jax +import numpy as np +from torax.geometry import circular_geometry +from torax.geometry import geometry + + +class CircularGeometryTest(absltest.TestCase): + + def test_build_geometry_provider_from_circular(self): + """Test that the circular geometry provider can be built.""" + geo_0 = circular_geometry.build_circular_geometry( + n_rho=25, + elongation_LCFS=1.72, + Rmaj=6.2, + Rmin=2.0, + B0=5.3, + hires_fac=4, + ) + geo_1 = circular_geometry.build_circular_geometry( + n_rho=25, + elongation_LCFS=1.72, + Rmaj=7.2, + Rmin=1.0, + B0=5.3, + hires_fac=4, + ) + provider = ( + circular_geometry.CircularAnalyticalGeometryProvider.create_provider( + {0.0: geo_0, 10.0: geo_1} + ) + ) + geo = provider(5.0) + np.testing.assert_allclose(geo.Rmaj, 6.7) + np.testing.assert_allclose(geo.Rmin, 1.5) + + def test_circular_geometry_can_be_input_to_jitted_function(self): + + @jax.jit + def foo(geo: geometry.Geometry): + return geo.Rmaj + + geo = circular_geometry.build_circular_geometry() + # Make sure you can call the function with geo as an arg. + foo(geo) + + +if __name__ == "__main__": + absltest.main() diff --git a/torax/geometry/tests/geometry_test.py b/torax/geometry/tests/geometry_test.py index 5230e194..ca774911 100644 --- a/torax/geometry/tests/geometry_test.py +++ b/torax/geometry/tests/geometry_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dataclasses import os from absl.testing import absltest @@ -52,23 +51,6 @@ def test_face_to_cell(self, n_rho, seed): np.testing.assert_allclose(cell_jax, cell_np) - def test_frozen(self): - """Test that the Geometry class is frozen.""" - geo = geometry.build_circular_geometry() - with self.assertRaises(dataclasses.FrozenInstanceError): - geo.drho_norm = 0.1 - - def test_circular_geometry_can_be_input_to_jitted_function(self): - """Test that a circular geometry can be input to a jitted function.""" - - def foo(geo: geometry.Geometry): - _ = geo # do nothing. - - foo_jitted = jax.jit(foo) - geo = geometry.build_circular_geometry() - # Make sure you can call the function with geo as an arg. - foo_jitted(geo) - def test_standard_geometry_can_be_input_to_jitted_function(self): """Test that a StandardGeometry can be input to a jitted function.""" @@ -176,31 +158,6 @@ def test_build_geometry_provider(self): np.testing.assert_allclose(geo.Rmin, 1.5) np.testing.assert_allclose(geo.B0, 5.9) - def test_build_geometry_provider_from_circular(self): - """Test that the circular geometry provider can be built.""" - geo_0 = geometry.build_circular_geometry( - n_rho=25, - elongation_LCFS=1.72, - Rmaj=6.2, - Rmin=2.0, - B0=5.3, - hires_fac=4, - ) - geo_1 = geometry.build_circular_geometry( - n_rho=25, - elongation_LCFS=1.72, - Rmaj=7.2, - Rmin=1.0, - B0=5.3, - hires_fac=4, - ) - provider = geometry.CircularAnalyticalGeometryProvider.create_provider( - {0.0: geo_0, 10.0: geo_1} - ) - geo = provider(5.0) - np.testing.assert_allclose(geo.Rmaj, 6.7) - np.testing.assert_allclose(geo.Rmin, 1.5) - @parameterized.parameters([ dict(invalid_key='rBt', invalid_shape=(2,)), dict(invalid_key='aminor', invalid_shape=(10, 3)), @@ -299,18 +256,6 @@ def test_build_geometry_from_eqdsk(self, geometry_file): ) geometry.build_standard_geometry(intermediate) - def test_geometry_objects_can_be_used_in_jax_jitted_functions(self): - """Test public API of geometry objects can be used in jitted functions.""" - geo = geometry.build_circular_geometry() - - @jax.jit - def f(geo: geometry.Geometry): - for field in dir(geo): - if not field.startswith('_'): - getattr(geo, field) - - f(geo) - def test_access_z_magnetic_axis_raises_error_for_chease_geometry(self): """Test that accessing z_magnetic_axis raises error for CHEASE geometry.""" intermediate = geometry.StandardGeometryIntermediates.from_chease() diff --git a/torax/pedestal_model/tests/set_pped_tpedratio_nped.py b/torax/pedestal_model/tests/set_pped_tpedratio_nped.py index 39af5858..33180b3f 100644 --- a/torax/pedestal_model/tests/set_pped_tpedratio_nped.py +++ b/torax/pedestal_model/tests/set_pped_tpedratio_nped.py @@ -19,7 +19,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.pedestal_model import set_pped_tpedratio_nped from torax.sources import source_models as source_models_lib @@ -31,7 +31,7 @@ class SetPressureTemperatureRatioAndDensityPedestalModelTest( def test_runtime_params_builds_dynamic_params(self): runtime_params = set_pped_tpedratio_nped.RuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) @@ -61,14 +61,14 @@ def test_build_and_call_pedestal_model( runtime_params = general_runtime_params.GeneralRuntimeParams() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, sources=source_models_builder.runtime_params, torax_mesh=geo.torax_mesh, pedestal=pedestal_runtime_params, ) - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() builder = set_pped_tpedratio_nped.SetPressureTemperatureRatioAndDensityPedestalModelBuilder( runtime_params=pedestal_runtime_params ) diff --git a/torax/pedestal_model/tests/set_tped_nped.py b/torax/pedestal_model/tests/set_tped_nped.py index 9e6e0939..51dd823c 100644 --- a/torax/pedestal_model/tests/set_tped_nped.py +++ b/torax/pedestal_model/tests/set_tped_nped.py @@ -18,7 +18,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.pedestal_model import set_tped_nped from torax.sources import source_models as source_models_lib @@ -28,7 +28,7 @@ class SetTemperatureDensityPedestalModelTest(parameterized.TestCase): def test_runtime_params_builds_dynamic_params(self): runtime_params = set_tped_nped.RuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) @@ -62,14 +62,14 @@ def test_build_and_call_pedestal_model( runtime_params = general_runtime_params.GeneralRuntimeParams() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, sources=source_models_builder.runtime_params, torax_mesh=geo.torax_mesh, pedestal=pedestal_runtime_params, ) - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() builder = set_tped_nped.SetTemperatureDensityPedestalModelBuilder( runtime_params=pedestal_runtime_params ) diff --git a/torax/sources/tests/bootstrap_current_source.py b/torax/sources/tests/bootstrap_current_source.py index 421d6b6b..6c28a892 100644 --- a/torax/sources/tests/bootstrap_current_source.py +++ b/torax/sources/tests/bootstrap_current_source.py @@ -17,7 +17,7 @@ from absl.testing import absltest import jax.numpy as jnp import numpy as np -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.sources import bootstrap_current_source from torax.sources import source as source_lib from torax.sources import source_profiles @@ -39,7 +39,7 @@ def setUpClass(cls): def test_extraction_of_relevant_profile_from_output(self): """Tests that the relevant profile is extracted from the output.""" source = bootstrap_current_source.BootstrapCurrentSource() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() cell = source_lib.ProfileType.CELL.get_profile_shape(geo) face = source_lib.ProfileType.FACE.get_profile_shape(geo) fake_profile = source_profiles.BootstrapCurrentProfile( diff --git a/torax/sources/tests/electron_cyclotron_source.py b/torax/sources/tests/electron_cyclotron_source.py index d57fe5d1..ac8c6fca 100644 --- a/torax/sources/tests/electron_cyclotron_source.py +++ b/torax/sources/tests/electron_cyclotron_source.py @@ -21,7 +21,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.sources import electron_cyclotron_source from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib @@ -55,7 +55,7 @@ def test_source_value(self): source = source_models.sources[self._source_name] source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED self.assertIsInstance(source, source_lib.Source) - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, @@ -93,7 +93,7 @@ def test_source_value(self): def test_extraction_of_relevant_profile_from_output(self): """Tests that the relevant profile is extracted from the output.""" - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() source = self._source_class() cell = source_lib.ProfileType.CELL.get_profile_shape(geo) fake_profile = jnp.stack((jnp.ones(cell), 2 * jnp.ones(cell))) diff --git a/torax/sources/tests/generic_current_source.py b/torax/sources/tests/generic_current_source.py index 0a442968..45199856 100644 --- a/torax/sources/tests/generic_current_source.py +++ b/torax/sources/tests/generic_current_source.py @@ -20,7 +20,7 @@ import numpy as np from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.sources import generic_current_source from torax.sources import source as source_lib from torax.sources.tests import test_lib @@ -40,7 +40,7 @@ def setUpClass(cls): def test_profile_is_on_cell_grid(self): """Tests that the profile is given on the cell grid.""" - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() source_builder = self._source_class_builder() source = source_builder() self.assertEqual( @@ -99,7 +99,7 @@ def test_get_source_profile_for_affected_core_profile_with( source = source_builder() # Build a face profile with 3 values on a 2-cell grid. - geo = geometry.build_circular_geometry(n_rho=2) + geo = circular_geometry.build_circular_geometry(n_rho=2) cell_profile = np.array([1.5, 2.5]) np.testing.assert_allclose( diff --git a/torax/sources/tests/impurity_radiation_heat_sink.py b/torax/sources/tests/impurity_radiation_heat_sink.py index a7c3307b..f5def00a 100644 --- a/torax/sources/tests/impurity_radiation_heat_sink.py +++ b/torax/sources/tests/impurity_radiation_heat_sink.py @@ -23,7 +23,7 @@ from torax import math_utils from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.sources import generic_ion_el_heat_source from torax.sources import ( impurity_radiation_heat_sink as impurity_radiation_heat_sink_lib, @@ -87,7 +87,7 @@ def test_source_value(self): self.assertIsInstance(impurity_radiation_sink, source_lib.Source) # Geometry, profiles, and dynamic runtime params - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, @@ -153,7 +153,7 @@ def test_source_value(self): def test_extraction_of_relevant_profile_from_output(self): """Tests that the relevant profile is extracted from the output.""" - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() source_builder = self._source_class_builder() source_models_builder = source_models_lib.SourceModelsBuilder( {self._source_name: source_builder}, diff --git a/torax/sources/tests/ion_cyclotron_source.py b/torax/sources/tests/ion_cyclotron_source.py index ab3ab198..88bb772f 100644 --- a/torax/sources/tests/ion_cyclotron_source.py +++ b/torax/sources/tests/ion_cyclotron_source.py @@ -27,7 +27,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.sources import ion_cyclotron_source from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib @@ -137,7 +137,7 @@ def test_source_value(self, mock_path): source_builder = self._source_class_builder() # pytype: disable=missing-parameter # pylint: enable=missing-kwoa runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() source_models_builder = source_models_lib.SourceModelsBuilder( {ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME: source_builder}, ) diff --git a/torax/sources/tests/qei_source.py b/torax/sources/tests/qei_source.py index 7c2ab405..f9c4cacb 100644 --- a/torax/sources/tests/qei_source.py +++ b/torax/sources/tests/qei_source.py @@ -17,7 +17,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.sources import qei_source from torax.sources import source_models as source_models_lib from torax.sources.tests import test_lib @@ -44,7 +44,7 @@ def test_source_value(self): source_models = source_models_builder() source = source_models.sources['qei_source'] runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() static_slice = runtime_params_slice.build_static_runtime_params_slice( runtime_params=runtime_params, source_runtime_params=source_models_builder.runtime_params, diff --git a/torax/sources/tests/source.py b/torax/sources/tests/source.py index 1dfa98f0..553a743a 100644 --- a/torax/sources/tests/source.py +++ b/torax/sources/tests/source.py @@ -23,6 +23,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib @@ -174,7 +175,7 @@ def test_zero_profile_works_by_default(self): source_models = source_models_builder() source = source_models.sources['foo'] runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, @@ -226,7 +227,7 @@ def test_correct_mode_called(self, mode, expected_profile): source = source_models.sources['foo'] source_runtime_params = source_models_builder.runtime_params runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry(n_rho=4) + geo = circular_geometry.build_circular_geometry(n_rho=4) source_runtime_params['foo'] = dataclasses.replace( source_models_builder.runtime_params['foo'], mode=mode, @@ -272,7 +273,7 @@ def test_defaults_output_zeros(self): source_models = source_models_builder() source = source_models.sources['foo'] runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, @@ -335,7 +336,7 @@ def test_defaults_output_zeros(self): def test_overriding_model(self): """The user-specified model should override the default model.""" - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() output_shape = source_lib.ProfileType.CELL.get_profile_shape(geo) expected_output = jnp.ones(output_shape) source_builder = source_lib.make_source_builder( @@ -379,7 +380,7 @@ def test_overriding_model(self): def test_overriding_prescribed_values(self): """Providing prescribed values results in the correct profile.""" - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() output_shape = source_lib.ProfileType.CELL.get_profile_shape(geo) # Define the expected output expected_output = jnp.ones(output_shape) @@ -447,7 +448,7 @@ def affected_core_profiles(self): source = TestSource( model_func=lambda _0, _1, _2, _3, _4, _5: profile, ) - geo = geometry.build_circular_geometry(n_rho=4) + geo = circular_geometry.build_circular_geometry(n_rho=4) psi_profile = source.get_source_profile_for_affected_core_profile( profile, source_lib.AffectedCoreProfile.PSI.value, geo ) @@ -476,7 +477,7 @@ def test_retrieving_profile_for_affected_state(self): source = test_lib.TestSource( model_func=lambda _0, _1, _2, _3, _4, _5: profile, ) - geo = geometry.build_circular_geometry(n_rho=4) + geo = circular_geometry.build_circular_geometry(n_rho=4) psi_profile = source.get_source_profile_for_affected_core_profile( profile, source_lib.AffectedCoreProfile.PSI.value, geo ) diff --git a/torax/sources/tests/source_models.py b/torax/sources/tests/source_models.py index b6a33b79..63dbc029 100644 --- a/torax/sources/tests/source_models.py +++ b/torax/sources/tests/source_models.py @@ -24,6 +24,7 @@ from torax import core_profile_setters from torax.config import runtime_params as runtime_params_lib from torax.config import runtime_params_slice +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.sources import generic_current_source from torax.sources import runtime_params as source_runtime_params_lib @@ -67,7 +68,7 @@ class SourceProfilesTest(parameterized.TestCase): def test_computing_source_profiles_works_with_all_defaults(self): """Tests that you can compute source profiles with all defaults.""" runtime_params = runtime_params_lib.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() dynamic_runtime_params_slice = ( @@ -118,7 +119,7 @@ def test_computing_source_profiles_works_with_all_defaults(self): def test_summed_temp_ion_profiles_dont_change_when_jitting(self): """Test that sum_sources_temp_{ion|el} works with jitting.""" - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() # Use the default sources where the generic_ion_el_heat_source, # fusion_heat_source, and ohmic_heat_source are included and produce @@ -193,7 +194,7 @@ def foo_formula( ) source_models = source_models_builder() runtime_params = runtime_params_lib.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, @@ -285,7 +286,7 @@ def source_name(self) -> str: source_models_builder.runtime_params['bar'].prescribed_values = 1 source_models = source_models_builder() runtime_params = runtime_params_lib.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, diff --git a/torax/sources/tests/source_runtime_params.py b/torax/sources/tests/source_runtime_params.py index db94aa6c..588d892f 100644 --- a/torax/sources/tests/source_runtime_params.py +++ b/torax/sources/tests/source_runtime_params.py @@ -13,7 +13,7 @@ # limitations under the License. """Tests for runtime params for sources.""" from absl.testing import absltest -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.sources import runtime_params as runtime_params_lib @@ -21,7 +21,7 @@ class RuntimeParamsTest(absltest.TestCase): def test_runtime_params_builds_dynamic_params(self): runtime_params = runtime_params_lib.RuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = runtime_params.make_provider(geo.torax_mesh) dynamic_params = provider.build_dynamic_params(t=0.0) self.assertIsInstance( diff --git a/torax/sources/tests/test_lib.py b/torax/sources/tests/test_lib.py index 9b45fbc7..8c253f28 100644 --- a/torax/sources/tests/test_lib.py +++ b/torax/sources/tests/test_lib.py @@ -24,7 +24,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib @@ -91,7 +91,7 @@ def setUpClass( def test_runtime_params_builds_dynamic_params(self): runtime_params = self._runtime_params_class() self.assertIsInstance(runtime_params, runtime_params_lib.RuntimeParams) - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = runtime_params.make_provider(geo.torax_mesh) dynamic_params = provider.build_dynamic_params(t=0.0) self.assertIsInstance( @@ -140,7 +140,7 @@ def test_source_value(self): source = source_models.sources[self._source_name] source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED self.assertIsInstance(source, source_lib.Source) - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, @@ -179,7 +179,7 @@ def test_source_value(self): source_builder = self._source_class_builder() # pylint: enable=missing-kwoa runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() source_models_builder = source_models_lib.SourceModelsBuilder( {self._source_name: source_builder}, ) @@ -216,7 +216,7 @@ def test_source_value(self): def test_extraction_of_relevant_profile_from_output(self): """Tests that the relevant profile is extracted from the output.""" - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() # pylint: disable=missing-kwoa source = self._source_class() # pytype: disable=missing-parameter # pylint: enable=missing-kwoa diff --git a/torax/tests/boundary_conditions.py b/torax/tests/boundary_conditions.py index b52cbc1a..6be50512 100644 --- a/torax/tests/boundary_conditions.py +++ b/torax/tests/boundary_conditions.py @@ -24,7 +24,7 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.sources import source_models as source_models_lib @@ -69,7 +69,7 @@ def test_setting_boundary_conditions( ), ) - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() initial_dynamic_runtime_params_slice = ( diff --git a/torax/tests/core_profile_setters_test.py b/torax/tests/core_profile_setters_test.py index b1d755b7..df1f7858 100644 --- a/torax/tests/core_profile_setters_test.py +++ b/torax/tests/core_profile_setters_test.py @@ -24,7 +24,7 @@ from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice as runtime_params_slice_lib from torax.fvm import cell_variable -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.sources import source_models as source_models_lib from torax.stepper import runtime_params as stepper_params_lib from torax.transport_model import runtime_params as transport_params_lib @@ -40,7 +40,7 @@ class CoreProfileSettersTest(parameterized.TestCase): def setUp(self): super().setUp() jax_utils.enable_errors(True) - self.geo = geometry.build_circular_geometry(n_rho=4) + self.geo = circular_geometry.build_circular_geometry(n_rho=4) def test_updated_ion_temperature(self): bound = np.array(42.) diff --git a/torax/tests/math_utils.py b/torax/tests/math_utils.py index 1e4c7aa7..898cfb3e 100644 --- a/torax/tests/math_utils.py +++ b/torax/tests/math_utils.py @@ -21,6 +21,7 @@ import numpy as np import scipy.integrate from torax import math_utils +from torax.geometry import circular_geometry from torax.geometry import geometry jax.config.update('jax_enable_x64', True) @@ -72,7 +73,7 @@ def test_cell_integration(self, num_cell_grid_points: int): x = jax.random.uniform( jax.random.PRNGKey(0), shape=(num_cell_grid_points + 1,) ) - geo = geometry.build_circular_geometry(n_rho=num_cell_grid_points) + geo = circular_geometry.build_circular_geometry(n_rho=num_cell_grid_points) np.testing.assert_allclose( math_utils.cell_integration(geometry.face_to_cell(x), geo), @@ -143,7 +144,7 @@ def test_cell_to_face( preserved_quantity: math_utils.IntegralPreservationQuantity, ): """Test that the cell_to_face method works as expected.""" - geo = geometry.build_circular_geometry(n_rho=len(cell_values)) + geo = circular_geometry.build_circular_geometry(n_rho=len(cell_values)) cell_values = jnp.array(cell_values, dtype=jnp.float32) face_values = math_utils.cell_to_face(cell_values, geo, preserved_quantity) @@ -176,7 +177,7 @@ def test_cell_to_face( def test_cell_to_face_raises_when_too_few_values(self,): """Test that the cell_to_face method raises when too few values are provided.""" - geo = geometry.build_circular_geometry(n_rho=1) + geo = circular_geometry.build_circular_geometry(n_rho=1) with self.assertRaises(ValueError): math_utils.cell_to_face(jnp.array([1.0], dtype=np.float32), geo) diff --git a/torax/tests/output.py b/torax/tests/output.py index 205d2ee8..604b95c6 100644 --- a/torax/tests/output.py +++ b/torax/tests/output.py @@ -27,7 +27,7 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice as runtime_params_slice_lib -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.geometry import geometry_provider from torax.sources import source as source_lib from torax.sources import source_profiles as source_profiles_lib @@ -56,7 +56,7 @@ def setUp(self): source_models_builder = default_sources.get_default_sources_builder() source_models = source_models_builder() # Make some dummy source profiles that could have come from these sources. - self.geo = geometry.build_circular_geometry() + self.geo = circular_geometry.build_circular_geometry() ones = jnp.ones(source_lib.ProfileType.CELL.get_profile_shape(self.geo)) geo_provider = geometry_provider.ConstantGeometryProvider(self.geo) dynamic_runtime_params_slice, geo = ( diff --git a/torax/tests/physics.py b/torax/tests/physics.py index 9628cff9..561aceb0 100644 --- a/torax/tests/physics.py +++ b/torax/tests/physics.py @@ -27,6 +27,7 @@ from torax import state from torax.config import runtime_params_slice from torax.fvm import cell_variable +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.sources import generic_current_source from torax.sources import runtime_params as source_runtime_params @@ -135,7 +136,7 @@ def test_update_psi_from_j( ) # pylint: disable=protected-access - if isinstance(geo, geometry.CircularAnalyticalGeometry): + if isinstance(geo, circular_geometry.CircularAnalyticalGeometry): currents = core_profile_setters._prescribe_currents_no_bootstrap( static_slice, dynamic_runtime_params_slice, @@ -314,7 +315,7 @@ def test_get_main_ion_dilution_factor(self, Zi, Zimp, Zeff, expected): def test_calculate_plh_scaling_factor(self): """Compare `calculate_plh_scaling_factor` to a reference value.""" - geo = geometry.build_circular_geometry( + geo = circular_geometry.build_circular_geometry( n_rho=25, elongation_LCFS=1.0, hires_fac=4, @@ -411,7 +412,7 @@ def test_calculate_plh_scaling_factor(self): # pylint: disable=invalid-name def test_calculate_scaling_law_confinement_time(self, elongation_LCFS): """Compare `calculate_scaling_law_confinement_time` to reference values.""" - geo = geometry.build_circular_geometry( + geo = circular_geometry.build_circular_geometry( n_rho=25, elongation_LCFS=elongation_LCFS, hires_fac=4, diff --git a/torax/tests/post_processing.py b/torax/tests/post_processing.py index 26cf2f80..b985fd8f 100644 --- a/torax/tests/post_processing.py +++ b/torax/tests/post_processing.py @@ -28,7 +28,7 @@ from torax.config import runtime_params as runtime_params_lib from torax.config import runtime_params_slice from torax.fvm import cell_variable -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.geometry import geometry_provider from torax.sources import source_profiles as source_profiles_lib from torax.tests.test_lib import default_sources @@ -42,7 +42,7 @@ class PostProcessingTest(parameterized.TestCase): def setUp(self): super().setUp() runtime_params = runtime_params_lib.GeneralRuntimeParams() - self.geo = geometry.build_circular_geometry() + self.geo = circular_geometry.build_circular_geometry() geo_provider = geometry_provider.ConstantGeometryProvider(self.geo) source_models_builder = default_sources.get_default_sources_builder() source_models = source_models_builder() @@ -159,7 +159,7 @@ def _make_constant_core_profile( def test_compute_stored_thermal_energy(self): """Test that stored thermal energy is computed correctly.""" - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() p_el = np.ones_like(geo.rho_face) p_ion = 2 * np.ones_like(geo.rho_face) p_tot = p_el + p_ion diff --git a/torax/tests/sim.py b/torax/tests/sim.py index 27ebaa7d..08eac104 100644 --- a/torax/tests/sim.py +++ b/torax/tests/sim.py @@ -30,7 +30,7 @@ from torax.config import build_sim as build_sim_lib from torax.config import numerics as numerics_lib from torax.config import runtime_params as runtime_params_lib -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.geometry import geometry_provider from torax.pedestal_model import set_tped_nped from torax.sources import source_models as source_models_lib @@ -492,7 +492,7 @@ def test_no_op(self): time_step_calculator = chi_time_step_calculator.ChiTimeStepCalculator() geo_provider = geometry_provider.ConstantGeometryProvider( - geometry.build_circular_geometry() + circular_geometry.build_circular_geometry() ) sim = sim_lib.Sim.create( @@ -734,7 +734,7 @@ def test_update_new_mesh(self): with self.assertRaisesRegex(ValueError, 'different mesh'): sim.update_base_components( geometry_provider=geometry_provider.ConstantGeometryProvider( - geometry.build_circular_geometry(n_rho=10) + circular_geometry.build_circular_geometry(n_rho=10) ) ) diff --git a/torax/tests/sim_custom_sources.py b/torax/tests/sim_custom_sources.py index ffdbdb4e..c017b656 100644 --- a/torax/tests/sim_custom_sources.py +++ b/torax/tests/sim_custom_sources.py @@ -30,6 +30,7 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.geometry import geometry_provider from torax.pedestal_model import set_tped_nped @@ -189,7 +190,7 @@ def custom_source_formula( 'test_particle_sources_constant.nc', _ALL_PROFILES ) geo_provider = geometry_provider.ConstantGeometryProvider( - geometry.build_circular_geometry() + circular_geometry.build_circular_geometry() ) sim = sim_lib.Sim.create( runtime_params=self.test_particle_sources_constant_runtime_params, diff --git a/torax/tests/sim_output_source_profiles.py b/torax/tests/sim_output_source_profiles.py index e6a643c9..7d3c4370 100644 --- a/torax/tests/sim_output_source_profiles.py +++ b/torax/tests/sim_output_source_profiles.py @@ -28,6 +28,7 @@ from torax import sim as sim_lib from torax import state as state_module from torax.config import runtime_params as general_runtime_params +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.geometry import geometry_provider as geometry_provider_lib from torax.orchestration import step_function @@ -68,7 +69,7 @@ class SimOutputSourceProfilesTest(sim_test_case.SimTestCase): def test_merging_source_profiles(self): """Tests that the implicit and explicit source profiles merge correctly.""" - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() source_models_builder = default_sources.get_default_sources_builder() source_models = source_models_builder() # Technically, the merge_source_profiles() function should be called with @@ -153,7 +154,7 @@ def custom_source_formula( runtime_params = general_runtime_params.GeneralRuntimeParams() runtime_params.numerics.t_final = 2. runtime_params.numerics.fixed_dt = 1. - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() time_stepper = fixed_time_step_calculator.FixedTimeStepCalculator() def mock_step_fn( _, diff --git a/torax/tests/sim_time_dependence.py b/torax/tests/sim_time_dependence.py index e6375960..eff7ab79 100644 --- a/torax/tests/sim_time_dependence.py +++ b/torax/tests/sim_time_dependence.py @@ -29,6 +29,7 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.geometry import geometry_provider as geometry_provider_lib from torax.pedestal_model import pedestal_model as pedestal_model_lib @@ -67,7 +68,7 @@ def test_time_dependent_params_update_in_adaptive_dt( dt_reduction_factor=1.5, ), ) - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() geometry_provider = geometry_provider_lib.ConstantGeometryProvider(geo) transport_builder = FakeTransportModelBuilder() source_models_builder = source_models_lib.SourceModelsBuilder() diff --git a/torax/tests/state.py b/torax/tests/state.py index f20918ca..be42d1bb 100644 --- a/torax/tests/state.py +++ b/torax/tests/state.py @@ -30,6 +30,7 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.geometry import geometry_provider from torax.sources import generic_current_source @@ -172,7 +173,7 @@ def test_initial_boundary_condition_from_time_dependent_params(self): source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() geo_provider = geometry_provider.ConstantGeometryProvider( - geometry.build_circular_geometry() + circular_geometry.build_circular_geometry() ) dynamic_runtime_params_slice, geo = ( torax_refs.build_consistent_dynamic_runtime_params_slice_and_geometry( @@ -206,7 +207,7 @@ def test_core_profiles_quasineutrality_check(self): source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() geo_provider = geometry_provider.ConstantGeometryProvider( - geometry.build_circular_geometry() + circular_geometry.build_circular_geometry() ) dynamic_runtime_params_slice, geo = ( torax_refs.build_consistent_dynamic_runtime_params_slice_and_geometry( @@ -234,7 +235,7 @@ def test_core_profiles_quasineutrality_check(self): assert not core_profiles.quasineutrality_satisfied() @parameterized.parameters([ - dict(geo_builder=geometry.build_circular_geometry), + dict(geo_builder=circular_geometry.build_circular_geometry), dict( geo_builder=lambda: geometry.build_standard_geometry( geometry.StandardGeometryIntermediates.from_chease() @@ -441,7 +442,7 @@ def test_initial_psi_from_geo_noop_circular(self): ne_bound_right=0.5, ), ) - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() dcs1 = runtime_params_slice.DynamicRuntimeParamsSliceProvider( config1, sources=source_models_builder.runtime_params, @@ -470,7 +471,7 @@ def test_initial_psi_from_geo_noop_circular(self): core_profiles1 = core_profile_setters.initial_core_profiles( dynamic_runtime_params_slice=dcs1, static_runtime_params_slice=static_slice, - geo=geometry.build_circular_geometry(), + geo=circular_geometry.build_circular_geometry(), source_models=source_models, ) static_slice = runtime_params_slice.build_static_runtime_params_slice( @@ -481,7 +482,7 @@ def test_initial_psi_from_geo_noop_circular(self): core_profiles2 = core_profile_setters.initial_core_profiles( dynamic_runtime_params_slice=dcs2, static_runtime_params_slice=static_slice, - geo=geometry.build_circular_geometry(), + geo=circular_geometry.build_circular_geometry(), source_models=source_models, ) np.testing.assert_allclose( diff --git a/torax/tests/test_data/test_explicit.py b/torax/tests/test_data/test_explicit.py index 6cd2c320..a022df2b 100644 --- a/torax/tests/test_data/test_explicit.py +++ b/torax/tests/test_data/test_explicit.py @@ -19,7 +19,7 @@ from torax.config import numerics as numerics_lib from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.geometry import geometry_provider from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.pedestal_model import set_tped_nped @@ -49,7 +49,7 @@ def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: def get_geometry_provider() -> geometry_provider.ConstantGeometryProvider: return geometry_provider.ConstantGeometryProvider( - geometry.build_circular_geometry() + circular_geometry.build_circular_geometry() ) diff --git a/torax/tests/test_lib/torax_refs.py b/torax/tests/test_lib/torax_refs.py index 29935596..271d1ede 100644 --- a/torax/tests/test_lib/torax_refs.py +++ b/torax/tests/test_lib/torax_refs.py @@ -26,6 +26,7 @@ from torax.config import config_args from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.geometry import geometry_provider as geometry_provider_lib from torax.sources import runtime_params as sources_params @@ -89,7 +90,7 @@ def circular_references() -> References: }, }, ) - geo = geometry.build_circular_geometry( + geo = circular_geometry.build_circular_geometry( n_rho=25, elongation_LCFS=1.72, hires_fac=4, diff --git a/torax/transport_model/tests/bohm_gyrobohm.py b/torax/transport_model/tests/bohm_gyrobohm.py index e7ab26be..e1a4b99a 100644 --- a/torax/transport_model/tests/bohm_gyrobohm.py +++ b/torax/transport_model/tests/bohm_gyrobohm.py @@ -13,7 +13,7 @@ # limitations under the License. """Tests for bohm_gyrobohm.""" from absl.testing import absltest -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.transport_model import bohm_gyrobohm @@ -21,7 +21,7 @@ class RuntimeParamsTest(absltest.TestCase): def test_runtime_params_builds_dynamic_params(self): runtime_params = bohm_gyrobohm.RuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) diff --git a/torax/transport_model/tests/constant.py b/torax/transport_model/tests/constant.py index 8878495f..023ee024 100644 --- a/torax/transport_model/tests/constant.py +++ b/torax/transport_model/tests/constant.py @@ -13,7 +13,7 @@ # limitations under the License. """Tests for constant transport model.""" from absl.testing import absltest -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.transport_model import constant @@ -21,7 +21,7 @@ class RuntimeParamsTest(absltest.TestCase): def test_runtime_params_builds_dynamic_params(self): runtime_params = constant.RuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) diff --git a/torax/transport_model/tests/critical_gradient.py b/torax/transport_model/tests/critical_gradient.py index eb7436c3..cb4876de 100644 --- a/torax/transport_model/tests/critical_gradient.py +++ b/torax/transport_model/tests/critical_gradient.py @@ -13,7 +13,7 @@ # limitations under the License. """Tests for critical gradient transport model.""" from absl.testing import absltest -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.transport_model import critical_gradient @@ -21,7 +21,7 @@ class RuntimeParamsTest(absltest.TestCase): def test_runtime_params_builds_dynamic_params(self): runtime_params = critical_gradient.RuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) diff --git a/torax/transport_model/tests/qlknn_transport_model.py b/torax/transport_model/tests/qlknn_transport_model.py index 954e1af5..23b06a53 100644 --- a/torax/transport_model/tests/qlknn_transport_model.py +++ b/torax/transport_model/tests/qlknn_transport_model.py @@ -21,7 +21,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.pedestal_model import set_tped_nped from torax.sources import source_models as source_models_lib from torax.transport_model import qlknn_transport_model @@ -38,7 +38,7 @@ def test_qlknn_transport_model_cache_works(self): qlknn_transport_model.get_default_model_path() ) runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() pedestal_model_builder = ( @@ -171,7 +171,7 @@ def test_clip_inputs(self): def test_runtime_params_builds_dynamic_params(self): runtime_params = qlknn_transport_model.RuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) diff --git a/torax/transport_model/tests/qualikiz_based_transport_model.py b/torax/transport_model/tests/qualikiz_based_transport_model.py index 1484d1ff..e2c9fe81 100644 --- a/torax/transport_model/tests/qualikiz_based_transport_model.py +++ b/torax/transport_model/tests/qualikiz_based_transport_model.py @@ -21,6 +21,7 @@ from torax import state from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.pedestal_model import set_tped_nped @@ -33,7 +34,7 @@ def _get_model_inputs(transport: qualikiz_based_transport_model.RuntimeParams): """Returns the model inputs for testing.""" runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() pedestal_model_builder = ( diff --git a/torax/transport_model/tests/qualikiz_transport_model.py b/torax/transport_model/tests/qualikiz_transport_model.py index c5a53cbe..b38f02ee 100644 --- a/torax/transport_model/tests/qualikiz_transport_model.py +++ b/torax/transport_model/tests/qualikiz_transport_model.py @@ -13,7 +13,7 @@ # limitations under the License. """Tests for qualikiz transport_model transport model.""" from absl.testing import absltest -from torax.geometry import geometry +from torax.geometry import circular_geometry # pylint: disable=g-import-not-at-top try: from torax.transport_model import qualikiz_transport_model @@ -29,7 +29,7 @@ def test_runtime_params_builds_dynamic_params(self): if not _QUALIKIZ_TRANSPORT_MODEL_AVAILABLE: self.skipTest('Qualikiz transport model is not available.') runtime_params = qualikiz_transport_model.RuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) diff --git a/torax/transport_model/tests/quasilinear_transport_model.py b/torax/transport_model/tests/quasilinear_transport_model.py index 7e23f9e7..22d7eab2 100644 --- a/torax/transport_model/tests/quasilinear_transport_model.py +++ b/torax/transport_model/tests/quasilinear_transport_model.py @@ -25,6 +25,7 @@ from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice from torax.fvm import cell_variable +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.pedestal_model import set_tped_nped @@ -39,7 +40,7 @@ def _get_model_inputs(transport: quasilinear_transport_model.RuntimeParams): """Returns the model inputs for testing.""" runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() pedestal_model_builder = ( @@ -265,7 +266,7 @@ def _call_implementation( def _get_dummy_core_profiles(value, right_face_constraint): """Returns dummy core profiles for testing.""" - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() currents = state.Currents.zeros(geo) dummy_cell_variable = cell_variable.CellVariable( value=value, diff --git a/torax/transport_model/tests/transport_model.py b/torax/transport_model/tests/transport_model.py index c14f8a7c..2551298d 100644 --- a/torax/transport_model/tests/transport_model.py +++ b/torax/transport_model/tests/transport_model.py @@ -25,6 +25,7 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice +from torax.geometry import circular_geometry from torax.geometry import geometry from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.pedestal_model import set_tped_nped @@ -45,7 +46,7 @@ def test_smoothing(self): ne_bound_right=0.5, ), ) - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() transport_model_builder = FakeTransportModelBuilder( @@ -197,7 +198,7 @@ def test_smoothing_everywhere(self): ne_bound_right=0.5, ), ) - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() transport_model_builder = FakeTransportModelBuilder( diff --git a/torax/transport_model/tests/transport_model_runtime_params.py b/torax/transport_model/tests/transport_model_runtime_params.py index eda96874..befb7861 100644 --- a/torax/transport_model/tests/transport_model_runtime_params.py +++ b/torax/transport_model/tests/transport_model_runtime_params.py @@ -13,7 +13,7 @@ # limitations under the License. """Tests for runtime params for transport model.""" from absl.testing import absltest -from torax.geometry import geometry +from torax.geometry import circular_geometry from torax.transport_model import runtime_params as runtime_params_lib @@ -21,7 +21,7 @@ class RuntimeParamsTest(absltest.TestCase): def test_runtime_params_builds_dynamic_params(self): runtime_params = runtime_params_lib.RuntimeParams() - geo = geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0)