diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9ebcc08f..780a7e5b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,7 +14,7 @@ jobs: matrix: os: [ubuntu-latest, windows-latest, macos-latest] test-args: - - "--features aws-auth" + - "" - "--no-default-features --features rustls-tls --package opensearch --test cert" runs-on: ${{ matrix.os }} steps: @@ -33,7 +33,7 @@ jobs: version: 2.6.0 secured: true - - name: Run Tests (${{ matrix.test-args }}) + - name: Run Tests working-directory: client run: cargo make test ${{ matrix.test-args }} env: diff --git a/Cargo.toml b/Cargo.toml index 67c89ece..adb63fc4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,5 +2,6 @@ members = [ "api_generator", "opensearch", + "opensearch-auth-awssigv4", "yaml_test_runner" ] \ No newline at end of file diff --git a/Makefile.toml b/Makefile.toml index 1d23ddea..c01a1dc2 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -19,18 +19,15 @@ OPENSEARCH_URL = { value = "${OPENSEARCH_PROTOCOL}://localhost:9200", condition [tasks.generate-certs] category = "OpenSearch" description = "Generates SSL certificates used for integration tests" -command = "bash ./.ci/generate-certs.sh" +command = "bash" +args = ["./.ci/generate-certs.sh"] [tasks.run-opensearch] category = "OpenSearch" private = true condition = { env_set = [ "STACK_VERSION"], env_false = ["CARGO_MAKE_CI"] } - -[tasks.run-opensearch.linux] -command = "./.ci/run-opensearch.sh" - -[tasks.run-opensearch.mac] -command = "./.ci/run-opensearch.sh" +command = "bash" +args = ["./.ci/run-opensearch.sh"] [tasks.run-opensearch.windows] script_runner = "cmd" diff --git a/api_generator/Cargo.toml b/api_generator/Cargo.toml index 082ee83b..9a1c9d48 100644 --- a/api_generator/Cargo.toml +++ b/api_generator/Cargo.toml @@ -28,7 +28,7 @@ semver = "1.0.14" serde = { version = "~1", features = ["derive"] } serde_json = "~1" simple_logger = "4.0.0" -syn = { version = "~1.0", features = ["full"] } +syn = { version = "~1.0", features = ["full", "extra-traits"] } tar = "~0.4" toml = "0.7.1" url = "2.1.1" diff --git a/opensearch-auth-awssigv4/Cargo.toml b/opensearch-auth-awssigv4/Cargo.toml new file mode 100644 index 00000000..197795ae --- /dev/null +++ b/opensearch-auth-awssigv4/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "opensearch-auth-awssigv4" +version = "2.0.0" +edition = "2018" +authors = ["OpenSearch Contributors"] +description = "Official OpenSearch Rust client AWS SigV4 support" +repository = "https://github.com/opensearch-project/opensearch-rs" +keywords = ["opensearch", "elasticsearch", "search", "lucene"] +categories = ["api-bindings", "database"] +documentation = "https://opensearch.org/docs/latest" +license = "Apache-2.0" +readme = "../README.md" + +[dependencies] +aws-config = "0.54" +aws-credential-types = "0.54" +aws-sigv4 = "0.54" +aws-types = "0.54" +http = "0.2" +opensearch = { path = "../opensearch" } +thiserror = "1.0" +tokio = { version = "~1" } + +[dev-dependencies] +anyhow = "1.0" +serde_json = "~1" +tokio = { version = "~1", features = ["full"] } +wiremock = "0.5" \ No newline at end of file diff --git a/opensearch/examples/aws_auth.rs b/opensearch-auth-awssigv4/examples/aws_auth.rs similarity index 70% rename from opensearch/examples/aws_auth.rs rename to opensearch-auth-awssigv4/examples/aws_auth.rs index 81783474..4152db3e 100644 --- a/opensearch/examples/aws_auth.rs +++ b/opensearch-auth-awssigv4/examples/aws_auth.rs @@ -9,23 +9,25 @@ * GitHub history for details. */ -#[tokio::main] -#[cfg(feature = "aws-auth")] -pub async fn main() -> Result<(), Box> { - use std::convert::TryInto; +use std::convert::TryInto; - use opensearch::{ - cat::CatIndicesParts, - http::transport::{SingleNodeConnectionPool, TransportBuilder}, - OpenSearch, - }; - use url::Url; +use opensearch::{ + cat::CatIndicesParts, + http::{ + transport::{SingleNodeConnectionPool, TransportBuilder}, + Url, + }, + OpenSearch, +}; +use opensearch_auth_awssigv4::TransportBuilderExt; +#[tokio::main] +pub async fn main() -> Result<(), Box> { let aws_config = aws_config::load_from_env().await; let host = ""; // e.g. https://search-mydomain.us-west-1.es.amazonaws.com let transport = TransportBuilder::new(SingleNodeConnectionPool::new(Url::parse(host).unwrap())) - .auth(aws_config.try_into()?) + .aws_sigv4(aws_config.try_into()?) .build()?; let client = OpenSearch::new(transport); @@ -40,8 +42,3 @@ pub async fn main() -> Result<(), Box> { println!("{}", text); Ok(()) } - -#[cfg(not(feature = "aws-auth"))] -pub fn main() { - panic!("Requires the `aws-auth` feature to be enabled") -} diff --git a/opensearch-auth-awssigv4/src/lib.rs b/opensearch-auth-awssigv4/src/lib.rs new file mode 100644 index 00000000..37099737 --- /dev/null +++ b/opensearch-auth-awssigv4/src/lib.rs @@ -0,0 +1,242 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +use std::convert::TryFrom; + +use aws_config::SdkConfig; +use aws_credential_types::{ + provider::{ProvideCredentials, SharedCredentialsProvider}, + time_source::TimeSource, + Credentials, +}; +use aws_sigv4::http_request::{ + sign, PayloadChecksumKind, SignableBody, SignableRequest, SignatureLocation, SigningParams, + SigningSettings, +}; +use http::header::{HeaderName, CONTENT_LENGTH, USER_AGENT}; +use opensearch::http::{ + middleware::{async_trait, RequestHandler, RequestPipeline, RequestPipelineError}, + reqwest::Response, + transport::TransportBuilder, + Request, +}; +use thiserror::Error; + +pub use aws_config; +pub use aws_credential_types; +pub use aws_types; + +#[derive(Error, Debug)] +pub enum BuildError { + #[error("the region for signing must be provided")] + MissingRegion, + #[error("the credentials provider for signing must be provided")] + MissingCredentialsProvider, +} + +#[derive(Debug, Clone)] +pub struct Builder { + service_name: Option, + credentials_provider: Option, + region: Option, + ignored_headers: Vec, + time_source: Option, +} + +impl Builder { + pub fn service_name(mut self, service_name: impl AsRef) -> Self { + self.service_name = Some(service_name.as_ref().into()); + self + } + + pub fn credentials_provider( + mut self, + credentials_provider: impl ProvideCredentials + 'static, + ) -> Self { + self.credentials_provider = Some(SharedCredentialsProvider::new(credentials_provider)); + self + } + + pub fn region(mut self, region: impl AsRef) -> Self { + self.region = Some(region.as_ref().into()); + self + } + + pub fn ignore_header(mut self, header_name: HeaderName) -> Self { + self.ignored_headers.push(header_name); + self + } + + #[doc(hidden)] + pub fn time_source(mut self, time_source: TimeSource) -> Self { + self.time_source = Some(time_source); + self + } + + pub fn build(self) -> Result { + Ok(AwsSigV4 { + service_name: self.service_name.unwrap_or_else(|| "es".to_string()), + credentials_provider: self + .credentials_provider + .ok_or(BuildError::MissingCredentialsProvider)?, + region: self.region.ok_or(BuildError::MissingRegion)?, + ignored_headers: self.ignored_headers, + time_source: self.time_source.unwrap_or_default(), + }) + } +} + +impl Default for Builder { + fn default() -> Self { + Self { + service_name: None, + credentials_provider: None, + region: None, + ignored_headers: vec![USER_AGENT, CONTENT_LENGTH], + time_source: None, + } + } +} + +impl From<&SdkConfig> for Builder { + fn from(value: &SdkConfig) -> Self { + let credentials_provider = value.credentials_provider().cloned(); + let region = value.region().map(|r| r.as_ref().into()); + + Self { + credentials_provider, + region, + ..Default::default() + } + } +} + +impl From for Builder { + fn from(value: SdkConfig) -> Self { + >::from(&value) + } +} + +#[derive(Error, Debug)] +pub enum AwsSigV4Error { + #[error("invalid signing params: {0}")] + InvalidSigningParams(#[from] aws_sigv4::signing_params::BuildError), + #[error("unable to retrieve credentials: {0}")] + FailedCredentialsRetrieval(#[from] aws_credential_types::provider::error::CredentialsError), + #[error("invalid uri: {0}")] + InvalidUri(#[from] http::uri::InvalidUri), + #[error("unable to sign request: {0}")] + FailedSigning(#[from] aws_sigv4::http_request::SigningError), +} + +#[derive(Debug, Clone)] +pub struct AwsSigV4 { + service_name: String, + credentials_provider: SharedCredentialsProvider, + region: String, + ignored_headers: Vec, + time_source: TimeSource, +} + +impl AwsSigV4 { + pub fn builder() -> Builder { + Builder::default() + } + + async fn sign_request(&self, request: &mut Request) -> Result<(), AwsSigV4Error> { + let credentials = self.credentials_provider.provide_credentials().await?; + let params = self.build_params(&credentials)?; + + let uri = request.url().as_str().parse()?; + + let signable_request = SignableRequest::new( + request.method(), + &uri, + request.headers(), + SignableBody::Bytes(request.body().and_then(|b| b.as_bytes()).unwrap_or(&[])), + ); + + let (mut instructions, _) = sign(signable_request, ¶ms)?.into_parts(); + + if let Some(new_headers) = instructions.take_headers() { + for (name, value) in new_headers.into_iter() { + request.headers_mut().insert( + name.expect("AWS SigV4 signing header name must never be None"), + value, + ); + } + } + + Ok(()) + } + + fn build_params<'a>( + &'a self, + credentials: &'a Credentials, + ) -> Result, aws_sigv4::signing_params::BuildError> { + let mut settings = SigningSettings::default(); + settings.signature_location = SignatureLocation::Headers; + settings.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; + settings.excluded_headers = Some(self.ignored_headers.clone()); + + let mut builder = SigningParams::builder() + .access_key(credentials.access_key_id()) + .secret_key(credentials.secret_access_key()) + .service_name(&self.service_name) + .region(&self.region) + .time(self.time_source.now()) + .settings(settings); + + builder.set_security_token(credentials.session_token()); + + builder.build() + } +} + +impl TryFrom<&SdkConfig> for AwsSigV4 { + type Error = BuildError; + + fn try_from(value: &SdkConfig) -> Result { + Builder::from(value).build() + } +} + +impl TryFrom for AwsSigV4 { + type Error = BuildError; + + fn try_from(value: SdkConfig) -> Result { + >::try_from(&value) + } +} + +#[async_trait] +impl RequestHandler for AwsSigV4 { + async fn handle( + &self, + mut request: Request, + next: RequestPipeline<'_>, + ) -> Result { + self.sign_request(&mut request) + .await + .map_err(|e| RequestPipelineError::Pipeline(e.into()))?; + next.run(request).await + } +} + +pub trait TransportBuilderExt { + fn aws_sigv4(self, aws_sigv4: AwsSigV4) -> Self; +} + +impl TransportBuilderExt for TransportBuilder { + fn aws_sigv4(self, aws_sigv4: AwsSigV4) -> Self { + self.with_handler(aws_sigv4) + } +} diff --git a/opensearch-auth-awssigv4/tests/aws_auth.rs b/opensearch-auth-awssigv4/tests/aws_auth.rs new file mode 100644 index 00000000..9abdc858 --- /dev/null +++ b/opensearch-auth-awssigv4/tests/aws_auth.rs @@ -0,0 +1,112 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +use std::time::SystemTime; + +use aws_credential_types::time_source::{TestingTimeSource, TimeSource}; +use aws_credential_types::Credentials; +use http::header::HOST; +use opensearch::http::transport::{SingleNodeConnectionPool, TransportBuilder}; +use opensearch::OpenSearch; +use opensearch_auth_awssigv4::{AwsSigV4, TransportBuilderExt}; +use wiremock::http::HeaderValue; +use wiremock::{MockServer, Request}; + +#[tokio::test] +async fn aws_auth_head() -> anyhow::Result<()> { + let server = MockServer::start().await; + + create_aws_client(&server.uri())?.ping().send().await?; + + let requests = server.received_requests().await.unwrap(); + + assert_eq!( + header_values(&requests[0], "x-amz-date"), + &["19700101T000000Z"] + ); + assert_eq!( + header_values(&requests[0], "x-amz-content-sha256"), + &["e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"] + ); + assert_eq!( + &header_values(&requests[0], "authorization"), + &[ + "AWS4-HMAC-SHA256 Credential=id/19700101/us-west-1/custom/aws4_request", + "SignedHeaders=accept;content-type;x-amz-content-sha256;x-amz-date", + "Signature=a8d6f3596419544dad3d9d35fd8a7ca89a2a0d20208f583f35c8898b1ce218db" + ] + ); + + Ok(()) +} + +#[tokio::test] +async fn aws_auth_post() -> anyhow::Result<()> { + let server = MockServer::start().await; + + create_aws_client(&server.uri())? + .index(opensearch::IndexParts::Index("movies")) + .body(serde_json::json!({ + "title": "Moneyball", + "director": "Bennett Miller", + "year": 2011 + } + )) + .send() + .await?; + + let requests = server.received_requests().await.unwrap(); + + assert_eq!( + header_values(&requests[0], "x-amz-date"), + &["19700101T000000Z"] + ); + assert_eq!( + header_values(&requests[0], "x-amz-content-sha256"), + &["f3a842f988a653a734ebe4e57c45f19293a002241a72f0b3abbff71e4f5297b9"] + ); + assert_eq!( + &header_values(&requests[0], "authorization"), + &[ + "AWS4-HMAC-SHA256 Credential=id/19700101/us-west-1/custom/aws4_request", + "SignedHeaders=accept;content-type;x-amz-content-sha256;x-amz-date", + "Signature=75183ec5cb26555dee2c19bf5d3ce56911d1def2936c36a63af207cf21b29d66" + ] + ); + + Ok(()) +} + +fn create_aws_client(addr: &str) -> anyhow::Result { + let aws_creds = Credentials::new("id", "secret", None, None, "token"); + let aws_sigv4 = AwsSigV4::builder() + .credentials_provider(aws_creds) + .region("us-west-1") + .service_name("custom") + .ignore_header(HOST) + .time_source(TimeSource::testing(&TestingTimeSource::new( + SystemTime::UNIX_EPOCH, + ))) + .build()?; + let builder = + TransportBuilder::new(SingleNodeConnectionPool::new(addr.parse()?)).aws_sigv4(aws_sigv4); + Ok(OpenSearch::new(builder.build()?)) +} + +fn header_values<'r>(request: &'r Request, header_name: &str) -> Vec<&'r str> { + request + .headers + .get(&header_name.into()) + .into_iter() + .flatten() + .map(HeaderValue::as_str) + .collect::>() +} diff --git a/opensearch/Cargo.toml b/opensearch/Cargo.toml index a3218e07..3649ac82 100644 --- a/opensearch/Cargo.toml +++ b/opensearch/Cargo.toml @@ -25,37 +25,33 @@ experimental-apis = ["beta-apis"] native-tls = ["reqwest/native-tls"] rustls-tls = ["reqwest/rustls-tls"] -# AWS SigV4 Auth support -aws-auth = ["aws-credential-types", "aws-sigv4", "aws-types"] - [dependencies] +async-trait = "0.1" base64 = "^0.21" bytes = "^1.0" dyn-clone = "~1" +futures-util = "0.3" lazy_static = "1.4" percent-encoding = "2.1.0" reqwest = { version = "~0.11", default-features = false, features = ["gzip", "json"] } -url = "^2.1" serde = { version = "~1", features = ["derive"] } serde_json = "~1" serde_with = "~2" +thiserror = "1.0" +url = "^2.1" void = "1.0.2" -aws-credential-types = { version = ">= 0.53", optional = true } -aws-sigv4 = { version = ">= 0.53", optional = true } -aws-types = { version = ">= 0.53", optional = true } [dev-dependencies] anyhow = "1.0" -aws-config = ">= 0.53" chrono = { version = "^0.4", features = ["serde"] } clap = "~2" futures = "0.3.1" http = "0.2" -hyper = { version = "0.14", default-features = false, features = ["tcp", "stream", "server"] } regex="1.4" sysinfo = "0.28.0" textwrap = "^0.16" tokio = { version = "1.0", default-features = false, features = ["macros", "net", "time", "rt-multi-thread"] } +wiremock = "0.5" xml-rs = "^0.8" [build-dependencies] diff --git a/opensearch/src/.generated.toml b/opensearch/src/.generated.toml index 0a22bf58..88c136e0 100644 --- a/opensearch/src/.generated.toml +++ b/opensearch/src/.generated.toml @@ -1,16 +1,16 @@ written = [ - 'cat.rs', - 'cluster.rs', - 'dangling_indices.rs', - 'indices.rs', - 'ingest.rs', - 'nodes.rs', - 'root/mod.rs', - 'snapshot.rs', - 'tasks.rs', - 'text_structure.rs', + "cat.rs", + "cluster.rs", + "dangling_indices.rs", + "indices.rs", + "ingest.rs", + "nodes.rs", + "root/mod.rs", + "snapshot.rs", + "tasks.rs", + "text_structure.rs", ] merged = [ - 'lib.rs', - 'params.rs', + "lib.rs", + "params.rs", ] diff --git a/opensearch/src/auth.rs b/opensearch/src/auth.rs index 787917ba..65d32dea 100644 --- a/opensearch/src/auth.rs +++ b/opensearch/src/auth.rs @@ -30,6 +30,12 @@ //! Authentication components +use crate::{ + http::middleware::{ClientInitializer, RequestInitializer}, + BoxError, +}; +use reqwest::Identity; + /// Credentials for authentication #[derive(Debug, Clone)] pub enum Credentials { @@ -45,16 +51,6 @@ pub enum Credentials { Certificate(ClientCertificate), /// An id and api_key to use for API key authentication ApiKey(String, String), - /// AWS credentials used for AWS SigV4 request signing. - /// - /// # Optional - /// - /// This requires the `aws-auth` feature to be enabled. - #[cfg(feature = "aws-auth")] - AwsSigV4( - aws_credential_types::provider::SharedCredentialsProvider, - aws_types::region::Region, - ), } #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] @@ -90,28 +86,50 @@ impl From for Credentials { } } -#[cfg(any(feature = "aws-auth"))] -impl std::convert::TryFrom<&aws_types::SdkConfig> for Credentials { - type Error = super::Error; - - fn try_from(value: &aws_types::SdkConfig) -> Result { - let credentials = value - .credentials_provider() - .ok_or_else(|| super::error::lib("SdkConfig does not have a credentials_provider"))? - .clone(); - let region = value - .region() - .ok_or_else(|| super::error::lib("SdkConfig does not have a region"))? - .clone(); - Ok(Credentials::AwsSigV4(credentials, region)) +impl ClientInitializer for Credentials { + fn init( + &self, + client: reqwest::ClientBuilder, + ) -> Result> { + match &self { + #[cfg(feature = "native-tls")] + Credentials::Certificate(ClientCertificate::Pkcs12(b, p)) => { + Ok(client.identity(Identity::from_pkcs12_der(b, p.as_deref().unwrap_or(""))?)) + } + #[cfg(feature = "rustls-tls")] + Credentials::Certificate(ClientCertificate::Pem(b)) => { + Ok(client.identity(Identity::from_pem(b)?)) + } + _ => Ok(client), + } } } -#[cfg(any(feature = "aws-auth"))] -impl std::convert::TryFrom for Credentials { - type Error = super::Error; +impl RequestInitializer for Credentials { + fn init( + &self, + request: reqwest::RequestBuilder, + ) -> Result> { + Ok(match &self { + Credentials::Basic(u, p) => request.basic_auth(u, Some(p)), + Credentials::Bearer(t) => request.bearer_auth(t), + Credentials::ApiKey(id, key) => { + use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder}; + use reqwest::header::{HeaderValue, AUTHORIZATION}; + use std::io::Write; - fn try_from(value: aws_types::SdkConfig) -> Result { - Credentials::try_from(&value) + let mut header_value = b"ApiKey ".to_vec(); + { + let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD); + write!(encoder, "{}:", id).unwrap(); + write!(encoder, "{}", key).unwrap(); + } + request.header( + AUTHORIZATION, + HeaderValue::from_bytes(&header_value).unwrap(), + ) + } + _ => request, + }) } } diff --git a/opensearch/src/error.rs b/opensearch/src/error.rs index 72709550..ea1d30bb 100644 --- a/opensearch/src/error.rs +++ b/opensearch/src/error.rs @@ -32,125 +32,82 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -use crate::http::{transport::BuildError, StatusCode}; -use std::{error, fmt, io}; + +use crate::http::{middleware::RequestPipelineError, transport::BuildError, StatusCode}; + +pub type BoxError<'a> = Box; /// An error with the client. /// /// Errors that can occur include IO and parsing errors, as well as specific -/// errors from Elasticsearch and internal errors from the client. -#[derive(Debug)] -pub struct Error { - kind: Kind, -} - -#[derive(Debug)] -enum Kind { +/// errors from OpenSearch and internal errors from the client. +#[derive(Debug, thiserror::Error)] +pub enum Error { /// An error building the client - Build(BuildError), + #[error("Error building the client: {0}")] + Build(#[from] BuildError), /// A general error from this library + #[error("Library error: {0}")] Lib(String), - /// HTTP library error - Http(reqwest::Error), + /// Reqwest error + #[error("Reqwest error: {0}")] + Reqwest(#[from] reqwest::Error), + + /// URL parse error + #[error("URL parse error: {0}")] + UrlParse(#[from] url::ParseError), /// IO error - Io(io::Error), + #[error("IO error: {0}")] + Io(#[from] std::io::Error), /// JSON error - Json(serde_json::error::Error), -} - -impl From for Error { - fn from(err: io::Error) -> Error { - Error { - kind: Kind::Io(err), - } - } -} - -impl From for Error { - fn from(err: reqwest::Error) -> Error { - Error { - kind: Kind::Http(err), - } - } -} + #[error("JSON error: {0}")] + Json(#[from] serde_json::error::Error), -impl From for Error { - fn from(err: serde_json::error::Error) -> Error { - Error { - kind: Kind::Json(err), - } - } -} + /// Request initializer error + #[error("Request initializer error: {0}")] + RequestInitializer(#[source] BoxError<'static>), -impl From for Error { - fn from(err: url::ParseError) -> Error { - Error { - kind: Kind::Lib(err.to_string()), - } - } + /// Request pipeline error + #[error("Request pipeline error: {0}")] + RequestPipeline(#[source] BoxError<'static>), } -impl From for Error { - fn from(err: BuildError) -> Error { - Error { - kind: Kind::Build(err), +impl From for Error { + fn from(err: RequestPipelineError) -> Self { + match err { + RequestPipelineError::Reqwest(err) => Self::Reqwest(err), + RequestPipelineError::Pipeline(err) => Self::RequestPipeline(err), } } } pub(crate) fn lib(err: impl Into) -> Error { - Error { - kind: Kind::Lib(err.into()), - } + Error::Lib(err.into()) } impl Error { /// The status code, if the error was generated from a response pub fn status_code(&self) -> Option { - match &self.kind { - Kind::Http(err) => err.status(), + match &self { + Self::Reqwest(err) => err.status(), _ => None, } } /// Returns true if the error is related to a timeout pub fn is_timeout(&self) -> bool { - match &self.kind { - Kind::Http(err) => err.is_timeout(), + match &self { + Self::Reqwest(err) => err.is_timeout(), _ => false, } } /// Returns true if the error is related to serialization or deserialization pub fn is_json(&self) -> bool { - matches!(self.kind, Kind::Json(_)) - } -} - -impl error::Error for Error { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - match &self.kind { - Kind::Build(err) => Some(err), - Kind::Lib(_) => None, - Kind::Http(err) => Some(err), - Kind::Io(err) => Some(err), - Kind::Json(err) => Some(err), - } - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match &self.kind { - Kind::Build(err) => err.fmt(f), - Kind::Lib(err) => err.fmt(f), - Kind::Http(err) => err.fmt(f), - Kind::Io(err) => err.fmt(f), - Kind::Json(err) => err.fmt(f), - } + matches!(self, Self::Json(_)) } } diff --git a/opensearch/src/http/aws_auth.rs b/opensearch/src/http/aws_auth.rs deleted file mode 100644 index ffa33c76..00000000 --- a/opensearch/src/http/aws_auth.rs +++ /dev/null @@ -1,79 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -use std::time::SystemTime; - -use aws_credential_types::{ - provider::{ProvideCredentials, SharedCredentialsProvider}, - Credentials, -}; -use aws_sigv4::{ - http_request::{ - sign, PayloadChecksumKind, SignableBody, SignableRequest, SigningParams, SigningSettings, - }, - signing_params::BuildError, -}; -use aws_types::region::Region; -use reqwest::Request; - -fn get_signing_params<'a>( - credentials: &'a Credentials, - service_name: &'a str, - region: &'a Region, -) -> Result, BuildError> { - let mut signing_settings = SigningSettings::default(); - signing_settings.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; // required for OpenSearch Serverless - - let mut builder = SigningParams::builder() - .access_key(credentials.access_key_id()) - .secret_key(credentials.secret_access_key()) - .service_name(service_name) - .region(region.as_ref()) - .time(SystemTime::now()) - .settings(signing_settings); - - builder.set_security_token(credentials.session_token()); - - builder.build() -} - -pub async fn sign_request( - request: &mut Request, - credentials_provider: &SharedCredentialsProvider, - service_name: &str, - region: &Region, -) -> Result<(), Box> { - let credentials = credentials_provider.provide_credentials().await?; - - let params = get_signing_params(&credentials, service_name, region)?; - - let uri = request.url().as_str().parse()?; - - let signable_request = SignableRequest::new( - request.method(), - &uri, - request.headers(), - SignableBody::Bytes(request.body().and_then(|b| b.as_bytes()).unwrap_or(&[])), - ); - - let (mut instructions, _) = sign(signable_request, ¶ms)?.into_parts(); - - if let Some(new_headers) = instructions.take_headers() { - for (name, value) in new_headers.into_iter() { - request.headers_mut().insert( - name.expect("AWS signing header name must never be None"), - value, - ); - } - } - - Ok(()) -} diff --git a/opensearch/src/http/headers.rs b/opensearch/src/http/headers.rs index 713168eb..4f616720 100644 --- a/opensearch/src/http/headers.rs +++ b/opensearch/src/http/headers.rs @@ -28,7 +28,7 @@ * GitHub history for details. */ -//! HTTP header names and values, including those specific to Elasticsearch +//! HTTP header names and values, including those specific to OpenSearch pub use reqwest::header::*; diff --git a/opensearch/src/http/middleware/initializers.rs b/opensearch/src/http/middleware/initializers.rs new file mode 100644 index 00000000..07fb194c --- /dev/null +++ b/opensearch/src/http/middleware/initializers.rs @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +use super::shared_middleware; +use crate::BoxError; +use reqwest::{ClientBuilder, RequestBuilder}; + +pub trait ClientInitializer: Send + Sync + 'static { + fn init(&self, client: ClientBuilder) -> Result>; +} + +impl ClientInitializer for F +where + F: Fn(ClientBuilder) -> Result> + Send + Sync + 'static, +{ + fn init(&self, client: ClientBuilder) -> Result> { + self(client) + } +} + +pub trait RequestInitializer: Send + Sync + 'static { + fn init(&self, request: RequestBuilder) -> Result>; +} + +impl RequestInitializer for F +where + F: Fn(RequestBuilder) -> Result> + Send + Sync + 'static, +{ + fn init(&self, request: RequestBuilder) -> Result> { + self(request) + } +} + +shared_middleware!( + SharedClientInitializer(ClientInitializer), + SharedRequestInitializer(RequestInitializer) +); diff --git a/opensearch/src/http/middleware/mod.rs b/opensearch/src/http/middleware/mod.rs new file mode 100644 index 00000000..b29ce054 --- /dev/null +++ b/opensearch/src/http/middleware/mod.rs @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +mod initializers; +mod request_pipeline; + +pub use async_trait::async_trait; +pub use initializers::*; +pub use request_pipeline::*; + +macro_rules! shared_middleware { + ($($shared:ident($trait:ident)),*) => { + $( + #[derive(Clone)] + pub struct $shared(std::sync::Arc); + + impl From for $shared + where M: $trait + { + fn from(middleware: M) -> Self { + Self(std::sync::Arc::new(middleware)) + } + } + + impl From> for $shared + where M: $trait + { + fn from(middleware: std::sync::Arc) -> Self { + Self(middleware) + } + } + + impl From> for $shared { + fn from(middleware: std::sync::Arc) -> Self { + Self(middleware) + } + } + + impl std::ops::Deref for $shared { + type Target = dyn $trait; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } + } + )* + } +} + +pub(self) use shared_middleware; diff --git a/opensearch/src/http/middleware/request_pipeline.rs b/opensearch/src/http/middleware/request_pipeline.rs new file mode 100644 index 00000000..faeb9847 --- /dev/null +++ b/opensearch/src/http/middleware/request_pipeline.rs @@ -0,0 +1,142 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +//! Request pipeline types +//! +//! Examples: +//! +//! ```no_run +//! use futures_util::future::BoxFuture; +//! use opensearch::http::{ +//! middleware::{RequestPipeline, RequestPipelineError}, +//! reqwest::{Request, Response}, +//! }; +//! +//! fn logger<'a>(req: Request, next: RequestPipeline<'a>) -> BoxFuture<'a, Result> { +//! Box::pin(async move { +//! println!("sending request to {}", req.url()); +//! let now = std::time::Instant::now(); +//! let res = next.run(req).await?; +//! println!("request completed ({:?})", now.elapsed()); +//! Ok(res) +//! }) +//! } +//! +//! # #[tokio::main] +//! # async fn main() { +//! # use opensearch::http::{ +//! # transport::{SingleNodeConnectionPool, TransportBuilder}, +//! # Url, +//! # }; +//! # let conn_pool = SingleNodeConnectionPool::new(Url::parse("http://localhost:9200").unwrap()); +//! # let _ = TransportBuilder::new(conn_pool).with_handler(logger); +//! # } +//! ``` +//! +//! ```no_run +//! use futures_util::future::BoxFuture; +//! use opensearch::http::{ +//! middleware::{async_trait, RequestHandler, RequestPipeline, RequestPipelineError}, +//! reqwest::{Request, Response}, +//! }; +//! +//! struct Logger; +//! +//! #[async_trait] +//! impl RequestHandler for Logger { +//! async fn handle(&self, request: Request, next: RequestPipeline<'_>) -> Result { +//! println!("sending request to {}", request.url()); +//! let now = std::time::Instant::now(); +//! let res = next.run(request).await?; +//! println!("request completed ({:?})", now.elapsed()); +//! Ok(res) +//! } +//! } +//! +//! # #[tokio::main] +//! # async fn main() { +//! # use opensearch::http::{ +//! # transport::{SingleNodeConnectionPool, TransportBuilder}, +//! # Url, +//! # }; +//! # let conn_pool = SingleNodeConnectionPool::new(Url::parse("http://localhost:9200").unwrap()); +//! # let _ = TransportBuilder::new(conn_pool).with_handler(Logger); +//! # } +//! ``` + +use super::{async_trait, shared_middleware}; +use crate::BoxError; +use futures_util::future::BoxFuture; +use reqwest::{Client, Request, Response}; +use std::fmt::Debug; + +#[derive(Debug, thiserror::Error)] +pub enum RequestPipelineError { + #[error("Reqwest error: {0}")] + Reqwest(#[from] reqwest::Error), + + #[error("Pipeline error: {0}")] + Pipeline(#[from] BoxError<'static>), +} + +#[async_trait] +pub trait RequestHandler: Send + Sync + 'static { + async fn handle( + &self, + request: Request, + next: RequestPipeline<'_>, + ) -> Result; +} + +#[async_trait] +impl RequestHandler for F +where + F: for<'a> Fn( + Request, + RequestPipeline<'a>, + ) -> BoxFuture<'a, Result> + + Send + + Sync + + 'static, +{ + async fn handle( + &self, + request: Request, + next: RequestPipeline<'_>, + ) -> Result { + self(request, next).await + } +} + +shared_middleware!(SharedRequestHandler(RequestHandler)); + +pub struct RequestPipeline<'a> { + pub client: &'a Client, + pipeline: &'a [SharedRequestHandler], +} + +impl<'a> RequestPipeline<'a> { + pub(crate) fn new(client: &'a Client, pipeline: &'a [SharedRequestHandler]) -> Self { + Self { client, pipeline } + } + + pub fn run( + mut self, + request: Request, + ) -> BoxFuture<'a, Result> { + if let Some((head, tail)) = self.pipeline.split_first() { + self.pipeline = tail; + head.handle(request, self) + } else { + Box::pin(async move { self.client.execute(request).await.map_err(Into::into) }) + } + } +} diff --git a/opensearch/src/http/mod.rs b/opensearch/src/http/mod.rs index 9e8a6417..7ae54293 100644 --- a/opensearch/src/http/mod.rs +++ b/opensearch/src/http/mod.rs @@ -30,18 +30,16 @@ //! HTTP components -#[cfg(feature = "aws-auth")] -pub(crate) mod aws_auth; - pub mod headers; +pub mod middleware; pub mod request; pub mod response; pub mod transport; -pub use reqwest::StatusCode; +pub use reqwest::{self, Request, StatusCode}; pub use url::Url; -/// Http methods supported by Elasticsearch +/// Http methods supported by OpenSearch #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Method { /// get diff --git a/opensearch/src/http/transport.rs b/opensearch/src/http/transport.rs index f70c6eb1..d37a3543 100644 --- a/opensearch/src/http/transport.rs +++ b/opensearch/src/http/transport.rs @@ -30,8 +30,9 @@ //! HTTP transport and connection components -#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] -use crate::auth::ClientCertificate; +use super::middleware::{ + RequestPipeline, SharedClientInitializer, SharedRequestHandler, SharedRequestInitializer, +}; #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] use crate::cert::CertificateValidation; use crate::{ @@ -39,73 +40,32 @@ use crate::{ error::Error, http::{ headers::{ - HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE, - DEFAULT_ACCEPT, DEFAULT_CONTENT_TYPE, DEFAULT_USER_AGENT, USER_AGENT, + HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE, DEFAULT_ACCEPT, + DEFAULT_CONTENT_TYPE, DEFAULT_USER_AGENT, USER_AGENT, }, request::Body, response::Response, Method, }, + BoxError, }; -use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder}; use bytes::BytesMut; use dyn_clone::clone_trait_object; use lazy_static::lazy_static; use serde::Serialize; -use std::{ - error, fmt, - fmt::Debug, - io::{self, Write}, - time::Duration, -}; +use std::{fmt::Debug, sync::Arc, time::Duration}; use url::Url; /// Error that can occur when building a [Transport] -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum BuildError { - /// IO error - Io(io::Error), - - /// Certificate error - Cert(reqwest::Error), -} - -impl From for BuildError { - fn from(err: io::Error) -> BuildError { - BuildError::Io(err) - } -} + /// Reqwest error + #[error("Reqwest error: {0}")] + Reqwest(#[from] reqwest::Error), -impl From for BuildError { - fn from(err: reqwest::Error) -> BuildError { - BuildError::Cert(err) - } -} - -impl error::Error for BuildError { - #[allow(warnings)] - fn description(&self) -> &str { - match *self { - BuildError::Io(ref err) => err.description(), - BuildError::Cert(ref err) => err.description(), - } - } - - fn cause(&self) -> Option<&dyn error::Error> { - match *self { - BuildError::Io(ref err) => Some(err as &dyn error::Error), - BuildError::Cert(ref err) => Some(err as &dyn error::Error), - } - } -} - -impl fmt::Display for BuildError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - BuildError::Io(ref err) => fmt::Display::fmt(err, f), - BuildError::Cert(ref err) => fmt::Display::fmt(err, f), - } - } + /// Client initializer error + #[error("Client initializer error: {0}")] + ClientInitializer(#[from] BoxError<'static>), } /// Default address to OpenSearch running on `http://localhost:9200` @@ -146,7 +106,9 @@ fn build_meta() -> String { pub struct TransportBuilder { client_builder: reqwest::ClientBuilder, conn_pool: Box, - credentials: Option, + init_stack: Vec, + req_init_stack: Vec, + req_handler_stack: Vec, #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] cert_validation: Option, proxy: Option, @@ -154,8 +116,6 @@ pub struct TransportBuilder { disable_proxy: bool, headers: HeaderMap, timeout: Option, - #[cfg(feature = "aws-auth")] - service_name: String, } impl TransportBuilder { @@ -168,7 +128,9 @@ impl TransportBuilder { Self { client_builder: reqwest::ClientBuilder::new(), conn_pool: Box::new(conn_pool), - credentials: None, + init_stack: Vec::new(), + req_init_stack: Vec::new(), + req_handler_stack: Vec::new(), #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] cert_validation: None, proxy: None, @@ -176,11 +138,24 @@ impl TransportBuilder { disable_proxy: false, headers: HeaderMap::new(), timeout: None, - #[cfg(feature = "aws-auth")] - service_name: "es".to_string(), } } + pub fn with_init(mut self, init: impl Into) -> Self { + self.init_stack.push(init.into()); + self + } + + pub fn with_req_init(mut self, init: impl Into) -> Self { + self.req_init_stack.push(init.into()); + self + } + + pub fn with_handler(mut self, handler: impl Into) -> Self { + self.req_handler_stack.push(handler.into()); + self + } + /// Configures a proxy. /// /// An optional username and password will be used to set the @@ -204,9 +179,10 @@ impl TransportBuilder { } /// Credentials for the client to use for authentication to OpenSearch. - pub fn auth(mut self, credentials: Credentials) -> Self { - self.credentials = Some(credentials); - self + pub fn auth(self, credentials: Credentials) -> Self { + let credentials = Arc::new(credentials); + self.with_init(credentials.clone()) + .with_req_init(credentials) } /// Validation applied to the certificate provided to establish a HTTPS connection. @@ -245,16 +221,7 @@ impl TransportBuilder { self } - /// Sets a global AWS service name. - /// - /// Default is "es". Other supported services are "aoss" for OpenSearch Serverless. - #[cfg(feature = "aws-auth")] - pub fn service_name(mut self, service_name: &str) -> Self { - self.service_name = service_name.to_string(); - self - } - - /// Builds a [Transport] to use to send API calls to Elasticsearch. + /// Builds a [Transport] to use to send API calls to OpenSearch. pub fn build(self) -> Result { let mut client_builder = self.client_builder; @@ -266,28 +233,6 @@ impl TransportBuilder { client_builder = client_builder.timeout(t); } - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - { - if let Some(Credentials::Certificate(cert)) = &self.credentials { - client_builder = match cert { - #[cfg(feature = "native-tls")] - ClientCertificate::Pkcs12(b, p) => { - let password = match p { - Some(pass) => pass.as_str(), - None => "", - }; - let pkcs12 = reqwest::Identity::from_pkcs12_der(b, password)?; - client_builder.identity(pkcs12) - } - #[cfg(feature = "rustls-tls")] - ClientCertificate::Pem(b) => { - let pem = reqwest::Identity::from_pem(b)?; - client_builder.identity(pem) - } - } - } - } - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] if let Some(v) = self.cert_validation { client_builder = match v { @@ -321,13 +266,19 @@ impl TransportBuilder { client_builder = client_builder.proxy(proxy); } + client_builder = self + .init_stack + .into_iter() + .try_fold(client_builder, |client_builder, init| { + init.init(client_builder) + })?; + let client = client_builder.build()?; Ok(Transport { client, conn_pool: self.conn_pool, - credentials: self.credentials, - #[cfg(feature = "aws-auth")] - service_name: self.service_name, + req_init_stack: self.req_init_stack.into_boxed_slice(), + req_handler_stack: self.req_handler_stack.into_boxed_slice(), }) } } @@ -339,7 +290,7 @@ impl Default for TransportBuilder { } } -/// A connection to an Elasticsearch node, used to send an API request +/// A connection to an OpenSearch node, used to send an API request #[derive(Debug, Clone)] pub struct Connection { url: Url, @@ -360,15 +311,14 @@ impl Connection { } } -/// A HTTP transport responsible for making the API requests to Elasticsearch, +/// A HTTP transport responsible for making the API requests to OpenSearch, /// using a [Connection] selected from a [ConnectionPool] -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct Transport { client: reqwest::Client, - credentials: Option, + req_init_stack: Box<[SharedRequestInitializer]>, + req_handler_stack: Box<[SharedRequestHandler]>, conn_pool: Box, - #[cfg(feature = "aws-auth")] - service_name: String, } impl Transport { @@ -413,38 +363,20 @@ impl Transport { let connection = self.conn_pool.next(); let url = connection.url.join(path.trim_start_matches('/'))?; let reqwest_method = self.method(method); - let mut request_builder = self.client.request(reqwest_method, url); + + let mut request_builder = self + .req_init_stack + .iter() + .try_fold( + self.client.request(reqwest_method, url), + |request_builder, init| init.init(request_builder), + ) + .map_err(Error::RequestInitializer)?; if let Some(t) = timeout { request_builder = request_builder.timeout(t); } - // set credentials before any headers, as credentials append to existing headers in reqwest, - // whilst setting headers() overwrites, so if an Authorization header has been specified - // on a specific request, we want it to overwrite. - if let Some(c) = &self.credentials { - request_builder = match c { - Credentials::Basic(u, p) => request_builder.basic_auth(u, Some(p)), - Credentials::Bearer(t) => request_builder.bearer_auth(t), - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - Credentials::Certificate(_) => request_builder, - Credentials::ApiKey(i, k) => { - let mut header_value = b"ApiKey ".to_vec(); - { - let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD); - write!(encoder, "{}:", i).unwrap(); - write!(encoder, "{}", k).unwrap(); - } - request_builder.header( - AUTHORIZATION, - HeaderValue::from_bytes(&header_value).unwrap(), - ) - } - #[cfg(feature = "aws-auth")] - Credentials::AwsSigV4(_, _) => request_builder, - } - } - // default headers first, overwrite with any provided let mut request_headers = HeaderMap::with_capacity(4 + headers.len()); request_headers.insert(CONTENT_TYPE, HeaderValue::from_static(DEFAULT_CONTENT_TYPE)); @@ -472,26 +404,11 @@ impl Transport { request_builder = request_builder.query(q); } - #[cfg_attr(not(feature = "aws-auth"), allow(unused_mut))] - let mut request = request_builder.build()?; + let response = RequestPipeline::new(&self.client, &self.req_handler_stack) + .run(request_builder.build()?) + .await?; - #[cfg(feature = "aws-auth")] - if let Some(Credentials::AwsSigV4(credentials_provider, region)) = &self.credentials { - super::aws_auth::sign_request( - &mut request, - credentials_provider, - &self.service_name, - region, - ) - .await - .map_err(|e| crate::error::lib(format!("AWSV4 Signing Failed: {}", e)))?; - } - - let response = self.client.execute(request).await; - match response { - Ok(r) => Ok(Response::new(r, method)), - Err(e) => Err(e.into()), - } + Ok(Response::new(response, method)) } } @@ -501,7 +418,18 @@ impl Default for Transport { } } -/// A pool of [Connection]s, used to make API calls to Elasticsearch. +impl std::fmt::Debug for Transport { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Transport") + .field("client", &self.client) + // .field("req_init_stack", &self.req_init_stack) + // .field("req_handler_stack", &self.req_handler_stack) + .field("conn_pool", &self.conn_pool) + .finish() + } +} + +/// A pool of [Connection]s, used to make API calls to OpenSearch. /// /// A [ConnectionPool] manages the connections, with different implementations determining how /// to get the next [Connection]. The simplest type of [ConnectionPool] is [SingleNodeConnectionPool], @@ -514,7 +442,7 @@ pub trait ConnectionPool: Debug + dyn_clone::DynClone + Sync + Send { clone_trait_object!(ConnectionPool); -/// A connection pool that manages the single connection to an Elasticsearch cluster. +/// A connection pool that manages the single connection to an OpenSearch cluster. #[derive(Debug, Clone)] pub struct SingleNodeConnectionPool { connection: Connection, diff --git a/opensearch/src/lib.rs b/opensearch/src/lib.rs index 4e858eaa..ef83e0eb 100644 --- a/opensearch/src/lib.rs +++ b/opensearch/src/lib.rs @@ -58,8 +58,6 @@ //! - **experimental-apis**: Enables experimental APIs. Experimental APIs are just that - an experiment. An experimental //! API might have breaking changes in any future version, or it might even be removed entirely. This feature also //! enables `beta-apis`. -//! - **aws-auth**: Enables authentication with Amazon OpenSearch and OpenSearch Serverless. -//! Performs AWS SigV4 signing using credential types from `aws-types`. //! //! # Getting started //! @@ -327,43 +325,6 @@ //! # } //! ``` //! -//! ## Amazon OpenSearch and OpenSearch Serverless -//! -//! For authenticating against an Amazon OpenSearch or OpenSearch Serverless endpoint using AWS SigV4 request signing, -//! you must enable the `aws-auth` feature, then pass the AWS credentials to the [TransportBuilder](http::transport::TransportBuilder). -//! The easiest way to retrieve AWS credentials in the required format is to use [aws-config](https://docs.rs/aws-config/latest/aws_config/). -//! -//! ```toml,no_run -//! [dependencies] -//! opensearch = { version = "1", features = ["aws-auth"] } -//! aws-config = "0.10" -//! ``` -//! -//! ```rust,no_run -//! # use aws_config::meta::region::RegionProviderChain; -//! # use opensearch::{ -//! # Error, OpenSearch, -//! # http::transport::{TransportBuilder,SingleNodeConnectionPool}, -//! # }; -//! # use url::Url; -//! # use std::convert::TryInto; -//! # #[tokio::main] -//! # async fn main() -> Result<(), Box> { -//! # #[cfg(feature = "aws-auth")] { -//! let creds = aws_config::load_from_env().await; -//! let url = Url::parse("https://...")?; -//! let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); -//! let aws_config = aws_config::from_env().region(region_provider).load().await.clone(); -//! let conn_pool = SingleNodeConnectionPool::new(url); -//! let transport = TransportBuilder::new(conn_pool) -//! .auth(aws_config.clone().try_into()?) -//! .service_name("es") // use "aoss" for OpenSearch Serverless -//! .build()?; -//! let client = OpenSearch::new(transport); -//! # } -//! # Ok(()) -//! # } -//! ``` #![doc( html_logo_url = "https://github.com/opensearch-project/opensearch-rs/raw/main/OpenSearch.svg" diff --git a/opensearch/tests/auth.rs b/opensearch/tests/auth.rs index 031be9c2..47d44556 100644 --- a/opensearch/tests/auth.rs +++ b/opensearch/tests/auth.rs @@ -33,71 +33,61 @@ use common::*; use opensearch::auth::Credentials; -use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder}; -use std::io::Write; +use wiremock::MockServer; #[tokio::test] async fn basic_auth_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - let mut header_value = b"Basic ".to_vec(); - { - let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD); - write!(encoder, "username:password").unwrap(); - } - - assert_eq!( - req.headers()["authorization"], - String::from_utf8(header_value).unwrap() - ); - http::Response::default() - }); - - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) + let server = MockServer::start().await; + + let builder = client::create_builder(&server.uri()) .auth(Credentials::Basic("username".into(), "password".into())); - let client = client::create(builder); - let _response = client.ping().send().await?; + let _ = client::create(builder).ping().send().await?; + + let requests = server.received_requests().await.unwrap(); + + assert_eq!( + &header_values(&requests[0], "authorization"), + &["Basic dXNlcm5hbWU6cGFzc3dvcmQ="] + ); Ok(()) } #[tokio::test] async fn api_key_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - let mut header_value = b"ApiKey ".to_vec(); - { - let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD); - write!(encoder, "id:api_key").unwrap(); - } - - assert_eq!( - req.headers()["authorization"], - String::from_utf8(header_value).unwrap() - ); - http::Response::default() - }); - - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) + let server = MockServer::start().await; + + let builder = client::create_builder(&server.uri()) .auth(Credentials::ApiKey("id".into(), "api_key".into())); - let client = client::create(builder); - let _response = client.ping().send().await?; + let _ = client::create(builder).ping().send().await?; + + let requests = server.received_requests().await.unwrap(); + + assert_eq!( + &header_values(&requests[0], "authorization"), + &["ApiKey aWQ6YXBpX2tleQ=="] + ); Ok(()) } #[tokio::test] async fn bearer_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_eq!(req.headers()["authorization"], "Bearer access_token"); - http::Response::default() - }); + let server = MockServer::start().await; + + let builder = + client::create_builder(&server.uri()).auth(Credentials::Bearer("access_token".into())); + + let _ = client::create(builder).ping().send().await?; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .auth(Credentials::Bearer("access_token".into())); + let requests = server.received_requests().await.unwrap(); - let client = client::create(builder); - let _response = client.ping().send().await?; + assert_eq!( + &header_values(&requests[0], "authorization"), + &["Bearer access_token"] + ); Ok(()) } diff --git a/opensearch/tests/aws_auth.rs b/opensearch/tests/aws_auth.rs deleted file mode 100644 index 72d6d13b..00000000 --- a/opensearch/tests/aws_auth.rs +++ /dev/null @@ -1,86 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -#![cfg(feature = "aws-auth")] - -pub mod common; -use common::*; -use opensearch::OpenSearch; -use regex::Regex; - -use aws_config::SdkConfig; -use aws_credential_types::provider::SharedCredentialsProvider; -use aws_credential_types::Credentials; -use aws_types::region::Region; -use std::convert::TryInto; - -#[tokio::test] -async fn aws_auth_get() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - let authorization_header = req.headers()["authorization"].to_str().unwrap(); - let re = Regex::new(r"^AWS4-HMAC-SHA256 Credential=id/\d*/us-west-1/custom/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature=[a-f,0-9].*$").unwrap(); - assert!( - re.is_match(authorization_header), - "{}", - authorization_header - ); - let amz_content_sha256_header = req.headers()["x-amz-content-sha256"].to_str().unwrap(); - assert_eq!( - amz_content_sha256_header, - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" - ); // SHA of empty string - http::Response::default() - }); - - let client = create_aws_client(format!("http://{}", server.addr()).as_ref())?; - let _response = client.ping().send().await?; - - Ok(()) -} - -#[tokio::test] -async fn aws_auth_post() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - let amz_content_sha256_header = req.headers()["x-amz-content-sha256"].to_str().unwrap(); - assert_eq!( - amz_content_sha256_header, - "f3a842f988a653a734ebe4e57c45f19293a002241a72f0b3abbff71e4f5297b9" - ); // SHA of the JSON - http::Response::default() - }); - - let client = create_aws_client(format!("http://{}", server.addr()).as_ref())?; - client - .index(opensearch::IndexParts::Index("movies")) - .body(serde_json::json!({ - "title": "Moneyball", - "director": "Bennett Miller", - "year": 2011 - } - )) - .send() - .await?; - - Ok(()) -} - -fn create_aws_client(addr: &str) -> anyhow::Result { - let aws_creds = Credentials::new("id", "secret", None, None, "token"); - let creds_provider = SharedCredentialsProvider::new(aws_creds); - let aws_config = SdkConfig::builder() - .credentials_provider(creds_provider) - .region(Region::new("us-west-1")) - .build(); - let builder = client::create_builder(addr) - .auth(aws_config.clone().try_into()?) - .service_name("custom"); - Ok(client::create(builder)) -} diff --git a/opensearch/tests/client.rs b/opensearch/tests/client.rs index 9cf68dcb..7c7f61b7 100644 --- a/opensearch/tests/client.rs +++ b/opensearch/tests/client.rs @@ -31,6 +31,8 @@ pub mod common; use common::*; +use crate::common::client::index_documents; +use bytes::Bytes; use opensearch::{ http::{ headers::{ @@ -42,60 +44,64 @@ use opensearch::{ params::TrackTotalHits, SearchParts, }; - -use crate::common::client::index_documents; -use bytes::Bytes; -use hyper::Method; use serde_json::{json, Value}; use std::time::Duration; +use wiremock::{ + http::Method, + matchers::{method, path}, + Mock, MockServer, ResponseTemplate, +}; #[tokio::test] async fn default_user_agent_content_type_accept_headers() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_eq!(req.headers()["user-agent"], DEFAULT_USER_AGENT); - assert_eq!(req.headers()["content-type"], "application/json"); - assert_eq!(req.headers()["accept"], "application/json"); - http::Response::default() - }); + let server = MockServer::start().await; - let client = client::create_for_url(format!("http://{}", server.addr()).as_ref()); - let _response = client.ping().send().await?; + let _ = client::create_for_url(&server.uri()).ping().send().await?; + + let requests = server.received_requests().await.unwrap(); + assert_eq!( + &header_values(&requests[0], "user-agent"), + &[DEFAULT_USER_AGENT] + ); + assert_eq!( + &header_values(&requests[0], "content-type"), + &["application/json"] + ); + assert_eq!( + &header_values(&requests[0], "accept"), + &["application/json"] + ); Ok(()) } #[tokio::test] async fn default_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_eq!(req.headers()["x-opaque-id"], "foo"); - http::Response::default() - }); + let server = MockServer::start().await; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()).header( + let builder = client::create_builder(&server.uri()).header( HeaderName::from_static(X_OPAQUE_ID), HeaderValue::from_static("foo"), ); - let client = client::create(builder); - let _response = client.ping().send().await?; + let _ = client::create(builder).ping().send().await?; + + let requests = server.received_requests().await.unwrap(); + assert_eq!(&header_values(&requests[0], "x-opaque-id"), &["foo"]); Ok(()) } #[tokio::test] async fn override_default_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_eq!(req.headers()["x-opaque-id"], "bar"); - http::Response::default() - }); + let server = MockServer::start().await; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()).header( + let builder = client::create_builder(&server.uri()).header( HeaderName::from_static(X_OPAQUE_ID), HeaderValue::from_static("foo"), ); - let client = client::create(builder); - let _response = client + let _ = client::create(builder) .ping() .header( HeaderName::from_static(X_OPAQUE_ID), @@ -104,18 +110,17 @@ async fn override_default_header() -> anyhow::Result<()> { .send() .await?; + let requests = server.received_requests().await.unwrap(); + assert_eq!(&header_values(&requests[0], "x-opaque-id"), &["bar"]); + Ok(()) } #[tokio::test] async fn x_opaque_id_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_eq!(req.headers()["x-opaque-id"], "foo"); - http::Response::default() - }); + let server = MockServer::start().await; - let client = client::create_for_url(format!("http://{}", server.addr()).as_ref()); - let _response = client + let _ = client::create_for_url(&server.uri()) .ping() .header( HeaderName::from_static(X_OPAQUE_ID), @@ -124,21 +129,26 @@ async fn x_opaque_id_header() -> anyhow::Result<()> { .send() .await?; + let requests = server.received_requests().await.unwrap(); + assert_eq!(&header_values(&requests[0], "x-opaque-id"), &["foo"]); + Ok(()) } #[tokio::test] async fn uses_global_request_timeout() { - let server = server::http(move |_| async move { - std::thread::sleep(Duration::from_secs(1)); - http::Response::default() - }); + let server = MockServer::start().await; + + Mock::given(method("HEAD")) + .and(path("/")) + .respond_with(DelayedResponse::new(1, ResponseTemplate::new(200))) + .mount(&server) + .await; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .timeout(std::time::Duration::from_millis(500)); + let builder = + client::create_builder(&server.uri()).timeout(std::time::Duration::from_millis(500)); - let client = client::create(builder); - let response = client.ping().send().await; + let response = client::create(builder).ping().send().await; match response { Ok(_) => panic!("Expected timeout error, but response received"), @@ -148,16 +158,17 @@ async fn uses_global_request_timeout() { #[tokio::test] async fn uses_call_request_timeout() { - let server = server::http(move |_| async move { - std::thread::sleep(Duration::from_secs(1)); - http::Response::default() - }); + let server = MockServer::start().await; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .timeout(std::time::Duration::from_secs(2)); + Mock::given(method("HEAD")) + .and(path("/")) + .respond_with(DelayedResponse::new(1, ResponseTemplate::new(200))) + .mount(&server) + .await; - let client = client::create(builder); - let response = client + let builder = client::create_builder(&server.uri()).timeout(std::time::Duration::from_secs(2)); + + let response = client::create(builder) .ping() .request_timeout(Duration::from_millis(500)) .send() @@ -171,16 +182,18 @@ async fn uses_call_request_timeout() { #[tokio::test] async fn call_request_timeout_supersedes_global_timeout() { - let server = server::http(move |_| async move { - std::thread::sleep(Duration::from_secs(1)); - http::Response::default() - }); + let server = MockServer::start().await; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .timeout(std::time::Duration::from_millis(500)); + Mock::given(method("HEAD")) + .and(path("/")) + .respond_with(DelayedResponse::new(1, ResponseTemplate::new(200))) + .mount(&server) + .await; - let client = client::create(builder); - let response = client + let builder = + client::create_builder(&server.uri()).timeout(std::time::Duration::from_millis(500)); + + let response = client::create(builder) .ping() .request_timeout(Duration::from_secs(2)) .send() @@ -239,18 +252,9 @@ async fn deprecation_warning_headers() -> anyhow::Result<()> { #[tokio::test] async fn serialize_querystring() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_eq!(req.method(), Method::GET); - assert_eq!(req.uri().path(), "/_search"); - assert_eq!( - req.uri().query(), - Some("filter_path=took%2C_shards&pretty=true&q=title%3AOpenSearch&track_total_hits=100000") - ); - http::Response::default() - }); - - let client = client::create_for_url(format!("http://{}", server.addr()).as_ref()); - let _response = client + let server = MockServer::start().await; + + let _ = client::create_for_url(&server.uri()) .search(SearchParts::None) .pretty(true) .filter_path(&["took", "_shards"]) @@ -259,6 +263,14 @@ async fn serialize_querystring() -> anyhow::Result<()> { .send() .await?; + let requests = server.received_requests().await.unwrap(); + assert_eq!(requests[0].method, Method::Get); + assert_eq!(requests[0].url.path(), "/_search"); + assert_eq!( + requests[0].url.query(), + Some("filter_path=took%2C_shards&pretty=true&q=title%3AOpenSearch&track_total_hits=100000") + ); + Ok(()) } diff --git a/opensearch/tests/common/mod.rs b/opensearch/tests/common/mod.rs index dd1dadcf..41717b28 100644 --- a/opensearch/tests/common/mod.rs +++ b/opensearch/tests/common/mod.rs @@ -28,8 +28,37 @@ * GitHub history for details. */ +use std::time::Duration; + pub mod client; -pub mod server; #[allow(unused)] pub static DEFAULT_USER_AGENT: &str = concat!("opensearch-rs/", env!("CARGO_PKG_VERSION")); + +pub fn header_values<'r>(request: &'r wiremock::Request, header_name: &str) -> Vec<&'r str> { + request + .headers + .get(&header_name.into()) + .into_iter() + .flatten() + .map(wiremock::http::HeaderValue::as_str) + .collect::>() +} + +pub struct DelayedResponse { + n_secs: u64, + resp: wiremock::ResponseTemplate, +} + +impl DelayedResponse { + pub fn new(n_secs: u64, resp: wiremock::ResponseTemplate) -> Self { + Self { n_secs, resp } + } +} + +impl wiremock::Respond for DelayedResponse { + fn respond(&self, _: &wiremock::Request) -> wiremock::ResponseTemplate { + std::thread::sleep(Duration::from_secs(self.n_secs)); + self.resp.clone() + } +} diff --git a/opensearch/tests/common/server.rs b/opensearch/tests/common/server.rs deleted file mode 100644 index f3dcd8af..00000000 --- a/opensearch/tests/common/server.rs +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to Elasticsearch B.V. under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch B.V. licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -// From reqwest crate -// Licensed under Apache License, Version 2.0 -// https://github.com/seanmonstar/reqwest/blob/master/LICENSE-APACHE - -use std::{ - convert::Infallible, future::Future, net, sync::mpsc as std_mpsc, thread, time::Duration, -}; - -use tokio::sync::oneshot; - -pub use http::Response; -use tokio::runtime; - -pub struct Server { - addr: net::SocketAddr, - panic_rx: std_mpsc::Receiver<()>, - shutdown_tx: Option>, -} - -impl Server { - pub fn addr(&self) -> net::SocketAddr { - self.addr - } -} - -impl Drop for Server { - fn drop(&mut self) { - if let Some(tx) = self.shutdown_tx.take() { - let _ = tx.send(()); - } - - if !::std::thread::panicking() { - self.panic_rx - .recv_timeout(Duration::from_secs(3)) - .expect("test server should not panic"); - } - } -} - -pub fn http(func: F) -> Server -where - F: Fn(http::Request) -> Fut + Clone + Send + 'static, - Fut: Future> + Send + 'static, -{ - //Spawn new runtime in thread to prevent reactor execution context conflict - thread::spawn(move || { - let rt = runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("new rt"); - - let srv = { - let _guard = rt.enter(); - hyper::Server::bind(&([127, 0, 0, 1], 0).into()).serve(hyper::service::make_service_fn( - move |_| { - let func = func.clone(); - async move { - Ok::<_, Infallible>(hyper::service::service_fn(move |req| { - let fut = func(req); - async move { Ok::<_, Infallible>(fut.await) } - })) - } - }, - )) - }; - - let addr = srv.local_addr(); - let (shutdown_tx, shutdown_rx) = oneshot::channel(); - let srv = srv.with_graceful_shutdown(async move { - let _ = shutdown_rx.await; - }); - - let (panic_tx, panic_rx) = std_mpsc::channel(); - let tname = format!( - "test({})-support-server", - thread::current().name().unwrap_or("") - ); - thread::Builder::new() - .name(tname) - .spawn(move || { - rt.block_on(srv).unwrap(); - let _ = panic_tx.send(()); - }) - .expect("thread spawn"); - - Server { - addr, - panic_rx, - shutdown_tx: Some(shutdown_tx), - } - }) - .join() - .unwrap() -}