Skip to content

Commit

Permalink
feat: client Transport trait now accepts http::request::Parts
Browse files Browse the repository at this point in the history
This is so that clients can support things like session pinning when used as part
of a load balancer.

To use a Service<Uri> as a tranpsort, wrap it in the UriTransport type.

BREAKING: The signature of the connect() method on Transport has changed.
  • Loading branch information
alexrudy committed Dec 12, 2024
1 parent 2839592 commit 9323a19
Show file tree
Hide file tree
Showing 13 changed files with 178 additions and 99 deletions.
5 changes: 3 additions & 2 deletions examples/single_threaded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use hyperdriver::info::HasConnectionInfo;
use hyperdriver::server::Accept;
use hyperdriver::service::{make_service_fn, RequestExecutor};
use hyperdriver::stream::TcpStream;
use hyperdriver::IntoRequestParts;
use pin_project::pin_project;
use tokio::io::{self, AsyncWriteExt};
use tokio::net::TcpListener;
Expand Down Expand Up @@ -391,8 +392,8 @@ impl Transport for TransportNotSend {

type Future = Pin<Box<dyn Future<Output = Result<Self::IO, Self::Error>> + Send>>;

fn connect(&mut self, uri: http::Uri) -> <Self as Transport>::Future {
self.tcp.connect(uri).boxed()
fn connect<R: IntoRequestParts>(&mut self, req: R) -> <Self as Transport>::Future {
self.tcp.connect(req.into_request_parts()).boxed()
}

fn poll_ready(
Expand Down
13 changes: 9 additions & 4 deletions src/client/conn/transport/duplex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::io;
use std::task::{Context, Poll};

use crate::BoxFuture;
use http::Uri;

use crate::stream::duplex::DuplexStream as Stream;

Expand All @@ -25,7 +24,7 @@ impl DuplexTransport {
}
}

impl tower::Service<Uri> for DuplexTransport {
impl tower::Service<http::request::Parts> for DuplexTransport {
type Response = Stream;

type Error = io::Error;
Expand All @@ -36,7 +35,7 @@ impl tower::Service<Uri> for DuplexTransport {
Poll::Ready(Ok(()))
}

fn call(&mut self, _req: Uri) -> Self::Future {
fn call(&mut self, _req: http::request::Parts) -> Self::Future {
let client = self.client.clone();
let max_buf_size = self.max_buf_size;
let fut = async move {
Expand Down Expand Up @@ -66,7 +65,13 @@ mod tests {
let (io, _) = tokio::join!(
async {
transport
.oneshot("https://example.com".parse().unwrap())
.oneshot(
http::Request::get("https://example.com")
.body(())
.unwrap()
.into_parts()
.0,
)
.await
.unwrap()
},
Expand Down
9 changes: 4 additions & 5 deletions src/client/conn/transport/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
use std::future::ready;

use http::Uri;
use thiserror::Error;

use crate::client::conn::protocol::mock::MockProtocol;
Expand Down Expand Up @@ -81,14 +80,14 @@ impl MockTransport {
/// Create a new connector for the transport.
pub fn connector(
self,
uri: Uri,
parts: http::request::Parts,
version: HttpProtocol,
) -> pool::Connector<Self, MockProtocol, crate::Body> {
pool::Connector::new(self, MockProtocol::default(), uri, version)
pool::Connector::new(self, MockProtocol::default(), parts, version)
}
}

impl tower::Service<http::Uri> for MockTransport {
impl tower::Service<http::request::Parts> for MockTransport {
type Response = MockStream;

type Error = MockConnectionError;
Expand All @@ -102,7 +101,7 @@ impl tower::Service<http::Uri> for MockTransport {
std::task::Poll::Ready(Ok(()))
}

fn call(&mut self, _req: http::Uri) -> Self::Future {
fn call(&mut self, _req: http::request::Parts) -> Self::Future {
let reuse = match &mut self.mode {
TransportMode::SingleUse => false,
TransportMode::Reusable => true,
Expand Down
70 changes: 56 additions & 14 deletions src/client/conn/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::future::Future;
#[cfg(feature = "tls")]
use std::sync::Arc;

use ::http::Uri;
#[cfg(feature = "tls")]
use rustls::client::ClientConfig;
#[cfg(feature = "tls")]
Expand All @@ -26,6 +25,7 @@ use crate::client::default_tls_config;
#[cfg(feature = "stream")]
use crate::info::BraidAddr;
use crate::info::HasConnectionInfo;
use crate::IntoRequestParts;

#[cfg(feature = "stream")]
pub mod duplex;
Expand Down Expand Up @@ -53,7 +53,9 @@ pub trait Transport: Clone + Send {
type Future: Future<Output = Result<Self::IO, <Self as Transport>::Error>> + Send + 'static;

/// Connect to a remote server and return a stream.
fn connect(&mut self, uri: Uri) -> <Self as Transport>::Future;
fn connect<R>(&mut self, req: R) -> <Self as Transport>::Future
where
R: IntoRequestParts;

/// Poll the transport to see if it is ready to accept a new connection.
fn poll_ready(
Expand All @@ -64,7 +66,7 @@ pub trait Transport: Clone + Send {

impl<T, IO> Transport for T
where
T: Service<Uri, Response = IO>,
T: Service<http::request::Parts, Response = IO>,
T: Clone + Send + Sync + 'static,
T::Error: std::error::Error + Send + Sync + 'static,
T::Future: Send + 'static,
Expand All @@ -75,8 +77,11 @@ where
type Error = T::Error;
type Future = T::Future;

fn connect(&mut self, uri: Uri) -> <Self as Service<Uri>>::Future {
self.call(uri)
fn connect<R>(&mut self, req: R) -> <Self as Service<http::request::Parts>>::Future
where
R: IntoRequestParts,
{
self.call(req.into_request_parts())
}

fn poll_ready(
Expand All @@ -87,6 +92,42 @@ where
}
}

/// A wrapper type which converts any service that accepts a URI
/// into a `hyperdriver` tranport type. Hyperdriver uses http::request::Parts
/// for transports, but many implementations only require the http::Uri
/// in order to function.
#[derive(Debug, Clone)]
pub struct UriTransport<T>(T);

impl<T, IO> Transport for UriTransport<T>
where
T: Service<http::Uri, Response = IO>,
T: Clone + Send + Sync + 'static,
T::Error: std::error::Error + Send + Sync + 'static,
T::Future: Send + 'static,
IO: HasConnectionInfo + Send + 'static,
IO::Addr: Send,
{
type IO = IO;
type Error = T::Error;
type Future = T::Future;

fn connect<R>(&mut self, req: R) -> <Self as Transport>::Future
where
R: IntoRequestParts,
{
let parts = req.into_request_parts();
self.0.call(parts.uri)
}

fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), <Self as Transport>::Error>> {
Service::poll_ready(&mut self.0, cx)
}
}

/// Extension trait for Transports to provide additional configuration options.
pub trait TransportExt: Transport {
#[cfg(feature = "stream")]
Expand Down Expand Up @@ -277,7 +318,7 @@ impl<T> TlsTransport<T> {
}
}

impl<T> Service<Uri> for TlsTransport<T>
impl<T> Service<http::request::Parts> for TlsTransport<T>
where
T: Transport,
<T as Transport>::IO: HasConnectionInfo + AsyncRead + AsyncWrite + Unpin,
Expand Down Expand Up @@ -311,26 +352,27 @@ where
}
}

fn call(&mut self, req: Uri) -> Self::Future {
fn call(&mut self, parts: http::request::Parts) -> Self::Future {
#[cfg_attr(not(feature = "tls"), allow(unused_variables))]
let use_tls = req
let use_tls = parts
.uri
.scheme_str()
.is_some_and(|s| matches!(s, "https" | "wss"));

match &mut self.braid {
InnerBraid::Plain(inner) => {
tracing::trace!(scheme=?req.scheme_str(), "connecting without TLS");
self::future::TransportBraidFuture::from_plain(inner.connect(req))
tracing::trace!(scheme=?parts.uri.scheme_str(), "connecting without TLS");
self::future::TransportBraidFuture::from_plain(inner.connect(parts))
}
#[cfg(feature = "tls")]
InnerBraid::Tls(inner) if use_tls => {
tracing::trace!(scheme=?req.scheme_str(), "connecting with TLS");
self::future::TransportBraidFuture::from_tls(inner.call(req))
tracing::trace!(scheme=?parts.uri.scheme_str(), "connecting with TLS");
self::future::TransportBraidFuture::from_tls(inner.call(parts))
}
#[cfg(feature = "tls")]
InnerBraid::Tls(inner) => {
tracing::trace!(scheme=?req.scheme_str(), "connecting without TLS");
self::future::TransportBraidFuture::from_plain(inner.transport_mut().connect(req))
tracing::trace!(scheme=?parts.uri.scheme_str(), "connecting without TLS");
self::future::TransportBraidFuture::from_plain(inner.transport_mut().connect(parts))
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/client/conn/transport/stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use ::http::Uri;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;

Expand All @@ -23,7 +22,7 @@ impl<T> IntoStream<T> {
}
}

impl<T> Service<Uri> for IntoStream<T>
impl<T> Service<http::request::Parts> for IntoStream<T>
where
T: Transport,
T::IO: Into<Stream> + AsyncRead + AsyncWrite + Unpin + Send + 'static,
Expand All @@ -40,7 +39,7 @@ where
self.transport.poll_ready(cx)
}

fn call(&mut self, req: Uri) -> Self::Future {
fn call(&mut self, req: http::request::Parts) -> Self::Future {
fut::ConnectFuture::new(self.transport.connect(req))
}
}
Expand Down Expand Up @@ -97,6 +96,7 @@ mod tests {
use crate::client::conn::transport::duplex::DuplexTransport;
use crate::client::conn::transport::TransportExt as _;
use crate::server::conn::AcceptExt as _;
use crate::IntoRequestParts;
use tower::ServiceExt as _;

#[tokio::test]
Expand All @@ -108,7 +108,7 @@ mod tests {
let (io, _) = tokio::join!(
async {
transport
.oneshot("https://example.com".parse().unwrap())
.oneshot("https://example.com".into_request_parts())
.await
.unwrap()
},
Expand Down
34 changes: 19 additions & 15 deletions src/client/conn/transport/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ use crate::BoxError;
/// # use hyperdriver::client::conn::transport::tcp::TcpTransport;
/// # use hyperdriver::client::conn::dns::GaiResolver;
/// # use hyperdriver::stream::tcp::TcpStream;
/// # use hyperdriver::IntoRequestParts;
/// # use tower::ServiceExt as _;
///
/// # async fn run() {
/// let transport: TcpTransport<GaiResolver, TcpStream> = TcpTransport::default();
///
/// let uri = "http://example.com".parse().unwrap();
/// let uri = "http://example.com".into_request_parts();
/// let stream = transport.oneshot(uri).await.unwrap();
/// # }
/// ```
Expand Down Expand Up @@ -169,7 +170,7 @@ impl<R, IO> TcpTransport<R, IO> {

type BoxFuture<'a, T, E> = crate::BoxFuture<'a, Result<T, E>>;

impl<R, IO> tower::Service<Uri> for TcpTransport<R, IO>
impl<R, IO> tower::Service<http::request::Parts> for TcpTransport<R, IO>
where
R: tower::Service<Box<str>, Response = SocketAddrs, Error = io::Error>
+ Clone
Expand All @@ -191,8 +192,8 @@ where
.map_err(TcpConnectionError::msg("dns poll_ready"))
}

fn call(&mut self, req: Uri) -> Self::Future {
let (host, port) = match get_host_and_port(&req) {
fn call(&mut self, req: http::request::Parts) -> Self::Future {
let (host, port) = match get_host_and_port(&req.uri) {
Ok((host, port)) => (host, port),
Err(e) => return Box::pin(std::future::ready(Err(e))),
};
Expand Down Expand Up @@ -623,7 +624,7 @@ mod test {
use tokio::net::TcpListener;
use tower::Service;

use crate::client::conn::Transport;
use crate::{client::conn::Transport, IntoRequestParts};

use super::*;

Expand Down Expand Up @@ -720,13 +721,16 @@ mod test {
listener: TcpListener,
) -> (T::IO, TcpStream)
where
T: Transport + Service<Uri, Response = T::IO>,
<T as Service<Uri>>::Error: std::fmt::Debug,
T: Transport + Service<http::request::Parts, Response = T::IO>,
<T as Service<http::request::Parts>>::Error: std::fmt::Debug,
{
tokio::join!(async { transport.oneshot(uri).await.unwrap() }, async {
let (stream, addr) = listener.accept().await.unwrap();
TcpStream::server(stream, addr)
})
tokio::join!(
async { transport.oneshot(uri.into_request_parts()).await.unwrap() },
async {
let (stream, addr) = listener.accept().await.unwrap();
TcpStream::server(stream, addr)
}
)
}

#[tokio::test]
Expand All @@ -742,7 +746,7 @@ mod test {
.with_resolver(Resolver(0))
.build::<TcpStream>();

let result = transport.oneshot(uri).await;
let result = transport.oneshot(uri.into_request_parts()).await;
assert!(result.is_err());
}

Expand Down Expand Up @@ -784,7 +788,7 @@ mod test {
.with_resolver(EmptyResolver)
.build::<TcpStream>();

let result = transport.oneshot(uri).await;
let result = transport.oneshot(uri.into_request_parts()).await;
assert!(result.is_err());

let err = result.unwrap_err();
Expand All @@ -796,7 +800,7 @@ mod test {
async fn test_transport_error() {
let _ = tracing_subscriber::fmt::try_init();

let uri: Uri = "http://example.com".parse().unwrap();
let parts = "http://example.com".into_request_parts();

let config = TcpTransportConfig::default();

Expand All @@ -805,7 +809,7 @@ mod test {
.with_resolver(ErrorResolver)
.build::<TcpStream>();

let result = transport.oneshot(uri).await;
let result = transport.oneshot(parts).await;
assert!(result.is_err());

let err = result.unwrap_err();
Expand Down
Loading

0 comments on commit 9323a19

Please sign in to comment.