Skip to content

Commit

Permalink
Change tests module name and small change to remove separate Chease G…
Browse files Browse the repository at this point in the history
…eometry

PiperOrigin-RevId: 717722210
  • Loading branch information
tamaranorman authored and Torax team committed Jan 21, 2025
1 parent e854805 commit 6b0cbd1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 33 deletions.
43 changes: 20 additions & 23 deletions torax/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class Geometry:
rho_hires: chex.Array
vpr_hires: chex.Array
Phibdot: chex.Array
_z_magnetic_axis: chex.Array
_z_magnetic_axis: chex.Array | None

@property
def rho_norm(self) -> chex.Array:
Expand Down Expand Up @@ -250,7 +250,13 @@ def g1_over_vpr2_face(self) -> jax.Array:

@property
def z_magnetic_axis(self) -> chex.Numeric:
return self._z_magnetic_axis
z_magnetic_axis = self._z_magnetic_axis
if z_magnetic_axis is not None:
return z_magnetic_axis
else:
raise RuntimeError(
'Geometry does not have a z magnetic axis.'
)


@chex.dataclass(frozen=True)
Expand Down Expand Up @@ -300,7 +306,7 @@ class GeometryProvider:
rho_hires_norm: interpolated_param.InterpolatedVarSingleAxis
rho_hires: interpolated_param.InterpolatedVarSingleAxis
vpr_hires: interpolated_param.InterpolatedVarSingleAxis
_z_magnetic_axis: interpolated_param.InterpolatedVarSingleAxis
_z_magnetic_axis: interpolated_param.InterpolatedVarSingleAxis | None

@classmethod
def create_provider(
Expand Down Expand Up @@ -330,6 +336,10 @@ def create_provider(
or attr.name == 'Ip_from_parameters'
):
continue
if attr.name == '_z_magnetic_axis':
if initial_geometry._z_magnetic_axis is None: # pylint: disable=protected-access
kwargs[attr.name] = None
continue
kwargs[attr.name] = interpolated_param.InterpolatedVarSingleAxis(
(times, np.stack([getattr(g, attr.name) for g in geos], axis=0))
)
Expand All @@ -356,6 +366,10 @@ def _get_geometry_base(self, t: chex.Numeric, geometry_class: Type[Geometry]):
if attr.name == 'Phibdot':
kwargs[attr.name] = 0.0
continue
if attr.name == '_z_magnetic_axis':
if self._z_magnetic_axis is None:
kwargs[attr.name] = None
continue
kwargs[attr.name] = getattr(self, attr.name).get_value(t)
return geometry_class(**kwargs) # pytype: disable=wrong-keyword-args

Expand Down Expand Up @@ -427,17 +441,6 @@ def __call__(self, t: chex.Numeric) -> Geometry:
return self._get_geometry_base(t, StandardGeometry)


@chex.dataclass(frozen=True)
class CheaseGeometry(StandardGeometry):
"""CHEASE geometry type."""

@property
def z_magnetic_axis(self) -> chex.Numeric:
raise NotImplementedError(
'CHEASE geometry does not have a z magnetic axis.'
)


def build_circular_geometry(
n_rho: int = 25,
elongation_LCFS: float = 1.72,
Expand Down Expand Up @@ -722,7 +725,7 @@ class StandardGeometryIntermediates:
vpr: chex.Array
n_rho: int
hires_fac: int
z_magnetic_axis: chex.Numeric
z_magnetic_axis: chex.Numeric | None

def __post_init__(self):
"""Extrapolates edge values based on a Cubic spline fit."""
Expand Down Expand Up @@ -860,8 +863,7 @@ def from_chease(
vpr=vpr,
n_rho=n_rho,
hires_fac=hires_fac,
# field doesn't exist in CHEASE, populate with 0.
z_magnetic_axis=np.array(0.0),
z_magnetic_axis=None,
)

@classmethod
Expand Down Expand Up @@ -1586,12 +1588,7 @@ def build_standard_geometry(
area_hires = rhon_interpolation_func(rho_hires_norm, area_intermediate)
area = rhon_interpolation_func(rho_norm, area_intermediate)

if intermediate.geometry_type == GeometryType.CHEASE:
geometry_type = CheaseGeometry
else:
geometry_type = StandardGeometry

return geometry_type(
return StandardGeometry(
geometry_type=intermediate.geometry_type.value,
drho_norm=np.asarray(drho_norm),
torax_mesh=mesh,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for torax.geometry."""

import dataclasses
import os

Expand Down Expand Up @@ -313,20 +311,18 @@ def f(geo: geometry.Geometry):

f(geo)

def test_build_standard_geometry_builds_correct_type_for_chease(self):
"""Test that the default CHEASE geometry can be built and is of the correct type."""
intermediate = geometry.StandardGeometryIntermediates.from_chease()
geo = geometry.build_standard_geometry(intermediate)
self.assertIsInstance(geo, geometry.CheaseGeometry)

def test_access_z_magnetic_axis_raises_error_for_chease_geometry(self):
"""Test that accessing z_magnetic_axis raises error for CHEASE geometry."""
intermediate = geometry.StandardGeometryIntermediates.from_chease()
geo = geometry.build_standard_geometry(intermediate)
# Check that a runtime error is raised under both JIT and non-JIT.
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(
RuntimeError, 'does not have a z magnetic axis'
):
_ = geo.z_magnetic_axis
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(
RuntimeError, 'does not have a z magnetic axis'
):

def f():
return geo.z_magnetic_axis
Expand Down

0 comments on commit 6b0cbd1

Please sign in to comment.