diff --git a/torax/geometry/geometry.py b/torax/geometry/geometry.py index 56482cd6..cefb8a85 100644 --- a/torax/geometry/geometry.py +++ b/torax/geometry/geometry.py @@ -179,7 +179,7 @@ class Geometry: rho_hires: chex.Array vpr_hires: chex.Array Phibdot: chex.Array - _z_magnetic_axis: chex.Array + _z_magnetic_axis: chex.Array | None @property def rho_norm(self) -> chex.Array: @@ -250,7 +250,13 @@ def g1_over_vpr2_face(self) -> jax.Array: @property def z_magnetic_axis(self) -> chex.Numeric: - return self._z_magnetic_axis + z_magnetic_axis = self._z_magnetic_axis + if z_magnetic_axis is not None: + return z_magnetic_axis + else: + raise RuntimeError( + 'Geometry does not have a z magnetic axis.' + ) @chex.dataclass(frozen=True) @@ -427,17 +433,6 @@ def __call__(self, t: chex.Numeric) -> Geometry: return self._get_geometry_base(t, StandardGeometry) -@chex.dataclass(frozen=True) -class CheaseGeometry(StandardGeometry): - """CHEASE geometry type.""" - - @property - def z_magnetic_axis(self) -> chex.Numeric: - raise NotImplementedError( - 'CHEASE geometry does not have a z magnetic axis.' - ) - - def build_circular_geometry( n_rho: int = 25, elongation_LCFS: float = 1.72, @@ -722,7 +717,7 @@ class StandardGeometryIntermediates: vpr: chex.Array n_rho: int hires_fac: int - z_magnetic_axis: chex.Numeric + z_magnetic_axis: chex.Numeric | None def __post_init__(self): """Extrapolates edge values based on a Cubic spline fit.""" @@ -860,8 +855,7 @@ def from_chease( vpr=vpr, n_rho=n_rho, hires_fac=hires_fac, - # field doesn't exist in CHEASE, populate with 0. - z_magnetic_axis=np.array(0.0), + z_magnetic_axis=None, ) @classmethod @@ -1586,12 +1580,7 @@ def build_standard_geometry( area_hires = rhon_interpolation_func(rho_hires_norm, area_intermediate) area = rhon_interpolation_func(rho_norm, area_intermediate) - if intermediate.geometry_type == GeometryType.CHEASE: - geometry_type = CheaseGeometry - else: - geometry_type = StandardGeometry - - return geometry_type( + return StandardGeometry( geometry_type=intermediate.geometry_type.value, drho_norm=np.asarray(drho_norm), torax_mesh=mesh, diff --git a/torax/geometry/tests/geometry.py b/torax/geometry/tests/geometry_test.py similarity index 96% rename from torax/geometry/tests/geometry.py rename to torax/geometry/tests/geometry_test.py index d81361ff..5230e194 100644 --- a/torax/geometry/tests/geometry.py +++ b/torax/geometry/tests/geometry_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for torax.geometry.""" - import dataclasses import os @@ -313,20 +311,18 @@ def f(geo: geometry.Geometry): f(geo) - def test_build_standard_geometry_builds_correct_type_for_chease(self): - """Test that the default CHEASE geometry can be built and is of the correct type.""" - intermediate = geometry.StandardGeometryIntermediates.from_chease() - geo = geometry.build_standard_geometry(intermediate) - self.assertIsInstance(geo, geometry.CheaseGeometry) - def test_access_z_magnetic_axis_raises_error_for_chease_geometry(self): """Test that accessing z_magnetic_axis raises error for CHEASE geometry.""" intermediate = geometry.StandardGeometryIntermediates.from_chease() geo = geometry.build_standard_geometry(intermediate) # Check that a runtime error is raised under both JIT and non-JIT. - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex( + RuntimeError, 'does not have a z magnetic axis' + ): _ = geo.z_magnetic_axis - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex( + RuntimeError, 'does not have a z magnetic axis' + ): def f(): return geo.z_magnetic_axis