Skip to content

Commit

Permalink
refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
JustinRayAngus committed Aug 23, 2024
1 parent 8b2aa46 commit 31da2a8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 54 deletions.
86 changes: 43 additions & 43 deletions Source/FieldSolver/ImplicitSolvers/WarpXSolverVec.H
Original file line number Diff line number Diff line change
Expand Up @@ -62,91 +62,91 @@ public:
[[nodiscard]] inline bool IsDefined () const { return m_is_defined; }

void Define ( WarpX* a_WarpX,
const warpx::fields::FieldType a_solver_vec_type );
const warpx::fields::FieldType a_field_type );

inline
void Define (const WarpXSolverVec& a_vec)
void Define ( const WarpXSolverVec& a_solver_vec )
{
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
a_vec.IsDefined(),
"WarpXSolverVec::Define(a_vec) called with undefined a_vec");
Define( a_vec.m_WarpX, a_vec.getVecType() );
a_solver_vec.IsDefined(),
"WarpXSolverVec::Define(solver_vec) called with undefined solver_vec");
Define( a_solver_vec.m_WarpX, a_solver_vec.getVecType() );
}

[[nodiscard]] RT dotProduct( const WarpXSolverVec& a_X ) const;

inline
void Copy ( const amrex::Vector<std::array< std::unique_ptr<amrex::MultiFab>, 3 > >& a_solver_vec,
const warpx::fields::FieldType a_solver_vec_type )
void Copy ( const amrex::Vector<std::array< std::unique_ptr<amrex::MultiFab>, 3 > >& a_field_vec,
const warpx::fields::FieldType a_field_type )
{

WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
IsDefined(),
"WarpXSolverVec::Copy() called on undefined WarpXSolverVec");
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
a_solver_vec_type==m_solver_vec_type,
a_field_type==m_field_type,
"WarpXSolverVec::Copy() called with vecs of different types");
for (int lev=0; lev<m_num_amr_levels; ++lev) {
for (int n=0; n<3; ++n) {
amrex::MultiFab::Copy( *m_solver_vec[lev][n], *a_solver_vec[lev][n], 0, 0, m_ncomp,
amrex::MultiFab::Copy( *m_field_vec[lev][n], *a_field_vec[lev][n], 0, 0, m_ncomp,
amrex::IntVect::TheZeroVector() );
}
}
}

inline
void Copy ( const WarpXSolverVec& a_vec )
void Copy ( const WarpXSolverVec& a_solver_vec )
{
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
a_vec.IsDefined(),
"WarpXSolverVec::Copy(a_vec) called with undefined a_vec");
a_solver_vec.IsDefined(),
"WarpXSolverVec::Copy(solver_vec) called with undefined solver_vec");
if (IsDefined()) {
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
a_vec.m_solver_vec_type==m_solver_vec_type,
"WarpXSolverVec::Copy(a_vec) called with vecs of different types");
a_solver_vec.m_field_type==m_field_type,
"WarpXSolverVec::Copy(solver_vec) called with vecs of different types");
}
else { Define(a_vec); }
const amrex::Vector<std::array< std::unique_ptr<amrex::MultiFab>, 3 > >& solver_vec = a_vec.getVec();
Copy(solver_vec, a_vec.getVecType());
else { Define(a_solver_vec); }
const amrex::Vector<std::array<std::unique_ptr<amrex::MultiFab>,3>>& solver_vec = a_solver_vec.getVec();
Copy(solver_vec, a_solver_vec.getVecType());
}

// Prohibit Copy assignment operator
WarpXSolverVec& operator= ( const WarpXSolverVec& a_vec ) = delete;
WarpXSolverVec& operator= ( const WarpXSolverVec& a_solver_vec ) = delete;

// Move assignment operator
WarpXSolverVec(WarpXSolverVec&&) noexcept = default;
WarpXSolverVec& operator= ( WarpXSolverVec&& a_vec ) noexcept
WarpXSolverVec& operator= ( WarpXSolverVec&& a_solver_vec ) noexcept
{
if (this != &a_vec) {
m_solver_vec = std::move(a_vec.m_solver_vec);
m_solver_vec_type = a_vec.m_solver_vec_type;
if (this != &a_solver_vec) {
m_field_vec = std::move(a_solver_vec.m_field_vec);
m_field_type = a_solver_vec.m_field_type;
m_is_defined = true;
}
return *this;
}

inline
void operator+= ( const WarpXSolverVec& a_vec )
void operator+= ( const WarpXSolverVec& a_solver_vec )
{
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
a_vec.m_solver_vec_type==m_solver_vec_type,
"WarpXSolverVec operator += a_vec called with solver vecs of different types");
a_solver_vec.m_field_type==m_field_type,
"WarpXSolverVec operator += solver_vec called with solver vecs of different types");
for (int lev = 0; lev < m_num_amr_levels; ++lev) {
for (int n=0; n<3; n++) {
m_solver_vec[lev][n]->plus(*(a_vec.getVec()[lev][n]), 0, 1, 0);
m_field_vec[lev][n]->plus(*(a_solver_vec.getVec()[lev][n]), 0, 1, 0);
}
}
}

