Skip to content

Commit

Permalink
feat: axum v0.8 support
Browse files Browse the repository at this point in the history
  • Loading branch information
devmaxde committed Jan 3, 2025
1 parent 2e18cab commit 647398f
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 114 deletions.
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "axum-test-helper"
version = "0.4.0"
version = "0.5.0"
edition = "2021"
categories = ["development-tools::testing"]
description = "Extra utilities for axum"
Expand All @@ -11,16 +11,17 @@ readme = "README.md"
repository = "https://github.com/cloudwalk/axum-test-helper"

[dependencies]
axum = "0.7"
reqwest = { version = "0.11.23", features = ["json", "stream", "multipart", "rustls-tls"], default-features = false }
axum = "0.8"
reqwest = { version = "0.12.12", features = ["json", "stream", "multipart", "rustls-tls"], default-features = false }
http = "1.0.0"
http-body = "0.4"
bytes = "1.4.0"
tower = "0.4.13"
tower = "0.5.2"
tower-service = "0.3"
serde = "1.0"
tokio = "1"
hyper = "1"
futures-util = "0.3.31"

[dev-dependencies]
serde = { version = "1", features = ["serde_derive"] }
Expand Down
201 changes: 91 additions & 110 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,72 +1,49 @@
//! # Axum Test Helper
//! This is a hard copy from TestClient at axum
//!
//! ## Features
//! - `cookies` - Enables support for cookies in the test client.
//! - `withouttrace` - Disables tracing for the test client.
//!
//! ## Example
//! ```rust
//! use axum::Router;
//! use axum::http::StatusCode;
//! use axum::routing::get;
//! use axum_test_helper::TestClient;
//!
//! fn main() {
//! let async_block = async {
//! // you can replace this Router with your own app
//! let app = Router::new().route("/", get(|| async {}));
//!
//! // initiate the TestClient with the previous declared Router
//! let client = TestClient::new(app);
//!
//! let res = client.get("/").send().await;
//! assert_eq!(res.status(), StatusCode::OK);
//! };
//!
//! // Create a runtime for executing the async block. This runtime is local
//! // to the main function and does not require any global setup.
//! let runtime = tokio::runtime::Builder::new_current_thread()
//! .enable_all()
//! .build()
//! .unwrap();
//!
//! // Use the local runtime to block on the async block.
//! runtime.block_on(async_block);
//! }
use bytes::Bytes;
use http::StatusCode;
use std::net::SocketAddr;
use futures_util::future::BoxFuture;
use http::header::{HeaderName, HeaderValue};
use std::ops::Deref;
use std::{convert::Infallible, future::IntoFuture, net::SocketAddr};
use axum::extract::Request;
use axum::response::Response;
use axum::serve;
use tokio::net::TcpListener;
use tower::make::Shared;
use tower_service::Service;

pub(crate) fn spawn_service<S>(svc: S) -> SocketAddr
where
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
std_listener.set_nonblocking(true).unwrap();
let listener = TcpListener::from_std(std_listener).unwrap();

let addr = listener.local_addr().unwrap();
println!("Listening on {addr}");

tokio::spawn(async move {
serve(listener, Shared::new(svc))
.await
.expect("server error")
});

addr
}

pub struct TestClient {
client: reqwest::Client,
addr: SocketAddr,
}

impl TestClient {
pub async fn new(svc: axum::Router) -> Self {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("Could not bind ephemeral socket");
let addr = listener.local_addr().unwrap();
#[cfg(feature = "withtrace")]
println!("Listening on {}", addr);

tokio::spawn(async move {
let server = axum::serve(listener, svc);
server.await.expect("server error");
});

#[cfg(feature = "cookies")]
let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.cookie_store(true)
.build()
.unwrap();
pub fn new<S>(svc: S) -> Self
where
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
let addr = spawn_service(svc);

#[cfg(not(feature = "cookies"))]
let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
Expand All @@ -75,48 +52,41 @@ impl TestClient {
TestClient { client, addr }
}

/// returns the base URL (http://ip:port) for this TestClient
///
/// this is useful when trying to check if Location headers in responses
/// are generated correctly as Location contains an absolute URL
pub fn base_url(&self) -> String {
format!("http://{}", self.addr)
}

pub fn get(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.get(format!("http://{}{}", self.addr, url)),
builder: self.client.get(format!("http://{}{url}", self.addr)),
}
}

pub fn head(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.head(format!("http://{}{}", self.addr, url)),
builder: self.client.head(format!("http://{}{url}", self.addr)),
}
}

