From 591d9c9ca57e8d75dfed584c8f135eef31e3bef5 Mon Sep 17 00:00:00 2001 From: Yann Simon Date: Sun, 20 Oct 2024 12:52:32 +0200 Subject: [PATCH] PoC of introducing SpoofableValue PoC to check which solution to pick for https://github.com/tokio-rs/axum/issues/2998 --- axum-extra/src/extract/host.rs | 17 ++++++++++------- axum-extra/src/extract/mod.rs | 13 +++++++++++++ axum-extra/src/extract/scheme.rs | 12 +++++++----- examples/tls-graceful-shutdown/src/main.rs | 2 +- examples/tls-rustls/src/main.rs | 2 +- 5 files changed, 32 insertions(+), 14 deletions(-) diff --git a/axum-extra/src/extract/host.rs b/axum-extra/src/extract/host.rs index 477bc4fa94..65f2398317 100644 --- a/axum-extra/src/extract/host.rs +++ b/axum-extra/src/extract/host.rs @@ -1,4 +1,7 @@ -use super::rejection::{FailedToResolveHost, HostRejection}; +use super::{ + rejection::{FailedToResolveHost, HostRejection}, + SpoofableValue, +}; use axum::extract::FromRequestParts; use http::{ header::{HeaderMap, FORWARDED}, @@ -18,7 +21,7 @@ const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; /// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make /// sure to validate them to avoid security issues. #[derive(Debug, Clone)] -pub struct Host(pub String); +pub struct Host(pub SpoofableValue); impl FromRequestParts for Host where @@ -28,7 +31,7 @@ where async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(host) = parse_forwarded(&parts.headers) { - return Ok(Host(host.to_owned())); + return Ok(Host(SpoofableValue::new(host.to_owned()))); } if let Some(host) = parts @@ -36,7 +39,7 @@ where .get(X_FORWARDED_HOST_HEADER_KEY) .and_then(|host| host.to_str().ok()) { - return Ok(Host(host.to_owned())); + return Ok(Host(SpoofableValue::new(host.to_owned()))); } if let Some(host) = parts @@ -44,11 +47,11 @@ where .get(http::header::HOST) .and_then(|host| host.to_str().ok()) { - return Ok(Host(host.to_owned())); + return Ok(Host(SpoofableValue::new(host.to_owned()))); } if let Some(host) = parts.uri.host() { - return Ok(Host(host.to_owned())); + return Ok(Host(SpoofableValue::new(host.to_owned()))); } Err(HostRejection::FailedToResolveHost(FailedToResolveHost)) @@ -81,7 +84,7 @@ mod tests { fn test_client() -> TestClient { async fn host_as_body(Host(host): Host) -> String { - host + host.spoofable_value() } TestClient::new(Router::new().route("/", get(host_as_body))) diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs index 7d2a5b2433..29b2b9d84d 100644 --- a/axum-extra/src/extract/mod.rs +++ b/axum-extra/src/extract/mod.rs @@ -63,3 +63,16 @@ pub use crate::json_lines::JsonLines; #[cfg(feature = "typed-header")] #[doc(no_inline)] pub use crate::typed_header::TypedHeader; + +#[derive(Debug, Clone)] +pub struct SpoofableValue(String); + +impl SpoofableValue { + pub fn new(value: String) -> Self { + Self(value) + } + + pub fn spoofable_value(self) -> String { + self.0 + } +} diff --git a/axum-extra/src/extract/scheme.rs b/axum-extra/src/extract/scheme.rs index 891d5c0bdd..c66d11852a 100644 --- a/axum-extra/src/extract/scheme.rs +++ b/axum-extra/src/extract/scheme.rs @@ -9,6 +9,8 @@ use http::{ header::{HeaderMap, FORWARDED}, request::Parts, }; + +use super::SpoofableValue; const X_FORWARDED_PROTO_HEADER_KEY: &str = "X-Forwarded-Proto"; /// Extractor that resolves the scheme / protocol of a request. @@ -21,7 +23,7 @@ const X_FORWARDED_PROTO_HEADER_KEY: &str = "X-Forwarded-Proto"; /// Note that user agents can set the `X-Forwarded-Proto` header to arbitrary values so make /// sure to validate them to avoid security issues. #[derive(Debug, Clone)] -pub struct Scheme(pub String); +pub struct Scheme(pub SpoofableValue); /// Rejection type used if the [`Scheme`] extractor is unable to /// resolve a scheme. @@ -43,7 +45,7 @@ where async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { // Within Forwarded header if let Some(scheme) = parse_forwarded(&parts.headers) { - return Ok(Scheme(scheme.to_owned())); + return Ok(Scheme(SpoofableValue::new(scheme.to_owned()))); } // X-Forwarded-Proto @@ -52,12 +54,12 @@ where .get(X_FORWARDED_PROTO_HEADER_KEY) .and_then(|scheme| scheme.to_str().ok()) { - return Ok(Scheme(scheme.to_owned())); + return Ok(Scheme(SpoofableValue::new(scheme.to_owned()))); } // From parts of an HTTP/2 request if let Some(scheme) = parts.uri.scheme_str() { - return Ok(Scheme(scheme.to_owned())); + return Ok(Scheme(SpoofableValue::new(scheme.to_owned()))); } Err(SchemeMissing) @@ -89,7 +91,7 @@ mod tests { fn test_client() -> TestClient { async fn scheme_as_body(Scheme(scheme): Scheme) -> String { - scheme + scheme.spoofable_value() } TestClient::new(Router::new().route("/", get(scheme_as_body))) diff --git a/examples/tls-graceful-shutdown/src/main.rs b/examples/tls-graceful-shutdown/src/main.rs index 344ecf9d05..8087964d8b 100644 --- a/examples/tls-graceful-shutdown/src/main.rs +++ b/examples/tls-graceful-shutdown/src/main.rs @@ -122,7 +122,7 @@ where } let redirect = move |Host(host): Host, uri: Uri| async move { - match make_https(host, uri, ports) { + match make_https(host.spoofable_value(), uri, ports) { Ok(uri) => Ok(Redirect::permanent(&uri.to_string())), Err(error) => { tracing::warn!(%error, "failed to convert URI to HTTPS"); diff --git a/examples/tls-rustls/src/main.rs b/examples/tls-rustls/src/main.rs index bd8b07fcfd..7852f134a0 100644 --- a/examples/tls-rustls/src/main.rs +++ b/examples/tls-rustls/src/main.rs @@ -88,7 +88,7 @@ async fn redirect_http_to_https(ports: Ports) { } let redirect = move |Host(host): Host, uri: Uri| async move { - match make_https(host, uri, ports) { + match make_https(host.spoofable_value(), uri, ports) { Ok(uri) => Ok(Redirect::permanent(&uri.to_string())), Err(error) => { tracing::warn!(%error, "failed to convert URI to HTTPS");