Skip to content

Commit

Permalink
compiles but does not work because curvealongz does not have dofs
Browse files Browse the repository at this point in the history
  • Loading branch information
smiet committed Nov 8, 2024
1 parent e33060c commit 46710ad
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ pybind11_add_module(${PROJECT_NAME}
src/simsoptpp/biot_savart_py.cpp
src/simsoptpp/biot_savart_vjp_py.cpp
src/simsoptpp/regular_grid_interpolant_3d_py.cpp
src/simsoptpp/curve.cpp src/simsoptpp/curverzfourier.cpp src/simsoptpp/curvexyzfourier.cpp src/simsoptpp/curveplanarfourier.cpp
src/simsoptpp/curve.cpp src/simsoptpp/curverzfourier.cpp src/simsoptpp/curvexyzfourier.cpp src/simsoptpp/curveplanarfourier.cpp src/simsoptpp/curvealongz.cpp
src/simsoptpp/surface.cpp src/simsoptpp/surfacerzfourier.cpp src/simsoptpp/surfacexyzfourier.cpp
src/simsoptpp/integral_BdotN.cpp
src/simsoptpp/dipole_field.cpp src/simsoptpp/permanent_magnet_optimization.cpp
Expand Down
1 change: 1 addition & 0 deletions src/simsopt/geo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .curvehelical import *
from .curverzfourier import *
from .curvexyzfourier import *
from .curvealongz import *
from .curvexyzfouriersymmetries import *
from .curveperturbed import *
from .curveobjectives import *
Expand Down
43 changes: 43 additions & 0 deletions src/simsopt/geo/curvealongz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from math import pi
from itertools import chain

import numpy as np
import jax.numpy as jnp
from scipy.fft import rfft

from .curve import Curve, JaxCurve
import simsoptpp as sopp


__all__ = ['CurveAlongZ',]


class CurveAlongZ(sopp.CurveAlongZ, Curve):

r"""
A class for representing a straight current along the z-axis. Useful for representing tokamak-like configurations or just adding a toroidal 1/r field to an existing configuratoin.
The quadrature points are placed at [0, 0, 10* tan(pi*(gamma+0.5*dgamma-0.5))].
a linear spacing in quadratures thus tigtly packs the quadrature points near the origin, with large separation at 0 (-inf) and 1 (inf).
"""

def __init__(self, quadpoints):
if isinstance(quadpoints, int):
quadpoints = list(np.linspace(0, 1, quadpoints, endpoint=False))
elif isinstance(quadpoints, np.ndarray):
quadpoints = list(quadpoints)
sopp.CurveAlongZ.__init__(self, quadpoints)
Curve.__init__(self, dofs=[])

def get_dofs(self):
"""
This function returns the dofs associated to this object.
"""
return np.asarray(sopp.CurveXYZFourier.get_dofs(self))

def set_dofs(self, dofs):
"""
This function sets the dofs associated to this object.
"""
self.local_x = dofs
sopp.CurveXYZFourier.set_dofs(self, dofs)
52 changes: 52 additions & 0 deletions src/simsoptpp/curvealongz.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "curvealongz.h"
#include <cmath>

template<class Array>
void CurveAlongZ<Array>::gamma_impl(Array& data, Array& quadpoints) {
int numquadpoints = quadpoints.size();
data *= 0;
for (int k = 0; k < numquadpoints; ++k) {
data(k, 0) = 0;
data(k, 1) = 0;
data(k, 2) = tan(M_PI * (quadpoints[k] - 0.5));
}
}

template<class Array>
void CurveAlongZ<Array>::gammadash_impl(Array& data) {
data *= 0;
for (int k = 0; k < numquadpoints; ++k) {
data(k, 0) = 0;
data(k, 1) = 0;
data(k, 2) = M_PI / (cos(M_PI * (quadpoints[k] - 0.5)) * cos(M_PI * (quadpoints[k] - 0.5)));
}
}

template<class Array>
void CurveAlongZ<Array>::gammadashdash_impl(Array& data) {
data *= 0;
for (int k = 0; k < numquadpoints; ++k) {
data(k, 0) = 0;
data(k, 1) = 0;
data(k, 2) = 2 * M_PI * M_PI * tan(M_PI * (quadpoints[k] - 0.5)) / (cos(M_PI * (quadpoints[k] - 0.5)) * cos(M_PI * (quadpoints[k] - 0.5)));
}
}

template<class Array>
void CurveAlongZ<Array>::dgamma_by_dcoeff_impl(Array& data) {
// Empty implementation
}

template<class Array>
void CurveAlongZ<Array>::dgammadash_by_dcoeff_impl(Array& data) {
// Empty implementation
}

template<class Array>
void CurveAlongZ<Array>::dgammadashdash_by_dcoeff_impl(Array& data) {
// Empty implementation
}

#include "xtensor-python/pyarray.hpp" // Numpy bindings
typedef xt::pyarray<double> Array;
template class CurveAlongZ<Array>;
84 changes: 84 additions & 0 deletions src/simsoptpp/curvealongz.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#pragma once

#include "curve.h"

template<class Array>
class CurveAlongZ : public Curve<Array> {
public:
using Curve<Array>::quadpoints;
using Curve<Array>::numquadpoints;
using Curve<Array>::check_the_persistent_cache;

CurveAlongZ(int _numquadpoints) : Curve<Array>(_numquadpoints) {}

CurveAlongZ(vector<double> _quadpoints) : Curve<Array>(_quadpoints) {}

CurveAlongZ(Array _quadpoints) : Curve<Array>(_quadpoints) {}

inline int num_dofs() override {
return 0;
}

void set_dofs_impl(const vector<double>& _dofs) override {
// No dofs to set
}

vector<double> get_dofs() override {
return vector<double>();
}

/**
* @brief Returns an empty array as the derivative of gamma with respect to coefficients.
*
* This function overrides the base class implementation to return an empty array.
*
* @return Array& An empty array.
*/
Array& dgamma_by_dcoeff() override {
static Array empty_array;
return empty_array;
}

/**
* @brief Returns an empty array as the derivative of gammadash with respect to coefficients.
*
* This function overrides the base class implementation to return an empty array.
*
* @return Array& An empty array.
*/
Array& dgammadash_by_dcoeff() override {
static Array empty_array;
return empty_array;
}

/**
* @brief Returns an empty array as the derivative of gammadashdash with respect to coefficients.
*
* This function overrides the base class implementation to return an empty array.
*
* @return Array& An empty array.
*/
Array& dgammadashdash_by_dcoeff() override {
static Array empty_array;
return empty_array;
}

/**
* @brief Returns an empty array as the derivative of gammadashdashdash with respect to coefficients.
*
* This function overrides the base class implementation to return an empty array.
*
* @return Array& An empty array.
*/
Array& dgammadashdashdash_by_dcoeff() override {
static Array empty_array;
return empty_array;
}

void gamma_impl(Array& data, Array& quadpoints) override;
void gammadash_impl(Array& data) override;
void gammadashdash_impl(Array& data) override;
void dgamma_by_dcoeff_impl(Array& data) override;
void dgammadash_by_dcoeff_impl(Array& data) override;
void dgammadashdash_by_dcoeff_impl(Array& data) override;
};
31 changes: 31 additions & 0 deletions src/simsoptpp/python_curves.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ typedef CurveXYZFourier<PyArray> PyCurveXYZFourier;
typedef CurveRZFourier<PyArray> PyCurveRZFourier;
#include "curveplanarfourier.h"
typedef CurvePlanarFourier<PyArray> PyCurvePlanarFourier;
#include "curvealongz.h"
typedef CurveAlongZ<PyArray> PyCurveAlongZ;


template <class PyCurveXYZFourierBase = PyCurveXYZFourier> class PyCurveXYZFourierTrampoline : public PyCurveTrampoline<PyCurveXYZFourierBase> {
public:
Expand Down Expand Up @@ -78,6 +81,28 @@ template <class PyCurvePlanarFourierBase = PyCurvePlanarFourier> class PyCurvePl
PyCurvePlanarFourierBase::gamma_impl(data, quadpoints);
}
};

