Skip to content

Commit

Permalink
PoC of introducing SpoofableValue
Browse files Browse the repository at this point in the history
PoC to check which solution to pick for #2998
  • Loading branch information
yanns committed Oct 20, 2024
1 parent ffeb4f9 commit 591d9c9
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 14 deletions.
17 changes: 10 additions & 7 deletions axum-extra/src/extract/host.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use super::rejection::{FailedToResolveHost, HostRejection};
use super::{
rejection::{FailedToResolveHost, HostRejection},
SpoofableValue,
};
use axum::extract::FromRequestParts;
use http::{
header::{HeaderMap, FORWARDED},
Expand All @@ -18,7 +21,7 @@ const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
/// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make
/// sure to validate them to avoid security issues.
#[derive(Debug, Clone)]
pub struct Host(pub String);
pub struct Host(pub SpoofableValue);

impl<S> FromRequestParts<S> for Host
where
Expand All @@ -28,27 +31,27 @@ where

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(host) = parse_forwarded(&parts.headers) {
return Ok(Host(host.to_owned()));
return Ok(Host(SpoofableValue::new(host.to_owned())));
}

if let Some(host) = parts
.headers
.get(X_FORWARDED_HOST_HEADER_KEY)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
return Ok(Host(SpoofableValue::new(host.to_owned())));
}

if let Some(host) = parts
.headers
.get(http::header::HOST)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
return Ok(Host(SpoofableValue::new(host.to_owned())));
}

if let Some(host) = parts.uri.host() {
return Ok(Host(host.to_owned()));
return Ok(Host(SpoofableValue::new(host.to_owned())));
}

Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
Expand Down Expand Up @@ -81,7 +84,7 @@ mod tests {

fn test_client() -> TestClient {
async fn host_as_body(Host(host): Host) -> String {
host
host.spoofable_value()
}

TestClient::new(Router::new().route("/", get(host_as_body)))
Expand Down
13 changes: 13 additions & 0 deletions axum-extra/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,16 @@ pub use crate::json_lines::JsonLines;
#[cfg(feature = "typed-header")]
#[doc(no_inline)]
pub use crate::typed_header::TypedHeader;

#[derive(Debug, Clone)]
pub struct SpoofableValue(String);

impl SpoofableValue {
pub fn new(value: String) -> Self {
Self(value)
}

pub fn spoofable_value(self) -> String {
self.0
}
}
12 changes: 7 additions & 5 deletions axum-extra/src/extract/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use http::{
header::{HeaderMap, FORWARDED},
request::Parts,
};

use super::SpoofableValue;
const X_FORWARDED_PROTO_HEADER_KEY: &str = "X-Forwarded-Proto";

/// Extractor that resolves the scheme / protocol of a request.
Expand All @@ -21,7 +23,7 @@ const X_FORWARDED_PROTO_HEADER_KEY: &str = "X-Forwarded-Proto";
/// Note that user agents can set the `X-Forwarded-Proto` header to arbitrary values so make
/// sure to validate them to avoid security issues.
#[derive(Debug, Clone)]
pub struct Scheme(pub String);
pub struct Scheme(pub SpoofableValue);

/// Rejection type used if the [`Scheme`] extractor is unable to
/// resolve a scheme.
Expand All @@ -43,7 +45,7 @@ where
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// Within Forwarded header
if let Some(scheme) = parse_forwarded(&parts.headers) {
return Ok(Scheme(scheme.to_owned()));
return Ok(Scheme(SpoofableValue::new(scheme.to_owned())));
}

// X-Forwarded-Proto
Expand All @@ -52,12 +54,12 @@ where
.get(X_FORWARDED_PROTO_HEADER_KEY)
.and_then(|scheme| scheme.to_str().ok())
{
return Ok(Scheme(scheme.to_owned()));
return Ok(Scheme(SpoofableValue::new(scheme.to_owned())));
}

// From parts of an HTTP/2 request
if let Some(scheme) = parts.uri.scheme_str() {
return Ok(Scheme(scheme.to_owned()));
return Ok(Scheme(SpoofableValue::new(scheme.to_owned())));
}

Err(SchemeMissing)
Expand Down Expand Up @@ -89,7 +91,7 @@ mod tests {

fn test_client() -> TestClient {
async fn scheme_as_body(Scheme(scheme): Scheme) -> String {
scheme
scheme.spoofable_value()
}

TestClient::new(Router::new().route("/", get(scheme_as_body)))
Expand Down
2 changes: 1 addition & 1 deletion examples/tls-graceful-shutdown/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ where
}

let redirect = move |Host(host): Host, uri: Uri| async move {
match make_https(host, uri, ports) {
match make_https(host.spoofable_value(), uri, ports) {
Ok(uri) => Ok(Redirect::permanent(&uri.to_string())),
Err(error) => {
tracing::warn!(%error, "failed to convert URI to HTTPS");
Expand Down
2 changes: 1 addition & 1 deletion examples/tls-rustls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async fn redirect_http_to_https(ports: Ports) {
}

let redirect = move |Host(host): Host, uri: Uri| async move {
match make_https(host, uri, ports) {
match make_https(host.spoofable_value(), uri, ports) {
Ok(uri) => Ok(Redirect::permanent(&uri.to_string())),
Err(error) => {
tracing::warn!(%error, "failed to convert URI to HTTPS");
Expand Down

0 comments on commit 591d9c9

Please sign in to comment.