Skip to content

Commit

Permalink
change strategy used in main kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
lucafedeli88 committed May 6, 2024
1 parent 8751ec9 commit 31bfd4d
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 81 deletions.
8 changes: 8 additions & 0 deletions Source/Particles/Radiation/RadiationHandler.H
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ private:
amrex::Gpu::DeviceVector<amrex::Real> m_det_n_y;
amrex::Gpu::DeviceVector<amrex::Real> m_det_n_z;

amrex::Gpu::DeviceVector<amrex::Real> m_b_x;
amrex::Gpu::DeviceVector<amrex::Real> m_b_y;
amrex::Gpu::DeviceVector<amrex::Real> m_b_z;

amrex::Gpu::DeviceVector<amrex::Real> m_bp_x;
amrex::Gpu::DeviceVector<amrex::Real> m_bp_y;
amrex::Gpu::DeviceVector<amrex::Real> m_bp_z;

amrex::Gpu::DeviceVector<ablastr::math::Complex> m_radiation_data;
amrex::Gpu::DeviceVector<amrex::Real> m_radiation_calculation;

Expand Down
163 changes: 82 additions & 81 deletions Source/Particles/Radiation/RadiationHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,64 +440,68 @@ void RadiationHandler::add_radiation_contribution(

WARPX_ALWAYS_ASSERT_WITH_MESSAGE((np-1) == static_cast<int>(np-1), "too many particles!");

#if defined(WARPX_DIM_3D)
const auto np_omegas_detpos = amrex::Box{
amrex::IntVect{0,0,0},
amrex::IntVect{0, omega_points-1, how_many_det_pos-1}};
#else
const auto np_omegas_detpos = amrex::Box{
amrex::IntVect{0,0},
amrex::IntVect{static_cast<int>(np-1), ((omega_points) * (how_many_det_pos) - 1)}};
amrex::ignore_unused(p_det_pos_y);
#endif


#if defined(WARPX_DIM_3D)
amrex::ParallelFor(
np_omegas_detpos, [=] AMREX_GPU_DEVICE(int, int i_om, int i_det){
#else
amrex::ParallelFor(
np_omegas_detpos, [=] AMREX_GPU_DEVICE(int ip, int i_om_det, int){
const int i_det = i_om_det % (how_many_det_pos);
const int i_om = i_om_det / (how_many_det_pos);
#endif

const auto i_omega_over_c = Complex{0.0_prt, 1.0_prt}*p_omegas[i_om]*inv_c;
m_b_x.resize(np);
m_b_y.resize(np);
m_b_z.resize(np);
m_bp_x.resize(np);
m_bp_y.resize(np);
m_bp_z.resize(np);

auto* p_b_x = m_b_x.dataPtr();
auto* p_b_y = m_b_y.dataPtr();
auto* p_b_z = m_b_z.dataPtr();
auto* p_bp_x = m_bp_x.dataPtr();
auto* p_bp_y = m_bp_y.dataPtr();
auto* p_bp_z = m_bp_z.dataPtr();


amrex::ParallelFor(np, [=] AMREX_GPU_DEVICE(int ip){
const auto ux = 0.5_prt*(p_ux[ip] + p_ux_old[ip]);
const auto uy = 0.5_prt*(p_uy[ip] + p_uy_old[ip]);
const auto uz = 0.5_prt*(p_uz[ip] + p_uz_old[ip]);

auto const u2 = ux*ux + uy*uy + uz*uz;

auto const one_over_gamma = 1._prt/std::sqrt(1.0_rt + u2*inv_c2);
auto const one_over_gamma_c = one_over_gamma*inv_c;
const auto one_over_dt_gamma_c = one_over_gamma_c/dt;

p_b_x[ip] = ux*one_over_gamma_c;
p_b_y[ip] = uy*one_over_gamma_c;
p_b_z[ip] = uz*one_over_gamma_c;
p_bp_x[ip] = (p_ux[ip] - p_ux_old[ip])*one_over_dt_gamma_c;
p_bp_y[ip] = (p_uy[ip] - p_uy_old[ip])*one_over_dt_gamma_c;
p_bp_z[ip] = (p_uz[ip] - p_uz_old[ip])*one_over_dt_gamma_c;
});

const auto nx = p_det_n_x[i_det];
const auto ny = p_det_n_y[i_det];
const auto nz = p_det_n_z[i_det];
const auto np_times_det_pos = np*how_many_det_pos;

auto sum_cx = Complex{0.0_prt, 0.0_prt};
auto sum_cy = Complex{0.0_prt, 0.0_prt};
auto sum_cz = Complex{0.0_prt, 0.0_prt};
amrex::ParallelFor(
np_times_det_pos, [=] AMREX_GPU_DEVICE(int ii){
const int ip = ii / (np);
const int i_det = ii % (np);

for (int ip = 0; ip < np; ++ip){
amrex::ParticleReal xp, yp, zp;
GetPosition.AsStored(ip, xp, yp, zp);

const auto ux = 0.5_prt*(p_ux[ip] + p_ux_old[ip]);
const auto uy = 0.5_prt*(p_uy[ip] + p_uy_old[ip]);
const auto uz = 0.5_prt*(p_uz[ip] + p_uz_old[ip]);

auto const u2 = ux*ux + uy*uy + uz*uz;

auto const one_over_gamma = 1._prt/std::sqrt(1.0_rt + u2*inv_c2);
auto const one_over_gamma_c = one_over_gamma*inv_c;
const auto bx = p_b_x[ip];
const auto by = p_b_y[ip];
const auto bz = p_b_z[ip];

const auto bx = ux*one_over_gamma_c;
const auto by = uy*one_over_gamma_c;
const auto bz = uz*one_over_gamma_c;
const auto bpx = p_bp_x[ip];
const auto bpy = p_bp_y[ip];
const auto bpz = p_bp_z[ip];

const auto one_over_dt_gamma_c = one_over_gamma_c/dt;
const auto w = p_w[ip];

const auto bpx = (p_ux[ip] - p_ux_old[ip])*one_over_dt_gamma_c;
const auto bpy = (p_uy[ip] - p_uy_old[ip])*one_over_dt_gamma_c;
const auto bpz = (p_uz[ip] - p_uz_old[ip])*one_over_dt_gamma_c;
const auto nx = p_det_n_x[i_det];
const auto ny = p_det_n_y[i_det];
const auto nz = p_det_n_z[i_det];

//Calculation of 1_beta.n, n corresponds to m_det_direction, the direction of the normal
const auto one_minus_b_dot_n = 1.0_prt - (bx*nx + by*ny + bz*nz);

//Calculation of 1_beta.n, n corresponds to m_det_direction, the direction of the normal
const auto n_minus_beta_x = nx - bx;
const auto n_minus_beta_y = ny - by;
const auto n_minus_beta_z = nz - bz;
Expand All @@ -512,45 +516,42 @@ void RadiationHandler::add_radiation_contribution(
const auto n_cross_n_minus_beta_cross_bp_y = nz*n_minus_beta_cross_bp_x - nx*n_minus_beta_cross_bp_z;
const auto n_cross_n_minus_beta_cross_bp_z = nx*n_minus_beta_cross_bp_y - ny*n_minus_beta_cross_bp_x;

const auto n_dot_r = nx*xp + ny*yp + nz*zp;
const auto phase_term = amrex::exp(i_omega_over_c*(c*current_time - (n_dot_r)));

const auto FF = p_m_FF[i_om*how_many_det_pos + i_det];
const auto form_factor = std::sqrt(p_w[ip] + (p_w[ip]*p_w[ip]-p_w[ip])*FF);

const auto coeff = q*phase_term/(one_minus_b_dot_n*one_minus_b_dot_n)*form_factor;

//Nyquist limiter
const amrex::Real nyquist_flag = (p_omegas[i_om] < ablastr::constant::math::pi/one_minus_b_dot_n/dt);

const auto cx = coeff*n_cross_n_minus_beta_cross_bp_x*nyquist_flag;
const auto cy = coeff*n_cross_n_minus_beta_cross_bp_y*nyquist_flag;
const auto cz = coeff*n_cross_n_minus_beta_cross_bp_z*nyquist_flag;
const auto nyquist_threshold = ablastr::constant::math::pi/one_minus_b_dot_n/dt;

sum_cx += cx;
sum_cy += cy;
sum_cz += cz;
}

const int ncomp = 3;
const int idx0 = (i_om*how_many_det_pos + i_det)*ncomp;
const int idx1 = idx0 + 1;
const int idx2 = idx0 + 2;

#if defined(AMREX_USE_OMP)
const auto n_dot_r = nx*xp + ny*yp + nz*zp;

amrex::HostDevice::Atomic::Add(&p_radiation_data[idx0].m_real, sum_cx.m_real);
amrex::HostDevice::Atomic::Add(&p_radiation_data[idx0].m_imag, sum_cx.m_imag);
amrex::HostDevice::Atomic::Add(&p_radiation_data[idx1].m_real, sum_cy.m_real);
amrex::HostDevice::Atomic::Add(&p_radiation_data[idx1].m_imag, sum_cy.m_imag);
amrex::HostDevice::Atomic::Add(&p_radiation_data[idx2].m_real, sum_cz.m_real);
amrex::HostDevice::Atomic::Add(&p_radiation_data[idx2].m_imag, sum_cz.m_imag);
#else
p_radiation_data[idx0] += sum_cx;
p_radiation_data[idx1] += sum_cy;
p_radiation_data[idx2] += sum_cz;
#ifdef AMREX_USE_OMP
#pragma omp simd
#endif
});
for (int i_om = 0; i_om < omega_points; ++i_om){
const auto i_omega_over_c = Complex{0.0_prt, 1.0_prt}*p_omegas[i_om]*inv_c;
const auto phase_term = amrex::exp(i_omega_over_c*(c*current_time - (n_dot_r)));

const auto FF = p_m_FF[i_om*how_many_det_pos + i_det];
const auto form_factor = std::sqrt(w + (w*w-w)*FF);

const auto coeff = q*phase_term/(one_minus_b_dot_n*one_minus_b_dot_n)*form_factor;

//Nyquist limiter
const amrex::Real nyquist_flag = (p_omegas[i_om] < nyquist_threshold);

const auto cx = coeff*n_cross_n_minus_beta_cross_bp_x*nyquist_flag;
const auto cy = coeff*n_cross_n_minus_beta_cross_bp_y*nyquist_flag;
const auto cz = coeff*n_cross_n_minus_beta_cross_bp_z*nyquist_flag;

constexpr int ncomp = 3;
const int idx0 = (i_om*how_many_det_pos + i_det)*ncomp;
const int idx1 = idx0 + 1;
const int idx2 = idx0 + 2;

amrex::HostDevice::Atomic::Add(&p_radiation_data[idx0].m_real, cx.m_real);
amrex::HostDevice::Atomic::Add(&p_radiation_data[idx0].m_imag, cx.m_imag);
amrex::HostDevice::Atomic::Add(&p_radiation_data[idx1].m_real, cy.m_real);
amrex::HostDevice::Atomic::Add(&p_radiation_data[idx1].m_imag, cy.m_imag);
amrex::HostDevice::Atomic::Add(&p_radiation_data[idx2].m_real, cz.m_real);
amrex::HostDevice::Atomic::Add(&p_radiation_data[idx2].m_imag, cz.m_imag);
}
});
}
}
}
Expand Down

0 comments on commit 31bfd4d

Please sign in to comment.