Skip to content

Commit

Permalink
[coll] Federated comm. (dmlc#9732)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Oct 30, 2023
1 parent fa65cf6 commit 80390e6
Show file tree
Hide file tree
Showing 13 changed files with 508 additions and 16 deletions.
2 changes: 1 addition & 1 deletion plugin/federated/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ target_sources(federated_client INTERFACE federated_client.h)
target_link_libraries(federated_client INTERFACE federated_proto)

# Rabit engine for Federated Learning.
target_sources(objxgboost PRIVATE federated_server.cc)
target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc)
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
114 changes: 114 additions & 0 deletions plugin/federated/federated_comm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/**
* Copyright 2023, XGBoost contributors
*/
#include "federated_comm.h"

#include <grpcpp/grpcpp.h>

#include <cstdint> // for int32_t
#include <cstdlib> // for getenv
#include <string> // for string, stoi

#include "../../src/common/common.h" // for Split
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h" // for Json
#include "xgboost/logging.h"

namespace xgboost::collective {
void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank, std::string const& server_cert,
std::string const& client_key, std::string const& client_cert) {
this->rank_ = rank;
this->world_ = world;

this->tracker_.host = host;
this->tracker_.port = port;
this->tracker_.rank = rank;

CHECK_GE(world, 1) << "Invalid world size.";
CHECK_GE(rank, 0) << "Invalid worker rank.";
CHECK_LT(rank, world) << "Invalid worker rank.";

if (server_cert.empty()) {
stub_ = [&] {
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
return federated::Federated::NewStub(
grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args));
}();
} else {
stub_ = [&] {
grpc::SslCredentialsOptions options;
options.pem_root_certs = server_cert;
options.pem_private_key = client_key;
options.pem_cert_chain = client_cert;
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args);
channel->WaitForConnected(
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
return federated::Federated::NewStub(channel);
}();
}
}

FederatedComm::FederatedComm(Json const& config) {
/**
* Topology
*/
std::string server_address{};
std::int32_t world_size{0};
std::int32_t rank{-1};
// Parse environment variables first.
auto* value = std::getenv("FEDERATED_SERVER_ADDRESS");
if (value != nullptr) {
server_address = value;
}
value = std::getenv("FEDERATED_WORLD_SIZE");
if (value != nullptr) {
world_size = std::stoi(value);
}
value = std::getenv("FEDERATED_RANK");
if (value != nullptr) {
rank = std::stoi(value);
}

server_address = OptionalArg<String>(config, "federated_server_address", server_address);
world_size =
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));

auto parsed = common::Split(server_address, ':');
CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address;

CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";

/**
* Certificates
*/
std::string server_cert{};
std::string client_key{};
std::string client_cert{};
value = getenv("FEDERATED_SERVER_CERT_PATH");
if (value != nullptr) {
server_cert = value;
}
value = getenv("FEDERATED_CLIENT_KEY_PATH");
if (value != nullptr) {
client_key = value;
}
value = getenv("FEDERATED_CLIENT_CERT_PATH");
if (value != nullptr) {
client_cert = value;
}

server_cert = OptionalArg<String>(config, "federated_server_cert_path", server_cert);
client_key = OptionalArg<String>(config, "federated_client_key_path", client_key);
client_cert = OptionalArg<String>(config, "federated_client_cert_path", client_cert);

this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
client_cert);
}
} // namespace xgboost::collective
53 changes: 53 additions & 0 deletions plugin/federated/federated_comm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* Copyright 2023, XGBoost contributors
*/
#pragma once

#include <federated.grpc.pb.h>
#include <federated.pb.h>

#include <cstdint> // for int32_t
#include <memory> // for unique_ptr
#include <string> // for string

#include "../../src/collective/comm.h" // for Comm
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h"

namespace xgboost::collective {
class FederatedComm : public Comm {
std::unique_ptr<federated::Federated::Stub> stub_;

void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
std::string const& server_cert, std::string const& client_key,
std::string const& client_cert);

public:
/**
* @param config
*
* - federated_server_address: Tracker address
* - federated_world_size: The number of workers
* - federated_rank: Rank of federated worker
* - federated_server_cert_path
* - federated_client_key_path
* - federated_client_cert_path
*/
explicit FederatedComm(Json const& config);
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
}
~FederatedComm() override { stub_.reset(); }

[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
LOG(FATAL) << "peer to peer communication is not allowed for federated learning.";
return nullptr;
}
[[nodiscard]] Result LogTracker(std::string msg) const override {
LOG(CONSOLE) << msg;
return Success();
}
[[nodiscard]] bool IsFederated() const override { return true; }
};
} // namespace xgboost::collective
7 changes: 5 additions & 2 deletions plugin/federated/federated_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
#include "federated_server.h"

#include <grpcpp/grpcpp.h>
#include <grpcpp/server.h> // for Server
#include <grpcpp/server_builder.h>
#include <xgboost/logging.h>

