From 818dc3df9191c685bc854dd878fe27cdca709a6c Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 3 Dec 2024 22:44:34 +0800 Subject: [PATCH] fix: return write-zero error when write return 0 --- src/common/mod.rs | 10 ++++-- src/common/test_stream.rs | 66 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/src/common/mod.rs b/src/common/mod.rs index 5dc6a05..0925e40 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -130,6 +130,8 @@ where while self.session.wants_write() { match self.write_io(cx) { + Poll::Ready(Ok(0)) => + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())), Poll::Ready(Ok(n)) => { wrlen += n; need_flush = true; @@ -322,14 +324,18 @@ where fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.session.writer().flush()?; while self.session.wants_write() { - ready!(self.write_io(cx))?; + if ready!(self.write_io(cx))? == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } } Pin::new(&mut self.io).poll_flush(cx) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { while self.session.wants_write() { - ready!(self.write_io(cx))?; + if ready!(self.write_io(cx))? == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } } Poll::Ready(match ready!(Pin::new(&mut self.io).poll_shutdown(cx)) { diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 847de8e..438213e 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -122,6 +122,36 @@ impl AsyncWrite for Expected { } } +struct Eof; + +impl AsyncRead for Eof { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +impl AsyncWrite for Eof { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Ready(Ok(0)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + #[tokio::test] async fn stream_good() -> io::Result<()> { stream_good_impl(false).await @@ -254,6 +284,23 @@ async fn stream_handshake_eof() -> io::Result<()> { Ok(()) as io::Result<()> } +#[tokio::test] +async fn stream_handshake_write_eof() -> io::Result<()> { + let (_, mut client) = make_pair(); + + let mut io = Eof; + let mut stream = Stream::new(&mut io, &mut client); + + let mut cx = Context::from_waker(noop_waker_ref()); + let r = stream.handshake(&mut cx); + assert_eq!( + r.map_err(|err| err.kind()), + Poll::Ready(Err(io::ErrorKind::WriteZero)) + ); + + Ok(()) as io::Result<()> +} + // see https://github.com/tokio-rs/tls/issues/77 #[tokio::test] async fn stream_handshake_regression_issues_77() -> io::Result<()> { @@ -291,6 +338,25 @@ async fn stream_eof() -> io::Result<()> { Ok(()) as io::Result<()> } +#[tokio::test] +async fn stream_write_zero() -> io::Result<()> { + let (server, mut client) = make_pair(); + let mut server = Connection::from(server); + poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + + let mut io = Eof; + let mut stream = Stream::new(&mut io, &mut client); + + stream.write(b"1").await.unwrap(); + let result = stream.shutdown().await; + assert_eq!( + result.err().map(|e| e.kind()), + Some(io::ErrorKind::WriteZero) + ); + + Ok(()) as io::Result<()> +} + fn make_pair() -> (ServerConnection, ClientConnection) { let (sconfig, cconfig) = utils::make_configs(); let server = ServerConnection::new(Arc::new(sconfig)).unwrap();