Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CurveXYZFourierSymmetries #404

Merged
merged 22 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
71db8b9
initial commit
andrewgiuliani Apr 18, 2024
c034678
linting
andrewgiuliani Apr 18, 2024
9ef07b1
more linting
andrewgiuliani Apr 18, 2024
c74f81b
added more unit tests
andrewgiuliani Apr 18, 2024
a5b9752
fixing typo in docstring
andrewgiuliani Apr 18, 2024
cdebaf3
cleaning up unit tests
andrewgiuliani Apr 18, 2024
fddc130
added a couple more unit tests
andrewgiuliani Apr 18, 2024
12f0ad2
Add CurveXYZHelical to docs and fix typos in docs
landreman Apr 19, 2024
814d126
CurveXYZHelical: reduce some redundant code, move imports to top of file
landreman Apr 19, 2024
42f13c9
fixing unit tests
andrewgiuliani Apr 19, 2024
b5123c0
first attempt at allowing CurveXYZHelical to wrap multiple times toro…
andrewgiuliani Apr 20, 2024
1a435df
fixing some docstrings in the tests
andrewgiuliani Apr 20, 2024
eb2ed51
renaming class, adding a trefoil unit test
andrewgiuliani Apr 22, 2024
4ae7c54
fixing unit tests
andrewgiuliani Apr 22, 2024
c91431f
cleaning up the condition checking if ntor and nfp are coprime
andrewgiuliani Apr 22, 2024
5cec5ac
testing on a stellsym and nonstellsym trefoil
andrewgiuliani Apr 22, 2024
ee48c83
fixing docstrings
andrewgiuliani Apr 22, 2024
ccea421
fixing documentation so it compiles
andrewgiuliani Apr 22, 2024
1c12267
minor change to docstring
andrewgiuliani Apr 22, 2024
6ab1f43
Merge branch 'master' into ag/curvexyzhelical
andrewgiuliani Apr 23, 2024
f7bdc53
modified unit tests for code coverage
andrewgiuliani Apr 23, 2024
6841058
linting fix
andrewgiuliani Apr 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/simsopt.geo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ simsopt.geo.curvexyzfourier module
:undoc-members:
:show-inheritance:

simsopt.geo.curvexyzhelical module
----------------------------------

.. automodule:: simsopt.geo.curvexyzhelical
:members:
:undoc-members:
:show-inheritance:

simsopt.geo.finitebuild module
------------------------------

Expand Down
2 changes: 2 additions & 0 deletions src/simsopt/geo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .curvehelical import *
from .curverzfourier import *
from .curvexyzfourier import *
from .curvexyzhelical import *
from .curveperturbed import *
from .curveobjectives import *
from .curveplanarfourier import *
Expand All @@ -28,6 +29,7 @@