#include <sstream>

#include "../../src/collective/comm.h"
#include "../../src/common/io.h"
#include "../../src/common/json_utils.h"

namespace xgboost::federated {
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
Expand Down Expand Up @@ -46,7 +49,7 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
void RunServer(int port, std::size_t world_size, char const* server_key_file,
char const* server_cert_file, char const* client_cert_file) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};
FederatedService service{static_cast<std::int32_t>(world_size)};

grpc::ServerBuilder builder;
auto options =
Expand All @@ -68,7 +71,7 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,

void RunInsecureServer(int port, std::size_t world_size) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};
FederatedService service{static_cast<std::int32_t>(world_size)};

grpc::ServerBuilder builder;
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
Expand Down
20 changes: 11 additions & 9 deletions plugin/federated/federated_server.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
/*!
* Copyright 2022 XGBoost contributors
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once

#include <federated.grpc.pb.h>

#include "../../src/collective/in_memory_handler.h"
#include <cstdint> // for int32_t
#include <future> // for future

namespace xgboost {
namespace federated {
#include "../../src/collective/in_memory_handler.h"
#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result

namespace xgboost::federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(std::size_t const world_size) : handler_{world_size} {}
explicit FederatedService(std::int32_t world_size)
: handler_{static_cast<std::size_t>(world_size)} {}

grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
Expand All @@ -34,6 +38,4 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
char const* server_cert_file, char const* client_cert_file);

void RunInsecureServer(int port, std::size_t world_size);

} // namespace federated
} // namespace xgboost
} // namespace xgboost::federated
101 changes: 101 additions & 0 deletions plugin/federated/federated_tracker.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#include "federated_tracker.h"

#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ...
#include <grpcpp/server_builder.h> // for ServerBuilder

#include <chrono> // for ms
#include <cstdint> // for int32_t
#include <exception> // for exception
#include <limits> // for numeric_limits
#include <string> // for string
#include <thread> // for sleep_for

#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for RequiredArg
#include "../../src/common/timer.h" // for Timer
#include "federated_server.h" // for FederatedService

namespace xgboost::collective {
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
if (is_secure) {
server_key_path_ = RequiredArg<String const>(config, "server_key_path", __func__);
server_cert_file_ = RequiredArg<String const>(config, "server_cert_path", __func__);
client_cert_file_ = RequiredArg<String const>(config, "client_cert_path", __func__);
}
}

std::future<Result> FederatedTracker::Run() {
return std::async([this]() {
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
federated::FederatedService service{static_cast<std::int32_t>(this->n_workers_)};
grpc::ServerBuilder builder;

if (this->server_cert_file_.empty()) {
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
if (this->port_ == 0) {
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
} else {
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
}
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
<< this->n_workers_;
} else {
auto options = grpc::SslServerCredentialsOptions(
GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file_);
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
key.private_key = xgboost::common::ReadAll(server_key_path_);
key.cert_chain = xgboost::common::ReadAll(server_cert_file_);
options.pem_key_cert_pairs.push_back(key);
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
if (this->port_ == 0) {
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options), &port_);
} else {
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
}
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
<< n_workers_;
}

try {
server_->Wait();
} catch (std::exception const& e) {
return collective::Fail(std::string{e.what()});
}
return collective::Success();
});
}

FederatedTracker::~FederatedTracker() = default;

Result FederatedTracker::Shutdown() {
common::Timer timer;
timer.Start();
using namespace std::chrono_literals;
while (!server_) {
timer.Stop();
auto ela = timer.ElapsedSeconds();
if (ela > this->Timeout().count()) {
return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) +
" seconds.");
}
std::this_thread::sleep_for(10ms);
}

try {
server_->Shutdown();
} catch (std::exception const& e) {
return Fail("Failed to shutdown:" + std::string{e.what()});
}

return Success();
}
} // namespace xgboost::collective
41 changes: 41 additions & 0 deletions plugin/federated/federated_tracker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h> // for Server

#include <future> // for future
#include <memory> // for unique_ptr
#include <string> // for string

#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result
#include "xgboost/json.h" // for Json

namespace xgboost::collective {
class FederatedTracker : public collective::Tracker {
std::unique_ptr<grpc::Server> server_;
std::string server_key_path_;
std::string server_cert_file_;
std::string client_cert_file_;

public:
/**
* @brief CTOR
*
* @param config Configuration, other than the base configuration from Tracker, we have:
*
* - federated_secure: bool whether this is a secure server.
* - server_key_path: path to the key.
* - server_cert_path: certificate path.
* - client_cert_path: certificate path for client.
*/
explicit FederatedTracker(Json const& config);
~FederatedTracker() override;
std::future<Result> Run() override;
// federated tracker do not provide initialization parameters, users have to provide it
// themseleves.
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; }
[[nodiscard]] Result Shutdown();
};
} // namespace xgboost::collective
Loading

0 comments on commit 80390e6

Please sign in to comment.