Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC of introducing SpoofableValue #3000

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions axum-extra/src/extract/host.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use super::rejection::{FailedToResolveHost, HostRejection};
use super::{
rejection::{FailedToResolveHost, HostRejection},
SpoofableValue,
};
use axum::extract::FromRequestParts;
use http::{
header::{HeaderMap, FORWARDED},
Expand All @@ -18,7 +21,7 @@ const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
/// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make
/// sure to validate them to avoid security issues.
#[derive(Debug, Clone)]
pub struct Host(pub String);
pub struct Host(pub SpoofableValue);

impl<S> FromRequestParts<S> for Host
where
Expand All @@ -28,27 +31,27 @@ where

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(host) = parse_forwarded(&parts.headers) {
return Ok(Host(host.to_owned()));
return Ok(Host(SpoofableValue::new(host.to_owned())));
}

if let Some(host) = parts
.headers
.get(X_FORWARDED_HOST_HEADER_KEY)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
return Ok(Host(SpoofableValue::new(host.to_owned())));
}

if let Some(host) = parts
.headers
.get(http::header::HOST)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
return Ok(Host(SpoofableValue::new(host.to_owned())));
}

if let Some(host) = parts.uri.host() {
return Ok(Host(host.to_owned()));
return Ok(Host(SpoofableValue::new(host.to_owned())));
}

Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
Expand Down Expand Up @@ -81,7 +84,7 @@ mod tests {

fn test_client() -> TestClient {
async fn host_as_body(Host(host): Host) -> String {
host
host.spoofable_value()
}

TestClient::new(Router::new().route("/", get(host_as_body)))
Expand Down
13 changes: 13 additions & 0 deletions axum-extra/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,16 @@ pub use crate::json_lines::JsonLines;
#[cfg(feature = "typed-header")]
#[doc(no_inline)]
pub use crate::typed_header::TypedHeader;

#[derive(Debug, Clone)]
pub struct SpoofableValue(String);

impl SpoofableValue {
pub fn new(value: String) -> Self {
Self(value)
}

pub fn spoofable_value(self) -> String {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we could add documentation to make users aware of risks using this method.

self.0
}
}
12 changes: 7 additions & 5 deletions axum-extra/src/extract/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use http::{
header::{HeaderMap, FORWARDED},
request::Parts,
};

use super::SpoofableValue;
const X_FORWARDED_PROTO_HEADER_KEY: &str = "X-Forwarded-Proto";

/// Extractor that resolves the scheme / protocol of a request.
Expand All @@ -21,7 +23,7 @@ const X_FORWARDED_PROTO_HEADER_KEY: &str = "X-Forwarded-Proto";
/// Note that user agents can set the `X-Forwarded-Proto` header to arbitrary values so make
/// sure to validate them to avoid security issues.
#[derive(Debug, Clone)]
pub struct Scheme(pub String);
pub struct Scheme(pub SpoofableValue);

/// Rejection type used if the [`Scheme`] extractor is unable to
/// resolve a scheme.
Expand All @@ -43,7 +45,7 @@ where
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// Within Forwarded header
if let Some(scheme) = parse_forwarded(&parts.headers) {
return Ok(Scheme(scheme.to_owned()));
return Ok(Scheme(SpoofableValue::new(scheme.to_owned())));
}

// X-Forwarded-Proto
Expand All @@ -52,12 +54,12 @@ where
.get(X_FORWARDED_PROTO_HEADER_KEY)
.and_then(|scheme| scheme.to_str().ok())
{
return Ok(Scheme(scheme.to_owned()));
return Ok(Scheme(SpoofableValue::new(scheme.to_owned())));
}

// From parts of an HTTP/2 request
if let Some(scheme) = parts.uri.scheme_str() {
return Ok(Scheme(scheme.to_owned()));
return Ok(Scheme(SpoofableValue::new(scheme.to_owned())));
}

Err(SchemeMissing)
Expand Down Expand Up @@ -89,7 +91,7 @@ mod tests {

fn test_client() -> TestClient {
async fn scheme_as_body(Scheme(scheme): Scheme) -> String {
scheme
scheme.spoofable_value()
}

TestClient::new(Router::new().route("/", get(scheme_as_body)))
Expand Down
2 changes: 1 addition & 1 deletion examples/tls-graceful-shutdown/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ where
}

let redirect = move |Host(host): Host, uri: Uri| async move {
match make_https(host, uri, ports) {
match make_https(host.spoofable_value(), uri, ports) {
Ok(uri) => Ok(Redirect::permanent(&uri.to_string())),
Err(error) => {
tracing::warn!(%error, "failed to convert URI to HTTPS");
Expand Down
2 changes: 1 addition & 1 deletion examples/tls-rustls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async fn redirect_http_to_https(ports: Ports) {
}

let redirect = move |Host(host): Host, uri: Uri| async move {
match make_https(host, uri, ports) {
match make_https(host.spoofable_value(), uri, ports) {
Ok(uri) => Ok(Redirect::permanent(&uri.to_string())),
Err(error) => {
tracing::warn!(%error, "failed to convert URI to HTTPS");
Expand Down
Loading