diff --git a/src/server/worker.cc b/src/server/worker.cc index d5a751e1578..155298f8438 100644 --- a/src/server/worker.cc +++ b/src/server/worker.cc @@ -22,6 +22,8 @@ #include #include +#include +#include #include #include @@ -313,9 +315,7 @@ void Worker::Stop(uint32_t wait_seconds) { } Status Worker::AddConnection(redis::Connection *c) { - std::unique_lock lock(conns_mu_); - auto iter = conns_.find(c->GetFD()); - if (iter != conns_.end()) { + if (ConnMap::const_accessor accessor; conns_.find(accessor, c->GetFD())) { return {Status::NotOK, "connection was exists"}; } @@ -325,7 +325,8 @@ Status Worker::AddConnection(redis::Connection *c) { return {Status::NotOK, "max number of clients reached"}; } - conns_.emplace(c->GetFD(), c); + ConnMap::accessor accessor; + conns_.insert(accessor, std::make_pair(c->GetFD(), c)); uint64_t id = srv->GetClientID(); c->SetID(id); @@ -335,18 +336,17 @@ Status Worker::AddConnection(redis::Connection *c) { redis::Connection *Worker::removeConnection(int fd) { redis::Connection *conn = nullptr; - std::unique_lock lock(conns_mu_); - auto iter = conns_.find(fd); - if (iter != conns_.end()) { - conn = iter->second; - conns_.erase(iter); + if (ConnMap::accessor accessor; conns_.find(accessor, fd)) { + { + conn = accessor->second; + conns_.erase(accessor); + } srv->DecrClientNum(); } - iter = monitor_conns_.find(fd); - if (iter != monitor_conns_.end()) { - conn = iter->second; - monitor_conns_.erase(iter); + if (ConnMap::accessor accessor; monitor_conns_.find(accessor, fd)) { + conn = accessor->second; + monitor_conns_.erase(accessor); srv->DecrClientNum(); srv->DecrMonitorClientNum(); } @@ -411,31 +411,30 @@ void Worker::FreeConnection(redis::Connection *conn) { } void Worker::FreeConnectionByID(int fd, uint64_t id) { - std::unique_lock lock(conns_mu_); - auto iter = conns_.find(fd); - if (iter != conns_.end() && iter->second->GetID() == id) { + if (ConnMap::accessor accessor; conns_.find(accessor, fd) && accessor->second->GetID() == id) { if (rate_limit_group_ != nullptr) { - bufferevent_remove_from_rate_limit_group(iter->second->GetBufferEvent()); + bufferevent_remove_from_rate_limit_group(accessor->second->GetBufferEvent()); } - delete iter->second; - conns_.erase(iter); + + // refer to https://github.com/oneapi-src/oneTBB/blob/v2021.13.0/include/oneapi/tbb/concurrent_hash_map.h#L826 and + // https://github.com/oneapi-src/oneTBB/blob/v2021.13.0/include/oneapi/tbb/concurrent_hash_map.h#L1418, erase will + // release the accessor, so we should access the value before erase action. + delete accessor->second; + conns_.erase(accessor); + srv->DecrClientNum(); } - - iter = monitor_conns_.find(fd); - if (iter != monitor_conns_.end() && iter->second->GetID() == id) { - delete iter->second; - monitor_conns_.erase(iter); + if (ConnMap::accessor accessor; monitor_conns_.find(accessor, fd) && accessor->second->GetID() == id) { + delete accessor->second; + monitor_conns_.erase(accessor); srv->DecrClientNum(); srv->DecrMonitorClientNum(); } } Status Worker::EnableWriteEvent(int fd) { - std::unique_lock lock(conns_mu_); - auto iter = conns_.find(fd); - if (iter != conns_.end()) { - auto bev = iter->second->GetBufferEvent(); + if (ConnMap::const_accessor accessor; conns_.find(accessor, fd)) { + auto bev = accessor->second->GetBufferEvent(); bufferevent_enable(bev, EV_WRITE); return Status::OK(); } @@ -444,11 +443,9 @@ Status Worker::EnableWriteEvent(int fd) { } Status Worker::Reply(int fd, const std::string &reply) { - std::unique_lock lock(conns_mu_); - auto iter = conns_.find(fd); - if (iter != conns_.end()) { - iter->second->SetLastInteraction(); - redis::Reply(iter->second->Output(), reply); + if (ConnMap::accessor accessor; conns_.find(accessor, fd)) { + accessor->second->SetLastInteraction(); + redis::Reply(accessor->second->Output(), reply); return Status::OK(); } @@ -456,72 +453,90 @@ Status Worker::Reply(int fd, const std::string &reply) { } void Worker::BecomeMonitorConn(redis::Connection *conn) { - { - std::lock_guard guard(conns_mu_); - conns_.erase(conn->GetFD()); - monitor_conns_[conn->GetFD()] = conn; + if (ConnMap::accessor accessor; conns_.find(accessor, conn->GetFD())) { + conns_.erase(accessor); + } + if (ConnMap::accessor accessor; monitor_conns_.find(accessor, conn->GetFD())) { + accessor->second = conn; + } else { + monitor_conns_.insert(accessor, std::make_pair(conn->GetFD(), conn)); } srv->IncrMonitorClientNum(); conn->EnableFlag(redis::Connection::kMonitor); } void Worker::QuitMonitorConn(redis::Connection *conn) { - { - std::lock_guard guard(conns_mu_); - monitor_conns_.erase(conn->GetFD()); - conns_[conn->GetFD()] = conn; + if (ConnMap::accessor accessor; monitor_conns_.find(accessor, conn->GetFD())) { + { + monitor_conns_.erase(accessor); + accessor.release(); + } + if (ConnMap::accessor accessor; conns_.find(accessor, conn->GetFD())) { + accessor->second = conn; + } else { + conns_.insert(accessor, std::make_pair(conn->GetFD(), conn)); + } } srv->DecrMonitorClientNum(); conn->DisableFlag(redis::Connection::kMonitor); } void Worker::FeedMonitorConns(redis::Connection *conn, const std::string &response) { - std::unique_lock lock(conns_mu_); - - for (const auto &iter : monitor_conns_) { - if (conn == iter.second) continue; // skip the monitor command - - if (conn->GetNamespace() == iter.second->GetNamespace() || iter.second->GetNamespace() == kDefaultNamespace) { - iter.second->Reply(response); - } - } + tbb::task_arena one_thread_arena(tbb::task_arena::constraints{}.set_max_concurrency(1)); + one_thread_arena.execute([this, conn, response]() { + tbb::parallel_for(monitor_conns_.range(), [conn, response](const ConnMap::range_type &range) { + for (auto &it : range) { + const auto &value = it.second; + if (conn == value) continue; // skip the monitor command + if (conn->GetNamespace() == value->GetNamespace() || value->GetNamespace() == kDefaultNamespace) { + value->Reply(response); + } + } + }); + }); } std::string Worker::GetClientsStr() { - std::unique_lock lock(conns_mu_); - - std::string output; - for (const auto &iter : conns_) { - redis::Connection *conn = iter.second; - output.append(conn->ToString()); - } - - return output; + tbb::task_arena one_thread_arena(tbb::task_arena::constraints{}.set_max_concurrency(1)); + return one_thread_arena.execute([this]() { + return tbb::parallel_reduce( + conns_.range(), std::string{}, + [](const ConnMap::range_type &range, std::string result) { + for (auto &it : range) { + result.append(it.second->ToString()); + } + return result; + }, + [](const std::string &lhs, const std::string &rhs) { + std::string result = lhs; + result.append(rhs); + return result; + }); + }); } void Worker::KillClient(redis::Connection *self, uint64_t id, const std::string &addr, uint64_t type, bool skipme, int64_t *killed) { - std::lock_guard guard(conns_mu_); - - for (const auto &iter : conns_) { - redis::Connection *conn = iter.second; - if (skipme && self == conn) continue; - - // no need to kill the client again if the kCloseAfterReply flag is set - if (conn->IsFlagEnabled(redis::Connection::kCloseAfterReply)) { - continue; - } - - if ((type & conn->GetClientType()) || - (!addr.empty() && (conn->GetAddr() == addr || conn->GetAnnounceAddr() == addr)) || - (id != 0 && conn->GetID() == id)) { - conn->EnableFlag(redis::Connection::kCloseAfterReply); - // enable write event to notify worker wake up ASAP, and remove the connection - if (!conn->IsFlagEnabled(redis::Connection::kSlave)) { // don't enable any event in slave connection - auto bev = conn->GetBufferEvent(); - bufferevent_enable(bev, EV_WRITE); + for (const auto key : getConnFds()) { + if (ConnMap::accessor accessor; conns_.find(accessor, key)) { + auto conn = accessor->second; + if (skipme && self == conn) continue; + + // no need to kill the client again if the kCloseAfterReply flag is set + if (conn->IsFlagEnabled(redis::Connection::kCloseAfterReply)) { + continue; + } + if ((type & conn->GetClientType()) || + (!addr.empty() && (conn->GetAddr() == addr || conn->GetAnnounceAddr() == addr)) || + (id != 0 && conn->GetID() == id)) { + conn->EnableFlag(redis::Connection::kCloseAfterReply); + // enable write event to notify worker wake up ASAP, and remove the connection + if (!conn->IsFlagEnabled(redis::Connection::kSlave)) { // don't enable any event in slave connection + auto bev = conn->GetBufferEvent(); + bufferevent_enable(bev, EV_WRITE); + } + (*killed)++; } - (*killed)++; } } } @@ -529,30 +544,49 @@ void Worker::KillClient(redis::Connection *self, uint64_t id, const std::string void Worker::KickoutIdleClients(int timeout) { std::vector> to_be_killed_conns; - { - std::lock_guard guard(conns_mu_); - if (conns_.empty()) { - return; - } + auto fd_list = getConnFds(); + if (fd_list.empty()) { + return; + } - int iterations = std::min(static_cast(conns_.size()), 50); - auto iter = conns_.upper_bound(last_iter_conn_fd_); - while (iterations--) { - if (iter == conns_.end()) iter = conns_.begin(); - if (static_cast(iter->second->GetIdleTime()) >= timeout) { - to_be_killed_conns.emplace_back(iter->first, iter->second->GetID()); - } - iter++; + std::set fds(fd_list.cbegin(), fd_list.cend()); + + int iterations = std::min(static_cast(conns_.size()), 50); + auto iter = fds.upper_bound(last_iter_conn_fd_); + while (iterations--) { + if (iter == fds.end()) { + iter = fds.begin(); + } + if (ConnMap::const_accessor accessor; + conns_.find(accessor, *iter) && static_cast(accessor->second->GetIdleTime()) >= timeout) { + to_be_killed_conns.emplace_back(accessor->first, accessor->second->GetID()); } - iter--; - last_iter_conn_fd_ = iter->first; + iter++; } + iter--; + last_iter_conn_fd_ = *iter; for (const auto &conn : to_be_killed_conns) { FreeConnectionByID(conn.first, conn.second); } } +std::vector Worker::getConnFds() const { + return tbb::parallel_reduce( + conns_.range(), std::vector{}, + [](const ConnMap::const_range_type &range, std::vector result) { + for (const auto &fd : range) { + result.emplace_back(fd.first); + } + return result; + }, + [](const std::vector &lhs, const std::vector &rhs) { + std::vector result = lhs; + result.insert(result.end(), rhs.begin(), rhs.end()); + return result; + }); +} + void WorkerThread::Start() { auto s = util::CreateThread("worker", [this] { this->worker_->Run(std::this_thread::get_id()); }); @@ -566,6 +600,24 @@ void WorkerThread::Start() { LOG(INFO) << "[worker] Thread #" << t_.get_id() << " started"; } +std::map Worker::GetConnections() const { + std::map result; + result = tbb::parallel_reduce( + conns_.range(), result, + [](const ConnMap::const_range_type &range, std::map tmp_result) { + for (auto &it : range) { + tmp_result.emplace(it.first, it.second); + } + return tmp_result; + }, + [](const std::map &lhs, const std::map &rhs) { + std::map result = lhs; + result.insert(rhs.cbegin(), rhs.cend()); + return result; + }); + return result; +} + void WorkerThread::Stop(uint32_t wait_seconds) { worker_->Stop(wait_seconds); } void WorkerThread::Join() { diff --git a/src/server/worker.h b/src/server/worker.h index b6918ba9296..a6b13391155 100644 --- a/src/server/worker.h +++ b/src/server/worker.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -75,22 +76,28 @@ class Worker : EventCallbackBase, EvconnlistenerBase { void TimerCB(int, int16_t events); lua_State *Lua() { return lua_; } - std::map GetConnections() const { return conns_; } + std::map GetConnections() const; Server *srv; private: + using ConnMap = tbb::concurrent_hash_map; + Status listenTCP(const std::string &host, uint32_t port, int backlog); void newTCPConnection(evconnlistener *listener, evutil_socket_t fd, sockaddr *address, int socklen); void newUnixSocketConnection(evconnlistener *listener, evutil_socket_t fd, sockaddr *address, int socklen); redis::Connection *removeConnection(int fd); + std::vector getConnFds() const; event_base *base_; UniqueEvent timer_; std::thread::id tid_; std::vector listen_events_; - std::mutex conns_mu_; - std::map conns_; - std::map monitor_conns_; + + // must use tbb::parallel_for or tbb::parallel_reduce to traverse + // refer: + // https://github.com/oneapi-src/oneTBB/blob/v2021.13.0/include/oneapi/tbb/concurrent_hash_map.h#L1033-L1051 + ConnMap conns_; + ConnMap monitor_conns_; int last_iter_conn_fd_ = 0; // fd of last processed connection in previous cron struct bufferevent_rate_limit_group *rate_limit_group_ = nullptr;