Skip to content

Commit

Permalink
[Draft] track_covariance_map
Browse files Browse the repository at this point in the history
  • Loading branch information
ax3l committed Jan 27, 2025
1 parent b6de9d4 commit 16a194b
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 24 deletions.
4 changes: 4 additions & 0 deletions src/ImpactX.H
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ namespace impactx
*/
void track_particles ();

/** Run the linear transport of the covariance matrix simulation loop
*/
void track_covariance_map ();

/** Query input for warning logger variables and set up warning logger accordingly
*
* Input variables are: ``always_warn_immediately`` and ``abort_on_warning_threshold``.
Expand Down
31 changes: 31 additions & 0 deletions src/ImpactX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,4 +336,35 @@ namespace impactx {
}, element_variant);
}
}

void
ImpactX::track_covariance_map ()
{
// TODO: move whole body out in separate file

// TODO: init done?
amrex::ParmParse const pp_dist = amrex::ParmParse("beam");
auto ref = initialization::read_reference_particle(pp_dist);
auto dist = initialization::read_distribution(pp_dist);
auto cm = impactx::initialization::create_covariance_matrix(dist);

// TODO: output of init state?

// loop over all beamline elements
for (auto & element_variant : m_lattice)
{
std::visit([&ref, &cm](auto&& element){
// push reference particle in global coordinates
{
BL_PROFILE("impactx::Push::RefPart");
element(ref);
}

// push Covariance Matrix
// cm = cm * element.transport_map(ref);
}, element_variant);
}

// TODO: output
}
} // namespace impactx
5 changes: 4 additions & 1 deletion src/particles/CovarianceMatrix.H
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
namespace impactx
{
/** this is a 6x6 matrix */
using CovarianceMatrix = amrex::SmallMatrix<amrex::ParticleReal, 6, 6, amrex::Order::F, 1>;
using Map6x6 = amrex::SmallMatrix<amrex::ParticleReal, 6, 6, amrex::Order::F, 1>;

/** the covariance matrix is 6x6 */
using CovarianceMatrix = Map6x6;

} // namespace impactx::distribution

Expand Down
2 changes: 1 addition & 1 deletion src/particles/PushAll.H
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ namespace impactx
}

// push covariance matrix
// TODO
// TODO ?
// note: decide what to do for elements that have no covariance matrix

// loop over refinement levels
Expand Down
5 changes: 3 additions & 2 deletions src/particles/elements/Drift.H
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace impactx
struct Drift
: public elements::mixin::Named,
public elements::mixin::BeamOptic<Drift>,
public elements::mixin::LinearTransport<Drift>,
public elements::mixin::Thick,
public elements::mixin::Alignment,
public elements::mixin::PipeAperture,
Expand Down Expand Up @@ -178,7 +179,7 @@ namespace impactx
* @returns 6x6 transport matrix
*/
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
elements::mixin::LinearTransport::Map6x6
Map6x6
transport_map (RefPart & AMREX_RESTRICT refpart) const
{
using namespace amrex::literals; // for _rt and _prt
Expand All @@ -191,7 +192,7 @@ namespace impactx
amrex::ParticleReal const betgam2 = std::pow(pt_ref, 2) - 1.0_prt;

// assign linear map matrix elements
elements::mixin::LinearTransport::Map6x6 R = elements::mixin::LinearTransport::Map6x6::Identity();
Map6x6 R = Map6x6::Identity();
R(1,2) = slice_ds;
R(3,4) = slice_ds;
R(5,6) = slice_ds / betgam2;
Expand Down
22 changes: 22 additions & 0 deletions src/particles/elements/Empty.H
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ namespace impactx
// nothing to do
}

/** Dos nothing to the reference particle.
*
* @param[in,out] refpart reference particle
*/
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
void operator() ([[maybe_unused]] RefPart & AMREX_RESTRICT refpart) const
{
// nothing to do
}

/** This function returns the linear transport map.
*
* @param[in] refpart reference particle
* @returns 6x6 transport matrix
*/
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
Map6x6
transport_map ([[maybe_unused]] RefPart & AMREX_RESTRICT refpart) const
{
// nothing to do
}

/** Does nothing to a particle.
*
* @param x particle position in x
Expand Down
6 changes: 3 additions & 3 deletions src/particles/elements/LinearMap.H
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace impactx
: public elements::mixin::Named,
public elements::mixin::BeamOptic<LinearMap>,
public elements::mixin::Alignment,
public elements::mixin::LinearTransport,
public elements::mixin::LinearTransport<Drift>,
public elements::mixin::NoFinalize
{
static constexpr auto type = "LinearMap";
Expand All @@ -48,7 +48,7 @@ namespace impactx
* @param name a user defined and not necessarily unique name of the element
*/
LinearMap (
LinearTransport::Map6x6 const & R,
Map6x6 const & R,
amrex::ParticleReal ds = 0,
amrex::ParticleReal dx = 0,
amrex::ParticleReal dy = 0,
Expand Down Expand Up @@ -173,7 +173,7 @@ namespace impactx
return m_ds;
}

