diff --git a/include/torch_ucc.hpp b/include/torch_ucc.hpp index d60f475..d7087a5 100644 --- a/include/torch_ucc.hpp +++ b/include/torch_ucc.hpp @@ -76,26 +76,19 @@ namespace c10d { } while (0) #ifdef USE_CUDA -#define SAVE_TENSOR(_TENSOR, _DATA) \ - if ((_TENSOR).device().is_cuda()) { \ - c10::cuda::CUDACachingAllocator::recordStream( \ - (_TENSOR).storage().data_ptr(), (*stream)); \ - } else { \ - (_DATA) = {(_TENSOR)}; \ - } +#define SAVE_TENSORS(_TENSORS, _DATA) \ + do { \ + if ((_TENSORS)[0].device().is_cuda()) { \ + for (const auto i : c10::irange((_TENSORS).size())) { \ + c10::cuda::CUDACachingAllocator::recordStream( \ + (_TENSORS)[i].storage().data_ptr(), (*stream)); \ + } \ + } else { \ + (_DATA) = (_TENSORS); \ + } \ + } while (0) -#define SAVE_TENSORS(_TENSORS, _DATA) \ - if ((_TENSORS)[0].device().is_cuda()) { \ - for (const auto i : c10::irange((_TENSORS).size())) { \ - c10::cuda::CUDACachingAllocator::recordStream( \ - (_TENSORS)[i].storage().data_ptr(), (*stream)); \ - } \ - } else { \ - (_DATA) = (_TENSORS); \ - } #else -#define SAVE_TENSOR(_TENSOR, _DATA) (_DATA) = {(_TENSOR)}; - #define SAVE_TENSORS(_TENSORS, _DATA) (_DATA) = (_TENSORS); #endif @@ -135,8 +128,8 @@ class ProcessGroupUCC : public ProcessGroup { class AllgatherWorkData : public WorkData { public: AllgatherWorkData(int size) - : recv_lengths(size), - recv_offsets(size) {} + : recv_lengths(size), + recv_offsets(size) {} std::vector recv_lengths; std::vector recv_offsets; }; @@ -153,6 +146,7 @@ class ProcessGroupUCC : public ProcessGroup { ucc_ee_h ee, CommBase* comm) : ProcessGroup::Work(-1, opType), + ee_(ee), status_(status), request_(request), comm_(comm) {} @@ -162,14 +156,29 @@ class ProcessGroupUCC : public ProcessGroup { bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override; void finalize(); std::unique_ptr data; +#ifdef USE_UCC_FUTURE + void finishWorkUCC(); + void finishWorkUCCError(std::exception_ptr eptr); + c10::intrusive_ptr getFuture() override; +#endif #ifdef USE_CUDA std::unique_ptr fence = nullptr; event_pool_t* ep = nullptr; + at::cuda::CUDAStream* stream; #endif protected: + ucc_coll_args_t args; + ucc_team_h team; + ucc_ee_h ee_; ucc_status_t status_; ucc_coll_req_h request_; CommBase* comm_; + +#ifdef USE_UCC_FUTURE + private: + // The future returned by getFuture. + c10::intrusive_ptr future_; +#endif }; explicit ProcessGroupUCC( @@ -293,8 +302,11 @@ class CommPG { std::condition_variable queue_produce_cv; std::condition_variable queue_consume_cv; std::deque> progress_queue; + std::deque> post_queue; bool stop_progress_loop; + void post_cpu_collective(c10::intrusive_ptr post_req); + void post_cuda_collective(c10::intrusive_ptr post_req); public: c10::DeviceIndex cuda_device_index; CommPG(torch_ucc_oob_coll_info_t* oob_info, @@ -328,7 +340,7 @@ class CommPG { ucc_team_h& team, ucc_ee_h ee, std::unique_ptr cuda_ev, - const at::cuda::CUDAStream& stream, + at::cuda::CUDAStream& stream, event_pool_t* ep); #endif diff --git a/src/torch_ucc.cpp b/src/torch_ucc.cpp index 6e0819a..b05cf4a 100644 --- a/src/torch_ucc.cpp +++ b/src/torch_ucc.cpp @@ -58,6 +58,7 @@ const std::map ucc_op_map = { struct torch_ucc_config_t { std::once_flag flag; std::array blocking_wait; + bool use_future; } torch_ucc_config; void read_confg() { @@ -84,6 +85,14 @@ void read_confg() { torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BROADCAST] = std::atoi(env); } +#ifdef USE_UCC_FUTURE + env = std::getenv("TORCH_UCC_USE_FUTURE"); + if (env) { + torch_ucc_config.use_future = !!std::atoi(env); + } else { + torch_ucc_config.use_future = true; + } +#endif } void check_device(c10::Device dev1, c10::Device dev2) { @@ -125,6 +134,8 @@ bool ProcessGroupUCC::WorkUCC::isSuccess() const { } bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) { + volatile ucc_status_t *coll_status = &status_; + while (*coll_status == UCC_OPERATION_INITIALIZED) {} #ifdef USE_CUDA if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) { // block user stream @@ -138,6 +149,30 @@ bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) { return true; } +#ifdef USE_UCC_FUTURE +void ProcessGroupUCC::WorkUCC::finishWorkUCCError(std::exception_ptr eptr) { + if (torch_ucc_config.use_future) { + future_->setError(eptr); + } + finish(eptr); +} +void ProcessGroupUCC::WorkUCC::finishWorkUCC() { + if (torch_ucc_config.use_future && future_) { + if (!data || data->dst.size() == 0) { + future_->markCompleted(c10::IValue(std::vector())); + } else { + future_->markCompleted(c10::IValue(data->dst)); + } + } + finish(); +} + +c10::intrusive_ptr ProcessGroupUCC::WorkUCC:: + getFuture() { + return future_; +} +#endif // #ifdef USE_UCC_FUTURE + void ProcessGroupUCC::WorkUCC::finalize() { if (request_ != nullptr) { if (isP2POp(opType_)) { @@ -149,6 +184,9 @@ void ProcessGroupUCC::WorkUCC::finalize() { status_ = UCC_OK; request_ = nullptr; } +#ifdef USE_UCC_FUTURE + finishWorkUCC(); +#endif } CommPG::CommPG(torch_ucc_oob_coll_info_t* oob_info, @@ -166,7 +204,7 @@ CommPG::CommPG(torch_ucc_oob_coll_info_t* oob_info, CommPG::~CommPG() { std::unique_lock lock(mutex); - queue_consume_cv.wait(lock, [&] { return progress_queue.empty(); }); + queue_consume_cv.wait(lock, [&] { return progress_queue.empty() && post_queue.empty(); }); stop_progress_loop = true; lock.unlock(); queue_produce_cv.notify_all(); @@ -373,25 +411,22 @@ c10::intrusive_ptr CommPG::enqueue_collective( ucc_coll_args_t& coll, std::unique_ptr data, ucc_team_h& team) { - ucc_coll_req_h request; - ucc_status_t st; - st = ucc_collective_init(&coll, &request, team); - if (st != UCC_OK) { - LOG(ERROR) << "failed to init collective: " << ucc_status_string(st); - throw std::runtime_error(ucc_status_string(st)); - } - st = ucc_collective_post(request); - if (st != UCC_OK) { - LOG(ERROR) << "failed to post collective: " << ucc_status_string(st); - throw std::runtime_error(ucc_status_string(st)); - } auto work = c10::make_intrusive( - opType, UCC_INPROGRESS, request, nullptr, &ucc_comm); + opType, UCC_OPERATION_INITIALIZED, nullptr, nullptr, &ucc_comm); + work->args = coll; + work->team = team; work->data = std::move(data); +#ifdef USE_UCC_FUTURE + if (torch_ucc_config.use_future) { + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + } +#endif std::unique_lock lock(mutex); - progress_queue.push_back(work); + post_queue.push_back(work); lock.unlock(); queue_produce_cv.notify_one(); + return work; } @@ -403,42 +438,69 @@ c10::intrusive_ptr CommPG::enqueue_cuda_collective( ucc_team_h& team, ucc_ee_h ee, std::unique_ptr cuda_ev, - const at::cuda::CUDAStream& stream, + at::cuda::CUDAStream& stream, event_pool_t* ep) { - ucc_coll_req_h request; + auto work = c10::make_intrusive( + opType, UCC_OPERATION_INITIALIZED, nullptr, ee, &ucc_comm); + work->args = coll; + work->team = team; + work->ep = ep; + work->data = std::move(data); + work->fence = std::move(cuda_ev); + work->stream = &stream; + std::unique_lock lock(mutex); + post_queue.push_back(work); + lock.unlock(); + queue_produce_cv.notify_one(); + return work; +} +#endif + +void CommPG::post_cpu_collective(c10::intrusive_ptr post_req) { ucc_status_t st; - st = ucc_collective_init(&coll, &request, team); + + st = ucc_collective_init(&post_req->args, &post_req->request_, post_req->team); if (st != UCC_OK) { LOG(ERROR) << "failed to init collective: " << ucc_status_string(st); throw std::runtime_error(ucc_status_string(st)); } + + st = ucc_collective_post(post_req->request_); + if (st != UCC_OK) { + LOG(ERROR) << "failed to post collective: " << ucc_status_string(st); + throw std::runtime_error(ucc_status_string(st)); + } + progress_queue.push_back(post_req); + post_req->status_ = UCC_INPROGRESS; +} + +void CommPG::post_cuda_collective(c10::intrusive_ptr post_req) { + ucc_status_t st; ucc_ev_t comp_ev, *post_ev; + + st = ucc_collective_init(&post_req->args, &post_req->request_, post_req->team); + if (st != UCC_OK) { + LOG(ERROR) << "failed to init collective: " << ucc_status_string(st); + throw std::runtime_error(ucc_status_string(st)); + } + comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE; comp_ev.ev_context = nullptr; comp_ev.ev_context_size = 0; - comp_ev.req = request; - st = ucc_collective_triggered_post(ee, &comp_ev); + comp_ev.req = post_req->request_; + st = ucc_collective_triggered_post(post_req->ee_, &comp_ev); if (st != UCC_OK) { LOG(ERROR) << "failed to post triggered collective: " << ucc_status_string(st); throw std::runtime_error(ucc_status_string(st)); } - st = ucc_ee_get_event(ee, &post_ev); + st = ucc_ee_get_event(post_req->ee_, &post_ev); TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST); - ucc_ee_ack_event(ee, post_ev); - auto work = c10::make_intrusive( - opType, UCC_INPROGRESS, request, ee, &ucc_comm); - work->data = std::move(data); - work->ep = ep; - cuda_ev->record(stream); - work->fence = std::move(cuda_ev); - std::unique_lock lock(mutex); - progress_queue.push_back(work); - lock.unlock(); - queue_produce_cv.notify_one(); - return work; + ucc_ee_ack_event(post_req->ee_, post_ev); + post_req->fence->record(*(post_req->stream)); + progress_queue.push_back(post_req); + post_req->status_ = UCC_INPROGRESS; } -#endif void CommPG::progress_loop() { std::unique_lock lock(mutex); @@ -446,26 +508,42 @@ void CommPG::progress_loop() { bool device_set = false; #endif while (!stop_progress_loop) { - if (progress_queue.empty()) { + if (post_queue.empty() && progress_queue.empty()) { queue_produce_cv.wait(lock); continue; } - auto work = progress_queue.front(); - progress_queue.pop_front(); - lock.unlock(); - queue_consume_cv.notify_one(); #ifdef USE_CUDA if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) { c10::cuda::set_device(cuda_device_index); device_set = true; } #endif - while (work->request_->status > 0) { - // operation initialized is in progress or - work->comm_->progress(); + while (!post_queue.empty()) { + auto work = post_queue.front(); + post_queue.pop_front(); + if (work->fence) { + post_cuda_collective(work); + } else { + post_cpu_collective(work); + } + } + lock.unlock(); + queue_consume_cv.notify_one(); + auto work = progress_queue.front(); + progress_queue.pop_front(); + try { + while (work->request_->status > 0) { + work->comm_->progress(); + } + work->finalize(); + work->data.reset(); + } catch (...) { +#ifdef USE_UCC_FUTURE + work->finishWorkUCCError(std::current_exception()); +#else + work->finish(std::current_exception()); +#endif } - work->finalize(); - work->data.reset(); lock.lock(); } } @@ -518,11 +596,27 @@ c10::intrusive_ptr ProcessGroupUCC::collective_post( } cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index())); cuda_ev->block(*stream); - auto work = comm->enqueue_cuda_collective(opType, coll, std::move(data), - team, cuda_ee, std::move(cuda_ev), *stream, &ep); + c10::intrusive_ptr work = + comm->enqueue_cuda_collective( + opType, + coll, + std::move(data), + team, + cuda_ee, + std::move(cuda_ev), + *stream, + &ep); +#ifdef USE_UCC_FUTURE + if (torch_ucc_config.use_future) { + c10::cuda::CUDAMultiStreamGuard streamGuard(*stream); + std::vector devList{dev}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devList); + } +#endif // #ifdef USE_UCC_FUTURE return work; } -#endif +#endif // #ifdef USE_CUDA return comm->enqueue_collective(opType, coll, std::move(data), team); } @@ -593,7 +687,7 @@ c10::intrusive_ptr ProcessGroupUCC::allreduce( coll.dst.info.count = tensor.numel(); coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type()); coll.dst.info.mem_type = ucc_mtype_map.at(tensor.device().type()); - SAVE_TENSORS(tensors, data->src); + SAVE_TENSORS(tensors, data->dst); return collective_post( OpType::ALLREDUCE, @@ -671,8 +765,10 @@ c10::intrusive_ptr ProcessGroupUCC::alltoall_base( coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER | UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER; } - SAVE_TENSOR(inputTensor, data->src); - SAVE_TENSOR(outputTensor, data->dst); + std::vector inputTensors = {inputTensor}; + std::vector outputTensors = {outputTensor}; + SAVE_TENSORS(inputTensors, data->src); + SAVE_TENSORS(outputTensors, data->dst); return collective_post( OpType::ALLTOALL_BASE, @@ -707,7 +803,7 @@ c10::intrusive_ptr ProcessGroupUCC::broadcast( coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type()); coll.src.info.mem_type = ucc_mtype_map.at(tensor.device().type()); coll.root = opts.rootRank; - SAVE_TENSORS(tensors, data->src); + SAVE_TENSORS(tensors, data->dst); return collective_post( OpType::BROADCAST, diff --git a/src/torch_ucc_comm.cpp b/src/torch_ucc_comm.cpp index 16ecda4..84d395a 100644 --- a/src/torch_ucc_comm.cpp +++ b/src/torch_ucc_comm.cpp @@ -43,7 +43,7 @@ CommUCX::CommUCX(int comm_size) { } memset(&worker_params, 0, sizeof(ucp_worker_params_t)); worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - worker_params.thread_mode = UCS_THREAD_MODE_MULTI; + worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; st = ucp_worker_create(context, &worker_params, &worker); if (st != UCS_OK) { LOG(ERROR) << "failed to create UCP worker: " << ucs_status_string(st); @@ -133,7 +133,7 @@ CommUCC::CommUCC(torch_ucc_oob_coll_info_t* oob_info) { } memset(&lib_params, 0, sizeof(ucc_lib_params_t)); lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE; - lib_params.thread_mode = UCC_THREAD_MULTIPLE; + lib_params.thread_mode = UCC_THREAD_SINGLE; st = ucc_init(&lib_params, lib_config, &lib); ucc_lib_config_release(lib_config); if (st != UCC_OK) { @@ -147,11 +147,6 @@ CommUCC::CommUCC(torch_ucc_oob_coll_info_t* oob_info) { LOG(ERROR) << "failed to query for lib attr: " << ucc_status_string(st); throw std::runtime_error(ucc_status_string(st)); } - if (lib_attr.thread_mode != UCC_THREAD_MULTIPLE) { - LOG(ERROR) << "ucc library wasn't initialized with mt support " - << "check ucc compile options "; - throw std::runtime_error("failed to init ucc lib"); - } st = ucc_context_config_read(lib, NULL, &context_config); if (st != UCC_OK) { ucc_finalize(lib);