From a5c60b24dec550ab65e26b948960a421d2e60b3d Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Wed, 8 Dec 2021 17:38:41 -0800 Subject: [PATCH] Fix poll_capacity to wake in combination with max_send_buffer_size --- src/proto/streams/prioritize.rs | 5 ++ src/proto/streams/stream.rs | 11 +++++ tests/h2-tests/tests/flow_control.rs | 74 ++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+) diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index eaaee162b..2347f6f0b 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -741,6 +741,11 @@ impl Prioritize { stream.buffered_send_data -= len as usize; stream.requested_send_capacity -= len; + // If the capacity was limited because of the + // max_send_buffer_size, then consider waking + // the send task again... + stream.notify_if_can_buffer_more(); + // Assign the capacity back to the connection that // was just consumed from the stream in the previous // line. diff --git a/src/proto/streams/stream.rs b/src/proto/streams/stream.rs index 79de47a9a..f67cc3642 100644 --- a/src/proto/streams/stream.rs +++ b/src/proto/streams/stream.rs @@ -279,6 +279,17 @@ impl Stream { } } + /// If the capacity was limited because of the max_send_buffer_size, + /// then consider waking the send task again... + pub fn notify_if_can_buffer_more(&mut self) { + // Only notify if the capacity exceeds the amount of buffered data + if self.send_flow.available() > self.buffered_send_data { + self.send_capacity_inc = true; + tracing::trace!(" notifying task"); + self.notify_send(); + } + } + /// Returns `Err` when the decrement cannot be completed due to overflow. pub fn dec_content_length(&mut self, len: usize) -> Result<(), ()> { match self.content_length { diff --git a/tests/h2-tests/tests/flow_control.rs b/tests/h2-tests/tests/flow_control.rs index 1a6018f73..7adb3d730 100644 --- a/tests/h2-tests/tests/flow_control.rs +++ b/tests/h2-tests/tests/flow_control.rs @@ -1668,3 +1668,77 @@ async fn max_send_buffer_size_overflow() { join(srv, client).await; } + +#[tokio::test] +async fn max_send_buffer_size_poll_capacity_wakes_task() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + srv.recv_frame(frames::data(1, &[0; 5][..])).await; + srv.recv_frame(frames::data(1, &[0; 5][..])).await; + srv.recv_frame(frames::data(1, &[0; 5][..])).await; + srv.recv_frame(frames::data(1, &[0; 5][..])).await; + srv.recv_frame(frames::data(1, &[][..]).eos()).await; + }; + + let client = async move { + let (mut client, mut conn) = client::Builder::new() + .max_send_buffer_size(5) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (response, mut stream) = client.send_request(request, false).unwrap(); + + let response = conn.drive(response).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + assert_eq!(stream.capacity(), 0); + const TO_SEND: usize = 20; + stream.reserve_capacity(TO_SEND); + assert_eq!( + stream.capacity(), + 5, + "polled capacity not over max buffer size" + ); + + let t1 = tokio::spawn(async move { + let mut sent = 0; + let buf = [0; TO_SEND]; + loop { + match poll_fn(|cx| stream.poll_capacity(cx)).await { + None => panic!("no cap"), + Some(Err(e)) => panic!("cap error: {:?}", e), + Some(Ok(cap)) => { + stream + .send_data(buf[sent..(sent + cap)].to_vec().into(), false) + .unwrap(); + sent += cap; + if sent >= TO_SEND { + break; + } + } + } + } + stream.send_data(Bytes::new(), true).unwrap(); + }); + + // Wait for the connection to close + conn.await.unwrap(); + t1.await.unwrap(); + }; + + join(srv, client).await; +}