forked from dmlc/xgboost
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fa65cf6
commit 80390e6
Showing
13 changed files
with
508 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/** | ||
* Copyright 2023, XGBoost contributors | ||
*/ | ||
#include "federated_comm.h" | ||
|
||
#include <grpcpp/grpcpp.h> | ||
|
||
#include <cstdint> // for int32_t | ||
#include <cstdlib> // for getenv | ||
#include <string> // for string, stoi | ||
|
||
#include "../../src/common/common.h" // for Split | ||
#include "../../src/common/json_utils.h" // for OptionalArg | ||
#include "xgboost/json.h" // for Json | ||
#include "xgboost/logging.h" | ||
|
||
namespace xgboost::collective { | ||
void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_t world, | ||
std::int32_t rank, std::string const& server_cert, | ||
std::string const& client_key, std::string const& client_cert) { | ||
this->rank_ = rank; | ||
this->world_ = world; | ||
|
||
this->tracker_.host = host; | ||
this->tracker_.port = port; | ||
this->tracker_.rank = rank; | ||
|
||
CHECK_GE(world, 1) << "Invalid world size."; | ||
CHECK_GE(rank, 0) << "Invalid worker rank."; | ||
CHECK_LT(rank, world) << "Invalid worker rank."; | ||
|
||
if (server_cert.empty()) { | ||
stub_ = [&] { | ||
grpc::ChannelArguments args; | ||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); | ||
return federated::Federated::NewStub( | ||
grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args)); | ||
}(); | ||
} else { | ||
stub_ = [&] { | ||
grpc::SslCredentialsOptions options; | ||
options.pem_root_certs = server_cert; | ||
options.pem_private_key = client_key; | ||
options.pem_cert_chain = client_cert; | ||
grpc::ChannelArguments args; | ||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); | ||
auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args); | ||
channel->WaitForConnected( | ||
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN))); | ||
return federated::Federated::NewStub(channel); | ||
}(); | ||
} | ||
} | ||
|
||
FederatedComm::FederatedComm(Json const& config) { | ||
/** | ||
* Topology | ||
*/ | ||
std::string server_address{}; | ||
std::int32_t world_size{0}; | ||
std::int32_t rank{-1}; | ||
// Parse environment variables first. | ||
auto* value = std::getenv("FEDERATED_SERVER_ADDRESS"); | ||
if (value != nullptr) { | ||
server_address = value; | ||
} | ||
value = std::getenv("FEDERATED_WORLD_SIZE"); | ||
if (value != nullptr) { | ||
world_size = std::stoi(value); | ||
} | ||
value = std::getenv("FEDERATED_RANK"); | ||
if (value != nullptr) { | ||
rank = std::stoi(value); | ||
} | ||
|
||
server_address = OptionalArg<String>(config, "federated_server_address", server_address); | ||
world_size = | ||
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size)); | ||
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank)); | ||
|
||
auto parsed = common::Split(server_address, ':'); | ||
CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address; | ||
|
||
CHECK_NE(rank, -1) << "Parameter `federated_rank` is required"; | ||
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required."; | ||
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required."; | ||
|
||
/** | ||
* Certificates | ||
*/ | ||
std::string server_cert{}; | ||
std::string client_key{}; | ||
std::string client_cert{}; | ||
value = getenv("FEDERATED_SERVER_CERT_PATH"); | ||
if (value != nullptr) { | ||
server_cert = value; | ||
} | ||
value = getenv("FEDERATED_CLIENT_KEY_PATH"); | ||
if (value != nullptr) { | ||
client_key = value; | ||
} | ||
value = getenv("FEDERATED_CLIENT_CERT_PATH"); | ||
if (value != nullptr) { | ||
client_cert = value; | ||
} | ||
|
||
server_cert = OptionalArg<String>(config, "federated_server_cert_path", server_cert); | ||
client_key = OptionalArg<String>(config, "federated_client_key_path", client_key); | ||
client_cert = OptionalArg<String>(config, "federated_client_cert_path", client_cert); | ||
|
||
this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key, | ||
client_cert); | ||
} | ||
} // namespace xgboost::collective |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/** | ||
* Copyright 2023, XGBoost contributors | ||
*/ | ||
#pragma once | ||
|
||
#include <federated.grpc.pb.h> | ||
#include <federated.pb.h> | ||
|
||
#include <cstdint> // for int32_t | ||
#include <memory> // for unique_ptr | ||
#include <string> // for string | ||
|
||
#include "../../src/collective/comm.h" // for Comm | ||
#include "../../src/common/json_utils.h" // for OptionalArg | ||
#include "xgboost/json.h" | ||
|
||
namespace xgboost::collective { | ||
class FederatedComm : public Comm { | ||
std::unique_ptr<federated::Federated::Stub> stub_; | ||
|
||
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank, | ||
std::string const& server_cert, std::string const& client_key, | ||
std::string const& client_cert); | ||
|
||
public: | ||
/** | ||
* @param config | ||
* | ||
* - federated_server_address: Tracker address | ||
* - federated_world_size: The number of workers | ||
* - federated_rank: Rank of federated worker | ||
* - federated_server_cert_path | ||
* - federated_client_key_path | ||
* - federated_client_cert_path | ||
*/ | ||
explicit FederatedComm(Json const& config); | ||
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world, | ||
std::int32_t rank) { | ||
this->Init(host, port, world, rank, {}, {}, {}); | ||
} | ||
~FederatedComm() override { stub_.reset(); } | ||
|
||
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override { | ||
LOG(FATAL) << "peer to peer communication is not allowed for federated learning."; | ||
return nullptr; | ||
} | ||
[[nodiscard]] Result LogTracker(std::string msg) const override { | ||
LOG(CONSOLE) << msg; | ||
return Success(); | ||
} | ||
[[nodiscard]] bool IsFederated() const override { return true; } | ||
}; | ||
} // namespace xgboost::collective |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
/** | ||
* Copyright 2022-2023, XGBoost contributors | ||
*/ | ||
#include "federated_tracker.h" | ||
|
||
#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ... | ||
#include <grpcpp/server_builder.h> // for ServerBuilder | ||
|
||
#include <chrono> // for ms | ||
#include <cstdint> // for int32_t | ||
#include <exception> // for exception | ||
#include <limits> // for numeric_limits | ||
#include <string> // for string | ||
#include <thread> // for sleep_for | ||
|
||
#include "../../src/common/io.h" // for ReadAll | ||
#include "../../src/common/json_utils.h" // for RequiredArg | ||
#include "../../src/common/timer.h" // for Timer | ||
#include "federated_server.h" // for FederatedService | ||
|
||
namespace xgboost::collective { | ||
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} { | ||
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__); | ||
if (is_secure) { | ||
server_key_path_ = RequiredArg<String const>(config, "server_key_path", __func__); | ||
server_cert_file_ = RequiredArg<String const>(config, "server_cert_path", __func__); | ||
client_cert_file_ = RequiredArg<String const>(config, "client_cert_path", __func__); | ||
} | ||
} | ||
|
||
std::future<Result> FederatedTracker::Run() { | ||
return std::async([this]() { | ||
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_); | ||
federated::FederatedService service{static_cast<std::int32_t>(this->n_workers_)}; | ||
grpc::ServerBuilder builder; | ||
|
||
if (this->server_cert_file_.empty()) { | ||
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max()); | ||
if (this->port_ == 0) { | ||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); | ||
} else { | ||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); | ||
} | ||
builder.RegisterService(&service); | ||
server_ = builder.BuildAndStart(); | ||
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size " | ||
<< this->n_workers_; | ||
} else { | ||
auto options = grpc::SslServerCredentialsOptions( | ||
GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); | ||
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file_); | ||
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair(); | ||
key.private_key = xgboost::common::ReadAll(server_key_path_); | ||
key.cert_chain = xgboost::common::ReadAll(server_cert_file_); | ||
options.pem_key_cert_pairs.push_back(key); | ||
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max()); | ||
if (this->port_ == 0) { | ||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options), &port_); | ||
} else { | ||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options)); | ||
} | ||
builder.RegisterService(&service); | ||
server_ = builder.BuildAndStart(); | ||
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size " | ||
<< n_workers_; | ||
} | ||
|
||
try { | ||
server_->Wait(); | ||
} catch (std::exception const& e) { | ||
return collective::Fail(std::string{e.what()}); | ||
} | ||
return collective::Success(); | ||
}); | ||
} | ||
|
||
FederatedTracker::~FederatedTracker() = default; | ||
|
||
Result FederatedTracker::Shutdown() { | ||
common::Timer timer; | ||
timer.Start(); | ||
using namespace std::chrono_literals; | ||
while (!server_) { | ||
timer.Stop(); | ||
auto ela = timer.ElapsedSeconds(); | ||
if (ela > this->Timeout().count()) { | ||
return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) + | ||
" seconds."); | ||
} | ||
std::this_thread::sleep_for(10ms); | ||
} | ||
|
||
try { | ||
server_->Shutdown(); | ||
} catch (std::exception const& e) { | ||
return Fail("Failed to shutdown:" + std::string{e.what()}); | ||
} | ||
|
||
return Success(); | ||
} | ||
} // namespace xgboost::collective |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
/** | ||
* Copyright 2022-2023, XGBoost contributors | ||
*/ | ||
#pragma once | ||
#include <federated.grpc.pb.h> // for Server | ||
|
||
#include <future> // for future | ||
#include <memory> // for unique_ptr | ||
#include <string> // for string | ||
|
||
#include "../../src/collective/tracker.h" // for Tracker | ||
#include "xgboost/collective/result.h" // for Result | ||
#include "xgboost/json.h" // for Json | ||
|
||
namespace xgboost::collective { | ||
class FederatedTracker : public collective::Tracker { | ||
std::unique_ptr<grpc::Server> server_; | ||
std::string server_key_path_; | ||
std::string server_cert_file_; | ||
std::string client_cert_file_; | ||
|
||
public: | ||
/** | ||
* @brief CTOR | ||
* | ||
* @param config Configuration, other than the base configuration from Tracker, we have: | ||
* | ||
* - federated_secure: bool whether this is a secure server. | ||
* - server_key_path: path to the key. | ||
* - server_cert_path: certificate path. | ||
* - client_cert_path: certificate path for client. | ||
*/ | ||
explicit FederatedTracker(Json const& config); | ||
~FederatedTracker() override; | ||
std::future<Result> Run() override; | ||
// federated tracker do not provide initialization parameters, users have to provide it | ||
// themseleves. | ||
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; } | ||
[[nodiscard]] Result Shutdown(); | ||
}; | ||
} // namespace xgboost::collective |
Oops, something went wrong.