diff --git a/src/insecure.rs b/src/insecure.rs index 5e87fb9..a0ea914 100644 --- a/src/insecure.rs +++ b/src/insecure.rs @@ -1,6 +1,6 @@ use crate::rudimental::{ - CfConnectingIp, FlyClientIp, Forwarded, MultiIpHeader, SingleIpHeader, TrueClientIp, - XForwardedFor, XRealIp, + CfConnectingIp, CloudFrontViewerAddress, FlyClientIp, Forwarded, MultiIpHeader, SingleIpHeader, + TrueClientIp, XForwardedFor, XRealIp, }; use axum::{ async_trait, @@ -49,6 +49,7 @@ impl InsecureClientIp { .or_else(|| FlyClientIp::maybe_ip_from_headers(headers)) .or_else(|| TrueClientIp::maybe_ip_from_headers(headers)) .or_else(|| CfConnectingIp::maybe_ip_from_headers(headers)) + .or_else(|| CloudFrontViewerAddress::maybe_ip_from_headers(headers)) .or_else(|| maybe_connect_info(extensions)) .map(Self) .ok_or(( diff --git a/src/rudimental.rs b/src/rudimental.rs index 024839c..041f98f 100644 --- a/src/rudimental.rs +++ b/src/rudimental.rs @@ -68,6 +68,12 @@ pub struct TrueClientIp(pub IpAddr); #[derive(Debug)] pub struct CfConnectingIp(pub IpAddr); +/// Extracts a valid IP from `CloudFront-Viewer-Address` (AWS CloudFront) header +/// +/// Rejects with a 500 error if the header is absent or the IP isn't valid +#[derive(Debug)] +pub struct CloudFrontViewerAddress(pub IpAddr); + pub(crate) trait SingleIpHeader { const HEADER: &'static str; @@ -162,6 +168,38 @@ impl_single_header!(FlyClientIp, "Fly-Client-IP"); impl_single_header!(TrueClientIp, "True-Client-IP"); impl_single_header!(CfConnectingIp, "CF-Connecting-IP"); +impl SingleIpHeader for CloudFrontViewerAddress { + const HEADER: &'static str = "cloudfront-viewer-address"; + + fn maybe_ip_from_headers(headers: &HeaderMap) -> Option { + headers + .get(Self::HEADER) + .and_then(|hv| hv.to_str().ok()) + // Spec: https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/adding-cloudfront-headers.html#cloudfront-headers-viewer-location + // Note: Both IPv4 and IPv6 addresses (in the specified format) do not contain + // non-ascii characters, so no need to handle percent-encoding. + // + // CloudFront does not use `[::]:12345` style notation for IPv6 (unfortunately), + // otherwise parsing via `SocketAddr` would be possible. + .and_then(|hv| hv.rsplit_once(':').map(|(ip, _port)| ip)) + .and_then(|s| s.parse::().ok()) + } +} + +#[async_trait] +impl FromRequestParts for CloudFrontViewerAddress +where + S: Sync, +{ + type Rejection = StringRejection; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + Ok(Self( + Self::maybe_ip_from_headers(&parts.headers).ok_or_else(Self::rejection)?, + )) + } +} + impl MultiIpHeader for XForwardedFor { const HEADER: &'static str = "X-Forwarded-For"; @@ -532,4 +570,54 @@ mod tests { let res = app().oneshot(req).await.unwrap(); assert_eq!(body_string(res.into_body()).await, "192.0.2.60"); } + + #[tokio::test] + async fn cloudfront_viewer_address_ipv4() { + fn app() -> Router { + Router::new().route( + "/", + get(|ip: super::CloudFrontViewerAddress| async move { ip.0.to_string() }), + ) + } + + let req = Request::builder().uri("/").body(Body::empty()).unwrap(); + let res = app().oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + + let req = Request::builder() + .uri("/") + .header("CloudFront-Viewer-Address", "198.51.100.10:46532") + .body(Body::empty()) + .unwrap(); + let res = app().oneshot(req).await.unwrap(); + assert_eq!(body_string(res.into_body()).await, "198.51.100.10"); + } + + #[tokio::test] + async fn cloudfront_viewer_address_ipv6() { + fn app() -> Router { + Router::new().route( + "/", + get(|ip: super::CloudFrontViewerAddress| async move { ip.0.to_string() }), + ) + } + + let req = Request::builder().uri("/").body(Body::empty()).unwrap(); + let res = app().oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + + let req = Request::builder() + .uri("/") + .header( + "CloudFront-Viewer-Address", + "2a09:bac1:3b20:38::17e:7:51786", + ) + .body(Body::empty()) + .unwrap(); + let res = app().oneshot(req).await.unwrap(); + assert_eq!( + body_string(res.into_body()).await, + "2a09:bac1:3b20:38::17e:7" + ); + } } diff --git a/src/secure.rs b/src/secure.rs index 786b03e..9e095d9 100644 --- a/src/secure.rs +++ b/src/secure.rs @@ -1,6 +1,6 @@ use crate::rudimental::{ - CfConnectingIp, FlyClientIp, Forwarded, MultiIpHeader, SingleIpHeader, StringRejection, - TrueClientIp, XForwardedFor, XRealIp, + CfConnectingIp, CloudFrontViewerAddress, FlyClientIp, Forwarded, MultiIpHeader, SingleIpHeader, + StringRejection, TrueClientIp, XForwardedFor, XRealIp, }; use axum::async_trait; use axum::extract::{ConnectInfo, Extension, FromRequestParts}; @@ -44,6 +44,8 @@ pub enum SecureClientIpSource { CfConnectingIp, /// IP from the [`axum::extract::ConnectInfo`] ConnectInfo, + /// IP from the `CloudFront-Viewer-Address` header + CloudFrontViewerAddress, } impl SecureClientIpSource { @@ -77,6 +79,7 @@ impl FromStr for SecureClientIpSource { "TrueClientIp" => Self::TrueClientIp, "CfConnectingIp" => Self::CfConnectingIp, "ConnectInfo" => Self::ConnectInfo, + "CloudFrontViewerAddress" => Self::CloudFrontViewerAddress, _ => return Err(ParseSecureClientIpSourceError(s.to_string())), }) } @@ -100,6 +103,9 @@ impl SecureClientIp { SecureClientIpSource::FlyClientIp => FlyClientIp::ip_from_headers(headers), SecureClientIpSource::TrueClientIp => TrueClientIp::ip_from_headers(headers), SecureClientIpSource::CfConnectingIp => CfConnectingIp::ip_from_headers(headers), + SecureClientIpSource::CloudFrontViewerAddress => { + CloudFrontViewerAddress::ip_from_headers(headers) + } SecureClientIpSource::ConnectInfo => extensions .get::>() .map(|ConnectInfo(addr)| addr.ip())