diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 957186a0d..8965dd072 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -71,6 +71,7 @@ jobs: - "feat.: default-tls disabled" - "feat.: rustls-tls" - "feat.: rustls-tls-manual-roots" + - "feat.: rustls-tls-native-roots" - "feat.: native-tls" - "feat.: default-tls and rustls-tls" - "feat.: cookies" @@ -101,23 +102,23 @@ jobs: - name: windows / stable-x86_64-msvc os: windows-latest target: x86_64-pc-windows-msvc - features: "--features blocking,gzip,brotli,deflate,json,multipart" + features: "--features blocking,gzip,brotli,deflate,json,multipart,stream" - name: windows / stable-i686-msvc os: windows-latest target: i686-pc-windows-msvc - features: "--features blocking,gzip,brotli,deflate,json,multipart" + features: "--features blocking,gzip,brotli,deflate,json,multipart,stream" - name: windows / stable-x86_64-gnu os: windows-latest rust: stable-x86_64-pc-windows-gnu target: x86_64-pc-windows-gnu - features: "--features blocking,gzip,brotli,deflate,json,multipart" + features: "--features blocking,gzip,brotli,deflate,json,multipart,stream" package_name: mingw-w64-x86_64-gcc mingw64_path: "C:\\msys64\\mingw64\\bin" - name: windows / stable-i686-gnu os: windows-latest rust: stable-i686-pc-windows-gnu target: i686-pc-windows-gnu - features: "--features blocking,gzip,brotli,deflate,json,multipart" + features: "--features blocking,gzip,brotli,deflate,json,multipart,stream" package_name: mingw-w64-i686-gcc mingw64_path: "C:\\msys64\\mingw32\\bin" @@ -127,6 +128,8 @@ jobs: features: "--no-default-features --features rustls-tls" - name: "feat.: rustls-tls-manual-roots" features: "--no-default-features --features rustls-tls-manual-roots" + - name: "feat.: rustls-tls-native-roots" + features: "--no-default-features --features rustls-tls-native-roots" - name: "feat.: native-tls" features: "--features native-tls" - name: "feat.: default-tls and rustls-tls" @@ -136,11 +139,11 @@ jobs: - name: "feat.: blocking" features: "--features blocking" - name: "feat.: gzip" - features: "--features gzip" + features: "--features gzip,stream" - name: "feat.: brotli" - features: "--features brotli" + features: "--features brotli,stream" - name: "feat.: deflate" - features: "--features deflate" + features: "--features deflate,stream" - name: "feat.: json" features: "--features json" - name: "feat.: multipart" @@ -204,11 +207,12 @@ jobs: with: toolchain: 'stable' - - name: Check - run: RUSTFLAGS="--cfg reqwest_unstable" cargo check --features http3 + #- name: Check + # run: RUSTFLAGS="--cfg reqwest_unstable" cargo check --features http3 docs: name: Docs + needs: [test] runs-on: ubuntu-latest steps: diff --git a/Cargo.toml b/Cargo.toml index 951afd4e6..2ca66c043 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ default-tls = ["hyper-tls", "native-tls-crate", "__tls", "tokio-native-tls"] # Enables native-tls specific functionality not available by default. native-tls = ["default-tls"] -native-tls-alpn = ["native-tls", "native-tls-crate/alpn"] +native-tls-alpn = ["native-tls", "native-tls-crate/alpn", "hyper-tls/alpn"] native-tls-vendored = ["native-tls", "native-tls-crate/vendored"] rustls-tls = ["rustls-tls-webpki-roots"] @@ -43,7 +43,7 @@ rustls-tls-manual-roots = ["__rustls"] rustls-tls-webpki-roots = ["webpki-roots", "__rustls"] rustls-tls-native-roots = ["rustls-native-certs", "__rustls"] -blocking = ["futures-util/io", "tokio/sync"] +blocking = ["futures-channel/sink", "futures-util/io", "futures-util/sink", "tokio/rt-multi-thread", "tokio/sync"] cookies = ["cookie_crate", "cookie_store"] @@ -64,7 +64,8 @@ stream = ["tokio/fs", "tokio-util", "wasm-streams"] socks = ["tokio-socks"] # Experimental HTTP/3 client. -http3 = ["rustls-tls-manual-roots", "h3", "h3-quinn", "quinn", "futures-channel"] +# Disabled while waiting for quinn to upgrade. +#http3 = ["rustls-tls-manual-roots", "h3", "h3-quinn", "quinn", "futures-channel"] # Internal (PRIVATE!) features used to aid testing. # Don't rely on these whatsoever. They may disappear at anytime. @@ -74,14 +75,14 @@ __tls = ["dep:rustls-pemfile"] # Enables common rustls code. # Equivalent to rustls-tls-manual-roots but shorter :) -__rustls = ["hyper-rustls", "tokio-rustls", "rustls", "__tls"] +__rustls = ["hyper-rustls", "tokio-rustls", "rustls", "__tls", "dep:rustls-pemfile", "rustls-pki-types"] # When enabled, disable using the cached SYS_PROXIES. __internal_proxy_sys_no_cache = [] [dependencies] base64 = "0.21" -http = "0.2" +http = "1" url = "2.2" bytes = "1.0" serde = "1.0" @@ -100,9 +101,11 @@ mime_guess = { version = "2.0", default-features = false, optional = true } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] encoding_rs = "0.8" -http-body = "0.4.0" -hyper = { version = "0.14.21", default-features = false, features = ["tcp", "http1", "http2", "client", "runtime"] } -h2 = "0.3.14" +http-body = "1" +http-body-util = "0.1" +hyper = { version = "1", features = ["http1", "http2", "client"] } +hyper-util = { version = "0.1.3", features = ["http1", "http2", "client", "client-legacy", "tokio"] } +h2 = "0.4" once_cell = "1" log = "0.4" mime = "0.3.16" @@ -115,16 +118,17 @@ ipnet = "2.3" rustls-pemfile = { version = "1.0", optional = true } ## default-tls -hyper-tls = { version = "0.5", optional = true } +hyper-tls = { version = "0.6", optional = true } native-tls-crate = { version = "0.2.10", optional = true, package = "native-tls" } tokio-native-tls = { version = "0.3.0", optional = true } # rustls-tls -hyper-rustls = { version = "0.24.0", default-features = false, optional = true } -rustls = { version = "0.21.6", features = ["dangerous_configuration"], optional = true } -tokio-rustls = { version = "0.24", optional = true } -webpki-roots = { version = "0.25", optional = true } -rustls-native-certs = { version = "0.6", optional = true } +hyper-rustls = { version = "0.26.0", default-features = false, optional = true } +rustls = { version = "0.22.2", optional = true } +rustls-pki-types = { version = "1.1.0", features = ["alloc"] ,optional = true } +tokio-rustls = { version = "0.25", optional = true } +webpki-roots = { version = "0.26.0", optional = true } +rustls-native-certs = { version = "0.7", optional = true } ## cookies cookie_crate = { version = "0.17.0", package = "cookie", optional = true } @@ -141,15 +145,16 @@ tokio-socks = { version = "0.5.1", optional = true } trust-dns-resolver = { version = "0.23", optional = true, features = ["tokio-runtime"] } # HTTP/3 experimental support -h3 = { version = "0.0.3", optional = true } -h3-quinn = { version = "0.0.4", optional = true } +h3 = { version = "0.0.4", optional = true } +h3-quinn = { version = "0.0.5", optional = true } quinn = { version = "0.10", default-features = false, features = ["tls-rustls", "ring", "runtime-tokio"], optional = true } futures-channel = { version = "0.3", optional = true } [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] env_logger = "0.10" -hyper = { version = "0.14", default-features = false, features = ["tcp", "stream", "http1", "http2", "client", "server", "runtime"] } +hyper = { version = "1.1.0", default-features = false, features = ["http1", "http2", "client", "server"] } +hyper-util = { version = "0.1", features = ["http1", "http2", "client", "client-legacy", "server-auto", "tokio"] } serde = { version = "1.0", features = ["derive"] } libflate = "1.0" brotli_crate = { package = "brotli", version = "3.3.0" } @@ -237,17 +242,17 @@ required-features = ["cookies"] [[test]] name = "gzip" path = "tests/gzip.rs" -required-features = ["gzip"] +required-features = ["gzip", "stream"] [[test]] name = "brotli" path = "tests/brotli.rs" -required-features = ["brotli"] +required-features = ["brotli", "stream"] [[test]] name = "deflate" path = "tests/deflate.rs" -required-features = ["deflate"] +required-features = ["deflate", "stream"] [[test]] name = "multipart" diff --git a/src/async_impl/body.rs b/src/async_impl/body.rs index 0d0357cb6..ff5446e53 100644 --- a/src/async_impl/body.rs +++ b/src/async_impl/body.rs @@ -4,10 +4,9 @@ use std::pin::Pin; use std::task::{Context, Poll}; use bytes::Bytes; -use futures_core::Stream; use http_body::Body as HttpBody; -use pin_project_lite::pin_project; -use sync_wrapper::SyncWrapper; +use http_body_util::combinators::BoxBody; +//use sync_wrapper::SyncWrapper; #[cfg(feature = "stream")] use tokio::fs::File; use tokio::time::Sleep; @@ -19,31 +18,22 @@ pub struct Body { inner: Inner, } -// The `Stream` trait isn't stable, so the impl isn't public. -pub(crate) struct ImplStream(Body); - enum Inner { Reusable(Bytes), - Streaming { - body: Pin< - Box< - dyn HttpBody> - + Send - + Sync, - >, - >, - timeout: Option>>, - }, + Streaming(BoxBody>), } -pin_project! { - struct WrapStream { - #[pin] - inner: SyncWrapper, - } +/// A body with a total timeout. +/// +/// The timeout does not reset upon each chunk, but rather requires the whole +/// body be streamed before the deadline is reached. +pub(crate) struct TotalTimeoutBody { + inner: B, + timeout: Pin>, } -struct WrapHyper(hyper::Body); +/// Converts any `impl Body` into a `impl Stream` of just its DATA frames. +pub(crate) struct DataStream(pub(crate) B); impl Body { /// Returns a reference to the internal data of the `Body`. @@ -52,7 +42,7 @@ impl Body { pub fn as_bytes(&self) -> Option<&[u8]> { match &self.inner { Inner::Reusable(bytes) => Some(bytes.as_ref()), - Inner::Streaming { .. } => None, + Inner::Streaming(..) => None, } } @@ -83,50 +73,44 @@ impl Body { #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] pub fn wrap_stream(stream: S) -> Body where - S: futures_core::stream::TryStream + Send + 'static, + S: futures_core::stream::TryStream + Send + Sync + 'static, S::Error: Into>, Bytes: From, { Body::stream(stream) } + #[cfg(any(feature = "stream", feature = "multipart", feature = "blocking"))] pub(crate) fn stream(stream: S) -> Body where - S: futures_core::stream::TryStream + Send + 'static, + S: futures_core::stream::TryStream + Send + Sync + 'static, S::Error: Into>, Bytes: From, { use futures_util::TryStreamExt; - - let body = Box::pin(WrapStream { - inner: SyncWrapper::new(stream.map_ok(Bytes::from).map_err(Into::into)), - }); + use http_body::Frame; + use http_body_util::StreamBody; + + let body = http_body_util::BodyExt::boxed(StreamBody::new( + stream + .map_ok(|d| Frame::data(Bytes::from(d))) + .map_err(Into::into), + )); Body { - inner: Inner::Streaming { - body, - timeout: None, - }, - } - } - - pub(crate) fn response(body: hyper::Body, timeout: Option>>) -> Body { - Body { - inner: Inner::Streaming { - body: Box::pin(WrapHyper(body)), - timeout, - }, + inner: Inner::Streaming(body), } } + /* #[cfg(feature = "blocking")] pub(crate) fn wrap(body: hyper::Body) -> Body { Body { inner: Inner::Streaming { body: Box::pin(WrapHyper(body)), - timeout: None, }, } } + */ pub(crate) fn empty() -> Body { Body::reusable(Bytes::new()) @@ -138,6 +122,25 @@ impl Body { } } + // pub? + pub(crate) fn streaming(inner: B) -> Body + where + B: HttpBody + Send + Sync + 'static, + B::Data: Into, + B::Error: Into>, + { + use http_body_util::BodyExt; + + let boxed = inner + .map_frame(|f| f.map_data(Into::into)) + .map_err(Into::into) + .boxed(); + + Body { + inner: Inner::Streaming(boxed), + } + } + pub(crate) fn try_reuse(self) -> (Option, Self) { let reuse = match self.inner { Inner::Reusable(ref chunk) => Some(chunk.clone()), @@ -154,30 +157,39 @@ impl Body { } } - pub(crate) fn into_stream(self) -> ImplStream { - ImplStream(self) + #[cfg(feature = "multipart")] + pub(crate) fn into_stream(self) -> DataStream { + DataStream(self) } #[cfg(feature = "multipart")] pub(crate) fn content_length(&self) -> Option { match self.inner { Inner::Reusable(ref bytes) => Some(bytes.len() as u64), - Inner::Streaming { ref body, .. } => body.size_hint().exact(), + Inner::Streaming(ref body) => body.size_hint().exact(), } } } +impl Default for Body { + #[inline] + fn default() -> Body { + Body::empty() + } +} + +/* impl From for Body { #[inline] fn from(body: hyper::Body) -> Body { Self { inner: Inner::Streaming { body: Box::pin(WrapHyper(body)), - timeout: None, }, } } } +*/ impl From for Body { #[inline] @@ -229,132 +241,112 @@ impl fmt::Debug for Body { } } -// ===== impl ImplStream ===== - -impl HttpBody for ImplStream { +impl HttpBody for Body { type Data = Bytes; type Error = crate::Error; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context, - ) -> Poll>> { - let opt_try_chunk = match self.0.inner { - Inner::Streaming { - ref mut body, - ref mut timeout, - } => { - if let Some(ref mut timeout) = timeout { - if let Poll::Ready(()) = timeout.as_mut().poll(cx) { - return Poll::Ready(Some(Err(crate::error::body(crate::error::TimedOut)))); - } - } - futures_core::ready!(Pin::new(body).poll_data(cx)) - .map(|opt_chunk| opt_chunk.map(Into::into).map_err(crate::error::body)) - } + ) -> Poll, Self::Error>>> { + match self.inner { Inner::Reusable(ref mut bytes) => { - if bytes.is_empty() { - None + let out = bytes.split_off(0); + if out.is_empty() { + Poll::Ready(None) } else { - Some(Ok(std::mem::replace(bytes, Bytes::new()))) + Poll::Ready(Some(Ok(hyper::body::Frame::data(out)))) } } - }; - - Poll::Ready(opt_try_chunk) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } - - fn is_end_stream(&self) -> bool { - match self.0.inner { - Inner::Streaming { ref body, .. } => body.is_end_stream(), - Inner::Reusable(ref bytes) => bytes.is_empty(), + Inner::Streaming(ref mut body) => Poll::Ready( + futures_core::ready!(Pin::new(body).poll_frame(cx)) + .map(|opt_chunk| opt_chunk.map_err(crate::error::body)), + ), } } fn size_hint(&self) -> http_body::SizeHint { - match self.0.inner { - Inner::Streaming { ref body, .. } => body.size_hint(), + match self.inner { Inner::Reusable(ref bytes) => { let mut hint = http_body::SizeHint::default(); hint.set_exact(bytes.len() as u64); hint } + Inner::Streaming(ref body) => body.size_hint(), } } } -impl Stream for ImplStream { - type Item = Result; +// ===== impl TotalTimeoutBody ===== - fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.poll_data(cx) +pub(crate) fn total_timeout(body: B, timeout: Pin>) -> TotalTimeoutBody { + TotalTimeoutBody { + inner: body, + timeout, } } -// ===== impl WrapStream ===== - -impl HttpBody for WrapStream +impl hyper::body::Body for TotalTimeoutBody where - S: Stream>, - D: Into, - E: Into>, + B: hyper::body::Body + Unpin, + B::Error: Into>, { - type Data = Bytes; - type Error = E; + type Data = B::Data; + type Error = crate::Error; - fn poll_data( - self: Pin<&mut Self>, + fn poll_frame( + mut self: Pin<&mut Self>, cx: &mut Context, - ) -> Poll>> { - let item = futures_core::ready!(self.project().inner.get_pin_mut().poll_next(cx)?); - - Poll::Ready(item.map(|val| Ok(val.into()))) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) + ) -> Poll, Self::Error>>> { + if let Poll::Ready(()) = self.timeout.as_mut().poll(cx) { + return Poll::Ready(Some(Err(crate::error::body(crate::error::TimedOut)))); + } + Poll::Ready( + futures_core::ready!(Pin::new(&mut self.inner).poll_frame(cx)) + .map(|opt_chunk| opt_chunk.map_err(crate::error::body)), + ) } } -// ===== impl WrapHyper ===== +pub(crate) type ResponseBody = + http_body_util::combinators::BoxBody>; -impl HttpBody for WrapHyper { - type Data = Bytes; - type Error = Box; +pub(crate) fn response( + body: hyper::body::Incoming, + timeout: Option>>, +) -> ResponseBody { + use http_body_util::BodyExt; - fn poll_data( - mut self: Pin<&mut Self>, - cx: &mut Context, - ) -> Poll>> { - // safe pin projection - Pin::new(&mut self.0) - .poll_data(cx) - .map(|opt| opt.map(|res| res.map_err(Into::into))) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) + if let Some(timeout) = timeout { + total_timeout(body, timeout).map_err(Into::into).boxed() + } else { + body.map_err(Into::into).boxed() } +} - fn is_end_stream(&self) -> bool { - self.0.is_end_stream() - } +// ===== impl DataStream ===== - fn size_hint(&self) -> http_body::SizeHint { - HttpBody::size_hint(&self.0) +impl futures_core::Stream for DataStream +where + B: HttpBody + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + loop { + return match futures_core::ready!(Pin::new(&mut self.0).poll_frame(cx)) { + Some(Ok(frame)) => { + // skip non-data frames + if let Ok(buf) = frame.into_data() { + Poll::Ready(Some(Ok(buf))) + } else { + continue; + } + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + }; + } } } diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 1c3cb19f1..88b5c97be 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -13,7 +13,7 @@ use http::header::{ }; use http::uri::Scheme; use http::Uri; -use hyper::client::{HttpConnector, ResponseFuture as HyperResponseFuture}; +use hyper_util::client::legacy::connect::HttpConnector; #[cfg(feature = "native-tls-crate")] use native_tls_crate::TlsConnector; use pin_project_lite::pin_project; @@ -52,6 +52,8 @@ use quinn::TransportConfig; #[cfg(feature = "http3")] use quinn::VarInt; +type HyperResponseFuture = hyper_util::client::legacy::ResponseFuture; + /// An asynchronous `Client` to make Requests with. /// /// The Client has various configuration values to tweak, but the defaults @@ -466,18 +468,7 @@ impl ClientBuilder { #[cfg(feature = "rustls-tls-webpki-roots")] if config.tls_built_in_root_certs { - use rustls::OwnedTrustAnchor; - - let trust_anchors = - webpki_roots::TLS_SERVER_ROOTS.iter().map(|trust_anchor| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - trust_anchor.subject, - trust_anchor.spki, - trust_anchor.name_constraints, - ) - }); - - root_cert_store.add_trust_anchors(trust_anchors); + root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); } #[cfg(feature = "rustls-tls-native-roots")] @@ -487,17 +478,14 @@ impl ClientBuilder { for cert in rustls_native_certs::load_native_certs() .map_err(crate::error::builder)? { - let cert = rustls::Certificate(cert.0); // Continue on parsing errors, as native stores often include ancient or syntactically // invalid certificates, like root certificates without any X509 extensions. // Inspiration: https://github.com/rustls/rustls/blob/633bf4ba9d9521a95f68766d04c22e2b01e68318/rustls/src/anchors.rs#L105-L112 - match root_cert_store.add(&cert) { + match root_cert_store.add(cert.into()) { Ok(_) => valid_count += 1, Err(err) => { invalid_count += 1; - log::warn!( - "rustls failed to parse DER certificate {err:?} {cert:?}" - ); + log::debug!("rustls failed to parse DER certificate: {err:?}"); } } } @@ -532,12 +520,8 @@ impl ClientBuilder { } // Build TLS config - let config_builder = rustls::ClientConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&versions) - .map_err(crate::error::builder)? - .with_root_certificates(root_cert_store); + let config_builder = + rustls::ClientConfig::builder().with_root_certificates(root_cert_store); // Finalize TLS config let mut tls = if let Some(id) = config.identity { @@ -612,7 +596,8 @@ impl ClientBuilder { connector.set_timeout(config.connect_timeout); connector.set_verbose(config.connection_verbose); - let mut builder = hyper::Client::builder(); + let mut builder = + hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()); if matches!(config.http_version_pref, HttpVersionPref::Http2) { builder.http2_only(true); } @@ -641,6 +626,8 @@ impl ClientBuilder { builder.http2_keep_alive_while_idle(true); } + #[cfg(not(target_arch = "wasm32"))] + builder.timer(hyper_util::rt::TokioTimer::new()); builder.pool_idle_timeout(config.pool_idle_timeout); builder.pool_max_idle_per_host(config.pool_max_idle_per_host); connector.set_keepalive(config.tcp_keepalive); @@ -1666,7 +1653,7 @@ impl ClientBuilder { } } -type HyperClient = hyper::Client; +type HyperClient = hyper_util::client::legacy::Client; impl Default for Client { fn default() -> Self { @@ -1846,9 +1833,7 @@ impl Client { ResponseFuture::H3(self.inner.h3_client.as_ref().unwrap().request(req)) } _ => { - let mut req = builder - .body(body.into_stream()) - .expect("valid request parts"); + let mut req = builder.body(body).expect("valid request parts"); *req.headers_mut() = headers.clone(); ResponseFuture::Default(self.inner.hyper.request(req)) } @@ -2218,7 +2203,7 @@ impl PendingRequest { let mut req = hyper::Request::builder() .method(self.method.clone()) .uri(uri) - .body(body.into_stream()) + .body(body) .expect("valid request parts"); *req.headers_mut() = self.headers.clone(); ResponseFuture::Default(self.client.hyper.request(req)) @@ -2230,6 +2215,13 @@ impl PendingRequest { } fn is_retryable_error(err: &(dyn std::error::Error + 'static)) -> bool { + // pop the legacy::Error + let err = if let Some(err) = err.source() { + err + } else { + return false; + }; + #[cfg(feature = "http3")] if let Some(cause) = err.source() { if let Some(err) = cause.downcast_ref::() { @@ -2454,7 +2446,7 @@ impl Future for PendingRequest { let mut req = hyper::Request::builder() .method(self.method.clone()) .uri(uri.clone()) - .body(body.into_stream()) + .body(body) .expect("valid request parts"); *req.headers_mut() = headers.clone(); std::mem::swap(self.as_mut().headers(), &mut headers); diff --git a/src/async_impl/decoder.rs b/src/async_impl/decoder.rs index 86eb6e5d9..128f77ecb 100644 --- a/src/async_impl/decoder.rs +++ b/src/async_impl/decoder.rs @@ -16,14 +16,15 @@ use bytes::Bytes; use futures_core::Stream; use futures_util::stream::Peekable; use http::HeaderMap; -use hyper::body::HttpBody; +use hyper::body::Body as HttpBody; +use hyper::body::Frame; #[cfg(any(feature = "gzip", feature = "brotli", feature = "deflate"))] use tokio_util::codec::{BytesCodec, FramedRead}; #[cfg(any(feature = "gzip", feature = "brotli", feature = "deflate"))] use tokio_util::io::StreamReader; -use super::super::Body; +use super::body::ResponseBody; use crate::error; #[derive(Clone, Copy, Debug)] @@ -36,6 +37,19 @@ pub(super) struct Accepts { pub(super) deflate: bool, } +impl Accepts { + pub fn none() -> Self { + Self { + #[cfg(feature = "gzip")] + gzip: false, + #[cfg(feature = "brotli")] + brotli: false, + #[cfg(feature = "deflate")] + deflate: false, + } + } +} + /// A response decompressor over a non-blocking stream of chunks. /// /// The inner decoder may be constructed asynchronously. @@ -50,7 +64,7 @@ type PeekableIoStreamReader = StreamReader; enum Inner { /// A `PlainText` decoder just returns the response content as is. - PlainText(super::body::ImplStream), + PlainText(ResponseBody), /// A `Gzip` decoder will uncompress the gzipped response content before returning it. #[cfg(feature = "gzip")] @@ -72,7 +86,7 @@ enum Inner { /// A future attempt to poll the response body for EOF so we know whether to use gzip or not. struct Pending(PeekableIoStream, DecoderType); -struct IoStream(super::body::ImplStream); +pub(crate) struct IoStream(B); enum DecoderType { #[cfg(feature = "gzip")] @@ -93,16 +107,21 @@ impl Decoder { #[cfg(feature = "blocking")] pub(crate) fn empty() -> Decoder { Decoder { - inner: Inner::PlainText(Body::empty().into_stream()), + inner: Inner::PlainText(empty()), } } + #[cfg(feature = "blocking")] + pub(crate) fn into_stream(self) -> IoStream { + IoStream(self) + } + /// A plain text decoder. /// /// This decoder will emit the underlying chunks as-is. - fn plain_text(body: Body) -> Decoder { + fn plain_text(body: ResponseBody) -> Decoder { Decoder { - inner: Inner::PlainText(body.into_stream()), + inner: Inner::PlainText(body), } } @@ -110,12 +129,12 @@ impl Decoder { /// /// This decoder will buffer and decompress chunks that are gzipped. #[cfg(feature = "gzip")] - fn gzip(body: Body) -> Decoder { + fn gzip(body: ResponseBody) -> Decoder { use futures_util::StreamExt; Decoder { inner: Inner::Pending(Box::pin(Pending( - IoStream(body.into_stream()).peekable(), + IoStream(body).peekable(), DecoderType::Gzip, ))), } @@ -125,12 +144,12 @@ impl Decoder { /// /// This decoder will buffer and decompress chunks that are brotlied. #[cfg(feature = "brotli")] - fn brotli(body: Body) -> Decoder { + fn brotli(body: ResponseBody) -> Decoder { use futures_util::StreamExt; Decoder { inner: Inner::Pending(Box::pin(Pending( - IoStream(body.into_stream()).peekable(), + IoStream(body).peekable(), DecoderType::Brotli, ))), } @@ -140,12 +159,12 @@ impl Decoder { /// /// This decoder will buffer and decompress chunks that are deflated. #[cfg(feature = "deflate")] - fn deflate(body: Body) -> Decoder { + fn deflate(body: ResponseBody) -> Decoder { use futures_util::StreamExt; Decoder { inner: Inner::Pending(Box::pin(Pending( - IoStream(body.into_stream()).peekable(), + IoStream(body).peekable(), DecoderType::Deflate, ))), } @@ -187,7 +206,11 @@ impl Decoder { /// how to decode the content body of the request. /// /// Uses the correct variant by inspecting the Content-Encoding header. - pub(super) fn detect(_headers: &mut HeaderMap, body: Body, _accepts: Accepts) -> Decoder { + pub(super) fn detect( + _headers: &mut HeaderMap, + body: ResponseBody, + _accepts: Accepts, + ) -> Decoder { #[cfg(feature = "gzip")] { if _accepts.gzip && Decoder::detect_encoding(_headers, "gzip") { @@ -213,26 +236,35 @@ impl Decoder { } } -impl Stream for Decoder { - type Item = Result; +impl HttpBody for Decoder { + type Data = Bytes; + type Error = crate::Error; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - // Do a read or poll for a pending decoder value. + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll, Self::Error>>> { match self.inner { #[cfg(any(feature = "brotli", feature = "gzip", feature = "deflate"))] Inner::Pending(ref mut future) => match Pin::new(future).poll(cx) { Poll::Ready(Ok(inner)) => { self.inner = inner; - self.poll_next(cx) + self.poll_frame(cx) } Poll::Ready(Err(e)) => Poll::Ready(Some(Err(crate::error::decode_io(e)))), Poll::Pending => Poll::Pending, }, - Inner::PlainText(ref mut body) => Pin::new(body).poll_next(cx), + Inner::PlainText(ref mut body) => { + match futures_core::ready!(Pin::new(body).poll_frame(cx)) { + Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode(err)))), + None => Poll::Ready(None), + } + } #[cfg(feature = "gzip")] Inner::Gzip(ref mut decoder) => { match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { - Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))), + Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), None => Poll::Ready(None), } @@ -240,7 +272,7 @@ impl Stream for Decoder { #[cfg(feature = "brotli")] Inner::Brotli(ref mut decoder) => { match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { - Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))), + Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), None => Poll::Ready(None), } @@ -248,32 +280,13 @@ impl Stream for Decoder { #[cfg(feature = "deflate")] Inner::Deflate(ref mut decoder) => { match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { - Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))), + Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), None => Poll::Ready(None), } } } } -} - -impl HttpBody for Decoder { - type Data = Bytes; - type Error = crate::Error; - - fn poll_data( - self: Pin<&mut Self>, - cx: &mut Context, - ) -> Poll>> { - self.poll_next(cx) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } fn size_hint(&self) -> http_body::SizeHint { match self.inner { @@ -285,6 +298,11 @@ impl HttpBody for Decoder { } } +fn empty() -> ResponseBody { + use http_body_util::{combinators::BoxBody, BodyExt, Empty}; + BoxBody::new(Empty::new().map_err(|never| match never {})) +} + impl Future for Pending { type Output = Result; @@ -303,13 +321,10 @@ impl Future for Pending { .expect("just peeked Some") .unwrap_err())); } - None => return Poll::Ready(Ok(Inner::PlainText(Body::empty().into_stream()))), + None => return Poll::Ready(Ok(Inner::PlainText(empty()))), }; - let _body = std::mem::replace( - &mut self.0, - IoStream(Body::empty().into_stream()).peekable(), - ); + let _body = std::mem::replace(&mut self.0, IoStream(empty()).peekable()); match self.1 { #[cfg(feature = "brotli")] @@ -331,14 +346,27 @@ impl Future for Pending { } } -impl Stream for IoStream { +impl Stream for IoStream +where + B: HttpBody + Unpin, + B::Error: Into>, +{ type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - match futures_core::ready!(Pin::new(&mut self.0).poll_next(cx)) { - Some(Ok(chunk)) => Poll::Ready(Some(Ok(chunk))), - Some(Err(err)) => Poll::Ready(Some(Err(err.into_io()))), - None => Poll::Ready(None), + loop { + return match futures_core::ready!(Pin::new(&mut self.0).poll_frame(cx)) { + Some(Ok(frame)) => { + // skip non-data frames + if let Ok(buf) = frame.into_data() { + Poll::Ready(Some(Ok(buf))) + } else { + continue; + } + } + Some(Err(err)) => Poll::Ready(Some(Err(error::into_io(err.into())))), + None => Poll::Ready(None), + }; } } } @@ -346,6 +374,7 @@ impl Stream for IoStream { // ===== impl Accepts ===== impl Accepts { + /* pub(super) fn none() -> Self { Accepts { #[cfg(feature = "gzip")] @@ -356,6 +385,7 @@ impl Accepts { deflate: false, } } + */ pub(super) fn as_str(&self) -> Option<&'static str> { match (self.is_gzip(), self.is_brotli(), self.is_deflate()) { diff --git a/src/async_impl/h3_client/connect.rs b/src/async_impl/h3_client/connect.rs index 968704713..ec732f66a 100644 --- a/src/async_impl/h3_client/connect.rs +++ b/src/async_impl/h3_client/connect.rs @@ -5,7 +5,7 @@ use bytes::Bytes; use h3::client::SendRequest; use h3_quinn::{Connection, OpenStreams}; use http::Uri; -use hyper::client::connect::dns::Name; +use hyper_util::client::legacy::connect::dns::Name; use quinn::{ClientConfig, Endpoint, TransportConfig}; use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; diff --git a/src/async_impl/h3_client/dns.rs b/src/async_impl/h3_client/dns.rs index 9cb50d1e3..bd59daaed 100644 --- a/src/async_impl/h3_client/dns.rs +++ b/src/async_impl/h3_client/dns.rs @@ -1,5 +1,5 @@ use core::task; -use hyper::client::connect::dns::Name; +use hyper_util::client::legacy::connect::dns::Name; use std::future::Future; use std::net::SocketAddr; use std::task::Poll; diff --git a/src/async_impl/h3_client/pool.rs b/src/async_impl/h3_client/pool.rs index d9ca3a661..d6442c81a 100644 --- a/src/async_impl/h3_client/pool.rs +++ b/src/async_impl/h3_client/pool.rs @@ -13,7 +13,7 @@ use h3::client::SendRequest; use h3_quinn::{Connection, OpenStreams}; use http::uri::{Authority, Scheme}; use http::{Request, Response, Uri}; -use hyper::Body as HyperBody; +use hyper::body as HyperBody; use log::trace; pub(super) type Key = (Scheme, Authority); diff --git a/src/async_impl/response.rs b/src/async_impl/response.rs index 77a3e53aa..a947b5151 100644 --- a/src/async_impl/response.rs +++ b/src/async_impl/response.rs @@ -4,9 +4,9 @@ use std::pin::Pin; use bytes::Bytes; use encoding_rs::{Encoding, UTF_8}; -use futures_util::stream::StreamExt; -use hyper::client::connect::HttpInfo; +use http_body_util::BodyExt; use hyper::{HeaderMap, StatusCode, Version}; +use hyper_util::client::legacy::connect::HttpInfo; use mime::Mime; #[cfg(feature = "json")] use serde::de::DeserializeOwned; @@ -17,9 +17,9 @@ use url::Url; use super::body::Body; use super::decoder::{Accepts, Decoder}; +use crate::async_impl::body::ResponseBody; #[cfg(feature = "cookies")] use crate::cookie; -use crate::response::ResponseUrl; /// A Response to a submitted `Request`. pub struct Response { @@ -31,13 +31,17 @@ pub struct Response { impl Response { pub(super) fn new( - res: hyper::Response, + res: hyper::Response, url: Url, accepts: Accepts, timeout: Option>>, ) -> Response { let (mut parts, body) = res.into_parts(); - let decoder = Decoder::detect(&mut parts.headers, Body::response(body, timeout), accepts); + let decoder = Decoder::detect( + &mut parts.headers, + super::body::response(body, timeout), + accepts, + ); let res = hyper::Response::from_parts(parts, decoder); Response { @@ -78,9 +82,9 @@ impl Response { /// - The response is compressed and automatically decoded (thus changing /// the actual decoded length). pub fn content_length(&self) -> Option { - use hyper::body::HttpBody; + use hyper::body::Body; - HttpBody::size_hint(self.res.body()).exact() + Body::size_hint(self.res.body()).exact() } /// Retrieve the cookies contained in the response. @@ -256,7 +260,11 @@ impl Response { /// # } /// ``` pub async fn bytes(self) -> crate::Result { - hyper::body::to_bytes(self.res.into_body()).await + use http_body_util::BodyExt; + + BodyExt::collect(self.res.into_body()) + .await + .map(|buf| buf.to_bytes()) } /// Stream a chunk of the response body. @@ -276,10 +284,19 @@ impl Response { /// # } /// ``` pub async fn chunk(&mut self) -> crate::Result> { - if let Some(item) = self.res.body_mut().next().await { - Ok(Some(item?)) - } else { - Ok(None) + use http_body_util::BodyExt; + + // loop to ignore unrecognized frames + loop { + if let Some(res) = self.res.body_mut().frame().await { + let frame = res?; + if let Ok(buf) = frame.into_data() { + return Ok(Some(buf)); + } + // else continue + } else { + return Ok(None); + } } } @@ -308,7 +325,7 @@ impl Response { #[cfg(feature = "stream")] #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] pub fn bytes_stream(self) -> impl futures_core::Stream> { - self.res.into_body() + super::body::DataStream(self.res.into_body()) } // util methods @@ -396,11 +413,26 @@ impl fmt::Debug for Response { } } +/// A `Response` can be piped as the `Body` of another request. +impl From for Body { + fn from(r: Response) -> Body { + Body::streaming(r.res.into_body()) + } +} + +// I'm not sure this conversion is that useful... People should be encouraged +// to use `http::Resposne`, not `reqwest::Response`. impl> From> for Response { fn from(r: http::Response) -> Response { + use crate::response::ResponseUrl; + let (mut parts, body) = r.into_parts(); - let body = body.into(); - let decoder = Decoder::detect(&mut parts.headers, body, Accepts::none()); + let body: crate::async_impl::body::Body = body.into(); + let decoder = Decoder::detect( + &mut parts.headers, + ResponseBody::new(body.map_err(Into::into)), + Accepts::none(), + ); let url = parts .extensions .remove::() @@ -414,13 +446,6 @@ impl> From> for Response { } } -/// A `Response` can be piped as the `Body` of another request. -impl From for Body { - fn from(r: Response) -> Body { - Body::stream(r.res.into_body()) - } -} - #[cfg(test)] mod tests { use super::Response; diff --git a/src/async_impl/upgrade.rs b/src/async_impl/upgrade.rs index 4a69b4db5..3b599d0ad 100644 --- a/src/async_impl/upgrade.rs +++ b/src/async_impl/upgrade.rs @@ -3,11 +3,12 @@ use std::task::{self, Poll}; use std::{fmt, io}; use futures_util::TryFutureExt; +use hyper_util::rt::TokioIo; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; /// An upgraded HTTP connection. pub struct Upgraded { - inner: hyper::upgrade::Upgraded, + inner: TokioIo, } impl AsyncRead for Upgraded { @@ -58,7 +59,9 @@ impl fmt::Debug for Upgraded { impl From for Upgraded { fn from(inner: hyper::upgrade::Upgraded) -> Self { - Upgraded { inner } + Upgraded { + inner: TokioIo::new(inner), + } } } diff --git a/src/blocking/body.rs b/src/blocking/body.rs index db46cde05..dd44c6fa2 100644 --- a/src/blocking/body.rs +++ b/src/blocking/body.rs @@ -9,6 +9,7 @@ use std::ptr; use bytes::buf::UninitSlice; use bytes::Bytes; +use futures_channel::mpsc; use crate::async_impl; @@ -133,12 +134,12 @@ impl Body { pub(crate) fn into_async(self) -> (Option, async_impl::Body, Option) { match self.kind { Kind::Reader(read, len) => { - let (tx, rx) = hyper::Body::channel(); + let (tx, rx) = mpsc::channel(0); let tx = Sender { body: (read, len), tx, }; - (Some(tx), async_impl::Body::wrap(rx), len) + (Some(tx), async_impl::Body::stream(rx), len) } Kind::Bytes(chunk) => { let len = chunk.len() as u64; @@ -257,11 +258,23 @@ impl Read for Reader { pub(crate) struct Sender { body: (Box, Option), - tx: hyper::body::Sender, + tx: mpsc::Sender>, } +#[derive(Debug)] +struct Abort; + +impl fmt::Display for Abort { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("abort request body") + } +} + +impl std::error::Error for Abort {} + async fn send_future(sender: Sender) -> Result<(), crate::Error> { use bytes::{BufMut, BytesMut}; + use futures_util::SinkExt; use std::cmp; let con_len = sender.body.1; @@ -312,7 +325,11 @@ async fn send_future(sender: Sender) -> Result<(), crate::Error> { buf.advance_mut(n); }, Err(e) => { - tx.take().expect("tx only taken on error").abort(); + let _ = tx + .take() + .expect("tx only taken on error") + .clone() + .try_send(Err(Abort)); return Err(crate::error::body(e)); } } @@ -324,7 +341,7 @@ async fn send_future(sender: Sender) -> Result<(), crate::Error> { let buf_len = buf.len() as u64; tx.as_mut() .expect("tx only taken on error") - .send_data(buf.split().freeze()) + .send(Ok(buf.split().freeze())) .await .map_err(crate::error::body)?; diff --git a/src/blocking/client.rs b/src/blocking/client.rs index 689e6a0d8..83d1f5675 100644 --- a/src/blocking/client.rs +++ b/src/blocking/client.rs @@ -1172,7 +1172,7 @@ impl Default for Timeout { } } -pub(crate) struct KeepCoreThreadAlive(#[allow(unused)] Option>); +pub(crate) struct KeepCoreThreadAlive(#[allow(dead_code)] Option>); impl KeepCoreThreadAlive { pub(crate) fn empty() -> KeepCoreThreadAlive { diff --git a/src/blocking/response.rs b/src/blocking/response.rs index 2da634f68..6ece95ba6 100644 --- a/src/blocking/response.rs +++ b/src/blocking/response.rs @@ -397,7 +397,7 @@ impl Response { if self.body.is_none() { let body = mem::replace(self.inner.body_mut(), async_impl::Decoder::empty()); - let body = body.map_err(crate::error::into_io).into_async_read(); + let body = body.into_stream().into_async_read(); self.body = Some(Box::pin(body)); } diff --git a/src/connect.rs b/src/connect.rs index b6b51130e..3ad374021 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -2,11 +2,13 @@ use http::header::HeaderValue; use http::uri::{Authority, Scheme}; use http::Uri; -use hyper::client::connect::{Connected, Connection}; -use hyper::service::Service; +use hyper::rt::{Read, ReadBufCursor, Write}; +use hyper_util::client::legacy::connect::{Connected, Connection}; +#[cfg(feature = "__tls")] +use hyper_util::rt::TokioIo; #[cfg(feature = "native-tls-crate")] use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tower_service::Service; use pin_project_lite::pin_project; use std::future::Future; @@ -25,7 +27,7 @@ use crate::dns::DynResolver; use crate::error::BoxError; use crate::proxy::{Proxy, ProxyScheme}; -pub(crate) type HttpConnector = hyper::client::HttpConnector; +pub(crate) type HttpConnector = hyper_util::client::legacy::connect::HttpConnector; #[derive(Clone)] pub(crate) struct Connector { @@ -193,8 +195,11 @@ impl Connector { if dst.scheme() == Some(&Scheme::HTTPS) { let host = dst.host().ok_or("no host in url")?.to_string(); let conn = socks::connect(proxy, dst, dns).await?; + let conn = hyper_util::rt::TokioIo::new(conn); + let conn = hyper_util::rt::TokioIo::new(conn); let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); let io = tls_connector.connect(&host, conn).await?; + let io = hyper_util::rt::TokioIo::new(io); return Ok(Conn { inner: self.verbose.wrap(NativeTlsConn { inner: io }), is_proxy: false, @@ -211,11 +216,15 @@ impl Connector { let tls = tls_proxy.clone(); let host = dst.host().ok_or("no host in url")?.to_string(); let conn = socks::connect(proxy, dst, dns).await?; - let server_name = rustls::ServerName::try_from(host.as_str()) - .map_err(|_| "Invalid Server Name")?; + let conn = hyper_util::rt::TokioIo::new(conn); + let conn = hyper_util::rt::TokioIo::new(conn); + let server_name = + rustls_pki_types::ServerName::try_from(host.as_str().to_owned()) + .map_err(|_| "Invalid Server Name")?; let io = RustlsConnector::from(tls) .connect(server_name, conn) .await?; + let io = hyper_util::rt::TokioIo::new(io); return Ok(Conn { inner: self.verbose.wrap(RustlsTlsConn { inner: io }), is_proxy: false, @@ -228,7 +237,7 @@ impl Connector { } socks::connect(proxy, dst, dns).await.map(|tcp| Conn { - inner: self.verbose.wrap(tcp), + inner: self.verbose.wrap(hyper_util::rt::TokioIo::new(tcp)), is_proxy: false, tls_info: false, }) @@ -262,7 +271,14 @@ impl Connector { if let hyper_tls::MaybeHttpsStream::Https(stream) = io { if !self.nodelay { - stream.get_ref().get_ref().get_ref().set_nodelay(false)?; + stream + .inner() + .get_ref() + .get_ref() + .get_ref() + .inner() + .inner() + .set_nodelay(false)?; } Ok(Conn { inner: self.verbose.wrap(NativeTlsConn { inner: stream }), @@ -293,8 +309,8 @@ impl Connector { if let hyper_rustls::MaybeHttpsStream::Https(stream) = io { if !self.nodelay { - let (io, _) = stream.get_ref(); - io.set_nodelay(false)?; + let (io, _) = stream.inner().get_ref(); + io.inner().inner().set_nodelay(false)?; } Ok(Conn { inner: self.verbose.wrap(RustlsTlsConn { inner: stream }), @@ -350,10 +366,12 @@ impl Connector { .await?; let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); let io = tls_connector - .connect(host.ok_or("no host in url")?, tunneled) + .connect(host.ok_or("no host in url")?, TokioIo::new(tunneled)) .await?; return Ok(Conn { - inner: self.verbose.wrap(NativeTlsConn { inner: io }), + inner: self.verbose.wrap(NativeTlsConn { + inner: TokioIo::new(io), + }), is_proxy: false, tls_info: false, }); @@ -366,7 +384,7 @@ impl Connector { tls_proxy, } => { if dst.scheme() == Some(&Scheme::HTTPS) { - use rustls::ServerName; + use rustls_pki_types::ServerName; use std::convert::TryFrom; use tokio_rustls::TlsConnector as RustlsConnector; @@ -377,16 +395,18 @@ impl Connector { let tls = tls.clone(); let conn = http.call(proxy_dst).await?; log::trace!("tunneling HTTPS over proxy"); - let maybe_server_name = - ServerName::try_from(host.as_str()).map_err(|_| "Invalid Server Name"); + let maybe_server_name = ServerName::try_from(host.as_str().to_owned()) + .map_err(|_| "Invalid Server Name"); let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?; let server_name = maybe_server_name?; let io = RustlsConnector::from(tls) - .connect(server_name, tunneled) + .connect(server_name, TokioIo::new(tunneled)) .await?; return Ok(Conn { - inner: self.verbose.wrap(RustlsTlsConn { inner: io }), + inner: self.verbose.wrap(RustlsTlsConn { + inner: TokioIo::new(io), + }), is_proxy: false, tls_info: false, }); @@ -476,18 +496,15 @@ impl TlsInfoFactory for tokio::net::TcpStream { } } -#[cfg(feature = "default-tls")] -impl TlsInfoFactory for hyper_tls::MaybeHttpsStream { +#[cfg(feature = "__tls")] +impl TlsInfoFactory for TokioIo { fn tls_info(&self) -> Option { - match self { - hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(), - hyper_tls::MaybeHttpsStream::Http(_) => None, - } + self.inner().tls_info() } } #[cfg(feature = "default-tls")] -impl TlsInfoFactory for hyper_tls::TlsStream> { +impl TlsInfoFactory for tokio_native_tls::TlsStream>> { fn tls_info(&self) -> Option { let peer_certificate = self .get_ref() @@ -500,7 +517,11 @@ impl TlsInfoFactory for hyper_tls::TlsStream { +impl TlsInfoFactory + for tokio_native_tls::TlsStream< + TokioIo>>, + > +{ fn tls_info(&self) -> Option { let peer_certificate = self .get_ref() @@ -512,32 +533,35 @@ impl TlsInfoFactory for tokio_native_tls::TlsStream { } } -#[cfg(feature = "__rustls")] -impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream { +#[cfg(feature = "default-tls")] +impl TlsInfoFactory for hyper_tls::MaybeHttpsStream> { fn tls_info(&self) -> Option { match self { - hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(), - hyper_rustls::MaybeHttpsStream::Http(_) => None, + hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(), + hyper_tls::MaybeHttpsStream::Http(_) => None, } } } #[cfg(feature = "__rustls")] -impl TlsInfoFactory for tokio_rustls::TlsStream { +impl TlsInfoFactory for tokio_rustls::client::TlsStream>> { fn tls_info(&self) -> Option { let peer_certificate = self .get_ref() .1 .peer_certificates() .and_then(|certs| certs.first()) - .map(|c| c.0.clone()); + .map(|c| c.first()) + .and_then(|c| c.map(|cc| vec![*cc])); Some(crate::tls::TlsInfo { peer_certificate }) } } #[cfg(feature = "__rustls")] impl TlsInfoFactory - for tokio_rustls::client::TlsStream> + for tokio_rustls::client::TlsStream< + TokioIo>>, + > { fn tls_info(&self) -> Option { let peer_certificate = self @@ -545,30 +569,28 @@ impl TlsInfoFactory .1 .peer_certificates() .and_then(|certs| certs.first()) - .map(|c| c.0.clone()); + .map(|c| c.first()) + .and_then(|c| c.map(|cc| vec![*cc])); Some(crate::tls::TlsInfo { peer_certificate }) } } #[cfg(feature = "__rustls")] -impl TlsInfoFactory for tokio_rustls::client::TlsStream { +impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream> { fn tls_info(&self) -> Option { - let peer_certificate = self - .get_ref() - .1 - .peer_certificates() - .and_then(|certs| certs.first()) - .map(|c| c.0.clone()); - Some(crate::tls::TlsInfo { peer_certificate }) + match self { + hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(), + hyper_rustls::MaybeHttpsStream::Http(_) => None, + } } } pub(crate) trait AsyncConn: - AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static + Read + Write + Connection + Send + Sync + Unpin + 'static { } -impl AsyncConn for T {} +impl AsyncConn for T {} #[cfg(feature = "__tls")] trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {} @@ -614,25 +636,25 @@ impl Connection for Conn { } } -impl AsyncRead for Conn { +impl Read for Conn { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut ReadBuf<'_>, + buf: ReadBufCursor<'_>, ) -> Poll> { let this = self.project(); - AsyncRead::poll_read(this.inner, cx, buf) + Read::poll_read(this.inner, cx, buf) } } -impl AsyncWrite for Conn { +impl Write for Conn { fn poll_write( self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll> { let this = self.project(); - AsyncWrite::poll_write(this.inner, cx, buf) + Write::poll_write(this.inner, cx, buf) } fn poll_write_vectored( @@ -641,7 +663,7 @@ impl AsyncWrite for Conn { bufs: &[IoSlice<'_>], ) -> Poll> { let this = self.project(); - AsyncWrite::poll_write_vectored(this.inner, cx, bufs) + Write::poll_write_vectored(this.inner, cx, bufs) } fn is_write_vectored(&self) -> bool { @@ -650,12 +672,12 @@ impl AsyncWrite for Conn { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = self.project(); - AsyncWrite::poll_flush(this.inner, cx) + Write::poll_flush(this.inner, cx) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = self.project(); - AsyncWrite::poll_shutdown(this.inner, cx) + Write::poll_shutdown(this.inner, cx) } } @@ -670,8 +692,9 @@ async fn tunnel( auth: Option, ) -> Result where - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { + use hyper_util::rt::TokioIo; use tokio::io::{AsyncReadExt, AsyncWriteExt}; let mut buf = format!( @@ -700,13 +723,15 @@ where // headers end buf.extend_from_slice(b"\r\n"); - conn.write_all(&buf).await?; + let mut tokio_conn = TokioIo::new(&mut conn); + + tokio_conn.write_all(&buf).await?; let mut buf = [0; 8192]; let mut pos = 0; loop { - let n = conn.read(&mut buf[pos..]).await?; + let n = tokio_conn.read(&mut buf[pos..]).await?; if n == 0 { return Err(tunnel_eof()); @@ -738,62 +763,69 @@ fn tunnel_eof() -> BoxError { #[cfg(feature = "default-tls")] mod native_tls_conn { use super::TlsInfoFactory; - use hyper::client::connect::{Connected, Connection}; + use hyper::rt::{Read, ReadBufCursor, Write}; + use hyper_tls::MaybeHttpsStream; + use hyper_util::client::legacy::connect::{Connected, Connection}; + use hyper_util::rt::TokioIo; use pin_project_lite::pin_project; use std::{ io::{self, IoSlice}, pin::Pin, task::{Context, Poll}, }; - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio::net::TcpStream; use tokio_native_tls::TlsStream; pin_project! { pub(super) struct NativeTlsConn { - #[pin] pub(super) inner: TlsStream, + #[pin] pub(super) inner: TokioIo>, } } - impl Connection for NativeTlsConn { - #[cfg(feature = "native-tls-alpn")] + impl Connection for NativeTlsConn>> { fn connected(&self) -> Connected { - match self.inner.get_ref().negotiated_alpn().ok() { - Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => self - .inner - .get_ref() - .get_ref() - .get_ref() - .connected() - .negotiated_h2(), - _ => self.inner.get_ref().get_ref().get_ref().connected(), - } + self.inner + .inner() + .get_ref() + .get_ref() + .get_ref() + .inner() + .connected() } + } - #[cfg(not(feature = "native-tls-alpn"))] + impl Connection for NativeTlsConn>>> { fn connected(&self) -> Connected { - self.inner.get_ref().get_ref().get_ref().connected() + self.inner + .inner() + .get_ref() + .get_ref() + .get_ref() + .inner() + .connected() } } - impl AsyncRead for NativeTlsConn { + impl Read for NativeTlsConn { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut ReadBuf<'_>, + buf: ReadBufCursor<'_>, ) -> Poll> { let this = self.project(); - AsyncRead::poll_read(this.inner, cx, buf) + Read::poll_read(this.inner, cx, buf) } } - impl AsyncWrite for NativeTlsConn { + impl Write for NativeTlsConn { fn poll_write( self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll> { let this = self.project(); - AsyncWrite::poll_write(this.inner, cx, buf) + Write::poll_write(this.inner, cx, buf) } fn poll_write_vectored( @@ -802,7 +834,7 @@ mod native_tls_conn { bufs: &[IoSlice<'_>], ) -> Poll> { let this = self.project(); - AsyncWrite::poll_write_vectored(this.inner, cx, bufs) + Write::poll_write_vectored(this.inner, cx, bufs) } fn is_write_vectored(&self) -> bool { @@ -814,7 +846,7 @@ mod native_tls_conn { cx: &mut Context, ) -> Poll> { let this = self.project(); - AsyncWrite::poll_flush(this.inner, cx) + Write::poll_flush(this.inner, cx) } fn poll_shutdown( @@ -822,17 +854,14 @@ mod native_tls_conn { cx: &mut Context, ) -> Poll> { let this = self.project(); - AsyncWrite::poll_shutdown(this.inner, cx) + Write::poll_shutdown(this.inner, cx) } } - impl TlsInfoFactory for NativeTlsConn { - fn tls_info(&self) -> Option { - self.inner.tls_info() - } - } - - impl TlsInfoFactory for NativeTlsConn> { + impl TlsInfoFactory for NativeTlsConn + where + TokioIo>: TlsInfoFactory, + { fn tls_info(&self) -> Option { self.inner.tls_info() } @@ -842,51 +871,76 @@ mod native_tls_conn { #[cfg(feature = "__rustls")] mod rustls_tls_conn { use super::TlsInfoFactory; - use hyper::client::connect::{Connected, Connection}; + use hyper::rt::{Read, ReadBufCursor, Write}; + use hyper_rustls::MaybeHttpsStream; + use hyper_util::client::legacy::connect::{Connected, Connection}; + use hyper_util::rt::TokioIo; use pin_project_lite::pin_project; use std::{ io::{self, IoSlice}, pin::Pin, task::{Context, Poll}, }; - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio::net::TcpStream; use tokio_rustls::client::TlsStream; pin_project! { pub(super) struct RustlsTlsConn { - #[pin] pub(super) inner: TlsStream, + #[pin] pub(super) inner: TokioIo>, } } - impl Connection for RustlsTlsConn { + impl Connection for RustlsTlsConn>> { + fn connected(&self) -> Connected { + if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") { + self.inner + .inner() + .get_ref() + .0 + .inner() + .connected() + .negotiated_h2() + } else { + self.inner.inner().get_ref().0.inner().connected() + } + } + } + impl Connection for RustlsTlsConn>>> { fn connected(&self) -> Connected { - if self.inner.get_ref().1.alpn_protocol() == Some(b"h2") { - self.inner.get_ref().0.connected().negotiated_h2() + if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") { + self.inner + .inner() + .get_ref() + .0 + .inner() + .connected() + .negotiated_h2() } else { - self.inner.get_ref().0.connected() + self.inner.inner().get_ref().0.inner().connected() } } } - impl AsyncRead for RustlsTlsConn { + impl Read for RustlsTlsConn { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut ReadBuf<'_>, + buf: ReadBufCursor<'_>, ) -> Poll> { let this = self.project(); - AsyncRead::poll_read(this.inner, cx, buf) + Read::poll_read(this.inner, cx, buf) } } - impl AsyncWrite for RustlsTlsConn { + impl Write for RustlsTlsConn { fn poll_write( self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll> { let this = self.project(); - AsyncWrite::poll_write(this.inner, cx, buf) + Write::poll_write(this.inner, cx, buf) } fn poll_write_vectored( @@ -895,7 +949,7 @@ mod rustls_tls_conn { bufs: &[IoSlice<'_>], ) -> Poll> { let this = self.project(); - AsyncWrite::poll_write_vectored(this.inner, cx, bufs) + Write::poll_write_vectored(this.inner, cx, bufs) } fn is_write_vectored(&self) -> bool { @@ -907,7 +961,7 @@ mod rustls_tls_conn { cx: &mut Context, ) -> Poll> { let this = self.project(); - AsyncWrite::poll_flush(this.inner, cx) + Write::poll_flush(this.inner, cx) } fn poll_shutdown( @@ -915,17 +969,13 @@ mod rustls_tls_conn { cx: &mut Context, ) -> Poll> { let this = self.project(); - AsyncWrite::poll_shutdown(this.inner, cx) + Write::poll_shutdown(this.inner, cx) } } - - impl TlsInfoFactory for RustlsTlsConn { - fn tls_info(&self) -> Option { - self.inner.tls_info() - } - } - - impl TlsInfoFactory for RustlsTlsConn> { + impl TlsInfoFactory for RustlsTlsConn + where + TokioIo>: TlsInfoFactory, + { fn tls_info(&self) -> Option { self.inner.tls_info() } @@ -998,13 +1048,13 @@ mod socks { } mod verbose { - use hyper::client::connect::{Connected, Connection}; + use hyper::rt::{Read, ReadBufCursor, Write}; + use hyper_util::client::legacy::connect::{Connected, Connection}; use std::cmp::min; use std::fmt; use std::io::{self, IoSlice}; use std::pin::Pin; use std::task::{Context, Poll}; - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pub(super) const OFF: Wrapper = Wrapper(false); @@ -1030,21 +1080,24 @@ mod verbose { inner: T, } - impl Connection for Verbose { + impl Connection for Verbose { fn connected(&self) -> Connected { self.inner.connected() } } - impl AsyncRead for Verbose { + impl Read for Verbose { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context, - buf: &mut ReadBuf<'_>, + buf: ReadBufCursor<'_>, ) -> Poll> { match Pin::new(&mut self.inner).poll_read(cx, buf) { Poll::Ready(Ok(())) => { + /* log::trace!("{:08x} read: {:?}", self.id, Escape(buf.filled())); + */ + log::trace!("TODO: verbose poll_read"); Poll::Ready(Ok(())) } Poll::Ready(Err(e)) => Poll::Ready(Err(e)), @@ -1053,7 +1106,7 @@ mod verbose { } } - impl AsyncWrite for Verbose { + impl Write for Verbose { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context, @@ -1169,6 +1222,7 @@ mod verbose { mod tests { use super::tunnel; use crate::proxy; + use hyper_util::rt::TokioIo; use std::io::{Read, Write}; use std::net::TcpListener; use std::thread; @@ -1231,7 +1285,7 @@ mod tests { .build() .expect("new rt"); let f = async move { - let tcp = TcpStream::connect(&addr).await?; + let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); tunnel(tcp, host, port, ua(), None).await @@ -1249,7 +1303,7 @@ mod tests { .build() .expect("new rt"); let f = async move { - let tcp = TcpStream::connect(&addr).await?; + let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); tunnel(tcp, host, port, ua(), None).await @@ -1267,7 +1321,7 @@ mod tests { .build() .expect("new rt"); let f = async move { - let tcp = TcpStream::connect(&addr).await?; + let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); tunnel(tcp, host, port, ua(), None).await @@ -1291,7 +1345,7 @@ mod tests { .build() .expect("new rt"); let f = async move { - let tcp = TcpStream::connect(&addr).await?; + let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); tunnel(tcp, host, port, ua(), None).await @@ -1313,7 +1367,7 @@ mod tests { .build() .expect("new rt"); let f = async move { - let tcp = TcpStream::connect(&addr).await?; + let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); tunnel( diff --git a/src/dns/gai.rs b/src/dns/gai.rs index f32f3b0e0..00c981f0a 100644 --- a/src/dns/gai.rs +++ b/src/dns/gai.rs @@ -1,6 +1,6 @@ use futures_util::future::FutureExt; -use hyper::client::connect::dns::{GaiResolver as HyperGaiResolver, Name}; -use hyper::service::Service; +use hyper_util::client::legacy::connect::dns::{GaiResolver as HyperGaiResolver, Name}; +use tower_service::Service; use crate::dns::{Addrs, Resolve, Resolving}; use crate::error::BoxError; diff --git a/src/dns/resolve.rs b/src/dns/resolve.rs index 3686765a0..4c36f30ec 100644 --- a/src/dns/resolve.rs +++ b/src/dns/resolve.rs @@ -1,5 +1,5 @@ -use hyper::client::connect::dns::Name; -use hyper::service::Service; +use hyper_util::client::legacy::connect::dns::Name; +use tower_service::Service; use std::collections::HashMap; use std::future::Future; diff --git a/src/dns/trust_dns.rs b/src/dns/trust_dns.rs index a25326085..fc93f08b1 100644 --- a/src/dns/trust_dns.rs +++ b/src/dns/trust_dns.rs @@ -1,6 +1,6 @@ //! DNS resolution via the [trust_dns_resolver](https://github.com/bluejekyll/trust-dns) crate -use hyper::client::connect::dns::Name; +use hyper_util::client::legacy::connect::dns::Name; use once_cell::sync::OnceCell; use trust_dns_resolver::{lookup_ip::LookupIpIntoIter, system_conf, TokioAsyncResolver}; diff --git a/src/error.rs b/src/error.rs index 9ffb6ed17..c558ebbac 100644 --- a/src/error.rs +++ b/src/error.rs @@ -127,7 +127,7 @@ impl Error { let mut source = self.source(); while let Some(err) = source { - if let Some(hyper_err) = err.downcast_ref::() { + if let Some(hyper_err) = err.downcast_ref::() { if hyper_err.is_connect() { return true; } @@ -291,9 +291,8 @@ pub(crate) fn upgrade>(e: E) -> Error { // io::Error helpers -#[allow(unused)] -pub(crate) fn into_io(e: Error) -> io::Error { - e.into_io() +pub(crate) fn into_io(e: BoxError) -> io::Error { + io::Error::new(io::ErrorKind::Other, e) } #[allow(unused)] diff --git a/src/tls.rs b/src/tls.rs index 3f53d875f..27101d733 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -46,9 +46,11 @@ #[cfg(feature = "__rustls")] use rustls::{ - client::HandshakeSignatureValid, client::ServerCertVerified, client::ServerCertVerifier, - DigitallySignedStruct, Error as TLSError, ServerName, + client::danger::HandshakeSignatureValid, client::danger::ServerCertVerified, + client::danger::ServerCertVerifier, DigitallySignedStruct, Error as TLSError, SignatureScheme, }; +#[cfg(feature = "__rustls")] +use rustls_pki_types::{ServerName, UnixTime}; use std::{ fmt, io::{BufRead, BufReader}, @@ -77,7 +79,6 @@ pub struct Identity { inner: ClientCert, } -#[derive(Clone)] enum ClientCert { #[cfg(feature = "native-tls")] Pkcs12(native_tls_crate::Identity), @@ -85,11 +86,32 @@ enum ClientCert { Pkcs8(native_tls_crate::Identity), #[cfg(feature = "__rustls")] Pem { - key: rustls::PrivateKey, - certs: Vec, + key: rustls_pki_types::PrivateKeyDer<'static>, + certs: Vec>, }, } +impl Clone for ClientCert { + fn clone(&self) -> Self { + match self { + #[cfg(feature = "native-tls")] + Self::Pkcs8(i) => Self::Pkcs8(i.clone()), + #[cfg(feature = "native-tls")] + Self::Pkcs12(i) => Self::Pkcs12(i.clone()), + #[cfg(feature = "__rustls")] + ClientCert::Pem { key, certs } => ClientCert::Pem { + key: key.clone_key(), + certs: certs.clone(), + }, + #[cfg_attr( + any(feature = "native-tls", feature = "__rustls"), + allow(unreachable_patterns) + )] + _ => unreachable!(), + } + } +} + impl Certificate { /// Create a `Certificate` from a binary DER encoded certificate /// @@ -181,14 +203,14 @@ impl Certificate { match self.original { Cert::Der(buf) => root_cert_store - .add(&rustls::Certificate(buf)) + .add(buf.into()) .map_err(crate::error::builder)?, Cert::Pem(buf) => { let mut reader = Cursor::new(buf); let certs = Self::read_pem_certs(&mut reader)?; for c in certs { root_cert_store - .add(&rustls::Certificate(c)) + .add(c.into()) .map_err(crate::error::builder)?; } } @@ -308,8 +330,8 @@ impl Identity { let (key, certs) = { let mut pem = Cursor::new(buf); - let mut sk = Vec::::new(); - let mut certs = Vec::::new(); + let mut sk = Vec::::new(); + let mut certs = Vec::::new(); for item in std::iter::from_fn(|| rustls_pemfile::read_one(&mut pem).transpose()) { match item.map_err(|_| { @@ -317,12 +339,16 @@ impl Identity { "Invalid identity PEM file", ))) })? { - rustls_pemfile::Item::X509Certificate(cert) => { - certs.push(rustls::Certificate(cert)) + rustls_pemfile::Item::X509Certificate(cert) => certs.push(cert.into()), + rustls_pemfile::Item::PKCS8Key(key) => { + sk.push(rustls_pki_types::PrivateKeyDer::Pkcs8(key.into())) + } + rustls_pemfile::Item::RSAKey(key) => { + sk.push(rustls_pki_types::PrivateKeyDer::Pkcs1(key.into())) + } + rustls_pemfile::Item::ECKey(key) => { + sk.push(rustls_pki_types::PrivateKeyDer::Sec1(key.into())) } - rustls_pemfile::Item::PKCS8Key(key) => sk.push(rustls::PrivateKey(key)), - rustls_pemfile::Item::RSAKey(key) => sk.push(rustls::PrivateKey(key)), - rustls_pemfile::Item::ECKey(key) => sk.push(rustls::PrivateKey(key)), _ => { return Err(crate::error::builder(TLSError::General(String::from( "No valid certificate was found", @@ -365,7 +391,8 @@ impl Identity { self, config_builder: rustls::ConfigBuilder< rustls::ClientConfig, - rustls::client::WantsTransparencyPolicyOrClientCert, + // Not sure here + rustls::client::WantsClientCert, >, ) -> crate::Result { match self.inner { @@ -491,18 +518,18 @@ impl Default for TlsBackend { } #[cfg(feature = "__rustls")] +#[derive(Debug)] pub(crate) struct NoVerifier; #[cfg(feature = "__rustls")] impl ServerCertVerifier for NoVerifier { fn verify_server_cert( &self, - _end_entity: &rustls::Certificate, - _intermediates: &[rustls::Certificate], + _end_entity: &rustls_pki_types::CertificateDer, + _intermediates: &[rustls_pki_types::CertificateDer], _server_name: &ServerName, - _scts: &mut dyn Iterator, _ocsp_response: &[u8], - _now: std::time::SystemTime, + _now: UnixTime, ) -> Result { Ok(ServerCertVerified::assertion()) } @@ -510,7 +537,7 @@ impl ServerCertVerifier for NoVerifier { fn verify_tls12_signature( &self, _message: &[u8], - _cert: &rustls::Certificate, + _cert: &rustls_pki_types::CertificateDer, _dss: &DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) @@ -519,11 +546,29 @@ impl ServerCertVerifier for NoVerifier { fn verify_tls13_signature( &self, _message: &[u8], - _cert: &rustls::Certificate, + _cert: &rustls_pki_types::CertificateDer, _dss: &DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) } + + fn supported_verify_schemes(&self) -> Vec { + vec![ + SignatureScheme::RSA_PKCS1_SHA1, + SignatureScheme::ECDSA_SHA1_Legacy, + SignatureScheme::RSA_PKCS1_SHA256, + SignatureScheme::ECDSA_NISTP256_SHA256, + SignatureScheme::RSA_PKCS1_SHA384, + SignatureScheme::ECDSA_NISTP384_SHA384, + SignatureScheme::RSA_PKCS1_SHA512, + SignatureScheme::ECDSA_NISTP521_SHA512, + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::RSA_PSS_SHA384, + SignatureScheme::RSA_PSS_SHA512, + SignatureScheme::ED25519, + SignatureScheme::ED448, + ] + } } /// Hyper extension carrying extra TLS layer information. diff --git a/tests/blocking.rs b/tests/blocking.rs index fa6c8d01c..314b3e504 100644 --- a/tests/blocking.rs +++ b/tests/blocking.rs @@ -1,7 +1,9 @@ mod support; +#[cfg(feature = "json")] use http::header::CONTENT_TYPE; -use http::HeaderValue; +use http_body_util::BodyExt; +#[cfg(feature = "json")] use std::collections::HashMap; use support::server; @@ -88,7 +90,7 @@ fn test_post() { assert_eq!(req.method(), "POST"); assert_eq!(req.headers()["content-length"], "5"); - let data = hyper::body::to_bytes(req.into_body()).await.unwrap(); + let data = req.into_body().collect().await.unwrap().to_bytes(); assert_eq!(&*data, b"Hello"); http::Response::default() @@ -115,7 +117,7 @@ fn test_post_form() { "application/x-www-form-urlencoded" ); - let data = hyper::body::to_bytes(req.into_body()).await.unwrap(); + let data = req.into_body().collect().await.unwrap().to_bytes(); assert_eq!(&*data, b"hello=world&sean=monstar"); http::Response::default() @@ -336,6 +338,8 @@ fn test_body_from_bytes() { #[test] #[cfg(feature = "json")] fn blocking_add_json_default_content_type_if_not_set_manually() { + use http::header::HeaderValue; + let mut map = HashMap::new(); map.insert("body", "json"); let content_type = HeaderValue::from_static("application/vnd.api+json"); diff --git a/tests/brotli.rs b/tests/brotli.rs index dc7d6d767..5c2b01849 100644 --- a/tests/brotli.rs +++ b/tests/brotli.rs @@ -19,7 +19,6 @@ async fn test_brotli_empty_body() { http::Response::builder() .header("content-encoding", "br") - .header("content-length", 100) .body(Default::default()) .unwrap() }); @@ -125,7 +124,7 @@ async fn brotli_case(response_size: usize, chunk_size: usize) { Some((chunk, (brotlied, pos + 1))) }); - let body = hyper::Body::wrap_stream(stream.map(Ok::<_, std::convert::Infallible>)); + let body = reqwest::Body::wrap_stream(stream.map(Ok::<_, std::convert::Infallible>)); http::Response::builder() .header("content-encoding", "br") diff --git a/tests/client.rs b/tests/client.rs index c09415c0a..f144763ee 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1,7 +1,6 @@ #![cfg(not(target_arch = "wasm32"))] mod support; -use futures_util::stream::StreamExt; use support::delay_server; use support::server; @@ -181,19 +180,23 @@ async fn response_json() { #[tokio::test] async fn body_pipe_response() { + use http_body_util::BodyExt; let _ = env_logger::try_init(); - let server = server::http(move |mut req| async move { + let server = server::http(move |req| async move { if req.uri() == "/get" { http::Response::new("pipe me".into()) } else { assert_eq!(req.uri(), "/pipe"); assert_eq!(req.headers()["transfer-encoding"], "chunked"); - let mut full: Vec = Vec::new(); - while let Some(item) = req.body_mut().next().await { - full.extend(&*item.unwrap()); - } + let full: Vec = req + .into_body() + .collect() + .await + .expect("must succeed") + .to_bytes() + .to_vec(); assert_eq!(full, b"pipe me"); @@ -370,7 +373,6 @@ fn use_preconfigured_rustls_default() { let root_cert_store = rustls::RootCertStore::empty(); let tls = rustls::ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(root_cert_store) .with_no_client_auth(); @@ -501,7 +503,9 @@ async fn highly_concurrent_requests_to_http2_server_with_low_max_concurrent_stre assert_eq!(req.version(), http::Version::HTTP_2); http::Response::default() }, - |builder| builder.http2_only(true).http2_max_concurrent_streams(1), + |builder| { + builder.http2().max_concurrent_streams(1); + }, ); let url = format!("http://{}", server.addr()); @@ -529,9 +533,8 @@ async fn highly_concurrent_requests_to_slow_http2_server_with_low_max_concurrent assert_eq!(req.version(), http::Version::HTTP_2); http::Response::default() }, - |mut http| { - http.http2_only(true).http2_max_concurrent_streams(1); - http + |http| { + http.http2().max_concurrent_streams(1); }, std::time::Duration::from_secs(2), ) diff --git a/tests/deflate.rs b/tests/deflate.rs index 3b8d9e021..ec27ba180 100644 --- a/tests/deflate.rs +++ b/tests/deflate.rs @@ -19,7 +19,6 @@ async fn test_deflate_empty_body() { http::Response::builder() .header("content-encoding", "deflate") - .header("content-length", 100) .body(Default::default()) .unwrap() }); @@ -128,7 +127,7 @@ async fn deflate_case(response_size: usize, chunk_size: usize) { Some((chunk, (deflated, pos + 1))) }); - let body = hyper::Body::wrap_stream(stream.map(Ok::<_, std::convert::Infallible>)); + let body = reqwest::Body::wrap_stream(stream.map(Ok::<_, std::convert::Infallible>)); http::Response::builder() .header("content-encoding", "deflate") diff --git a/tests/gzip.rs b/tests/gzip.rs index 66e1b7f25..57189e0ac 100644 --- a/tests/gzip.rs +++ b/tests/gzip.rs @@ -20,7 +20,6 @@ async fn test_gzip_empty_body() { http::Response::builder() .header("content-encoding", "gzip") - .header("content-length", 100) .body(Default::default()) .unwrap() }); @@ -129,7 +128,7 @@ async fn gzip_case(response_size: usize, chunk_size: usize) { Some((chunk, (gzipped, pos + 1))) }); - let body = hyper::Body::wrap_stream(stream.map(Ok::<_, std::convert::Infallible>)); + let body = reqwest::Body::wrap_stream(stream.map(Ok::<_, std::convert::Infallible>)); http::Response::builder() .header("content-encoding", "gzip") diff --git a/tests/multipart.rs b/tests/multipart.rs index 59ada280d..425c830a7 100644 --- a/tests/multipart.rs +++ b/tests/multipart.rs @@ -1,6 +1,6 @@ #![cfg(not(target_arch = "wasm32"))] mod support; -use futures_util::stream::StreamExt; +use http_body_util::BodyExt; use support::server; #[tokio::test] @@ -33,8 +33,8 @@ async fn text_part() { ); let mut full: Vec = Vec::new(); - while let Some(item) = req.body_mut().next().await { - full.extend(&*item.unwrap()); + while let Some(item) = req.body_mut().frame().await { + full.extend(&*item.unwrap().into_data().unwrap()); } assert_eq!(full, expected_body.as_bytes()); @@ -97,10 +97,7 @@ async fn stream_part() { assert_eq!(req.headers()["content-type"], ct); assert_eq!(req.headers()["transfer-encoding"], "chunked"); - let mut full: Vec = Vec::new(); - while let Some(item) = req.body_mut().next().await { - full.extend(&*item.unwrap()); - } + let full = req.collect().await.unwrap().to_bytes(); assert_eq!(full, expected_body.as_bytes()); @@ -159,10 +156,7 @@ fn blocking_file_part() { expected_body.len().to_string() ); - let mut full: Vec = Vec::new(); - while let Some(item) = req.body_mut().next().await { - full.extend(&*item.unwrap()); - } + let full = req.collect().await.unwrap().to_bytes(); assert_eq!(full, expected_body.as_bytes()); diff --git a/tests/redirect.rs b/tests/redirect.rs index 9df6265a4..c98c799ef 100644 --- a/tests/redirect.rs +++ b/tests/redirect.rs @@ -1,7 +1,7 @@ #![cfg(not(target_arch = "wasm32"))] mod support; -use futures_util::stream::StreamExt; -use hyper::Body; +use http_body_util::BodyExt; +use reqwest::Body; use support::server; #[tokio::test] @@ -87,7 +87,14 @@ async fn test_redirect_307_and_308_tries_to_post_again() { assert_eq!(req.method(), "POST"); assert_eq!(req.headers()["content-length"], "5"); - let data = req.body_mut().next().await.unwrap().unwrap(); + let data = req + .body_mut() + .frame() + .await + .unwrap() + .unwrap() + .into_data() + .unwrap(); assert_eq!(&*data, b"Hello"); if req.uri() == &*format!("/{code}") { @@ -130,7 +137,14 @@ fn test_redirect_307_does_not_try_if_reader_cannot_reset() { assert_eq!(req.uri(), &*format!("/{code}")); assert_eq!(req.headers()["transfer-encoding"], "chunked"); - let data = req.body_mut().next().await.unwrap().unwrap(); + let data = req + .body_mut() + .frame() + .await + .unwrap() + .unwrap() + .into_data() + .unwrap(); assert_eq!(&*data, b"Hello"); http::Response::builder() diff --git a/tests/support/delay_server.rs b/tests/support/delay_server.rs index 08f421598..f79c2a4df 100644 --- a/tests/support/delay_server.rs +++ b/tests/support/delay_server.rs @@ -1,14 +1,13 @@ #![cfg(not(target_arch = "wasm32"))] +#![allow(unused)] use std::convert::Infallible; use std::future::Future; use std::net; -use std::sync::Arc; use std::time::Duration; use futures_util::FutureExt; use http::{Request, Response}; use hyper::service::service_fn; -use hyper::Body; use tokio::net::TcpListener; use tokio::select; use tokio::sync::oneshot; @@ -29,12 +28,14 @@ pub struct Server { server_terminated_rx: oneshot::Receiver<()>, } +type Builder = hyper_util::server::conn::auto::Builder; + impl Server { - pub async fn new(func: F1, apply_config: F2, delay: Duration) -> Self + pub async fn new(func: F1, apply_config: F2, delay: Duration) -> Self where - F1: Fn(Request) -> Fut + Clone + Send + 'static, - Fut: Future> + Send + 'static, - F2: FnOnce(hyper::server::conn::Http) -> hyper::server::conn::Http + Send + 'static, + F1: Fn(Request) -> Fut + Clone + Send + 'static, + Fut: Future> + Send + 'static, + F2: FnOnce(&mut Builder) -> Bu + Send + 'static, { let (shutdown_tx, shutdown_rx) = oneshot::channel(); let (server_terminated_tx, server_terminated_rx) = oneshot::channel(); @@ -43,9 +44,12 @@ impl Server { let addr = tcp_listener.local_addr().unwrap(); tokio::spawn(async move { - let http = Arc::new(apply_config(hyper::server::conn::Http::new())); + let mut builder = + hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); + apply_config(&mut builder); tokio::spawn(async move { + let builder = builder; let (connection_shutdown_tx, connection_shutdown_rx) = oneshot::channel(); let connection_shutdown_rx = connection_shutdown_rx.shared(); let mut shutdown_rx = std::pin::pin!(shutdown_rx); @@ -59,24 +63,24 @@ impl Server { } res = tcp_listener.accept() => { let (stream, _) = res.unwrap(); + let io = hyper_util::rt::TokioIo::new(stream); let handle = tokio::spawn({ let connection_shutdown_rx = connection_shutdown_rx.clone(); - let http = http.clone(); let func = func.clone(); + let svc = service_fn(move |req| { + let fut = func(req); + async move { + Ok::<_, Infallible>(fut.await) + }}); + let builder = builder.clone(); async move { + let fut = builder.serve_connection_with_upgrades(io, svc); tokio::time::sleep(delay).await; - let mut conn = std::pin::pin!(http.serve_connection( - stream, - service_fn(move |req| { - let fut = func(req); - async move { - Ok::<_, Infallible>(fut.await) - }}) - )); + let mut conn = std::pin::pin!(fut); select! { _ = conn.as_mut() => {} diff --git a/tests/support/server.rs b/tests/support/server.rs index 5193a5fbe..f9c45b4d2 100644 --- a/tests/support/server.rs +++ b/tests/support/server.rs @@ -1,12 +1,11 @@ #![cfg(not(target_arch = "wasm32"))] -use std::convert::{identity, Infallible}; +use std::convert::Infallible; use std::future::Future; use std::net; use std::sync::mpsc as std_mpsc; use std::thread; use std::time::Duration; -use hyper::server::conn::AddrIncoming; use tokio::runtime; use tokio::sync::oneshot; @@ -38,19 +37,19 @@ impl Drop for Server { pub fn http(func: F) -> Server where - F: Fn(http::Request) -> Fut + Clone + Send + 'static, - Fut: Future> + Send + 'static, + F: Fn(http::Request) -> Fut + Clone + Send + 'static, + Fut: Future> + Send + 'static, { - http_with_config(func, identity) + http_with_config(func, |_builder| {}) } -pub fn http_with_config(func: F1, apply_config: F2) -> Server +type Builder = hyper_util::server::conn::auto::Builder; + +pub fn http_with_config(func: F1, apply_config: F2) -> Server where - F1: Fn(http::Request) -> Fut + Clone + Send + 'static, - Fut: Future> + Send + 'static, - F2: FnOnce(hyper::server::Builder) -> hyper::server::Builder - + Send - + 'static, + F1: Fn(http::Request) -> Fut + Clone + Send + 'static, + Fut: Future> + Send + 'static, + F2: FnOnce(&mut Builder) -> Bu + Send + 'static, { // Spawn new runtime in thread to prevent reactor execution context conflict thread::spawn(move || { @@ -58,26 +57,14 @@ where .enable_all() .build() .expect("new rt"); - let srv = rt.block_on(async move { - let builder = hyper::Server::bind(&([127, 0, 0, 1], 0).into()); - - apply_config(builder).serve(hyper::service::make_service_fn(move |_| { - let func = func.clone(); - async move { - Ok::<_, Infallible>(hyper::service::service_fn(move |req| { - let fut = func(req); - async move { Ok::<_, Infallible>(fut.await) } - })) - } - })) - }); - - let addr = srv.local_addr(); - let (shutdown_tx, shutdown_rx) = oneshot::channel(); - let srv = srv.with_graceful_shutdown(async move { - let _ = shutdown_rx.await; + let listener = rt.block_on(async move { + tokio::net::TcpListener::bind(&std::net::SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap() }); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); let (panic_tx, panic_rx) = std_mpsc::channel(); let tname = format!( "test({})-support-server", @@ -86,11 +73,34 @@ where thread::Builder::new() .name(tname) .spawn(move || { - rt.block_on(srv).unwrap(); - let _ = panic_tx.send(()); + rt.block_on(async move { + let mut builder = + hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); + apply_config(&mut builder); + + loop { + tokio::select! { + _ = &mut shutdown_rx => { + break; + } + accepted = listener.accept() => { + let (io, _) = accepted.expect("accepted"); + let func = func.clone(); + let svc = hyper::service::service_fn(move |req| { + let fut = func(req); + async move { Ok::<_, Infallible>(fut.await) } + }); + let builder = builder.clone(); + tokio::spawn(async move { + let _ = builder.serve_connection_with_upgrades(hyper_util::rt::TokioIo::new(io), svc).await; + }); + } + } + } + let _ = panic_tx.send(()); + }); }) .expect("thread spawn"); - Server { addr, panic_rx, diff --git a/tests/timeouts.rs b/tests/timeouts.rs index 6f6b0d588..ee690933e 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -143,6 +143,7 @@ async fn connect_many_timeout() { assert!(err.is_connect() && err.is_timeout()); } +#[cfg(feature = "stream")] #[tokio::test] async fn response_timeout() { let _ = env_logger::try_init(); @@ -150,7 +151,7 @@ async fn response_timeout() { let server = server::http(move |_req| { async { // immediate response, but delayed body - let body = hyper::Body::wrap_stream(futures_util::stream::once(async { + let body = reqwest::Body::wrap_stream(futures_util::stream::once(async { tokio::time::sleep(Duration::from_secs(2)).await; Ok::<_, std::convert::Infallible>("Hello") })); @@ -232,6 +233,7 @@ fn timeout_blocking_request() { } #[cfg(feature = "blocking")] +#[cfg(feature = "stream")] #[test] fn blocking_request_timeout_body() { let _ = env_logger::try_init(); @@ -247,7 +249,7 @@ fn blocking_request_timeout_body() { let server = server::http(move |_req| { async { // immediate response, but delayed body - let body = hyper::Body::wrap_stream(futures_util::stream::once(async { + let body = reqwest::Body::wrap_stream(futures_util::stream::once(async { tokio::time::sleep(Duration::from_secs(1)).await; Ok::<_, std::convert::Infallible>("Hello") })); diff --git a/tests/upgrade.rs b/tests/upgrade.rs index de5c2904d..5ea72acc2 100644 --- a/tests/upgrade.rs +++ b/tests/upgrade.rs @@ -11,7 +11,7 @@ async fn http_upgrade() { assert_eq!(req.headers()["upgrade"], "foobar"); tokio::spawn(async move { - let mut upgraded = hyper::upgrade::on(req).await.unwrap(); + let mut upgraded = hyper_util::rt::TokioIo::new(hyper::upgrade::on(req).await.unwrap()); let mut buf = vec![0; 7]; upgraded.read_exact(&mut buf).await.unwrap(); @@ -25,7 +25,7 @@ async fn http_upgrade() { .status(http::StatusCode::SWITCHING_PROTOCOLS) .header(http::header::CONNECTION, "upgrade") .header(http::header::UPGRADE, "foobar") - .body(hyper::Body::empty()) + .body(reqwest::Body::default()) .unwrap() } });