diff --git a/examples/gcs_demo.cc b/examples/gcs_demo.cc index e531dc62..0baf631f 100644 --- a/examples/gcs_demo.cc +++ b/examples/gcs_demo.cc @@ -14,18 +14,22 @@ using namespace util; using absl::GetFlag; ABSL_FLAG(string, bucket, "", ""); -ABSL_FLAG(string, access_token, "", ""); ABSL_FLAG(uint32_t, connect_ms, 2000, ""); ABSL_FLAG(bool, epoll, false, "Whether to use epoll instead of io_uring"); void Run(SSL_CTX* ctx) { fb2::ProactorBase* pb = fb2::ProactorBase::me(); - cloud::GCS gcs(ctx, pb); - error_code ec = gcs.Connect(GetFlag(FLAGS_connect_ms)); + cloud::GCPCredsProvider provider; + unsigned connect_ms = GetFlag(FLAGS_connect_ms); + error_code ec = provider.Init(connect_ms, pb); + CHECK(!ec) << "Could not load credentials " << ec.message(); + + cloud::GCS gcs(&provider, ctx, pb); + ec = gcs.Connect(connect_ms); CHECK(!ec) << "Could not connect " << ec; auto res = gcs.ListBuckets(); - CHECK(res) << res.error(); + CHECK(res) << res.error().message(); for (auto v : *res) { CONSOLE_INFO << v; } diff --git a/util/cloud/gcp/CMakeLists.txt b/util/cloud/gcp/CMakeLists.txt index 532eb973..2ce50be3 100644 --- a/util/cloud/gcp/CMakeLists.txt +++ b/util/cloud/gcp/CMakeLists.txt @@ -1,3 +1,3 @@ add_library(gcp_lib gcs.cc) -cxx_link(gcp_lib http_client_lib) +cxx_link(gcp_lib http_client_lib TRDP::rapidjson) diff --git a/util/cloud/gcp/gcs.cc b/util/cloud/gcp/gcs.cc index 4dd0aded..9defcecb 100644 --- a/util/cloud/gcp/gcs.cc +++ b/util/cloud/gcp/gcs.cc @@ -3,15 +3,303 @@ #include "util/cloud/gcp/gcs.h" +#include +#include +#include + +#include +#include + +#include "base/logging.h" +#include "io/file.h" +#include "io/file_util.h" +#include "io/line_reader.h" + +using namespace std; +namespace h2 = boost::beast::http; +namespace rj = rapidjson; + namespace util { namespace cloud { -namespace { +namespace { constexpr char kDomain[] = "www.googleapis.com"; + +using EmptyRequest = h2::request; + +auto Unexpected(std::errc code) { + return nonstd::make_unexpected(make_error_code(code)); +} + +string AuthHeader(string_view access_token) { + return absl::StrCat("Bearer ", access_token); +} + +EmptyRequest PrepareRequest(h2::verb req_verb, boost::beast::string_view url, + const string_view access_token) { + EmptyRequest req(req_verb, url, 11); + req.set(h2::field::host, kDomain); + req.set(h2::field::authorization, AuthHeader(access_token)); + req.keep_alive(true); + + return req; +} + +bool IsUnauthorized(const h2::header& resp) { + if (resp.result() != h2::status::unauthorized) { + return false; + } + auto it = resp.find("WWW-Authenticate"); + + return it != resp.end(); +} + +struct PrintTag { + const h2::header& msg; +}; + +std::ostream& operator<<(std::ostream& os, const PrintTag& tag) { + os << tag.msg.reason() << endl; + for (const auto& f : tag.msg) { + os << f.name_string() << " : " << f.value() << endl; + } + os << "-------------------------"; + + return os; +} + + +io::Result ExpandFile(string_view path) { + io::Result res = io::StatFiles(path); + + if (!res) { + return nonstd::make_unexpected(res.error()); + } + + if (res->empty()) { + VLOG(1) << "Could not find " << path; + return Unexpected(errc::no_such_file_or_directory); + } + return res->front().name; +} + +std::error_code LoadGCPConfig(string* account_id, string* project_id) { + io::Result path = ExpandFile("~/.config/gcloud/configurations/config_default"); + if (!path) { + return path.error(); + } + + io::Result config = io::ReadFileToString(*path); + if (!config) { + return config.error(); + } + + io::BytesSource bs(*config); + io::LineReader reader(&bs, DO_NOT_TAKE_OWNERSHIP, 11); + string scratch; + string_view line; + while (reader.Next(&line, &scratch)) { + vector vals = absl::StrSplit(line, "="); + if (vals.size() != 2) + continue; + for (auto& v : vals) { + v = absl::StripAsciiWhitespace(v); + } + if (vals[0] == "account") { + *account_id = string(vals[1]); + } else if (vals[0] == "project") { + *project_id = string(vals[1]); + } + } + + return {}; +} + +std::error_code ParseADC(string_view adc_file, string* client_id, string* client_secret, + string* refresh_token) { + io::Result adc = io::ReadFileToString(adc_file); + if (!adc) { + return adc.error(); + } + + rj::Document adc_doc; + constexpr unsigned kFlags = rj::kParseTrailingCommasFlag | rj::kParseCommentsFlag; + adc_doc.ParseInsitu(&adc->front()); + + if (adc_doc.HasParseError()) { + return make_error_code(errc::protocol_error); + } + + for (auto it = adc_doc.MemberBegin(); it != adc_doc.MemberEnd(); ++it) { + if (it->name == "client_id") { + *client_id = it->value.GetString(); + } else if (it->name == "client_secret") { + *client_secret = it->value.GetString(); + } else if (it->name == "refresh_token") { + *refresh_token = it->value.GetString(); + } + } + + return {}; +} + +// token, expire_in (seconds) +using TokenTtl = pair; + +io::Result ParseTokenResponse(std::string&& response) { + VLOG(1) << "Refresh Token response: " << response; + + rj::Document doc; + constexpr unsigned kFlags = rj::kParseTrailingCommasFlag | rj::kParseCommentsFlag; + doc.ParseInsitu(&response.front()); + + if (doc.HasParseError()) { + return Unexpected(errc::bad_message); + } + + TokenTtl result; + auto it = doc.FindMember("token_type"); + if (it == doc.MemberEnd() || string_view{it->value.GetString()} != "Bearer"sv) { + return Unexpected(errc::bad_message); + } + + it = doc.FindMember("access_token"); + if (it == doc.MemberEnd()) { + return Unexpected(errc::bad_message); + } + result.first = it->value.GetString(); + it = doc.FindMember("expires_in"); + if (it == doc.MemberEnd() || !it->value.IsUint()) { + return Unexpected(errc::bad_message); + } + result.second = it->value.GetUint(); + + return result; +} + +template +error_code SendWithToken(GCPCredsProvider* provider, http::Client* client, EmptyRequest* req, h2::response* resp) { + for (unsigned i = 0; i < 2; ++i) { // Iterate for possible token refresh. + VLOG(1) << "HttpReq" << i << ": " << *req << ", socket " << client->native_handle(); + + error_code ec = client->Send(*req, resp); + if (ec) { + return ec; + } + VLOG(1) << "HttpResp" << i << ": " << *resp; + + if (resp->result() == h2::status::ok) { + break; + }; + + if (IsUnauthorized(*resp)) { + ec = provider->RefreshToken(client->proactor()); + if (ec) { + return ec; + } + + *resp = {}; + req->set(h2::field::authorization, AuthHeader(provider->access_token())); + + continue; + } + LOG(FATAL) << "Unexpected response " << *resp; + } + return {}; +} + } // namespace +error_code GCPCredsProvider::Init(unsigned connect_ms, fb2::ProactorBase* pb) { + CHECK_GT(connect_ms, 0u); + + io::Result root_path = ExpandFile("~/.config/gcloud"); + if (!root_path) { + return root_path.error(); + } + + bool is_cloud_env = false; + string gce_file = absl::StrCat(*root_path, "/gce"); -GCS::GCS(SSL_CTX* ssl_cntx, fb2::ProactorBase* pb) { + VLOG(1) << "Reading from " << gce_file; + + io::Result gce_file_str = io::ReadFileToString(gce_file); + + if (gce_file_str && *gce_file_str == "True") { + is_cloud_env = true; + } + + if (is_cloud_env) { + use_instance_metadata_ = true; + LOG(FATAL) << "TBD: do not support reading from instance metadata"; + } else { + error_code ec = LoadGCPConfig(&account_id_, &project_id_); + if (ec) + return ec; + if (account_id_.empty() || project_id_.empty()) { + LOG(WARNING) << "gcloud config file is not valid"; + return make_error_code(errc::not_supported); + } + string adc_file = absl::StrCat(*root_path, "/legacy_credentials/", account_id_, "/adc.json"); + VLOG(1) << "ADC file: " << adc_file; + ec = ParseADC(adc_file, &client_id_, &client_secret_, &refresh_token_); + if (ec) + return ec; + if (client_id_.empty() || client_secret_.empty() || refresh_token_.empty()) { + LOG(WARNING) << "Bad ADC file " << adc_file; + return make_error_code(errc::bad_message); + } + } + + // At this point we should have all the data to get an access token. + connect_ms_ = connect_ms; + return RefreshToken(pb); +} + +error_code GCPCredsProvider::RefreshToken(fb2::ProactorBase* pb) { + constexpr char kDomain[] = "oauth2.googleapis.com"; + + http::TlsClient https_client(pb); + https_client.set_connect_timeout_ms(connect_ms_); + SSL_CTX* context = http::TlsClient::CreateSslContext(); + error_code ec = https_client.Connect(kDomain, "443", context); + http::TlsClient::FreeContext(context); + + if (ec) + return ec; + h2::request req{h2::verb::post, "/token", 11}; + req.set(h2::field::host, kDomain); + req.set(h2::field::content_type, "application/x-www-form-urlencoded"); + + string& body = req.body(); + body = absl::StrCat("grant_type=refresh_token&client_secret=", client_secret_, + "&refresh_token=", refresh_token_); + absl::StrAppend(&body, "&client_id=", client_id_); + req.prepare_payload(); + VLOG(1) << "Req: " << req; + + h2::response resp; + ec = https_client.Send(req, &resp); + if (ec) + return ec; + if (resp.result() != h2::status::ok) { + LOG(WARNING) << "Http error: " << string(resp.reason()) << ", Body: ", resp.body(); + return make_error_code(errc::permission_denied); + } + + io::Result token = ParseTokenResponse(std::move(resp.body())); + if (!token) + return token.error(); + + folly::RWSpinLock::WriteHolder lock(lock_); + access_token_ = token->first; + expire_time_.store(time(nullptr) + token->second, std::memory_order_release); + + return {}; +} + +GCS::GCS(GCPCredsProvider* provider, SSL_CTX* ssl_cntx, fb2::ProactorBase* pb) + : creds_provider_(*provider), ssl_ctx_(ssl_cntx) { client_.reset(new http::TlsClient(pb)); } @@ -21,10 +309,21 @@ GCS::~GCS() { std::error_code GCS::Connect(unsigned msec) { client_->set_connect_timeout_ms(msec); - return client_->Connect(kDomain, "443"); + return client_->Connect(kDomain, "443", ssl_ctx_); } auto GCS::ListBuckets() -> ListBucketResult { + string url = absl::StrCat("/storage/v1/b?project=", creds_provider_.project_id()); + absl::StrAppend(&url, "&fields=items,nextPageToken"); + + auto http_req = PrepareRequest(h2::verb::get, url, creds_provider_.access_token()); + + rj::Document doc; + h2::response resp_msg; + error_code ec = SendWithToken(&creds_provider_, client_.get(), &http_req, &resp_msg); + if (ec) + return nonstd::make_unexpected(ec); + VLOG(2) << "ListResponse: " << resp_msg.body(); return {}; } diff --git a/util/cloud/gcp/gcs.h b/util/cloud/gcp/gcs.h index 8e1775cf..23b78cc7 100644 --- a/util/cloud/gcp/gcs.h +++ b/util/cloud/gcp/gcs.h @@ -8,6 +8,7 @@ #include #include "util/http/http_client.h" +#include "base/RWSpinLock.h" typedef struct ssl_ctx_st SSL_CTX; @@ -19,11 +20,58 @@ class ProactorBase; namespace cloud { +class GCPCredsProvider { + GCPCredsProvider(const GCPCredsProvider&) = delete; + GCPCredsProvider& operator=(const GCPCredsProvider&) = delete; + + public: + GCPCredsProvider() = default; + + std::error_code Init(unsigned connect_ms, fb2::ProactorBase* pb); + + const std::string& project_id() const { + return project_id_; + } + + const std::string& client_id() const { + return client_id_; + } + + // Thread-safe method to access the token. + std::string access_token() const { + folly::RWSpinLock::ReadHolder lock(lock_); + return access_token_; + } + + time_t expire_time() const { + return expire_time_.load(std::memory_order_acquire); + } + + // Thread-safe method issues refresh of the token. + // Right now will do the refresh unconditonally. + // TODO: to use expire_time_ to skip the refresh if expire time is far away. + std::error_code RefreshToken(fb2::ProactorBase* pb); + + private: + bool use_instance_metadata_ = false; + unsigned connect_ms_ = 0; + + fb2::ProactorBase* pb_ = nullptr; + std::string account_id_; + std::string project_id_; + + std::string client_id_, client_secret_, refresh_token_; + + mutable folly::RWSpinLock lock_; // protects access_token_ + std::string access_token_; + std::atomic expire_time_ = 0; // seconds since epoch +}; + class GCS { public: using ListBucketResult = io::Result>; - GCS(SSL_CTX* ssl_cntx, fb2::ProactorBase* pb); + GCS(GCPCredsProvider* creds_provider, SSL_CTX* ssl_cntx, fb2::ProactorBase* pb); ~GCS(); std::error_code Connect(unsigned msec); @@ -31,8 +79,9 @@ class GCS { ListBucketResult ListBuckets(); private: + GCPCredsProvider& creds_provider_; SSL_CTX* ssl_ctx_; - std::unique_ptr client_; + std::unique_ptr client_; }; } // namespace cloud diff --git a/util/http/http_client.h b/util/http/http_client.h index 6d9c97ef..5718bd83 100644 --- a/util/http/http_client.h +++ b/util/http/http_client.h @@ -80,6 +80,10 @@ class Client { return socket_->native_handle(); } + ProactorBase* proactor() const { + return proactor_; + } + protected: std::unique_ptr socket_; @@ -123,6 +127,7 @@ template auto Client::Send(const Req& req, Resp* r } *resp = Resp{}; + ec = read_ec; } return HandleError(ec);