From 0c60de05bcc006f93bd19c3d29f55f62b3bea3bb Mon Sep 17 00:00:00 2001 From: stackinspector Date: Tue, 3 Dec 2024 23:43:46 +0800 Subject: [PATCH] Add a simple `WebSocketStream::send` method to replace `Sink` trait usage And also bump MSRV to 1.64. Fixes https://github.com/sdroege/async-tungstenite/issues/142 --- .github/workflows/ci.yml | 2 +- Cargo.lock.msrv | 2 + Cargo.toml | 20 +++++--- examples/autobahn-client.rs | 1 + examples/autobahn-server.rs | 4 +- examples/server-headers.rs | 2 +- src/compat.rs | 26 +++++----- src/lib.rs | 98 +++++++++++++++++++++++++++++++++---- 8 files changed, 125 insertions(+), 30 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 404f277..f5b93a0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -97,7 +97,7 @@ jobs: strategy: matrix: rust: - - 1.63.0 + - 1.64.0 steps: - name: Checkout sources diff --git a/Cargo.lock.msrv b/Cargo.lock.msrv index 63187f6..d987867 100644 --- a/Cargo.lock.msrv +++ b/Cargo.lock.msrv @@ -247,9 +247,11 @@ dependencies = [ "async-native-tls", "async-std", "async-tls", + "atomic-waker", "env_logger", "futures", "futures-channel", + "futures-core", "futures-io", "futures-util", "gio", diff --git a/Cargo.toml b/Cargo.toml index c423da1..d580945 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,10 +12,11 @@ version = "0.28.0" edition = "2018" readme = "README.md" include = ["examples/**/*", "src/**/*", "LICENSE", "README.md", "CHANGELOG.md"] -rust-version = "1.63" +rust-version = "1.64" [features] -default = ["handshake"] +default = ["handshake", "futures-03-sink"] +futures-03-sink = ["futures-util"] handshake = ["tungstenite/handshake"] async-std-runtime = ["async-std", "handshake"] tokio-runtime = ["tokio", "handshake"] @@ -37,10 +38,17 @@ features = ["async-std-runtime", "tokio-runtime", "gio-runtime", "async-tls", "a [dependencies] log = "0.4" -futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } +futures-core = { version = "0.3", default-features = false } +atomic-waker = { version = "1.1", default-features = false } futures-io = { version = "0.3", default-features = false, features = ["std"] } pin-project-lite = "0.2" +[dependencies.futures-util] +optional = true +version = "0.3" +default-features = false +features = ["sink"] + [dependencies.tungstenite] version = "0.24" default-features = false @@ -141,7 +149,7 @@ required-features = ["async-std-runtime"] [[example]] name = "autobahn-server" -required-features = ["async-std-runtime"] +required-features = ["async-std-runtime", "futures-03-sink"] [[example]] name = "server" @@ -153,7 +161,7 @@ required-features = ["async-std-runtime"] [[example]] name = "server-headers" -required-features = ["async-std-runtime", "handshake"] +required-features = ["async-std-runtime", "handshake", "futures-util"] [[example]] name = "interval-server" @@ -173,4 +181,4 @@ required-features = ["tokio-runtime"] [[example]] name = "server-custom-accept" -required-features = ["tokio-runtime"] +required-features = ["tokio-runtime", "futures-util"] diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 3c49b06..389caa7 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -32,6 +32,7 @@ async fn run_test(case: u32) -> Result<()> { while let Some(msg) = ws_stream.next().await { let msg = msg?; if msg.is_text() || msg.is_binary() { + // for Sink of futures 0.3, see autobahn-server example ws_stream.send(msg).await?; } } diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 3f570e8..4f9c965 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -23,7 +23,9 @@ async fn handle_connection(peer: SocketAddr, stream: TcpStream) -> Result<()> { while let Some(msg) = ws_stream.next().await { let msg = msg?; if msg.is_text() || msg.is_binary() { - ws_stream.send(msg).await?; + // here we explicitly using futures 0.3's Sink implementation for send message + // for WebSocketStream::send, see autobahn-client example + futures::SinkExt::send(&mut ws_stream, msg).await?; } } diff --git a/examples/server-headers.rs b/examples/server-headers.rs index 5dd69a1..71740a4 100644 --- a/examples/server-headers.rs +++ b/examples/server-headers.rs @@ -24,7 +24,7 @@ use async_tungstenite::{ use url::Url; #[macro_use] extern crate log; -use futures_util::{SinkExt, StreamExt}; +use futures_util::StreamExt; #[async_std::main] async fn main() { diff --git a/src/compat.rs b/src/compat.rs index 2fec932..c2fe19a 100644 --- a/src/compat.rs +++ b/src/compat.rs @@ -2,10 +2,10 @@ use log::*; use std::io::{Read, Write}; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, Wake, Waker}; +use atomic_waker::AtomicWaker; use futures_io::{AsyncRead, AsyncWrite}; -use futures_util::task; use std::sync::Arc; use tungstenite::Error as WsError; @@ -49,18 +49,20 @@ pub(crate) struct AllowStd { // read waker slot for this, but any would do. // // Don't ever use this from multiple tasks at the same time! +#[cfg(feature = "handshake")] pub(crate) trait SetWaker { - fn set_waker(&self, waker: &task::Waker); + fn set_waker(&self, waker: &Waker); } +#[cfg(feature = "handshake")] impl SetWaker for AllowStd { - fn set_waker(&self, waker: &task::Waker) { + fn set_waker(&self, waker: &Waker) { self.set_waker(ContextWaker::Read, waker); } } impl AllowStd { - pub(crate) fn new(inner: S, waker: &task::Waker) -> Self { + pub(crate) fn new(inner: S, waker: &Waker) -> Self { let res = Self { inner, write_waker_proxy: Default::default(), @@ -83,7 +85,7 @@ impl AllowStd { // // Write: this is only supposde to be called by write operations, i.e. the Sink impl on the // WebSocketStream. - pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &task::Waker) { + pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &Waker) { match kind { ContextWaker::Read => { self.write_waker_proxy.read_waker.register(waker); @@ -103,11 +105,11 @@ impl AllowStd { // reads and writes, and the same for writes. #[derive(Debug, Default)] struct WakerProxy { - read_waker: task::AtomicWaker, - write_waker: task::AtomicWaker, + read_waker: AtomicWaker, + write_waker: AtomicWaker, } -impl std::task::Wake for WakerProxy { +impl Wake for WakerProxy { fn wake(self: Arc) { self.wake_by_ref() } @@ -129,10 +131,10 @@ where #[cfg(feature = "verbose-logging")] trace!("{}:{} AllowStd.with_context", file!(), line!()); let waker = match kind { - ContextWaker::Read => task::Waker::from(self.read_waker_proxy.clone()), - ContextWaker::Write => task::Waker::from(self.write_waker_proxy.clone()), + ContextWaker::Read => Waker::from(self.read_waker_proxy.clone()), + ContextWaker::Write => Waker::from(self.write_waker_proxy.clone()), }; - let mut context = task::Context::from_waker(&waker); + let mut context = Context::from_waker(&waker); f(&mut context, Pin::new(&mut self.inner)) } diff --git a/src/lib.rs b/src/lib.rs index 7c80e7d..7114934 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,17 +58,16 @@ mod handshake; ))] pub mod stream; -use std::io::{Read, Write}; +use std::{ + io::{Read, Write}, + pin::Pin, + task::{ready, Context, Poll}, +}; use compat::{cvt, AllowStd, ContextWaker}; +use futures_core::stream::{FusedStream, Stream}; use futures_io::{AsyncRead, AsyncWrite}; -use futures_util::{ - sink::{Sink, SinkExt}, - stream::{FusedStream, Stream}, -}; use log::*; -use std::pin::Pin; -use std::task::{Context, Poll}; #[cfg(feature = "handshake")] use tungstenite::{ @@ -227,6 +226,7 @@ where #[derive(Debug)] pub struct WebSocketStream { inner: WebSocket>, + #[cfg(feature = "futures-03-sink")] closing: bool, ended: bool, /// Tungstenite is probably ready to receive more data. @@ -269,6 +269,7 @@ impl WebSocketStream { pub(crate) fn new(ws: WebSocket>) -> Self { Self { inner: ws, + #[cfg(feature = "futures-03-sink")] closing: false, ended: false, ready: true, @@ -337,7 +338,7 @@ where return Poll::Ready(None); } - match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| { + match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| { #[cfg(feature = "verbose-logging")] trace!( "{}:{} Stream.with_context poll_next -> read()", @@ -368,7 +369,8 @@ where } } -impl Sink for WebSocketStream +#[cfg(feature = "futures-03-sink")] +impl futures_util::Sink for WebSocketStream where T: AsyncRead + AsyncWrite + Unpin, { @@ -446,6 +448,84 @@ where } } +impl WebSocketStream { + /// Simple send method to replace `futures_sink::Sink` (till v0.3). + pub async fn send(&mut self, msg: Message) -> Result<(), WsError> + where + S: AsyncRead + AsyncWrite + Unpin, + { + Send::new(self, msg).await + } +} + +struct Send<'a, S> { + ws: &'a mut WebSocketStream, + msg: Option, +} + +impl<'a, S> Send<'a, S> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn new(ws: &'a mut WebSocketStream, msg: Message) -> Self { + Self { ws, msg: Some(msg) } + } +} + +impl std::future::Future for Send<'_, S> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result<(), WsError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.msg.is_some() { + if !self.ws.ready { + // Currently blocked so try to flush the blockage away + let polled = self + .ws + .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())) + .map(|r| { + self.ws.ready = true; + r + }); + ready!(polled)? + } + + let msg = self.msg.take().expect("unreachable"); + match self.ws.with_context(None, |s| s.write(msg)) { + Ok(_) => Ok(()), + Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => { + // the message was accepted and queued so not an error + // + // set to false here for cancellation safety of *this* Future + self.ws.ready = false; + Ok(()) + } + Err(e) => { + debug!("websocket start_send error: {}", e); + Err(e) + } + }?; + } + + let polled = self + .ws + .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())) + .map(|r| { + self.ws.ready = true; + match r { + // WebSocket connection has just been closed. Flushing completed, not an error. + Err(WsError::ConnectionClosed) => Ok(()), + other => other, + } + }); + ready!(polled)?; + + Poll::Ready(Ok(())) + } +} + #[cfg(any( feature = "async-tls", feature = "async-std-runtime",