diff --git a/Cargo.toml b/Cargo.toml index f59bab9..b6511dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "axum-test-helper" -version = "0.4.0" +version = "0.5.0" edition = "2021" categories = ["development-tools::testing"] description = "Extra utilities for axum" @@ -11,16 +11,17 @@ readme = "README.md" repository = "https://github.com/cloudwalk/axum-test-helper" [dependencies] -axum = "0.7" -reqwest = { version = "0.11.23", features = ["json", "stream", "multipart", "rustls-tls"], default-features = false } +axum = "0.8" +reqwest = { version = "0.12.12", features = ["json", "stream", "multipart", "rustls-tls"], default-features = false } http = "1.0.0" http-body = "0.4" bytes = "1.4.0" -tower = "0.4.13" +tower = "0.5.2" tower-service = "0.3" serde = "1.0" tokio = "1" hyper = "1" +futures-util = "0.3.31" [dev-dependencies] serde = { version = "1", features = ["serde_derive"] } diff --git a/src/lib.rs b/src/lib.rs index ad8d617..ab6268b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,44 +1,35 @@ -//! # Axum Test Helper -//! This is a hard copy from TestClient at axum -//! -//! ## Features -//! - `cookies` - Enables support for cookies in the test client. -//! - `withouttrace` - Disables tracing for the test client. -//! -//! ## Example -//! ```rust -//! use axum::Router; -//! use axum::http::StatusCode; -//! use axum::routing::get; -//! use axum_test_helper::TestClient; -//! -//! fn main() { -//! let async_block = async { -//! // you can replace this Router with your own app -//! let app = Router::new().route("/", get(|| async {})); -//! -//! // initiate the TestClient with the previous declared Router -//! let client = TestClient::new(app); -//! -//! let res = client.get("/").send().await; -//! assert_eq!(res.status(), StatusCode::OK); -//! }; -//! -//! // Create a runtime for executing the async block. This runtime is local -//! // to the main function and does not require any global setup. -//! let runtime = tokio::runtime::Builder::new_current_thread() -//! .enable_all() -//! .build() -//! .unwrap(); -//! -//! // Use the local runtime to block on the async block. -//! runtime.block_on(async_block); -//! } - use bytes::Bytes; -use http::StatusCode; -use std::net::SocketAddr; +use futures_util::future::BoxFuture; +use http::header::{HeaderName, HeaderValue}; +use std::ops::Deref; +use std::{convert::Infallible, future::IntoFuture, net::SocketAddr}; +use axum::extract::Request; +use axum::response::Response; +use axum::serve; use tokio::net::TcpListener; +use tower::make::Shared; +use tower_service::Service; + +pub(crate) fn spawn_service(svc: S) -> SocketAddr +where + S: Service + Clone + Send + 'static, + S::Future: Send, +{ + let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + std_listener.set_nonblocking(true).unwrap(); + let listener = TcpListener::from_std(std_listener).unwrap(); + + let addr = listener.local_addr().unwrap(); + println!("Listening on {addr}"); + + tokio::spawn(async move { + serve(listener, Shared::new(svc)) + .await + .expect("server error") + }); + + addr +} pub struct TestClient { client: reqwest::Client, @@ -46,27 +37,13 @@ pub struct TestClient { } impl TestClient { - pub async fn new(svc: axum::Router) -> Self { - let listener = TcpListener::bind("127.0.0.1:0") - .await - .expect("Could not bind ephemeral socket"); - let addr = listener.local_addr().unwrap(); - #[cfg(feature = "withtrace")] - println!("Listening on {}", addr); - - tokio::spawn(async move { - let server = axum::serve(listener, svc); - server.await.expect("server error"); - }); - - #[cfg(feature = "cookies")] - let client = reqwest::Client::builder() - .redirect(reqwest::redirect::Policy::none()) - .cookie_store(true) - .build() - .unwrap(); + pub fn new(svc: S) -> Self + where + S: Service + Clone + Send + 'static, + S::Future: Send, + { + let addr = spawn_service(svc); - #[cfg(not(feature = "cookies"))] let client = reqwest::Client::builder() .redirect(reqwest::redirect::Policy::none()) .build() @@ -75,48 +52,41 @@ impl TestClient { TestClient { client, addr } } - /// returns the base URL (http://ip:port) for this TestClient - /// - /// this is useful when trying to check if Location headers in responses - /// are generated correctly as Location contains an absolute URL - pub fn base_url(&self) -> String { - format!("http://{}", self.addr) - } - pub fn get(&self, url: &str) -> RequestBuilder { RequestBuilder { - builder: self.client.get(format!("http://{}{}", self.addr, url)), + builder: self.client.get(format!("http://{}{url}", self.addr)), } } pub fn head(&self, url: &str) -> RequestBuilder { RequestBuilder { - builder: self.client.head(format!("http://{}{}", self.addr, url)), + builder: self.client.head(format!("http://{}{url}", self.addr)), } } pub fn post(&self, url: &str) -> RequestBuilder { RequestBuilder { - builder: self.client.post(format!("http://{}{}", self.addr, url)), + builder: self.client.post(format!("http://{}{url}", self.addr)), } } + #[allow(dead_code)] pub fn put(&self, url: &str) -> RequestBuilder { RequestBuilder { - builder: self.client.put(format!("http://{}{}", self.addr, url)), + builder: self.client.put(format!("http://{}{url}", self.addr)), } } + #[allow(dead_code)] pub fn patch(&self, url: &str) -> RequestBuilder { RequestBuilder { - builder: self.client.patch(format!("http://{}{}", self.addr, url)), + builder: self.client.patch(format!("http://{}{url}", self.addr)), } } - pub fn delete(&self, url: &str) -> RequestBuilder { - RequestBuilder { - builder: self.client.delete(format!("http://{}{}", self.addr, url)), - } + #[allow(dead_code)] + pub fn server_port(&self) -> u16 { + self.addr.port() } } @@ -125,22 +95,11 @@ pub struct RequestBuilder { } impl RequestBuilder { - pub async fn send(self) -> TestResponse { - TestResponse { - response: self.builder.send().await.unwrap(), - } - } - pub fn body(mut self, body: impl Into) -> Self { self.builder = self.builder.body(body); self } - pub fn form(mut self, form: &T) -> Self { - self.builder = self.builder.form(&form); - self - } - pub fn json(mut self, json: &T) -> Self where T: serde::Serialize, @@ -149,36 +108,67 @@ impl RequestBuilder { self } - pub fn header(mut self, key: &str, value: &str) -> Self { + pub fn header(mut self, key: K, value: V) -> Self + where + HeaderName: TryFrom, + >::Error: Into, + HeaderValue: TryFrom, + >::Error: Into, + { self.builder = self.builder.header(key, value); self } + #[allow(dead_code)] pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self { self.builder = self.builder.multipart(form); self } + + #[allow(dead_code)] + pub fn form(mut self, form: &[(&str, &str)]) -> Self { + self.builder = self.builder.form(form); + self + } } -/// A wrapper around [`reqwest::Response`] that provides common methods with internal `unwrap()`s. -/// -/// This is conventient for tests where panics are what you want. For access to -/// non-panicking versions or the complete `Response` API use `into_inner()` or -/// `as_ref()`. +impl IntoFuture for RequestBuilder { + type Output = TestResponse; + type IntoFuture = BoxFuture<'static, Self::Output>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(async { + TestResponse { + response: self.builder.send().await.unwrap(), + } + }) + } +} + +#[derive(Debug)] pub struct TestResponse { response: reqwest::Response, } -impl TestResponse { - pub async fn text(self) -> String { - self.response.text().await.unwrap() +impl Deref for TestResponse { + type Target = reqwest::Response; + + fn deref(&self) -> &Self::Target { + &self.response } +} +impl TestResponse { #[allow(dead_code)] pub async fn bytes(self) -> Bytes { self.response.bytes().await.unwrap() } + pub async fn text(self) -> String { + self.response.text().await.unwrap() + } + + #[allow(dead_code)] pub async fn json(self) -> T where T: serde::de::DeserializeOwned, @@ -186,14 +176,6 @@ impl TestResponse { self.response.json().await.unwrap() } - pub fn status(&self) -> StatusCode { - StatusCode::from_u16(self.response.status().as_u16()).unwrap() - } - - pub fn headers(&self) -> &reqwest::header::HeaderMap { - self.response.headers() - } - pub async fn chunk(&mut self) -> Option { self.response.chunk().await.unwrap() } @@ -234,17 +216,17 @@ mod tests { #[tokio::test] async fn test_get_request() { let app = Router::new().route("/", get(|| async {})); - let client = super::TestClient::new(app).await; - let res = client.get("/").send().await; + let client = super::TestClient::new(app); + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn test_post_form_request() { let app = Router::new().route("/", post(handle_form)); - let client = super::TestClient::new(app).await; + let client = super::TestClient::new(app); let form = [("val", "bar"), ("baz", "quux")]; - let res = client.post("/").form(&form).send().await; + let res = client.post("/").form(&form).await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "bar"); } @@ -261,7 +243,7 @@ mod tests { "/", post(|json_value: Json| async { json_value }), ); - let client = super::TestClient::new(app).await; + let client = super::TestClient::new(app); let payload = TestPayload { name: "Alice".to_owned(), age: 30, @@ -270,7 +252,6 @@ mod tests { .post("/") .header("Content-Type", "application/json") .json(&payload) - .send() .await; assert_eq!(res.status(), StatusCode::OK); let response_body: TestPayload = serde_json::from_str(&res.text().await).unwrap();