From ada377c57eec006889484d10e5ce83e4ac46c971 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 15 Nov 2023 14:16:19 +0800 Subject: [PATCH] [coll] Reduce the scope of lock in the event loop. (#9784) --- include/xgboost/collective/socket.h | 23 ++++--- rabit/src/allreduce_base.cc | 4 +- src/collective/allreduce.cc | 11 ++-- src/collective/comm.cc | 35 ++++++---- src/collective/loop.cc | 88 ++++++++++++++++++-------- src/collective/loop.h | 17 ++--- tests/cpp/collective/test_allreduce.cc | 9 ++- 7 files changed, 117 insertions(+), 70 deletions(-) diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index 5dd1b9ffaff2..84453411046e 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -412,19 +412,24 @@ class TCPSocket { return Success(); } - void SetKeepAlive() { + [[nodiscard]] Result SetKeepAlive() { std::int32_t keepalive = 1; - xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, - reinterpret_cast(&keepalive), sizeof(keepalive)), - 0); + auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&keepalive), + sizeof(keepalive)); + if (rc != 0) { + return system::FailWithCode("Failed to set TCP keeaplive."); + } + return Success(); } - void SetNoDelay() { + [[nodiscard]] Result SetNoDelay() { std::int32_t tcp_no_delay = 1; - xgboost_CHECK_SYS_CALL( - setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&tcp_no_delay), - sizeof(tcp_no_delay)), - 0); + auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&tcp_no_delay), + sizeof(tcp_no_delay)); + if (rc != 0) { + return system::FailWithCode("Failed to set TCP no delay."); + } + return Success(); } /** diff --git a/rabit/src/allreduce_base.cc b/rabit/src/allreduce_base.cc index 5cab4ae327c5..b99eb3763a3e 100644 --- a/rabit/src/allreduce_base.cc +++ b/rabit/src/allreduce_base.cc @@ -417,9 +417,9 @@ void AllreduceBase::SetParam(const char *name, const char *val) { utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket"); // set the socket to non-blocking mode, enable TCP keepalive CHECK(all_link.sock.NonBlocking(true).OK()); - all_link.sock.SetKeepAlive(); + CHECK(all_link.sock.SetKeepAlive().OK()); if (rabit_enable_tcp_no_delay) { - all_link.sock.SetNoDelay(); + CHECK(all_link.sock.SetNoDelay().OK()); } if (tree_neighbors.count(all_link.rank) != 0) { if (all_link.rank == parent_rank) { diff --git a/src/collective/allreduce.cc b/src/collective/allreduce.cc index 65c06686860b..f95a9a9f1ed5 100644 --- a/src/collective/allreduce.cc +++ b/src/collective/allreduce.cc @@ -6,6 +6,7 @@ #include // for min #include // for size_t #include // for int32_t, int8_t +#include // for move #include // for vector #include "../data/array_interface.h" // for Type, DispatchDType @@ -47,7 +48,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span data, auto seg = s_buf.subspan(0, recv_seg.size()); prev_ch->RecvAll(seg); - auto rc = prev_ch->Block(); + auto rc = comm.Block(); if (!rc.OK()) { return rc; } @@ -83,11 +84,9 @@ Result RingAllreduce(Comm const& comm, common::Span data, Func cons auto prev_ch = comm.Chan(prev); auto next_ch = comm.Chan(next); - rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch); - if (!rc.OK()) { - return rc; - } - return comm.Block(); + return std::move(rc) << [&] { + return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch); + } << [&] { return comm.Block(); }; }); } } // namespace xgboost::collective::cpu_impl diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 964137ff1307..9da9083f8e42 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -33,19 +33,28 @@ Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds time Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry, std::string const& task_id, TCPSocket* out, std::int32_t rank, std::int32_t world) { - // get information from tracker + // Get information from the tracker CHECK(!info.host.empty()); - auto rc = Connect(info.host, info.port, retry, timeout, out); - if (!rc.OK()) { - return Fail("Failed to connect to the tracker.", std::move(rc)); - } - TCPSocket& tracker = *out; - return std::move(rc) - << [&] { return tracker.NonBlocking(false); } - << [&] { return tracker.RecvTimeout(timeout); } - << [&] { return proto::Magic{}.Verify(&tracker); } - << [&] { return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); }; + return Success() << [&] { + auto rc = Connect(info.host, info.port, retry, timeout, out); + if (rc.OK()) { + return rc; + } else { + return Fail("Failed to connect to the tracker.", std::move(rc)); + } + } << [&] { + return tracker.NonBlocking(false); + } << [&] { + return tracker.RecvTimeout(timeout); + } << [&] { + return proto::Magic{}.Verify(&tracker); + } << [&] { + return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); + } << [&] { + LOG(INFO) << "Task " << task_id << " connected to the tracker"; + return Success(); + }; } [[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const { @@ -257,8 +266,8 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se CHECK(this->channels_.empty()); for (auto& w : workers) { if (w) { - w->SetNoDelay(); - rc = w->NonBlocking(true); + rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); } + << [&] { return w->SetKeepAlive(); }; } if (!rc.OK()) { return rc; diff --git a/src/collective/loop.cc b/src/collective/loop.cc index 95a1019acdf3..10fce051630e 100644 --- a/src/collective/loop.cc +++ b/src/collective/loop.cc @@ -10,21 +10,26 @@ #include "xgboost/logging.h" // for CHECK namespace xgboost::collective { -Result Loop::EmptyQueue() { +Result Loop::EmptyQueue(std::queue* p_queue) const { timer_.Start(__func__); - auto error = [this] { - this->stop_ = true; + auto error = [this] { timer_.Stop(__func__); }; + + if (stop_) { timer_.Stop(__func__); - }; + return Success(); + } - while (!queue_.empty() && !stop_) { - std::queue qcopy; + auto& qcopy = *p_queue; + + // clear the copied queue + while (!qcopy.empty()) { rabit::utils::PollHelper poll; + std::size_t n_ops = qcopy.size(); - // watch all ops - while (!queue_.empty()) { - auto op = queue_.front(); - queue_.pop(); + // Iterate through all the ops for poll + for (std::size_t i = 0; i < n_ops; ++i) { + auto op = qcopy.front(); + qcopy.pop(); switch (op.code) { case Op::kRead: { @@ -40,6 +45,7 @@ Result Loop::EmptyQueue() { return Fail("Invalid socket operation."); } } + qcopy.push(op); } @@ -51,10 +57,12 @@ Result Loop::EmptyQueue() { error(); return rc; } + // we wonldn't be here if the queue is empty. CHECK(!qcopy.empty()); - while (!qcopy.empty() && !stop_) { + // Iterate through all the ops for performing the operations + for (std::size_t i = 0; i < n_ops; ++i) { auto op = qcopy.front(); qcopy.pop(); @@ -81,20 +89,21 @@ Result Loop::EmptyQueue() { } if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) { - stop_ = true; auto rc = system::FailWithCode("Invalid socket output."); error(); return rc; } + op.off += n_bytes_done; CHECK_LE(op.off, op.n); if (op.off != op.n) { // not yet finished, push back to queue for next round. - queue_.push(op); + qcopy.push(op); } } } + timer_.Stop(__func__); return Success(); } @@ -107,22 +116,42 @@ void Loop::Process() { if (stop_) { break; } - CHECK(!mu_.try_lock()); - this->rc_ = this->EmptyQueue(); - if (!rc_.OK()) { - stop_ = true; + auto unlock_notify = [&](bool is_blocking) { + if (!is_blocking) { + return; + } + lock.unlock(); cv_.notify_one(); - break; - } + }; - CHECK(queue_.empty()); - CHECK(!mu_.try_lock()); - cv_.notify_one(); - } + // move the queue + std::queue qcopy; + bool is_blocking = false; + while (!queue_.empty()) { + auto op = queue_.front(); + queue_.pop(); + if (op.code == Op::kBlock) { + is_blocking = true; + } else { + qcopy.push(op); + } + } + // unblock the queue + if (!is_blocking) { + lock.unlock(); + } + // clear the queue + auto rc = this->EmptyQueue(&qcopy); + // Handle error + if (!rc.OK()) { + this->rc_ = std::move(rc); + unlock_notify(is_blocking); + return; + } - if (rc_.OK()) { - CHECK(queue_.empty()); + CHECK(qcopy.empty()); + unlock_notify(is_blocking); } } @@ -140,6 +169,15 @@ Result Loop::Stop() { return Success(); } +[[nodiscard]] Result Loop::Block() { + this->Submit(Op{Op::kBlock}); + { + std::unique_lock lock{mu_}; + cv_.wait(lock, [this] { return (this->queue_.empty()) || stop_; }); + } + return std::move(rc_); +} + Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { timer_.Init(__func__); worker_ = std::thread{[this] { diff --git a/src/collective/loop.h b/src/collective/loop.h index 0bccbc0d09ef..4f5cb12b3eb8 100644 --- a/src/collective/loop.h +++ b/src/collective/loop.h @@ -20,13 +20,14 @@ namespace xgboost::collective { class Loop { public: struct Op { - enum Code : std::int8_t { kRead = 0, kWrite = 1 } code; + enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code; std::int32_t rank{-1}; std::int8_t* ptr{nullptr}; std::size_t n{0}; TCPSocket* sock{nullptr}; std::size_t off{0}; + explicit Op(Code c) : code{c} { CHECK(c == kBlock); } Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off) : code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {} Op(Op const&) = default; @@ -44,9 +45,9 @@ class Loop { Result rc_; bool stop_{false}; std::exception_ptr curr_exce_{nullptr}; - common::Monitor timer_; + common::Monitor mutable timer_; - Result EmptyQueue(); + Result EmptyQueue(std::queue* p_queue) const; void Process(); public: @@ -60,15 +61,7 @@ class Loop { cv_.notify_one(); } - [[nodiscard]] Result Block() { - { - std::unique_lock lock{mu_}; - cv_.notify_all(); - } - std::unique_lock lock{mu_}; - cv_.wait(lock, [this] { return this->queue_.empty() || stop_; }); - return std::move(rc_); - } + [[nodiscard]] Result Block(); explicit Loop(std::chrono::seconds timeout); diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index 744608dec009..21b4d9fd0fe2 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -18,31 +18,34 @@ class AllreduceWorker : public WorkerForTest { void Basic() { { std::vector data(13, 0.0); - Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { + auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { for (std::size_t i = 0; i < rhs.size(); ++i) { rhs[i] += lhs[i]; } }); + ASSERT_TRUE(rc.OK()); ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0); } { std::vector data(1, 1.0); - Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { + auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { for (std::size_t i = 0; i < rhs.size(); ++i) { rhs[i] += lhs[i]; } }); + ASSERT_TRUE(rc.OK()); ASSERT_EQ(data[0], static_cast(comm_.World())); } } void Acc() { std::vector data(314, 1.5); - Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { + auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { for (std::size_t i = 0; i < rhs.size(); ++i) { rhs[i] += lhs[i]; } }); + ASSERT_TRUE(rc.OK()); for (std::size_t i = 0; i < data.size(); ++i) { auto v = data[i]; ASSERT_EQ(v, 1.5 * static_cast(comm_.World())) << i;