From 06a0264c85b2f798728e91643b209b7c79502b60 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Tue, 19 Sep 2023 15:38:57 -0600 Subject: [PATCH 1/8] Split websocket --- src/error.rs | 2 + src/fragment.rs | 55 +++++++++++++ src/lib.rs | 215 +++++++++++++++++++++++++++++++++++++++++++----- tests/split.rs | 161 ++++++++++++++++++++++++++++++++++++ 4 files changed, 414 insertions(+), 19 deletions(-) create mode 100644 tests/split.rs diff --git a/src/error.rs b/src/error.rs index f0a0d4b..46db8fe 100644 --- a/src/error.rs +++ b/src/error.rs @@ -41,4 +41,6 @@ pub enum WebSocketError { #[cfg(feature = "upgrade")] #[error(transparent)] HTTPError(#[from] hyper::Error), + #[error("Failed to send frame")] + SendError(#[from] Box), } diff --git a/src/fragment.rs b/src/fragment.rs index 9f6d3e9..17e12c9 100644 --- a/src/fragment.rs +++ b/src/fragment.rs @@ -12,12 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::future::Future; + use crate::error::WebSocketError; use crate::frame::Frame; use crate::recv::SharedRecv; use crate::OpCode; use crate::ReadHalf; use crate::WebSocket; +use crate::WebSocketRead; use crate::WriteHalf; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; @@ -136,6 +139,58 @@ impl<'f, S> FragmentCollector { } } +pub struct FragmentCollectorRead { + stream: S, + read_half: ReadHalf, + fragments: Fragments, + // !Sync marker + _marker: std::marker::PhantomData, +} + +impl<'f, S> FragmentCollectorRead { + /// Creates a new `FragmentCollector` with the provided `WebSocket`. + pub fn new(ws: WebSocketRead) -> FragmentCollectorRead + where + S: AsyncReadExt + Unpin, + { + let (stream, read_half) = ws.into_parts_internal(); + FragmentCollectorRead { + stream, + read_half, + fragments: Fragments::new(), + _marker: std::marker::PhantomData, + } + } + + /// Reads a WebSocket frame, collecting fragmented messages until the final frame is received and returns the completed message. + /// + /// Text frames payload is guaranteed to be valid UTF-8. + pub async fn read_frame( + &mut self, + send_fn: &mut impl FnMut(Frame<'f>) -> R, + ) -> Result, WebSocketError> + where + S: AsyncReadExt + Unpin, + E: Into>, + R: Future>, + { + loop { + let (res, obligated_send) = + self.read_half.read_frame_inner(&mut self.stream).await; + if let Some(frame) = obligated_send { + let res = send_fn(frame).await; + res.map_err(|e| WebSocketError::SendError(e.into()))?; + } + let Some(frame) = res? else { + continue; + }; + if let Some(frame) = self.fragments.accumulate(frame)? { + return Ok(frame); + } + } + } +} + /// Accumulates potentially fragmented [`Frame`]s to defragment the incoming WebSocket stream. struct Fragments { fragments: Option, diff --git a/src/lib.rs b/src/lib.rs index 7d994ca..23a6008 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -162,18 +162,24 @@ mod recv; #[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))] pub mod upgrade; +use std::future::Future; + use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; pub use crate::close::CloseCode; pub use crate::error::WebSocketError; pub use crate::fragment::FragmentCollector; +pub use crate::fragment::FragmentCollectorRead; pub use crate::frame::Frame; pub use crate::frame::OpCode; pub use crate::frame::Payload; pub use crate::mask::unmask; use crate::recv::SharedRecv; +#[derive(Copy, Clone, Default)] +struct UnsendMarker(std::marker::PhantomData); + #[derive(Copy, Clone, PartialEq)] pub enum Role { Server, @@ -199,13 +205,145 @@ pub(crate) struct ReadHalf { max_message_size: usize, } +pub struct WebSocketRead { + stream: S, + read_half: ReadHalf, + _marker: UnsendMarker, +} + +pub struct WebSocketWrite { + stream: S, + write_half: WriteHalf, + _marker: UnsendMarker, +} + +/// Create a split `WebSocketRead`/`WebSocketWrite` pair from a stream that has already completed the WebSocket handshake. +pub fn after_handshake_split( + read: R, + write: W, + role: Role, +) -> (WebSocketRead, WebSocketWrite) +where + R: AsyncWriteExt + Unpin, + W: AsyncWriteExt + Unpin, +{ + ( + WebSocketRead { + stream: read, + read_half: ReadHalf::after_handshake(role), + _marker: UnsendMarker::default(), + }, + WebSocketWrite { + stream: write, + write_half: WriteHalf::after_handshake(role), + _marker: UnsendMarker::default(), + }, + ) +} + +impl<'f, S> WebSocketRead { + /// Consumes the `WebSocketRead` and returns the underlying stream. + #[inline] + pub(crate) fn into_parts_internal(self) -> (S, ReadHalf) { + (self.stream, self.read_half) + } + + pub fn set_writev_threshold(&mut self, threshold: usize) { + self.read_half.writev_threshold = threshold; + } + + /// Sets whether to automatically close the connection when a close frame is received. When set to `false`, the application will have to manually send close frames. + /// + /// Default: `true` + pub fn set_auto_close(&mut self, auto_close: bool) { + self.read_half.auto_close = auto_close; + } + + /// Sets whether to automatically send a pong frame when a ping frame is received. + /// + /// Default: `true` + pub fn set_auto_pong(&mut self, auto_pong: bool) { + self.read_half.auto_pong = auto_pong; + } + + /// Sets the maximum message size in bytes. If a message is received that is larger than this, the connection will be closed. + /// + /// Default: 64 MiB + pub fn set_max_message_size(&mut self, max_message_size: usize) { + self.read_half.max_message_size = max_message_size; + } + + /// Sets whether to automatically apply the mask to the frame payload. + /// + /// Default: `true` + pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) { + self.read_half.auto_apply_mask = auto_apply_mask; + } + + /// Reads a frame from the stream. + pub async fn read_frame( + &mut self, + send_fn: &mut impl FnMut(Frame<'f>) -> R, + ) -> Result + where + S: AsyncReadExt + Unpin, + E: Into>, + R: Future>, + { + loop { + let (res, obligated_send) = + self.read_half.read_frame_inner(&mut self.stream).await; + if let Some(frame) = obligated_send { + let res = send_fn(frame).await; + res.map_err(|e| WebSocketError::SendError(e.into()))?; + } + if let Some(frame) = res? { + break Ok(frame); + } + } + } +} + +impl<'f, S> WebSocketWrite { + /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used. + /// + /// Default: `true` + pub fn set_writev(&mut self, vectored: bool) { + self.write_half.vectored = vectored; + } + + pub fn set_writev_threshold(&mut self, threshold: usize) { + self.write_half.writev_threshold = threshold; + } + + /// Sets whether to automatically apply the mask to the frame payload. + /// + /// Default: `true` + pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) { + self.write_half.auto_apply_mask = auto_apply_mask; + } + + pub fn is_closed(&self) -> bool { + self.write_half.closed + } + + pub async fn write_frame( + &mut self, + frame: Frame<'f>, + ) -> Result<(), WebSocketError> + where + S: AsyncWriteExt + Unpin, + { + self.write_half.write_frame(&mut self.stream, frame).await + } +} + /// WebSocket protocol implementation over an async stream. pub struct WebSocket { stream: S, write_half: WriteHalf, read_half: ReadHalf, - // !Sync marker - _marker: std::marker::PhantomData, + _marker: UnsendMarker, } impl<'f, S> WebSocket { @@ -235,25 +373,35 @@ impl<'f, S> WebSocket { recv::init_once(); Self { stream, - write_half: WriteHalf { - role, - closed: false, - auto_apply_mask: true, - vectored: true, - writev_threshold: 1024, - write_buffer: Vec::with_capacity(2), + write_half: WriteHalf::after_handshake(role), + read_half: ReadHalf::after_handshake(role), + _marker: UnsendMarker::default(), + } + } + + pub fn split( + self, + split_fn: impl Fn(S) -> (R, W), + ) -> (WebSocketRead, WebSocketWrite) + where + S: AsyncReadExt + AsyncWriteExt + Unpin, + R: AsyncReadExt + Unpin, + W: AsyncWriteExt + Unpin, + { + let (stream, read, write) = self.into_parts_internal(); + let (r, w) = split_fn(stream); + ( + WebSocketRead { + stream: r, + read_half: read, + _marker: UnsendMarker::default(), }, - read_half: ReadHalf { - role, - spill: None, - auto_apply_mask: true, - auto_close: true, - auto_pong: true, - writev_threshold: 1024, - max_message_size: 64 << 20, + WebSocketWrite { + stream: w, + write_half: write, + _marker: UnsendMarker::default(), }, - _marker: std::marker::PhantomData, - } + ) } /// Consumes the `WebSocket` and returns the underlying stream. @@ -310,6 +458,10 @@ impl<'f, S> WebSocket { self.write_half.auto_apply_mask = auto_apply_mask; } + pub fn is_closed(&self) -> bool { + self.write_half.closed + } + /// Writes a frame to the stream. /// /// # Example @@ -388,6 +540,18 @@ impl<'f, S> WebSocket { } impl ReadHalf { + pub fn after_handshake(role: Role) -> Self { + Self { + role, + spill: None, + auto_apply_mask: true, + auto_close: true, + auto_pong: true, + writev_threshold: 1024, + max_message_size: 64 << 20, + } + } + /// Attempt to read a single frame from from the incoming stream, returning any send obligations if /// `auto_close` or `auto_pong` are enabled. Callers to this function are obligated to send the /// frame in the latter half of the tuple if one is specified, unless the write half of this socket @@ -573,6 +737,17 @@ impl ReadHalf { } impl WriteHalf { + pub fn after_handshake(role: Role) -> Self { + Self { + role, + closed: false, + auto_apply_mask: true, + vectored: true, + writev_threshold: 1024, + write_buffer: Vec::with_capacity(2), + } + } + /// Writes a frame to the provided stream. pub async fn write_frame<'a, S>( &'a mut self, @@ -588,6 +763,8 @@ impl WriteHalf { if frame.opcode == OpCode::Close { self.closed = true; + } else if self.closed { + return Err(WebSocketError::ConnectionClosed); } if self.vectored && frame.payload.len() > self.writev_threshold { diff --git a/tests/split.rs b/tests/split.rs new file mode 100644 index 0000000..ee50900 --- /dev/null +++ b/tests/split.rs @@ -0,0 +1,161 @@ +// Copyright 2023 Divy Srivastava +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::Result; +use fastwebsockets::upgrade; +use fastwebsockets::Frame; +use fastwebsockets::OpCode; +use hyper::server::conn::Http; +use hyper::service::service_fn; +use hyper::Body; +use hyper::Request; +use hyper::Response; +use tokio::net::TcpListener; + +use fastwebsockets::handshake; +use fastwebsockets::WebSocketRead; +use fastwebsockets::WebSocketWrite; +use hyper::header::CONNECTION; +use hyper::header::UPGRADE; +use hyper::upgrade::Upgraded; + +use std::future::Future; + +use tokio::net::TcpStream; + +use tokio::sync::mpsc::unbounded_channel; + +const N_CLIENTS: usize = 20; + +async fn handle_client( + client_id: usize, + fut: upgrade::UpgradeFut, +) -> Result<()> { + let mut ws = fut.await?; + ws.set_writev(false); + let mut ws = fastwebsockets::FragmentCollector::new(ws); + + ws.write_frame(Frame::binary(client_id.to_ne_bytes().as_ref().into())) + .await + .unwrap(); + + Ok(()) +} + +async fn server_upgrade(mut req: Request) -> Result> { + let (response, fut) = upgrade::upgrade(&mut req)?; + + let client_id: usize = req + .headers() + .get("CLIENT-ID") + .unwrap() + .to_str() + .unwrap() + .parse() + .unwrap(); + tokio::spawn(async move { + handle_client(client_id, fut).await.unwrap(); + }); + + Ok(response) +} + +async fn connect( + client_id: usize, +) -> Result<( + WebSocketRead>, + WebSocketWrite>, +)> { + let stream = TcpStream::connect("localhost:8080").await?; + + let req = Request::builder() + .method("GET") + .uri("http://localhost:8080/") + .header("Host", "localhost:8080") + .header(UPGRADE, "websocket") + .header(CONNECTION, "upgrade") + .header("CLIENT-ID", &format!("{}", client_id)) + .header( + "Sec-WebSocket-Key", + fastwebsockets::handshake::generate_key(), + ) + .header("Sec-WebSocket-Version", "13") + .body(Body::empty())?; + + let (ws, _) = handshake::client(&SpawnExecutor, req, stream).await?; + Ok(ws.split(|s| tokio::io::split(s))) +} + +async fn start_client(client_id: usize) -> Result<()> { + let (mut r, _w) = connect(client_id).await.unwrap(); + let (write_queue_tx, _write_queue_rx) = unbounded_channel(); + let frame = r + .read_frame(&mut |frame| { + let res = write_queue_tx.send(frame).map_err(|_| { + Box::::from( + "Failed to send frame".to_owned(), + ) + }); + async { res } + }) + .await?; + match frame.opcode { + OpCode::Close => {} + OpCode::Binary => { + let n = usize::from_ne_bytes(frame.payload[..].try_into().unwrap()); + assert_eq!(n, client_id); + } + _ => { + panic!("Unexpected"); + } + } + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test() -> Result<()> { + let listener = TcpListener::bind("127.0.0.1:8080").await?; + println!("Server started, listening on {}", "127.0.0.1:8080"); + tokio::spawn(async move { + loop { + let (stream, _) = listener.accept().await.unwrap(); + tokio::spawn(async move { + let conn_fut = Http::new() + .serve_connection(stream, service_fn(server_upgrade)) + .with_upgrades(); + conn_fut.await.unwrap(); + }); + } + }); + let mut tasks = Vec::with_capacity(N_CLIENTS); + for client in 0..N_CLIENTS { + tasks.push(start_client(client)); + } + for handle in tasks { + handle.await.unwrap(); + } + Ok(()) +} + +struct SpawnExecutor; + +impl hyper::rt::Executor for SpawnExecutor +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + fn execute(&self, fut: Fut) { + tokio::task::spawn(fut); + } +} From f916c3c682efd285f038cf87f83bdb68eb14b8a1 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Fri, 27 Oct 2023 18:53:15 -0600 Subject: [PATCH 2/8] Fix up test --- tests/split.rs | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/split.rs b/tests/split.rs index ee50900..17643f8 100644 --- a/tests/split.rs +++ b/tests/split.rs @@ -29,13 +29,13 @@ use fastwebsockets::WebSocketWrite; use hyper::header::CONNECTION; use hyper::header::UPGRADE; use hyper::upgrade::Upgraded; +use tokio::sync::Mutex; use std::future::Future; +use std::rc::Rc; use tokio::net::TcpStream; -use tokio::sync::mpsc::unbounded_channel; - const N_CLIENTS: usize = 20; async fn handle_client( @@ -98,16 +98,12 @@ async fn connect( } async fn start_client(client_id: usize) -> Result<()> { - let (mut r, _w) = connect(client_id).await.unwrap(); - let (write_queue_tx, _write_queue_rx) = unbounded_channel(); + let (mut r, w) = connect(client_id).await.unwrap(); + let w = Rc::new(Mutex::new(w)); let frame = r - .read_frame(&mut |frame| { - let res = write_queue_tx.send(frame).map_err(|_| { - Box::::from( - "Failed to send frame".to_owned(), - ) - }); - async { res } + .read_frame(&mut move |frame| { + let w = w.clone(); + async move { w.lock().await.write_frame(frame).await } }) .await?; match frame.opcode { From f0e9bb750271f6f7dbb91c29ab07633f259ded13 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Fri, 27 Oct 2023 19:06:12 -0600 Subject: [PATCH 3/8] Gate split behind unstable-split API --- Cargo.toml | 1 + src/fragment.rs | 4 ++++ src/lib.rs | 8 ++++++++ 3 files changed, 13 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index dae33f7..ba62611 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ thiserror = "1.0.40" default = ["simd"] simd = ["simdutf8/aarch64_neon"] upgrade = ["hyper", "pin-project", "base64", "sha1"] +unstable-split = [] [dev-dependencies] tokio = { version = "1.25.0", features = ["full", "macros"] } diff --git a/src/fragment.rs b/src/fragment.rs index 17e12c9..f9e3624 100644 --- a/src/fragment.rs +++ b/src/fragment.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#[cfg(feature="unstable-split")] use std::future::Future; use crate::error::WebSocketError; @@ -20,6 +21,7 @@ use crate::recv::SharedRecv; use crate::OpCode; use crate::ReadHalf; use crate::WebSocket; +#[cfg(feature="unstable-split")] use crate::WebSocketRead; use crate::WriteHalf; use tokio::io::AsyncReadExt; @@ -139,6 +141,7 @@ impl<'f, S> FragmentCollector { } } +#[cfg(feature="unstable-split")] pub struct FragmentCollectorRead { stream: S, read_half: ReadHalf, @@ -147,6 +150,7 @@ pub struct FragmentCollectorRead { _marker: std::marker::PhantomData, } +#[cfg(feature="unstable-split")] impl<'f, S> FragmentCollectorRead { /// Creates a new `FragmentCollector` with the provided `WebSocket`. pub fn new(ws: WebSocketRead) -> FragmentCollectorRead diff --git a/src/lib.rs b/src/lib.rs index 23a6008..f37a875 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -162,6 +162,7 @@ mod recv; #[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))] pub mod upgrade; +#[cfg(feature="unstable-split")] use std::future::Future; use tokio::io::AsyncReadExt; @@ -170,6 +171,7 @@ use tokio::io::AsyncWriteExt; pub use crate::close::CloseCode; pub use crate::error::WebSocketError; pub use crate::fragment::FragmentCollector; +#[cfg(feature="unstable-split")] pub use crate::fragment::FragmentCollectorRead; pub use crate::frame::Frame; pub use crate::frame::OpCode; @@ -205,18 +207,21 @@ pub(crate) struct ReadHalf { max_message_size: usize, } +#[cfg(feature="unstable-split")] pub struct WebSocketRead { stream: S, read_half: ReadHalf, _marker: UnsendMarker, } +#[cfg(feature="unstable-split")] pub struct WebSocketWrite { stream: S, write_half: WriteHalf, _marker: UnsendMarker, } +#[cfg(feature="unstable-split")] /// Create a split `WebSocketRead`/`WebSocketWrite` pair from a stream that has already completed the WebSocket handshake. pub fn after_handshake_split( read: R, @@ -241,6 +246,7 @@ where ) } +#[cfg(feature="unstable-split")] impl<'f, S> WebSocketRead { /// Consumes the `WebSocketRead` and returns the underlying stream. #[inline] @@ -304,6 +310,7 @@ impl<'f, S> WebSocketRead { } } +#[cfg(feature="unstable-split")] impl<'f, S> WebSocketWrite { /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used. /// @@ -379,6 +386,7 @@ impl<'f, S> WebSocket { } } + #[cfg(feature="unstable-split")] pub fn split( self, split_fn: impl Fn(S) -> (R, W), From c4c078208a6134e1c7de07691b3117a930868c1a Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Fri, 27 Oct 2023 19:14:41 -0600 Subject: [PATCH 4/8] fmt --- src/fragment.rs | 8 ++++---- src/lib.rs | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/fragment.rs b/src/fragment.rs index f9e3624..8a3d755 100644 --- a/src/fragment.rs +++ b/src/fragment.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#[cfg(feature="unstable-split")] +#[cfg(feature = "unstable-split")] use std::future::Future; use crate::error::WebSocketError; @@ -21,7 +21,7 @@ use crate::recv::SharedRecv; use crate::OpCode; use crate::ReadHalf; use crate::WebSocket; -#[cfg(feature="unstable-split")] +#[cfg(feature = "unstable-split")] use crate::WebSocketRead; use crate::WriteHalf; use tokio::io::AsyncReadExt; @@ -141,7 +141,7 @@ impl<'f, S> FragmentCollector { } } -#[cfg(feature="unstable-split")] +#[cfg(feature = "unstable-split")] pub struct FragmentCollectorRead { stream: S, read_half: ReadHalf, @@ -150,7 +150,7 @@ pub struct FragmentCollectorRead { _marker: std::marker::PhantomData, } -#[cfg(feature="unstable-split")] +#[cfg(feature = "unstable-split")] impl<'f, S> FragmentCollectorRead { /// Creates a new `FragmentCollector` with the provided `WebSocket`. pub fn new(ws: WebSocketRead) -> FragmentCollectorRead diff --git a/src/lib.rs b/src/lib.rs index f37a875..42ade77 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -162,7 +162,7 @@ mod recv; #[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))] pub mod upgrade; -#[cfg(feature="unstable-split")] +#[cfg(feature = "unstable-split")] use std::future::Future; use tokio::io::AsyncReadExt; @@ -171,7 +171,7 @@ use tokio::io::AsyncWriteExt; pub use crate::close::CloseCode; pub use crate::error::WebSocketError; pub use crate::fragment::FragmentCollector; -#[cfg(feature="unstable-split")] +#[cfg(feature = "unstable-split")] pub use crate::fragment::FragmentCollectorRead; pub use crate::frame::Frame; pub use crate::frame::OpCode; @@ -207,21 +207,21 @@ pub(crate) struct ReadHalf { max_message_size: usize, } -#[cfg(feature="unstable-split")] +#[cfg(feature = "unstable-split")] pub struct WebSocketRead { stream: S, read_half: ReadHalf, _marker: UnsendMarker, } -#[cfg(feature="unstable-split")] +#[cfg(feature = "unstable-split")] pub struct WebSocketWrite { stream: S, write_half: WriteHalf, _marker: UnsendMarker, } -#[cfg(feature="unstable-split")] +#[cfg(feature = "unstable-split")] /// Create a split `WebSocketRead`/`WebSocketWrite` pair from a stream that has already completed the WebSocket handshake. pub fn after_handshake_split( read: R, @@ -246,7 +246,7 @@ where ) } -#[cfg(feature="unstable-split")] +#[cfg(feature = "unstable-split")] impl<'f, S> WebSocketRead { /// Consumes the `WebSocketRead` and returns the underlying stream. #[inline] @@ -310,7 +310,7 @@ impl<'f, S> WebSocketRead { } } -#[cfg(feature="unstable-split")] +#[cfg(feature = "unstable-split")] impl<'f, S> WebSocketWrite { /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used. /// @@ -386,7 +386,7 @@ impl<'f, S> WebSocket { } } - #[cfg(feature="unstable-split")] + #[cfg(feature = "unstable-split")] pub fn split( self, split_fn: impl Fn(S) -> (R, W), From 0c3154fa1d5d288375ca32ad623865d7034e79f1 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Mon, 30 Oct 2023 10:03:15 -0600 Subject: [PATCH 5/8] Add split echo server for testing --- examples/echo_server_split.rs | 85 +++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 examples/echo_server_split.rs diff --git a/examples/echo_server_split.rs b/examples/echo_server_split.rs new file mode 100644 index 0000000..e2ba3ef --- /dev/null +++ b/examples/echo_server_split.rs @@ -0,0 +1,85 @@ +// Copyright 2023 Divy Srivastava +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use fastwebsockets::upgrade; +use fastwebsockets::FragmentCollectorRead; +use fastwebsockets::OpCode; +use fastwebsockets::WebSocketError; +use hyper::server::conn::Http; +use hyper::service::service_fn; +use hyper::Body; +use hyper::Request; +use hyper::Response; +use tokio::net::TcpListener; + +async fn handle_client(fut: upgrade::UpgradeFut) -> Result<(), WebSocketError> { + let ws = fut.await?; + let (rx, mut tx) = ws.split(|ws| tokio::io::split(ws)); + let mut rx = FragmentCollectorRead::new(rx); + loop { + // Empty send_fn is fine because the benchmark does not create obligated writes. + let frame = rx + .read_frame(&mut move |_| async { + unreachable!(); + Ok::<_, WebSocketError>(()) + }) + .await?; + match frame.opcode { + OpCode::Close => break, + OpCode::Text | OpCode::Binary => { + tx.write_frame(frame).await?; + } + _ => {} + } + } + + Ok(()) +} +async fn server_upgrade( + mut req: Request, +) -> Result, WebSocketError> { + let (response, fut) = upgrade::upgrade(&mut req)?; + + tokio::task::spawn(async move { + if let Err(e) = tokio::task::unconstrained(handle_client(fut)).await { + eprintln!("Error in websocket connection: {}", e); + } + }); + + Ok(response) +} + +fn main() -> Result<(), WebSocketError> { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); + + rt.block_on(async move { + let listener = TcpListener::bind("127.0.0.1:8080").await?; + println!("Server started, listening on {}", "127.0.0.1:8080"); + loop { + let (stream, _) = listener.accept().await?; + println!("Client connected"); + tokio::spawn(async move { + let conn_fut = Http::new() + .serve_connection(stream, service_fn(server_upgrade)) + .with_upgrades(); + if let Err(e) = conn_fut.await { + println!("An error occurred: {:?}", e); + } + }); + } + }) +} From 8d974210cc6c4dc985390581e0b407cf7e7d767f Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Mon, 30 Oct 2023 10:05:08 -0600 Subject: [PATCH 6/8] Improve docs on split --- src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 42ade77..3ba21e4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -386,6 +386,9 @@ impl<'f, S> WebSocket { } } + /// Split a [`WebSocket`] into a [`WebSocketRead`] and [`WebSocketWrite`] half. Note that the split version does not + /// handle fragmented packets and you may wish to create a [`FragmentCollectorRead`] over top of the read half that + /// is returned. #[cfg(feature = "unstable-split")] pub fn split( self, From d8c08ff2cb7118fb6ac18f2add61c70f8adc7294 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Mon, 30 Oct 2023 10:11:16 -0600 Subject: [PATCH 7/8] Panic if this were ever called --- examples/echo_server_split.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/echo_server_split.rs b/examples/echo_server_split.rs index e2ba3ef..4087f04 100644 --- a/examples/echo_server_split.rs +++ b/examples/echo_server_split.rs @@ -30,9 +30,8 @@ async fn handle_client(fut: upgrade::UpgradeFut) -> Result<(), WebSocketError> { loop { // Empty send_fn is fine because the benchmark does not create obligated writes. let frame = rx - .read_frame(&mut move |_| async { + .read_frame::<_, WebSocketError>(&mut move |_| async { unreachable!(); - Ok::<_, WebSocketError>(()) }) .await?; match frame.opcode { From 2db690f9523dd0e35875e234523aca4b8e1bbb64 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Mon, 30 Oct 2023 10:14:05 -0600 Subject: [PATCH 8/8] Gate this error on the feature --- src/error.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/error.rs b/src/error.rs index 46db8fe..444e8e2 100644 --- a/src/error.rs +++ b/src/error.rs @@ -41,6 +41,7 @@ pub enum WebSocketError { #[cfg(feature = "upgrade")] #[error(transparent)] HTTPError(#[from] hyper::Error), + #[cfg(feature = "unstable-split")] #[error("Failed to send frame")] SendError(#[from] Box), }