__all__ = (curve.__all__ + curvehelical.__all__ +
curverzfourier.__all__ + curvexyzfourier.__all__ +
curvexyzhelical.__all__ +
curveperturbed.__all__ + curveobjectives.__all__ +
curveplanarfourier.__all__ +
finitebuild.__all__ + plotting.__all__ +
Expand Down
123 changes: 123 additions & 0 deletions src/simsopt/geo/curvexyzhelical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import jax.numpy as jnp
import numpy as np
from .curve import JaxCurve

__all__ = ['CurveXYZHelical']


def jaxXYZHelicalFouriercurve_pure(dofs, quadpoints, order, nfp, stellsym):

if stellsym:
xc = dofs[:order+1]
ys = dofs[order+1:2*order+1]
zs = dofs[2*order+1:]

theta, m = jnp.meshgrid(quadpoints, jnp.arange(order+1), indexing='ij')
xhat = np.sum(xc[None, :] * jnp.cos(2 * jnp.pi * nfp*m*theta), axis=1)
yhat = np.sum(ys[None, :] * jnp.sin(2 * jnp.pi * nfp*m[:, 1:]*theta[:, 1:]), axis=1)

x = jnp.cos(2*jnp.pi*quadpoints) * xhat - jnp.sin(2*jnp.pi*quadpoints) * yhat
y = jnp.sin(2*jnp.pi*quadpoints) * xhat + jnp.cos(2*jnp.pi*quadpoints) * yhat
z = jnp.sum(zs[None, :] * jnp.sin(2*jnp.pi*nfp * m[:, 1:]*theta[:, 1:]), axis=1)
else:
xc = dofs[0 : order+1]
xs = dofs[order+1 : 2*order+1]
yc = dofs[2*order+1: 3*order+2]
ys = dofs[3*order+2: 4*order+2]
zc = dofs[4*order+2: 5*order+3]
zs = dofs[5*order+3: ]

theta, m = jnp.meshgrid(quadpoints, jnp.arange(order+1), indexing='ij')
xhat = np.sum(xc[None, :] * jnp.cos(2*jnp.pi*nfp*m*theta), axis=1) + np.sum(xs[None, :] * jnp.sin(2*jnp.pi*nfp*m[:, 1:]*theta[:, 1:]), axis=1)
yhat = np.sum(yc[None, :] * jnp.cos(2*jnp.pi*nfp*m*theta), axis=1) + np.sum(ys[None, :] * jnp.sin(2*jnp.pi*nfp*m[:, 1:]*theta[:, 1:]), axis=1)

x = jnp.cos(2*jnp.pi*quadpoints) * xhat - jnp.sin(2*jnp.pi*quadpoints) * yhat
y = jnp.sin(2*jnp.pi*quadpoints) * xhat + jnp.cos(2*jnp.pi*quadpoints) * yhat
z = np.sum(zc[None, :] * jnp.cos(2*jnp.pi*nfp*m*theta), axis=1) + np.sum(zs[None, :] * jnp.sin(2*jnp.pi*nfp*m[:, 1:]*theta[:, 1:]), axis=1)

gamma = jnp.zeros((len(quadpoints),3))
gamma = gamma.at[:, 0].add(x)
gamma = gamma.at[:, 1].add(y)
gamma = gamma.at[:, 2].add(z)
return gamma


class CurveXYZHelical(JaxCurve):
r'''A curve representation for a helical coil that does not lie on a torus.
The coordinates of the curve are given by:

.. math::
\hat x(\theta) &= x_{c, 0} + \sum_{m=1}^{\text{order}} x_{c,m} \cos(2 \pi n_{\text{fp}} m \theta)\\
\hat y(\theta) &= \sum_{m=1}^{\text{order}} y_{s,m} \sin(2 \pi n_{\text{fp}} m \theta)\\
x(\theta) &= \hat x(\theta) \cos(2 \pi \theta) - \hat y(\theta) \sin(2 \pi \theta)\\
y(\theta) &= \hat x(\theta) \sin(2 \pi \theta) + \hat y(\theta) \cos(2 \pi \theta)\\
z(\theta) &= \sum_{m=1}^{\text{order}} z_{s,m} \sin(2 \pi n_{\text{fp}} m \theta)

if the coil is stellarator symmetric. When the coil is not stellarator symmetric, the formulas above
become

.. math::
\hat x(\theta) &= x_{c, 0} + \sum_{m=1}^{\text{order}} \left[ x_{c, m} \cos(2 \pi n_{\text{fp}} m \theta) + x_{s, m} \sin(2 \pi n_{\text{fp}} m \theta) \right] \\
\hat y(\theta) &= y_{c, 0} + \sum_{m=1}^{\text{order}} \left[ y_{c, m} \cos(2 \pi n_{\text{fp}} m \theta) + y_{s, m} \sin(2 \pi n_{\text{fp}} m \theta) \right] \\
x(\theta) &= \hat x(\theta) \cos(2 \pi \theta) - \hat y(\theta) \sin(2 \pi \theta)\\
y(\theta) &= \hat x(\theta) \sin(2 \pi \theta) + \hat y(\theta) \cos(2 \pi \theta)\\
z(\theta) &= z_{c, 0} + \sum_{m=1}^{\text{order}} \left[ z_{c, m} \cos(2 \pi n_{\text{fp}} m \theta) + z_{s, m} \sin(2 \pi n_{\text{fp}} m \theta) \right]

Args:
quadpoints: number of grid points/resolution along the curve,
order: how many Fourier harmonics to include in the Fourier representation,
nfp: discrete rotational symmetry number,
stellsym: stellaratory symmetry if True, not stellarator symmetric otherwise.
'''

def __init__(self, quadpoints, order, nfp, stellsym, **kwargs):
if isinstance(quadpoints, int):
quadpoints = np.linspace(0, 1, quadpoints, endpoint=False)
pure = lambda dofs, points: jaxXYZHelicalFouriercurve_pure(
dofs, points, order, nfp, stellsym)

self.order = order
self.nfp = nfp
self.stellsym = stellsym
self.coefficients = np.zeros(self.num_dofs())
if "dofs" not in kwargs:
if "x0" not in kwargs:
kwargs["x0"] = self.coefficients
else:
self.set_dofs_impl(kwargs["x0"])

super().__init__(quadpoints, pure, names=self._make_names(order), **kwargs)

def _make_names(self, order):
if self.stellsym:
x_cos_names = [f'xc({i})' for i in range(0, order + 1)]
x_names = x_cos_names
y_sin_names = [f'ys({i})' for i in range(1, order + 1)]
y_names = y_sin_names
z_sin_names = [f'zs({i})' for i in range(1, order + 1)]
z_names = z_sin_names
andrewgiuliani marked this conversation as resolved.
Show resolved Hide resolved
else:
x_names = ['xc(0)']
x_cos_names = [f'xc({i})' for i in range(1, order + 1)]
x_sin_names = [f'xs({i})' for i in range(1, order + 1)]
x_names += x_cos_names + x_sin_names
y_names = ['yc(0)']
y_cos_names = [f'yc({i})' for i in range(1, order + 1)]
y_sin_names = [f'ys({i})' for i in range(1, order + 1)]
y_names += y_cos_names + y_sin_names
z_names = ['zc(0)']
z_cos_names = [f'zc({i})' for i in range(1, order + 1)]
z_sin_names = [f'zs({i})' for i in range(1, order + 1)]
z_names += z_cos_names + z_sin_names

return x_names + y_names + z_names

def num_dofs(self):
return (self.order+1) + self.order + self.order if self.stellsym else 3*(2*self.order+1)

def get_dofs(self):
return self.coefficients

def set_dofs_impl(self, dofs):
self.coefficients[:] = dofs[:]

150 changes: 146 additions & 4 deletions tests/geo/test_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from simsopt.geo.curverzfourier import CurveRZFourier
from simsopt.geo.curveplanarfourier import CurvePlanarFourier
from simsopt.geo.curvehelical import CurveHelical
from simsopt.geo.curvexyzhelical import CurveXYZHelical
from simsopt.geo.curve import RotatedCurve, curves_to_vtk
from simsopt.geo import parameters
from simsopt.configs.zoo import get_ncsx_data, get_w7x_data
Expand All @@ -36,7 +37,7 @@ def taylor_test(f, df, x, epsilons=None, direction=None):
dfx = df(x)@direction
if epsilons is None:
epsilons = np.power(2., -np.asarray(range(7, 20)))
# print("################################################################################")
print("################################################################################")
err_old = 1e9
counter = 0
for eps in epsilons:
Expand All @@ -46,16 +47,19 @@ def taylor_test(f, df, x, epsilons=None, direction=None):
fminuseps = f(x - eps * direction)
dfest = (fpluseps-fminuseps)/(2*eps)
err = np.linalg.norm(dfest - dfx)
print(err, err/err_old)

assert err < 1e-9 or err < 0.3 * err_old
if err < 1e-9:
break
err_old = err
counter += 1
if err > 1e-10:
assert counter > 3
# print("################################################################################")
print("################################################################################")


#def get_curve(curvetype, rotated, x=np.linspace(0, 1, 100, endpoint=False)):
def get_curve(curvetype, rotated, x=np.asarray([0.5])):
np.random.seed(2)
rand_scale = 0.01
Expand All @@ -73,9 +77,13 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])):
curve = CurveHelical(x, order, 5, 2, 1.0, 0.3, x0=np.ones((2*order,)))
elif curvetype == "CurvePlanarFourier":
curve = CurvePlanarFourier(x, order, 2, True)
elif curvetype == "CurveXYZHelical1":
curve = CurveXYZHelical(x, order, 2, True)
elif curvetype == "CurveXYZHelical2":
curve = CurveXYZHelical(x, order, 2, False)
else:
assert False

