diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index c5b46c46..46fa90cc 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -62,6 +62,7 @@ full = [ "auth", "catch-panic", "compression-full", + "concurrency-limit", "cors", "decompression-full", "follow-redirect", @@ -86,6 +87,7 @@ full = [ add-extension = [] auth = ["base64", "validate-request"] catch-panic = ["tracing", "futures-util/std"] +concurrency-limit = [] cors = [] follow-redirect = ["futures-util", "iri-string", "tower/util"] fs = ["futures-util", "tokio/fs", "tokio-util/io", "tokio/io-util", "dep:http-range-header", "mime_guess", "mime", "percent-encoding", "httpdate", "set-status", "futures-util/alloc", "tracing"] diff --git a/tower-http/src/concurrency_limit.rs b/tower-http/src/concurrency_limit.rs new file mode 100644 index 00000000..227467e9 --- /dev/null +++ b/tower-http/src/concurrency_limit.rs @@ -0,0 +1,312 @@ +//! Limit the max number of concurrently processed requests. +//! +//! The service sets a maximum limit to the number of concurrently processed requests. The +//! processing of a request starts when it is received by the service (`tower::Service::call` is +//! called) and is considered complete when the response body is consumed, dropped, or an error +//! happens. +//! +//! Internally, it uses semaphore to track and limit number of in-flight requests +//! +//! # Relation to `ConcurrencyLimit` from `tower` crate +//! +//! The [`tower::limit::concurrency::ConcurrencyLimit`] service uses a different definition of +//! 'request processing'. It starts when request is received by `tower::Service::call`, and ends +//! immediatelly after response is produced. +//! +//! In some cases it may not work properly with [`http::Response`], as it does not account for +//! process of consuming response body. +//! +//! When stream is used as response body, the process of consumig it (ie streaming to called) may +//! take longer and use more resources than just producing the response itself. And often it the +//! number of streams we are processing concurrently we want to limit. +//! +//! The service version from [`tower-http`](crate) takes response body consumption into +//! consideration and *will* limit number of concurrent streams correctly. +//! +//! ``` +//! use std::convert::Infallible; +//! use bytes::Bytes; +//! use http::{Request, Response}; +//! use http_body_util::Full; +//! use tower::{Service, ServiceExt, ServiceBuilder}; +//! use tower_http::concurrency_limit::ConcurrencyLimitLayer; +//! +//! async fn handle(req: Request>) -> Result>, Infallible> { +//! // ... +//! # Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let mut service = ServiceBuilder::new() +//! // limit number of concurrent requests to 3 +//! .layer(ConcurrencyLimitLayer::new(3)) +//! .service_fn(handle); +//! +//! // Call the service. +//! let response = service +//! .ready() +//! .await? +//! .call(Request::new(Full::default())) +//! .await?; +//! # Ok(()) +//! # } +//! ``` +//! + +use http::{Request, Response}; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{ready, Context, Poll}, +}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio_util::sync::PollSemaphore; + +/// Limit max number of concurrent requests (per service) +/// +/// The layer enforces a same concurrency limit for each inner service separately. In other words, +/// [`ConcurrencyLimit`] middleware constructed from this layer for each service will have its own +/// semaphore and will track and limit requests separately. +/// +/// To track and limit multiple services together, see [`GlobalConcurrencyLimitLayer`] +/// +/// See the [module docs](crate::concurrency) for more details. +#[derive(Clone, Debug)] +pub struct ConcurrencyLimitLayer { + max: usize, +} + +impl ConcurrencyLimitLayer { + /// Create new [`ConcurrencyLimitLayer`] with semaphore size + pub fn new(max: usize) -> Self { + Self { max } + } +} + +impl tower_layer::Layer for ConcurrencyLimitLayer { + type Service = ConcurrencyLimit; + + fn layer(&self, service: S) -> Self::Service { + ConcurrencyLimit::new(service, Arc::new(Semaphore::new(self.max))) + } +} + +/// Limit max number of concurrent requests (shared) +/// +/// The layer enforces a single concurrency limit for multiple inner services at once. In other +/// words, [`ConcurrencyLimit`] middleware constructed from this layer for each service will have +/// one shared semaphore and will track and limit requests together.. +/// +/// To track and limit each service separately, see [`ConcurrencyLimitLayer`]. +/// +/// See the [module docs](crate::concurrency) for more details. +#[derive(Clone, Debug)] +pub struct SharedConcurrencyLimitLayer { + semaphore: Arc, +} + +impl SharedConcurrencyLimitLayer { + /// Create new [`ConcurrencyLimitLayer`] with shared semaphore + pub fn new(max: usize) -> Self { + Self { + semaphore: Arc::new(Semaphore::new(max)), + } + } + + /// Create [`ConcurrencyLimitLayer`] from semaphore + pub fn from_semaphore(semaphore: Arc) -> Self { + Self { semaphore } + } +} + +impl tower_layer::Layer for SharedConcurrencyLimitLayer { + type Service = ConcurrencyLimit; + + fn layer(&self, service: S) -> Self::Service { + ConcurrencyLimit::new(service, self.semaphore.clone()) + } +} + +/// Middleware that limits max number fo concurrent in-flight requests. +/// +/// See the [module docs](crate::concurrency) for more details. +#[derive(Debug)] +pub struct ConcurrencyLimit { + inner: S, + semaphore: PollSemaphore, + permit: Option, +} + +impl ConcurrencyLimit { + /// Create new [`ConcurrencyLimit`] with associated semaphore + pub fn new(inner: S, semaphore: Arc) -> Self { + Self { + inner, + semaphore: PollSemaphore::new(semaphore), + permit: None, + } + } + + define_inner_service_accessors!(); +} + +// Since we hold an `OwnedSemaphorePermit`, we can't derive `Clone`. Instead, when cloning the +// service, create a new service with the same semaphore, but with the permit in the un-acquired +// state. +impl Clone for ConcurrencyLimit { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + semaphore: self.semaphore.clone(), + permit: None, + } + } +} + +impl tower_service::Service> for ConcurrencyLimit +where + S: tower_service::Service, Response = Response>, +{ + type Response = Response>; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.permit.is_none() { + self.permit = ready!(self.semaphore.poll_acquire(cx)); + } + + self.inner.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + let permit = self + .permit + .take() + .expect("max requests in-flight; poll_ready must be called first"); + + let future = self.inner.call(request); + ResponseFuture { + inner: future, + permit: Some(permit), + } + } +} + +pin_project! { + + /// Response future for [`ConcurrencyLimit`] + pub struct ResponseFuture { + #[pin] + inner: F, + + // The permit is stored inside option, so that we can take it out from the future on its + // completion and pass it to the ResponseBody. The permit has to be droped only after + // ResponseBody is consumed. + permit: Option, + } +} + +impl Future for ResponseFuture +where + F: Future, E>>, +{ + type Output = Result>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let response = ready!(this.inner.poll(cx))?; + + let permit = this.permit.take().unwrap(); + let response = response.map(move |body| ResponseBody { + inner: body, + permit, + }); + + Poll::Ready(Ok(response)) + } +} + +pin_project! { + + /// Response body for [`ConcurrencyLimit`] + /// + /// It enforces limit on number of `struct` instances in concurrent existence. + pub struct ResponseBody { + #[pin] + inner: B, + permit: OwnedSemaphorePermit, + } +} + +impl http_body::Body for ResponseBody +where + B: http_body::Body, +{ + type Data = B::Data; + type Error = B::Error; + + #[inline] + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + self.project().inner.poll_frame(cx) + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + self.inner.size_hint() + } +} + +#[cfg(test)] +mod tests { + use http::Request; + use tower::{BoxError, ServiceBuilder}; + use tower_service::Service; + use crate::test_helpers::Body; + use super::*; + + #[tokio::test] + async fn basic() { + let semaphore = Arc::new(Semaphore::new(1)); + assert_eq!(1, semaphore.available_permits()); + + let mut service = ServiceBuilder::new() + .layer(SharedConcurrencyLimitLayer::from_semaphore( + semaphore.clone(), + )) + .service_fn(echo); + + // driving service to ready pre-acquire semaphore permit, decrease available count + std::future::poll_fn(|cx| service.poll_ready(cx)) + .await + .unwrap(); + assert_eq!(0, semaphore.available_permits()); + + // creating response future decreases number of permits + let response_future = service.call(Request::new(Body::empty())); + + // awaiting response future moves permit to response, no change in available count + let response = response_future.await.unwrap(); + assert_eq!(0, semaphore.available_permits()); + + // consuming response frees permit and increase available count + let body = response.into_body(); + crate::test_helpers::to_bytes(body).await.unwrap(); + assert_eq!(1, semaphore.available_permits()); + } + + async fn echo(req: Request) -> Result, BoxError> { + Ok(Response::new(req.into_body())) + } +} diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index 4c731e83..33d018ed 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -294,6 +294,9 @@ mod compression_utils; ))] pub use compression_utils::CompressionLevel; +#[cfg(feature = "concurrency-limit")] +pub mod concurrency_limit; + #[cfg(feature = "map-response-body")] pub mod map_response_body;