Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: split WebSocket #48

Merged
merged 8 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ thiserror = "1.0.40"
default = ["simd"]
simd = ["simdutf8/aarch64_neon"]
upgrade = ["hyper", "pin-project", "base64", "sha1"]
unstable-split = []

[dev-dependencies]
tokio = { version = "1.25.0", features = ["full", "macros"] }
Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,6 @@ pub enum WebSocketError {
#[cfg(feature = "upgrade")]
#[error(transparent)]
HTTPError(#[from] hyper::Error),
#[error("Failed to send frame")]
SendError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
}
59 changes: 59 additions & 0 deletions src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#[cfg(feature = "unstable-split")]
use std::future::Future;

use crate::error::WebSocketError;
use crate::frame::Frame;
use crate::recv::SharedRecv;
use crate::OpCode;
use crate::ReadHalf;
use crate::WebSocket;
#[cfg(feature = "unstable-split")]
use crate::WebSocketRead;
use crate::WriteHalf;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
Expand Down Expand Up @@ -136,6 +141,60 @@ impl<'f, S> FragmentCollector<S> {
}
}

#[cfg(feature = "unstable-split")]
pub struct FragmentCollectorRead<S> {
stream: S,
read_half: ReadHalf,
fragments: Fragments,
// !Sync marker
_marker: std::marker::PhantomData<SharedRecv>,
}

#[cfg(feature = "unstable-split")]
impl<'f, S> FragmentCollectorRead<S> {
/// Creates a new `FragmentCollector` with the provided `WebSocket`.
pub fn new(ws: WebSocketRead<S>) -> FragmentCollectorRead<S>
where
S: AsyncReadExt + Unpin,
{
let (stream, read_half) = ws.into_parts_internal();
FragmentCollectorRead {
stream,
read_half,
fragments: Fragments::new(),
_marker: std::marker::PhantomData,
}
}

/// Reads a WebSocket frame, collecting fragmented messages until the final frame is received and returns the completed message.
///
/// Text frames payload is guaranteed to be valid UTF-8.
pub async fn read_frame<R, E>(
&mut self,
send_fn: &mut impl FnMut(Frame<'f>) -> R,
) -> Result<Frame<'f>, WebSocketError>
where
S: AsyncReadExt + Unpin,
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
R: Future<Output = Result<(), E>>,
{
loop {
let (res, obligated_send) =
self.read_half.read_frame_inner(&mut self.stream).await;
if let Some(frame) = obligated_send {
let res = send_fn(frame).await;
res.map_err(|e| WebSocketError::SendError(e.into()))?;
}
let Some(frame) = res? else {
continue;
};
if let Some(frame) = self.fragments.accumulate(frame)? {
return Ok(frame);
}
}
}
}

/// Accumulates potentially fragmented [`Frame`]s to defragment the incoming WebSocket stream.
struct Fragments {
fragments: Option<Fragment>,
Expand Down
223 changes: 204 additions & 19 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,26 @@ mod recv;
#[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))]
pub mod upgrade;

#[cfg(feature = "unstable-split")]
use std::future::Future;

use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;

pub use crate::close::CloseCode;
pub use crate::error::WebSocketError;
pub use crate::fragment::FragmentCollector;
#[cfg(feature = "unstable-split")]
pub use crate::fragment::FragmentCollectorRead;
pub use crate::frame::Frame;
pub use crate::frame::OpCode;
pub use crate::frame::Payload;
pub use crate::mask::unmask;
use crate::recv::SharedRecv;

#[derive(Copy, Clone, Default)]
struct UnsendMarker(std::marker::PhantomData<SharedRecv>);

#[derive(Copy, Clone, PartialEq)]
pub enum Role {
Server,
Expand All @@ -199,13 +207,150 @@ pub(crate) struct ReadHalf {
max_message_size: usize,
}

#[cfg(feature = "unstable-split")]
pub struct WebSocketRead<S> {
stream: S,
read_half: ReadHalf,
_marker: UnsendMarker,
}

#[cfg(feature = "unstable-split")]
pub struct WebSocketWrite<S> {
stream: S,
write_half: WriteHalf,
_marker: UnsendMarker,
}

#[cfg(feature = "unstable-split")]
/// Create a split `WebSocketRead`/`WebSocketWrite` pair from a stream that has already completed the WebSocket handshake.
pub fn after_handshake_split<R, W>(
read: R,
write: W,
role: Role,
) -> (WebSocketRead<R>, WebSocketWrite<W>)
where
R: AsyncWriteExt + Unpin,
W: AsyncWriteExt + Unpin,
{
(
WebSocketRead {
stream: read,
read_half: ReadHalf::after_handshake(role),
_marker: UnsendMarker::default(),
},
WebSocketWrite {
stream: write,
write_half: WriteHalf::after_handshake(role),
_marker: UnsendMarker::default(),
},
)
}

#[cfg(feature = "unstable-split")]
impl<'f, S> WebSocketRead<S> {
/// Consumes the `WebSocketRead` and returns the underlying stream.
#[inline]
pub(crate) fn into_parts_internal(self) -> (S, ReadHalf) {
(self.stream, self.read_half)
}

pub fn set_writev_threshold(&mut self, threshold: usize) {
self.read_half.writev_threshold = threshold;
}

/// Sets whether to automatically close the connection when a close frame is received. When set to `false`, the application will have to manually send close frames.
///
/// Default: `true`
pub fn set_auto_close(&mut self, auto_close: bool) {
self.read_half.auto_close = auto_close;
}

/// Sets whether to automatically send a pong frame when a ping frame is received.
///
/// Default: `true`
pub fn set_auto_pong(&mut self, auto_pong: bool) {
self.read_half.auto_pong = auto_pong;
}

/// Sets the maximum message size in bytes. If a message is received that is larger than this, the connection will be closed.
///
/// Default: 64 MiB
pub fn set_max_message_size(&mut self, max_message_size: usize) {
self.read_half.max_message_size = max_message_size;
}

/// Sets whether to automatically apply the mask to the frame payload.
///
/// Default: `true`
pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
self.read_half.auto_apply_mask = auto_apply_mask;
}

/// Reads a frame from the stream.
pub async fn read_frame<R, E>(
&mut self,
send_fn: &mut impl FnMut(Frame<'f>) -> R,
) -> Result<Frame, WebSocketError>
where
S: AsyncReadExt + Unpin,
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
R: Future<Output = Result<(), E>>,
{
loop {
let (res, obligated_send) =
self.read_half.read_frame_inner(&mut self.stream).await;
if let Some(frame) = obligated_send {
let res = send_fn(frame).await;
res.map_err(|e| WebSocketError::SendError(e.into()))?;
}
if let Some(frame) = res? {
break Ok(frame);
}
}
}
}

#[cfg(feature = "unstable-split")]
impl<'f, S> WebSocketWrite<S> {
/// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used.
///
/// Default: `true`
pub fn set_writev(&mut self, vectored: bool) {
self.write_half.vectored = vectored;
}

pub fn set_writev_threshold(&mut self, threshold: usize) {
self.write_half.writev_threshold = threshold;
}

/// Sets whether to automatically apply the mask to the frame payload.
///
/// Default: `true`
pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
self.write_half.auto_apply_mask = auto_apply_mask;
}

