Skip to content

Commit

Permalink
use ParallelFor for particle loops
Browse files Browse the repository at this point in the history
  • Loading branch information
BenWibking committed Jan 19, 2024
1 parent 273ea0d commit fa8c675
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions src/simulation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1113,25 +1113,29 @@ template <typename problem_t> void AMRSimulation<problem_t>::kickParticlesAllLev
AMREX_ALWAYS_ASSERT(!accel_mf.contains_nan(0, AMREX_SPACEDIM));
AMREX_ALWAYS_ASSERT(!accel_mf.contains_nan());

// loop over particles on this level
// loop over boxes of particles on this level
for (quokka::CICParticleIterator pIter(*CICParticles, lev); pIter.isValid(); ++pIter) {
auto &particles = pIter.GetArrayOfStructs();
quokka::CICParticleContainer::ParticleType *pData = particles().data();
const amrex::Long np = pIter.numParticles();

amrex::Array4<const amrex::Real> const &accel = accel_mf.array(pIter);
const auto plo = geom[lev].ProbLoArray();
const auto dx_inv = geom[lev].InvCellSizeArray();

for (auto &p : particles) {
amrex::ParallelFor(np, [=] AMREX_GPU_DEVICE(int idx) {
quokka::CICParticleContainer::ParticleType &p = pData[idx];
amrex::ParticleInterpolator::Linear interp(p, plo, dx_inv);
interp.MeshToParticle(
p, accel, 0, quokka::ParticleVxIdx, AMREX_SPACEDIM,
[=] AMREX_GPU_DEVICE(amrex::Array4<const amrex::Real> const &arr, int i, int j, int k, int comp) {
return arr(i, j, k, comp); // no weighting
},
[=] AMREX_GPU_DEVICE(quokka::CICParticleContainer::ParticleType & p, int const comp, amrex::Real const acc_comp) {
[=] AMREX_GPU_DEVICE(quokka::CICParticleContainer::ParticleType & p, int comp, amrex::Real acc_comp) {
// kick particle by updating its velocity
p.rdata(comp) += 0.5 * dt * static_cast<amrex::ParticleReal>(acc_comp);
});
}
});
}
}
}
Expand All @@ -1145,12 +1149,16 @@ template <typename problem_t> void AMRSimulation<problem_t>::driftParticlesAllLe
for (int lev = 0; lev <= finest_level; ++lev) {
for (quokka::CICParticleIterator pIter(*CICParticles, lev); pIter.isValid(); ++pIter) {
auto &particles = pIter.GetArrayOfStructs();
for (auto &particle : particles) {
quokka::CICParticleContainer::ParticleType *pData = particles().data();
const amrex::Long np = pIter.numParticles();

amrex::ParallelFor(np, [=] AMREX_GPU_DEVICE(int idx) {
quokka::CICParticleContainer::ParticleType &p = pData[idx];
// update particle position
for (int i = 0; i < AMREX_SPACEDIM; ++i) {
particle.pos(i) += dt * particle.rdata(quokka::ParticleVxIdx + i);
p.pos(i) += dt * p.rdata(quokka::ParticleVxIdx + i);
}
}
});
}
}
}
Expand Down

0 comments on commit fa8c675

Please sign in to comment.