template <class PyCurveAlongZBase = PyCurveAlongZ> class PyCurveAlongZTrampoline : public PyCurveTrampoline<PyCurveAlongZBase> {
public:
using PyCurveTrampoline<PyCurveAlongZBase>::PyCurveTrampoline; // Inherit constructors

int num_dofs() override {
return PyCurveAlongZBase::num_dofs();
}

void set_dofs_impl(const vector<double>& _dofs) override {
PyCurveAlongZBase::set_dofs_impl(_dofs);
}

vector<double> get_dofs() override {
return PyCurveAlongZBase::get_dofs();
}

void gamma_impl(PyArray& data, PyArray& quadpoints) override {
PyCurveAlongZBase::gamma_impl(data, quadpoints);
}
};

template <typename T, typename S> void register_common_curve_methods(S &c) {
c.def("gamma", &T::gamma)
.def("gamma_impl", &T::gamma_impl)
Expand Down Expand Up @@ -144,4 +169,10 @@ void init_curves(py::module_ &m) {
.def_readonly("stellsym", &PyCurvePlanarFourier::stellsym)
.def_readonly("nfp", &PyCurvePlanarFourier::nfp);
register_common_curve_methods<PyCurvePlanarFourier>(pycurveplanarfourier);

auto pycurvealongz = py::class_<PyCurveAlongZ, shared_ptr<PyCurveAlongZ>, PyCurveAlongZTrampoline<PyCurveAlongZ>, PyCurve>(m, "CurveAlongZ")
.def(py::init<int>())
.def(py::init<std::vector<double>>())
.def(py::init<PyArray>());
register_common_curve_methods<PyCurveAlongZ>(pycurvealongz);
}

0 comments on commit 46710ad

Please sign in to comment.