Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move circular geometry to its own file #669

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torax import sim as sim_lib
from torax.config import config_args
from torax.config import runtime_params as runtime_params_lib
from torax.geometry import circular_geometry
from torax.geometry import geometry
from torax.geometry import geometry_provider
from torax.pedestal_model import pedestal_model as pedestal_model_lib
Expand Down Expand Up @@ -126,14 +127,14 @@ def _build_circular_geometry_provider(
raise ValueError('n_rho must be set in the input config.')
geometries = {}
for time, c in kwargs['geometry_configs'].items():
geometries[time] = geometry.build_circular_geometry(
geometries[time] = circular_geometry.build_circular_geometry(
n_rho=kwargs['n_rho'], **c
)
return geometry.CircularAnalyticalGeometryProvider.create_provider(
return circular_geometry.CircularAnalyticalGeometryProvider.create_provider(
geometries
)
return geometry_provider.ConstantGeometryProvider(
geometry.build_circular_geometry(**kwargs)
circular_geometry.build_circular_geometry(**kwargs)
)


Expand All @@ -153,7 +154,7 @@ def build_geometry_provider_from_config(
expected in the rest of the config. See the following functions to get a full
list of the arguments exposed:

- `geometry.build_circular_geometry()`
- `circular_geometry.build_circular_geometry()`
- `geometry.StandardGeometryIntermediates.from_chease()`
- `geometry.StandardGeometryIntermediates.from_fbt()`

Expand Down
7 changes: 4 additions & 3 deletions torax/config/tests/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torax.config import build_sim
from torax.config import runtime_params as runtime_params_lib
from torax.config import runtime_params_slice
from torax.geometry import circular_geometry
from torax.geometry import geometry
from torax.geometry import geometry_provider
from torax.pedestal_model import set_tped_nped
Expand Down Expand Up @@ -105,7 +106,7 @@ def test_build_sim_with_full_config(self):
)
with self.subTest('geometry'):
geo = sim.geometry_provider(sim.initial_state.t)
self.assertIsInstance(geo, geometry.CircularAnalyticalGeometry)
self.assertIsInstance(geo, circular_geometry.CircularAnalyticalGeometry)
self.assertEqual(geo.torax_mesh.nx, 5)
with self.subTest('sources'):
self.assertEqual(
Expand Down Expand Up @@ -185,7 +186,7 @@ def test_general_runtime_params_with_time_dependent_args(self):
self.assertEqual(runtime_params.profile_conditions.ne_is_fGW, False)
self.assertEqual(runtime_params.numerics.q_correction_factor, 0.2)
self.assertEqual(runtime_params.output_dir, '/tmp/this/is/a/test')
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
dynamic_runtime_params_slice = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params,
Expand Down Expand Up @@ -218,7 +219,7 @@ def test_build_circular_geometry(self):
)
geo = geo_provider(t=0)
np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 5)
self.assertIsInstance(geo, geometry.CircularAnalyticalGeometry)
self.assertIsInstance(geo, circular_geometry.CircularAnalyticalGeometry)
np.testing.assert_array_equal(geo.B0, 5.3) # test a default.

def test_build_geometry_from_chease(self):
Expand Down
6 changes: 3 additions & 3 deletions torax/config/tests/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
from absl.testing import parameterized
from torax import interpolated_param
from torax.config import numerics
from torax.geometry import geometry
from torax.geometry import circular_geometry


class NumericsTest(parameterized.TestCase):
"""Unit tests for the `torax.config.numerics` module."""

def test_numerics_make_provider(self):
nums = numerics.Numerics()
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = nums.make_provider(geo.torax_mesh)
provider.build_dynamic_params(t=0.0)

Expand All @@ -35,7 +35,7 @@ def test_interpolated_vars_are_only_constructed_once(
):
"""Tests that interpolated vars are only constructed once."""
nums = numerics.Numerics()
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = nums.make_provider(geo.torax_mesh)
interpolated_params = {}
for field in provider:
Expand Down
10 changes: 5 additions & 5 deletions torax/config/tests/plasma_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torax import charge_states
from torax import interpolated_param
from torax.config import plasma_composition
from torax.geometry import geometry
from torax.geometry import circular_geometry


class PlasmaCompositionTest(parameterized.TestCase):
Expand All @@ -29,7 +29,7 @@ class PlasmaCompositionTest(parameterized.TestCase):
def test_plasma_composition_make_provider(self):
"""Checks provider construction with no issues."""
pc = plasma_composition.PlasmaComposition()
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
provider.build_dynamic_params(t=0.0)

Expand All @@ -40,7 +40,7 @@ def test_plasma_composition_make_provider(self):
)
def test_zeff_accepts_float_inputs(self, zeff: float):
"""Tests that zeff accepts a single float input."""
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
pc = plasma_composition.PlasmaComposition(Zeff=zeff)
provider = pc.make_provider(geo.torax_mesh)
dynamic_pc = provider.build_dynamic_params(t=0.0)
Expand All @@ -63,7 +63,7 @@ def test_zeff_and_zeff_face_match_expected(self):
1.0: {0.0: 1.8, 0.5: 2.1, 1.0: 2.4},
}

geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
pc = plasma_composition.PlasmaComposition(Zeff=zeff_profile)
provider = pc.make_provider(geo.torax_mesh)

Expand Down Expand Up @@ -102,7 +102,7 @@ def test_interpolated_vars_are_only_constructed_once(
):
"""Tests that interpolated vars are only constructed once."""
pc = plasma_composition.PlasmaComposition()
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
interpolated_params = {}
for field in provider:
Expand Down
14 changes: 7 additions & 7 deletions torax/config/tests/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torax import interpolated_param
from torax.config import config_args
from torax.config import profile_conditions
from torax.geometry import geometry
from torax.geometry import circular_geometry
import xarray as xr


Expand All @@ -30,7 +30,7 @@ class ProfileConditionsTest(parameterized.TestCase):

def test_profile_conditions_make_provider(self):
pc = profile_conditions.ProfileConditions()
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
provider.build_dynamic_params(t=0.0)

Expand All @@ -46,7 +46,7 @@ def test_profile_conditions_sets_Te_bound_right_correctly(
Te={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}},
Te_bound_right=Te_bound_right,
)
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
dcs = provider.build_dynamic_params(t=0.0)
self.assertEqual(dcs.Te_bound_right, expected_initial_value)
Expand All @@ -65,7 +65,7 @@ def test_profile_conditions_sets_Ti_bound_right_correctly(
Ti={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}},
Ti_bound_right=Ti_bound_right,
)
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
dcs = provider.build_dynamic_params(t=0.0)
self.assertEqual(dcs.Ti_bound_right, expected_initial_value)
Expand All @@ -84,7 +84,7 @@ def test_profile_conditions_sets_ne_bound_right_correctly(
ne={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}},
ne_bound_right=ne_bound_right,
)
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
dcs = provider.build_dynamic_params(t=0.0)
self.assertEqual(dcs.ne_bound_right, expected_initial_value)
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_profile_conditions_sets_psi_correctly(
self, psi, expected_initial_value, expected_second_value
):
"""Tests that psi is set correctly."""
geo = geometry.build_circular_geometry(n_rho=4)
geo = circular_geometry.build_circular_geometry(n_rho=4)
pc = profile_conditions.ProfileConditions(
psi=psi,
)
Expand All @@ -147,7 +147,7 @@ def test_interpolated_vars_are_only_constructed_once(
):
"""Tests that interpolated vars are only constructed once."""
pc = profile_conditions.ProfileConditions()
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
interpolated_params = {}
for field in provider:
Expand Down
4 changes: 2 additions & 2 deletions torax/config/tests/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torax.config import config_args
from torax.config import profile_conditions as profile_conditions_lib
from torax.config import runtime_params as general_runtime_params
from torax.geometry import geometry
from torax.geometry import circular_geometry


# pylint: disable=invalid-name
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_runtime_params_make_provider(self):
runtime_params = general_runtime_params.GeneralRuntimeParams(
profile_conditions=profile_conditions_lib.ProfileConditions()
)
torax_mesh = geometry.build_circular_geometry().torax_mesh
torax_mesh = circular_geometry.build_circular_geometry().torax_mesh
runtime_params_provider = runtime_params.make_provider(torax_mesh)
runtime_params_provider.build_dynamic_params(0.0)

Expand Down
14 changes: 7 additions & 7 deletions torax/config/tests/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torax.config import profile_conditions as profile_conditions_lib
from torax.config import runtime_params as general_runtime_params
from torax.config import runtime_params_slice as runtime_params_slice_lib
from torax.geometry import geometry
from torax.geometry import circular_geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import electron_density_sources
from torax.sources import generic_current_source
Expand All @@ -37,7 +37,7 @@ class RuntimeParamsSliceTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self._geo = geometry.build_circular_geometry()
self._geo = circular_geometry.build_circular_geometry()

def test_dynamic_slice_can_be_input_to_jitted_function(self):
"""Tests that the slice can be input to a jitted function."""
Expand Down Expand Up @@ -351,7 +351,7 @@ def test_profile_conditions_set_electron_temperature_and_boundary_condition(
runtime_params = general_runtime_params.GeneralRuntimeParams(
profile_conditions=profile_conditions,
)
geo = geometry.build_circular_geometry(n_rho=4)
geo = circular_geometry.build_circular_geometry(n_rho=4)
dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
torax_mesh=geo.torax_mesh,
Expand Down Expand Up @@ -394,7 +394,7 @@ def test_profile_conditions_set_electron_density_and_boundary_condition(
ne_is_fGW=ne_is_fGW,
),
)
geo = geometry.build_circular_geometry(n_rho=4)
geo = circular_geometry.build_circular_geometry(n_rho=4)

dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
Expand Down Expand Up @@ -427,7 +427,7 @@ def test_update_dynamic_slice_provider_updates_runtime_params(
Ti_bound_right={0.0: 1.0, 1.0: 2.0},
),
)
geo = geometry.build_circular_geometry(n_rho=4)
geo = circular_geometry.build_circular_geometry(n_rho=4)
provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
torax_mesh=geo.torax_mesh,
Expand Down Expand Up @@ -463,7 +463,7 @@ def test_update_dynamic_slice_provider_updates_sources(
source_models_builder.runtime_params[
generic_current_source.GenericCurrentSource.SOURCE_NAME
].Iext = 1.0
geo = geometry.build_circular_geometry(n_rho=4)
geo = circular_geometry.build_circular_geometry(n_rho=4)
provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
sources=source_models_builder.runtime_params,
Expand Down Expand Up @@ -519,7 +519,7 @@ def test_update_dynamic_slice_provider_updates_transport(
"""Tests that the dynamic slice provider can be updated."""
runtime_params = general_runtime_params.GeneralRuntimeParams()
transport = transport_params_lib.RuntimeParams(De_inner=1.0)
geo = geometry.build_circular_geometry(n_rho=4)
geo = circular_geometry.build_circular_geometry(n_rho=4)
provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
torax_mesh=geo.torax_mesh,
Expand Down
3 changes: 2 additions & 1 deletion torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torax.config import profile_conditions
from torax.config import runtime_params_slice
from torax.fvm import cell_variable
from torax.geometry import circular_geometry
from torax.geometry import geometry
from torax.sources import ohmic_heat_source
from torax.sources import source_models as source_models_lib
Expand Down Expand Up @@ -648,7 +649,7 @@ def _init_psi_and_current(
)
# Calculating j according to nu formula and psi from j.
elif (
isinstance(geo, geometry.CircularAnalyticalGeometry)
isinstance(geo, circular_geometry.CircularAnalyticalGeometry)
or dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j
):
currents = _prescribe_currents_no_bootstrap(
Expand Down
4 changes: 2 additions & 2 deletions torax/fvm/tests/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torax.config import runtime_params as general_runtime_params
from torax.config import runtime_params_slice as runtime_params_slice_lib
from torax.fvm import calc_coeffs
from torax.geometry import geometry
from torax.geometry import circular_geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_models as source_models_lib
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_calc_coeffs_smoke_test(
predictor_corrector=False,
theta_imp=theta_imp,
)
geo = geometry.build_circular_geometry(n_rho=num_cells)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
transport_model_builder = (
constant_transport_model.ConstantTransportModelBuilder(
runtime_params=constant_transport_model.RuntimeParams(
Expand Down
10 changes: 5 additions & 5 deletions torax/fvm/tests/fvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torax.fvm import cell_variable
from torax.fvm import implicit_solve_block
from torax.fvm import residual_and_loss
from torax.geometry import geometry
from torax.geometry import circular_geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_models as source_models_lib
Expand Down Expand Up @@ -390,7 +390,7 @@ def test_nonlinear_solve_block_loss_minimum(
predictor_corrector=False,
theta_imp=theta_imp,
)
geo = geometry.build_circular_geometry(n_rho=num_cells)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
transport_model_builder = (
constant_transport_model.ConstantTransportModelBuilder(
runtime_params=constant_transport_model.RuntimeParams(
Expand Down Expand Up @@ -558,7 +558,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
pedestal_model_builder = (
set_tped_nped.SetTemperatureDensityPedestalModelBuilder()
)
geo = geometry.build_circular_geometry(n_rho=num_cells)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
dynamic_runtime_params_slice = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params,
Expand All @@ -579,7 +579,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
stepper=stepper_params,
)
)
geo = geometry.build_circular_geometry(n_rho=num_cells)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
source_models = source_models_builder()
initial_core_profiles = core_profile_setters.initial_core_profiles(
static_runtime_params_slice,
Expand Down Expand Up @@ -681,7 +681,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
predictor_corrector=False,
theta_imp=0.0,
)
geo = geometry.build_circular_geometry(n_rho=num_cells)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
transport_model_builder = (
constant_transport_model.ConstantTransportModelBuilder(
runtime_params=constant_transport_model.RuntimeParams(
Expand Down
Loading
Loading