From b8286c8fe2074535762e188436c972048b7a34c1 Mon Sep 17 00:00:00 2001 From: Thomas Farr Date: Thu, 14 Dec 2023 10:07:40 +1300 Subject: [PATCH 1/5] Implement middleware types to allow intercepting client & request handling Signed-off-by: Thomas Farr --- opensearch/Cargo.toml | 1 + opensearch/src/client.rs | 6 + opensearch/src/error.rs | 28 ++- .../http/middleware/initializers/client.rs | 47 ++++ .../src/http/middleware/initializers/mod.rs | 44 ++++ .../http/middleware/initializers/request.rs | 64 +++++ opensearch/src/http/middleware/mod.rs | 20 ++ .../src/http/middleware/request_pipeline.rs | 105 +++++++++ opensearch/src/http/mod.rs | 1 + opensearch/src/http/transport.rs | 218 +++++++++++++++++- opensearch/tests/common/mod.rs | 2 + opensearch/tests/middleware.rs | 106 +++++++++ 12 files changed, 635 insertions(+), 7 deletions(-) create mode 100644 opensearch/src/http/middleware/initializers/client.rs create mode 100644 opensearch/src/http/middleware/initializers/mod.rs create mode 100644 opensearch/src/http/middleware/initializers/request.rs create mode 100644 opensearch/src/http/middleware/mod.rs create mode 100644 opensearch/src/http/middleware/request_pipeline.rs create mode 100644 opensearch/tests/middleware.rs diff --git a/opensearch/Cargo.toml b/opensearch/Cargo.toml index 5d78c6fc..fc730c82 100644 --- a/opensearch/Cargo.toml +++ b/opensearch/Cargo.toml @@ -29,6 +29,7 @@ rustls-tls = ["reqwest/rustls-tls"] aws-auth = ["aws-credential-types", "aws-sigv4", "aws-smithy-runtime-api", "aws-types"] [dependencies] +async-trait = "0.1" base64 = "0.21" bytes = "1.0" dyn-clone = "1" diff --git a/opensearch/src/client.rs b/opensearch/src/client.rs index 21c937aa..2e2d15ea 100644 --- a/opensearch/src/client.rs +++ b/opensearch/src/client.rs @@ -104,3 +104,9 @@ impl OpenSearch { .await } } + +// Ensure that the client is `Send` and `Sync` +const _: () = { + const fn is_send_sync() {} + is_send_sync::() +}; diff --git a/opensearch/src/error.rs b/opensearch/src/error.rs index 4fe10ae4..ae34e781 100644 --- a/opensearch/src/error.rs +++ b/opensearch/src/error.rs @@ -35,7 +35,10 @@ use crate::{ cert::CertificateError, - http::{transport, StatusCode}, + http::{ + middleware::{RequestPipelineError, RequestPipelineErrorKind}, + transport, StatusCode, + }, }; pub(crate) type BoxError<'a> = Box; @@ -53,7 +56,7 @@ where Kind: From, { fn from(error: E) -> Self { - Self(Kind::from(error)) + Self(error.into()) } } @@ -80,11 +83,32 @@ enum Kind { #[cfg(feature = "aws-auth")] #[error("AwsSigV4 error: {0}")] AwsSigV4(#[from] crate::http::aws_auth::AwsSigV4Error), + + #[error("request initializer error: {0}")] + RequestInitializer(#[source] BoxError<'static>), + + #[error("request pipeline error: {0}")] + RequestPipeline(#[source] BoxError<'static>), +} + +impl From for Kind { + fn from(err: RequestPipelineError) -> Self { + use RequestPipelineErrorKind::*; + + match err.0 { + Pipeline(err) => Self::RequestPipeline(err), + Http(err) => Self::Http(err), + } + } } use Kind::*; impl Error { + pub(crate) fn request_initializer(err: BoxError<'static>) -> Self { + Self(RequestInitializer(err)) + } + /// The status code, if the error was generated from a response pub fn status_code(&self) -> Option { match &self.0 { diff --git a/opensearch/src/http/middleware/initializers/client.rs b/opensearch/src/http/middleware/initializers/client.rs new file mode 100644 index 00000000..006752c1 --- /dev/null +++ b/opensearch/src/http/middleware/initializers/client.rs @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +use super::InitializerResult; +use crate::BoxError; +use reqwest::ClientBuilder; + +pub trait ClientInitializer: 'static { + type Result: InitializerResult; + + fn init(self, client: ClientBuilder) -> Self::Result; +} + +impl ClientInitializer for F +where + F: FnOnce(ClientBuilder) -> R + 'static, + R: InitializerResult, +{ + type Result = R; + + fn init(self, client: ClientBuilder) -> Self::Result { + self(client) + } +} + +pub(crate) trait BoxedClientInitializer { + fn init(self: Box, client: ClientBuilder) -> Result>; +} + +impl BoxedClientInitializer for T +where + T: ClientInitializer + Sized, +{ + fn init(self: Box, client: ClientBuilder) -> Result> { + ClientInitializer::init(*self, client) + .into_result() + .map_err(Into::into) + } +} diff --git a/opensearch/src/http/middleware/initializers/mod.rs b/opensearch/src/http/middleware/initializers/mod.rs new file mode 100644 index 00000000..d39e67c4 --- /dev/null +++ b/opensearch/src/http/middleware/initializers/mod.rs @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +mod client; +mod request; + +use crate::BoxError; +use std::convert::Infallible; + +pub use client::*; +pub use request::*; + +pub trait InitializerResult { + type Error: Into>; + + fn into_result(self) -> Result; +} + +impl InitializerResult for Result +where + E: Into>, +{ + type Error = E; + + fn into_result(self) -> Result { + self + } +} + +impl InitializerResult for T { + type Error = Infallible; + + fn into_result(self) -> Result { + Ok(self) + } +} diff --git a/opensearch/src/http/middleware/initializers/request.rs b/opensearch/src/http/middleware/initializers/request.rs new file mode 100644 index 00000000..906f3103 --- /dev/null +++ b/opensearch/src/http/middleware/initializers/request.rs @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +use super::InitializerResult; +use crate::BoxError; +use reqwest::RequestBuilder; + +pub trait RequestInitializer: Clone + std::fmt::Debug + Send + Sync + 'static { + type Result: InitializerResult; + + fn init(&self, request: RequestBuilder) -> Self::Result; +} + +#[derive(Clone)] +pub struct RequestInitializerFn(F); + +pub fn request_initializer_fn(f: F) -> RequestInitializerFn { + RequestInitializerFn(f) +} + +impl std::fmt::Debug for RequestInitializerFn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(stringify!(RequestInitializerFn)).finish() + } +} + +impl RequestInitializer for RequestInitializerFn +where + F: Fn(RequestBuilder) -> R + Clone + Send + Sync + 'static, + R: InitializerResult, +{ + type Result = R; + + fn init(&self, request: RequestBuilder) -> Self::Result { + self.0(request) + } +} + +pub(crate) trait BoxedRequestInitializer: + dyn_clone::DynClone + std::fmt::Debug + Send + Sync + 'static +{ + fn init(&self, request: RequestBuilder) -> Result>; +} + +impl BoxedRequestInitializer for T +where + T: RequestInitializer, +{ + fn init(&self, request: RequestBuilder) -> Result> { + RequestInitializer::init(self, request) + .into_result() + .map_err(Into::into) + } +} + +dyn_clone::clone_trait_object!(BoxedRequestInitializer); diff --git a/opensearch/src/http/middleware/mod.rs b/opensearch/src/http/middleware/mod.rs new file mode 100644 index 00000000..c26528cc --- /dev/null +++ b/opensearch/src/http/middleware/mod.rs @@ -0,0 +1,20 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +mod initializers; +mod request_pipeline; + +pub use async_trait::async_trait; +pub use initializers::*; +pub use request_pipeline::*; + +pub(crate) type BoxFuture<'a, T> = + std::pin::Pin + Send + 'a>>; diff --git a/opensearch/src/http/middleware/request_pipeline.rs b/opensearch/src/http/middleware/request_pipeline.rs new file mode 100644 index 00000000..6e7ecda9 --- /dev/null +++ b/opensearch/src/http/middleware/request_pipeline.rs @@ -0,0 +1,105 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +use super::{async_trait, BoxFuture}; +use crate::BoxError; +use reqwest::{Client, Request, Response}; + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct RequestPipelineError(pub(crate) RequestPipelineErrorKind); + +impl RequestPipelineError { + pub fn new(err: impl Into>) -> Self { + Self(RequestPipelineErrorKind::Pipeline(err.into())) + } + + fn http(err: reqwest::Error) -> Self { + Self(RequestPipelineErrorKind::Http(err)) + } +} + +#[derive(Debug, thiserror::Error)] +pub(crate) enum RequestPipelineErrorKind { + #[error("http error: {0}")] + Http(#[source] reqwest::Error), + + #[error("pipeline error: {0}")] + Pipeline(#[source] BoxError<'static>), +} + +#[async_trait] +pub trait RequestHandler: dyn_clone::DynClone + std::fmt::Debug + Send + Sync + 'static { + async fn handle( + &self, + request: Request, + next: RequestPipeline<'_>, + ) -> Result; +} + +dyn_clone::clone_trait_object!(RequestHandler); + +#[derive(Clone)] +pub struct RequestHandlerFn(F); + +pub fn request_handler_fn(f: F) -> RequestHandlerFn { + RequestHandlerFn(f) +} + +impl std::fmt::Debug for RequestHandlerFn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(stringify!(RequestHandlerFn)).finish() + } +} + +#[async_trait] +impl RequestHandler for RequestHandlerFn +where + F: for<'a> Fn( + Request, + RequestPipeline<'a>, + ) -> BoxFuture<'a, Result> + + Clone + + Send + + Sync + + 'static, +{ + async fn handle( + &self, + request: Request, + next: RequestPipeline<'_>, + ) -> Result { + self.0(request, next).await + } +} + +pub struct RequestPipeline<'a> { + pub client: &'a Client, + pipeline: &'a [Box], +} + +impl<'a> RequestPipeline<'a> { + pub(crate) fn new(client: &'a Client, pipeline: &'a [Box]) -> Self { + Self { client, pipeline } + } + + pub async fn run(mut self, request: Request) -> Result { + if let Some((head, tail)) = self.pipeline.split_first() { + self.pipeline = tail; + head.handle(request, self).await + } else { + self.client + .execute(request) + .await + .map_err(RequestPipelineError::http) + } + } +} diff --git a/opensearch/src/http/mod.rs b/opensearch/src/http/mod.rs index 7e131c8f..be4401af 100644 --- a/opensearch/src/http/mod.rs +++ b/opensearch/src/http/mod.rs @@ -34,6 +34,7 @@ pub(crate) mod aws_auth; pub mod headers; +pub mod middleware; pub mod request; pub mod response; pub mod transport; diff --git a/opensearch/src/http/transport.rs b/opensearch/src/http/transport.rs index e8cc4822..bf12edd9 100644 --- a/opensearch/src/http/transport.rs +++ b/opensearch/src/http/transport.rs @@ -37,12 +37,13 @@ use crate::cert::CertificateValidation; use crate::{ auth::Credentials, cert::CertificateError, - error::Error, + error::{BoxError, Error}, http::{ headers::{ HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE, DEFAULT_ACCEPT, DEFAULT_CONTENT_TYPE, DEFAULT_USER_AGENT, USER_AGENT, }, + middleware::*, request::Body, response::Response, Method, @@ -54,6 +55,7 @@ use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder}; use bytes::BytesMut; use dyn_clone::clone_trait_object; use lazy_static::lazy_static; +use reqwest::ClientBuilder; use serde::Serialize; use std::{fmt::Debug, io::Write, time::Duration}; use url::Url; @@ -64,6 +66,8 @@ pub(crate) enum BuildError { Proxy(#[source] reqwest::Error), #[error("client configuration error: {0}")] ClientBuilder(#[source] reqwest::Error), + #[error("client initializer error: {0}")] + ClientInitializer(#[source] BoxError<'static>), } /// Default address to OpenSearch running on `http://localhost:9200` @@ -102,7 +106,6 @@ fn build_meta() -> String { /// Builds a HTTP transport to make API calls to OpenSearch pub struct TransportBuilder { - client_builder: reqwest::ClientBuilder, conn_pool: Box, credentials: Option, #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] @@ -116,6 +119,9 @@ pub struct TransportBuilder { sigv4_service_name: String, #[cfg(feature = "aws-auth")] sigv4_time_source: Option, + init_stack: Vec>, + req_init_stack: Vec>, + req_handler_stack: Vec>, } impl TransportBuilder { @@ -126,7 +132,6 @@ impl TransportBuilder { P: ConnectionPool + Debug + Clone + Send + 'static, { Self { - client_builder: reqwest::ClientBuilder::new(), conn_pool: Box::new(conn_pool), credentials: None, #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] @@ -140,6 +145,9 @@ impl TransportBuilder { sigv4_service_name: "es".to_string(), #[cfg(feature = "aws-auth")] sigv4_time_source: None, + init_stack: Vec::new(), + req_init_stack: Vec::new(), + req_handler_stack: Vec::new(), } } @@ -225,9 +233,187 @@ impl TransportBuilder { self } + /// Adds a [ClientInitializer] to the stack of initializers that will be called when the underlying [reqwest::Client] is being constructed. + /// + /// Initializers are called in the order they are added. + /// + /// # Example + /// ```rust,no_run + /// use opensearch::http::{middleware::*, transport::*}; + /// + /// struct Initializer; + /// + /// impl ClientInitializer for Initializer { + /// type Result = Result; + /// + /// fn init(self, builder: reqwest::ClientBuilder) -> Self::Result { + /// let addr = "12.4.1.8".parse::()?; + /// Ok(builder.local_address(addr)) + /// } + /// } + /// + /// fn might_fail(builder: reqwest::ClientBuilder) -> Result> { + /// let url = std::env::var("PROXY_URL")?; + /// Ok(builder.proxy(reqwest::Proxy::all(url)?)) + /// } + /// + /// let transport: Transport = TransportBuilder::default() + /// .with_init(Initializer) + /// .with_init(might_fail) + /// .with_init(|client_builder: reqwest::ClientBuilder| client_builder.redirect(reqwest::redirect::Policy::limited(1))) + /// .build()?; + /// # Ok::<(), opensearch::Error>(()) + /// ``` + pub fn with_init(mut self, init: impl ClientInitializer) -> Self { + self.init_stack.push(Box::new(init)); + self + } + + /// Adds a [RequestInitializer] to the stack of initializers that will be called when an underlying [reqwest::Request] is being constructed. + /// + /// Initializers are called in the order they are added. + /// + /// # Example + /// ```rust,no_run + /// use opensearch::http::{middleware::*, transport::*}; + /// use std::sync::{Arc, atomic::{AtomicUsize, Ordering}}; + /// + /// #[derive(Debug, Clone)] + /// struct Counter(Arc); + /// + /// impl Counter { + /// fn new() -> Self { + /// Self(Arc::new(AtomicUsize::new(0))) + /// } + /// } + /// + /// impl RequestInitializer for Counter { + /// type Result = reqwest::RequestBuilder; + /// + /// fn init(&self, request: reqwest::RequestBuilder) -> Self::Result { + /// let counter = self.0.fetch_add(1, Ordering::SeqCst); + /// request.header("x-request-id", format!("req-{}", counter)) + /// } + /// } + /// + /// let transport: Transport = TransportBuilder::default() + /// .with_req_init(Counter::new()) + /// .build()?; + /// # Ok::<(), opensearch::Error>(()) + /// ``` + pub fn with_req_init(mut self, init: impl RequestInitializer) -> Self { + self.req_init_stack.push(Box::new(init)); + self + } + + /// Adds a [RequestInitializer] to the stack of initializers that will be called when an underlying [reqwest::Request] is being constructed. + /// + /// Initializers are called in the order they are added. + /// + /// # Example + /// ```rust,no_run + /// use opensearch::http::{middleware::*, transport::*}; + /// use std::sync::{Arc, atomic::{AtomicUsize, Ordering}}; + /// + /// let counter = Arc::new(AtomicUsize::new(0)); + /// let transport: Transport = TransportBuilder::default() + /// .with_req_init_fn(move |request_builder: reqwest::RequestBuilder| { + /// let counter = counter.fetch_add(1, Ordering::SeqCst); + /// request_builder.header("x-request-id", format!("req-{}", counter)) + /// }) + /// .build()?; + /// # Ok::<(), opensearch::Error>(()) + /// ``` + pub fn with_req_init_fn(self, init: F) -> Self + where + F: Fn(reqwest::RequestBuilder) -> R + Clone + Send + Sync + 'static, + R: InitializerResult, + { + self.with_req_init(request_initializer_fn(init)) + } + + /// Adds a [RequestHandler] to the stack of handlers that will be called when an underlying [reqwest::Request] is being sent. + /// + /// Handlers are called in the order they are added. + /// + /// # Example + /// ```rust,no_run + /// use opensearch::http::{ + /// middleware::{async_trait, RequestHandler, RequestPipeline, RequestPipelineError}, + /// reqwest::{Request, Response}, + /// transport::{Transport, TransportBuilder}, + /// }; + /// + /// #[derive(Debug, Clone)] + /// struct Logger; + /// + /// #[async_trait] + /// impl RequestHandler for Logger { + /// async fn handle(&self, request: Request, next: RequestPipeline<'_>) -> Result { + /// println!("sending request to {}", request.url()); + /// let now = std::time::Instant::now(); + /// let res = next.run(request).await?; + /// println!("request completed ({:?})", now.elapsed()); + /// Ok(res) + /// } + /// } + /// + /// let transport: Transport = TransportBuilder::default() + /// .with_handler(Logger) + /// .build()?; + /// # Ok::<(), opensearch::Error>(()) + /// ``` + pub fn with_handler(mut self, handler: impl RequestHandler) -> Self { + self.req_handler_stack.push(Box::new(handler)); + self + } + + /// Adds a [RequestHandler] to the stack of handlers that will be called when an underlying [reqwest::Request] is being sent. + /// + /// Handlers are called in the order they are added. + /// + /// # Example + /// ```rust,no_run + /// use opensearch::http::{ + /// middleware::{RequestHandler, RequestPipeline, RequestPipelineError}, + /// reqwest::{Request, Response}, + /// transport::{Transport, TransportBuilder}, + /// }; + /// use std::{future::Future, pin::Pin}; + /// + /// fn logger(req: Request, next: RequestPipeline<'_>) -> Pin> + Send + '_>> { + /// Box::pin(async move { + /// println!("sending request to {}", req.url()); + /// let now = std::time::Instant::now(); + /// let res = next.run(req).await?; + /// println!("request completed ({:?})", now.elapsed()); + /// Ok(res) + /// }) + /// } + /// + /// let transport: Transport = TransportBuilder::default() + /// .with_handler_fn(logger) + /// .build()?; + /// # Ok::<(), opensearch::Error>(()) + /// ``` + pub fn with_handler_fn(self, handler: F) -> Self + where + F: for<'a> Fn( + reqwest::Request, + RequestPipeline<'a>, + ) + -> BoxFuture<'a, Result> + + Clone + + Send + + Sync + + 'static, + { + self.with_handler(request_handler_fn(handler)) + } + /// Builds a [Transport] to use to send API calls to OpenSearch. pub fn build(self) -> Result { - let mut client_builder = self.client_builder; + let mut client_builder = ClientBuilder::new(); if let Some(t) = self.timeout { client_builder = client_builder.timeout(t); @@ -290,6 +476,14 @@ impl TransportBuilder { client_builder = client_builder.proxy(proxy); } + client_builder = self + .init_stack + .into_iter() + .try_fold(client_builder, |client_builder, init| { + init.init(client_builder) + }) + .map_err(BuildError::ClientInitializer)?; + let client = client_builder.build().map_err(BuildError::ClientBuilder)?; Ok(Transport { client, @@ -300,6 +494,8 @@ impl TransportBuilder { sigv4_service_name: self.sigv4_service_name, #[cfg(feature = "aws-auth")] sigv4_time_source: self.sigv4_time_source.unwrap_or_default(), + req_init_stack: self.req_init_stack.into_boxed_slice(), + req_handler_stack: self.req_handler_stack.into_boxed_slice(), }) } } @@ -344,6 +540,8 @@ pub struct Transport { sigv4_service_name: String, #[cfg(feature = "aws-auth")] sigv4_time_source: SharedTimeSource, + req_init_stack: Box<[Box]>, + req_handler_stack: Box<[Box]>, } impl Transport { @@ -451,6 +649,14 @@ impl Transport { request_builder = request_builder.query(q); } + request_builder = self + .req_init_stack + .iter() + .try_fold(request_builder, |request_builder, init| { + init.init(request_builder) + }) + .map_err(Error::request_initializer)?; + #[cfg_attr(not(feature = "aws-auth"), allow(unused_mut))] let mut request = request_builder.build()?; @@ -466,7 +672,9 @@ impl Transport { .await?; } - let response = self.client.execute(request).await?; + let response = RequestPipeline::new(&self.client, &self.req_handler_stack) + .run(request) + .await?; Ok(Response::new(response, method)) } diff --git a/opensearch/tests/common/mod.rs b/opensearch/tests/common/mod.rs index 5c879006..e002a676 100644 --- a/opensearch/tests/common/mod.rs +++ b/opensearch/tests/common/mod.rs @@ -28,6 +28,8 @@ * GitHub history for details. */ +#![allow(unused)] + pub mod client; pub mod server; diff --git a/opensearch/tests/middleware.rs b/opensearch/tests/middleware.rs new file mode 100644 index 00000000..ba01c64e --- /dev/null +++ b/opensearch/tests/middleware.rs @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +mod common; + +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +use opensearch::http::middleware::{ + async_trait, RequestHandler, RequestInitializer, RequestPipeline, RequestPipelineError, +}; +use reqwest::RequestBuilder; + +use crate::common::{server::MockServer, tracing_init}; + +#[tokio::test] +async fn request_initializer() -> anyhow::Result<()> { + #[derive(Debug, Clone)] + struct Counter(Arc); + + impl RequestInitializer for Counter { + type Result = RequestBuilder; + + fn init(&self, request: RequestBuilder) -> Self::Result { + let counter = self.0.fetch_add(1, Ordering::SeqCst); + request.header("x-counter", counter.to_string()) + } + } + + tracing_init(); + + let mut server = MockServer::start()?; + + let counter_fn = { + let counter = Arc::new(AtomicUsize::new(1)); + move |request_builder: RequestBuilder| { + let counter = counter.fetch_add(1, Ordering::SeqCst); + request_builder.header("x-counter-fn", counter.to_string()) + } + }; + + let client = server.client_with(|b| { + b.with_req_init(Counter(Arc::new(AtomicUsize::new(101)))) + .with_req_init_fn(counter_fn) + }); + + client.ping().send().await?; + client.ping().send().await?; + + let req1 = server.received_request().await?; + let req2 = server.received_request().await?; + + assert_eq!(req1.header("x-counter-fn"), Some("1")); + assert_eq!(req1.header("x-counter"), Some("101")); + + assert_eq!(req2.header("x-counter-fn"), Some("2")); + assert_eq!(req2.header("x-counter"), Some("102")); + + Ok(()) +} + +#[tokio::test] +async fn request_handler() -> anyhow::Result<()> { + #[derive(Debug, Clone)] + struct Handler(Arc); + + #[async_trait] + impl RequestHandler for Handler { + async fn handle( + &self, + request: reqwest::Request, + next: RequestPipeline<'_>, + ) -> Result { + self.0.fetch_add(1, Ordering::SeqCst); + next.run(request).await + } + } + + tracing_init(); + + let server = MockServer::start()?; + + let handler_called = Arc::new(AtomicUsize::new(0)); + + let client = server.client_with(|b| b.with_handler(Handler(handler_called.clone()))); + + client.ping().send().await?; + + assert_eq!(handler_called.load(Ordering::SeqCst), 1); + + client.ping().send().await?; + + assert_eq!(handler_called.load(Ordering::SeqCst), 2); + + Ok(()) +} From 4d3c50dd15c594405942a6563701f1311a5a351c Mon Sep 17 00:00:00 2001 From: Thomas Farr Date: Thu, 14 Dec 2023 10:12:18 +1300 Subject: [PATCH 2/5] Add changelog entry Signed-off-by: Thomas Farr --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bcd4c69c..af1c06ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Internalized the `BuildError` type, consolidating on the `Error` type ([#228](https://github.com/opensearch-project/opensearch-rs/pull/228)) ### Added +- Added middleware types to allow intercepting construction and handling of the underlying `reqwest` client & requests ([#232](https://github.com/opensearch-project/opensearch-rs/pull/232)) ### Dependencies - Bumps `aws-*` dependencies to `1` ([#219](https://github.com/opensearch-project/opensearch-rs/pull/219)) From bc41b3d2f5597f482d5436dfa6ca91aa075db898 Mon Sep 17 00:00:00 2001 From: Thomas Farr Date: Thu, 14 Dec 2023 12:58:37 +1300 Subject: [PATCH 3/5] Relax position of clone bounds Signed-off-by: Thomas Farr --- .../http/middleware/initializers/request.rs | 28 +++++++++++-- .../src/http/middleware/request_pipeline.rs | 40 ++++++++++++++++--- opensearch/src/http/transport.rs | 8 ++-- 3 files changed, 63 insertions(+), 13 deletions(-) diff --git a/opensearch/src/http/middleware/initializers/request.rs b/opensearch/src/http/middleware/initializers/request.rs index 906f3103..f657588c 100644 --- a/opensearch/src/http/middleware/initializers/request.rs +++ b/opensearch/src/http/middleware/initializers/request.rs @@ -13,7 +13,7 @@ use super::InitializerResult; use crate::BoxError; use reqwest::RequestBuilder; -pub trait RequestInitializer: Clone + std::fmt::Debug + Send + Sync + 'static { +pub trait RequestInitializer: std::fmt::Debug + Send + Sync + 'static { type Result: InitializerResult; fn init(&self, request: RequestBuilder) -> Self::Result; @@ -34,7 +34,7 @@ impl std::fmt::Debug for RequestInitializerFn { impl RequestInitializer for RequestInitializerFn where - F: Fn(RequestBuilder) -> R + Clone + Send + Sync + 'static, + F: Fn(RequestBuilder) -> R + Send + Sync + 'static, R: InitializerResult, { type Result = R; @@ -44,6 +44,28 @@ where } } +impl RequestInitializer for std::sync::Arc +where + R: RequestInitializer, +{ + type Result = R::Result; + + fn init(&self, request: RequestBuilder) -> Self::Result { + self.as_ref().init(request) + } +} + +impl RequestInitializer for std::sync::Arc> +where + R: InitializerResult + 'static, +{ + type Result = R; + + fn init(&self, request: RequestBuilder) -> Self::Result { + self.as_ref().init(request) + } +} + pub(crate) trait BoxedRequestInitializer: dyn_clone::DynClone + std::fmt::Debug + Send + Sync + 'static { @@ -52,7 +74,7 @@ pub(crate) trait BoxedRequestInitializer: impl BoxedRequestInitializer for T where - T: RequestInitializer, + T: RequestInitializer + Clone, { fn init(&self, request: RequestBuilder) -> Result> { RequestInitializer::init(self, request) diff --git a/opensearch/src/http/middleware/request_pipeline.rs b/opensearch/src/http/middleware/request_pipeline.rs index 6e7ecda9..563a2696 100644 --- a/opensearch/src/http/middleware/request_pipeline.rs +++ b/opensearch/src/http/middleware/request_pipeline.rs @@ -37,7 +37,7 @@ pub(crate) enum RequestPipelineErrorKind { } #[async_trait] -pub trait RequestHandler: dyn_clone::DynClone + std::fmt::Debug + Send + Sync + 'static { +pub trait RequestHandler: std::fmt::Debug + Send + Sync + 'static { async fn handle( &self, request: Request, @@ -45,8 +45,6 @@ pub trait RequestHandler: dyn_clone::DynClone + std::fmt::Debug + Send + Sync + ) -> Result; } -dyn_clone::clone_trait_object!(RequestHandler); - #[derive(Clone)] pub struct RequestHandlerFn(F); @@ -67,7 +65,6 @@ where Request, RequestPipeline<'a>, ) -> BoxFuture<'a, Result> - + Clone + Send + Sync + 'static, @@ -81,13 +78,44 @@ where } } +#[async_trait] +impl RequestHandler for std::sync::Arc +where + R: RequestHandler, +{ + async fn handle( + &self, + request: Request, + next: RequestPipeline<'_>, + ) -> Result { + self.as_ref().handle(request, next).await + } +} + +#[async_trait] +impl RequestHandler for std::sync::Arc { + async fn handle( + &self, + request: Request, + next: RequestPipeline<'_>, + ) -> Result { + self.as_ref().handle(request, next).await + } +} + +pub(crate) trait BoxedRequestHandler: RequestHandler + dyn_clone::DynClone {} + +impl BoxedRequestHandler for T where T: RequestHandler + Clone {} + +dyn_clone::clone_trait_object!(BoxedRequestHandler); + pub struct RequestPipeline<'a> { pub client: &'a Client, - pipeline: &'a [Box], + pipeline: &'a [Box], } impl<'a> RequestPipeline<'a> { - pub(crate) fn new(client: &'a Client, pipeline: &'a [Box]) -> Self { + pub(crate) fn new(client: &'a Client, pipeline: &'a [Box]) -> Self { Self { client, pipeline } } diff --git a/opensearch/src/http/transport.rs b/opensearch/src/http/transport.rs index bf12edd9..c595267a 100644 --- a/opensearch/src/http/transport.rs +++ b/opensearch/src/http/transport.rs @@ -121,7 +121,7 @@ pub struct TransportBuilder { sigv4_time_source: Option, init_stack: Vec>, req_init_stack: Vec>, - req_handler_stack: Vec>, + req_handler_stack: Vec>, } impl TransportBuilder { @@ -301,7 +301,7 @@ impl TransportBuilder { /// .build()?; /// # Ok::<(), opensearch::Error>(()) /// ``` - pub fn with_req_init(mut self, init: impl RequestInitializer) -> Self { + pub fn with_req_init(mut self, init: impl RequestInitializer + Clone) -> Self { self.req_init_stack.push(Box::new(init)); self } @@ -363,7 +363,7 @@ impl TransportBuilder { /// .build()?; /// # Ok::<(), opensearch::Error>(()) /// ``` - pub fn with_handler(mut self, handler: impl RequestHandler) -> Self { + pub fn with_handler(mut self, handler: impl RequestHandler + Clone) -> Self { self.req_handler_stack.push(Box::new(handler)); self } @@ -541,7 +541,7 @@ pub struct Transport { #[cfg(feature = "aws-auth")] sigv4_time_source: SharedTimeSource, req_init_stack: Box<[Box]>, - req_handler_stack: Box<[Box]>, + req_handler_stack: Box<[Box]>, } impl Transport { From 9d4efc907c798e3f7098f5ec8f03ee0cd40235d0 Mon Sep 17 00:00:00 2001 From: Thomas Farr Date: Fri, 15 Dec 2023 12:41:11 +1300 Subject: [PATCH 4/5] Move is_send_sync check behind test cfg Signed-off-by: Thomas Farr --- opensearch/src/client.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/opensearch/src/client.rs b/opensearch/src/client.rs index 2e2d15ea..121a344b 100644 --- a/opensearch/src/client.rs +++ b/opensearch/src/client.rs @@ -105,8 +105,12 @@ impl OpenSearch { } } -// Ensure that the client is `Send` and `Sync` -const _: () = { - const fn is_send_sync() {} - is_send_sync::() -}; +#[cfg(test)] +mod test { + #[test] + fn client_is_send_sync() { + // Ensure that the client is `Send` and `Sync` + fn is_send_sync() {} + is_send_sync::() + } +} From 2d8d0e822cccf393d5573d688552a5317a98293a Mon Sep 17 00:00:00 2001 From: Thomas Farr Date: Fri, 15 Dec 2023 12:42:07 +1300 Subject: [PATCH 5/5] Rename RequestPipeline to RequestHandlerChain Signed-off-by: Thomas Farr --- opensearch/src/error.rs | 14 ++-- opensearch/src/http/middleware/mod.rs | 4 +- ...request_pipeline.rs => request_handler.rs} | 58 ++++++++-------- opensearch/src/http/transport.rs | 67 +++++++++---------- opensearch/tests/middleware.rs | 10 +-- 5 files changed, 78 insertions(+), 75 deletions(-) rename opensearch/src/http/middleware/{request_pipeline.rs => request_handler.rs} (63%) diff --git a/opensearch/src/error.rs b/opensearch/src/error.rs index ae34e781..d6c7ee4a 100644 --- a/opensearch/src/error.rs +++ b/opensearch/src/error.rs @@ -36,7 +36,7 @@ use crate::{ cert::CertificateError, http::{ - middleware::{RequestPipelineError, RequestPipelineErrorKind}, + middleware::{RequestHandlerError, RequestHandlerErrorKind}, transport, StatusCode, }, }; @@ -87,16 +87,16 @@ enum Kind { #[error("request initializer error: {0}")] RequestInitializer(#[source] BoxError<'static>), - #[error("request pipeline error: {0}")] - RequestPipeline(#[source] BoxError<'static>), + #[error("request handler error: {0}")] + RequestHandler(#[source] BoxError<'static>), } -impl From for Kind { - fn from(err: RequestPipelineError) -> Self { - use RequestPipelineErrorKind::*; +impl From for Kind { + fn from(err: RequestHandlerError) -> Self { + use RequestHandlerErrorKind::*; match err.0 { - Pipeline(err) => Self::RequestPipeline(err), + Handler(err) => Self::RequestHandler(err), Http(err) => Self::Http(err), } } diff --git a/opensearch/src/http/middleware/mod.rs b/opensearch/src/http/middleware/mod.rs index c26528cc..afbaed04 100644 --- a/opensearch/src/http/middleware/mod.rs +++ b/opensearch/src/http/middleware/mod.rs @@ -10,11 +10,11 @@ */ mod initializers; -mod request_pipeline; +mod request_handler; pub use async_trait::async_trait; pub use initializers::*; -pub use request_pipeline::*; +pub use request_handler::*; pub(crate) type BoxFuture<'a, T> = std::pin::Pin + Send + 'a>>; diff --git a/opensearch/src/http/middleware/request_pipeline.rs b/opensearch/src/http/middleware/request_handler.rs similarity index 63% rename from opensearch/src/http/middleware/request_pipeline.rs rename to opensearch/src/http/middleware/request_handler.rs index 563a2696..ac9f9ee7 100644 --- a/opensearch/src/http/middleware/request_pipeline.rs +++ b/opensearch/src/http/middleware/request_handler.rs @@ -15,25 +15,25 @@ use reqwest::{Client, Request, Response}; #[derive(Debug, thiserror::Error)] #[error(transparent)] -pub struct RequestPipelineError(pub(crate) RequestPipelineErrorKind); +pub struct RequestHandlerError(pub(crate) RequestHandlerErrorKind); -impl RequestPipelineError { +impl RequestHandlerError { pub fn new(err: impl Into>) -> Self { - Self(RequestPipelineErrorKind::Pipeline(err.into())) + Self(RequestHandlerErrorKind::Handler(err.into())) } fn http(err: reqwest::Error) -> Self { - Self(RequestPipelineErrorKind::Http(err)) + Self(RequestHandlerErrorKind::Http(err)) } } #[derive(Debug, thiserror::Error)] -pub(crate) enum RequestPipelineErrorKind { +pub(crate) enum RequestHandlerErrorKind { #[error("http error: {0}")] Http(#[source] reqwest::Error), - #[error("pipeline error: {0}")] - Pipeline(#[source] BoxError<'static>), + #[error("handler error: {0}")] + Handler(#[source] BoxError<'static>), } #[async_trait] @@ -41,8 +41,8 @@ pub trait RequestHandler: std::fmt::Debug + Send + Sync + 'static { async fn handle( &self, request: Request, - next: RequestPipeline<'_>, - ) -> Result; + next: RequestHandlerChain<'_>, + ) -> Result; } #[derive(Clone)] @@ -63,8 +63,8 @@ impl RequestHandler for RequestHandlerFn where F: for<'a> Fn( Request, - RequestPipeline<'a>, - ) -> BoxFuture<'a, Result> + RequestHandlerChain<'a>, + ) -> BoxFuture<'a, Result> + Send + Sync + 'static, @@ -72,8 +72,8 @@ where async fn handle( &self, request: Request, - next: RequestPipeline<'_>, - ) -> Result { + next: RequestHandlerChain<'_>, + ) -> Result { self.0(request, next).await } } @@ -86,8 +86,8 @@ where async fn handle( &self, request: Request, - next: RequestPipeline<'_>, - ) -> Result { + next: RequestHandlerChain<'_>, + ) -> Result { self.as_ref().handle(request, next).await } } @@ -97,8 +97,8 @@ impl RequestHandler for std::sync::Arc { async fn handle( &self, request: Request, - next: RequestPipeline<'_>, - ) -> Result { + next: RequestHandlerChain<'_>, + ) -> Result { self.as_ref().handle(request, next).await } } @@ -109,25 +109,29 @@ impl BoxedRequestHandler for T where T: RequestHandler + Clone {} dyn_clone::clone_trait_object!(BoxedRequestHandler); -pub struct RequestPipeline<'a> { - pub client: &'a Client, - pipeline: &'a [Box], +pub struct RequestHandlerChain<'a> { + client: &'a Client, + chain: &'a [Box], } -impl<'a> RequestPipeline<'a> { - pub(crate) fn new(client: &'a Client, pipeline: &'a [Box]) -> Self { - Self { client, pipeline } +impl<'a> RequestHandlerChain<'a> { + pub(crate) fn new(client: &'a Client, chain: &'a [Box]) -> Self { + Self { client, chain } } - pub async fn run(mut self, request: Request) -> Result { - if let Some((head, tail)) = self.pipeline.split_first() { - self.pipeline = tail; + pub fn client(&self) -> &'a Client { + self.client + } + + pub async fn run(mut self, request: Request) -> Result { + if let Some((head, tail)) = self.chain.split_first() { + self.chain = tail; head.handle(request, self).await } else { self.client .execute(request) .await - .map_err(RequestPipelineError::http) + .map_err(RequestHandlerError::http) } } } diff --git a/opensearch/src/http/transport.rs b/opensearch/src/http/transport.rs index c595267a..881b03eb 100644 --- a/opensearch/src/http/transport.rs +++ b/opensearch/src/http/transport.rs @@ -119,9 +119,9 @@ pub struct TransportBuilder { sigv4_service_name: String, #[cfg(feature = "aws-auth")] sigv4_time_source: Option, - init_stack: Vec>, - req_init_stack: Vec>, - req_handler_stack: Vec>, + client_initializers: Vec>, + request_initializers: Vec>, + request_handlers: Vec>, } impl TransportBuilder { @@ -145,9 +145,9 @@ impl TransportBuilder { sigv4_service_name: "es".to_string(), #[cfg(feature = "aws-auth")] sigv4_time_source: None, - init_stack: Vec::new(), - req_init_stack: Vec::new(), - req_handler_stack: Vec::new(), + client_initializers: Vec::new(), + request_initializers: Vec::new(), + request_handlers: Vec::new(), } } @@ -258,14 +258,14 @@ impl TransportBuilder { /// } /// /// let transport: Transport = TransportBuilder::default() - /// .with_init(Initializer) - /// .with_init(might_fail) - /// .with_init(|client_builder: reqwest::ClientBuilder| client_builder.redirect(reqwest::redirect::Policy::limited(1))) + /// .with_client_initializer(Initializer) + /// .with_client_initializer(might_fail) + /// .with_client_initializer(|client_builder: reqwest::ClientBuilder| client_builder.redirect(reqwest::redirect::Policy::limited(1))) /// .build()?; /// # Ok::<(), opensearch::Error>(()) /// ``` - pub fn with_init(mut self, init: impl ClientInitializer) -> Self { - self.init_stack.push(Box::new(init)); + pub fn with_client_initializer(mut self, init: impl ClientInitializer) -> Self { + self.client_initializers.push(Box::new(init)); self } @@ -297,12 +297,12 @@ impl TransportBuilder { /// } /// /// let transport: Transport = TransportBuilder::default() - /// .with_req_init(Counter::new()) + /// .with_initializer(Counter::new()) /// .build()?; /// # Ok::<(), opensearch::Error>(()) /// ``` - pub fn with_req_init(mut self, init: impl RequestInitializer + Clone) -> Self { - self.req_init_stack.push(Box::new(init)); + pub fn with_initializer(mut self, init: impl RequestInitializer + Clone) -> Self { + self.request_initializers.push(Box::new(init)); self } @@ -317,29 +317,29 @@ impl TransportBuilder { /// /// let counter = Arc::new(AtomicUsize::new(0)); /// let transport: Transport = TransportBuilder::default() - /// .with_req_init_fn(move |request_builder: reqwest::RequestBuilder| { + /// .with_initializer_fn(move |request_builder: reqwest::RequestBuilder| { /// let counter = counter.fetch_add(1, Ordering::SeqCst); /// request_builder.header("x-request-id", format!("req-{}", counter)) /// }) /// .build()?; /// # Ok::<(), opensearch::Error>(()) /// ``` - pub fn with_req_init_fn(self, init: F) -> Self + pub fn with_initializer_fn(self, init: F) -> Self where F: Fn(reqwest::RequestBuilder) -> R + Clone + Send + Sync + 'static, R: InitializerResult, { - self.with_req_init(request_initializer_fn(init)) + self.with_initializer(request_initializer_fn(init)) } /// Adds a [RequestHandler] to the stack of handlers that will be called when an underlying [reqwest::Request] is being sent. /// /// Handlers are called in the order they are added. - /// + /// /// # Example /// ```rust,no_run /// use opensearch::http::{ - /// middleware::{async_trait, RequestHandler, RequestPipeline, RequestPipelineError}, + /// middleware::{async_trait, RequestHandler, RequestHandlerChain, RequestHandlerError}, /// reqwest::{Request, Response}, /// transport::{Transport, TransportBuilder}, /// }; @@ -349,7 +349,7 @@ impl TransportBuilder { /// /// #[async_trait] /// impl RequestHandler for Logger { - /// async fn handle(&self, request: Request, next: RequestPipeline<'_>) -> Result { + /// async fn handle(&self, request: Request, next: RequestHandlerChain<'_>) -> Result { /// println!("sending request to {}", request.url()); /// let now = std::time::Instant::now(); /// let res = next.run(request).await?; @@ -364,24 +364,24 @@ impl TransportBuilder { /// # Ok::<(), opensearch::Error>(()) /// ``` pub fn with_handler(mut self, handler: impl RequestHandler + Clone) -> Self { - self.req_handler_stack.push(Box::new(handler)); + self.request_handlers.push(Box::new(handler)); self } /// Adds a [RequestHandler] to the stack of handlers that will be called when an underlying [reqwest::Request] is being sent. /// /// Handlers are called in the order they are added. - /// + /// /// # Example /// ```rust,no_run /// use opensearch::http::{ - /// middleware::{RequestHandler, RequestPipeline, RequestPipelineError}, + /// middleware::{RequestHandler, RequestHandlerChain, RequestHandlerError}, /// reqwest::{Request, Response}, /// transport::{Transport, TransportBuilder}, /// }; /// use std::{future::Future, pin::Pin}; /// - /// fn logger(req: Request, next: RequestPipeline<'_>) -> Pin> + Send + '_>> { + /// fn logger(req: Request, next: RequestHandlerChain<'_>) -> Pin> + Send + '_>> { /// Box::pin(async move { /// println!("sending request to {}", req.url()); /// let now = std::time::Instant::now(); @@ -400,9 +400,8 @@ impl TransportBuilder { where F: for<'a> Fn( reqwest::Request, - RequestPipeline<'a>, - ) - -> BoxFuture<'a, Result> + RequestHandlerChain<'a>, + ) -> BoxFuture<'a, Result> + Clone + Send + Sync @@ -477,7 +476,7 @@ impl TransportBuilder { } client_builder = self - .init_stack + .client_initializers .into_iter() .try_fold(client_builder, |client_builder, init| { init.init(client_builder) @@ -494,8 +493,8 @@ impl TransportBuilder { sigv4_service_name: self.sigv4_service_name, #[cfg(feature = "aws-auth")] sigv4_time_source: self.sigv4_time_source.unwrap_or_default(), - req_init_stack: self.req_init_stack.into_boxed_slice(), - req_handler_stack: self.req_handler_stack.into_boxed_slice(), + request_initializers: self.request_initializers.into_boxed_slice(), + request_handlers: self.request_handlers.into_boxed_slice(), }) } } @@ -540,8 +539,8 @@ pub struct Transport { sigv4_service_name: String, #[cfg(feature = "aws-auth")] sigv4_time_source: SharedTimeSource, - req_init_stack: Box<[Box]>, - req_handler_stack: Box<[Box]>, + request_initializers: Box<[Box]>, + request_handlers: Box<[Box]>, } impl Transport { @@ -650,7 +649,7 @@ impl Transport { } request_builder = self - .req_init_stack + .request_initializers .iter() .try_fold(request_builder, |request_builder, init| { init.init(request_builder) @@ -672,7 +671,7 @@ impl Transport { .await?; } - let response = RequestPipeline::new(&self.client, &self.req_handler_stack) + let response = RequestHandlerChain::new(&self.client, &self.request_handlers) .run(request) .await?; diff --git a/opensearch/tests/middleware.rs b/opensearch/tests/middleware.rs index ba01c64e..c155a1f6 100644 --- a/opensearch/tests/middleware.rs +++ b/opensearch/tests/middleware.rs @@ -17,7 +17,7 @@ use std::sync::{ }; use opensearch::http::middleware::{ - async_trait, RequestHandler, RequestInitializer, RequestPipeline, RequestPipelineError, + async_trait, RequestHandler, RequestHandlerChain, RequestHandlerError, RequestInitializer, }; use reqwest::RequestBuilder; @@ -50,8 +50,8 @@ async fn request_initializer() -> anyhow::Result<()> { }; let client = server.client_with(|b| { - b.with_req_init(Counter(Arc::new(AtomicUsize::new(101)))) - .with_req_init_fn(counter_fn) + b.with_initializer(Counter(Arc::new(AtomicUsize::new(101)))) + .with_initializer_fn(counter_fn) }); client.ping().send().await?; @@ -79,8 +79,8 @@ async fn request_handler() -> anyhow::Result<()> { async fn handle( &self, request: reqwest::Request, - next: RequestPipeline<'_>, - ) -> Result { + next: RequestHandlerChain<'_>, + ) -> Result { self.0.fetch_add(1, Ordering::SeqCst); next.run(request).await }