inline
void operator-= (const WarpXSolverVec& a_vec)
void operator-= (const WarpXSolverVec& a_solver_vec)
{
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
a_vec.m_solver_vec_type==m_solver_vec_type,
"WarpXSolverVec operator -= a_vec called with solver vecs of different types");
a_solver_vec.m_field_type==m_field_type,
"WarpXSolverVec operator -= solver_vec called with solver vecs of different types");
for (int lev = 0; lev < m_num_amr_levels; ++lev) {
for (int n=0; n<3; n++) {
m_solver_vec[lev][n]->minus(*(a_vec.getVec()[lev][n]), 0, 1, 0);
m_field_vec[lev][n]->minus(*(a_solver_vec.getVec()[lev][n]), 0, 1, 0);
}
}
}
Expand All @@ -158,12 +158,12 @@ public:
void linComb (const RT a, const WarpXSolverVec& X, const RT b, const WarpXSolverVec& Y)
{
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
X.m_solver_vec_type==m_solver_vec_type &&
Y.m_solver_vec_type==m_solver_vec_type,
X.m_field_type==m_field_type &&
Y.m_field_type==m_field_type,
"WarpXSolverVec::linComb(a,X,b,Y) called with solver vecs of different types");
for (int lev = 0; lev < m_num_amr_levels; ++lev) {
for (int n=0; n<3; n++) {
amrex::MultiFab::LinComb(*m_solver_vec[lev][n], a, *X.getVec()[lev][n], 0,
amrex::MultiFab::LinComb(*m_field_vec[lev][n], a, *X.getVec()[lev][n], 0,
b, *Y.getVec()[lev][n], 0,
0, 1, 0);
}
Expand All @@ -176,11 +176,11 @@ public:
void increment (const WarpXSolverVec& X, const RT a)
{
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
X.m_solver_vec_type==m_solver_vec_type,
"WarpXSolverVec::increment(X,z) called with solver vecs of different types");
X.m_field_type==m_field_type,
"WarpXSolverVec::increment(X,a) called with solver vecs of different types");
for (int lev = 0; lev < m_num_amr_levels; ++lev) {
for (int n=0; n<3; n++) {
amrex::MultiFab::Saxpy( *m_solver_vec[lev][n], a, *X.getVec()[lev][n],
amrex::MultiFab::Saxpy( *m_field_vec[lev][n], a, *X.getVec()[lev][n],
0, 0, 1, amrex::IntVect::TheZeroVector() );
}
}
Expand All @@ -194,7 +194,7 @@ public:
{
for (int lev = 0; lev < m_num_amr_levels; ++lev) {
for (int n=0; n<3; n++) {
m_solver_vec[lev][n]->mult(a_a, 0, 1);
m_field_vec[lev][n]->mult(a_a, 0, 1);
}
}
}
Expand All @@ -210,7 +210,7 @@ public:
"WarpXSolverVec::ones() called on undefined WarpXSolverVec");
for (int lev = 0; lev < m_num_amr_levels; ++lev) {
for (int n=0; n<3; n++) {
m_solver_vec[lev][n]->setVal(a_val);
m_field_vec[lev][n]->setVal(a_val);
}
}
}
Expand All @@ -221,11 +221,11 @@ public:
return std::sqrt(norm);
}

[[nodiscard]] const amrex::Vector<std::array< std::unique_ptr<amrex::MultiFab>, 3 > >& getVec() const {return m_solver_vec;}
amrex::Vector<std::array< std::unique_ptr<amrex::MultiFab>, 3 > >& getVec() {return m_solver_vec;}
[[nodiscard]] const amrex::Vector<std::array< std::unique_ptr<amrex::MultiFab>, 3 > >& getVec() const {return m_field_vec;}
amrex::Vector<std::array< std::unique_ptr<amrex::MultiFab>, 3 > >& getVec() {return m_field_vec;}

