Skip to content

Commit

Permalink
client/redirect: expose previous and next request methods
Browse files Browse the repository at this point in the history
This augments `redirect::Attempt` in order to expose two
relevant HTTP methods during redirects:
 - the methods for the previous request.
 - the method to be performed on the next request upon
   following a redirection.

The two methods can different, notably when redirecting `POST`
requests.

Ref: https://en.wikipedia.org/wiki/Post/Redirect/Get
  • Loading branch information
lucab committed Sep 23, 2024
1 parent d85f44b commit 79522a0
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 10 deletions.
12 changes: 8 additions & 4 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2599,6 +2599,7 @@ impl Future for PendingRequest {
}
}
}
let previous_method = self.method.clone();
let should_redirect = match res.status() {
StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND | StatusCode::SEE_OTHER => {
self.body = None;
Expand Down Expand Up @@ -2660,10 +2661,13 @@ impl Future for PendingRequest {
}
let url = self.url.clone();
self.as_mut().urls().push(url);
let action = self
.client
.redirect_policy
.check(res.status(), &loc, &self.urls);
let action = self.client.redirect_policy.check(
res.status(),
&self.method,
&loc,
&previous_method,
&self.urls,
);

match action {
redirect::ActionKind::Follow => {
Expand Down
72 changes: 66 additions & 6 deletions src/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::error::Error as StdError;
use std::fmt;

use crate::header::{HeaderMap, AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, WWW_AUTHENTICATE};
use crate::Method;
use hyper::StatusCode;

use crate::Url;
Expand All @@ -30,7 +31,9 @@ pub struct Policy {
#[derive(Debug)]
pub struct Attempt<'a> {
status: StatusCode,
next_method: &'a Method,
next: &'a Url,
previous_method: &'a Method,
previous: &'a [Url],
}

Expand Down Expand Up @@ -138,10 +141,19 @@ impl Policy {
}
}

pub(crate) fn check(&self, status: StatusCode, next: &Url, previous: &[Url]) -> ActionKind {
pub(crate) fn check(
&self,
status: StatusCode,
next_method: &Method,
next: &Url,
previous_method: &Method,
previous: &[Url],
) -> ActionKind {
self.redirect(Attempt {
status,
next_method,
next,
previous_method,
previous,
})
.inner
Expand All @@ -165,15 +177,26 @@ impl<'a> Attempt<'a> {
self.status
}

/// Get the method for the next request, after applying redirection logic.
pub fn next_method(&self) -> &Method {
self.next_method
}

/// Get the next URL to redirect to.
pub fn url(&self) -> &Url {
self.next
}

/// Get the method for the previous request, before redirection.
pub fn previous_method(&self) -> &Method {
self.previous_method
}

/// Get the list of previous URLs that have already been requested in this chain.
pub fn previous(&self) -> &[Url] {
self.previous
}

/// Returns an action meaning reqwest should follow the next URL.
pub fn follow(self) -> Action {
Action {
Expand Down Expand Up @@ -264,14 +287,26 @@ fn test_redirect_policy_limit() {
.map(|i| Url::parse(&format!("http://a.b/c/{i}")).unwrap())
.collect::<Vec<_>>();

match policy.check(StatusCode::FOUND, &next, &previous) {
match policy.check(
StatusCode::FOUND,
&Method::GET,
&next,
&Method::GET,
&previous,
) {
ActionKind::Follow => (),
other => panic!("unexpected {other:?}"),
}

previous.push(Url::parse("http://a.b.d/e/33").unwrap());

match policy.check(StatusCode::FOUND, &next, &previous) {
match policy.check(
StatusCode::FOUND,
&Method::GET,
&next,
&Method::GET,
&previous,
) {
ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
other => panic!("unexpected {other:?}"),
}
Expand All @@ -283,7 +318,13 @@ fn test_redirect_policy_limit_to_0() {
let next = Url::parse("http://x.y/z").unwrap();
let previous = vec![Url::parse("http://a.b/c").unwrap()];

match policy.check(StatusCode::FOUND, &next, &previous) {
match policy.check(
StatusCode::FOUND,
&Method::GET,
&next,
&Method::GET,
&previous,
) {
ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
other => panic!("unexpected {other:?}"),
}
Expand All @@ -300,13 +341,13 @@ fn test_redirect_policy_custom() {
});

let next = Url::parse("http://bar/baz").unwrap();
match policy.check(StatusCode::FOUND, &next, &[]) {
match policy.check(StatusCode::FOUND, &Method::GET, &next, &Method::GET, &[]) {
ActionKind::Follow => (),
other => panic!("unexpected {other:?}"),
}

let next = Url::parse("http://foo/baz").unwrap();
match policy.check(StatusCode::FOUND, &next, &[]) {
match policy.check(StatusCode::FOUND, &Method::GET, &next, &Method::GET, &[]) {
ActionKind::Stop => (),
other => panic!("unexpected {other:?}"),
}
Expand Down Expand Up @@ -335,3 +376,22 @@ fn test_remove_sensitive_headers() {
remove_sensitive_headers(&mut headers, &next, &prev);
assert_eq!(headers, filtered_headers);
}

#[test]
fn test_redirect_custom_policy_methods() {
let policy = Policy::custom(|attempt| {
let next = attempt.next_method();
if next != Method::HEAD {
panic!("unexpected next method {:?}", next);
}
let prev = attempt.previous_method();
if prev != Method::PUT {
panic!("unexpected previous method {:?}", prev);
}
attempt.stop()
});

let next = Url::parse("http://bar/baz").unwrap();
let res = policy.check(StatusCode::FOUND, &Method::HEAD, &next, &Method::PUT, &[]);
assert!(matches!(res, ActionKind::Stop));
}
46 changes: 46 additions & 0 deletions tests/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,52 @@ async fn test_redirect_307_and_308_tries_to_post_again() {
}
}

#[tokio::test]
async fn test_redirect_custom_policy_previous_next_methods() {
use reqwest::{Method, StatusCode};

let codes = [301u16, 307];
for &code in &codes {
let redirect = server::http(move |req| async move {
if req.method() == "POST" {
assert_eq!(req.uri(), &*format!("/{code}"));
http::Response::builder()
.status(code)
.header("location", "/dst")
.header("server", "test-redirect")
.body(Body::default())
.unwrap()
} else {
panic!("unexpected method {}", req.method());
}
});

let policy = reqwest::redirect::Policy::custom(|attempt| {
if attempt.previous_method() == Method::POST
&& attempt.status() == StatusCode::MOVED_PERMANENTLY
{
assert_eq!(attempt.next_method(), Method::GET);
} else if attempt.previous_method() == Method::POST
&& attempt.status() == StatusCode::TEMPORARY_REDIRECT
{
assert_eq!(attempt.next_method(), Method::POST);
} else {
panic!("unexpected attempt: {:?}", attempt);
}
attempt.stop()
});

let url = format!("http://{}/{}", redirect.addr(), code);
let client = reqwest::ClientBuilder::new()
.redirect(policy)
.build()
.unwrap();
let res = client.post(&url).send().await.unwrap();
assert_eq!(res.url().as_str(), url);
assert_eq!(res.status(), StatusCode::from_u16(code).unwrap());
}
}

#[cfg(feature = "blocking")]
#[test]
fn test_redirect_307_does_not_try_if_reader_cannot_reset() {
Expand Down

0 comments on commit 79522a0

Please sign in to comment.