Skip to content

Commit

Permalink
Changed how WarpXSolverVec is defined. Explicit call to SetDotMask() …
Browse files Browse the repository at this point in the history
…is no longer required.
  • Loading branch information
JustinRayAngus committed Aug 23, 2024
1 parent fac39a7 commit 8b2aa46
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 41 deletions.
9 changes: 2 additions & 7 deletions Source/FieldSolver/ImplicitSolvers/SemiImplicitEM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,8 @@ void SemiImplicitEM::Define ( WarpX* a_WarpX )
m_WarpX = a_WarpX;

// Define E and Eold vectors
m_E.Define( m_WarpX->getMultiLevelField(FieldType::Efield_fp), FieldType::Efield_fp );
m_Eold.Define( m_WarpX->getMultiLevelField(FieldType::Efield_fp), FieldType::Efield_fp );

// Need to define the WarpXSolverVec owned dot_mask to do dot
// product correctly for linear and nonlinear solvers
const amrex::Vector<amrex::Geometry>& Geom = m_WarpX->Geom();
m_E.SetDotMask(Geom);
m_E.Define( m_WarpX, FieldType::Efield_fp );
m_Eold.Define( m_E );

// Parse implicit solver parameters
const amrex::ParmParse pp("implicit_evolve");
Expand Down
9 changes: 2 additions & 7 deletions Source/FieldSolver/ImplicitSolvers/ThetaImplicitEM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,8 @@ void ThetaImplicitEM::Define ( WarpX* const a_WarpX )
m_WarpX = a_WarpX;

// Define E and Eold vectors
m_E.Define( m_WarpX->getMultiLevelField(FieldType::Efield_fp), FieldType::Efield_fp );
m_Eold.Define( m_WarpX->getMultiLevelField(FieldType::Efield_fp), FieldType::Efield_fp );

// Need to define the WarpXSolverVec owned dot_mask to do dot
// product correctly for linear and nonlinear solvers
const amrex::Vector<amrex::Geometry>& Geom = m_WarpX->Geom();
m_E.SetDotMask(Geom);
m_E.Define( m_WarpX, FieldType::Efield_fp );
m_Eold.Define( m_E );

// Define Bold MultiFab
const int num_levels = 1;
Expand Down
31 changes: 10 additions & 21 deletions Source/FieldSolver/ImplicitSolvers/WarpXSolverVec.H
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include <AMReX.H>
#include <AMReX_Array.H>
#include <AMReX_BLassert.H>
#include <AMReX_Geometry.H>
#include <AMReX_IntVect.H>
#include <AMReX_LayoutData.H>
#include <AMReX_MultiFab.H>
Expand Down Expand Up @@ -46,6 +45,7 @@
* be used for other solver vectors, such as electrostatic (array size 1) or Darwin (array size 4).
*/

class WarpX;
class WarpXSolverVec
{
public:
Expand All @@ -61,34 +61,18 @@ public:

[[nodiscard]] inline bool IsDefined () const { return m_is_defined; }

void Define ( WarpX* a_WarpX,
const warpx::fields::FieldType a_solver_vec_type );

inline
void Define (const WarpXSolverVec& a_vec)
{
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
a_vec.IsDefined(),
"WarpXSolverVec::Define(a_vec) called with undefined a_vec");
Define( a_vec.getVec(), a_vec.getVecType() );
}

inline
void Define ( const amrex::Vector<std::array< std::unique_ptr<amrex::MultiFab>, 3 > >& a_solver_vec,
const warpx::fields::FieldType a_solver_vec_type )
{
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
!IsDefined(),
"WarpXSolverVec::Define() called on undefined WarpXSolverVec");
m_solver_vec.resize(m_num_amr_levels);
const int lev = 0;
for (int n=0; n<3; n++) {
const amrex::MultiFab& mf_model = *a_solver_vec[lev][n];
m_solver_vec[lev][n] = std::make_unique<amrex::MultiFab>( mf_model.boxArray(), mf_model.DistributionMap(),
mf_model.nComp(), amrex::IntVect::TheZeroVector() );
}
m_solver_vec_type = a_solver_vec_type;
m_is_defined = true;
Define( a_vec.m_WarpX, a_vec.getVecType() );
}

