Skip to content

Commit

Permalink
support future in ProcessGroupUCC (#9)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #9

- Add `getFuture` and relevant functions to support proper synchronization on CUDA stream
- Remove `SAVE_TENSOR`, use `SAVE_TENSORS` to cover all cases
- build and runtime flags to enable/disable this feature
  - build flag: `-DUSE_UCC_FUTURE`
  - runtime env. variable: `TORCH_UCC_USE_FUTURE=1`

Differential Revision: D29272193

fbshipit-source-id: 68fdd076e0d72069910c4f9ba3149e39bad5bda6
  • Loading branch information
kingchc authored and facebook-github-bot committed Jul 7, 2021
1 parent 03064f0 commit d327b1e
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 32 deletions.
44 changes: 24 additions & 20 deletions include/torch_ucc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,19 @@ namespace c10d {
} while (0)

#ifdef USE_CUDA
#define SAVE_TENSOR(_TENSOR, _DATA) \
if ((_TENSOR).device().is_cuda()) { \
c10::cuda::CUDACachingAllocator::recordStream( \
(_TENSOR).storage().data_ptr(), (*stream)); \
} else { \
(_DATA) = {(_TENSOR)}; \
}
#define SAVE_TENSORS(_TENSORS, _DATA) \
do { \
if ((_TENSORS)[0].device().is_cuda()) { \
for (const auto i : c10::irange((_TENSORS).size())) { \
c10::cuda::CUDACachingAllocator::recordStream( \
(_TENSORS)[i].storage().data_ptr(), (*stream)); \
} \
} else { \
(_DATA) = (_TENSORS); \
} \
} while (0)

#define SAVE_TENSORS(_TENSORS, _DATA) \
if ((_TENSORS)[0].device().is_cuda()) { \
for (const auto i : c10::irange((_TENSORS).size())) { \
c10::cuda::CUDACachingAllocator::recordStream( \
(_TENSORS)[i].storage().data_ptr(), (*stream)); \
} \
} else { \
(_DATA) = (_TENSORS); \
}
#else
#define SAVE_TENSOR(_TENSOR, _DATA) (_DATA) = {(_TENSOR)};

#define SAVE_TENSORS(_TENSORS, _DATA) (_DATA) = (_TENSORS);
#endif

Expand Down Expand Up @@ -135,8 +128,8 @@ class ProcessGroupUCC : public ProcessGroup {
class AllgatherWorkData : public WorkData {
public:
AllgatherWorkData(int size)
: recv_lengths(size),
recv_offsets(size) {}
: recv_lengths(size),
recv_offsets(size) {}
std::vector<uint64_t> recv_lengths;
std::vector<uint64_t> recv_offsets;
};
Expand All @@ -162,6 +155,11 @@ class ProcessGroupUCC : public ProcessGroup {
bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
void finalize();
std::unique_ptr<WorkData> data;
#ifdef USE_UCC_FUTURE
void finishWorkUCC();
void finishWorkUCCError(std::exception_ptr eptr);
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
#endif
#ifdef USE_CUDA
std::unique_ptr<at::cuda::CUDAEvent> fence = nullptr;
event_pool_t* ep = nullptr;
Expand All @@ -170,6 +168,12 @@ class ProcessGroupUCC : public ProcessGroup {
ucc_status_t status_;
ucc_coll_req_h request_;
CommBase* comm_;

#ifdef USE_UCC_FUTURE
private:
// The future returned by getFuture.
c10::intrusive_ptr<at::ivalue::Future> future_;
#endif
};

explicit ProcessGroupUCC(
Expand Down
92 changes: 80 additions & 12 deletions src/torch_ucc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ const std::map<ReduceOp, ucc_reduction_op_t> ucc_op_map = {
struct torch_ucc_config_t {
std::once_flag flag;
std::array<bool, 32> blocking_wait;
bool use_future;
} torch_ucc_config;

void read_confg() {
Expand All @@ -84,6 +85,14 @@ void read_confg() {
torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BROADCAST] =
std::atoi(env);
}
#ifdef USE_UCC_FUTURE
env = std::getenv("TORCH_UCC_USE_FUTURE");
if (env) {
torch_ucc_config.use_future = !!std::atoi(env);
} else {
torch_ucc_config.use_future = true;
}
#endif
}

void check_device(c10::Device dev1, c10::Device dev2) {
Expand Down Expand Up @@ -138,6 +147,30 @@ bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) {
return true;
}

#ifdef USE_UCC_FUTURE
void ProcessGroupUCC::WorkUCC::finishWorkUCCError(std::exception_ptr eptr) {
if (torch_ucc_config.use_future) {
future_->setError(eptr);
}
finish(eptr);
}
void ProcessGroupUCC::WorkUCC::finishWorkUCC() {
if (torch_ucc_config.use_future && future_) {
if (!data || data->dst.size() == 0) {
future_->markCompleted(c10::IValue(std::vector<at::Tensor>()));
} else {
future_->markCompleted(c10::IValue(data->dst));
}
}
finish();
}

c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupUCC::WorkUCC::
getFuture() {
return future_;
}
#endif // #ifdef USE_UCC_FUTURE

void ProcessGroupUCC::WorkUCC::finalize() {
if (request_ != nullptr) {
if (isP2POp(opType_)) {
Expand All @@ -149,6 +182,9 @@ void ProcessGroupUCC::WorkUCC::finalize() {
status_ = UCC_OK;
request_ = nullptr;
}
#ifdef USE_UCC_FUTURE
finishWorkUCC();
#endif
}

CommPG::CommPG(torch_ucc_oob_coll_info_t* oob_info,
Expand Down Expand Up @@ -388,6 +424,12 @@ c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> CommPG::enqueue_collective(
auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
opType, UCC_INPROGRESS, request, nullptr, &ucc_comm);
work->data = std::move(data);
#ifdef USE_UCC_FUTURE
if (torch_ucc_config.use_future) {
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()));
}
#endif
std::unique_lock<std::mutex> lock(mutex);
progress_queue.push_back(work);
lock.unlock();
Expand Down Expand Up @@ -460,12 +502,20 @@ void CommPG::progress_loop() {
device_set = true;
}
#endif
while (work->request_->status > 0) {
// operation initialized is in progress or
work->comm_->progress();
try {
while (work->request_->status > 0) {
// operation initialized is in progress or
work->comm_->progress();
}
work->finalize();
work->data.reset();
} catch (...) {
#ifdef USE_UCC_FUTURE
work->finishWorkUCCError(std::current_exception());
#else
work->finish(std::current_exception());
#endif
}
work->finalize();
work->data.reset();
lock.lock();
}
}
Expand Down Expand Up @@ -518,11 +568,27 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::collective_post(
}
cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index()));
cuda_ev->block(*stream);
auto work = comm->enqueue_cuda_collective(opType, coll, std::move(data),
team, cuda_ee, std::move(cuda_ev), *stream, &ep);
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work =
comm->enqueue_cuda_collective(
opType,
coll,
std::move(data),
team,
cuda_ee,
std::move(cuda_ev),
*stream,
&ep);
#ifdef USE_UCC_FUTURE
if (torch_ucc_config.use_future) {
c10::cuda::CUDAMultiStreamGuard streamGuard(*stream);
std::vector<c10::Device> devList{dev};
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devList);
}
#endif // #ifdef USE_UCC_FUTURE
return work;
}
#endif
#endif // #ifdef USE_CUDA
return comm->enqueue_collective(opType, coll, std::move(data), team);
}

