From b6236b175bc756a228bfaffe533c969922d15b17 Mon Sep 17 00:00:00 2001 From: Oliver Daff Date: Tue, 22 Nov 2022 16:25:53 +1000 Subject: [PATCH] feat/sink-utils (#22) * Added checksum module. * Added counter --- CHANGELOG.md | 4 + Cargo.toml | 17 ++++ docker-compose.yml | 2 +- src/checksum.rs | 205 +++++++++++++++++++++++++++++++++++++++++++++ src/counter.rs | 200 +++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 7 +- 6 files changed, 433 insertions(+), 2 deletions(-) create mode 100644 src/checksum.rs create mode 100644 src/counter.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b4b6e6..7433161 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Added checksum module with CRC32Sink to calculate CRC32. +- Added counter module with ByteCounter and ByteLimit. +- Enable optional features generation on docs.rs. + ## 0.3.0 - Added a `try_finally` helper for running async cleanup code at the end of diff --git a/Cargo.toml b/Cargo.toml index 3b22c04..72bdf4d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,10 +15,27 @@ include = [ "LICENCE", ] +[features] +# No features on by default +default = [] + +# Shorthand for enabling everything +full = ["checksum"] + +checksum = ["crc32fast", "anyhow"] + [dependencies] +anyhow = { version = "1.0.66", optional = true } +crc32fast = { version = "1.3.2", optional=true } futures = "0.3.25" +pin-project-lite = "0.2.9" [dev-dependencies] anyhow = "1.0.66" tokio = { version = "1.22.0", features=["macros"] } tokio-test = "0.4.2" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + diff --git a/docker-compose.yml b/docker-compose.yml index 8605e3c..1e69b5e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: '3.7' services: cargo: - image: harrisonai/rust:1.60-0.2 + image: harrisonai/rust:1.65 entrypoint: cargo volumes: - '~/.cargo/registry:/usr/local/cargo/registry' diff --git a/src/checksum.rs b/src/checksum.rs new file mode 100644 index 0000000..f7c4d67 --- /dev/null +++ b/src/checksum.rs @@ -0,0 +1,205 @@ +//! Module for checksum calculation. + +use std::marker::PhantomData; +use std::task::Poll; + +use crc32fast::Hasher; +use futures::prelude::*; +use pin_project_lite::pin_project; +use std::pin::Pin; +use std::task::Context; + +pin_project! { +#[must_use = "sinks do nothing unless polled"] +/// A [Sink] that will calculate the CRC32 of any +/// `AsRef<[u8]>`. +/// +/// ```rust +/// # use cobalt_async::checksum::CRC32Sink; +/// # use futures::sink::SinkExt; +/// # use tokio_test; +/// # tokio_test::block_on(async { +/// let mut sink = CRC32Sink::default(); +/// sink.send("this is a test".as_bytes()).await?; +/// sink.close().await.unwrap(); +/// assert_eq!(220129258, sink.value().unwrap()); +/// # Ok::<(), anyhow::Error>(()) +/// # }); +/// +/// ``` +/// Attempting to get the [value](`CRC32Sink::value`) of the [Sink] before +/// calling [Sink::poll_close] results in a [None] being returned +/// ```rust +/// # use cobalt_async::checksum::CRC32Sink; +/// # use futures::sink::SinkExt; +/// # use tokio_test; +/// # tokio_test::block_on(async { +/// let mut sink = CRC32Sink::default(); +/// sink.send("this is a test".as_bytes()).await?; +/// assert!(sink.value().is_none()); +/// # Ok::<(), anyhow::Error>(()) +/// # }); +/// +pub struct CRC32Sink> { + digest: Option, + value: Option, + //Needed to allow StreamExt to determine the type of Item + marker: std::marker::PhantomData +} +} + +impl> CRC32Sink { + /// Produces a new CRC32Sink. + pub fn new() -> CRC32Sink { + CRC32Sink { + digest: Some(Hasher::new()), + value: None, + marker: PhantomData, + } + } + + /// Returns the crc32 of the values passed into the + /// [Sink] once the [Sink] has been closed. + pub fn value(&self) -> Option { + self.value + } +} + +impl> Default for CRC32Sink { + fn default() -> Self { + Self::new() + } +} + +/// The [futures] crate provides a [into_sink](futures::io::AsyncWriteExt::into_sink) for AsyncWrite but it is +/// not possible to get the value out of it afterwards as it takes ownership. +impl> futures::sink::Sink for CRC32Sink { + type Error = anyhow::Error; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + match self.digest { + Some(_) => Poll::Ready(Ok(())), + None => Poll::Ready(Err(anyhow::Error::msg("Close has been called."))), + } + } + + fn start_send( + self: Pin<&mut CRC32Sink>, + item: Item, + ) -> std::result::Result<(), Self::Error> { + let mut this = self.project(); + match &mut this.digest { + Some(digest) => { + digest.update(item.as_ref()); + Ok(()) + } + None => Err(anyhow::Error::msg("Close has been called.")), + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + match self.digest { + Some(_) => Poll::Ready(Ok(())), + None => Poll::Ready(Err(anyhow::Error::msg("Close has been called."))), + } + } + + fn poll_close( + mut self: Pin<&mut CRC32Sink>, + _cx: &mut Context<'_>, + ) -> Poll> { + match std::mem::take(&mut self.digest) { + Some(digest) => { + self.value = Some(digest.finalize()); + Poll::Ready(Ok(())) + } + None => Poll::Ready(Err(anyhow::Error::msg("Close has been called."))), + } + } +} + +impl AsyncWrite for CRC32Sink<&[u8]> { + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut this = self.project(); + match &mut this.digest { + Some(digest) => { + digest.update(buf); + Poll::Ready(Ok(buf.len())) + } + None => Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Close has been called", + ))), + } + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + match &mut this.digest { + Some(_) => Poll::Ready(Ok(())), + None => Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Close has been called", + ))), + } + } + + fn poll_close(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + match std::mem::take(&mut self.digest) { + Some(digest) => { + self.value = Some(digest.finalize()); + Poll::Ready(Ok(())) + } + None => Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Close has been called", + ))), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_crc32() { + let mut sink = CRC32Sink::default(); + sink.send(&[0, 100]).await.unwrap(); + sink.close().await.unwrap(); + assert_eq!(184989630, sink.value.unwrap()); + } + + #[tokio::test] + async fn test_crc32_write_after_close() { + let mut sink = CRC32Sink::default(); + sink.send(&[0, 100]).await.unwrap(); + sink.close().await.unwrap(); + assert!(sink.send(&[0, 100]).await.is_err()); + } + + #[tokio::test] + async fn test_crc32_flush_after_close() { + let mut sink = CRC32Sink::default(); + sink.send(&[0, 100]).await.unwrap(); + sink.close().await.unwrap(); + assert!(sink.flush().await.is_err()); + } + + #[tokio::test] + async fn test_crc32_close_after_close() { + let mut sink = CRC32Sink::<&[u8]>::default(); + futures::SinkExt::close(&mut sink).await.unwrap(); + assert!(futures::SinkExt::close(&mut sink).await.is_err()); + } +} diff --git a/src/counter.rs b/src/counter.rs new file mode 100644 index 0000000..4308118 --- /dev/null +++ b/src/counter.rs @@ -0,0 +1,200 @@ +//! Module holding utilities related to counts +//! of bytes passed to [AsyncWrite]. + +use std::task::Poll; + +use futures::ready; +use futures::AsyncWrite; +use pin_project_lite::pin_project; +use std::io::{Error, ErrorKind, Result}; +use std::pin::Pin; +use std::task::Context; + +pin_project! { + ///An [AsyncWrite] which counts bytes written to the + ///wrapped [AsyncWrite]. + ///```rust + /// # use cobalt_async::counter::ByteCounter; + /// # use futures::AsyncWriteExt; + /// # use futures::io::sink; + /// # use tokio_test; + /// # tokio_test::block_on(async { + /// let mut counter = ByteCounter::new(sink()); + /// counter.write("this is a test".as_bytes()).await?; + /// assert_eq!(14, counter.byte_count()); + /// # Ok::<(), anyhow::Error>(()) + /// # }); + ///```` + /// + ///## Note + ///The count is stored as a [u128] and unchecked addition + ///is use to increment the count which means wrapping a + ///long running [AsyncWrite] may lead to an overflow. + #[derive(Debug)] + pub struct ByteCounter { + byte_count: u128, + #[pin] + inner: T + } +} + +impl ByteCounter { + /// Returns a new ByteCounter + /// wrapping the inner [AsyncWrite], with + /// the byte count initialised to 0. + pub fn new(inner: T) -> Self { + ByteCounter { + byte_count: 0, + inner, + } + } + + /// Returns the current count of bytes + /// written into the [AsyncWrite]. + pub fn byte_count(&self) -> u128 { + self.byte_count + } + + /// Returns the inner [AsyncWrite], consuming + /// this [Self]. + pub fn into_inner(self) -> T { + self.inner + } +} + +impl AsyncWrite for ByteCounter { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let this = self.project(); + let written = ready!(this.inner.poll_write(cx, buf))?; + *this.byte_count += u128::try_from(written).map_err(|e| Error::new(ErrorKind::Other, e))?; + Poll::Ready(Ok(written)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx) + } +} + +pin_project! { + ///An AyncWrite which raises an Error if the number of bytes + ///written is more that the `byte_limit` + ///```rust + /// # use cobalt_async::counter::ByteLimit; + /// # use futures::AsyncWriteExt; + /// # use futures::io::sink; + /// # use tokio_test; + /// # tokio_test::block_on(async { + /// let mut counter = ByteLimit::new_from_inner(sink(), 10); + /// counter.write("this is a ".as_bytes()).await?; + /// assert!(counter.write("error".as_bytes()).await.is_err()); + /// # Ok::<(), anyhow::Error>(()) + /// # }); + ///```` + /// + ///## Note + ///The count is stored as a [u128] and unchecked addition + ///is use to increment the count which means wrapping a + ///long running [AsyncWrite] may lead to an overflow. + #[derive(Debug)] + pub struct ByteLimit { + byte_limit: u128, + #[pin] + counter: ByteCounter + } +} + +impl ByteLimit { + /// Creates a new [ByteLimit] using the provided [ByteCounter] + /// which will raise an error when the `byte_limit` is exceeded. + pub fn new(counter: ByteCounter, byte_limit: u128) -> Self { + ByteLimit { + byte_limit, + counter, + } + } + + /// Convenience method to create a `ByteLimit` and `ByteCounter` + /// from the inner `AsyncWrite`. + pub fn new_from_inner(inner: T, byte_limit: u128) -> Self { + ByteLimit { + byte_limit, + counter: ByteCounter::new(inner), + } + } + + /// Returns the inner [ByteCounter] consuming the [Self]. + pub fn into_innner(self) -> ByteCounter { + self.counter + } +} + +impl AsyncWrite for ByteLimit { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let this = self.project(); + let current_count = this.counter.byte_count(); + match this.counter.poll_write(cx, buf)? { + Poll::Ready(written) => { + let written_u128 = + u128::try_from(written).map_err(|e| Error::new(ErrorKind::Other, e))?; + if current_count + written_u128 > *this.byte_limit { + Poll::Ready(Err(Error::new( + ErrorKind::Other, + "Byte Limit Reached: {this.byte_limit} bytes", + ))) + } else { + Poll::Ready(Ok(written)) + } + } + _ => Poll::Pending, + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().counter.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().counter.poll_close(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::prelude::*; + + #[tokio::test] + async fn test_byte_count() { + let buffer = vec![]; + let mut counter = ByteCounter::new(buffer); + let bytes_to_write = 100_usize; + assert!(counter.write_all(&vec![0; bytes_to_write]).await.is_ok()); + counter.close().await.unwrap(); + assert_eq!( + u128::try_from(bytes_to_write).unwrap(), + counter.byte_count() + ); + } + + #[tokio::test] + async fn test_byte_count_limit_over() { + let buffer = vec![]; + let mut counter = ByteLimit::new_from_inner(buffer, 99); + let bytes_to_write = 100_usize; + assert!(counter.write_all(&vec![0; bytes_to_write]).await.is_err()); + } + + #[tokio::test] + async fn test_byte_count_limit_reached() { + let buffer = vec![]; + let mut counter = ByteLimit::new_from_inner(buffer, 100); + let bytes_to_write = 100_usize; + assert!(counter.write_all(&vec![0; bytes_to_write]).await.is_ok()); + let counter = counter.into_innner(); + assert_eq!(counter.byte_count(), bytes_to_write as u128); + } +} diff --git a/src/lib.rs b/src/lib.rs index e1b8fbe..8d7d898 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,8 +11,13 @@ //! At [harrison.ai](https://harrison.ai) our mission is to create AI-as-a-medical-device solutions through //! ventures and ultimately improve the standard of healthcare for 1 million lives every day. //! - +//! +#![cfg_attr(docsrs, feature(doc_cfg))] +#[cfg(feature = "checksum")] +#[cfg_attr(docsrs, doc(cfg(feature = "checksum")))] +pub mod checksum; mod chunker; +pub mod counter; mod try_finally; pub use chunker::{apply_chunker, try_apply_chunker, ChunkResult};