dofs = np.zeros((curve.dof_size, ))
if curvetype in ["CurveXYZFourier", "JaxCurveXYZFourier"]:
dofs[1] = 1.
Expand All @@ -87,18 +95,152 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])):
dofs[order+1] = 0.1
elif curvetype in ["CurveHelical", "CurveHelicalInitx0"]:
dofs[0] = np.pi/2
elif curvetype == "CurveXYZHelical1":
R = 1
r = 0.5
curve.set('xc(0)', R)
curve.set('xc(1)', -r)
curve.set('zs(1)', -r)
dofs = curve.get_dofs()
elif curvetype == "CurveXYZHelical2":
R = 1
r = 0.5
curve.set('xc(0)', R)
curve.set('xs(1)', -0.1*r)
curve.set('xc(1)', -r)
curve.set('zs(1)', -r)
curve.set('zc(0)', 1)
curve.set('zs(1)', r)
dofs = curve.get_dofs()
else:
assert False

curve.x = dofs + rand_scale * np.random.rand(len(dofs)).reshape(dofs.shape)

if rotated:
curve = RotatedCurve(curve, 0.5, flip=False)
return curve


class Testing(unittest.TestCase):

curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier", "CurvePlanarFourier", "CurveHelical", "CurveHelicalInitx0"]
curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier", "CurvePlanarFourier", "CurveHelical", "CurveXYZHelical1","CurveXYZHelical2", "CurveHelicalInitx0"]

