Skip to content

Commit

Permalink
try to improve performances
Browse files Browse the repository at this point in the history
  • Loading branch information
lucafedeli88 committed Sep 19, 2024
1 parent b9620e4 commit 2d5f8b5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
3 changes: 2 additions & 1 deletion Source/Particles/Radiation/RadiationHandler.H
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <AMReX_GpuContainers.H>
#include <AMReX_REAL.H>
#include <AMReX_Scan.H>
#include <AMReX_Vector.H>

#include <memory>
Expand Down Expand Up @@ -97,8 +98,8 @@ private:
std::optional<int> m_step_skip = std::nullopt;

std::optional<std::array<amrex::Real,2>> m_gamma_range = std::nullopt;
amrex::Gpu::DeviceVector<int> m_offset;
amrex::Gpu::DeviceVector<int> m_mask;
amrex::Gpu::DeviceVector<int> m_offset;
amrex::Gpu::DeviceVector<int> m_idx;
};
#endif // WARPX_PARTICLES_RADIATION_H
19 changes: 16 additions & 3 deletions Source/Particles/Radiation/RadiationHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,9 +564,10 @@ void RadiationHandler::add_radiation_contribution(
else{ //m_gamma_range.has_value()


m_offset.resize(np);
m_mask.resize(np);
m_offset.resize(np);
int* const AMREX_RESTRICT p_mask = m_mask.dataPtr();
int* const AMREX_RESTRICT p_offset = m_offset.dataPtr();

const auto gamma_min = m_gamma_range.value()[0];
const auto gamma_max = m_gamma_range.value()[1];
Expand All @@ -587,6 +588,18 @@ void RadiationHandler::add_radiation_contribution(
p_mask[ip] = is_in;
});

const auto nrad = amrex::Scan::ExclusiveSum(np, p_mask, p_offset);

m_idx.resize(nrad);
int* const AMREX_RESTRICT p_idx = m_idx.dataPtr();

amrex::ParallelFor(np, [=] AMREX_GPU_DEVICE(int ip){
if (p_mask[ip]){
p_idx[p_offset[ip]] = ip;
}
});


#if defined(WARPX_DIM_3D)
amrex::ParallelFor(
np_omegas_detpos, [=] AMREX_GPU_DEVICE(int, int i_om, int i_det){
Expand All @@ -607,9 +620,9 @@ void RadiationHandler::add_radiation_contribution(
auto sum_cy = Complex{0.0_prt, 0.0_prt};
auto sum_cz = Complex{0.0_prt, 0.0_prt};

for (int ip = 0; ip < np; ++ip){
for (int irad = 0; irad < nrad; ++irad){

if (!p_mask[ip]) {continue;}
const auto ip = p_idx[irad];

amrex::ParticleReal xp, yp, zp;
GetPosition.AsStored(ip, xp, yp, zp);
Expand Down

0 comments on commit 2d5f8b5

Please sign in to comment.