void SetDotMask( const amrex::Vector<amrex::Geometry>& a_Geom );
[[nodiscard]] RT dotProduct( const WarpXSolverVec& a_X ) const;

inline
Expand Down Expand Up @@ -258,8 +242,13 @@ private:
static constexpr int m_ncomp = 1;
static constexpr int m_num_amr_levels = 1;

inline static bool m_warpx_ptr_defined = false;
inline static WarpX* m_WarpX = nullptr;
void SetWarpXPointer( WarpX* a_WarpX );

inline static bool m_dot_mask_defined = false;
inline static amrex::Vector<std::array<std::unique_ptr<amrex::iMultiFab>,3>> m_dotMask;
void SetDotMask();

};

Expand Down
46 changes: 40 additions & 6 deletions Source/FieldSolver/ImplicitSolvers/WarpXSolverVec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,53 @@
* License: BSD-3-Clause-LBNL
*/
#include "FieldSolver/ImplicitSolvers/WarpXSolverVec.H"
#include "WarpX.H"

void WarpXSolverVec::SetDotMask( const amrex::Vector<amrex::Geometry>& a_Geom )
void WarpXSolverVec::Define ( WarpX* a_WarpX,
const warpx::fields::FieldType a_solver_vec_type )
{
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
!IsDefined(),
"WarpXSolverVec::Define(a_vec, a_type) called on already defined WarpXSolverVec");

const amrex::Vector<std::array<std::unique_ptr<amrex::MultiFab>,3>>& a_solver_vec(a_WarpX->getMultiLevelField(a_solver_vec_type));
m_solver_vec.resize(m_num_amr_levels);
const int lev = 0;
for (int n=0; n<3; n++) {
const amrex::MultiFab& mf_model = *a_solver_vec[lev][n];
m_solver_vec[lev][n] = std::make_unique<amrex::MultiFab>( mf_model.boxArray(), mf_model.DistributionMap(),
mf_model.nComp(), amrex::IntVect::TheZeroVector() );
}
m_solver_vec_type = a_solver_vec_type;
m_is_defined = true;
SetWarpXPointer(a_WarpX);
SetDotMask();
}

void WarpXSolverVec::SetWarpXPointer( WarpX* a_WarpX )
{
if (m_warpx_ptr_defined) { return; }
m_WarpX = a_WarpX;
m_warpx_ptr_defined = true;
}

void WarpXSolverVec::SetDotMask()
{
if (m_dot_mask_defined) { return; }
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
IsDefined(),
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
IsDefined(),
"WarpXSolverVec::SetDotMask() called from undefined instance ");
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
m_warpx_ptr_defined,
"WarpXSolverVec::SetDotMask() called before SetWarpXPointer() ");

const amrex::Vector<amrex::Geometry>& Geom = m_WarpX->Geom();
m_dotMask.resize(m_num_amr_levels);
for ( int n = 0; n < 3; n++) {
const amrex::BoxArray& grids = m_solver_vec[0][n]->boxArray();
const amrex::MultiFab tmp( grids, m_solver_vec[0][n]->DistributionMap(),
1, 0, amrex::MFInfo().SetAlloc(false) );
const amrex::Periodicity& period = a_Geom[0].periodicity();
const amrex::Periodicity& period = Geom[0].periodicity();
m_dotMask[0][n] = tmp.OwnerMask(period);
}
m_dot_mask_defined = true;
Expand All @@ -35,8 +68,9 @@ void WarpXSolverVec::SetDotMask( const amrex::Vector<amrex::Geometry>& a_Geom )
m_dot_mask_defined,
"WarpXSolverVec::dotProduct called with m_dotMask not yet defined");
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
a_X.IsDefined(),
"WarpXSolverVec::dotProduct(a_X) called with undefined a_X");
a_X.m_solver_vec_type==m_solver_vec_type,
"WarpXSolverVec::dotProduct(X) called with solver vecs of different types");

amrex::Real result = 0.0;
const int lev = 0;
const bool local = true;
Expand Down

0 comments on commit 8b2aa46

Please sign in to comment.