Skip to content

Commit

Permalink
optimize treelet sync for GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
sekelle committed Aug 6, 2024
1 parent ed4fb87 commit cf3407c
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 17 deletions.
37 changes: 37 additions & 0 deletions domain/include/cstone/focus/exchange_focus.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include "cstone/domain/index_ranges.hpp"
#include "cstone/primitives/gather.hpp"
#include "cstone/primitives/mpi_wrappers.hpp"
#include "cstone/primitives/primitives_gpu.h"
#include "cstone/tree/csarray.hpp"
#include "cstone/tree/octree.hpp"
#include "cstone/util/gsl-lite.hpp"
Expand Down Expand Up @@ -310,6 +311,42 @@ void syncTreelets(gsl::span<const int> peers,
indexTreelets<KeyType>(octree.prefixes, octree.levelRange, treelets, treeletIdx);
}

template<class KeyType, class DAlloc>
void syncTreeletsGpu(gsl::span<const int> peers,
gsl::span<const IndexPair<TreeNodeIndex>> assignment,
const std::vector<KeyType>& leaves,
OctreeData<KeyType, GpuTag>& octreeAcc,
thrust::device_vector<KeyType, DAlloc>& leavesAcc,
std::vector<std::vector<KeyType>>& treelets,
std::vector<std::vector<TreeNodeIndex>>& treeletIdx)
{
exchangeTreelets<KeyType>(peers, assignment, leaves, treelets);
checkTreelets<KeyType>(leaves, treelets, treeletIdx);

std::vector<TreeNodeIndex> nodeOps(leaves.size(), 1);
exchangeRejectedKeys<KeyType>(peers, leaves, treelets, treeletIdx, nodeOps);
pruneTreelets<KeyType>(peers, treelets, treeletIdx);

if (std::count(nodeOps.begin(), nodeOps.end(), 1) != nodeOps.size())
{
assert(octreeAcc.childOffsets.size() >= nodeOps.size());
gsl::span<TreeNodeIndex> nops(rawPtr(octreeAcc.childOffsets), nodeOps.size());
memcpyH2D(rawPtr(nodeOps), nodeOps.size(), nops.data());

exclusiveScanGpu(nops.data(), nops.data() + nops.size(), nops.data());
TreeNodeIndex newNumLeafNodes;
memcpyD2H(nops.data() + nops.size() - 1, 1, &newNumLeafNodes);

auto& newLeaves = octreeAcc.prefixes;
reallocateDestructive(newLeaves, newNumLeafNodes + 1, 1.05);
rebalanceTreeGpu(rawPtr(leavesAcc), nNodes(leavesAcc), newNumLeafNodes, nops.data(), rawPtr(newLeaves));
swap(newLeaves, leavesAcc);

octreeAcc.resize(nNodes(leavesAcc));
buildOctreeGpu(rawPtr(leavesAcc), octreeAcc.data());
}
}

template<class T>
void exchangeTreeletGeneral(gsl::span<const int> peerRanks,
const std::vector<std::vector<TreeNodeIndex>>& peerTrees,
Expand Down
14 changes: 11 additions & 3 deletions domain/include/cstone/focus/octree_focus_mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ class FocusedOctree
while (not macRefineGpu(octreeAcc_, leavesAcc_, centersAcc_, macsAcc_, prevFocusStart, prevFocusEnd,
focusStart, focusEnd, invThetaRefine, box))
;

reallocateDestructive(leaves_, leavesAcc_.size(), allocGrowthRate_);
memcpyD2H(rawPtr(leavesAcc_), leavesAcc_.size(), rawPtr(leaves_));
}
else
{
Expand All @@ -156,12 +159,17 @@ class FocusedOctree
focusEnd, invThetaRefine, box))
;
}
downloadOctree();
translateAssignment<KeyType>(assignment, leaves_, peers_, myRank_, assignment_);

syncTreelets(peers_, assignment_, treeData_, leaves_, treelets_, treeletIdx_);
if constexpr (HaveGpu<Accelerator>{})
{
syncTreeletsGpu(peers_, assignment_, leaves_, octreeAcc_, leavesAcc_, treelets_, treeletIdx_);
downloadOctree();
indexTreelets<KeyType>(treeData_.prefixes, treeData_.levelRange, treelets_, treeletIdx_);
}
else { syncTreelets(peers_, assignment_, treeData_, leaves_, treelets_, treeletIdx_); }

translateAssignment<KeyType>(assignment, leaves_, peers_, myRank_, assignment_);
uploadOctree();

prevFocusStart = focusStart;
prevFocusEnd = focusEnd;
Expand Down
4 changes: 2 additions & 2 deletions domain/test/integration_mpi/exchange_focus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void exchangeFocusIrregular(int myRank, int numRanks)
// finer resolution at one location outside the regular grid + cells that don't exist on rank 1
octreeMaker.divide(7).divide(7, 0);
treeLeavesRef[0] = octreeMaker.makeTree();
octreeMaker.divide(7,0,3);
octreeMaker.divide(7, 0, 3);
treeLeavesInitial[0] = octreeMaker.makeTree();
EXPECT_EQ(treeLeavesRef[0].size() + 7, treeLeavesInitial[0].size());
}
Expand All @@ -83,7 +83,7 @@ void exchangeFocusIrregular(int myRank, int numRanks)
}
// finer resolution at one location outside the regular grid
octreeMaker.divide(1).divide(1, 6);
treeLeavesRef[1] = octreeMaker.makeTree();
treeLeavesRef[1] = octreeMaker.makeTree();
treeLeavesInitial[1] = treeLeavesRef[1];
}