Expand Down Expand Up @@ -593,7 +659,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::allreduce(
coll.dst.info.count = tensor.numel();
coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
coll.dst.info.mem_type = ucc_mtype_map.at(tensor.device().type());
SAVE_TENSORS(tensors, data->src);
SAVE_TENSORS(tensors, data->dst);

return collective_post(
OpType::ALLREDUCE,
Expand Down Expand Up @@ -671,8 +737,10 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::alltoall_base(
coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER |
UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER;
}
SAVE_TENSOR(inputTensor, data->src);
SAVE_TENSOR(outputTensor, data->dst);
std::vector<at::Tensor> inputTensors = {inputTensor};
std::vector<at::Tensor> outputTensors = {outputTensor};
SAVE_TENSORS(inputTensors, data->src);
SAVE_TENSORS(outputTensors, data->dst);

return collective_post(
OpType::ALLTOALL_BASE,
Expand Down Expand Up @@ -707,7 +775,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::broadcast(
coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
coll.src.info.mem_type = ucc_mtype_map.at(tensor.device().type());
coll.root = opts.rootRank;
SAVE_TENSORS(tensors, data->src);
SAVE_TENSORS(tensors, data->dst);

return collective_post(
OpType::BROADCAST,
Expand Down

0 comments on commit d327b1e

Please sign in to comment.