Skip to content

Commit

Permalink
feat: add CloudFront header to Insecure and Secure
Browse files Browse the repository at this point in the history
  • Loading branch information
jreppnow committed Apr 8, 2024
1 parent dcbe03e commit ed4052a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/insecure.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "aws-cloudfront")]
use crate::rudimental::CloudFrontViewerAddress;
use crate::rudimental::{
CfConnectingIp, FlyClientIp, Forwarded, MultiIpHeader, SingleIpHeader, TrueClientIp,
XForwardedFor, XRealIp,
Expand Down Expand Up @@ -43,12 +45,17 @@ impl InsecureClientIp {
headers: &HeaderMap<HeaderValue>,
extensions: &Extensions,
) -> Result<Self, Rejection> {
XForwardedFor::maybe_leftmost_ip(headers)
let maybe_ip = XForwardedFor::maybe_leftmost_ip(headers)
.or_else(|| Forwarded::maybe_leftmost_ip(headers))
.or_else(|| XRealIp::maybe_ip_from_headers(headers))
.or_else(|| FlyClientIp::maybe_ip_from_headers(headers))
.or_else(|| TrueClientIp::maybe_ip_from_headers(headers))
.or_else(|| CfConnectingIp::maybe_ip_from_headers(headers))
.or_else(|| CfConnectingIp::maybe_ip_from_headers(headers));

#[cfg(feature = "aws-cloudfront")]
let maybe_ip = maybe_ip.or_else(|| CloudFrontViewerAddress::maybe_ip_from_headers(headers));

maybe_ip
.or_else(|| maybe_connect_info(extensions))
.map(Self)
.ok_or((
Expand Down
11 changes: 11 additions & 0 deletions src/secure.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "aws-cloudfront")]
use crate::rudimental::CloudFrontViewerAddress;
use crate::rudimental::{
CfConnectingIp, FlyClientIp, Forwarded, MultiIpHeader, SingleIpHeader, StringRejection,
TrueClientIp, XForwardedFor, XRealIp,
Expand Down Expand Up @@ -44,6 +46,9 @@ pub enum SecureClientIpSource {
CfConnectingIp,
/// IP from the [`axum::extract::ConnectInfo`]
ConnectInfo,
/// IP from the `CloudFront-Viewer-Address` header
#[cfg(feature = "aws-cloudfront")]
CloudFrontViewerAddress,
}

impl SecureClientIpSource {
Expand Down Expand Up @@ -77,6 +82,8 @@ impl FromStr for SecureClientIpSource {
"TrueClientIp" => Self::TrueClientIp,
"CfConnectingIp" => Self::CfConnectingIp,
"ConnectInfo" => Self::ConnectInfo,
#[cfg(feature = "aws-cloudfront")]
"CloudFrontViewerAddress" => Self::CloudFrontViewerAddress,
_ => return Err(ParseSecureClientIpSourceError(s.to_string())),
})
}
Expand All @@ -100,6 +107,10 @@ impl SecureClientIp {
SecureClientIpSource::FlyClientIp => FlyClientIp::ip_from_headers(headers),
SecureClientIpSource::TrueClientIp => TrueClientIp::ip_from_headers(headers),
SecureClientIpSource::CfConnectingIp => CfConnectingIp::ip_from_headers(headers),
#[cfg(feature = "aws-cloudfront")]
SecureClientIpSource::CloudFrontViewerAddress => {
CloudFrontViewerAddress::ip_from_headers(headers)
}
SecureClientIpSource::ConnectInfo => extensions
.get::<ConnectInfo<SocketAddr>>()
.map(|ConnectInfo(addr)| addr.ip())
Expand Down

0 comments on commit ed4052a

Please sign in to comment.