diff --git a/examples/gcs_demo.cc b/examples/gcs_demo.cc index faf492c8..2b8a694d 100644 --- a/examples/gcs_demo.cc +++ b/examples/gcs_demo.cc @@ -29,22 +29,19 @@ void Run(SSL_CTX* ctx) { CHECK(!ec) << "Could not load credentials " << ec.message(); cloud::GCS gcs(&provider, ctx, pb); - ec = gcs.Connect(connect_ms); - CHECK(!ec) << "Could not connect " << ec; string prefix = GetFlag(FLAGS_prefix); - if (!prefix.empty()) { - string bucket = GetFlag(FLAGS_bucket); - auto conn_pool = gcs.CreateConnectionPool(); - CHECK(!bucket.empty()); + string bucket = GetFlag(FLAGS_bucket); + if (!bucket.empty()) { + auto conn_pool = gcs.GetConnectionPool(); if (GetFlag(FLAGS_write) > 0) { auto src = io::ReadFileToString("/proc/self/exe"); CHECK(src); for (unsigned i = 0; i < GetFlag(FLAGS_write); ++i) { string dest_key = absl::StrCat(prefix, "_", i); io::Result dest_res = - cloud::OpenWriteGcsFile(bucket, dest_key, &provider, conn_pool.get()); + cloud::OpenWriteGcsFile(bucket, dest_key, &provider, conn_pool); CHECK(dest_res) << "Could not open " << dest_key << " " << dest_res.error().message(); unique_ptr dest(*dest_res); error_code ec = dest->Write(*src); diff --git a/util/cloud/gcp/gcp_utils.cc b/util/cloud/gcp/gcp_utils.cc index ac7aa8dd..b11ae4dc 100644 --- a/util/cloud/gcp/gcp_utils.cc +++ b/util/cloud/gcp/gcp_utils.cc @@ -8,7 +8,9 @@ #include #include "base/logging.h" + #include "util/cloud/gcp/gcp_creds_provider.h" +#include "util/http/http_client.h" namespace util::cloud { using namespace std; @@ -66,20 +68,26 @@ std::error_code DynamicBodyRequestImpl::Send(http::Client* client) { } // namespace detail -RobustSender::RobustSender(unsigned num_iterations, GCPCredsProvider* provider) - : num_iterations_(num_iterations), provider_(provider) { +RobustSender::RobustSender(http::ClientPool* pool, GCPCredsProvider* provider) + : pool_(pool), provider_(provider) { } -auto RobustSender::Send(http::Client* client, +auto RobustSender::Send(unsigned num_iterations, detail::HttpRequestBase* req) -> io::Result { error_code ec; - for (unsigned i = 0; i < num_iterations_; ++i) { // Iterate for possible token refresh. - VLOG(1) << "HttpReq " << client->host() << ": " << req->GetHeaders() << ", [" - << client->native_handle() << "]"; + for (unsigned i = 0; i < num_iterations; ++i) { // Iterate for possible token refresh. + auto res = pool_->GetHandle(); + if (!res) + return nonstd::make_unexpected(res.error()); + + auto client_handle = std::move(res.value()); + + VLOG(1) << "HttpReq " << client_handle->host() << ": " << req->GetHeaders() << ", [" + << client_handle->native_handle() << "]"; - RETURN_UNEXPECTED(req->Send(client)); + RETURN_UNEXPECTED(req->Send(client_handle.get())); HeaderParserPtr parser(new h2::response_parser()); - RETURN_UNEXPECTED(client->ReadHeader(parser.get())); + RETURN_UNEXPECTED(client_handle->ReadHeader(parser.get())); { const auto& msg = parser->get(); VLOG(1) << "RespHeader" << i << ": " << msg; @@ -95,18 +103,19 @@ auto RobustSender::Send(http::Client* client, // We have some kind of error, possibly with body that needs to be drained. h2::response_parser drainer(std::move(*parser)); - RETURN_UNEXPECTED(client->Recv(&drainer)); + RETURN_UNEXPECTED(client_handle->Recv(&drainer)); const auto& msg = drainer.get(); if (DoesServerPushback(msg.result())) { - LOG(INFO) << "Retrying(" << client->native_handle() << ") with " << msg; + LOG(INFO) << "Retrying(" << client_handle->native_handle() << ") with " << msg; ThisFiber::SleepFor(100ms); continue; } if (IsUnauthorized(msg)) { - RETURN_UNEXPECTED(provider_->RefreshToken(client->proactor())); + VLOG(1) << "Refreshing token"; + RETURN_UNEXPECTED(provider_->RefreshToken(client_handle->proactor())); req->SetHeader(h2::field::authorization, AuthHeader(provider_->access_token())); continue; diff --git a/util/cloud/gcp/gcp_utils.h b/util/cloud/gcp/gcp_utils.h index 794b0d4f..1579cbda 100644 --- a/util/cloud/gcp/gcp_utils.h +++ b/util/cloud/gcp/gcp_utils.h @@ -4,10 +4,12 @@ #pragma once #include +#include +#include #include #include "io/io.h" -#include "util/http/http_client.h" +#include "util/http/https_client_pool.h" namespace util::cloud { class GCPCredsProvider; @@ -99,12 +101,12 @@ class RobustSender { using HeaderParserPtr = std::unique_ptr>; - RobustSender(unsigned num_iterations, GCPCredsProvider* provider); + RobustSender(http::ClientPool* pool, GCPCredsProvider* provider); - io::Result Send(http::Client* client, detail::HttpRequestBase* req); + io::Result Send(unsigned num_iterations, detail::HttpRequestBase* req); private: - unsigned num_iterations_; + http::ClientPool* pool_; GCPCredsProvider* provider_; }; diff --git a/util/cloud/gcp/gcs.cc b/util/cloud/gcp/gcs.cc index 03a70e61..55dbee92 100644 --- a/util/cloud/gcp/gcs.cc +++ b/util/cloud/gcp/gcs.cc @@ -10,6 +10,7 @@ #include #include +#include "base/flags.h" #include "base/logging.h" #include "io/file.h" #include "io/file_util.h" @@ -21,6 +22,8 @@ using namespace std; namespace h2 = boost::beast::http; namespace rj = rapidjson; +ABSL_FLAG(string, gcs_auth_token, "", ""); + namespace util { namespace cloud { @@ -30,11 +33,15 @@ auto Unexpected(std::errc code) { return nonstd::make_unexpected(make_error_code(code)); } -#define RETURN_ERROR(x) \ - do { \ - auto ec = (x); \ - if (ec) \ - return ec; \ +const char kInstanceTokenUrl[] = "/computeMetadata/v1/instance/service-accounts/default/token"; + +#define RETURN_ERROR(x) \ + do { \ + auto ec = (x); \ + if (ec) { \ + VLOG(1) << "Error calling " << #x << ": " << ec.message(); \ + return ec; \ + } \ } while (false) io::Result ExpandFilePath(string_view path) { @@ -115,7 +122,7 @@ std::error_code ParseADC(string_view adc_file, string* client_id, string* client using TokenTtl = pair; io::Result ParseTokenResponse(std::string&& response) { - VLOG(1) << "Refresh Token response: " << response; + VLOG(2) << "Refresh Token response: " << response; rj::Document doc; constexpr unsigned kFlags = rj::kParseTrailingCommasFlag | rj::kParseCommentsFlag; @@ -177,6 +184,56 @@ error_code EnableKeepAlive(int fd) { return std::error_code{}; } +error_code ConfigureMetadataClient(fb2::ProactorBase* pb, http::Client* client) { + client->set_connect_timeout_ms(1000); + static const char kMetaDataHost[] = "metadata.google.internal"; + return client->Connect(kMetaDataHost, "80"); +} + +error_code ReadGCPConfigFromMetadata(fb2::ProactorBase* pb, string* account_id, string* project_id, + TokenTtl* token) { + http::Client client(pb); + RETURN_ERROR(ConfigureMetadataClient(pb, &client)); + + const char kEmailUrl[] = "/computeMetadata/v1/instance/service-accounts/default/email"; + h2::request req{h2::verb::get, kEmailUrl, 11}; + req.set("Metadata-Flavor", "Google"); + + h2::response resp; + RETURN_ERROR(client.Send(req, &resp)); + if (resp.result() != h2::status::ok) { + LOG(WARNING) << "Http error: " << string(resp.reason()) << ", Body: ", resp.body(); + return make_error_code(errc::permission_denied); + } + + *account_id = std::move(resp.body()); + resp.clear(); + + const char kProjectIdUrl[] = "/computeMetadata/v1/project/project-id"; + req.target(kProjectIdUrl); + RETURN_ERROR(client.Send(req, &resp)); + if (resp.result() != h2::status::ok) { + LOG(WARNING) << "Http error: " << string(resp.reason()) << ", Body: ", resp.body(); + return make_error_code(errc::permission_denied); + } + + *project_id = std::move(resp.body()); + resp.clear(); + + req.target(kInstanceTokenUrl); + RETURN_ERROR(client.Send(req, &resp)); + if (resp.result() != h2::status::ok) { + LOG(WARNING) << "Http error: " << string(resp.reason()) << ", Body: ", resp.body(); + return make_error_code(errc::permission_denied); + } + io::Result token_res = ParseTokenResponse(std::move(resp.body())); + if (!token_res) + return token_res.error(); + *token = std::move(*token_res); + + return {}; +} + #define FETCH_ARRAY_MEMBER(val) \ if (!(val).IsArray()) \ return make_error_code(errc::bad_message); \ @@ -203,9 +260,21 @@ error_code GCPCredsProvider::Init(unsigned connect_ms, fb2::ProactorBase* pb) { is_cloud_env = true; } + connect_ms_ = connect_ms; + if (is_cloud_env) { use_instance_metadata_ = true; - LOG(FATAL) << "TBD: do not support reading from instance metadata"; + VLOG(1) << "Reading from instance metadata"; + TokenTtl token_ttl; + RETURN_ERROR(ReadGCPConfigFromMetadata(pb, &account_id_, &project_id_, &token_ttl)); + + string inject_token = absl::GetFlag(FLAGS_gcs_auth_token); + if (!inject_token.empty()) { + token_ttl.first = inject_token; + } + folly::RWSpinLock::WriteHolder lock(lock_); + access_token_ = token_ttl.first; + expire_time_.store(time(nullptr) + token_ttl.second, std::memory_order_release); } else { RETURN_ERROR(LoadGCPConfig(&account_id_, &project_id_)); @@ -221,43 +290,54 @@ error_code GCPCredsProvider::Init(unsigned connect_ms, fb2::ProactorBase* pb) { LOG(WARNING) << "Bad ADC file " << adc_file; return make_error_code(errc::bad_message); } + + // At this point we should have all the data to get an access token. + RETURN_ERROR(RefreshToken(pb)); } - // At this point we should have all the data to get an access token. - connect_ms_ = connect_ms; - return RefreshToken(pb); + return {}; } error_code GCPCredsProvider::RefreshToken(fb2::ProactorBase* pb) { - constexpr char kDomain[] = "oauth2.googleapis.com"; + h2::response resp; - http::TlsClient https_client(pb); - https_client.set_connect_timeout_ms(connect_ms_); - SSL_CTX* context = http::TlsClient::CreateSslContext(); - error_code ec = https_client.Connect(kDomain, "443", context); - http::TlsClient::FreeContext(context); + if (use_instance_metadata_) { + http::Client client(pb); + RETURN_ERROR(ConfigureMetadataClient(pb, &client)); - if (ec) { - VLOG(1) << "Could not connect to " << kDomain; - return ec; - } - h2::request req{h2::verb::post, "/token", 11}; - req.set(h2::field::host, kDomain); - req.set(h2::field::content_type, "application/x-www-form-urlencoded"); + h2::request req{h2::verb::get, kInstanceTokenUrl, 11}; + req.set("Metadata-Flavor", "Google"); + RETURN_ERROR(client.Send(req, &resp)); + } else { + constexpr char kDomain[] = "oauth2.googleapis.com"; - string& body = req.body(); - body = absl::StrCat("grant_type=refresh_token&client_secret=", client_secret_, - "&refresh_token=", refresh_token_); - absl::StrAppend(&body, "&client_id=", client_id_); - req.prepare_payload(); - VLOG(1) << "Req: " << req; + http::TlsClient https_client(pb); + https_client.set_connect_timeout_ms(connect_ms_); + SSL_CTX* context = http::TlsClient::CreateSslContext(); + error_code ec = https_client.Connect(kDomain, "443", context); + http::TlsClient::FreeContext(context); - h2::response resp; - RETURN_ERROR(https_client.Send(req, &resp)); - - if (resp.result() != h2::status::ok) { - LOG(WARNING) << "Http error: " << string(resp.reason()) << ", Body: ", resp.body(); - return make_error_code(errc::permission_denied); + if (ec) { + VLOG(1) << "Could not connect to " << kDomain; + return ec; + } + h2::request req{h2::verb::post, "/token", 11}; + req.set(h2::field::host, kDomain); + req.set(h2::field::content_type, "application/x-www-form-urlencoded"); + + string& body = req.body(); + body = absl::StrCat("grant_type=refresh_token&client_secret=", client_secret_, + "&refresh_token=", refresh_token_); + absl::StrAppend(&body, "&client_id=", client_id_); + req.prepare_payload(); + VLOG(1) << "Req: " << req; + + RETURN_ERROR(https_client.Send(req, &resp)); + + if (resp.result() != h2::status::ok) { + LOG(WARNING) << "Http error: " << string(resp.reason()) << ", Body: ", resp.body(); + return make_error_code(errc::permission_denied); + } } io::Result token = ParseTokenResponse(std::move(resp.body())); @@ -273,19 +353,17 @@ error_code GCPCredsProvider::RefreshToken(fb2::ProactorBase* pb) { GCS::GCS(GCPCredsProvider* provider, SSL_CTX* ssl_cntx, fb2::ProactorBase* pb) : creds_provider_(*provider), ssl_ctx_(ssl_cntx) { - client_.reset(new http::TlsClient(pb)); -} - -GCS::~GCS() { -} - -std::error_code GCS::Connect(unsigned msec) { - client_->set_connect_timeout_ms(msec); - client_->AssignOnConnect([](int fd) { + client_pool_.reset(new http::ClientPool(GCS_API_DOMAIN, ssl_ctx_, pb)); + client_pool_->SetOnConnect([](int fd) { auto ec = EnableKeepAlive(fd); LOG_IF(WARNING, ec) << "Error setting keep alive " << ec.message() << " " << fd; }); - return client_->Connect(GCS_API_DOMAIN, "443", ssl_ctx_); + + // TODO: to make it configurable. + client_pool_->set_connect_timeout(2000); +} + +GCS::~GCS() { } error_code GCS::ListBuckets(ListBucketCb cb) { @@ -296,15 +374,20 @@ error_code GCS::ListBuckets(ListBucketCb cb) { rj::Document doc; - RobustSender sender(2, &creds_provider_); + RobustSender sender(client_pool_.get(), &creds_provider_); while (true) { - io::Result parse_res = sender.Send(client_.get(), &empty_req); + io::Result parse_res = sender.Send(2, &empty_req); if (!parse_res) return parse_res.error(); RobustSender::HeaderParserPtr empty_parser = std::move(*parse_res); h2::response_parser resp(std::move(*empty_parser)); - RETURN_ERROR(client_->Recv(&resp)); + auto res = client_pool_->GetHandle(); + if (!res) + return res.error(); + auto client = std::move(*res); + + RETURN_ERROR(client->Recv(&resp)); auto msg = resp.release(); @@ -352,15 +435,20 @@ error_code GCS::List(string_view bucket, string_view prefix, bool recursive, Lis detail::EmptyRequestImpl empty_req(h2::verb::get, url, creds_provider_.access_token()); rj::Document doc; - RobustSender sender(2, &creds_provider_); + RobustSender sender(client_pool_.get(), &creds_provider_); while (true) { - io::Result parse_res = sender.Send(client_.get(), &empty_req); + io::Result parse_res = sender.Send(2, &empty_req); if (!parse_res) return parse_res.error(); RobustSender::HeaderParserPtr empty_parser = std::move(*parse_res); - h2::response_parser resp(std::move(*empty_parser)); - RETURN_ERROR(client_->Recv(&resp)); + + auto res = client_pool_->GetHandle(); + if (!res) + return res.error(); + auto client = std::move(*res); + + RETURN_ERROR(client->Recv(&resp)); auto msg = resp.release(); @@ -406,11 +494,5 @@ error_code GCS::List(string_view bucket, string_view prefix, bool recursive, Lis return {}; } -unique_ptr GCS::CreateConnectionPool() const { - unique_ptr res( - new http::ClientPool(GCS_API_DOMAIN, ssl_ctx_, client_->proactor())); - return res; -} - } // namespace cloud } // namespace util \ No newline at end of file diff --git a/util/cloud/gcp/gcs.h b/util/cloud/gcp/gcs.h index 50c63692..97e239df 100644 --- a/util/cloud/gcp/gcs.h +++ b/util/cloud/gcp/gcs.h @@ -31,18 +31,16 @@ class GCS { GCS(GCPCredsProvider* creds_provider, SSL_CTX* ssl_cntx, fb2::ProactorBase* pb); ~GCS(); - std::error_code Connect(unsigned msec); - std::error_code ListBuckets(ListBucketCb cb); std::error_code List(std::string_view bucket, std::string_view prefix, bool recursive, ListObjectCb cb); - std::unique_ptr CreateConnectionPool() const; + http::ClientPool* GetConnectionPool() { return client_pool_.get(); } private: GCPCredsProvider& creds_provider_; SSL_CTX* ssl_ctx_; - std::unique_ptr client_; + std::unique_ptr client_pool_; }; } // namespace cloud diff --git a/util/cloud/gcp/gcs_file.cc b/util/cloud/gcp/gcs_file.cc index 3dcc5536..094c3d5f 100644 --- a/util/cloud/gcp/gcs_file.cc +++ b/util/cloud/gcp/gcs_file.cc @@ -7,6 +7,7 @@ #include #include +#include // for operator<< #include #include "base/flags.h" @@ -14,6 +15,7 @@ #include "strings/escaping.h" #include "util/cloud/gcp/gcp_utils.h" #include "util/http/http_common.h" +#include "util/http/http_client.h" ABSL_FLAG(bool, gcs_dry_upload, false, ""); @@ -100,9 +102,8 @@ error_code GcsWriteFile::Close() { string body; if (!absl::GetFlag(FLAGS_gcs_dry_upload)) { - RobustSender sender(3, creds_provider_); - auto client_handle = pool_->GetHandle(); - io::Result res = sender.Send(client_handle.get(), req.get()); + RobustSender sender(pool_, creds_provider_); + io::Result res = sender.Send(3, req.get()); if (!res) { LOG(ERROR) << "Error closing GCS file " << create_file_name() << " for request: \n" << req->GetHeaders() << ", status " << res.error().message(); @@ -110,9 +111,15 @@ error_code GcsWriteFile::Close() { } HeaderParserPtr head_parser = std::move(*res); h2::response_parser resp(std::move(*head_parser)); + auto handle_res = pool_->GetHandle(); + if (!handle_res) + return handle_res.error(); + + auto client_handle = std::move(*handle_res); auto ec = client_handle->Recv(&resp); if (ec) return ec; + body = std::move(resp.get().body()); /* @@ -188,11 +195,11 @@ error_code GcsWriteFile::Upload() { error_code res; if (!absl::GetFlag(FLAGS_gcs_dry_upload)) { // TODO: RobustSender must access the entire pool, not just a single client. - RobustSender sender(3, creds_provider_); - auto client_handle = pool_->GetHandle(); - io::Result res = sender.Send(client_handle.get(), req.get()); + RobustSender sender(pool_, creds_provider_); + io::Result res = sender.Send(3, req.get()); if (!res) return res.error(); + // auto client_handle = pool_->GetHandle(); VLOG(1) << "Uploaded range " << uploaded_ << "/" << to << " for " << upload_id_; HeaderParserPtr parser_ptr = std::move(*res); @@ -236,9 +243,8 @@ io::Result OpenWriteGcsFile(const string& bucket, const string& detail::EmptyRequestImpl empty_req(h2::verb::post, url, token); empty_req.Finalize(); // it's post request so it's required. - RobustSender sender(3, creds_provider); - auto client_handle = pool->GetHandle(); - io::Result res = sender.Send(client_handle.get(), &empty_req); + RobustSender sender(pool, creds_provider); + io::Result res = sender.Send(3, &empty_req); if (!res) { return nonstd::make_unexpected(res.error()); } diff --git a/util/http/http_client.h b/util/http/http_client.h index 5718bd83..06cf4d9b 100644 --- a/util/http/http_client.h +++ b/util/http/http_client.h @@ -136,6 +136,7 @@ template auto Client::Send(const Req& req, Resp* r template auto Client::Send(const Req& req) -> BoostError { BoostError ec; AsioStreamAdapter<> adapter(*socket_); + assert(socket_); for (uint32_t i = 0; i < retry_cnt_; ++i) { ::boost::beast::http::write(adapter, req, ec); diff --git a/util/http/https_client_pool.cc b/util/http/https_client_pool.cc index 0e43b25f..154fd1a9 100644 --- a/util/http/https_client_pool.cc +++ b/util/http/https_client_pool.cc @@ -37,7 +37,7 @@ ClientPool::~ClientPool() { } } -auto ClientPool::GetHandle() -> ClientHandle { +auto ClientPool::GetHandle() -> io::Result { while (!available_handles_.empty()) { // Pulling the oldest handles first. std::unique_ptr ptr{std::move(available_handles_.front())}; @@ -60,10 +60,14 @@ auto ClientPool::GetHandle() -> ClientHandle { // TODO: create tls/Non-tls clients based on whether ssl_cntx_ is null. std::unique_ptr client(new TlsClient{&proactor_}); client->set_retry_count(retry_cnt_); - + if (on_connect_) { + client->AssignOnConnect(on_connect_); + } auto ec = client->Connect(domain_, "443", ssl_cntx_); LOG_IF(WARNING, ec) << "ClientPool: Could not connect " << ec; + if (ec) + return nonstd::make_unexpected(ec); ++existing_handles_; return ClientHandle{client.release(), HandleGuard{this}}; diff --git a/util/http/https_client_pool.h b/util/http/https_client_pool.h index b3f437bb..29ab8982 100644 --- a/util/http/https_client_pool.h +++ b/util/http/https_client_pool.h @@ -5,9 +5,12 @@ #pragma once #include +#include #include #include +#include "io/io.h" + typedef struct ssl_ctx_st SSL_CTX; namespace util { @@ -43,12 +46,12 @@ class ClientPool { /*! @brief Returns https client connection from the pool. * - * Must be called withing IoContext thread. Once ClientHandle destructs, + * Must be called withing Proactor thread. Once ClientHandle destructs, * the connection returns to the pool. GetHandle() might block the calling fiber for * connect_msec_ millis in case it creates a new connection. * Note that all allocated handles must be destroyed before destroying their parent pool. */ - ClientHandle GetHandle(); + io::Result GetHandle(); void set_connect_timeout(unsigned msec) { connect_msec_ = msec; @@ -68,6 +71,14 @@ class ClientPool { return domain_; } + fb2::ProactorBase& proactor() { + return proactor_; + } + + void SetOnConnect(std::function cb) { + on_connect_ = std::move(cb); + } + private: std::string domain_; SSL_CTX* ssl_cntx_; @@ -77,6 +88,7 @@ class ClientPool { int existing_handles_ = 0; std::deque available_handles_; // Using queue to allow round-robin access. + std::function on_connect_; }; } // namespace http