def get_curvexyzhelical(self, stellsym=True, x=None, nfp=None):
# returns a CurveXYZHelical that is randomly perturbed

np.random.seed(1)
rand_scale = 1e-2

if nfp is None:
nfp = 3
if x is None:
x = np.linspace(0, 1, 200, endpoint=False)

order = 2
curve = CurveXYZHelical(x, order, nfp, stellsym)
R = 1
r = 0.25
curve.set('xc(0)', R)
curve.set('xc(2)', r)
curve.set('ys(2)', -r)
curve.set('zs(1)', -2*r)
curve.set('zs(2)', r)
dofs = curve.x.copy()
curve.x = dofs + rand_scale * np.random.rand(len(dofs)).reshape(dofs.shape)

return curve

def test_curvehelical_is_curvexyzhelical(self):
# this test checks that both helical coil representations can produce the same helical curve on a torus
order = 1
nfp = 2
x = np.linspace(0, 1, 100, endpoint=False)
curve1 = CurveXYZHelical(x, order, nfp, True)
R = 1
r = 0.5
curve1.set('xc(0)', R)
curve1.set('xc(1)', r)
curve1.set('zs(1)', -r)
curve2 = CurveHelical(x, order, nfp, 1, R, r, x0=np.zeros((2*order,)))
assert np.mean(np.linalg.norm(curve1.gamma()-curve2.gamma(), axis=-1)) == 0
andrewgiuliani marked this conversation as resolved.
Show resolved Hide resolved

