From c8bc677a62ebca1efc36989d1451552e46521f7c Mon Sep 17 00:00:00 2001 From: Thomas Farr Date: Wed, 15 Mar 2023 10:10:25 +1300 Subject: [PATCH] Refactor AWS SigV4 to middleware approach Signed-off-by: Thomas Farr --- .github/workflows/test.yml | 2 +- Makefile.toml | 10 +- opensearch/Cargo.toml | 19 +- opensearch/examples/aws_auth.rs | 17 +- opensearch/src/auth.rs | 76 +++-- opensearch/src/aws/mod.rs | 18 ++ opensearch/src/aws/sigv4.rs | 262 ++++++++++++++++++ opensearch/src/error.rs | 119 +++----- opensearch/src/http/aws_auth.rs | 92 ------ opensearch/src/http/headers.rs | 2 +- .../src/http/middleware/initializers.rs | 45 +++ opensearch/src/http/middleware/mod.rs | 58 ++++ .../src/http/middleware/request_pipeline.rs | 142 ++++++++++ opensearch/src/http/mod.rs | 8 +- opensearch/src/http/transport.rs | 256 ++++++----------- opensearch/src/lib.rs | 31 ++- opensearch/src/models/mod.rs | 13 + opensearch/tests/auth.rs | 85 +++--- opensearch/tests/aws_auth.rs | 139 +++++----- opensearch/tests/cert.rs | 49 ++-- opensearch/tests/client.rs | 175 ++++++------ opensearch/tests/common/client.rs | 82 ++++-- opensearch/tests/common/mod.rs | 18 +- opensearch/tests/common/server.rs | 262 +++++++++++------- opensearch/tests/error.rs | 6 +- 25 files changed, 1172 insertions(+), 814 deletions(-) create mode 100644 opensearch/src/aws/mod.rs create mode 100644 opensearch/src/aws/sigv4.rs delete mode 100644 opensearch/src/http/aws_auth.rs create mode 100644 opensearch/src/http/middleware/initializers.rs create mode 100644 opensearch/src/http/middleware/mod.rs create mode 100644 opensearch/src/http/middleware/request_pipeline.rs diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 82ecd08b..a9bbf290 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,7 +36,7 @@ jobs: version: 2.8.0 secured: true - - name: Run Tests (${{ matrix.test-args }}) + - name: Run Tests working-directory: client run: cargo make test ${{ matrix.test-args }} env: diff --git a/Makefile.toml b/Makefile.toml index 95b496eb..c01a1dc2 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -20,18 +20,14 @@ OPENSEARCH_URL = { value = "${OPENSEARCH_PROTOCOL}://localhost:9200", condition category = "OpenSearch" description = "Generates SSL certificates used for integration tests" command = "bash" -args =["./.ci/generate-certs.sh"] +args = ["./.ci/generate-certs.sh"] [tasks.run-opensearch] category = "OpenSearch" private = true condition = { env_set = [ "STACK_VERSION"], env_false = ["CARGO_MAKE_CI"] } - -[tasks.run-opensearch.linux] -command = "./.ci/run-opensearch.sh" - -[tasks.run-opensearch.mac] -command = "./.ci/run-opensearch.sh" +command = "bash" +args = ["./.ci/run-opensearch.sh"] [tasks.run-opensearch.windows] script_runner = "cmd" diff --git a/opensearch/Cargo.toml b/opensearch/Cargo.toml index 28bb3ba3..85f5e09e 100644 --- a/opensearch/Cargo.toml +++ b/opensearch/Cargo.toml @@ -25,13 +25,19 @@ experimental-apis = ["beta-apis"] native-tls = ["reqwest/native-tls"] rustls-tls = ["reqwest/rustls-tls"] -# AWS SigV4 Auth support -aws-auth = ["aws-credential-types", "aws-sigv4", "aws-smithy-runtime-api", "aws-types"] +aws-auth = ["dep:aws-config", "dep:aws-credential-types", "dep:aws-sigv4", "dep:aws-smithy-runtime-api", "dep:aws-types"] [dependencies] +async-trait = "0.1" +aws-config = { version = "1", optional = true } +aws-credential-types = { version = "1", optional = true } +aws-sigv4 = { version = "1", optional = true } +aws-smithy-runtime-api = { version = "1", features = ["client"], optional = true } +aws-types = { version = "1", optional = true } base64 = "0.21" bytes = "1.0" dyn-clone = "1" +futures-util = "0.3" lazy_static = "1.4" percent-encoding = "2.1.0" reqwest = { version = "0.11", default-features = false, features = ["gzip", "json"] } @@ -39,23 +45,18 @@ url = "2.1" serde = { version = "1", features = ["derive"] } serde_json = "1" serde_with = "3" +thiserror = "1" void = "1.0.2" -aws-credential-types = { version = "1", optional = true } -aws-sigv4 = { version = "1", optional = true } -aws-smithy-runtime-api = { version = "1", optional = true, features = ["client"]} -aws-types = { version = "1", optional = true } [dev-dependencies] anyhow = "1.0" -aws-config = "1" aws-smithy-async = "1" chrono = { version = "0.4", features = ["serde"] } clap = "2" -futures = "0.3.1" +futures = "0.3" http-body-util = "0.1.0" hyper = { version = "1", features = ["full"] } hyper-util = { version = "0.1", features = ["full"] } -regex="1.4" sysinfo = "0.29.0" test-case = "3" textwrap = "0.16" diff --git a/opensearch/examples/aws_auth.rs b/opensearch/examples/aws_auth.rs index 81783474..26450ff7 100644 --- a/opensearch/examples/aws_auth.rs +++ b/opensearch/examples/aws_auth.rs @@ -9,23 +9,26 @@ * GitHub history for details. */ -#[tokio::main] #[cfg(feature = "aws-auth")] +#[tokio::main] pub async fn main() -> Result<(), Box> { use std::convert::TryInto; + use aws_config::BehaviorVersion; use opensearch::{ cat::CatIndicesParts, - http::transport::{SingleNodeConnectionPool, TransportBuilder}, + http::{ + transport::{SingleNodeConnectionPool, TransportBuilder}, + Url, + }, OpenSearch, }; - use url::Url; - let aws_config = aws_config::load_from_env().await; + let aws_config = aws_config::load_defaults(BehaviorVersion::latest()).await; let host = ""; // e.g. https://search-mydomain.us-west-1.es.amazonaws.com let transport = TransportBuilder::new(SingleNodeConnectionPool::new(Url::parse(host).unwrap())) - .auth(aws_config.try_into()?) + .aws_sigv4(aws_config.try_into()?) .build()?; let client = OpenSearch::new(transport); @@ -42,6 +45,6 @@ pub async fn main() -> Result<(), Box> { } #[cfg(not(feature = "aws-auth"))] -pub fn main() { - panic!("Requires the `aws-auth` feature to be enabled") +fn main() { + panic!("This example requires the `aws-auth` feature to be enabled") } diff --git a/opensearch/src/auth.rs b/opensearch/src/auth.rs index 787917ba..65d32dea 100644 --- a/opensearch/src/auth.rs +++ b/opensearch/src/auth.rs @@ -30,6 +30,12 @@ //! Authentication components +use crate::{ + http::middleware::{ClientInitializer, RequestInitializer}, + BoxError, +}; +use reqwest::Identity; + /// Credentials for authentication #[derive(Debug, Clone)] pub enum Credentials { @@ -45,16 +51,6 @@ pub enum Credentials { Certificate(ClientCertificate), /// An id and api_key to use for API key authentication ApiKey(String, String), - /// AWS credentials used for AWS SigV4 request signing. - /// - /// # Optional - /// - /// This requires the `aws-auth` feature to be enabled. - #[cfg(feature = "aws-auth")] - AwsSigV4( - aws_credential_types::provider::SharedCredentialsProvider, - aws_types::region::Region, - ), } #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] @@ -90,28 +86,50 @@ impl From for Credentials { } } -#[cfg(any(feature = "aws-auth"))] -impl std::convert::TryFrom<&aws_types::SdkConfig> for Credentials { - type Error = super::Error; - - fn try_from(value: &aws_types::SdkConfig) -> Result { - let credentials = value - .credentials_provider() - .ok_or_else(|| super::error::lib("SdkConfig does not have a credentials_provider"))? - .clone(); - let region = value - .region() - .ok_or_else(|| super::error::lib("SdkConfig does not have a region"))? - .clone(); - Ok(Credentials::AwsSigV4(credentials, region)) +impl ClientInitializer for Credentials { + fn init( + &self, + client: reqwest::ClientBuilder, + ) -> Result> { + match &self { + #[cfg(feature = "native-tls")] + Credentials::Certificate(ClientCertificate::Pkcs12(b, p)) => { + Ok(client.identity(Identity::from_pkcs12_der(b, p.as_deref().unwrap_or(""))?)) + } + #[cfg(feature = "rustls-tls")] + Credentials::Certificate(ClientCertificate::Pem(b)) => { + Ok(client.identity(Identity::from_pem(b)?)) + } + _ => Ok(client), + } } } -#[cfg(any(feature = "aws-auth"))] -impl std::convert::TryFrom for Credentials { - type Error = super::Error; +impl RequestInitializer for Credentials { + fn init( + &self, + request: reqwest::RequestBuilder, + ) -> Result> { + Ok(match &self { + Credentials::Basic(u, p) => request.basic_auth(u, Some(p)), + Credentials::Bearer(t) => request.bearer_auth(t), + Credentials::ApiKey(id, key) => { + use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder}; + use reqwest::header::{HeaderValue, AUTHORIZATION}; + use std::io::Write; - fn try_from(value: aws_types::SdkConfig) -> Result { - Credentials::try_from(&value) + let mut header_value = b"ApiKey ".to_vec(); + { + let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD); + write!(encoder, "{}:", id).unwrap(); + write!(encoder, "{}", key).unwrap(); + } + request.header( + AUTHORIZATION, + HeaderValue::from_bytes(&header_value).unwrap(), + ) + } + _ => request, + }) } } diff --git a/opensearch/src/aws/mod.rs b/opensearch/src/aws/mod.rs new file mode 100644 index 00000000..bc548d11 --- /dev/null +++ b/opensearch/src/aws/mod.rs @@ -0,0 +1,18 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +pub use aws_config; +pub use aws_credential_types; +pub use aws_types; + +mod sigv4; + +pub use sigv4::*; diff --git a/opensearch/src/aws/sigv4.rs b/opensearch/src/aws/sigv4.rs new file mode 100644 index 00000000..7b59e79e --- /dev/null +++ b/opensearch/src/aws/sigv4.rs @@ -0,0 +1,262 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +use std::{borrow::Cow, convert::TryFrom, str::Utf8Error}; + +use async_trait::async_trait; +use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider}; +use aws_sigv4::{ + http_request::{ + sign, PayloadChecksumKind, SignableBody, SignableRequest, SignatureLocation, SigningParams, + SigningSettings, + }, + sign::v4, +}; +use aws_smithy_runtime_api::client::identity::Identity; +use aws_types::{sdk_config::SharedTimeSource, SdkConfig}; +use reqwest::{ + header::{HeaderName, HeaderValue, CONTENT_LENGTH, USER_AGENT}, + Request, Response, +}; +use thiserror::Error; + +use crate::http::{ + middleware::{RequestHandler, RequestPipeline, RequestPipelineError}, + transport::TransportBuilder, +}; + +#[derive(Error, Debug)] +pub enum AwsSigV4BuildError { + #[error("the region for signing must be provided")] + MissingRegion, + #[error("the credentials provider for signing must be provided")] + MissingCredentialsProvider, +} + +#[derive(Debug, Clone)] +pub struct AwsSigV4Builder { + service_name: Option, + credentials_provider: Option, + region: Option, + ignored_headers: Vec, + time_source: Option, +} + +impl AwsSigV4Builder { + pub fn service_name(mut self, service_name: impl AsRef) -> Self { + self.service_name = Some(service_name.as_ref().to_owned()); + self + } + + pub fn credentials_provider( + mut self, + credentials_provider: impl ProvideCredentials + 'static, + ) -> Self { + self.credentials_provider = Some(SharedCredentialsProvider::new(credentials_provider)); + self + } + + pub fn region(mut self, region: impl AsRef) -> Self { + self.region = Some(region.as_ref().to_owned()); + self + } + + pub fn ignore_header(mut self, header_name: impl AsRef) -> Self { + self.ignored_headers.push(header_name.as_ref().to_owned()); + self + } + + #[doc(hidden)] + pub fn time_source(mut self, time_source: impl Into) -> Self { + self.time_source = Some(time_source.into()); + self + } + + pub fn build(self) -> Result { + Ok(AwsSigV4 { + service_name: self.service_name.unwrap_or_else(|| "es".into()), + credentials_provider: self + .credentials_provider + .ok_or(AwsSigV4BuildError::MissingCredentialsProvider)?, + region: self.region.ok_or(AwsSigV4BuildError::MissingRegion)?, + ignored_headers: self.ignored_headers.into_iter().map(Cow::Owned).collect(), + time_source: self.time_source.unwrap_or_default(), + }) + } +} + +impl Default for AwsSigV4Builder { + fn default() -> Self { + Self { + service_name: None, + credentials_provider: None, + region: None, + ignored_headers: vec![USER_AGENT.as_str().into(), CONTENT_LENGTH.as_str().into()], + time_source: None, + } + } +} + +impl From<&SdkConfig> for AwsSigV4Builder { + fn from(value: &SdkConfig) -> Self { + Self { + credentials_provider: value.credentials_provider(), + region: value.region().map(|r| r.to_string()), + time_source: value.time_source(), + ..Default::default() + } + } +} + +impl From for AwsSigV4Builder { + fn from(value: SdkConfig) -> Self { + >::from(&value) + } +} + +#[derive(Error, Debug)] +pub enum AwsSigV4Error { + #[error("invalid signing params: {0}")] + InvalidSigningParams(#[from] v4::signing_params::BuildError), + #[error("unable to retrieve credentials: {0}")] + FailedCredentialsRetrieval(#[from] aws_credential_types::provider::error::CredentialsError), + #[error("unable to sign request: {0}")] + FailedSigning(#[from] aws_sigv4::http_request::SigningError), + #[error("unable to sign a non UTF-8 header {0:?}: {1}")] + NonUtf8Header(HeaderName, Utf8Error), +} + +#[derive(Debug, Clone)] +pub struct AwsSigV4 { + service_name: String, + credentials_provider: SharedCredentialsProvider, + region: String, + ignored_headers: Vec>, + time_source: SharedTimeSource, +} + +impl AwsSigV4 { + pub fn builder() -> AwsSigV4Builder { + AwsSigV4Builder::default() + } + + async fn sign_request(&self, request: &mut Request) -> Result<(), AwsSigV4Error> { + let identity = self + .credentials_provider + .provide_credentials() + .await? + .into(); + + let params = self.build_params(&identity)?; + + let signable_request = self.build_signable_request(request)?; + + let (new_headers, new_query_params) = { + let (instructions, _) = sign(signable_request, ¶ms)?.into_parts(); + instructions.into_parts() + }; + + for header in new_headers.into_iter() { + let mut value = HeaderValue::from_str(header.value()) + .expect("AWS signing header value must be a valid header"); + value.set_sensitive(header.sensitive()); + + request.headers_mut().insert(header.name(), value); + } + + for (key, value) in new_query_params.into_iter() { + request.url_mut().query_pairs_mut().append_pair(key, &value); + } + + Ok(()) + } + + fn build_params<'a>( + &'a self, + identity: &'a Identity, + ) -> Result, AwsSigV4Error> { + let mut signing_settings = SigningSettings::default(); + signing_settings.signature_location = SignatureLocation::Headers; + signing_settings.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; // required for OpenSearch Serverless + signing_settings.excluded_headers = Some(self.ignored_headers.clone()); + + let params = v4::SigningParams::builder() + .identity(&identity) + .name(&self.service_name) + .region(self.region.as_ref()) + .time(self.time_source.now()) + .settings(signing_settings) + .build()?; + + Ok(SigningParams::V4(params)) + } + + fn build_signable_request<'a>( + &'a self, + request: &'a Request, + ) -> Result, AwsSigV4Error> { + let method = request.method().as_str(); + let uri = request.url().as_str(); + + let mut headers = Vec::with_capacity(request.headers().len()); + for (name, value) in request.headers().iter() { + let value = std::str::from_utf8(value.as_bytes()) + .map_err(|e| AwsSigV4Error::NonUtf8Header(name.clone(), e))?; + headers.push((name.as_str(), value)) + } + + let body = match request.body() { + Some(b) => match b.as_bytes() { + Some(bytes) => SignableBody::Bytes(bytes), + None => SignableBody::UnsignedPayload, // Body is not in memory (ie. streaming), so we can't sign it + }, + None => SignableBody::Bytes(&[]), + }; + + SignableRequest::new(method, uri, headers.into_iter(), body).map_err(Into::into) + } +} + +impl TryFrom<&SdkConfig> for AwsSigV4 { + type Error = AwsSigV4BuildError; + + fn try_from(value: &SdkConfig) -> Result { + AwsSigV4Builder::from(value).build() + } +} + +impl TryFrom for AwsSigV4 { + type Error = AwsSigV4BuildError; + + fn try_from(value: SdkConfig) -> Result { + >::try_from(&value) + } +} + +#[async_trait] +impl RequestHandler for AwsSigV4 { + async fn handle( + &self, + mut request: Request, + next: RequestPipeline<'_>, + ) -> Result { + self.sign_request(&mut request) + .await + .map_err(|e| RequestPipelineError::Pipeline(e.into()))?; + next.run(request).await + } +} + +impl TransportBuilder { + pub fn aws_sigv4(self, aws_sigv4: AwsSigV4) -> Self { + self.with_handler(aws_sigv4) + } +} diff --git a/opensearch/src/error.rs b/opensearch/src/error.rs index 72709550..ea1d30bb 100644 --- a/opensearch/src/error.rs +++ b/opensearch/src/error.rs @@ -32,125 +32,82 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -use crate::http::{transport::BuildError, StatusCode}; -use std::{error, fmt, io}; + +use crate::http::{middleware::RequestPipelineError, transport::BuildError, StatusCode}; + +pub type BoxError<'a> = Box; /// An error with the client. /// /// Errors that can occur include IO and parsing errors, as well as specific -/// errors from Elasticsearch and internal errors from the client. -#[derive(Debug)] -pub struct Error { - kind: Kind, -} - -#[derive(Debug)] -enum Kind { +/// errors from OpenSearch and internal errors from the client. +#[derive(Debug, thiserror::Error)] +pub enum Error { /// An error building the client - Build(BuildError), + #[error("Error building the client: {0}")] + Build(#[from] BuildError), /// A general error from this library + #[error("Library error: {0}")] Lib(String), - /// HTTP library error - Http(reqwest::Error), + /// Reqwest error + #[error("Reqwest error: {0}")] + Reqwest(#[from] reqwest::Error), + + /// URL parse error + #[error("URL parse error: {0}")] + UrlParse(#[from] url::ParseError), /// IO error - Io(io::Error), + #[error("IO error: {0}")] + Io(#[from] std::io::Error), /// JSON error - Json(serde_json::error::Error), -} - -impl From for Error { - fn from(err: io::Error) -> Error { - Error { - kind: Kind::Io(err), - } - } -} - -impl From for Error { - fn from(err: reqwest::Error) -> Error { - Error { - kind: Kind::Http(err), - } - } -} + #[error("JSON error: {0}")] + Json(#[from] serde_json::error::Error), -impl From for Error { - fn from(err: serde_json::error::Error) -> Error { - Error { - kind: Kind::Json(err), - } - } -} + /// Request initializer error + #[error("Request initializer error: {0}")] + RequestInitializer(#[source] BoxError<'static>), -impl From for Error { - fn from(err: url::ParseError) -> Error { - Error { - kind: Kind::Lib(err.to_string()), - } - } + /// Request pipeline error + #[error("Request pipeline error: {0}")] + RequestPipeline(#[source] BoxError<'static>), } -impl From for Error { - fn from(err: BuildError) -> Error { - Error { - kind: Kind::Build(err), +impl From for Error { + fn from(err: RequestPipelineError) -> Self { + match err { + RequestPipelineError::Reqwest(err) => Self::Reqwest(err), + RequestPipelineError::Pipeline(err) => Self::RequestPipeline(err), } } } pub(crate) fn lib(err: impl Into) -> Error { - Error { - kind: Kind::Lib(err.into()), - } + Error::Lib(err.into()) } impl Error { /// The status code, if the error was generated from a response pub fn status_code(&self) -> Option { - match &self.kind { - Kind::Http(err) => err.status(), + match &self { + Self::Reqwest(err) => err.status(), _ => None, } } /// Returns true if the error is related to a timeout pub fn is_timeout(&self) -> bool { - match &self.kind { - Kind::Http(err) => err.is_timeout(), + match &self { + Self::Reqwest(err) => err.is_timeout(), _ => false, } } /// Returns true if the error is related to serialization or deserialization pub fn is_json(&self) -> bool { - matches!(self.kind, Kind::Json(_)) - } -} - -impl error::Error for Error { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - match &self.kind { - Kind::Build(err) => Some(err), - Kind::Lib(_) => None, - Kind::Http(err) => Some(err), - Kind::Io(err) => Some(err), - Kind::Json(err) => Some(err), - } - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match &self.kind { - Kind::Build(err) => err.fmt(f), - Kind::Lib(err) => err.fmt(f), - Kind::Http(err) => err.fmt(f), - Kind::Io(err) => err.fmt(f), - Kind::Json(err) => err.fmt(f), - } + matches!(self, Self::Json(_)) } } diff --git a/opensearch/src/http/aws_auth.rs b/opensearch/src/http/aws_auth.rs deleted file mode 100644 index 6150ffd0..00000000 --- a/opensearch/src/http/aws_auth.rs +++ /dev/null @@ -1,92 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -use crate::http::headers::HeaderValue; -use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider}; -use aws_sigv4::{ - http_request::{ - sign, PayloadChecksumKind, SignableBody, SignableRequest, SigningParams, SigningSettings, - }, - sign::v4, -}; -use aws_smithy_runtime_api::client::identity::Identity; -use aws_types::{region::Region, sdk_config::SharedTimeSource}; -use reqwest::Request; - -pub async fn sign_request( - request: &mut Request, - credentials_provider: &SharedCredentialsProvider, - service_name: &str, - region: &Region, - time_source: &SharedTimeSource, -) -> Result<(), Box> { - let identity = { - let c = credentials_provider.provide_credentials().await?; - let e = c.expiry(); - Identity::new(c, e) - }; - - let signing_settings = { - let mut s = SigningSettings::default(); - s.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; // required for OpenSearch Serverless - s - }; - - let params = { - let p = v4::SigningParams::builder() - .identity(&identity) - .name(service_name) - .region(region.as_ref()) - .time(time_source.now()) - .settings(signing_settings) - .build()?; - SigningParams::V4(p) - }; - - let signable_request = { - let method = request.method().as_str(); - let uri = request.url().as_str(); - let headers = request.headers().iter().map(|(k, v)| { - ( - k.as_str(), - std::str::from_utf8(v.as_bytes()).expect("only utf-8 headers are signable"), - ) - }); - let body = match request.body() { - Some(b) => match b.as_bytes() { - Some(bytes) => SignableBody::Bytes(bytes), - None => SignableBody::UnsignedPayload, // Body is not in memory (ie. streaming), so we can't sign it - }, - None => SignableBody::Bytes(&[]), - }; - - SignableRequest::new(method, uri, headers, body)? - }; - - let (new_headers, new_query_params) = { - let (instructions, _) = sign(signable_request, ¶ms)?.into_parts(); - instructions.into_parts() - }; - - for header in new_headers.into_iter() { - let mut value = HeaderValue::from_str(header.value()) - .expect("AWS signing header value must be a valid header"); - value.set_sensitive(header.sensitive()); - - request.headers_mut().insert(header.name(), value); - } - - for (key, value) in new_query_params.into_iter() { - request.url_mut().query_pairs_mut().append_pair(key, &value); - } - - Ok(()) -} diff --git a/opensearch/src/http/headers.rs b/opensearch/src/http/headers.rs index 713168eb..4f616720 100644 --- a/opensearch/src/http/headers.rs +++ b/opensearch/src/http/headers.rs @@ -28,7 +28,7 @@ * GitHub history for details. */ -//! HTTP header names and values, including those specific to Elasticsearch +//! HTTP header names and values, including those specific to OpenSearch pub use reqwest::header::*; diff --git a/opensearch/src/http/middleware/initializers.rs b/opensearch/src/http/middleware/initializers.rs new file mode 100644 index 00000000..07fb194c --- /dev/null +++ b/opensearch/src/http/middleware/initializers.rs @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +use super::shared_middleware; +use crate::BoxError; +use reqwest::{ClientBuilder, RequestBuilder}; + +pub trait ClientInitializer: Send + Sync + 'static { + fn init(&self, client: ClientBuilder) -> Result>; +} + +impl ClientInitializer for F +where + F: Fn(ClientBuilder) -> Result> + Send + Sync + 'static, +{ + fn init(&self, client: ClientBuilder) -> Result> { + self(client) + } +} + +pub trait RequestInitializer: Send + Sync + 'static { + fn init(&self, request: RequestBuilder) -> Result>; +} + +impl RequestInitializer for F +where + F: Fn(RequestBuilder) -> Result> + Send + Sync + 'static, +{ + fn init(&self, request: RequestBuilder) -> Result> { + self(request) + } +} + +shared_middleware!( + SharedClientInitializer(ClientInitializer), + SharedRequestInitializer(RequestInitializer) +); diff --git a/opensearch/src/http/middleware/mod.rs b/opensearch/src/http/middleware/mod.rs new file mode 100644 index 00000000..b29ce054 --- /dev/null +++ b/opensearch/src/http/middleware/mod.rs @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +mod initializers; +mod request_pipeline; + +pub use async_trait::async_trait; +pub use initializers::*; +pub use request_pipeline::*; + +macro_rules! shared_middleware { + ($($shared:ident($trait:ident)),*) => { + $( + #[derive(Clone)] + pub struct $shared(std::sync::Arc); + + impl From for $shared + where M: $trait + { + fn from(middleware: M) -> Self { + Self(std::sync::Arc::new(middleware)) + } + } + + impl From> for $shared + where M: $trait + { + fn from(middleware: std::sync::Arc) -> Self { + Self(middleware) + } + } + + impl From> for $shared { + fn from(middleware: std::sync::Arc) -> Self { + Self(middleware) + } + } + + impl std::ops::Deref for $shared { + type Target = dyn $trait; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } + } + )* + } +} + +pub(self) use shared_middleware; diff --git a/opensearch/src/http/middleware/request_pipeline.rs b/opensearch/src/http/middleware/request_pipeline.rs new file mode 100644 index 00000000..d9e2c59e --- /dev/null +++ b/opensearch/src/http/middleware/request_pipeline.rs @@ -0,0 +1,142 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +//! Request pipeline types +//! +//! Examples: +//! +//! ```no_run +//! use futures_util::future::BoxFuture; +//! use opensearch::http::{ +//! middleware::{RequestPipeline, RequestPipelineError}, +//! reqwest::{Request, Response}, +//! }; +//! +//! fn logger<'a>(req: Request, next: RequestPipeline<'a>) -> BoxFuture<'a, Result> { +//! Box::pin(async move { +//! println!("sending request to {}", req.url()); +//! let now = std::time::Instant::now(); +//! let res = next.run(req).await?; +//! println!("request completed ({:?})", now.elapsed()); +//! Ok(res) +//! }) +//! } +//! +//! # #[tokio::main] +//! # async fn main() { +//! # use opensearch::http::{ +//! # transport::{SingleNodeConnectionPool, TransportBuilder}, +//! # Url, +//! # }; +//! # let conn_pool = SingleNodeConnectionPool::new(Url::parse("http://localhost:9200").unwrap()); +//! # let _ = TransportBuilder::new(conn_pool).with_handler(logger); +//! # } +//! ``` +//! +//! ```no_run +//! use futures_util::future::BoxFuture; +//! use opensearch::http::{ +//! middleware::{async_trait, RequestHandler, RequestPipeline, RequestPipelineError}, +//! reqwest::{Request, Response}, +//! }; +//! +//! struct Logger; +//! +//! #[async_trait] +//! impl RequestHandler for Logger { +//! async fn handle(&self, request: Request, next: RequestPipeline<'_>) -> Result { +//! println!("sending request to {}", request.url()); +//! let now = std::time::Instant::now(); +//! let res = next.run(request).await?; +//! println!("request completed ({:?})", now.elapsed()); +//! Ok(res) +//! } +//! } +//! +//! # #[tokio::main] +//! # async fn main() { +//! # use opensearch::http::{ +//! # transport::{SingleNodeConnectionPool, TransportBuilder}, +//! # Url, +//! # }; +//! # let conn_pool = SingleNodeConnectionPool::new(Url::parse("http://localhost:9200").unwrap()); +//! # let _ = TransportBuilder::new(conn_pool).with_handler(Logger); +//! # } +//! ``` + +use super::{async_trait, shared_middleware}; +use crate::BoxError; +use futures_util::future::BoxFuture; +use reqwest::{Client, Request, Response}; +use std::fmt::Debug; + +#[derive(Debug, thiserror::Error)] +pub enum RequestPipelineError { + #[error("Reqwest error: {0}")] + Reqwest(#[from] reqwest::Error), + + #[error("Pipeline error: {0}")] + Pipeline(#[from] BoxError<'static>), +} + +#[async_trait] +pub trait RequestHandler: Send + Sync + 'static { + async fn handle( + &self, + request: Request, + next: RequestPipeline<'_>, + ) -> Result; +} + +#[async_trait] +impl RequestHandler for F +where + F: for<'a> Fn( + Request, + RequestPipeline<'a>, + ) -> BoxFuture<'a, Result> + + Send + + Sync + + 'static, +{ + async fn handle( + &self, + request: Request, + next: RequestPipeline<'_>, + ) -> Result { + self(request, next).await + } +} + +shared_middleware!(SharedRequestHandler(RequestHandler)); + +pub struct RequestPipeline<'a> { + pub client: &'a Client, + pipeline: &'a [SharedRequestHandler], +} + +impl<'a> RequestPipeline<'a> { + pub(crate) fn new(client: &'a Client, pipeline: &'a [SharedRequestHandler]) -> Self { + Self { client, pipeline } + } + + pub fn run( + mut self, + request: Request, + ) -> BoxFuture<'a, Result> { + if let Some((head, tail)) = self.pipeline.split_first() { + self.pipeline = tail; + head.handle(request, self) + } else { + Box::pin(async move { self.client.execute(request).await.map_err(Into::into) }) + } + } +} diff --git a/opensearch/src/http/mod.rs b/opensearch/src/http/mod.rs index 9e8a6417..7ae54293 100644 --- a/opensearch/src/http/mod.rs +++ b/opensearch/src/http/mod.rs @@ -30,18 +30,16 @@ //! HTTP components -#[cfg(feature = "aws-auth")] -pub(crate) mod aws_auth; - pub mod headers; +pub mod middleware; pub mod request; pub mod response; pub mod transport; -pub use reqwest::StatusCode; +pub use reqwest::{self, Request, StatusCode}; pub use url::Url; -/// Http methods supported by Elasticsearch +/// Http methods supported by OpenSearch #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Method { /// get diff --git a/opensearch/src/http/transport.rs b/opensearch/src/http/transport.rs index ca981bd2..035a23cf 100644 --- a/opensearch/src/http/transport.rs +++ b/opensearch/src/http/transport.rs @@ -30,8 +30,9 @@ //! HTTP transport and connection components -#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] -use crate::auth::ClientCertificate; +use super::middleware::{ + RequestPipeline, SharedClientInitializer, SharedRequestHandler, SharedRequestInitializer, +}; #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] use crate::cert::CertificateValidation; use crate::{ @@ -39,75 +40,32 @@ use crate::{ error::Error, http::{ headers::{ - HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE, - DEFAULT_ACCEPT, DEFAULT_CONTENT_TYPE, DEFAULT_USER_AGENT, USER_AGENT, + HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE, DEFAULT_ACCEPT, + DEFAULT_CONTENT_TYPE, DEFAULT_USER_AGENT, USER_AGENT, }, request::Body, response::Response, Method, }, + BoxError, }; -#[cfg(feature = "aws-auth")] -use aws_types::sdk_config::SharedTimeSource; -use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder}; use bytes::BytesMut; use dyn_clone::clone_trait_object; use lazy_static::lazy_static; use serde::Serialize; -use std::{ - error, fmt, - fmt::Debug, - io::{self, Write}, - time::Duration, -}; +use std::{fmt::Debug, sync::Arc, time::Duration}; use url::Url; /// Error that can occur when building a [Transport] -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum BuildError { - /// IO error - Io(io::Error), - - /// Certificate error - Cert(reqwest::Error), -} - -impl From for BuildError { - fn from(err: io::Error) -> BuildError { - BuildError::Io(err) - } -} + /// Reqwest error + #[error("Reqwest error: {0}")] + Reqwest(#[from] reqwest::Error), -impl From for BuildError { - fn from(err: reqwest::Error) -> BuildError { - BuildError::Cert(err) - } -} - -impl error::Error for BuildError { - #[allow(warnings)] - fn description(&self) -> &str { - match *self { - BuildError::Io(ref err) => err.description(), - BuildError::Cert(ref err) => err.description(), - } - } - - fn cause(&self) -> Option<&dyn error::Error> { - match *self { - BuildError::Io(ref err) => Some(err as &dyn error::Error), - BuildError::Cert(ref err) => Some(err as &dyn error::Error), - } - } -} - -impl fmt::Display for BuildError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - BuildError::Io(ref err) => fmt::Display::fmt(err, f), - BuildError::Cert(ref err) => fmt::Display::fmt(err, f), - } - } + /// Client initializer error + #[error("Client initializer error: {0}")] + ClientInitializer(#[from] BoxError<'static>), } /// Default address to OpenSearch running on `http://localhost:9200` @@ -148,7 +106,9 @@ fn build_meta() -> String { pub struct TransportBuilder { client_builder: reqwest::ClientBuilder, conn_pool: Box, - credentials: Option, + init_stack: Vec, + req_init_stack: Vec, + req_handler_stack: Vec, #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] cert_validation: Option, proxy: Option, @@ -156,10 +116,6 @@ pub struct TransportBuilder { disable_proxy: bool, headers: HeaderMap, timeout: Option, - #[cfg(feature = "aws-auth")] - sigv4_service_name: String, - #[cfg(feature = "aws-auth")] - sigv4_time_source: Option, } impl TransportBuilder { @@ -172,7 +128,9 @@ impl TransportBuilder { Self { client_builder: reqwest::ClientBuilder::new(), conn_pool: Box::new(conn_pool), - credentials: None, + init_stack: Vec::new(), + req_init_stack: Vec::new(), + req_handler_stack: Vec::new(), #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] cert_validation: None, proxy: None, @@ -180,13 +138,24 @@ impl TransportBuilder { disable_proxy: false, headers: HeaderMap::new(), timeout: None, - #[cfg(feature = "aws-auth")] - sigv4_service_name: "es".to_string(), - #[cfg(feature = "aws-auth")] - sigv4_time_source: None, } } + pub fn with_init(mut self, init: impl Into) -> Self { + self.init_stack.push(init.into()); + self + } + + pub fn with_req_init(mut self, init: impl Into) -> Self { + self.req_init_stack.push(init.into()); + self + } + + pub fn with_handler(mut self, handler: impl Into) -> Self { + self.req_handler_stack.push(handler.into()); + self + } + /// Configures a proxy. /// /// An optional username and password will be used to set the @@ -210,9 +179,10 @@ impl TransportBuilder { } /// Credentials for the client to use for authentication to OpenSearch. - pub fn auth(mut self, credentials: Credentials) -> Self { - self.credentials = Some(credentials); - self + pub fn auth(self, credentials: Credentials) -> Self { + let credentials = Arc::new(credentials); + self.with_init(credentials.clone()) + .with_req_init(credentials) } /// Validation applied to the certificate provided to establish a HTTPS connection. @@ -251,25 +221,7 @@ impl TransportBuilder { self } - /// Sets the AWS SigV4 signing service name. - /// - /// Default is "es". Other supported services are "aoss" for OpenSearch Serverless. - #[cfg(feature = "aws-auth")] - pub fn service_name(mut self, service_name: &str) -> Self { - self.sigv4_service_name = service_name.to_string(); - self - } - - /// Sets the AWS SigV4 signing time source. - /// - /// Default is `SystemTimeSource` - #[cfg(feature = "aws-auth")] - pub fn sigv4_time_source(mut self, sigv4_time_source: SharedTimeSource) -> Self { - self.sigv4_time_source = Some(sigv4_time_source); - self - } - - /// Builds a [Transport] to use to send API calls to Elasticsearch. + /// Builds a [Transport] to use to send API calls to OpenSearch. pub fn build(self) -> Result { let mut client_builder = self.client_builder; @@ -277,28 +229,6 @@ impl TransportBuilder { client_builder = client_builder.timeout(t); } - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - { - if let Some(Credentials::Certificate(cert)) = &self.credentials { - client_builder = match cert { - #[cfg(feature = "native-tls")] - ClientCertificate::Pkcs12(b, p) => { - let password = match p { - Some(pass) => pass.as_str(), - None => "", - }; - let pkcs12 = reqwest::Identity::from_pkcs12_der(b, password)?; - client_builder.identity(pkcs12) - } - #[cfg(feature = "rustls-tls")] - ClientCertificate::Pem(b) => { - let pem = reqwest::Identity::from_pem(b)?; - client_builder.identity(pem) - } - } - } - } - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] if let Some(v) = self.cert_validation { client_builder = match v { @@ -332,16 +262,20 @@ impl TransportBuilder { client_builder = client_builder.proxy(proxy); } + client_builder = self + .init_stack + .into_iter() + .try_fold(client_builder, |client_builder, init| { + init.init(client_builder) + })?; + let client = client_builder.build()?; Ok(Transport { client, - conn_pool: self.conn_pool, - credentials: self.credentials, default_headers: self.headers, - #[cfg(feature = "aws-auth")] - sigv4_service_name: self.sigv4_service_name, - #[cfg(feature = "aws-auth")] - sigv4_time_source: self.sigv4_time_source.unwrap_or_default(), + conn_pool: self.conn_pool, + req_init_stack: self.req_init_stack.into_boxed_slice(), + req_handler_stack: self.req_handler_stack.into_boxed_slice(), }) } } @@ -353,7 +287,7 @@ impl Default for TransportBuilder { } } -/// A connection to an Elasticsearch node, used to send an API request +/// A connection to an OpenSearch node, used to send an API request #[derive(Debug, Clone)] pub struct Connection { url: Url, @@ -374,18 +308,15 @@ impl Connection { } } -/// A HTTP transport responsible for making the API requests to Elasticsearch, +/// A HTTP transport responsible for making the API requests to OpenSearch, /// using a [Connection] selected from a [ConnectionPool] -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct Transport { client: reqwest::Client, - credentials: Option, - conn_pool: Box, default_headers: HeaderMap, - #[cfg(feature = "aws-auth")] - sigv4_service_name: String, - #[cfg(feature = "aws-auth")] - sigv4_time_source: SharedTimeSource, + req_init_stack: Box<[SharedRequestInitializer]>, + req_handler_stack: Box<[SharedRequestHandler]>, + conn_pool: Box, } impl Transport { @@ -430,38 +361,20 @@ impl Transport { let connection = self.conn_pool.next(); let url = connection.url.join(path.trim_start_matches('/'))?; let reqwest_method = self.method(method); - let mut request_builder = self.client.request(reqwest_method, url); + + let mut request_builder = self + .req_init_stack + .iter() + .try_fold( + self.client.request(reqwest_method, url), + |request_builder, init| init.init(request_builder), + ) + .map_err(Error::RequestInitializer)?; if let Some(t) = timeout { request_builder = request_builder.timeout(t); } - // set credentials before any headers, as credentials append to existing headers in reqwest, - // whilst setting headers() overwrites, so if an Authorization header has been specified - // on a specific request, we want it to overwrite. - if let Some(c) = &self.credentials { - request_builder = match c { - Credentials::Basic(u, p) => request_builder.basic_auth(u, Some(p)), - Credentials::Bearer(t) => request_builder.bearer_auth(t), - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - Credentials::Certificate(_) => request_builder, - Credentials::ApiKey(i, k) => { - let mut header_value = b"ApiKey ".to_vec(); - { - let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD); - write!(encoder, "{}:", i).unwrap(); - write!(encoder, "{}", k).unwrap(); - } - request_builder.header( - AUTHORIZATION, - HeaderValue::from_bytes(&header_value).unwrap(), - ) - } - #[cfg(feature = "aws-auth")] - Credentials::AwsSigV4(_, _) => request_builder, - } - } - // default headers first, overwrite with any provided let mut request_headers = HeaderMap::with_capacity(4 + self.default_headers.len() + headers.len()); @@ -493,27 +406,11 @@ impl Transport { request_builder = request_builder.query(q); } - #[cfg_attr(not(feature = "aws-auth"), allow(unused_mut))] - let mut request = request_builder.build()?; - - #[cfg(feature = "aws-auth")] - if let Some(Credentials::AwsSigV4(credentials_provider, region)) = &self.credentials { - super::aws_auth::sign_request( - &mut request, - credentials_provider, - &self.sigv4_service_name, - region, - &self.sigv4_time_source, - ) - .await - .map_err(|e| crate::error::lib(format!("AWSV4 Signing Failed: {}", e)))?; - } + let response = RequestPipeline::new(&self.client, &self.req_handler_stack) + .run(request_builder.build()?) + .await?; - let response = self.client.execute(request).await; - match response { - Ok(r) => Ok(Response::new(r, method)), - Err(e) => Err(e.into()), - } + Ok(Response::new(response, method)) } } @@ -523,7 +420,18 @@ impl Default for Transport { } } -/// A pool of [Connection]s, used to make API calls to Elasticsearch. +impl std::fmt::Debug for Transport { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Transport") + .field("client", &self.client) + // .field("req_init_stack", &self.req_init_stack) + // .field("req_handler_stack", &self.req_handler_stack) + .field("conn_pool", &self.conn_pool) + .finish() + } +} + +/// A pool of [Connection]s, used to make API calls to OpenSearch. /// /// A [ConnectionPool] manages the connections, with different implementations determining how /// to get the next [Connection]. The simplest type of [ConnectionPool] is [SingleNodeConnectionPool], @@ -536,7 +444,7 @@ pub trait ConnectionPool: Debug + dyn_clone::DynClone + Sync + Send { clone_trait_object!(ConnectionPool); -/// A connection pool that manages the single connection to an Elasticsearch cluster. +/// A connection pool that manages the single connection to an OpenSearch cluster. #[derive(Debug, Clone)] pub struct SingleNodeConnectionPool { connection: Connection, diff --git a/opensearch/src/lib.rs b/opensearch/src/lib.rs index 226f6324..78a89102 100644 --- a/opensearch/src/lib.rs +++ b/opensearch/src/lib.rs @@ -74,7 +74,7 @@ //! //! ```toml,no_run //! [dependencies] -//! opensearch = "1.0.0" +//! opensearch = "3" //! ``` //! The following _optional_ dependencies may also be useful to create requests and read responses //! @@ -341,29 +341,30 @@ //! //! ```toml,no_run //! [dependencies] -//! opensearch = { version = "1", features = ["aws-auth"] } -//! aws-config = "0.10" +//! opensearch = { version = "3", features = ["aws-auth"] } +//! aws-config = "1" //! ``` //! //! ```rust,no_run -//! # use aws_config::meta::region::RegionProviderChain; -//! # use opensearch::{ -//! # Error, OpenSearch, -//! # http::transport::{TransportBuilder,SingleNodeConnectionPool}, -//! # }; -//! # use url::Url; -//! # use std::convert::TryInto; //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # #[cfg(feature = "aws-auth")] { -//! let creds = aws_config::load_from_env().await; +//! use opensearch::{ +//! aws::{aws_config::{BehaviorVersion, self, meta::region::RegionProviderChain}, AwsSigV4Builder}, +//! Error, OpenSearch, +//! http::transport::{TransportBuilder,SingleNodeConnectionPool}, +//! }; +//! use url::Url; +//! //! let url = Url::parse("https://...")?; //! let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); -//! let aws_config = aws_config::from_env().region(region_provider).load().await.clone(); +//! let aws_config = aws_config::defaults(BehaviorVersion::latest()).region(region_provider).load().await.clone(); +//! let sigv4 = AwsSigV4Builder::from(aws_config) +//! .service_name("es") // use "aoss" for OpenSearch Serverless +//! .build()?; //! let conn_pool = SingleNodeConnectionPool::new(url); //! let transport = TransportBuilder::new(conn_pool) -//! .auth(aws_config.clone().try_into()?) -//! .service_name("es") // use "aoss" for OpenSearch Serverless +//! .aws_sigv4(sigv4) //! .build()?; //! let client = OpenSearch::new(transport); //! # } @@ -392,6 +393,8 @@ mod readme { } pub mod auth; +#[cfg(feature = "aws-auth")] +pub mod aws; pub mod cert; pub mod http; pub mod models; diff --git a/opensearch/src/models/mod.rs b/opensearch/src/models/mod.rs index 0fb2d8a8..88c1987a 100644 --- a/opensearch/src/models/mod.rs +++ b/opensearch/src/models/mod.rs @@ -1,3 +1,16 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +#![allow(unused)] + use serde::Deserialize; #[derive(Deserialize, Debug)] diff --git a/opensearch/tests/auth.rs b/opensearch/tests/auth.rs index 96719305..4069f3c1 100644 --- a/opensearch/tests/auth.rs +++ b/opensearch/tests/auth.rs @@ -29,77 +29,58 @@ */ pub mod common; -use common::*; +use crate::common::server::MockServer; use opensearch::auth::Credentials; -use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder}; -use std::io::Write; - #[tokio::test] async fn basic_auth_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - let mut header_value = b"Basic ".to_vec(); - { - let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD); - write!(encoder, "username:password").unwrap(); - } - - assert_header_eq!( - req, - "authorization", - String::from_utf8(header_value).unwrap() - ); - server::empty_response() - }); - - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .auth(Credentials::Basic("username".into(), "password".into())); - - let client = client::create(builder); - let _response = client.ping().send().await?; + let mut server = MockServer::start().await?; + + let client = + server.client_with(|b| b.auth(Credentials::Basic("username".into(), "password".into()))); + + let _ = client.ping().send().await?; + + let request = server.received_request().await?; + + assert_eq!( + request.header("authorization"), + Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ=") + ); Ok(()) } #[tokio::test] async fn api_key_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - let mut header_value = b"ApiKey ".to_vec(); - { - let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD); - write!(encoder, "id:api_key").unwrap(); - } - - assert_header_eq!( - req, - "authorization", - String::from_utf8(header_value).unwrap() - ); - server::empty_response() - }); - - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .auth(Credentials::ApiKey("id".into(), "api_key".into())); - - let client = client::create(builder); - let _response = client.ping().send().await?; + let mut server = MockServer::start().await?; + + let client = server.client_with(|b| b.auth(Credentials::ApiKey("id".into(), "api_key".into()))); + + let _ = client.ping().send().await?; + + let request = server.received_request().await?; + + assert_eq!( + request.header("authorization"), + Some("ApiKey aWQ6YXBpX2tleQ==") + ); Ok(()) } #[tokio::test] async fn bearer_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_header_eq!(req, "authorization", "Bearer access_token"); - server::empty_response() - }); + let mut server = MockServer::start().await?; + + let client = server.client_with(|b| b.auth(Credentials::Bearer("access_token".into()))); + + let _ = client.ping().send().await?; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .auth(Credentials::Bearer("access_token".into())); + let request = server.received_request().await?; - let client = client::create(builder); - let _response = client.ping().send().await?; + assert_eq!(request.header("authorization"), Some("Bearer access_token")); Ok(()) } diff --git a/opensearch/tests/aws_auth.rs b/opensearch/tests/aws_auth.rs index 200d421f..233d6dc2 100644 --- a/opensearch/tests/aws_auth.rs +++ b/opensearch/tests/aws_auth.rs @@ -12,19 +12,34 @@ #![cfg(feature = "aws-auth")] pub mod common; -use aws_config::SdkConfig; -use aws_credential_types::provider::SharedCredentialsProvider; use aws_credential_types::Credentials as AwsCredentials; use aws_smithy_async::time::StaticTimeSource; use aws_types::region::Region; -use common::*; -use opensearch::{auth::Credentials, indices::IndicesCreateParts, OpenSearch}; -use regex::Regex; -use reqwest::header::HOST; +use common::{server::MockServer, tracing_init}; +use opensearch::{ + aws::{AwsSigV4, AwsSigV4BuildError, AwsSigV4Builder}, + http::headers::HOST, + indices::IndicesCreateParts, +}; use serde_json::json; -use std::convert::TryInto; use test_case::test_case; +fn sigv4_config_builder(service_name: &str) -> AwsSigV4Builder { + let aws_creds = AwsCredentials::new("test-access-key", "test-secret-key", None, None, "test"); + let region = Region::new("ap-southeast-2"); + let time_source = StaticTimeSource::from_secs(1673626117); // 2023-01-13 16:08:37 +0000 + + AwsSigV4::builder() + .credentials_provider(aws_creds) + .region(region) + .service_name(service_name.to_owned()) + .time_source(time_source) +} + +fn sigv4_config(service_name: &str) -> Result { + sigv4_config_builder(service_name).build() +} + #[test_case("es", "10c9be415f4b9f15b12abbb16bd3e3730b2e6c76e0cf40db75d08a44ed04a3a1"; "when service name is es")] #[test_case("aoss", "34903aef90423aa7dd60575d3d45316c6ef2d57bbe564a152b41bf8f5917abf6"; "when service name is aoss")] #[test_case("arbitrary", "156e65c504ea2b2722a481b7515062e7692d27217b477828854e715f507e6a36"; "when service name is arbitrary")] @@ -35,22 +50,12 @@ async fn aws_auth_signs_correctly( ) -> anyhow::Result<()> { tracing_init(); - let (server, mut rx) = server::capturing_http(); + let mut server = MockServer::start().await?; - let aws_creds = AwsCredentials::new("test-access-key", "test-secret-key", None, None, "test"); - let region = Region::new("ap-southeast-2"); - let time_source = StaticTimeSource::from_secs(1673626117); // 2023-01-13 16:08:37 +0000 let host = format!("aaabbbcccddd111222333.ap-southeast-2.{service_name}.amazonaws.com"); + let sigv4 = sigv4_config(service_name)?; - let transport_builder = client::create_builder(&format!("http://{}", server.addr())) - .auth(Credentials::AwsSigV4( - SharedCredentialsProvider::new(aws_creds), - region, - )) - .service_name(service_name) - .sigv4_time_source(time_source.into()) - .header(HOST, host.parse().unwrap()); - let client = client::create(transport_builder); + let client = server.client_with(|b| b.aws_sigv4(sigv4).header(HOST, host.parse().unwrap())); let _ = client .indices() @@ -74,59 +79,55 @@ async fn aws_auth_signs_correctly( .send() .await?; - let sent_req = rx.recv().await.expect("should have sent a request"); + let sent_req = server.received_request().await?; - assert_header_eq!(sent_req, "accept", "application/json"); - assert_header_eq!(sent_req, "content-type", "application/json"); - assert_header_eq!(sent_req, "host", host); - assert_header_eq!(sent_req, "x-amz-date", "20230113T160837Z"); - assert_header_eq!( - sent_req, - "x-amz-content-sha256", - "4c770eaed349122a28302ff73d34437cad600acda5a9dd373efc7da2910f8564" + assert_eq!(sent_req.header("accept"), Some("application/json")); + assert_eq!(sent_req.header("content-type"), Some("application/json")); + assert_eq!(sent_req.header("host"), Some(host.as_str())); + assert_eq!(sent_req.header("x-amz-date"), Some("20230113T160837Z")); + assert_eq!( + sent_req.header("x-amz-content-sha256"), + Some("4c770eaed349122a28302ff73d34437cad600acda5a9dd373efc7da2910f8564") ); - assert_header_eq!(sent_req, "authorization", format!("AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/{service_name}/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature={expected_signature}")); + assert_eq!(sent_req.header("authorization"), Some(format!("AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/{service_name}/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature={expected_signature}").as_str())); Ok(()) } #[tokio::test] async fn aws_auth_get() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - let authorization_header = req.headers()["authorization"].to_str().unwrap(); - let re = Regex::new(r"^AWS4-HMAC-SHA256 Credential=id/\d*/us-west-1/custom/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature=[a-f,0-9].*$").unwrap(); - assert!( - re.is_match(authorization_header), - "{}", - authorization_header - ); - assert_header_eq!( - req, - "x-amz-content-sha256", - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" - ); // SHA of empty string - server::empty_response() - }); - - let client = create_aws_client(format!("http://{}", server.addr()).as_ref())?; - let _response = client.ping().send().await?; + tracing_init(); + + let mut server = MockServer::start().await?; + let sigv4 = sigv4_config_builder("custom") + .ignore_header("host") + .build()?; + + let client = server.client_with(|b| b.aws_sigv4(sigv4)); + + let _ = client.ping().send().await?; + + let sent_req = server.received_request().await?; + + assert_eq!(sent_req.header("authorization"), Some("AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/custom/aws4_request, SignedHeaders=accept;content-type;x-amz-content-sha256;x-amz-date, Signature=8c882ad6cff05cb6c5bc91a030a92582787f34ef4af858a728c6f943c4ff2f21")); + assert_eq!( + sent_req.header("x-amz-content-sha256"), + Some("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") + ); Ok(()) } #[tokio::test] async fn aws_auth_post() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_header_eq!( - req, - "x-amz-content-sha256", - "f3a842f988a653a734ebe4e57c45f19293a002241a72f0b3abbff71e4f5297b9" - ); // SHA of the JSON - server::empty_response() - }); - - let client = create_aws_client(format!("http://{}", server.addr()).as_ref())?; - client + tracing_init(); + + let mut server = MockServer::start().await?; + let sigv4 = sigv4_config("custom")?; + + let client = server.client_with(|b| b.aws_sigv4(sigv4)); + + let _ = client .index(opensearch::IndexParts::Index("movies")) .body(serde_json::json!({ "title": "Moneyball", @@ -137,18 +138,12 @@ async fn aws_auth_post() -> anyhow::Result<()> { .send() .await?; - Ok(()) -} + let sent_req = server.received_request().await?; + + assert_eq!( + sent_req.header("x-amz-content-sha256"), + Some("f3a842f988a653a734ebe4e57c45f19293a002241a72f0b3abbff71e4f5297b9") + ); // SHA of the JSON -fn create_aws_client(addr: &str) -> anyhow::Result { - let aws_creds = AwsCredentials::new("id", "secret", None, None, "token"); - let creds_provider = SharedCredentialsProvider::new(aws_creds); - let aws_config = SdkConfig::builder() - .credentials_provider(creds_provider) - .region(Region::new("us-west-1")) - .build(); - let builder = client::create_builder(addr) - .auth(aws_config.clone().try_into()?) - .service_name("custom"); - Ok(client::create(builder)) + Ok(()) } diff --git a/opensearch/tests/cert.rs b/opensearch/tests/cert.rs index df5f6493..1b85d29b 100644 --- a/opensearch/tests/cert.rs +++ b/opensearch/tests/cert.rs @@ -50,8 +50,7 @@ fn expected_error_message() -> &'static str { #[tokio::test] #[cfg(feature = "native-tls")] async fn default_certificate_validation() -> anyhow::Result<()> { - let builder = client::create_default_builder().cert_validation(CertificateValidation::Default); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Default)); match client.ping().send().await { Ok(response) => Err(anyhow!( @@ -77,8 +76,7 @@ async fn default_certificate_validation() -> anyhow::Result<()> { #[tokio::test] #[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))] async fn default_certificate_validation_rustls_tls() -> anyhow::Result<()> { - let builder = client::create_default_builder().cert_validation(CertificateValidation::Default); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Default)); match client.ping().send().await { Ok(response) => Err(anyhow!( @@ -103,8 +101,7 @@ async fn default_certificate_validation_rustls_tls() -> anyhow::Result<()> { /// Allows any certificate through #[tokio::test] async fn none_certificate_validation() -> anyhow::Result<()> { - let builder = client::create_default_builder().cert_validation(CertificateValidation::None); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::None)); let _response = client.ping().send().await?; Ok(()) } @@ -118,9 +115,7 @@ async fn none_certificate_validation() -> anyhow::Result<()> { ))] async fn full_certificate_ca_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(CA_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Full(cert)); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Full(cert))); let _response = client.ping().send().await?; Ok(()) } @@ -135,9 +130,7 @@ async fn full_certificate_ca_chain_validation() -> anyhow::Result<()> { let mut cert = Certificate::from_pem(CA_CHAIN_CERT)?; cert.append(Certificate::from_pem(CA_CERT)?); assert_eq!(cert.len(), 3, "expected three certificates in CA chain"); - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Full(cert)); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Full(cert))); let _response = client.ping().send().await?; Ok(()) } @@ -147,9 +140,7 @@ async fn full_certificate_ca_chain_validation() -> anyhow::Result<()> { #[cfg(all(windows, feature = "native-tls"))] async fn full_certificate_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(ESNODE_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Full(cert)); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Full(cert))); let _response = client.ping().send().await?; Ok(()) } @@ -163,9 +154,7 @@ async fn full_certificate_validation_rustls_tls() -> anyhow::Result<()> { chain.extend(ESNODE_CERT); let cert = Certificate::from_pem(chain.as_slice())?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Full(cert)); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Full(cert))); let _response = client.ping().send().await?; Ok(()) } @@ -176,9 +165,7 @@ async fn full_certificate_validation_rustls_tls() -> anyhow::Result<()> { #[cfg(all(unix, feature = "native-tls"))] async fn full_certificate_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(ESNODE_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Full(cert)); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Full(cert))); match client.ping().send().await { Ok(response) => Err(anyhow!( @@ -205,9 +192,8 @@ async fn full_certificate_validation() -> anyhow::Result<()> { #[cfg(all(windows, feature = "native-tls"))] async fn certificate_certificate_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(ESNODE_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Certificate(cert)); - let client = client::create(builder); + let client = + client::create_with(|b| b.cert_validation(CertificateValidation::Certificate(cert))); let _response = client.ping().send().await?; Ok(()) } @@ -218,9 +204,8 @@ async fn certificate_certificate_validation() -> anyhow::Result<()> { #[cfg(all(unix, feature = "native-tls"))] async fn certificate_certificate_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(ESNODE_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Certificate(cert)); - let client = client::create(builder); + let client = + client::create_with(|b| b.cert_validation(CertificateValidation::Certificate(cert))); match client.ping().send().await { Ok(response) => Err(anyhow!( @@ -248,9 +233,8 @@ async fn certificate_certificate_validation() -> anyhow::Result<()> { #[cfg(all(feature = "native-tls", not(target_os = "macos")))] async fn certificate_certificate_ca_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(CA_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Certificate(cert)); - let client = client::create(builder); + let client = + client::create_with(|b| b.cert_validation(CertificateValidation::Certificate(cert))); let _response = client.ping().send().await?; Ok(()) } @@ -260,9 +244,8 @@ async fn certificate_certificate_ca_validation() -> anyhow::Result<()> { #[cfg(feature = "native-tls")] async fn fail_certificate_certificate_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(ESNODE_NO_SAN_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Certificate(cert)); - let client = client::create(builder); + let client = + client::create_with(|b| b.cert_validation(CertificateValidation::Certificate(cert))); match client.ping().send().await { Ok(response) => Err(anyhow!( diff --git a/opensearch/tests/client.rs b/opensearch/tests/client.rs index f4358f67..7f3f2e75 100644 --- a/opensearch/tests/client.rs +++ b/opensearch/tests/client.rs @@ -30,7 +30,10 @@ pub mod common; use common::*; +use hyper::Method; +use crate::common::{client::index_documents, server::MockServer}; +use bytes::Bytes; use opensearch::{ http::{ headers::{ @@ -42,60 +45,56 @@ use opensearch::{ params::TrackTotalHits, SearchParts, }; - -use crate::common::client::index_documents; -use bytes::Bytes; -use hyper::Method; use serde_json::{json, Value}; use std::time::Duration; #[tokio::test] async fn default_user_agent_content_type_accept_headers() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_header_eq!(req, "user-agent", DEFAULT_USER_AGENT); - assert_header_eq!(req, "content-type", "application/json"); - assert_header_eq!(req, "accept", "application/json"); - server::empty_response() - }); + let mut server = MockServer::start().await?; - let client = client::create_for_url(format!("http://{}", server.addr()).as_ref()); - let _response = client.ping().send().await?; + let _ = server.client().ping().send().await?; + + let request = server.received_request().await?; + + assert_eq!(request.header("user-agent"), Some(DEFAULT_USER_AGENT)); + assert_eq!(request.header("content-type"), Some(DEFAULT_CONTENT_TYPE)); + assert_eq!(request.header("accept"), Some(DEFAULT_ACCEPT)); Ok(()) } #[tokio::test] async fn default_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_header_eq!(req, "x-opaque-id", "foo"); - server::empty_response() + let mut server = MockServer::start().await?; + + let client = server.client_with(|b| { + b.header( + HeaderName::from_static(X_OPAQUE_ID), + HeaderValue::from_static("foo"), + ) }); - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()).header( - HeaderName::from_static(X_OPAQUE_ID), - HeaderValue::from_static("foo"), - ); + let _ = client.ping().send().await?; + + let request = server.received_request().await?; - let client = client::create(builder); - let _response = client.ping().send().await?; + assert_eq!(request.header("x-opaque-id"), Some("foo")); Ok(()) } #[tokio::test] async fn override_default_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_header_eq!(req, "x-opaque-id", "bar"); - server::empty_response() - }); + let mut server = MockServer::start().await?; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()).header( - HeaderName::from_static(X_OPAQUE_ID), - HeaderValue::from_static("foo"), - ); + let client = server.client_with(|b| { + b.header( + HeaderName::from_static(X_OPAQUE_ID), + HeaderValue::from_static("foo"), + ) + }); - let client = client::create(builder); - let _response = client + let _ = client .ping() .header( HeaderName::from_static(X_OPAQUE_ID), @@ -104,18 +103,19 @@ async fn override_default_header() -> anyhow::Result<()> { .send() .await?; + let request = server.received_request().await?; + + assert_eq!(request.header("x-opaque-id"), Some("bar")); + Ok(()) } #[tokio::test] async fn x_opaque_id_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_header_eq!(req, "x-opaque-id", "foo"); - server::empty_response() - }); + let mut server = MockServer::start().await?; - let client = client::create_for_url(format!("http://{}", server.addr()).as_ref()); - let _response = client + let _ = server + .client() .ping() .header( HeaderName::from_static(X_OPAQUE_ID), @@ -124,39 +124,41 @@ async fn x_opaque_id_header() -> anyhow::Result<()> { .send() .await?; + let request = server.received_request().await?; + + assert_eq!(request.header("x-opaque-id"), Some("foo")); + Ok(()) } #[tokio::test] -async fn uses_global_request_timeout() { - let server = server::http(move |_| async move { - std::thread::sleep(Duration::from_secs(1)); - server::empty_response() - }); +async fn uses_global_request_timeout() -> anyhow::Result<()> { + let server = MockServer::builder() + .response_delay(Duration::from_secs(1)) + .start() + .await?; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .timeout(std::time::Duration::from_millis(500)); + let client = server.client_with(|b| b.timeout(Duration::from_millis(500))); - let client = client::create(builder); let response = client.ping().send().await; match response { Ok(_) => panic!("Expected timeout error, but response received"), Err(e) => assert!(e.is_timeout(), "Expected timeout error, but was {:?}", e), } + + Ok(()) } #[tokio::test] -async fn uses_call_request_timeout() { - let server = server::http(move |_| async move { - std::thread::sleep(Duration::from_secs(1)); - server::empty_response() - }); +async fn uses_call_request_timeout() -> anyhow::Result<()> { + let server = MockServer::builder() + .response_delay(Duration::from_secs(1)) + .start() + .await?; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .timeout(std::time::Duration::from_secs(2)); + let client = server.client_with(|b| b.timeout(Duration::from_secs(2))); - let client = client::create(builder); let response = client .ping() .request_timeout(Duration::from_millis(500)) @@ -167,34 +169,37 @@ async fn uses_call_request_timeout() { Ok(_) => panic!("Expected timeout error, but response received"), Err(e) => assert!(e.is_timeout(), "Expected timeout error, but was {:?}", e), } + + Ok(()) } #[tokio::test] -async fn call_request_timeout_supersedes_global_timeout() { - let server = server::http(move |_| async move { - std::thread::sleep(Duration::from_secs(1)); - server::empty_response() - }); +async fn call_request_timeout_supersedes_global_timeout() -> anyhow::Result<()> { + let server = MockServer::builder() + .response_delay(Duration::from_secs(1)) + .start() + .await?; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .timeout(std::time::Duration::from_millis(500)); + let client = server.client_with(|b| b.timeout(Duration::from_millis(500))); - let client = client::create(builder); let response = client .ping() .request_timeout(Duration::from_secs(2)) .send() .await; - match response { - Ok(_) => (), - Err(e) => assert!(e.is_timeout(), "Did not expect error, but was {:?}", e), - } + assert!( + response.is_ok(), + "Expected response, but was: {:?}", + response + ); + + Ok(()) } #[tokio::test] async fn deprecation_warning_headers() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let _ = index_documents(&client).await?; let response = client .search(SearchParts::None) @@ -239,18 +244,10 @@ async fn deprecation_warning_headers() -> anyhow::Result<()> { #[tokio::test] async fn serialize_querystring() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_eq!(req.method(), Method::GET); - assert_eq!(req.uri().path(), "/_search"); - assert_eq!( - req.uri().query(), - Some("filter_path=took%2C_shards&pretty=true&q=title%3AOpenSearch&track_total_hits=100000") - ); - server::empty_response() - }); + let mut server = MockServer::start().await?; - let client = client::create_for_url(format!("http://{}", server.addr()).as_ref()); - let _response = client + let _ = server + .client() .search(SearchParts::None) .pretty(true) .filter_path(&["took", "_shards"]) @@ -259,12 +256,20 @@ async fn serialize_querystring() -> anyhow::Result<()> { .send() .await?; + let request = server.received_request().await?; + assert_eq!(request.method(), Method::GET); + assert_eq!(request.path(), "/_search"); + assert_eq!( + request.query(), + Some("filter_path=took%2C_shards&pretty=true&q=title%3AOpenSearch&track_total_hits=100000") + ); + Ok(()) } #[tokio::test] async fn search_with_body() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let _ = index_documents(&client).await?; let response = client .search(SearchParts::None) @@ -307,7 +312,7 @@ async fn search_with_body() -> anyhow::Result<()> { #[tokio::test] async fn search_with_no_body() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let _ = index_documents(&client).await?; let response = client .search(SearchParts::None) @@ -330,7 +335,7 @@ async fn search_with_no_body() -> anyhow::Result<()> { #[tokio::test] async fn read_response_as_bytes() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let _ = index_documents(&client).await?; let response = client .search(SearchParts::None) @@ -356,7 +361,7 @@ async fn read_response_as_bytes() -> anyhow::Result<()> { #[tokio::test] async fn cat_health_format_json() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let response = client .cat() .health() @@ -380,7 +385,7 @@ async fn cat_health_format_json() -> anyhow::Result<()> { #[tokio::test] async fn cat_health_header_json() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let response = client .cat() .health() @@ -404,7 +409,7 @@ async fn cat_health_header_json() -> anyhow::Result<()> { #[tokio::test] async fn cat_health_text() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let response = client.cat().health().pretty(true).send().await?; assert_eq!(response.status_code(), StatusCode::OK); @@ -422,7 +427,7 @@ async fn cat_health_text() -> anyhow::Result<()> { #[tokio::test] async fn clone_search_with_body() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let _ = index_documents(&client).await?; let base_request = client.search(SearchParts::None); @@ -448,7 +453,7 @@ async fn clone_search_with_body() -> anyhow::Result<()> { #[tokio::test] async fn byte_slice_body() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let body = b"{\"query\":{\"match_all\":{}}}"; let response = client diff --git a/opensearch/tests/common/client.rs b/opensearch/tests/common/client.rs index f974bf8e..edba76c1 100644 --- a/opensearch/tests/common/client.rs +++ b/opensearch/tests/common/client.rs @@ -35,12 +35,12 @@ use opensearch::{ http::{ response::Response, transport::{SingleNodeConnectionPool, TransportBuilder}, + StatusCode, }, indices::IndicesExistsParts, params::Refresh, BulkOperation, BulkParts, Error, OpenSearch, DEFAULT_ADDRESS, }; -use reqwest::StatusCode; use serde_json::json; use sysinfo::{ProcessRefreshKind, RefreshKind, System, SystemExt}; use url::Url; @@ -63,45 +63,69 @@ fn running_proxy() -> bool { has_fiddler } -pub fn create_default_builder() -> TransportBuilder { - create_builder(cluster_addr().as_str()) -} +pub struct TestClientBuilder(TransportBuilder); + +impl TestClientBuilder { + pub fn new() -> Self { + Self::with_url(&cluster_addr()) + } + + pub fn with_url(url: &str) -> Self { + let url = Url::parse(url).unwrap(); + let secure = url.scheme() == "https"; + let conn_pool = SingleNodeConnectionPool::new(url); + let mut builder = TransportBuilder::new(conn_pool); + + // assume if we're running with HTTPS then authentication is also enabled and disable + // certificate validation - we'll change this for tests that need to. + if secure { + builder = builder.auth(Credentials::Basic("admin".into(), "admin".into())); + + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + { + builder = builder.cert_validation(CertificateValidation::None); + } + } + + Self(builder) + } + + pub fn with(mut self, configurator: impl FnOnce(TransportBuilder) -> TransportBuilder) -> Self { + self.0 = configurator(self.0); + self + } + + pub fn build(self) -> OpenSearch { + let mut builder = self.0; -pub fn create_builder(addr: &str) -> TransportBuilder { - let url = Url::parse(addr).unwrap(); - let conn_pool = SingleNodeConnectionPool::new(url.clone()); - let mut builder = TransportBuilder::new(conn_pool); - // assume if we're running with HTTPS then authentication is also enabled and disable - // certificate validation - we'll change this for tests that need to. - if url.scheme() == "https" { - builder = builder.auth(Credentials::Basic("admin".into(), "admin".into())); - - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - { - builder = builder.cert_validation(CertificateValidation::None); + if running_proxy() { + let proxy_url = Url::parse("http://localhost:8888").unwrap(); + builder = builder.proxy(proxy_url, None, None); } + + let transport = builder.build().unwrap(); + OpenSearch::new(transport) } +} - builder +pub fn builder() -> TestClientBuilder { + TestClientBuilder::new() } -pub fn create_default() -> OpenSearch { - create_for_url(cluster_addr().as_str()) +pub fn builder_with_url(url: &str) -> TestClientBuilder { + TestClientBuilder::with_url(url) } -pub fn create_for_url(url: &str) -> OpenSearch { - let builder = create_builder(url); - create(builder) +pub fn create() -> OpenSearch { + builder().build() } -pub fn create(mut builder: TransportBuilder) -> OpenSearch { - if running_proxy() { - let proxy_url = Url::parse("http://localhost:8888").unwrap(); - builder = builder.proxy(proxy_url, None, None); - } +pub fn create_with(configurator: impl FnOnce(TransportBuilder) -> TransportBuilder) -> OpenSearch { + builder().with(configurator).build() +} - let transport = builder.build().unwrap(); - OpenSearch::new(transport) +pub fn create_with_url(url: &str) -> OpenSearch { + builder_with_url(url).build() } /// index some documents into a posts index. If the posts index already exists, do nothing. diff --git a/opensearch/tests/common/mod.rs b/opensearch/tests/common/mod.rs index 93dacd5c..e7fc241c 100644 --- a/opensearch/tests/common/mod.rs +++ b/opensearch/tests/common/mod.rs @@ -28,23 +28,13 @@ * GitHub history for details. */ -#![allow(unused)] +pub mod client; +pub mod server; -pub(crate) mod client; -pub(crate) mod server; - -pub(crate) static DEFAULT_USER_AGENT: &str = concat!("opensearch-rs/", env!("CARGO_PKG_VERSION")); - -macro_rules! assert_header_eq { - ($req:expr, $header:expr, $value:expr) => { - assert_eq!($req.headers()[$header], $value); - }; -} +pub static DEFAULT_USER_AGENT: &str = concat!("opensearch-rs/", env!("CARGO_PKG_VERSION")); static TRACING: std::sync::Once = std::sync::Once::new(); -pub(crate) fn tracing_init() { +pub fn tracing_init() { TRACING.call_once(|| tracing_subscriber::fmt::init()) } - -pub(crate) use assert_header_eq; diff --git a/opensearch/tests/common/server.rs b/opensearch/tests/common/server.rs index 07069105..26d86831 100644 --- a/opensearch/tests/common/server.rs +++ b/opensearch/tests/common/server.rs @@ -32,47 +32,153 @@ // Licensed under Apache License, Version 2.0 // https://github.com/seanmonstar/reqwest/blob/master/LICENSE-APACHE -use std::{ - convert::Infallible, - future::Future, - net::{self, SocketAddr}, - sync::mpsc as std_mpsc, - thread, - time::Duration, -}; +use std::{net::SocketAddr, sync::mpsc as std_mpsc, thread, time::Duration}; use bytes::Bytes; use http_body_util::Empty; use hyper::{ - body::{Body, Incoming}, - server::conn::http1, - service::service_fn, - Request, Response, + body::Incoming, server::conn::http1, service::service_fn, HeaderMap, Method, Request, Response, + Uri, }; use hyper_util::rt::TokioIo; +use opensearch::{http::transport::TransportBuilder, OpenSearch}; use tokio::{ - net::{TcpListener, TcpStream}, - sync::{broadcast, mpsc}, + net::TcpListener, + runtime, + sync::{mpsc, watch}, }; -use tokio::runtime; +use super::client::TestClientBuilder; + +#[derive(Default)] +pub struct MockServerBuilder { + response_delay: Option, +} + +impl MockServerBuilder { + pub fn response_delay(mut self, delay: Duration) -> Self { + self.response_delay = Some(delay); + self + } + + pub async fn start(self) -> anyhow::Result { + let thread_name = thread::current().name().unwrap_or("").to_owned(); + + thread::spawn(move || { + let rt = runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("new rt"); + let _ = rt.enter(); + + let (shutdown_tx, mut shutdown_rx) = watch::channel(false); + let (requests_tx, requests_rx) = mpsc::unbounded_channel(); + let listener = rt + .block_on(TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))) + .unwrap(); + let addr = listener.local_addr().unwrap(); + let response_delay = self.response_delay.clone(); + + let srv = async move { + loop { + let (stream, _) = tokio::select! { + res = listener.accept() => res?, + _ = shutdown_rx.changed() => break + }; + let io = TokioIo::new(stream); + + let mut shutdown_rx = shutdown_rx.clone(); + let requests_tx = requests_tx.clone(); + + tokio::task::spawn(async move { + let conn = http1::Builder::new().serve_connection( + io, + service_fn(move |req| { + let requests_tx = requests_tx.clone(); + async move { + requests_tx.send(req.into())?; + if let Some(response_delay) = response_delay.clone() { + tokio::time::sleep(response_delay).await; + } + Ok::<_, anyhow::Error>(Response::>::default()) + } + }), + ); + tokio::pin!(conn); + tokio::select! { + _ = conn.as_mut() => {}, + _ = shutdown_rx.changed() => conn.as_mut().graceful_shutdown() + } + }); + } + Ok::<(), anyhow::Error>(()) + }; + + let (panic_tx, panic_rx) = std_mpsc::channel(); + let thread_name = format!("test({})-support-server", thread_name); + thread::Builder::new() + .name(thread_name) + .spawn(move || { + rt.block_on(srv).unwrap(); + let _ = panic_tx.send(()); + }) + .expect("thread spawn"); + + MockServer { + uri: format!("http://{}", addr), + requests_rx, + panic_rx, + shutdown_tx: Some(shutdown_tx), + } + }) + .join() + .map_err(|e| anyhow::anyhow!("MockServer construction failed: {:?}", e)) + } +} -pub struct Server { - addr: net::SocketAddr, +pub struct MockServer { + uri: String, + requests_rx: mpsc::UnboundedReceiver, panic_rx: std_mpsc::Receiver<()>, - shutdown_tx: Option>, + shutdown_tx: Option>, } -impl Server { - pub fn addr(&self) -> net::SocketAddr { - self.addr +impl MockServer { + pub fn builder() -> MockServerBuilder { + MockServerBuilder::default() + } + + pub async fn start() -> anyhow::Result { + MockServerBuilder::default().start().await + } + + pub fn client(&self) -> OpenSearch { + self.client_builder().build() + } + + pub fn client_with( + &self, + configurator: impl FnOnce(TransportBuilder) -> TransportBuilder, + ) -> OpenSearch { + self.client_builder().with(configurator).build() + } + + pub fn client_builder(&self) -> TestClientBuilder { + super::client::builder_with_url(&self.uri) + } + + pub async fn received_request(&mut self) -> anyhow::Result { + self.requests_rx + .recv() + .await + .ok_or_else(|| anyhow::anyhow!("no request received")) } } -impl Drop for Server { +impl Drop for MockServer { fn drop(&mut self) { if let Some(tx) = self.shutdown_tx.take() { - tx.send(()).unwrap(); + tx.send(true).unwrap(); } if !::std::thread::panicking() { @@ -83,90 +189,36 @@ impl Drop for Server { } } -pub fn http(func: F) -> Server -where - F: Fn(Request) -> Fut + Clone + Send + 'static, - Fut: Future> + Send + 'static, - B: Body + Send + 'static, - B::Data: Send, - B::Error: std::error::Error + Send + Sync, -{ - let thread_name = thread::current().name().unwrap_or("").to_owned(); - - thread::spawn(move || { - let rt = runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("new rt"); - let _ = rt.enter(); - - let (shutdown_tx, mut shutdown_rx) = broadcast::channel(1); - let listener = rt - .block_on(TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))) - .unwrap(); - let addr = listener.local_addr().unwrap(); - - let srv = async move { - loop { - let (stream, _) = tokio::select! { - res = listener.accept() => res?, - _ = shutdown_rx.recv() => break - }; - let io = TokioIo::new(stream); - - let mut func = func.clone(); - let mut shutdown_rx = shutdown_rx.resubscribe(); - - tokio::task::spawn(async move { - let conn = http1::Builder::new().serve_connection( - io, - service_fn(move |req| { - let func = func.clone(); - async move { Ok::<_, Infallible>(func(req).await) } - }), - ); - tokio::pin!(conn); - tokio::select! { - res = conn.as_mut() => {}, - _ = shutdown_rx.recv() => conn.as_mut().graceful_shutdown() - } - }); - } - Ok::<(), anyhow::Error>(()) - }; - - let (panic_tx, panic_rx) = std_mpsc::channel(); - let thread_name = format!("test({})-support-server", thread_name); - thread::Builder::new() - .name(thread_name) - .spawn(move || { - rt.block_on(srv).unwrap(); - let _ = panic_tx.send(()); - }) - .expect("thread spawn"); - - Server { - addr, - panic_rx, - shutdown_tx: Some(shutdown_tx), - } - }) - .join() - .unwrap() +pub struct ReceivedRequest { + method: Method, + uri: Uri, + headers: HeaderMap, } -pub fn capturing_http() -> (Server, mpsc::UnboundedReceiver>) { - let (tx, rx) = mpsc::unbounded_channel(); - let server = http(move |req| { - let tx = tx.clone(); - async move { - tx.send(req).unwrap(); - empty_response() - } - }); - (server, rx) +impl ReceivedRequest { + pub fn method(&self) -> &Method { + &self.method + } + + pub fn path(&self) -> &str { + self.uri.path() + } + + pub fn query(&self) -> Option<&str> { + self.uri.query() + } + + pub fn header(&self, name: &str) -> Option<&str> { + self.headers.get(name).and_then(|v| v.to_str().ok()) + } } -pub fn empty_response() -> Response> { - Default::default() +impl From> for ReceivedRequest { + fn from(req: Request) -> Self { + ReceivedRequest { + method: req.method().clone(), + uri: req.uri().clone(), + headers: req.headers().clone(), + } + } } diff --git a/opensearch/tests/error.rs b/opensearch/tests/error.rs index 382eab69..eb97e65d 100644 --- a/opensearch/tests/error.rs +++ b/opensearch/tests/error.rs @@ -38,8 +38,7 @@ use serde_json::{json, Value}; /// Responses in the range 400-599 return Response body #[tokio::test] async fn bad_request_returns_response() -> anyhow::Result<()> { - let client = client::create_default(); - let response = client + let response = client::create() .explain(ExplainParts::IndexId("non_existent_index", "id")) .body(json!({})) .send() @@ -63,8 +62,7 @@ async fn bad_request_returns_response() -> anyhow::Result<()> { #[tokio::test] async fn deserialize_exception() -> anyhow::Result<()> { - let client = client::create_default(); - let response = client + let response = client::create() .explain(ExplainParts::IndexId("non_existent_index", "id")) .error_trace(true) .body(json!({}))