diff --git a/bitar/src/chunker/streaming_chunker.rs b/bitar/src/chunker/streaming_chunker.rs index 152e5bc..23abafd 100644 --- a/bitar/src/chunker/streaming_chunker.rs +++ b/bitar/src/chunker/streaming_chunker.rs @@ -5,8 +5,11 @@ use std::{ }; use bytes::BytesMut; -use futures_util::{ready, Stream}; -use tokio::io::{AsyncRead, ReadBuf}; +use futures_util::{ready, FutureExt, Stream}; +use tokio::{ + io::{AsyncRead, AsyncReadExt}, + pin, +}; use crate::{chunker::Chunker, Chunk}; @@ -47,65 +50,28 @@ where return Poll::Ready(Some(Ok((offset, chunk)))); } } - // No chunk found in the buffer. Read data and append to buffer. - match ready!(refill_buf(cx, &mut me.buf, &mut me.reader)) { - Ok(0) if me.buf.is_empty() => { - // EOF and empty buffer. - return Poll::Ready(None); - } - Ok(0) => { - // EOF but some data in buffer (last chunk). - let chunk = Chunk(me.buf.split().freeze()); - return Poll::Ready(Some(Ok((me.chunk_start, chunk)))); - } - Ok(_) => { - // Buffer refilled. - } - Err(err) => return Poll::Ready(Some(Err(err))), + // Append more data to buffer since no chunk was found. + if me.buf.capacity() < me.buf.len() + REFILL_SIZE { + me.buf.reserve(REFILL_SIZE); } - } - } -} - -fn refill_buf(cx: &mut Context, buf: &mut BytesMut, mut reader: R) -> Poll> -where - R: AsyncRead + Unpin, -{ - let mut read_count = 0; - let before_size = buf.len(); - { - let new_size = before_size + REFILL_SIZE; - if buf.capacity() < new_size { - buf.reserve(REFILL_SIZE); - } - unsafe { - // Use unsafe set_len() here instead of resize as we don't care for - // zeroing the content of buf. - buf.set_len(new_size); - } - } - while read_count < REFILL_SIZE { - let offset = before_size + read_count; - let mut read_buf = ReadBuf::new(&mut buf[offset..]); - let rc = match Pin::new(&mut reader).poll_read(cx, &mut read_buf) { - Poll::Ready(Ok(())) if read_buf.filled().is_empty() => break, // EOF - Poll::Ready(Ok(())) => read_buf.filled().len(), - Poll::Ready(Err(err)) => { - buf.resize(before_size + read_count, 0); - return Poll::Ready(Err(err)); - } - Poll::Pending => { - buf.resize(before_size + read_count, 0); - if read_count > 0 { - return Poll::Ready(Ok(read_count)); + let read_f = me.reader.read_buf(&mut me.buf); + pin!(read_f); + match ready!(read_f.poll_unpin(cx))? { + 0 => { + // End of file/reader. + // Return a last chunk if there is data left in buffer. + let last_chunk = if me.buf.is_empty() { + None + } else { + let chunk = Chunk(me.buf.split().freeze()); + Some(Ok((me.chunk_start, chunk))) + }; + return Poll::Ready(last_chunk); } - return Poll::Pending; + _rc => {} } - }; - read_count += rc; + } } - buf.resize(before_size + read_count, 0); - Poll::Ready(Ok(read_count)) } #[cfg(test)] @@ -115,6 +81,7 @@ mod tests { use crate::chunker::{Config, FilterBits, FilterConfig}; use futures_util::StreamExt; use std::cmp; + use tokio::io::ReadBuf; // The MockSource will return bytes_per_read bytes every other read // and Pending every other, to replicate a source with limited I/O. @@ -139,7 +106,7 @@ mod tests { impl AsyncRead for MockSource { fn poll_read( mut self: Pin<&mut Self>, - _cx: &mut Context, + cx: &mut Context, buf: &mut ReadBuf, ) -> Poll> { let data_available = self.data.len() - self.offset; @@ -147,11 +114,12 @@ mod tests { Poll::Ready(Ok(())) } else if self.pending { self.pending = false; + cx.waker().wake_by_ref(); Poll::Pending } else { let read = cmp::min( data_available, - cmp::min(buf.initialized().len(), self.bytes_per_read), + cmp::min(buf.remaining(), self.bytes_per_read), ); buf.put_slice(&self.data[self.offset..self.offset + read]); self.offset += read;