diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 9a34f3fb6..54750a6f7 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -2599,6 +2599,7 @@ impl Future for PendingRequest { } } } + let previous_method = self.method.clone(); let should_redirect = match res.status() { StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND | StatusCode::SEE_OTHER => { self.body = None; @@ -2660,10 +2661,13 @@ impl Future for PendingRequest { } let url = self.url.clone(); self.as_mut().urls().push(url); - let action = self - .client - .redirect_policy - .check(res.status(), &loc, &self.urls); + let action = self.client.redirect_policy.check( + res.status(), + &self.method, + &loc, + &previous_method, + &self.urls, + ); match action { redirect::ActionKind::Follow => { diff --git a/src/redirect.rs b/src/redirect.rs index 50a096f9f..b2fa6db2a 100644 --- a/src/redirect.rs +++ b/src/redirect.rs @@ -8,6 +8,7 @@ use std::error::Error as StdError; use std::fmt; use crate::header::{HeaderMap, AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, WWW_AUTHENTICATE}; +use crate::Method; use hyper::StatusCode; use crate::Url; @@ -30,7 +31,9 @@ pub struct Policy { #[derive(Debug)] pub struct Attempt<'a> { status: StatusCode, + next_method: &'a Method, next: &'a Url, + previous_method: &'a Method, previous: &'a [Url], } @@ -138,10 +141,19 @@ impl Policy { } } - pub(crate) fn check(&self, status: StatusCode, next: &Url, previous: &[Url]) -> ActionKind { + pub(crate) fn check( + &self, + status: StatusCode, + next_method: &Method, + next: &Url, + previous_method: &Method, + previous: &[Url], + ) -> ActionKind { self.redirect(Attempt { status, + next_method, next, + previous_method, previous, }) .inner @@ -165,15 +177,26 @@ impl<'a> Attempt<'a> { self.status } + /// Get the method for the next request, after applying redirection logic. + pub fn next_method(&self) -> &Method { + self.next_method + } + /// Get the next URL to redirect to. pub fn url(&self) -> &Url { self.next } + /// Get the method for the previous request, before redirection. + pub fn previous_method(&self) -> &Method { + self.previous_method + } + /// Get the list of previous URLs that have already been requested in this chain. pub fn previous(&self) -> &[Url] { self.previous } + /// Returns an action meaning reqwest should follow the next URL. pub fn follow(self) -> Action { Action { @@ -264,14 +287,26 @@ fn test_redirect_policy_limit() { .map(|i| Url::parse(&format!("http://a.b/c/{i}")).unwrap()) .collect::>(); - match policy.check(StatusCode::FOUND, &next, &previous) { + match policy.check( + StatusCode::FOUND, + &Method::GET, + &next, + &Method::GET, + &previous, + ) { ActionKind::Follow => (), other => panic!("unexpected {other:?}"), } previous.push(Url::parse("http://a.b.d/e/33").unwrap()); - match policy.check(StatusCode::FOUND, &next, &previous) { + match policy.check( + StatusCode::FOUND, + &Method::GET, + &next, + &Method::GET, + &previous, + ) { ActionKind::Error(err) if err.is::() => (), other => panic!("unexpected {other:?}"), } @@ -283,7 +318,13 @@ fn test_redirect_policy_limit_to_0() { let next = Url::parse("http://x.y/z").unwrap(); let previous = vec![Url::parse("http://a.b/c").unwrap()]; - match policy.check(StatusCode::FOUND, &next, &previous) { + match policy.check( + StatusCode::FOUND, + &Method::GET, + &next, + &Method::GET, + &previous, + ) { ActionKind::Error(err) if err.is::() => (), other => panic!("unexpected {other:?}"), } @@ -300,13 +341,13 @@ fn test_redirect_policy_custom() { }); let next = Url::parse("http://bar/baz").unwrap(); - match policy.check(StatusCode::FOUND, &next, &[]) { + match policy.check(StatusCode::FOUND, &Method::GET, &next, &Method::GET, &[]) { ActionKind::Follow => (), other => panic!("unexpected {other:?}"), } let next = Url::parse("http://foo/baz").unwrap(); - match policy.check(StatusCode::FOUND, &next, &[]) { + match policy.check(StatusCode::FOUND, &Method::GET, &next, &Method::GET, &[]) { ActionKind::Stop => (), other => panic!("unexpected {other:?}"), } @@ -335,3 +376,22 @@ fn test_remove_sensitive_headers() { remove_sensitive_headers(&mut headers, &next, &prev); assert_eq!(headers, filtered_headers); } + +#[test] +fn test_redirect_custom_policy_methods() { + let policy = Policy::custom(|attempt| { + let next = attempt.next_method(); + if next != Method::HEAD { + panic!("unexpected next method {:?}", next); + } + let prev = attempt.previous_method(); + if prev != Method::PUT { + panic!("unexpected previous method {:?}", prev); + } + attempt.stop() + }); + + let next = Url::parse("http://bar/baz").unwrap(); + let res = policy.check(StatusCode::FOUND, &Method::HEAD, &next, &Method::PUT, &[]); + assert!(matches!(res, ActionKind::Stop)); +} diff --git a/tests/redirect.rs b/tests/redirect.rs index c496d90d3..efb398fb2 100644 --- a/tests/redirect.rs +++ b/tests/redirect.rs @@ -127,6 +127,52 @@ async fn test_redirect_307_and_308_tries_to_post_again() { } } +#[tokio::test] +async fn test_redirect_custom_policy_previous_next_methods() { + use reqwest::{Method, StatusCode}; + + let codes = [301u16, 307]; + for &code in &codes { + let redirect = server::http(move |req| async move { + if req.method() == "POST" { + assert_eq!(req.uri(), &*format!("/{code}")); + http::Response::builder() + .status(code) + .header("location", "/dst") + .header("server", "test-redirect") + .body(Body::default()) + .unwrap() + } else { + panic!("unexpected method {}", req.method()); + } + }); + + let policy = reqwest::redirect::Policy::custom(|attempt| { + if attempt.previous_method() == Method::POST + && attempt.status() == StatusCode::MOVED_PERMANENTLY + { + assert_eq!(attempt.next_method(), Method::GET); + } else if attempt.previous_method() == Method::POST + && attempt.status() == StatusCode::TEMPORARY_REDIRECT + { + assert_eq!(attempt.next_method(), Method::POST); + } else { + panic!("unexpected attempt: {:?}", attempt); + } + attempt.stop() + }); + + let url = format!("http://{}/{}", redirect.addr(), code); + let client = reqwest::ClientBuilder::new() + .redirect(policy) + .build() + .unwrap(); + let res = client.post(&url).send().await.unwrap(); + assert_eq!(res.url().as_str(), url); + assert_eq!(res.status(), StatusCode::from_u16(code).unwrap()); + } +} + #[cfg(feature = "blocking")] #[test] fn test_redirect_307_does_not_try_if_reader_cannot_reset() {