From d9eb1232b8981110accc60ff06c03fd4253d0ec5 Mon Sep 17 00:00:00 2001 From: Allan Zyne Date: Mon, 18 Nov 2024 22:40:13 -0800 Subject: [PATCH 1/2] WaitPermute only allow move --- src/idtr.cpp | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/src/idtr.cpp b/src/idtr.cpp index 11869ea..b8ebf64 100644 --- a/src/idtr.cpp +++ b/src/idtr.cpp @@ -525,7 +525,7 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype, if (isStrided) { unpack(rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims, oDataPtr); - delete[](char *) rBuff; + delete[] (char *)rBuff; } }; assert(sendbuff.empty() && sszs.empty() && soffs.empty() && rszs.empty() && @@ -735,18 +735,28 @@ template class WaitPermute { SHARPY::rank_type cRank, SHARPY::rank_type nRanks, std::vector &&parts, std::vector &&axes, std::vector oGShape, ndarray &&input, - ndarray &&output, std::vector &&receiveBuffer, - std::vector &&receiveOffsets, + ndarray &&output, std::vector &&sendBuffer, + std::vector &&sendOffsets, std::vector &&sendSizes, + std::vector &&receiveBuffer, std::vector &&receiveOffsets, std::vector &&receiveSizes) : tc(tc), hdl(hdl), cRank(cRank), nRanks(nRanks), parts(std::move(parts)), axes(std::move(axes)), oGShape(std::move(oGShape)), input(std::move(input)), output(std::move(output)), + sendBuffer(std::move(sendBuffer)), sendOffsets(std::move(sendOffsets)), + sendSizes(std::move(sendSizes)), receiveBuffer(std::move(receiveBuffer)), receiveOffsets(std::move(receiveOffsets)), receiveSizes(std::move(receiveSizes)) {} + // Only allow move + WaitPermute(const WaitPermute &) = delete; + WaitPermute &operator=(const WaitPermute &) = delete; + WaitPermute(WaitPermute &&) = default; + WaitPermute &operator=(WaitPermute &&) = default; + void operator()() { tc->wait(hdl); + std::vector> receiveRankBuffer(nRanks); for (size_t rank = 0; rank < nRanks; ++rank) { auto &rankBuffer = receiveRankBuffer[rank]; @@ -755,6 +765,7 @@ template class WaitPermute { receiveBuffer.begin() + receiveOffsets[rank] + receiveSizes[rank]); } + // FIXME: very low efficiency, need to improve std::vector receiveRankBufferCount(nRanks, 0); input.globalIndices([&](const id &inputIndex) { id outputIndex = inputIndex.permute(axes); @@ -777,6 +788,9 @@ template class WaitPermute { std::vector oGShape; ndarray input; ndarray output; + std::vector sendBuffer; + std::vector sendOffsets; + std::vector sendSizes; std::vector receiveBuffer; std::vector receiveOffsets; std::vector receiveSizes; @@ -870,6 +884,7 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype, for (auto i = 0ul; i < nRanks; ++i) { dspl[i] = 4 * i; } + tc->gather(parts.data(), counts.data(), dspl.data(), SHARPY::INT64, SHARPY::REPLICATED); @@ -919,10 +934,12 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype, sendOffsets.data(), sharpytype, receiveBuffer.data(), receiveSizes.data(), receiveOffsets.data()); - auto wait = WaitPermute(tc, hdl, cRank, nRanks, std::move(parts), - std::move(axes), std::move(oGShape), std::move(input), - std::move(output), std::move(receiveBuffer), - std::move(receiveOffsets), std::move(receiveSizes)); + auto wait = + WaitPermute(tc, hdl, cRank, nRanks, std::move(parts), std::move(axes), + std::move(oGShape), std::move(input), std::move(output), + std::move(sendBuffer), std::move(sendOffsets), + std::move(sendSizes), std::move(receiveBuffer), + std::move(receiveOffsets), std::move(receiveSizes)); assert(parts.empty() && axes.empty() && receiveBuffer.empty() && receiveOffsets.empty() && receiveSizes.empty()); From d66450400b2fe32d87606c4a2dd221bd165ed991 Mon Sep 17 00:00:00 2001 From: Yang Zhao Date: Tue, 19 Nov 2024 14:49:54 +0800 Subject: [PATCH 2/2] Update src/idtr.cpp format --- src/idtr.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/idtr.cpp b/src/idtr.cpp index b8ebf64..f0da0d7 100644 --- a/src/idtr.cpp +++ b/src/idtr.cpp @@ -525,7 +525,7 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype, if (isStrided) { unpack(rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims, oDataPtr); - delete[] (char *)rBuff; + delete[](char *) rBuff; } }; assert(sendbuff.empty() && sszs.empty() && soffs.empty() && rszs.empty() &&