diff --git a/examples/single_threaded.rs b/examples/single_threaded.rs index 263df69..ce555cf 100644 --- a/examples/single_threaded.rs +++ b/examples/single_threaded.rs @@ -29,6 +29,7 @@ use hyperdriver::info::HasConnectionInfo; use hyperdriver::server::Accept; use hyperdriver::service::{make_service_fn, RequestExecutor}; use hyperdriver::stream::TcpStream; +use hyperdriver::IntoRequestParts; use pin_project::pin_project; use tokio::io::{self, AsyncWriteExt}; use tokio::net::TcpListener; @@ -391,8 +392,8 @@ impl Transport for TransportNotSend { type Future = Pin> + Send>>; - fn connect(&mut self, uri: http::Uri) -> ::Future { - self.tcp.connect(uri).boxed() + fn connect(&mut self, req: R) -> ::Future { + self.tcp.connect(req.into_request_parts()).boxed() } fn poll_ready( diff --git a/src/client/conn/transport/duplex.rs b/src/client/conn/transport/duplex.rs index 2e9ccd4..73a2bf0 100644 --- a/src/client/conn/transport/duplex.rs +++ b/src/client/conn/transport/duplex.rs @@ -4,7 +4,6 @@ use std::io; use std::task::{Context, Poll}; use crate::BoxFuture; -use http::Uri; use crate::stream::duplex::DuplexStream as Stream; @@ -25,7 +24,7 @@ impl DuplexTransport { } } -impl tower::Service for DuplexTransport { +impl tower::Service for DuplexTransport { type Response = Stream; type Error = io::Error; @@ -36,7 +35,7 @@ impl tower::Service for DuplexTransport { Poll::Ready(Ok(())) } - fn call(&mut self, _req: Uri) -> Self::Future { + fn call(&mut self, _req: http::request::Parts) -> Self::Future { let client = self.client.clone(); let max_buf_size = self.max_buf_size; let fut = async move { @@ -66,7 +65,13 @@ mod tests { let (io, _) = tokio::join!( async { transport - .oneshot("https://example.com".parse().unwrap()) + .oneshot( + http::Request::get("https://example.com") + .body(()) + .unwrap() + .into_parts() + .0, + ) .await .unwrap() }, diff --git a/src/client/conn/transport/mock.rs b/src/client/conn/transport/mock.rs index c89bd0b..ca9e1f5 100644 --- a/src/client/conn/transport/mock.rs +++ b/src/client/conn/transport/mock.rs @@ -2,7 +2,6 @@ use std::future::ready; -use http::Uri; use thiserror::Error; use crate::client::conn::protocol::mock::MockProtocol; @@ -81,14 +80,14 @@ impl MockTransport { /// Create a new connector for the transport. pub fn connector( self, - uri: Uri, + parts: http::request::Parts, version: HttpProtocol, ) -> pool::Connector { - pool::Connector::new(self, MockProtocol::default(), uri, version) + pool::Connector::new(self, MockProtocol::default(), parts, version) } } -impl tower::Service for MockTransport { +impl tower::Service for MockTransport { type Response = MockStream; type Error = MockConnectionError; @@ -102,7 +101,7 @@ impl tower::Service for MockTransport { std::task::Poll::Ready(Ok(())) } - fn call(&mut self, _req: http::Uri) -> Self::Future { + fn call(&mut self, _req: http::request::Parts) -> Self::Future { let reuse = match &mut self.mode { TransportMode::SingleUse => false, TransportMode::Reusable => true, diff --git a/src/client/conn/transport/mod.rs b/src/client/conn/transport/mod.rs index 81644c4..d99e6f7 100644 --- a/src/client/conn/transport/mod.rs +++ b/src/client/conn/transport/mod.rs @@ -6,7 +6,6 @@ use std::future::Future; #[cfg(feature = "tls")] use std::sync::Arc; -use ::http::Uri; #[cfg(feature = "tls")] use rustls::client::ClientConfig; #[cfg(feature = "tls")] @@ -26,6 +25,7 @@ use crate::client::default_tls_config; #[cfg(feature = "stream")] use crate::info::BraidAddr; use crate::info::HasConnectionInfo; +use crate::IntoRequestParts; #[cfg(feature = "stream")] pub mod duplex; @@ -53,7 +53,9 @@ pub trait Transport: Clone + Send { type Future: Future::Error>> + Send + 'static; /// Connect to a remote server and return a stream. - fn connect(&mut self, uri: Uri) -> ::Future; + fn connect(&mut self, req: R) -> ::Future + where + R: IntoRequestParts; /// Poll the transport to see if it is ready to accept a new connection. fn poll_ready( @@ -64,7 +66,7 @@ pub trait Transport: Clone + Send { impl Transport for T where - T: Service, + T: Service, T: Clone + Send + Sync + 'static, T::Error: std::error::Error + Send + Sync + 'static, T::Future: Send + 'static, @@ -75,8 +77,11 @@ where type Error = T::Error; type Future = T::Future; - fn connect(&mut self, uri: Uri) -> >::Future { - self.call(uri) + fn connect(&mut self, req: R) -> >::Future + where + R: IntoRequestParts, + { + self.call(req.into_request_parts()) } fn poll_ready( @@ -87,6 +92,42 @@ where } } +/// A wrapper type which converts any service that accepts a URI +/// into a `hyperdriver` tranport type. Hyperdriver uses http::request::Parts +/// for transports, but many implementations only require the http::Uri +/// in order to function. +#[derive(Debug, Clone)] +pub struct UriTransport(T); + +impl Transport for UriTransport +where + T: Service, + T: Clone + Send + Sync + 'static, + T::Error: std::error::Error + Send + Sync + 'static, + T::Future: Send + 'static, + IO: HasConnectionInfo + Send + 'static, + IO::Addr: Send, +{ + type IO = IO; + type Error = T::Error; + type Future = T::Future; + + fn connect(&mut self, req: R) -> ::Future + where + R: IntoRequestParts, + { + let parts = req.into_request_parts(); + self.0.call(parts.uri) + } + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll::Error>> { + Service::poll_ready(&mut self.0, cx) + } +} + /// Extension trait for Transports to provide additional configuration options. pub trait TransportExt: Transport { #[cfg(feature = "stream")] @@ -277,7 +318,7 @@ impl TlsTransport { } } -impl Service for TlsTransport +impl Service for TlsTransport where T: Transport, ::IO: HasConnectionInfo + AsyncRead + AsyncWrite + Unpin, @@ -311,26 +352,27 @@ where } } - fn call(&mut self, req: Uri) -> Self::Future { + fn call(&mut self, parts: http::request::Parts) -> Self::Future { #[cfg_attr(not(feature = "tls"), allow(unused_variables))] - let use_tls = req + let use_tls = parts + .uri .scheme_str() .is_some_and(|s| matches!(s, "https" | "wss")); match &mut self.braid { InnerBraid::Plain(inner) => { - tracing::trace!(scheme=?req.scheme_str(), "connecting without TLS"); - self::future::TransportBraidFuture::from_plain(inner.connect(req)) + tracing::trace!(scheme=?parts.uri.scheme_str(), "connecting without TLS"); + self::future::TransportBraidFuture::from_plain(inner.connect(parts)) } #[cfg(feature = "tls")] InnerBraid::Tls(inner) if use_tls => { - tracing::trace!(scheme=?req.scheme_str(), "connecting with TLS"); - self::future::TransportBraidFuture::from_tls(inner.call(req)) + tracing::trace!(scheme=?parts.uri.scheme_str(), "connecting with TLS"); + self::future::TransportBraidFuture::from_tls(inner.call(parts)) } #[cfg(feature = "tls")] InnerBraid::Tls(inner) => { - tracing::trace!(scheme=?req.scheme_str(), "connecting without TLS"); - self::future::TransportBraidFuture::from_plain(inner.transport_mut().connect(req)) + tracing::trace!(scheme=?parts.uri.scheme_str(), "connecting without TLS"); + self::future::TransportBraidFuture::from_plain(inner.transport_mut().connect(parts)) } } } diff --git a/src/client/conn/transport/stream.rs b/src/client/conn/transport/stream.rs index fc4e119..75451c3 100644 --- a/src/client/conn/transport/stream.rs +++ b/src/client/conn/transport/stream.rs @@ -1,4 +1,3 @@ -use ::http::Uri; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; @@ -23,7 +22,7 @@ impl IntoStream { } } -impl Service for IntoStream +impl Service for IntoStream where T: Transport, T::IO: Into + AsyncRead + AsyncWrite + Unpin + Send + 'static, @@ -40,7 +39,7 @@ where self.transport.poll_ready(cx) } - fn call(&mut self, req: Uri) -> Self::Future { + fn call(&mut self, req: http::request::Parts) -> Self::Future { fut::ConnectFuture::new(self.transport.connect(req)) } } @@ -97,6 +96,7 @@ mod tests { use crate::client::conn::transport::duplex::DuplexTransport; use crate::client::conn::transport::TransportExt as _; use crate::server::conn::AcceptExt as _; + use crate::IntoRequestParts; use tower::ServiceExt as _; #[tokio::test] @@ -108,7 +108,7 @@ mod tests { let (io, _) = tokio::join!( async { transport - .oneshot("https://example.com".parse().unwrap()) + .oneshot("https://example.com".into_request_parts()) .await .unwrap() }, diff --git a/src/client/conn/transport/tcp.rs b/src/client/conn/transport/tcp.rs index d6eb8b6..86377ad 100644 --- a/src/client/conn/transport/tcp.rs +++ b/src/client/conn/transport/tcp.rs @@ -58,12 +58,13 @@ use crate::BoxError; /// # use hyperdriver::client::conn::transport::tcp::TcpTransport; /// # use hyperdriver::client::conn::dns::GaiResolver; /// # use hyperdriver::stream::tcp::TcpStream; +/// # use hyperdriver::IntoRequestParts; /// # use tower::ServiceExt as _; /// /// # async fn run() { /// let transport: TcpTransport = TcpTransport::default(); /// -/// let uri = "http://example.com".parse().unwrap(); +/// let uri = "http://example.com".into_request_parts(); /// let stream = transport.oneshot(uri).await.unwrap(); /// # } /// ``` @@ -169,7 +170,7 @@ impl TcpTransport { type BoxFuture<'a, T, E> = crate::BoxFuture<'a, Result>; -impl tower::Service for TcpTransport +impl tower::Service for TcpTransport where R: tower::Service, Response = SocketAddrs, Error = io::Error> + Clone @@ -191,8 +192,8 @@ where .map_err(TcpConnectionError::msg("dns poll_ready")) } - fn call(&mut self, req: Uri) -> Self::Future { - let (host, port) = match get_host_and_port(&req) { + fn call(&mut self, req: http::request::Parts) -> Self::Future { + let (host, port) = match get_host_and_port(&req.uri) { Ok((host, port)) => (host, port), Err(e) => return Box::pin(std::future::ready(Err(e))), }; @@ -623,7 +624,7 @@ mod test { use tokio::net::TcpListener; use tower::Service; - use crate::client::conn::Transport; + use crate::{client::conn::Transport, IntoRequestParts}; use super::*; @@ -720,13 +721,16 @@ mod test { listener: TcpListener, ) -> (T::IO, TcpStream) where - T: Transport + Service, - >::Error: std::fmt::Debug, + T: Transport + Service, + >::Error: std::fmt::Debug, { - tokio::join!(async { transport.oneshot(uri).await.unwrap() }, async { - let (stream, addr) = listener.accept().await.unwrap(); - TcpStream::server(stream, addr) - }) + tokio::join!( + async { transport.oneshot(uri.into_request_parts()).await.unwrap() }, + async { + let (stream, addr) = listener.accept().await.unwrap(); + TcpStream::server(stream, addr) + } + ) } #[tokio::test] @@ -742,7 +746,7 @@ mod test { .with_resolver(Resolver(0)) .build::(); - let result = transport.oneshot(uri).await; + let result = transport.oneshot(uri.into_request_parts()).await; assert!(result.is_err()); } @@ -784,7 +788,7 @@ mod test { .with_resolver(EmptyResolver) .build::(); - let result = transport.oneshot(uri).await; + let result = transport.oneshot(uri.into_request_parts()).await; assert!(result.is_err()); let err = result.unwrap_err(); @@ -796,7 +800,7 @@ mod test { async fn test_transport_error() { let _ = tracing_subscriber::fmt::try_init(); - let uri: Uri = "http://example.com".parse().unwrap(); + let parts = "http://example.com".into_request_parts(); let config = TcpTransportConfig::default(); @@ -805,7 +809,7 @@ mod test { .with_resolver(ErrorResolver) .build::(); - let result = transport.oneshot(uri).await; + let result = transport.oneshot(parts).await; assert!(result.is_err()); let err = result.unwrap_err(); diff --git a/src/client/conn/transport/tls.rs b/src/client/conn/transport/tls.rs index e209731..61e6feb 100644 --- a/src/client/conn/transport/tls.rs +++ b/src/client/conn/transport/tls.rs @@ -4,7 +4,6 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use http::Uri; use rustls::ClientConfig as TlsClientConfig; use tokio::io::{AsyncRead, AsyncWrite}; @@ -46,7 +45,7 @@ impl TlsTransportWrapper { } } -impl tower::Service for TlsTransportWrapper +impl tower::Service for TlsTransportWrapper where T: Transport, ::IO: HasConnectionInfo + AsyncRead + AsyncWrite + Unpin, @@ -62,9 +61,9 @@ where .map_err(TlsConnectionError::Connection) } - fn call(&mut self, req: Uri) -> Self::Future { + fn call(&mut self, req: http::request::Parts) -> Self::Future { let config = self.config.clone(); - let Some(host) = req.host().map(String::from) else { + let Some(host) = req.uri.host().map(String::from) else { return future::TlsConnectionFuture::error(TlsConnectionError::NoDomain); }; @@ -230,6 +229,7 @@ mod tests { use crate::{ fixtures, + helpers::IntoRequestParts, info::HasTlsConnectionInfo as _, server::conn::AcceptExt, stream::tls::{TlsHandshakeExt, TlsHandshakeStream as _}, @@ -252,11 +252,10 @@ mod tests { config.alpn_protocols.push(b"h2".to_vec()); let accept = crate::server::conn::Acceptor::new(server).with_tls(config.into()); - let uri = "https://example.com/".parse().unwrap(); - + let parts = "https://example.com/".into_request_parts(); let (stream, _) = tokio::join!( async { - let mut stream = transport.oneshot(uri).await.unwrap(); + let mut stream = transport.oneshot(parts).await.unwrap(); stream.finish_handshake().await.unwrap(); stream }, diff --git a/src/client/pool/checkout.rs b/src/client/pool/checkout.rs index d1bc7b1..371fb62 100644 --- a/src/client/pool/checkout.rs +++ b/src/client/pool/checkout.rs @@ -6,7 +6,6 @@ use std::task::ready; use std::task::Context; use std::task::Poll; -use http::Uri; use pin_project::pin_project; use pin_project::pinned_drop; use thiserror::Error; @@ -135,7 +134,7 @@ where P::Connection: PoolableConnection, { PollReadyTransport { - address: Uri, + parts: http::request::Parts, transport: Option, protocol: Option

, }, @@ -163,9 +162,9 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ConnectorState::PollReadyTransport { address, .. } => f + ConnectorState::PollReadyTransport { parts, .. } => f .debug_struct("PollReadyTransport") - .field("address", address) + .field("address", &parts.uri) .finish(), ConnectorState::Connect { .. } => f.debug_tuple("Connect").finish(), ConnectorState::PollReadyHandshake { .. } => { @@ -212,13 +211,18 @@ where P::Connection: PoolableConnection, { /// Create a new connection from a transport connector and a handshake function. - pub fn new(transport: T, protocol: P, address: Uri, version: HttpProtocol) -> Self { + pub fn new( + transport: T, + protocol: P, + parts: http::request::Parts, + version: HttpProtocol, + ) -> Self { //TODO: Fix this let shareable = false; Self { state: ConnectorState::PollReadyTransport { - address, + parts, transport: Some(transport), protocol: Some(protocol), }, @@ -250,7 +254,7 @@ where loop { match connector_projected.state.as_mut().project() { ConnectorStateProjected::PollReadyTransport { - address, + parts, transport, protocol, } => { @@ -265,7 +269,7 @@ where let mut transport = transport .take() .expect("future polled in invalid state: transport is None"); - let future = transport.connect(address.clone()); + let future = transport.connect(parts.clone()); let protocol = protocol.take(); tracing::trace!("transport ready"); @@ -736,6 +740,9 @@ where mod test { use super::*; + #[cfg(feature = "mocks")] + use crate::IntoRequestParts; + use static_assertions::assert_impl_all; assert_impl_all!(Error: std::error::Error, Send, Sync, Into); @@ -759,7 +766,7 @@ mod test { let transport = MockTransport::single(); let checkout = Checkout::detached( - transport.connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + transport.connector("mock://address".into_request_parts(), HttpProtocol::Http1), ); assert!(checkout.token.is_zero()); diff --git a/src/client/pool/mod.rs b/src/client/pool/mod.rs index 1244835..992687e 100644 --- a/src/client/pool/mod.rs +++ b/src/client/pool/mod.rs @@ -532,6 +532,7 @@ mod tests { use crate::client::conn::protocol::HttpProtocol; use crate::client::conn::transport::mock::MockConnectionError; + use crate::helpers::IntoRequestParts; use super::*; use crate::client::conn::protocol::mock::MockSender; @@ -577,7 +578,7 @@ mod tests { key.clone(), false, MockTransport::single() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ) .await .unwrap(); @@ -591,7 +592,7 @@ mod tests { key.clone(), false, MockTransport::single() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ) .await .unwrap(); @@ -606,7 +607,7 @@ mod tests { key, false, MockTransport::single() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ) .await .unwrap(); @@ -632,7 +633,7 @@ mod tests { key.clone(), true, MockTransport::reusable() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ) .await .unwrap(); @@ -646,7 +647,7 @@ mod tests { key.clone(), true, MockTransport::reusable() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ) .await .unwrap(); @@ -661,7 +662,7 @@ mod tests { key.clone(), true, MockTransport::reusable() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ) .await .unwrap(); @@ -686,7 +687,7 @@ mod tests { key.clone(), true, MockTransport::channel(rx) - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1) + .connector("mock://address".into_request_parts(), HttpProtocol::Http1) )); assert!(futures_util::poll!(&mut checkout_a).is_pending()); @@ -695,7 +696,7 @@ mod tests { key.clone(), true, MockTransport::reusable() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), )); assert!(futures_util::poll!(&mut checkout_b).is_pending()); @@ -730,7 +731,7 @@ mod tests { key.clone(), false, MockTransport::single() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ); let token = checkout.token(); @@ -768,7 +769,7 @@ mod tests { key.clone(), false, MockTransport::single() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ); let token = checkout.token(); @@ -813,14 +814,14 @@ mod tests { key.clone(), true, MockTransport::reusable() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ); let checkout = pool.checkout( key.clone(), true, MockTransport::reusable() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ); drop(start); @@ -845,7 +846,7 @@ mod tests { key.clone(), true, MockTransport::reusable() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ); drop(pool); @@ -869,7 +870,7 @@ mod tests { key.clone(), true, MockTransport::error() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ); let outcome = checkout.now_or_never().unwrap(); @@ -895,7 +896,7 @@ mod tests { key.clone(), false, MockTransport::single() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ) .await .unwrap(); @@ -909,7 +910,7 @@ mod tests { key.clone(), false, MockTransport::single() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ) .await .unwrap(); @@ -924,7 +925,7 @@ mod tests { key, false, MockTransport::single() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ) .await .unwrap(); @@ -950,7 +951,7 @@ mod tests { key.clone(), false, MockTransport::single() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ) .await .unwrap(); @@ -962,7 +963,7 @@ mod tests { key.clone(), false, MockTransport::single() - .connector("mock://address".parse().unwrap(), HttpProtocol::Http1), + .connector("mock://address".into_request_parts(), HttpProtocol::Http1), ); let token = checkout.token(); diff --git a/src/client/pool/service.rs b/src/client/pool/service.rs index a7d4cc5..2d1a2e3 100644 --- a/src/client/pool/service.rs +++ b/src/client/pool/service.rs @@ -231,12 +231,7 @@ where let transport = self.transport.clone(); let http_protocol = request_parts.version.into(); - let connector = Connector::new( - transport, - protocol, - request_parts.uri.clone(), - http_protocol, - ); + let connector = Connector::new(transport, protocol, request_parts.clone(), http_protocol); if let Some(pool) = self.pool.as_ref() { tracing::trace!(?key, "checking out connection"); diff --git a/src/lib.rs b/src/lib.rs index da97e15..addf4b7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -135,6 +135,8 @@ pub use client::Client; #[cfg(feature = "server")] pub use server::Server; +pub use helpers::IntoRequestParts; + type BoxFuture<'a, T> = Pin + Send + 'a>>; type BoxError = Box; @@ -167,6 +169,34 @@ pub(crate) mod private { pub trait Sealed {} } +/// Helpers to turn URI-like items into empty request parts +pub(crate) mod helpers { + + /// Turn the item into http::request::Parts infallibly + pub trait IntoRequestParts { + /// Produce http::request::Parts + fn into_request_parts(self) -> http::request::Parts; + } + + impl IntoRequestParts for &str { + fn into_request_parts(self) -> http::request::Parts { + http::Request::get(self).body(()).unwrap().into_parts().0 + } + } + + impl IntoRequestParts for http::Uri { + fn into_request_parts(self) -> http::request::Parts { + http::Request::get(self).body(()).unwrap().into_parts().0 + } + } + + impl IntoRequestParts for http::request::Parts { + fn into_request_parts(self) -> http::request::Parts { + self + } + } +} + /// Test fixtures for the `hyperdriver` crate. #[cfg(test)] #[cfg(feature = "tls")] diff --git a/src/server/conn/tls/info.rs b/src/server/conn/tls/info.rs index b2a9d9a..e0dbf60 100644 --- a/src/server/conn/tls/info.rs +++ b/src/server/conn/tls/info.rs @@ -185,7 +185,7 @@ mod tests { use tower::make::Shared; use tower::Service; - use crate::fixtures; + use crate::{fixtures, IntoRequestParts}; use crate::client::conn::transport::duplex::DuplexTransport; use crate::client::conn::transport::TransportExt as _; @@ -218,15 +218,11 @@ mod tests { .with_tls(crate::fixtures::tls_client_config().into()); let client = async move { - let mut stream = client - .connect("https://example.com".parse().unwrap()) - .await - .unwrap(); + let parts = "https://example.com".into_request_parts(); + let mut stream = client.connect(parts).await.unwrap(); tracing::debug!("client connected"); - stream.finish_handshake().await.unwrap(); - tracing::debug!("client handshake finished"); stream @@ -237,25 +233,25 @@ mod tests { let mut conn = acceptor.accept().await.unwrap(); tracing::debug!("server accepted"); - let mut make_service = TlsConnectionInfoLayer::new().layer(Shared::new(service)); - conn.finish_handshake().await.unwrap(); tracing::debug!("server handshake finished"); let mut svc = Service::call(&mut make_service, &conn).await.unwrap(); tracing::debug!("server created"); - let _ = tower::Service::call(&mut svc, http::Request::new(crate::Body::empty())) .await .unwrap(); - tracing::debug!("server request handled"); conn } .instrument(tracing::info_span!("server")); - let (stream, conn) = tokio::join!(client, server); + let (stream, conn) = tokio::time::timeout(std::time::Duration::from_secs(60), async { + tokio::join!(client, server) + }) + .await + .unwrap(); drop((stream, conn)); } } diff --git a/src/server/conn/tls/mod.rs b/src/server/conn/tls/mod.rs index 411d5da..b06c23d 100644 --- a/src/server/conn/tls/mod.rs +++ b/src/server/conn/tls/mod.rs @@ -193,7 +193,7 @@ mod tests { use tracing::Instrument as _; - use crate::fixtures; + use crate::{fixtures, IntoRequestParts}; use crate::client::conn::transport::duplex::DuplexTransport; use crate::client::conn::transport::TransportExt as _; @@ -223,7 +223,7 @@ mod tests { let client = async move { let mut stream = client - .connect("https://example.com".parse().unwrap()) + .connect("https://example.com".into_request_parts()) .await .unwrap();