pub fn post(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.post(format!("http://{}{}", self.addr, url)),
builder: self.client.post(format!("http://{}{url}", self.addr)),
}
}

#[allow(dead_code)]
pub fn put(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.put(format!("http://{}{}", self.addr, url)),
builder: self.client.put(format!("http://{}{url}", self.addr)),
}
}

#[allow(dead_code)]
pub fn patch(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.patch(format!("http://{}{}", self.addr, url)),
builder: self.client.patch(format!("http://{}{url}", self.addr)),
}
}

pub fn delete(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.delete(format!("http://{}{}", self.addr, url)),
}
#[allow(dead_code)]
pub fn server_port(&self) -> u16 {
self.addr.port()
}
}

Expand All @@ -125,22 +95,11 @@ pub struct RequestBuilder {
}

impl RequestBuilder {
pub async fn send(self) -> TestResponse {
TestResponse {
response: self.builder.send().await.unwrap(),
}
}

pub fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
self.builder = self.builder.body(body);
self
}

pub fn form<T: serde::Serialize + ?Sized>(mut self, form: &T) -> Self {
self.builder = self.builder.form(&form);
self
}

pub fn json<T>(mut self, json: &T) -> Self
where
T: serde::Serialize,
Expand All @@ -149,51 +108,74 @@ impl RequestBuilder {
self
}

pub fn header(mut self, key: &str, value: &str) -> Self {
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.builder = self.builder.header(key, value);
self
}

#[allow(dead_code)]
pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
self.builder = self.builder.multipart(form);
self
}

#[allow(dead_code)]
pub fn form(mut self, form: &[(&str, &str)]) -> Self {
self.builder = self.builder.form(form);
self
}
}

/// A wrapper around [`reqwest::Response`] that provides common methods with internal `unwrap()`s.
///
/// This is conventient for tests where panics are what you want. For access to
/// non-panicking versions or the complete `Response` API use `into_inner()` or
/// `as_ref()`.
impl IntoFuture for RequestBuilder {
type Output = TestResponse;
type IntoFuture = BoxFuture<'static, Self::Output>;

fn into_future(self) -> Self::IntoFuture {
Box::pin(async {
TestResponse {
response: self.builder.send().await.unwrap(),
}
})
}
}

#[derive(Debug)]
pub struct TestResponse {
response: reqwest::Response,
}

impl TestResponse {
pub async fn text(self) -> String {
self.response.text().await.unwrap()
impl Deref for TestResponse {
type Target = reqwest::Response;

fn deref(&self) -> &Self::Target {
&self.response
}
}

impl TestResponse {
#[allow(dead_code)]
pub async fn bytes(self) -> Bytes {
self.response.bytes().await.unwrap()
}

pub async fn text(self) -> String {
self.response.text().await.unwrap()
}

#[allow(dead_code)]
pub async fn json<T>(self) -> T
where
T: serde::de::DeserializeOwned,
{
self.response.json().await.unwrap()
}

pub fn status(&self) -> StatusCode {
StatusCode::from_u16(self.response.status().as_u16()).unwrap()
}

pub fn headers(&self) -> &reqwest::header::HeaderMap {
self.response.headers()
}

pub async fn chunk(&mut self) -> Option<Bytes> {
self.response.chunk().await.unwrap()
}
Expand Down Expand Up @@ -234,17 +216,17 @@ mod tests {
#[tokio::test]
async fn test_get_request() {
let app = Router::new().route("/", get(|| async {}));
let client = super::TestClient::new(app).await;
let res = client.get("/").send().await;
let client = super::TestClient::new(app);
let res = client.get("/").await;
assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn test_post_form_request() {
let app = Router::new().route("/", post(handle_form));
let client = super::TestClient::new(app).await;
let client = super::TestClient::new(app);
let form = [("val", "bar"), ("baz", "quux")];
let res = client.post("/").form(&form).send().await;
let res = client.post("/").form(&form).await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "bar");
}
Expand All @@ -261,7 +243,7 @@ mod tests {
"/",
post(|json_value: Json<serde_json::Value>| async { json_value }),
);
let client = super::TestClient::new(app).await;
let client = super::TestClient::new(app);
let payload = TestPayload {
name: "Alice".to_owned(),
age: 30,
Expand All @@ -270,7 +252,6 @@ mod tests {
.post("/")
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
let response_body: TestPayload = serde_json::from_str(&res.text().await).unwrap();
Expand Down

0 comments on commit 647398f

Please sign in to comment.