Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Post collectives from progress thread #18

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions include/torch_ucc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,19 @@ namespace c10d {
} while (0)

#ifdef USE_CUDA
#define SAVE_TENSOR(_TENSOR, _DATA) \
if ((_TENSOR).device().is_cuda()) { \
c10::cuda::CUDACachingAllocator::recordStream( \
(_TENSOR).storage().data_ptr(), (*stream)); \
} else { \
(_DATA) = {(_TENSOR)}; \
}
#define SAVE_TENSORS(_TENSORS, _DATA) \
do { \
if ((_TENSORS)[0].device().is_cuda()) { \
for (const auto i : c10::irange((_TENSORS).size())) { \
c10::cuda::CUDACachingAllocator::recordStream( \
(_TENSORS)[i].storage().data_ptr(), (*stream)); \
} \
} else { \
(_DATA) = (_TENSORS); \
} \
} while (0)

#define SAVE_TENSORS(_TENSORS, _DATA) \
if ((_TENSORS)[0].device().is_cuda()) { \
for (const auto i : c10::irange((_TENSORS).size())) { \
c10::cuda::CUDACachingAllocator::recordStream( \
(_TENSORS)[i].storage().data_ptr(), (*stream)); \
} \
} else { \
(_DATA) = (_TENSORS); \
}
#else
#define SAVE_TENSOR(_TENSOR, _DATA) (_DATA) = {(_TENSOR)};

#define SAVE_TENSORS(_TENSORS, _DATA) (_DATA) = (_TENSORS);
#endif

Expand Down Expand Up @@ -135,8 +128,8 @@ class ProcessGroupUCC : public ProcessGroup {
class AllgatherWorkData : public WorkData {
public:
AllgatherWorkData(int size)
: recv_lengths(size),
recv_offsets(size) {}
: recv_lengths(size),
recv_offsets(size) {}
std::vector<uint64_t> recv_lengths;
std::vector<uint64_t> recv_offsets;
};
Expand All @@ -153,6 +146,7 @@ class ProcessGroupUCC : public ProcessGroup {
ucc_ee_h ee,
CommBase* comm)
: ProcessGroup::Work(-1, opType),
ee_(ee),
status_(status),
request_(request),
comm_(comm) {}
Expand All @@ -162,14 +156,29 @@ class ProcessGroupUCC : public ProcessGroup {
bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
void finalize();
std::unique_ptr<WorkData> data;
#ifdef USE_UCC_FUTURE
void finishWorkUCC();
void finishWorkUCCError(std::exception_ptr eptr);
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
#endif
#ifdef USE_CUDA
std::unique_ptr<at::cuda::CUDAEvent> fence = nullptr;
event_pool_t* ep = nullptr;
at::cuda::CUDAStream* stream;
#endif
protected:
ucc_coll_args_t args;
ucc_team_h team;
ucc_ee_h ee_;
ucc_status_t status_;
ucc_coll_req_h request_;
CommBase* comm_;

#ifdef USE_UCC_FUTURE
private:
// The future returned by getFuture.
c10::intrusive_ptr<at::ivalue::Future> future_;
#endif
};

explicit ProcessGroupUCC(
Expand Down Expand Up @@ -293,8 +302,11 @@ class CommPG {
std::condition_variable queue_produce_cv;
std::condition_variable queue_consume_cv;
std::deque<c10::intrusive_ptr<ProcessGroupUCC::WorkUCC>> progress_queue;
std::deque<c10::intrusive_ptr<ProcessGroupUCC::WorkUCC>> post_queue;
bool stop_progress_loop;

void post_cpu_collective(c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> post_req);
void post_cuda_collective(c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> post_req);
public:
c10::DeviceIndex cuda_device_index;
CommPG(torch_ucc_oob_coll_info_t* oob_info,
Expand Down Expand Up @@ -328,7 +340,7 @@ class CommPG {
ucc_team_h& team,
ucc_ee_h ee,
std::unique_ptr<at::cuda::CUDAEvent> cuda_ev,
const at::cuda::CUDAStream& stream,
at::cuda::CUDAStream& stream,
event_pool_t* ep);
#endif

Expand Down
Loading