Skip to content

Commit

Permalink
wip unit test for CurveCWSFourier
Browse files Browse the repository at this point in the history
  • Loading branch information
Frank Fu committed Jul 9, 2024
1 parent 103f8de commit e1f161e
Showing 1 changed file with 51 additions and 3 deletions.
54 changes: 51 additions & 3 deletions tests/geo/test_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from simsopt.geo.curveplanarfourier import CurvePlanarFourier
from simsopt.geo.curvehelical import CurveHelical
from simsopt.geo.curvexyzfouriersymmetries import CurveXYZFourierSymmetries
from simsopt.geo import SurfaceRZFourier, CurveCWSFourier, Curve2D
from simsopt.geo.curve import RotatedCurve, curves_to_vtk
from simsopt.geo import parameters
from simsopt.configs.zoo import get_ncsx_data, get_w7x_data
Expand Down Expand Up @@ -83,6 +84,30 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])):
curve = CurveXYZFourierSymmetries(x, order, 2, False)
elif curvetype == "CurveXYZFourierSymmetries3":
curve = CurveXYZFourierSymmetries(x, order, 2, False, ntor=3)
elif curvetype in [
"CurveCWSFourier_windowpane",
"CurveCWSFourier_helical",
"CurveCWSFourier_pol",
"CurveCWSFourier_tor"]:
surf_test = SurfaceRZFourier(
nfp=1,
stellsym=True,
mpol=1,
ntor=1,
quadpoints_phi=np.arange(50)/50,
quadpoints_theta=np.arange(50)/50,
)
if curvetype == "CurveCWSFourier_windowpane":
test_curve2d = Curve2D(x, order)
elif curvetype == "CurveCWSFourier_helical":
test_curve2d = Curve2D(x, order, G=10, H=1)
elif curvetype == "CurveCWSFourier_pol":
test_curve2d = Curve2D(x, order, G=1)
elif curvetype == "CurveCWSFourier_tor":
test_curve2d = Curve2D(x, order, H=1)
else:
assert False
curve = CurveCWSFourier(test_curve2d, surf_test)
else:
assert False

Expand Down Expand Up @@ -124,6 +149,16 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])):
curve.set('zc(0)', 1)
curve.set('zs(1)', r)
dofs = curve.get_dofs()
elif curvetype in [
"CurveCWSFourier_windowpane",
"CurveCWSFourier_helical",
"CurveCWSFourier_pol",
"CurveCWSFourier_tor"]:
curve.curve2d.set('thetas(1)', .1)
curve.curve2d.set('phic(1)', .05)
# The curve.curve2d.dofs and curve.x are not equivalent
# because curve.x includes dofs of the surface.
dofs = curve.x
else:
assert False

Expand All @@ -136,7 +171,21 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])):

class Testing(unittest.TestCase):

curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier", "CurvePlanarFourier", "CurveHelical", "CurveXYZFourierSymmetries1","CurveXYZFourierSymmetries2", "CurveXYZFourierSymmetries3", "CurveHelicalInitx0"]
curvetypes = [
"CurveXYZFourier",
"JaxCurveXYZFourier",
"CurveRZFourier",
"CurvePlanarFourier",
"CurveHelical",
"CurveXYZFourierSymmetries1",
"CurveXYZFourierSymmetries2",
"CurveXYZFourierSymmetries3",
"CurveHelicalInitx0",
"CurveCWSFourier_windowpane",
"CurveCWSFourier_helical",
"CurveCWSFourier_pol",
"CurveCWSFourier_tor"
]

def get_curvexyzfouriersymmetries(self, stellsym=True, x=None, nfp=None, ntor=1):
# returns a CurveXYZFourierSymmetries that is randomly perturbed
Expand Down Expand Up @@ -273,7 +322,6 @@ def test_trefoil_stellsym(self):
curve.set('zs(1)', -1)
np.testing.assert_allclose(curve.gamma(), XYZ, atol=1e-14)


def test_nonstellsym(self):
# this test checks that you can obtain a stellarator symmetric magnetic field from two non-stellarator symmetric
# CurveXYZFourierSymmetries curves.
Expand Down Expand Up @@ -780,4 +828,4 @@ def test_load_curves_from_makegrid_file(self):


if __name__ == "__main__":
unittest.main()
unittest.main()

0 comments on commit e1f161e

Please sign in to comment.