From b175b341c943e4151ffcf4474f36b278c4d7915e Mon Sep 17 00:00:00 2001 From: Thomas Farr Date: Thu, 18 Jan 2024 11:17:36 +1300 Subject: [PATCH] Implement middleware types to allow intercepting client & request handling (#232) * Implement middleware types to allow intercepting client & request handling Signed-off-by: Thomas Farr * Add changelog entry Signed-off-by: Thomas Farr * Relax position of clone bounds Signed-off-by: Thomas Farr * Move is_send_sync check behind test cfg Signed-off-by: Thomas Farr * Rename RequestPipeline to RequestHandlerChain Signed-off-by: Thomas Farr --------- Signed-off-by: Thomas Farr --- CHANGELOG.md | 1 + opensearch/Cargo.toml | 1 + opensearch/src/client.rs | 10 + opensearch/src/error.rs | 28 ++- .../http/middleware/initializers/client.rs | 47 ++++ .../src/http/middleware/initializers/mod.rs | 44 ++++ .../http/middleware/initializers/request.rs | 86 +++++++ opensearch/src/http/middleware/mod.rs | 20 ++ .../src/http/middleware/request_handler.rs | 137 +++++++++++ opensearch/src/http/mod.rs | 1 + opensearch/src/http/transport.rs | 217 +++++++++++++++++- opensearch/tests/common/mod.rs | 2 + opensearch/tests/middleware.rs | 106 +++++++++ 13 files changed, 693 insertions(+), 7 deletions(-) create mode 100644 opensearch/src/http/middleware/initializers/client.rs create mode 100644 opensearch/src/http/middleware/initializers/mod.rs create mode 100644 opensearch/src/http/middleware/initializers/request.rs create mode 100644 opensearch/src/http/middleware/mod.rs create mode 100644 opensearch/src/http/middleware/request_handler.rs create mode 100644 opensearch/tests/middleware.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index ae5dae44..fad46740 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Internalized the `BuildError` type, consolidating on the `Error` type ([#228](https://github.com/opensearch-project/opensearch-rs/pull/228)) ### Added +- Added middleware types to allow intercepting construction and handling of the underlying `reqwest` client & requests ([#232](https://github.com/opensearch-project/opensearch-rs/pull/232)) ### Dependencies - Bumps `aws-*` dependencies to `1` ([#219](https://github.com/opensearch-project/opensearch-rs/pull/219)) diff --git a/opensearch/Cargo.toml b/opensearch/Cargo.toml index 76a388f8..4c1e67c5 100644 --- a/opensearch/Cargo.toml +++ b/opensearch/Cargo.toml @@ -29,6 +29,7 @@ rustls-tls = ["reqwest/rustls-tls"] aws-auth = ["aws-credential-types", "aws-sigv4", "aws-smithy-runtime-api", "aws-types"] [dependencies] +async-trait = "0.1" base64 = "0.21" bytes = "1.0" dyn-clone = "1" diff --git a/opensearch/src/client.rs b/opensearch/src/client.rs index 21c937aa..121a344b 100644 --- a/opensearch/src/client.rs +++ b/opensearch/src/client.rs @@ -104,3 +104,13 @@ impl OpenSearch { .await } } + +#[cfg(test)] +mod test { + #[test] + fn client_is_send_sync() { + // Ensure that the client is `Send` and `Sync` + fn is_send_sync() {} + is_send_sync::() + } +} diff --git a/opensearch/src/error.rs b/opensearch/src/error.rs index 4fe10ae4..d6c7ee4a 100644 --- a/opensearch/src/error.rs +++ b/opensearch/src/error.rs @@ -35,7 +35,10 @@ use crate::{ cert::CertificateError, - http::{transport, StatusCode}, + http::{ + middleware::{RequestHandlerError, RequestHandlerErrorKind}, + transport, StatusCode, + }, }; pub(crate) type BoxError<'a> = Box; @@ -53,7 +56,7 @@ where Kind: From, { fn from(error: E) -> Self { - Self(Kind::from(error)) + Self(error.into()) } } @@ -80,11 +83,32 @@ enum Kind { #[cfg(feature = "aws-auth")] #[error("AwsSigV4 error: {0}")] AwsSigV4(#[from] crate::http::aws_auth::AwsSigV4Error), + + #[error("request initializer error: {0}")] + RequestInitializer(#[source] BoxError<'static>), + + #[error("request handler error: {0}")] + RequestHandler(#[source] BoxError<'static>), +} + +impl From for Kind { + fn from(err: RequestHandlerError) -> Self { + use RequestHandlerErrorKind::*; + + match err.0 { + Handler(err) => Self::RequestHandler(err), + Http(err) => Self::Http(err), + } + } } use Kind::*; impl Error { + pub(crate) fn request_initializer(err: BoxError<'static>) -> Self { + Self(RequestInitializer(err)) + } + /// The status code, if the error was generated from a response pub fn status_code(&self) -> Option { match &self.0 { diff --git a/opensearch/src/http/middleware/initializers/client.rs b/opensearch/src/http/middleware/initializers/client.rs new file mode 100644 index 00000000..006752c1 --- /dev/null +++ b/opensearch/src/http/middleware/initializers/client.rs @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +use super::InitializerResult; +use crate::BoxError; +use reqwest::ClientBuilder; + +pub trait ClientInitializer: 'static { + type Result: InitializerResult; + + fn init(self, client: ClientBuilder) -> Self::Result; +} + +impl ClientInitializer for F +where + F: FnOnce(ClientBuilder) -> R + 'static, + R: InitializerResult, +{ + type Result = R; + + fn init(self, client: ClientBuilder) -> Self::Result { + self(client) + } +} + +pub(crate) trait BoxedClientInitializer { + fn init(self: Box, client: ClientBuilder) -> Result>; +} + +impl BoxedClientInitializer for T +where + T: ClientInitializer + Sized, +{ + fn init(self: Box, client: ClientBuilder) -> Result> { + ClientInitializer::init(*self, client) + .into_result() + .map_err(Into::into) + } +} diff --git a/opensearch/src/http/middleware/initializers/mod.rs b/opensearch/src/http/middleware/initializers/mod.rs new file mode 100644 index 00000000..d39e67c4 --- /dev/null +++ b/opensearch/src/http/middleware/initializers/mod.rs @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +mod client; +mod request; + +use crate::BoxError; +use std::convert::Infallible; + +pub use client::*; +pub use request::*; + +pub trait InitializerResult { + type Error: Into>; + + fn into_result(self) -> Result; +} + +impl InitializerResult for Result +where + E: Into>, +{ + type Error = E; + + fn into_result(self) -> Result { + self + } +} + +impl InitializerResult for T { + type Error = Infallible; + + fn into_result(self) -> Result { + Ok(self) + } +} diff --git a/opensearch/src/http/middleware/initializers/request.rs b/opensearch/src/http/middleware/initializers/request.rs new file mode 100644 index 00000000..f657588c --- /dev/null +++ b/opensearch/src/http/middleware/initializers/request.rs @@ -0,0 +1,86 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +use super::InitializerResult; +use crate::BoxError; +use reqwest::RequestBuilder; + +pub trait RequestInitializer: std::fmt::Debug + Send + Sync + 'static { + type Result: InitializerResult; + + fn init(&self, request: RequestBuilder) -> Self::Result; +} + +#[derive(Clone)] +pub struct RequestInitializerFn(F); + +pub fn request_initializer_fn(f: F) -> RequestInitializerFn { + RequestInitializerFn(f) +} + +impl std::fmt::Debug for RequestInitializerFn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(stringify!(RequestInitializerFn)).finish() + } +} + +impl RequestInitializer for RequestInitializerFn +where + F: Fn(RequestBuilder) -> R + Send + Sync + 'static, + R: InitializerResult, +{ + type Result = R; + + fn init(&self, request: RequestBuilder) -> Self::Result { + self.0(request) + } +} + +impl RequestInitializer for std::sync::Arc +where + R: RequestInitializer, +{ + type Result = R::Result; + + fn init(&self, request: RequestBuilder) -> Self::Result { + self.as_ref().init(request) + } +} + +impl RequestInitializer for std::sync::Arc> +where + R: InitializerResult + 'static, +{ + type Result = R; + + fn init(&self, request: RequestBuilder) -> Self::Result { + self.as_ref().init(request) + } +} + +pub(crate) trait BoxedRequestInitializer: + dyn_clone::DynClone + std::fmt::Debug + Send + Sync + 'static +{ + fn init(&self, request: RequestBuilder) -> Result>; +} + +impl BoxedRequestInitializer for T +where + T: RequestInitializer + Clone, +{ + fn init(&self, request: RequestBuilder) -> Result> { + RequestInitializer::init(self, request) + .into_result() + .map_err(Into::into) + } +} + +dyn_clone::clone_trait_object!(BoxedRequestInitializer); diff --git a/opensearch/src/http/middleware/mod.rs b/opensearch/src/http/middleware/mod.rs new file mode 100644 index 00000000..afbaed04 --- /dev/null +++ b/opensearch/src/http/middleware/mod.rs @@ -0,0 +1,20 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +mod initializers; +mod request_handler; + +pub use async_trait::async_trait; +pub use initializers::*; +pub use request_handler::*; + +pub(crate) type BoxFuture<'a, T> = + std::pin::Pin + Send + 'a>>; diff --git a/opensearch/src/http/middleware/request_handler.rs b/opensearch/src/http/middleware/request_handler.rs new file mode 100644 index 00000000..ac9f9ee7 --- /dev/null +++ b/opensearch/src/http/middleware/request_handler.rs @@ -0,0 +1,137 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +use super::{async_trait, BoxFuture}; +use crate::BoxError; +use reqwest::{Client, Request, Response}; + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct RequestHandlerError(pub(crate) RequestHandlerErrorKind); + +impl RequestHandlerError { + pub fn new(err: impl Into>) -> Self { + Self(RequestHandlerErrorKind::Handler(err.into())) + } + + fn http(err: reqwest::Error) -> Self { + Self(RequestHandlerErrorKind::Http(err)) + } +} + +#[derive(Debug, thiserror::Error)] +pub(crate) enum RequestHandlerErrorKind { + #[error("http error: {0}")] + Http(#[source] reqwest::Error), + + #[error("handler error: {0}")] + Handler(#[source] BoxError<'static>), +} + +#[async_trait] +pub trait RequestHandler: std::fmt::Debug + Send + Sync + 'static { + async fn handle( + &self, + request: Request, + next: RequestHandlerChain<'_>, + ) -> Result; +} + +#[derive(Clone)] +pub struct RequestHandlerFn(F); + +pub fn request_handler_fn(f: F) -> RequestHandlerFn { + RequestHandlerFn(f) +} + +impl std::fmt::Debug for RequestHandlerFn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(stringify!(RequestHandlerFn)).finish() + } +} + +#[async_trait] +impl RequestHandler for RequestHandlerFn +where + F: for<'a> Fn( + Request, + RequestHandlerChain<'a>, + ) -> BoxFuture<'a, Result> + + Send + + Sync + + 'static, +{ + async fn handle( + &self, + request: Request, + next: RequestHandlerChain<'_>, + ) -> Result { + self.0(request, next).await + } +} + +#[async_trait] +impl RequestHandler for std::sync::Arc +where + R: RequestHandler, +{ + async fn handle( + &self, + request: Request, + next: RequestHandlerChain<'_>, + ) -> Result { + self.as_ref().handle(request, next).await + } +} + +#[async_trait] +impl RequestHandler for std::sync::Arc { + async fn handle( + &self, + request: Request, + next: RequestHandlerChain<'_>, + ) -> Result { + self.as_ref().handle(request, next).await + } +} + +pub(crate) trait BoxedRequestHandler: RequestHandler + dyn_clone::DynClone {} + +impl BoxedRequestHandler for T where T: RequestHandler + Clone {} + +dyn_clone::clone_trait_object!(BoxedRequestHandler); + +pub struct RequestHandlerChain<'a> { + client: &'a Client, + chain: &'a [Box], +} + +impl<'a> RequestHandlerChain<'a> { + pub(crate) fn new(client: &'a Client, chain: &'a [Box]) -> Self { + Self { client, chain } + } + + pub fn client(&self) -> &'a Client { + self.client + } + + pub async fn run(mut self, request: Request) -> Result { + if let Some((head, tail)) = self.chain.split_first() { + self.chain = tail; + head.handle(request, self).await + } else { + self.client + .execute(request) + .await + .map_err(RequestHandlerError::http) + } + } +} diff --git a/opensearch/src/http/mod.rs b/opensearch/src/http/mod.rs index 7e131c8f..be4401af 100644 --- a/opensearch/src/http/mod.rs +++ b/opensearch/src/http/mod.rs @@ -34,6 +34,7 @@ pub(crate) mod aws_auth; pub mod headers; +pub mod middleware; pub mod request; pub mod response; pub mod transport; diff --git a/opensearch/src/http/transport.rs b/opensearch/src/http/transport.rs index e8cc4822..881b03eb 100644 --- a/opensearch/src/http/transport.rs +++ b/opensearch/src/http/transport.rs @@ -37,12 +37,13 @@ use crate::cert::CertificateValidation; use crate::{ auth::Credentials, cert::CertificateError, - error::Error, + error::{BoxError, Error}, http::{ headers::{ HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE, DEFAULT_ACCEPT, DEFAULT_CONTENT_TYPE, DEFAULT_USER_AGENT, USER_AGENT, }, + middleware::*, request::Body, response::Response, Method, @@ -54,6 +55,7 @@ use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder}; use bytes::BytesMut; use dyn_clone::clone_trait_object; use lazy_static::lazy_static; +use reqwest::ClientBuilder; use serde::Serialize; use std::{fmt::Debug, io::Write, time::Duration}; use url::Url; @@ -64,6 +66,8 @@ pub(crate) enum BuildError { Proxy(#[source] reqwest::Error), #[error("client configuration error: {0}")] ClientBuilder(#[source] reqwest::Error), + #[error("client initializer error: {0}")] + ClientInitializer(#[source] BoxError<'static>), } /// Default address to OpenSearch running on `http://localhost:9200` @@ -102,7 +106,6 @@ fn build_meta() -> String { /// Builds a HTTP transport to make API calls to OpenSearch pub struct TransportBuilder { - client_builder: reqwest::ClientBuilder, conn_pool: Box, credentials: Option, #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] @@ -116,6 +119,9 @@ pub struct TransportBuilder { sigv4_service_name: String, #[cfg(feature = "aws-auth")] sigv4_time_source: Option, + client_initializers: Vec>, + request_initializers: Vec>, + request_handlers: Vec>, } impl TransportBuilder { @@ -126,7 +132,6 @@ impl TransportBuilder { P: ConnectionPool + Debug + Clone + Send + 'static, { Self { - client_builder: reqwest::ClientBuilder::new(), conn_pool: Box::new(conn_pool), credentials: None, #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] @@ -140,6 +145,9 @@ impl TransportBuilder { sigv4_service_name: "es".to_string(), #[cfg(feature = "aws-auth")] sigv4_time_source: None, + client_initializers: Vec::new(), + request_initializers: Vec::new(), + request_handlers: Vec::new(), } } @@ -225,9 +233,186 @@ impl TransportBuilder { self } + /// Adds a [ClientInitializer] to the stack of initializers that will be called when the underlying [reqwest::Client] is being constructed. + /// + /// Initializers are called in the order they are added. + /// + /// # Example + /// ```rust,no_run + /// use opensearch::http::{middleware::*, transport::*}; + /// + /// struct Initializer; + /// + /// impl ClientInitializer for Initializer { + /// type Result = Result; + /// + /// fn init(self, builder: reqwest::ClientBuilder) -> Self::Result { + /// let addr = "12.4.1.8".parse::()?; + /// Ok(builder.local_address(addr)) + /// } + /// } + /// + /// fn might_fail(builder: reqwest::ClientBuilder) -> Result> { + /// let url = std::env::var("PROXY_URL")?; + /// Ok(builder.proxy(reqwest::Proxy::all(url)?)) + /// } + /// + /// let transport: Transport = TransportBuilder::default() + /// .with_client_initializer(Initializer) + /// .with_client_initializer(might_fail) + /// .with_client_initializer(|client_builder: reqwest::ClientBuilder| client_builder.redirect(reqwest::redirect::Policy::limited(1))) + /// .build()?; + /// # Ok::<(), opensearch::Error>(()) + /// ``` + pub fn with_client_initializer(mut self, init: impl ClientInitializer) -> Self { + self.client_initializers.push(Box::new(init)); + self + } + + /// Adds a [RequestInitializer] to the stack of initializers that will be called when an underlying [reqwest::Request] is being constructed. + /// + /// Initializers are called in the order they are added. + /// + /// # Example + /// ```rust,no_run + /// use opensearch::http::{middleware::*, transport::*}; + /// use std::sync::{Arc, atomic::{AtomicUsize, Ordering}}; + /// + /// #[derive(Debug, Clone)] + /// struct Counter(Arc); + /// + /// impl Counter { + /// fn new() -> Self { + /// Self(Arc::new(AtomicUsize::new(0))) + /// } + /// } + /// + /// impl RequestInitializer for Counter { + /// type Result = reqwest::RequestBuilder; + /// + /// fn init(&self, request: reqwest::RequestBuilder) -> Self::Result { + /// let counter = self.0.fetch_add(1, Ordering::SeqCst); + /// request.header("x-request-id", format!("req-{}", counter)) + /// } + /// } + /// + /// let transport: Transport = TransportBuilder::default() + /// .with_initializer(Counter::new()) + /// .build()?; + /// # Ok::<(), opensearch::Error>(()) + /// ``` + pub fn with_initializer(mut self, init: impl RequestInitializer + Clone) -> Self { + self.request_initializers.push(Box::new(init)); + self + } + + /// Adds a [RequestInitializer] to the stack of initializers that will be called when an underlying [reqwest::Request] is being constructed. + /// + /// Initializers are called in the order they are added. + /// + /// # Example + /// ```rust,no_run + /// use opensearch::http::{middleware::*, transport::*}; + /// use std::sync::{Arc, atomic::{AtomicUsize, Ordering}}; + /// + /// let counter = Arc::new(AtomicUsize::new(0)); + /// let transport: Transport = TransportBuilder::default() + /// .with_initializer_fn(move |request_builder: reqwest::RequestBuilder| { + /// let counter = counter.fetch_add(1, Ordering::SeqCst); + /// request_builder.header("x-request-id", format!("req-{}", counter)) + /// }) + /// .build()?; + /// # Ok::<(), opensearch::Error>(()) + /// ``` + pub fn with_initializer_fn(self, init: F) -> Self + where + F: Fn(reqwest::RequestBuilder) -> R + Clone + Send + Sync + 'static, + R: InitializerResult, + { + self.with_initializer(request_initializer_fn(init)) + } + + /// Adds a [RequestHandler] to the stack of handlers that will be called when an underlying [reqwest::Request] is being sent. + /// + /// Handlers are called in the order they are added. + /// + /// # Example + /// ```rust,no_run + /// use opensearch::http::{ + /// middleware::{async_trait, RequestHandler, RequestHandlerChain, RequestHandlerError}, + /// reqwest::{Request, Response}, + /// transport::{Transport, TransportBuilder}, + /// }; + /// + /// #[derive(Debug, Clone)] + /// struct Logger; + /// + /// #[async_trait] + /// impl RequestHandler for Logger { + /// async fn handle(&self, request: Request, next: RequestHandlerChain<'_>) -> Result { + /// println!("sending request to {}", request.url()); + /// let now = std::time::Instant::now(); + /// let res = next.run(request).await?; + /// println!("request completed ({:?})", now.elapsed()); + /// Ok(res) + /// } + /// } + /// + /// let transport: Transport = TransportBuilder::default() + /// .with_handler(Logger) + /// .build()?; + /// # Ok::<(), opensearch::Error>(()) + /// ``` + pub fn with_handler(mut self, handler: impl RequestHandler + Clone) -> Self { + self.request_handlers.push(Box::new(handler)); + self + } + + /// Adds a [RequestHandler] to the stack of handlers that will be called when an underlying [reqwest::Request] is being sent. + /// + /// Handlers are called in the order they are added. + /// + /// # Example + /// ```rust,no_run + /// use opensearch::http::{ + /// middleware::{RequestHandler, RequestHandlerChain, RequestHandlerError}, + /// reqwest::{Request, Response}, + /// transport::{Transport, TransportBuilder}, + /// }; + /// use std::{future::Future, pin::Pin}; + /// + /// fn logger(req: Request, next: RequestHandlerChain<'_>) -> Pin> + Send + '_>> { + /// Box::pin(async move { + /// println!("sending request to {}", req.url()); + /// let now = std::time::Instant::now(); + /// let res = next.run(req).await?; + /// println!("request completed ({:?})", now.elapsed()); + /// Ok(res) + /// }) + /// } + /// + /// let transport: Transport = TransportBuilder::default() + /// .with_handler_fn(logger) + /// .build()?; + /// # Ok::<(), opensearch::Error>(()) + /// ``` + pub fn with_handler_fn(self, handler: F) -> Self + where + F: for<'a> Fn( + reqwest::Request, + RequestHandlerChain<'a>, + ) -> BoxFuture<'a, Result> + + Clone + + Send + + Sync + + 'static, + { + self.with_handler(request_handler_fn(handler)) + } + /// Builds a [Transport] to use to send API calls to OpenSearch. pub fn build(self) -> Result { - let mut client_builder = self.client_builder; + let mut client_builder = ClientBuilder::new(); if let Some(t) = self.timeout { client_builder = client_builder.timeout(t); @@ -290,6 +475,14 @@ impl TransportBuilder { client_builder = client_builder.proxy(proxy); } + client_builder = self + .client_initializers + .into_iter() + .try_fold(client_builder, |client_builder, init| { + init.init(client_builder) + }) + .map_err(BuildError::ClientInitializer)?; + let client = client_builder.build().map_err(BuildError::ClientBuilder)?; Ok(Transport { client, @@ -300,6 +493,8 @@ impl TransportBuilder { sigv4_service_name: self.sigv4_service_name, #[cfg(feature = "aws-auth")] sigv4_time_source: self.sigv4_time_source.unwrap_or_default(), + request_initializers: self.request_initializers.into_boxed_slice(), + request_handlers: self.request_handlers.into_boxed_slice(), }) } } @@ -344,6 +539,8 @@ pub struct Transport { sigv4_service_name: String, #[cfg(feature = "aws-auth")] sigv4_time_source: SharedTimeSource, + request_initializers: Box<[Box]>, + request_handlers: Box<[Box]>, } impl Transport { @@ -451,6 +648,14 @@ impl Transport { request_builder = request_builder.query(q); } + request_builder = self + .request_initializers + .iter() + .try_fold(request_builder, |request_builder, init| { + init.init(request_builder) + }) + .map_err(Error::request_initializer)?; + #[cfg_attr(not(feature = "aws-auth"), allow(unused_mut))] let mut request = request_builder.build()?; @@ -466,7 +671,9 @@ impl Transport { .await?; } - let response = self.client.execute(request).await?; + let response = RequestHandlerChain::new(&self.client, &self.request_handlers) + .run(request) + .await?; Ok(Response::new(response, method)) } diff --git a/opensearch/tests/common/mod.rs b/opensearch/tests/common/mod.rs index 5c879006..e002a676 100644 --- a/opensearch/tests/common/mod.rs +++ b/opensearch/tests/common/mod.rs @@ -28,6 +28,8 @@ * GitHub history for details. */ +#![allow(unused)] + pub mod client; pub mod server; diff --git a/opensearch/tests/middleware.rs b/opensearch/tests/middleware.rs new file mode 100644 index 00000000..c155a1f6 --- /dev/null +++ b/opensearch/tests/middleware.rs @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +mod common; + +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +use opensearch::http::middleware::{ + async_trait, RequestHandler, RequestHandlerChain, RequestHandlerError, RequestInitializer, +}; +use reqwest::RequestBuilder; + +use crate::common::{server::MockServer, tracing_init}; + +#[tokio::test] +async fn request_initializer() -> anyhow::Result<()> { + #[derive(Debug, Clone)] + struct Counter(Arc); + + impl RequestInitializer for Counter { + type Result = RequestBuilder; + + fn init(&self, request: RequestBuilder) -> Self::Result { + let counter = self.0.fetch_add(1, Ordering::SeqCst); + request.header("x-counter", counter.to_string()) + } + } + + tracing_init(); + + let mut server = MockServer::start()?; + + let counter_fn = { + let counter = Arc::new(AtomicUsize::new(1)); + move |request_builder: RequestBuilder| { + let counter = counter.fetch_add(1, Ordering::SeqCst); + request_builder.header("x-counter-fn", counter.to_string()) + } + }; + + let client = server.client_with(|b| { + b.with_initializer(Counter(Arc::new(AtomicUsize::new(101)))) + .with_initializer_fn(counter_fn) + }); + + client.ping().send().await?; + client.ping().send().await?; + + let req1 = server.received_request().await?; + let req2 = server.received_request().await?; + + assert_eq!(req1.header("x-counter-fn"), Some("1")); + assert_eq!(req1.header("x-counter"), Some("101")); + + assert_eq!(req2.header("x-counter-fn"), Some("2")); + assert_eq!(req2.header("x-counter"), Some("102")); + + Ok(()) +} + +#[tokio::test] +async fn request_handler() -> anyhow::Result<()> { + #[derive(Debug, Clone)] + struct Handler(Arc); + + #[async_trait] + impl RequestHandler for Handler { + async fn handle( + &self, + request: reqwest::Request, + next: RequestHandlerChain<'_>, + ) -> Result { + self.0.fetch_add(1, Ordering::SeqCst); + next.run(request).await + } + } + + tracing_init(); + + let server = MockServer::start()?; + + let handler_called = Arc::new(AtomicUsize::new(0)); + + let client = server.client_with(|b| b.with_handler(Handler(handler_called.clone()))); + + client.ping().send().await?; + + assert_eq!(handler_called.load(Ordering::SeqCst), 1); + + client.ping().send().await?; + + assert_eq!(handler_called.load(Ordering::SeqCst), 2); + + Ok(()) +}