diff --git a/include/torch_ucc.hpp b/include/torch_ucc.hpp index e0ac03d..d7087a5 100644 --- a/include/torch_ucc.hpp +++ b/include/torch_ucc.hpp @@ -146,6 +146,7 @@ class ProcessGroupUCC : public ProcessGroup { ucc_ee_h ee, CommBase* comm) : ProcessGroup::Work(-1, opType), + ee_(ee), status_(status), request_(request), comm_(comm) {} @@ -163,8 +164,12 @@ class ProcessGroupUCC : public ProcessGroup { #ifdef USE_CUDA std::unique_ptr fence = nullptr; event_pool_t* ep = nullptr; + at::cuda::CUDAStream* stream; #endif protected: + ucc_coll_args_t args; + ucc_team_h team; + ucc_ee_h ee_; ucc_status_t status_; ucc_coll_req_h request_; CommBase* comm_; @@ -297,8 +302,11 @@ class CommPG { std::condition_variable queue_produce_cv; std::condition_variable queue_consume_cv; std::deque> progress_queue; + std::deque> post_queue; bool stop_progress_loop; + void post_cpu_collective(c10::intrusive_ptr post_req); + void post_cuda_collective(c10::intrusive_ptr post_req); public: c10::DeviceIndex cuda_device_index; CommPG(torch_ucc_oob_coll_info_t* oob_info, @@ -332,7 +340,7 @@ class CommPG { ucc_team_h& team, ucc_ee_h ee, std::unique_ptr cuda_ev, - const at::cuda::CUDAStream& stream, + at::cuda::CUDAStream& stream, event_pool_t* ep); #endif diff --git a/src/torch_ucc.cpp b/src/torch_ucc.cpp index 349de9b..b05cf4a 100644 --- a/src/torch_ucc.cpp +++ b/src/torch_ucc.cpp @@ -134,6 +134,8 @@ bool ProcessGroupUCC::WorkUCC::isSuccess() const { } bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) { + volatile ucc_status_t *coll_status = &status_; + while (*coll_status == UCC_OPERATION_INITIALIZED) {} #ifdef USE_CUDA if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) { // block user stream @@ -202,7 +204,7 @@ CommPG::CommPG(torch_ucc_oob_coll_info_t* oob_info, CommPG::~CommPG() { std::unique_lock lock(mutex); - queue_consume_cv.wait(lock, [&] { return progress_queue.empty(); }); + queue_consume_cv.wait(lock, [&] { return progress_queue.empty() && post_queue.empty(); }); stop_progress_loop = true; lock.unlock(); queue_produce_cv.notify_all(); @@ -409,20 +411,10 @@ c10::intrusive_ptr CommPG::enqueue_collective( ucc_coll_args_t& coll, std::unique_ptr data, ucc_team_h& team) { - ucc_coll_req_h request; - ucc_status_t st; - st = ucc_collective_init(&coll, &request, team); - if (st != UCC_OK) { - LOG(ERROR) << "failed to init collective: " << ucc_status_string(st); - throw std::runtime_error(ucc_status_string(st)); - } - st = ucc_collective_post(request); - if (st != UCC_OK) { - LOG(ERROR) << "failed to post collective: " << ucc_status_string(st); - throw std::runtime_error(ucc_status_string(st)); - } auto work = c10::make_intrusive( - opType, UCC_INPROGRESS, request, nullptr, &ucc_comm); + opType, UCC_OPERATION_INITIALIZED, nullptr, nullptr, &ucc_comm); + work->args = coll; + work->team = team; work->data = std::move(data); #ifdef USE_UCC_FUTURE if (torch_ucc_config.use_future) { @@ -431,9 +423,10 @@ c10::intrusive_ptr CommPG::enqueue_collective( } #endif std::unique_lock lock(mutex); - progress_queue.push_back(work); + post_queue.push_back(work); lock.unlock(); queue_produce_cv.notify_one(); + return work; } @@ -445,42 +438,69 @@ c10::intrusive_ptr CommPG::enqueue_cuda_collective( ucc_team_h& team, ucc_ee_h ee, std::unique_ptr cuda_ev, - const at::cuda::CUDAStream& stream, + at::cuda::CUDAStream& stream, event_pool_t* ep) { - ucc_coll_req_h request; + auto work = c10::make_intrusive( + opType, UCC_OPERATION_INITIALIZED, nullptr, ee, &ucc_comm); + work->args = coll; + work->team = team; + work->ep = ep; + work->data = std::move(data); + work->fence = std::move(cuda_ev); + work->stream = &stream; + std::unique_lock lock(mutex); + post_queue.push_back(work); + lock.unlock(); + queue_produce_cv.notify_one(); + return work; +} +#endif + +void CommPG::post_cpu_collective(c10::intrusive_ptr post_req) { ucc_status_t st; - st = ucc_collective_init(&coll, &request, team); + + st = ucc_collective_init(&post_req->args, &post_req->request_, post_req->team); if (st != UCC_OK) { LOG(ERROR) << "failed to init collective: " << ucc_status_string(st); throw std::runtime_error(ucc_status_string(st)); } + + st = ucc_collective_post(post_req->request_); + if (st != UCC_OK) { + LOG(ERROR) << "failed to post collective: " << ucc_status_string(st); + throw std::runtime_error(ucc_status_string(st)); + } + progress_queue.push_back(post_req); + post_req->status_ = UCC_INPROGRESS; +} + +void CommPG::post_cuda_collective(c10::intrusive_ptr post_req) { + ucc_status_t st; ucc_ev_t comp_ev, *post_ev; + + st = ucc_collective_init(&post_req->args, &post_req->request_, post_req->team); + if (st != UCC_OK) { + LOG(ERROR) << "failed to init collective: " << ucc_status_string(st); + throw std::runtime_error(ucc_status_string(st)); + } + comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE; comp_ev.ev_context = nullptr; comp_ev.ev_context_size = 0; - comp_ev.req = request; - st = ucc_collective_triggered_post(ee, &comp_ev); + comp_ev.req = post_req->request_; + st = ucc_collective_triggered_post(post_req->ee_, &comp_ev); if (st != UCC_OK) { LOG(ERROR) << "failed to post triggered collective: " << ucc_status_string(st); throw std::runtime_error(ucc_status_string(st)); } - st = ucc_ee_get_event(ee, &post_ev); + st = ucc_ee_get_event(post_req->ee_, &post_ev); TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST); - ucc_ee_ack_event(ee, post_ev); - auto work = c10::make_intrusive( - opType, UCC_INPROGRESS, request, ee, &ucc_comm); - work->data = std::move(data); - work->ep = ep; - cuda_ev->record(stream); - work->fence = std::move(cuda_ev); - std::unique_lock lock(mutex); - progress_queue.push_back(work); - lock.unlock(); - queue_produce_cv.notify_one(); - return work; + ucc_ee_ack_event(post_req->ee_, post_ev); + post_req->fence->record(*(post_req->stream)); + progress_queue.push_back(post_req); + post_req->status_ = UCC_INPROGRESS; } -#endif void CommPG::progress_loop() { std::unique_lock lock(mutex); @@ -488,23 +508,31 @@ void CommPG::progress_loop() { bool device_set = false; #endif while (!stop_progress_loop) { - if (progress_queue.empty()) { + if (post_queue.empty() && progress_queue.empty()) { queue_produce_cv.wait(lock); continue; } - auto work = progress_queue.front(); - progress_queue.pop_front(); - lock.unlock(); - queue_consume_cv.notify_one(); #ifdef USE_CUDA if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) { c10::cuda::set_device(cuda_device_index); device_set = true; } #endif + while (!post_queue.empty()) { + auto work = post_queue.front(); + post_queue.pop_front(); + if (work->fence) { + post_cuda_collective(work); + } else { + post_cpu_collective(work); + } + } + lock.unlock(); + queue_consume_cv.notify_one(); + auto work = progress_queue.front(); + progress_queue.pop_front(); try { while (work->request_->status > 0) { - // operation initialized is in progress or work->comm_->progress(); } work->finalize(); diff --git a/src/torch_ucc_comm.cpp b/src/torch_ucc_comm.cpp index 16ecda4..84d395a 100644 --- a/src/torch_ucc_comm.cpp +++ b/src/torch_ucc_comm.cpp @@ -43,7 +43,7 @@ CommUCX::CommUCX(int comm_size) { } memset(&worker_params, 0, sizeof(ucp_worker_params_t)); worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - worker_params.thread_mode = UCS_THREAD_MODE_MULTI; + worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; st = ucp_worker_create(context, &worker_params, &worker); if (st != UCS_OK) { LOG(ERROR) << "failed to create UCP worker: " << ucs_status_string(st); @@ -133,7 +133,7 @@ CommUCC::CommUCC(torch_ucc_oob_coll_info_t* oob_info) { } memset(&lib_params, 0, sizeof(ucc_lib_params_t)); lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE; - lib_params.thread_mode = UCC_THREAD_MULTIPLE; + lib_params.thread_mode = UCC_THREAD_SINGLE; st = ucc_init(&lib_params, lib_config, &lib); ucc_lib_config_release(lib_config); if (st != UCC_OK) { @@ -147,11 +147,6 @@ CommUCC::CommUCC(torch_ucc_oob_coll_info_t* oob_info) { LOG(ERROR) << "failed to query for lib attr: " << ucc_status_string(st); throw std::runtime_error(ucc_status_string(st)); } - if (lib_attr.thread_mode != UCC_THREAD_MULTIPLE) { - LOG(ERROR) << "ucc library wasn't initialized with mt support " - << "check ucc compile options "; - throw std::runtime_error("failed to init ucc lib"); - } st = ucc_context_config_read(lib, NULL, &context_config); if (st != UCC_OK) { ucc_finalize(lib);