diff --git a/Source/FieldSolver/Fields.H b/Source/FieldSolver/Fields.H index 88cd9474a78..1a42da215e7 100644 --- a/Source/FieldSolver/Fields.H +++ b/Source/FieldSolver/Fields.H @@ -36,6 +36,20 @@ namespace warpx::fields Efield_avg_cp, Bfield_avg_cp }; + + inline bool + isFieldArray (const FieldType field_type) + { + if (field_type == FieldType::Efield_aux || field_type == FieldType::Bfield_aux || + field_type == FieldType::Efield_fp || field_type == FieldType::Bfield_fp || + field_type == FieldType::current_fp || field_type == FieldType::current_fp_nodal || + field_type == FieldType::vector_potential_fp || + field_type == FieldType::Efield_cp || field_type == FieldType::Bfield_cp || + field_type == FieldType::current_cp || + field_type == FieldType::Efield_avg_fp || field_type == FieldType::Bfield_avg_fp || + field_type == FieldType::Efield_avg_cp || field_type == FieldType::Bfield_avg_cp) { return true; } + else { return false; } + } } #endif //WARPX_FIELDS_H_ diff --git a/Source/FieldSolver/ImplicitSolvers/WarpXSolverVec.cpp b/Source/FieldSolver/ImplicitSolvers/WarpXSolverVec.cpp index 4a3ffc742a9..0c3a0b0ad1a 100644 --- a/Source/FieldSolver/ImplicitSolvers/WarpXSolverVec.cpp +++ b/Source/FieldSolver/ImplicitSolvers/WarpXSolverVec.cpp @@ -15,7 +15,13 @@ void WarpXSolverVec::Define ( WarpX* a_WarpX, { WARPX_ALWAYS_ASSERT_WITH_MESSAGE( !IsDefined(), - "WarpXSolverVec::Define(a_vec, a_type) called on already defined WarpXSolverVec"); + "WarpXSolverVec::Define() called on already defined WarpXSolverVec"); + + // Define static member pointer to WarpX + if (!m_warpx_ptr_defined) { + m_WarpX = a_WarpX; + m_warpx_ptr_defined = true; + } m_array_type = a_array_type; m_scalar_type = a_scalar_type; @@ -26,12 +32,17 @@ void WarpXSolverVec::Define ( WarpX* a_WarpX, // Define the 3D vector field data container if (m_array_type != FieldType::None) { + WARPX_ALWAYS_ASSERT_WITH_MESSAGE( + isFieldArray(m_array_type), + "WarpXSolverVec::Define() called with array_type not an array field"); + for (int lev = 0; lev < m_num_amr_levels; ++lev) { + using arr_mf_type = std::array; + const arr_mf_type this_array = m_WarpX->getFieldPointerArray(m_array_type, lev); for (int n = 0; n < 3; n++) { - const amrex::MultiFab* mf_model = a_WarpX->getFieldPointer(m_array_type,lev,n); - m_array_vec[lev][n] = std::make_unique( mf_model->boxArray(), - mf_model->DistributionMap(), - mf_model->nComp(), + m_array_vec[lev][n] = std::make_unique( this_array[n]->boxArray(), + this_array[n]->DistributionMap(), + this_array[n]->nComp(), amrex::IntVect::TheZeroVector() ); } } @@ -41,23 +52,21 @@ void WarpXSolverVec::Define ( WarpX* a_WarpX, // Define the scalar data container if (m_scalar_type != FieldType::None) { + WARPX_ALWAYS_ASSERT_WITH_MESSAGE( + !isFieldArray(m_scalar_type), + "WarpXSolverVec::Define() called with scalar_type not a scalar field "); + for (int lev = 0; lev < m_num_amr_levels; ++lev) { - const amrex::MultiFab* mf_model = a_WarpX->getFieldPointer(m_scalar_type,lev,0); - m_scalar_vec[lev] = std::make_unique( mf_model->boxArray(), - mf_model->DistributionMap(), - mf_model->nComp(), + const amrex::MultiFab* this_mf = m_WarpX->getFieldPointer(m_scalar_type,lev,0); + m_scalar_vec[lev] = std::make_unique( this_mf->boxArray(), + this_mf->DistributionMap(), + this_mf->nComp(), amrex::IntVect::TheZeroVector() ); } } m_is_defined = true; - - // Define static member pointer to WarpX - if (!m_warpx_ptr_defined) { - m_WarpX = a_WarpX; - m_warpx_ptr_defined = true; - } } void WarpXSolverVec::Copy ( FieldType a_array_type, @@ -73,9 +82,10 @@ void WarpXSolverVec::Copy ( FieldType a_array_type, for (int lev = 0; lev < m_num_amr_levels; ++lev) { if (m_array_type != FieldType::None) { + using arr_mf_type = std::array; + const arr_mf_type this_array = m_WarpX->getFieldPointerArray(m_array_type, lev); for (int n = 0; n < 3; ++n) { - const amrex::MultiFab* this_field = m_WarpX->getFieldPointer(m_array_type,lev,n); - amrex::MultiFab::Copy( *m_array_vec[lev][n], *this_field, 0, 0, m_ncomp, + amrex::MultiFab::Copy( *m_array_vec[lev][n], *this_array[n], 0, 0, m_ncomp, amrex::IntVect::TheZeroVector() ); } }