LinearTransport::Map6x6 m_transport_map; // 6x6 transport map
Map6x6 m_transport_map; // 6x6 transport map
amrex::ParticleReal m_ds; // finite ds allowed for bookkeeping, but we do not allow slicing
};

Expand Down
34 changes: 21 additions & 13 deletions src/particles/elements/mixin/lineartransport.H
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#ifndef IMPACTX_ELEMENTS_MIXIN_LINEAR_TRANSPORT_H
#define IMPACTX_ELEMENTS_MIXIN_LINEAR_TRANSPORT_H

#include "particles/ImpactXParticleContainer.H"
#include "particles/CovarianceMatrix.H"

#include <ablastr/constant.H>

Expand All @@ -24,27 +24,35 @@ namespace impactx::elements::mixin
{
/** This is a helper class for lattice elements that can be expressed as linear transport maps.
*/
template<typename T_Element>
struct LinearTransport
{
/** ...
*/
LinearTransport (
)
{
}

//LinearTransport () = default;
LinearTransport () = default;
LinearTransport (LinearTransport const &) = default;
LinearTransport& operator= (LinearTransport const &) = default;
LinearTransport (LinearTransport&&) = default;
LinearTransport& operator= (LinearTransport&& rhs) = default;

~LinearTransport () = default;

// 6x6 linear transport map
using Map6x6 = amrex::SmallMatrix<amrex::ParticleReal, 6, 6, amrex::Order::F, 1>;
// note: for most elements, R is returned by a member function. Some store it also internally as a member.
// Map6x6 m_transport_map; ///< linearized map
/** Linear push of the covariance matrix through an element
*
* @param cm covariance matrix
*/
AMREX_GPU_DEVICE AMREX_FORCE_INLINE
void
operator() (Map6x6 & cm) const
{
static_assert(
std::is_base_of_v<LinearTransport, T_Element>,
"LinearTransport can only be used as a mixin class!"
);

// small trick to force every derived class has to implement a method transport_map
// (w/o using a purely virtual function)
T_Element& element = *static_cast<T_Element*>(this);
cm *= element.transport_map(cm);
}
};

} // namespace impactx::elements::mixin
Expand Down
18 changes: 14 additions & 4 deletions src/python/elements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ namespace

void init_elements(py::module& m)
{
/*
m.def_property_readonly_static(
"Map6x6",
[](py::object){ return py::type::of<Map6x6>(); },
"1-indexed, Fortran-ordered, 6x6 linear transport map type"
);
*/

py::module_ me = m.def_submodule(
"elements",
"Accelerator lattice elements in ImpactX"
Expand Down Expand Up @@ -219,10 +227,11 @@ void init_elements(py::module& m)
)
;

/*
py::class_<elements::mixin::LinearTransport>(mx, "LinearTransport")
// type of map
.def_property_readonly_static("Map6x6",
[](py::object /* lt */){ return py::type::of<elements::mixin::LinearTransport::Map6x6>(); },
[](py::object){ return py::type::of<elements::mixin::LinearTransport::R>(); },
"1-indexed, Fortran-ordered, 6x6 linear transport map type"
)
// values of the map
Expand All @@ -231,6 +240,7 @@ void init_elements(py::module& m)
// "1-indexed, Fortran-ordered, 6x6 linear transport map values"
//)
;
*/

// diagnostics

Expand Down Expand Up @@ -1633,7 +1643,7 @@ void init_elements(py::module& m)
;
register_beamoptics_push(py_TaperedPL);

py::class_<LinearMap, elements::mixin::Named, elements::mixin::Alignment, elements::mixin::LinearTransport> py_LinearMap(me, "LinearMap");
py::class_<LinearMap, elements::mixin::Named, elements::mixin::Alignment> py_LinearMap(me, "LinearMap");
py_LinearMap
.def("__repr__",
[](LinearMap const & linearmap) {
Expand All @@ -1643,7 +1653,7 @@ void init_elements(py::module& m)
}
)
.def(py::init<
elements::mixin::LinearTransport::Map6x6,
Map6x6,
amrex::ParticleReal,
amrex::ParticleReal,
amrex::ParticleReal,
Expand All @@ -1660,7 +1670,7 @@ void init_elements(py::module& m)
)
.def_property("R",
[](LinearMap & linearmap) { return linearmap.m_transport_map; },
[](LinearMap & linearmap, elements::mixin::LinearTransport::Map6x6 R) { linearmap.m_transport_map = R; },
[](LinearMap & linearmap, Map6x6 R) { linearmap.m_transport_map = R; },
"linear map as a 6x6 transport matrix"
)
.def_property("ds",
Expand Down

0 comments on commit 16a194b

Please sign in to comment.