From e1f161e4137ab6736398baed48c2922219d9ec58 Mon Sep 17 00:00:00 2001 From: Frank Fu Date: Tue, 9 Jul 2024 00:26:10 -0400 Subject: [PATCH] wip unit test for CurveCWSFourier --- tests/geo/test_curve.py | 54 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/tests/geo/test_curve.py b/tests/geo/test_curve.py index a997a6e84..affab019a 100644 --- a/tests/geo/test_curve.py +++ b/tests/geo/test_curve.py @@ -12,6 +12,7 @@ from simsopt.geo.curveplanarfourier import CurvePlanarFourier from simsopt.geo.curvehelical import CurveHelical from simsopt.geo.curvexyzfouriersymmetries import CurveXYZFourierSymmetries +from simsopt.geo import SurfaceRZFourier, CurveCWSFourier, Curve2D from simsopt.geo.curve import RotatedCurve, curves_to_vtk from simsopt.geo import parameters from simsopt.configs.zoo import get_ncsx_data, get_w7x_data @@ -83,6 +84,30 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])): curve = CurveXYZFourierSymmetries(x, order, 2, False) elif curvetype == "CurveXYZFourierSymmetries3": curve = CurveXYZFourierSymmetries(x, order, 2, False, ntor=3) + elif curvetype in [ + "CurveCWSFourier_windowpane", + "CurveCWSFourier_helical", + "CurveCWSFourier_pol", + "CurveCWSFourier_tor"]: + surf_test = SurfaceRZFourier( + nfp=1, + stellsym=True, + mpol=1, + ntor=1, + quadpoints_phi=np.arange(50)/50, + quadpoints_theta=np.arange(50)/50, + ) + if curvetype == "CurveCWSFourier_windowpane": + test_curve2d = Curve2D(x, order) + elif curvetype == "CurveCWSFourier_helical": + test_curve2d = Curve2D(x, order, G=10, H=1) + elif curvetype == "CurveCWSFourier_pol": + test_curve2d = Curve2D(x, order, G=1) + elif curvetype == "CurveCWSFourier_tor": + test_curve2d = Curve2D(x, order, H=1) + else: + assert False + curve = CurveCWSFourier(test_curve2d, surf_test) else: assert False @@ -124,6 +149,16 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])): curve.set('zc(0)', 1) curve.set('zs(1)', r) dofs = curve.get_dofs() + elif curvetype in [ + "CurveCWSFourier_windowpane", + "CurveCWSFourier_helical", + "CurveCWSFourier_pol", + "CurveCWSFourier_tor"]: + curve.curve2d.set('thetas(1)', .1) + curve.curve2d.set('phic(1)', .05) + # The curve.curve2d.dofs and curve.x are not equivalent + # because curve.x includes dofs of the surface. + dofs = curve.x else: assert False @@ -136,7 +171,21 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])): class Testing(unittest.TestCase): - curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier", "CurvePlanarFourier", "CurveHelical", "CurveXYZFourierSymmetries1","CurveXYZFourierSymmetries2", "CurveXYZFourierSymmetries3", "CurveHelicalInitx0"] + curvetypes = [ + "CurveXYZFourier", + "JaxCurveXYZFourier", + "CurveRZFourier", + "CurvePlanarFourier", + "CurveHelical", + "CurveXYZFourierSymmetries1", + "CurveXYZFourierSymmetries2", + "CurveXYZFourierSymmetries3", + "CurveHelicalInitx0", + "CurveCWSFourier_windowpane", + "CurveCWSFourier_helical", + "CurveCWSFourier_pol", + "CurveCWSFourier_tor" + ] def get_curvexyzfouriersymmetries(self, stellsym=True, x=None, nfp=None, ntor=1): # returns a CurveXYZFourierSymmetries that is randomly perturbed @@ -273,7 +322,6 @@ def test_trefoil_stellsym(self): curve.set('zs(1)', -1) np.testing.assert_allclose(curve.gamma(), XYZ, atol=1e-14) - def test_nonstellsym(self): # this test checks that you can obtain a stellarator symmetric magnetic field from two non-stellarator symmetric # CurveXYZFourierSymmetries curves. @@ -780,4 +828,4 @@ def test_load_curves_from_makegrid_file(self): if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file