// solver vec types are for now limited to warpx::fields::FieldType
warpx::fields::FieldType getVecType () const { return m_solver_vec_type; }
warpx::fields::FieldType getVecType () const { return m_field_type; }

// clearDotMask() must be called by the highest class that owns WarpXSolverVec()
// after it is done being used ( typically in the destructor ) to avoid the
Expand All @@ -236,8 +236,8 @@ public:
private:

bool m_is_defined = false;
amrex::Vector<std::array< std::unique_ptr<amrex::MultiFab>, 3 > > m_solver_vec;
warpx::fields::FieldType m_solver_vec_type;
amrex::Vector<std::array< std::unique_ptr<amrex::MultiFab>, 3 > > m_field_vec;
warpx::fields::FieldType m_field_type;

static constexpr int m_ncomp = 1;
static constexpr int m_num_amr_levels = 1;
Expand Down
22 changes: 11 additions & 11 deletions Source/FieldSolver/ImplicitSolvers/WarpXSolverVec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@
#include "WarpX.H"

void WarpXSolverVec::Define ( WarpX* a_WarpX,
const warpx::fields::FieldType a_solver_vec_type )
const warpx::fields::FieldType a_field_type )
{
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
!IsDefined(),
"WarpXSolverVec::Define(a_vec, a_type) called on already defined WarpXSolverVec");

const amrex::Vector<std::array<std::unique_ptr<amrex::MultiFab>,3>>& a_solver_vec(a_WarpX->getMultiLevelField(a_solver_vec_type));
m_solver_vec.resize(m_num_amr_levels);
const amrex::Vector<std::array<std::unique_ptr<amrex::MultiFab>,3>>& this_vec(a_WarpX->getMultiLevelField(a_field_type));
m_field_vec.resize(m_num_amr_levels);
const int lev = 0;
for (int n=0; n<3; n++) {
const amrex::MultiFab& mf_model = *a_solver_vec[lev][n];
m_solver_vec[lev][n] = std::make_unique<amrex::MultiFab>( mf_model.boxArray(), mf_model.DistributionMap(),
mf_model.nComp(), amrex::IntVect::TheZeroVector() );
const amrex::MultiFab& mf_model = *this_vec[lev][n];
m_field_vec[lev][n] = std::make_unique<amrex::MultiFab>( mf_model.boxArray(), mf_model.DistributionMap(),
mf_model.nComp(), amrex::IntVect::TheZeroVector() );
}
m_solver_vec_type = a_solver_vec_type;
m_field_type = a_field_type;
m_is_defined = true;
SetWarpXPointer(a_WarpX);
SetDotMask();
Expand All @@ -48,8 +48,8 @@ void WarpXSolverVec::SetDotMask()
const amrex::Vector<amrex::Geometry>& Geom = m_WarpX->Geom();
m_dotMask.resize(m_num_amr_levels);
for ( int n = 0; n < 3; n++) {
const amrex::BoxArray& grids = m_solver_vec[0][n]->boxArray();
const amrex::MultiFab tmp( grids, m_solver_vec[0][n]->DistributionMap(),
const amrex::BoxArray& grids = m_field_vec[0][n]->boxArray();
const amrex::MultiFab tmp( grids, m_field_vec[0][n]->DistributionMap(),
1, 0, amrex::MFInfo().SetAlloc(false) );
const amrex::Periodicity& period = Geom[0].periodicity();
m_dotMask[0][n] = tmp.OwnerMask(period);
Expand All @@ -68,15 +68,15 @@ void WarpXSolverVec::SetDotMask()
m_dot_mask_defined,
"WarpXSolverVec::dotProduct called with m_dotMask not yet defined");
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
a_X.m_solver_vec_type==m_solver_vec_type,
a_X.m_field_type==m_field_type,
"WarpXSolverVec::dotProduct(X) called with solver vecs of different types");

amrex::Real result = 0.0;
const int lev = 0;
const bool local = true;
for (int n = 0; n < 3; ++n) {
auto rtmp = amrex::MultiFab::Dot( *m_dotMask[lev][n],
*m_solver_vec[lev][n], 0,
*m_field_vec[lev][n], 0,
*a_X.getVec()[lev][n], 0, 1, 0, local);
result += rtmp;
}
Expand Down

0 comments on commit 31da2a8

Please sign in to comment.