def test_nonstellsym(self):
# this test checks that you can obtain a stellarator symmetric magnetic field from two non-stellarator symmetric
# CurveXYZHelical curves.
nfp = 3
curve = self.get_curvexyzhelical(stellsym=False, nfp=nfp)
from simsopt.field import BiotSavart, Current, coils_via_symmetries, Coil
current = Current(1e5)
coils = coils_via_symmetries([curve], [current], 1, True)
bs = BiotSavart(coils)
bs.set_points([[1, 1, 1], [1, -1, -1]])
B=bs.B_cyl()
assert np.abs(B[0, 0]+B[1, 0]) <1e-15
assert np.abs(B[0, 1]-B[1, 1]) <1e-15
assert np.abs(B[0, 2]-B[1, 2]) <1e-15

# sanity check that a nonstellarator symmetric CurveXYZHelical produces a non-stellsym magnetic field
bs = BiotSavart([Coil(curve, Current(1e5))])
bs.set_points([[1, 1, 1], [1, -1, -1]])
B=bs.B_cyl()
assert not np.abs(B[0, 0]+B[1, 0]) <1e-15
assert not np.abs(B[0, 1]-B[1, 1]) <1e-15
assert not np.abs(B[0, 2]-B[1, 2]) <1e-15


def test_xyzhelical_symmetries(self):

nfp = 3
# does the stellarator symmetric curve have rotational symmetry?
curve = self.get_curvexyzhelical(stellsym=True, nfp=nfp, x = np.array([0.123, 0.123+1/nfp]))
out = curve.gamma()
alpha = -2*np.pi/nfp
R = np.array([[np.cos(alpha), np.sin(alpha), 0], [-np.sin(alpha), np.cos(alpha), 0], [0, 0, 1]])
print(R@out[0], out[1])
assert np.linalg.norm(out[1]-R@out[0])<1e-15

# does the stellarator symmetric curve indeed pass through (x0, 0, 0)?
curve = self.get_curvexyzhelical(stellsym=True, nfp=nfp, x = np.array([0]))
out = curve.gamma()
assert out[0, 0] !=0
assert out[0, 1] == 0
assert out[0, 2] == 0


# does the non-stellarator symmetric curve not pass through (x0, 0, 0)?
curve = self.get_curvexyzhelical(stellsym=False, nfp=nfp, x = np.array([0]))
out = curve.gamma()
assert out[0, 0] !=0
assert out[0, 1] != 0
assert out[0, 2] != 0

# is the stellarator symmetric curve actually stellarator symmetric?
curve = self.get_curvexyzhelical(stellsym=True, nfp=nfp, x = np.array([0.123, -0.123]))
pts = curve.gamma()
assert np.abs(pts[0, 0]-pts[1, 0]) <1e-15
assert np.abs(pts[0, 1]+pts[1, 1]) <1e-15
assert np.abs(pts[0, 2]+pts[1, 2]) <1e-15

# is the field from the stellarator symmetric curve actually stellarator symmetric?
from simsopt.field import BiotSavart, Current, Coil
curve = self.get_curvexyzhelical(stellsym=True, nfp=nfp, x=np.linspace(0, 1, 200, endpoint=False))
current = Current(1e5)
coil = Coil(curve, current)
bs = BiotSavart([coil])
bs.set_points([[1, 1, 1], [1, -1, -1]])
B=bs.B_cyl()
assert np.abs(B[0, 0]+B[1, 0]) <1e-15
assert np.abs(B[0, 1]-B[1, 1]) <1e-15
assert np.abs(B[0, 2]-B[1, 2]) <1e-15

# does the non-stellarator symmetric curve have rotational symmetry still?
curve = self.get_curvexyzhelical(stellsym=False, nfp=nfp, x = np.array([0.123, 0.123+1/nfp]))
out = curve.gamma()
alpha = -2*np.pi/nfp
R = np.array([[np.cos(alpha), np.sin(alpha), 0], [-np.sin(alpha), np.cos(alpha), 0], [0, 0, 1]])
print(R@out[0], out[1])
assert np.linalg.norm(out[1]-R@out[0])<1e-15

def test_curve_helical_xyzfourier(self):
x = np.asarray([0.6])
Expand Down
Loading