Skip to content

Commit

Permalink
Merge pull request #6 from plabayo/sync/tower-http-0.4.2
Browse files Browse the repository at this point in the history
port tower-http 0.4.3 to tower-async-http (came from 0.4.1)
  • Loading branch information
GlenDC authored Jul 20, 2023
2 parents 3ac41fd + 6700221 commit 891c9c7
Show file tree
Hide file tree
Showing 20 changed files with 477 additions and 388 deletions.
27 changes: 27 additions & 0 deletions tower-async-http/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,33 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## 0.1.2 (July 20, 2023)

Sync with original `tower-http` codebase from [`0.4.1`](https://github.com/tower-rs/tower-http/releases/tag/tower-http-0.4.1)
to [`0.4.3`](https://github.com/tower-rs/tower-http/releases/tag/tower-http-0.4.3).

## Added

- **cors:** Add support for private network preflights ([tower-rs/tower-http#373])
- **compression:** Implement `Default` for `DecompressionBody` ([tower-rs/tower-http#370])

## Changed

- **compression:** Update to async-compression 0.4 ([tower-rs/tower-http#371])

## Fixed

- **compression:** Override default brotli compression level 11 -> 4 ([tower-rs/tower-http#356])
- **trace:** Simplify dynamic tracing level application ([tower-rs/tower-http#380])
- **normalize_path:** Fix path normalization for preceding slashes ([tower-rs/tower-http#359])

[tower-rs/tower-http#356]: https://github.com/tower-rs/tower-http/pull/356
[tower-rs/tower-http#359]: https://github.com/tower-rs/tower-http/pull/359
[tower-rs/tower-http#370]: https://github.com/tower-rs/tower-http/pull/370
[tower-rs/tower-http#371]: https://github.com/tower-rs/tower-http/pull/371
[tower-rs/tower-http#373]: https://github.com/tower-rs/tower-http/pull/373
[tower-rs/tower-http#380]: https://github.com/tower-rs/tower-http/pull/380

## 0.1.1 (July 18, 2023)

- Improve, expand and fix documentation;
Expand Down
4 changes: 2 additions & 2 deletions tower-async-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ description = """
Tower Async middleware and utilities for HTTP clients and servers.
An "Async Trait" fork from the original Tower Library.
"""
version = "0.1.1"
version = "0.1.2"
authors = ["Glen De Cauwsemaecker <[email protected]>"]
edition = "2021"
license = "MIT"
Expand All @@ -25,7 +25,7 @@ tower-async-layer = { version = "0.1", path = "../tower-async-layer" }
tower-async-service = { version = "0.1", path = "../tower-async-service" }

# optional dependencies
async-compression = { version = "0.3", optional = true, features = ["tokio"] }
async-compression = { version = "0.4", optional = true, features = ["tokio"] }
base64 = { version = "0.21", optional = true }
http-range-header = "0.3.0"
iri-string = { version = "0.7.0", optional = true }
Expand Down
6 changes: 4 additions & 2 deletions tower-async-http/src/auth/add_authorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@
//! # }
//! ```

use base64::{engine::general_purpose::STANDARD as base64, Engine};
use base64::Engine as _;
use http::{HeaderValue, Request, Response};
use std::convert::TryFrom;
use tower_async_layer::Layer;
use tower_async_service::Service;

const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;

/// Layer that applies [`AddAuthorization`] which adds authorization to all requests using the
/// [`Authorization`] header.
///
Expand All @@ -67,7 +69,7 @@ impl AddAuthorizationLayer {
/// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
/// with this method. However use of HTTPS/TLS is not enforced by this middleware.
pub fn basic(username: &str, password: &str) -> Self {
let encoded = base64.encode(format!("{}:{}", username, password));
let encoded = BASE64.encode(format!("{}:{}", username, password));
let value = HeaderValue::try_from(format!("Basic {}", encoded)).unwrap();
Self { value }
}
Expand Down
14 changes: 8 additions & 6 deletions tower-async-http/src/auth/require_authorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,16 @@
//! Custom validation can be made by implementing [`ValidateRequest`].

use crate::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer};
use base64::{engine::general_purpose::STANDARD as base64, Engine};
use base64::Engine as _;
use http::{
header::{self, HeaderValue},
Request, Response, StatusCode,
};
use http_body::Body;
use std::{fmt, marker::PhantomData};

const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;

impl<S, ResBody> ValidateRequestHeader<S, Basic<ResBody>> {
/// Authorize requests using a username and password pair.
///
Expand Down Expand Up @@ -194,7 +196,7 @@ impl<ResBody> Basic<ResBody> {
where
ResBody: Body + Default,
{
let encoded = base64.encode(format!("{}:{}", username, password));
let encoded = BASE64.encode(format!("{}:{}", username, password));
let header_value = format!("Basic {}", encoded).parse().unwrap();
Self {
header_value,
Expand Down Expand Up @@ -260,7 +262,7 @@ mod tests {
let request = Request::get("/")
.header(
header::AUTHORIZATION,
format!("Basic {}", base64.encode("foo:bar")),
format!("Basic {}", BASE64.encode("foo:bar")),
)
.body(Body::empty())
.unwrap();
Expand All @@ -279,7 +281,7 @@ mod tests {
let request = Request::get("/")
.header(
header::AUTHORIZATION,
format!("Basic {}", base64.encode("wrong:credentials")),
format!("Basic {}", BASE64.encode("wrong:credentials")),
)
.body(Body::empty())
.unwrap();
Expand Down Expand Up @@ -317,7 +319,7 @@ mod tests {
let request = Request::get("/")
.header(
header::AUTHORIZATION,
format!("basic {}", base64.encode("foo:bar")),
format!("basic {}", BASE64.encode("foo:bar")),
)
.body(Body::empty())
.unwrap();
Expand All @@ -336,7 +338,7 @@ mod tests {
let request = Request::get("/")
.header(
header::AUTHORIZATION,
format!("Basic {}", base64.encode("Foo:bar")),
format!("Basic {}", BASE64.encode("Foo:bar")),
)
.body(Body::empty())
.unwrap();
Expand Down
11 changes: 10 additions & 1 deletion tower-async-http/src/compression/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,16 @@ where
type Output = BrotliEncoder<Self::Input>;

fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output {
BrotliEncoder::with_quality(input, quality.into_async_compression())
// The brotli crate used under the hood here has a default compression level of 11,
// which is the max for brotli. This causes extremely slow compression times, so we
// manually set a default of 4 here.
//
// This is the same default used by NGINX for on-the-fly brotli compression.
let level = match quality {
CompressionLevel::Default => async_compression::Level::Precise(4),
other => other.into_async_compression(),
};
BrotliEncoder::with_quality(input, level)
}

fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
Expand Down
10 changes: 8 additions & 2 deletions tower-async-http/src/compression/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ mod tests {

#[tokio::test]
async fn accept_encoding_configuration_works() -> Result<(), crate::BoxError> {
let deflate_only_layer = CompressionLayer::new().no_br().no_gzip();
let deflate_only_layer = CompressionLayer::new()
.quality(CompressionLevel::Best)
.no_br()
.no_gzip();

let mut service = ServiceBuilder::new()
// Compress responses based on the `Accept-Encoding` header.
Expand All @@ -173,7 +176,10 @@ mod tests {

let deflate_bytes_len = bytes.len();

let br_only_layer = CompressionLayer::new().no_gzip().no_deflate();
let br_only_layer = CompressionLayer::new()
.quality(CompressionLevel::Best)
.no_gzip()
.no_deflate();

let mut service = ServiceBuilder::new()
// Compress responses based on the `Accept-Encoding` header.
Expand Down
4 changes: 3 additions & 1 deletion tower-async-http/src/compression_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,9 @@ impl CompressionLevel {
CompressionLevel::Fastest => AsyncCompressionLevel::Fastest,
CompressionLevel::Best => AsyncCompressionLevel::Best,
CompressionLevel::Default => AsyncCompressionLevel::Default,
CompressionLevel::Precise(quality) => AsyncCompressionLevel::Precise(quality),
CompressionLevel::Precise(quality) => {
AsyncCompressionLevel::Precise(quality.try_into().unwrap_or(i32::MAX))
}
}
}
}
2 changes: 2 additions & 0 deletions tower-async-http/src/cors/allow_credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ impl AllowCredentials {

/// Allow credentials for some requests, based on a given predicate
///
/// The first argument to the predicate is the request origin.
///
/// See [`CorsLayer::allow_credentials`] for more details.
///
/// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials
Expand Down
196 changes: 196 additions & 0 deletions tower-async-http/src/cors/allow_private_network.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
use std::{fmt, sync::Arc};

use http::{
header::{HeaderName, HeaderValue},
request::Parts as RequestParts,
};

/// Holds configuration for how to set the [`Access-Control-Allow-Private-Network`][wicg] header.
///
/// See [`CorsLayer::allow_private_network`] for more details.
///
/// [wicg]: https://wicg.github.io/private-network-access/
/// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network
#[derive(Clone, Default)]
#[must_use]
pub struct AllowPrivateNetwork(AllowPrivateNetworkInner);

static TRUE: HeaderValue = HeaderValue::from_static("true");

impl AllowPrivateNetwork {
/// Allow requests via a more private network than the one used to access the origin
///
/// See [`CorsLayer::allow_private_network`] for more details.
///
/// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network
pub fn yes() -> Self {
Self(AllowPrivateNetworkInner::Yes)
}

/// Allow requests via private network for some requests, based on a given predicate
///
/// The first argument to the predicate is the request origin.
///
/// See [`CorsLayer::allow_private_network`] for more details.
///
/// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network
pub fn predicate<F>(f: F) -> Self
where
F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
{
Self(AllowPrivateNetworkInner::Predicate(Arc::new(f)))
}

pub(super) fn to_header(
&self,
origin: Option<&HeaderValue>,
parts: &RequestParts,
) -> Option<(HeaderName, HeaderValue)> {
#[allow(clippy::declare_interior_mutable_const)]
const REQUEST_PRIVATE_NETWORK: HeaderName =
HeaderName::from_static("access-control-request-private-network");

#[allow(clippy::declare_interior_mutable_const)]
const ALLOW_PRIVATE_NETWORK: HeaderName =
HeaderName::from_static("access-control-allow-private-network");

// Cheapest fallback: allow_private_network hasn't been set
if let AllowPrivateNetworkInner::No = &self.0 {
return None;
}

// Access-Control-Allow-Private-Network is only relevant if the request
// has the Access-Control-Request-Private-Network header set, else skip
if parts.headers.get(REQUEST_PRIVATE_NETWORK) != Some(&TRUE) {
return None;
}

let allow_private_network = match &self.0 {
AllowPrivateNetworkInner::Yes => true,
AllowPrivateNetworkInner::No => false, // unreachable, but not harmful
AllowPrivateNetworkInner::Predicate(c) => c(origin?, parts),
};

allow_private_network.then(|| (ALLOW_PRIVATE_NETWORK, TRUE.clone()))
}
}

impl From<bool> for AllowPrivateNetwork {
fn from(v: bool) -> Self {
match v {
true => Self(AllowPrivateNetworkInner::Yes),
false => Self(AllowPrivateNetworkInner::No),
}
}
}

impl fmt::Debug for AllowPrivateNetwork {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
AllowPrivateNetworkInner::Yes => f.debug_tuple("Yes").finish(),
AllowPrivateNetworkInner::No => f.debug_tuple("No").finish(),
AllowPrivateNetworkInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
}
}
}

#[derive(Clone)]
enum AllowPrivateNetworkInner {
Yes,
No,
Predicate(
Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
),
}

impl Default for AllowPrivateNetworkInner {
fn default() -> Self {
Self::No
}
}

#[cfg(test)]
mod tests {
use super::AllowPrivateNetwork;
use crate::cors::CorsLayer;

use http::{header::ORIGIN, request::Parts, HeaderName, HeaderValue, Request, Response};
use hyper::Body;
use tower_async::{BoxError, ServiceBuilder};
use tower_async_service::Service;

const REQUEST_PRIVATE_NETWORK: HeaderName =
HeaderName::from_static("access-control-request-private-network");

const ALLOW_PRIVATE_NETWORK: HeaderName =
HeaderName::from_static("access-control-allow-private-network");

const TRUE: HeaderValue = HeaderValue::from_static("true");

#[tokio::test]
async fn cors_private_network_header_is_added_correctly() {
let mut service = ServiceBuilder::new()
.layer(CorsLayer::new().allow_private_network(true))
.service_fn(echo);

let req = Request::builder()
.header(REQUEST_PRIVATE_NETWORK, TRUE)
.body(Body::empty())
.unwrap();
let res = service.call(req).await.unwrap();

assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);

let req = Request::builder().body(Body::empty()).unwrap();
let res = service.call(req).await.unwrap();

assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
}

#[tokio::test]
async fn cors_private_network_header_is_added_correctly_with_predicate() {
let allow_private_network =
AllowPrivateNetwork::predicate(|origin: &HeaderValue, parts: &Parts| {
parts.uri.path() == "/allow-private" && origin == "localhost"
});
let mut service = ServiceBuilder::new()
.layer(CorsLayer::new().allow_private_network(allow_private_network))
.service_fn(echo);

let req = Request::builder()
.header(ORIGIN, "localhost")
.header(REQUEST_PRIVATE_NETWORK, TRUE)
.uri("/allow-private")
.body(Body::empty())
.unwrap();

let res = service.call(req).await.unwrap();
assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);

let req = Request::builder()
.header(ORIGIN, "localhost")
.header(REQUEST_PRIVATE_NETWORK, TRUE)
.uri("/other")
.body(Body::empty())
.unwrap();

let res = service.call(req).await.unwrap();

assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());

let req = Request::builder()
.header(ORIGIN, "not-localhost")
.header(REQUEST_PRIVATE_NETWORK, TRUE)
.uri("/allow-private")
.body(Body::empty())
.unwrap();

let res = service.call(req).await.unwrap();

assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
}

async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
}
Loading

0 comments on commit 891c9c7

Please sign in to comment.