Expand Down
8 changes: 4 additions & 4 deletions domain/test/unit/focus/octree_focus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,15 +464,15 @@ TEST_F(MacRefinement, fullSurface)
while (!macRefine(octree, leaves, centers, macs, focusEnd, focusEnd, focusStart, focusEnd, invTheta, box)) {}

int numNodesVertex = 7 + 8;
int numNodesEdge = 6 + 2 * 8;
int numNodesFace = 4 + 4 * 8;
int numNodesEdge = 6 + 2 * 8;
int numNodesFace = 4 + 4 * 8;
EXPECT_EQ(nNodes(leaves), 64 + 7 + 3 * numNodesFace + 3 * numNodesEdge + numNodesVertex);
}

TEST_F(MacRefinement, noSurface)
{
Box<T> box(0, 1);
float invTheta = sqrt(3) / 2 + 1e-6;
float invTheta = sqrt(3) / 2 + 1e-6;
TreeNodeIndex numNodesStart = octree.numLeafNodes;

KeyType oldFStart = decodePlaceholderBit(KeyType(0101));
Expand All @@ -487,7 +487,7 @@ TEST_F(MacRefinement, noSurface)
TEST_F(MacRefinement, partialSurface)
{
Box<T> box(0, 1);
float invTheta = sqrt(3) / 2 + 1e-6;
float invTheta = sqrt(3) / 2 + 1e-6;
TreeNodeIndex numNodesStart = octree.numLeafNodes;

KeyType oldFStart = 0;
Expand Down
10 changes: 5 additions & 5 deletions domain/test/unit_cuda/focus/octree_focus.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ protected:

// L3 in first octant, L1 otherwise
h_leaves = om.makeTree();
leaves = h_leaves;
leaves = h_leaves;

octree.resize(nNodes(h_leaves));
buildOctreeGpu<KeyType>(rawPtr(leaves), octree.data());
Expand All @@ -57,15 +57,15 @@ TEST_F(MacRefinementGpu, fullSurface)
while (!macRefineGpu(octree, leaves, centers, macs, focusEnd, focusEnd, focusStart, focusEnd, invTheta, box)) {}

int numNodesVertex = 7 + 8;
int numNodesEdge = 6 + 2 * 8;
int numNodesFace = 4 + 4 * 8;
int numNodesEdge = 6 + 2 * 8;
int numNodesFace = 4 + 4 * 8;
EXPECT_EQ(nNodes(leaves), 64 + 7 + 3 * numNodesFace + 3 * numNodesEdge + numNodesVertex);
}

TEST_F(MacRefinementGpu, noSurface)
{
Box<T> box(0, 1);
float invTheta = sqrt(3) / 2 + 1e-6;
float invTheta = sqrt(3) / 2 + 1e-6;
TreeNodeIndex numNodesStart = octree.numLeafNodes;

KeyType oldFStart = decodePlaceholderBit(KeyType(0101));
Expand All @@ -80,7 +80,7 @@ TEST_F(MacRefinementGpu, noSurface)
TEST_F(MacRefinementGpu, partialSurface)
{
Box<T> box(0, 1);
float invTheta = sqrt(3) / 2 + 1e-6;
float invTheta = sqrt(3) / 2 + 1e-6;
TreeNodeIndex numNodesStart = octree.numLeafNodes;

KeyType oldFStart = 0;
Expand Down
5 changes: 2 additions & 3 deletions domain/test/unit_cuda/traversal/macs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "cstone/cuda/cuda_utils.cuh"
#include "cstone/focus/source_center.hpp"
#include "cstone/traversal/collisions_gpu.h"
#include "cstone/traversal/macs.hpp"
#include "cstone/tree/cs_util.hpp"
#include "cstone/tree/octree_gpu.h"

Expand Down Expand Up @@ -42,13 +41,13 @@ TEST(Macs, limitSource4x4_matchCPU)
markMacsGpu(ov.prefixes, ov.childOffsets, rawPtr(centers), box, rawPtr(leaves) + 0, 32, true, rawPtr(macs));
thrust::host_vector<char> h_macs = macs;

thrust::host_vector<char> macRef{1, 0, 0, 0, 0, 1, 1, 1, 1};
thrust::host_vector<char> macRef = std::vector<char>{1, 0, 0, 0, 0, 1, 1, 1, 1};
macRef.resize(ov.numNodes);
EXPECT_EQ(macRef, h_macs);

thrust::fill(macs.begin(), macs.end(), 0);
markMacsGpu(ov.prefixes, ov.childOffsets, rawPtr(centers), box, rawPtr(leaves) + 0, 32, false, rawPtr(macs));
h_macs = macs;
h_macs = macs;
int numMacs = std::accumulate(h_macs.begin(), h_macs.end(), 0);
EXPECT_EQ(numMacs, 5 + 16);
}

0 comments on commit cf3407c

Please sign in to comment.