pub fn is_closed(&self) -> bool {
self.write_half.closed
}

pub async fn write_frame(
&mut self,
frame: Frame<'f>,
) -> Result<(), WebSocketError>
where
S: AsyncWriteExt + Unpin,
{
self.write_half.write_frame(&mut self.stream, frame).await
}
}

/// WebSocket protocol implementation over an async stream.
pub struct WebSocket<S> {
stream: S,
write_half: WriteHalf,
read_half: ReadHalf,
// !Sync marker
_marker: std::marker::PhantomData<SharedRecv>,
_marker: UnsendMarker,
}

impl<'f, S> WebSocket<S> {
Expand Down Expand Up @@ -235,25 +380,36 @@ impl<'f, S> WebSocket<S> {
recv::init_once();
Self {
stream,
write_half: WriteHalf {
role,
closed: false,
auto_apply_mask: true,
vectored: true,
writev_threshold: 1024,
write_buffer: Vec::with_capacity(2),
write_half: WriteHalf::after_handshake(role),
read_half: ReadHalf::after_handshake(role),
_marker: UnsendMarker::default(),
}
}

#[cfg(feature = "unstable-split")]
pub fn split<R, W>(
self,
split_fn: impl Fn(S) -> (R, W),
) -> (WebSocketRead<R>, WebSocketWrite<W>)
where
S: AsyncReadExt + AsyncWriteExt + Unpin,
R: AsyncReadExt + Unpin,
W: AsyncWriteExt + Unpin,
{
let (stream, read, write) = self.into_parts_internal();
let (r, w) = split_fn(stream);
(
WebSocketRead {
stream: r,
read_half: read,
_marker: UnsendMarker::default(),
},
read_half: ReadHalf {
role,
spill: None,
auto_apply_mask: true,
auto_close: true,
auto_pong: true,
writev_threshold: 1024,
max_message_size: 64 << 20,
WebSocketWrite {
stream: w,
write_half: write,
_marker: UnsendMarker::default(),
},
_marker: std::marker::PhantomData,
}
)
}

/// Consumes the `WebSocket` and returns the underlying stream.
Expand Down Expand Up @@ -310,6 +466,10 @@ impl<'f, S> WebSocket<S> {
self.write_half.auto_apply_mask = auto_apply_mask;
}

pub fn is_closed(&self) -> bool {
self.write_half.closed
}

/// Writes a frame to the stream.
///
/// # Example
Expand Down Expand Up @@ -388,6 +548,18 @@ impl<'f, S> WebSocket<S> {
}

impl ReadHalf {
pub fn after_handshake(role: Role) -> Self {
Self {
role,
spill: None,
auto_apply_mask: true,
auto_close: true,
auto_pong: true,
writev_threshold: 1024,
max_message_size: 64 << 20,
}
}

/// Attempt to read a single frame from from the incoming stream, returning any send obligations if
/// `auto_close` or `auto_pong` are enabled. Callers to this function are obligated to send the
/// frame in the latter half of the tuple if one is specified, unless the write half of this socket
Expand Down Expand Up @@ -573,6 +745,17 @@ impl ReadHalf {
}

impl WriteHalf {
pub fn after_handshake(role: Role) -> Self {
Self {
role,
closed: false,
auto_apply_mask: true,
vectored: true,
writev_threshold: 1024,
write_buffer: Vec::with_capacity(2),
}
}

/// Writes a frame to the provided stream.
pub async fn write_frame<'a, S>(
&'a mut self,
Expand All @@ -588,6 +771,8 @@ impl WriteHalf {

if frame.opcode == OpCode::Close {
self.closed = true;
} else if self.closed {
return Err(WebSocketError::ConnectionClosed);
}

if self.vectored && frame.payload.len() > self.writev_threshold {
Expand Down
Loading