diff --git a/src/records.rs b/src/records.rs index cc553ff..12aa942 100644 --- a/src/records.rs +++ b/src/records.rs @@ -12,6 +12,7 @@ //! use kafka_protocol::protocol::Decodable; //! use kafka_protocol::records::RecordBatchDecoder; //! use bytes::Bytes; +//! use kafka_protocol::records::Compression; //! //! # const HEADER: [u8; 45] = [ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,]; //! # const RECORD: [u8; 79] = [ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x43, 0x0, 0x0, 0x0, 0x0, 0x2, 0x73, 0x6d, 0x29, 0x7b, 0x0, 0b00000000, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x22, 0x1, 0xd0, 0xf, 0x2, 0xa, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0xa, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x0,]; @@ -26,12 +27,19 @@ //! for topic in res.responses { //! for partition in topic.partitions { //! let mut records = partition.records.unwrap(); -//! let records = RecordBatchDecoder::decode(&mut records).unwrap(); +//! let records = RecordBatchDecoder::decode(&mut records, Some(decompress_record_batch_data)).unwrap(); //! } //! } +//! +//! fn decompress_record_batch_data(compressed_buffer: &mut bytes::Bytes, compression: Compression) -> anyhow::Result { +//! match compression { +//! Compression::None => Ok(compressed_buffer.to_vec().into()), +//! _ => { panic!("Compression not implemented") } +//! } +//! } //! ``` use anyhow::{anyhow, bail, Result}; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use crc::{Crc, CRC_32_ISO_HDLC}; use crc32c::crc32c; use indexmap::IndexMap; @@ -44,7 +52,6 @@ use crate::protocol::{ use super::compression::{self as cmpr, Compressor, Decompressor}; use std::cmp::Ordering; use std::convert::TryFrom; - /// IEEE (checksum) cyclic redundancy check. pub const IEEE: Crc = Crc::::new(&CRC_32_ISO_HDLC); @@ -151,16 +158,25 @@ const MAGIC_BYTE_OFFSET: usize = 16; impl RecordBatchEncoder { /// Encode records into given buffer, using provided encoding options that select the encoding /// strategy based on version. - pub fn encode<'a, B, I>(buf: &mut B, records: I, options: &RecordEncodeOptions) -> Result<()> + /// # Arguments + /// * `compressor` - A function that compresses the given batch of records. + /// If `None`, the right compression algorithm will automatically be selected and applied. + pub fn encode<'a, B, I, CF>( + buf: &mut B, + records: I, + options: &RecordEncodeOptions, + compressor: Option, + ) -> Result<()> where B: ByteBufMut, I: IntoIterator, I::IntoIter: Clone, + CF: Fn(&mut BytesMut, &mut B, Compression) -> Result<()>, { let records = records.into_iter(); match options.version { - 0..=1 => Self::encode_legacy(buf, records, options), - 2 => Self::encode_new(buf, records, options), + 0..=1 => Self::encode_legacy(buf, records, options, compressor), + 2 => Self::encode_new(buf, records, options, compressor), _ => bail!("Unknown record batch version"), } } @@ -178,10 +194,16 @@ impl RecordBatchEncoder { } Ok(()) } - fn encode_legacy<'a, B, I>(buf: &mut B, records: I, options: &RecordEncodeOptions) -> Result<()> + fn encode_legacy<'a, B, I, CF>( + buf: &mut B, + records: I, + options: &RecordEncodeOptions, + compressor: Option, + ) -> Result<()> where B: ByteBufMut, I: Iterator + Clone, + CF: Fn(&mut BytesMut, &mut B, Compression) -> Result<()>, { if options.compression == Compression::None { // No wrapper needed @@ -210,20 +232,26 @@ impl RecordBatchEncoder { // Value (Compressed MessageSet) let size_gap = buf.put_typed_gap(gap::I32); let value_start = buf.offset(); - match options.compression { - Compression::Snappy => cmpr::Snappy::compress(buf, |buf| { - Self::encode_legacy_records(buf, records, &inner_opts) - })?, - Compression::Gzip => cmpr::Gzip::compress(buf, |buf| { - Self::encode_legacy_records(buf, records, &inner_opts) - })?, - Compression::Lz4 => cmpr::Lz4::compress(buf, |buf| { - Self::encode_legacy_records(buf, records, &inner_opts) - })?, - Compression::Zstd => cmpr::Zstd::compress(buf, |buf| { - Self::encode_legacy_records(buf, records, &inner_opts) - })?, - _ => unimplemented!(), + if let Some(compressor) = compressor { + let mut encoded_buf = BytesMut::new(); + Self::encode_legacy_records(&mut encoded_buf, records, &inner_opts)?; + compressor(&mut encoded_buf, buf, options.compression)?; + } else { + match options.compression { + Compression::Snappy => cmpr::Snappy::compress(buf, |buf| { + Self::encode_legacy_records(buf, records, &inner_opts) + })?, + Compression::Gzip => cmpr::Gzip::compress(buf, |buf| { + Self::encode_legacy_records(buf, records, &inner_opts) + })?, + Compression::Lz4 => cmpr::Lz4::compress(buf, |buf| { + Self::encode_legacy_records(buf, records, &inner_opts) + })?, + Compression::Zstd => cmpr::Zstd::compress(buf, |buf| { + Self::encode_legacy_records(buf, records, &inner_opts) + })?, + _ => unimplemented!(), + } } let value_end = buf.offset(); @@ -259,14 +287,16 @@ impl RecordBatchEncoder { Ok(()) } - fn encode_new_batch<'a, B, I>( + fn encode_new_batch<'a, B, I, CF>( buf: &mut B, records: &mut I, options: &RecordEncodeOptions, + compressor: Option<&CF>, ) -> Result where B: ByteBufMut, I: Iterator + Clone, + CF: Fn(&mut BytesMut, &mut B, Compression) -> Result<()>, { let mut record_peeker = records.clone(); @@ -375,24 +405,30 @@ impl RecordBatchEncoder { // Records let records = records.take(num_records); - match options.compression { - Compression::None => cmpr::None::compress(buf, |buf| { - Self::encode_new_records(buf, records, min_offset, min_timestamp, options) - })?, - Compression::Snappy => cmpr::Snappy::compress(buf, |buf| { - Self::encode_new_records(buf, records, min_offset, min_timestamp, options) - })?, - Compression::Gzip => cmpr::Gzip::compress(buf, |buf| { - Self::encode_new_records(buf, records, min_offset, min_timestamp, options) - })?, - Compression::Lz4 => cmpr::Lz4::compress(buf, |buf| { - Self::encode_new_records(buf, records, min_offset, min_timestamp, options) - })?, - Compression::Zstd => cmpr::Zstd::compress(buf, |buf| { - Self::encode_new_records(buf, records, min_offset, min_timestamp, options) - })?, - } + if let Some(compressor) = compressor { + let mut record_buf = BytesMut::new(); + Self::encode_new_records(&mut record_buf, records, min_offset, min_timestamp, options)?; + compressor(&mut record_buf, buf, options.compression)?; + } else { + match options.compression { + Compression::None => cmpr::None::compress(buf, |buf| { + Self::encode_new_records(buf, records, min_offset, min_timestamp, options) + })?, + Compression::Snappy => cmpr::Snappy::compress(buf, |buf| { + Self::encode_new_records(buf, records, min_offset, min_timestamp, options) + })?, + Compression::Gzip => cmpr::Gzip::compress(buf, |buf| { + Self::encode_new_records(buf, records, min_offset, min_timestamp, options) + })?, + Compression::Lz4 => cmpr::Lz4::compress(buf, |buf| { + Self::encode_new_records(buf, records, min_offset, min_timestamp, options) + })?, + Compression::Zstd => cmpr::Zstd::compress(buf, |buf| { + Self::encode_new_records(buf, records, min_offset, min_timestamp, options) + })?, + } + } let batch_end = buf.offset(); // Fill size gap @@ -413,34 +449,49 @@ impl RecordBatchEncoder { Ok(true) } - fn encode_new<'a, B, I>( + fn encode_new<'a, B, I, CF>( buf: &mut B, mut records: I, options: &RecordEncodeOptions, + compressor: Option, ) -> Result<()> where B: ByteBufMut, I: Iterator + Clone, + CF: Fn(&mut BytesMut, &mut B, Compression) -> Result<()>, { - while Self::encode_new_batch(buf, &mut records, options)? {} + while Self::encode_new_batch(buf, &mut records, options, compressor.as_ref())? {} Ok(()) } } impl RecordBatchDecoder { /// Decode the provided buffer into a vec of records. - pub fn decode(buf: &mut B) -> Result> { + /// # Arguments + /// * `decompressor` - A function that decompresses the given batch of records. + /// If `None`, the right decompression algorithm will automatically be selected and applied. + pub fn decode(buf: &mut B, decompressor: Option) -> Result> + where + F: Fn(&mut bytes::Bytes, Compression) -> Result, + { let mut records = Vec::new(); while buf.has_remaining() { - Self::decode_batch(buf, &mut records)?; + Self::decode_batch(buf, &mut records, decompressor.as_ref())?; } Ok(records) } - fn decode_batch(buf: &mut B, records: &mut Vec) -> Result<()> { + fn decode_batch( + buf: &mut B, + records: &mut Vec, + decompress_func: Option<&F>, + ) -> Result<()> + where + F: Fn(&mut bytes::Bytes, Compression) -> Result, + { let version = buf.try_peek_bytes(MAGIC_BYTE_OFFSET..(MAGIC_BYTE_OFFSET + 1))?[0] as i8; match version { 0..=1 => Record::decode_legacy(buf, version, records), - 2 => Self::decode_new_batch(buf, version, records), + 2 => Self::decode_new_batch(buf, version, records, decompress_func), _ => { bail!("Unknown record batch version ({})", version); } @@ -458,11 +509,15 @@ impl RecordBatchDecoder { } Ok(()) } - fn decode_new_batch( + fn decode_new_batch( buf: &mut B, version: i8, records: &mut Vec, - ) -> Result<()> { + decompress_func: Option<&F>, + ) -> Result<()> + where + F: Fn(&mut bytes::Bytes, Compression) -> Result, + { // Base offset let min_offset = types::Int64.decode(buf)?; @@ -554,24 +609,29 @@ impl RecordBatchDecoder { producer_epoch, }; - // Records - match compression { - Compression::None => cmpr::None::decompress(buf, |buf| { - Self::decode_new_records(buf, &batch_decode_info, version, records) - })?, - Compression::Snappy => cmpr::Snappy::decompress(buf, |buf| { - Self::decode_new_records(buf, &batch_decode_info, version, records) - })?, - Compression::Gzip => cmpr::Gzip::decompress(buf, |buf| { - Self::decode_new_records(buf, &batch_decode_info, version, records) - })?, - Compression::Zstd => cmpr::Zstd::decompress(buf, |buf| { - Self::decode_new_records(buf, &batch_decode_info, version, records) - })?, - Compression::Lz4 => cmpr::Lz4::decompress(buf, |buf| { - Self::decode_new_records(buf, &batch_decode_info, version, records) - })?, - }; + if let Some(decompress_func) = decompress_func { + let mut decompressed_buf = decompress_func(buf, compression)?; + + Self::decode_new_records(&mut decompressed_buf, &batch_decode_info, version, records)?; + } else { + match compression { + Compression::None => cmpr::None::decompress(buf, |buf| { + Self::decode_new_records(buf, &batch_decode_info, version, records) + })?, + Compression::Snappy => cmpr::Snappy::decompress(buf, |buf| { + Self::decode_new_records(buf, &batch_decode_info, version, records) + })?, + Compression::Gzip => cmpr::Gzip::decompress(buf, |buf| { + Self::decode_new_records(buf, &batch_decode_info, version, records) + })?, + Compression::Zstd => cmpr::Zstd::decompress(buf, |buf| { + Self::decode_new_records(buf, &batch_decode_info, version, records) + })?, + Compression::Lz4 => cmpr::Lz4::decompress(buf, |buf| { + Self::decode_new_records(buf, &batch_decode_info, version, records) + })?, + }; + } Ok(()) } diff --git a/tests/all_tests/fetch_response.rs b/tests/all_tests/fetch_response.rs index f681658..cf11c4f 100644 --- a/tests/all_tests/fetch_response.rs +++ b/tests/all_tests/fetch_response.rs @@ -1,6 +1,7 @@ #[cfg(feature = "client")] mod client_tests { use bytes::Bytes; + use kafka_protocol::records::Compression; use kafka_protocol::{ messages::FetchResponse, protocol::Decodable, records::RecordBatchDecoder, }; @@ -85,7 +86,9 @@ mod client_tests { assert_eq!(partition.aborted_transactions.as_ref().unwrap().len(), 0); let mut records = partition.records.unwrap(); - let records = RecordBatchDecoder::decode(&mut records).unwrap(); + let records = + RecordBatchDecoder::decode(&mut records, Some(decompress_record_batch_data)) + .unwrap(); assert_eq!(records.len(), 1); for record in records { assert_eq!( @@ -120,7 +123,9 @@ mod client_tests { assert_eq!(partition.aborted_transactions.as_ref().unwrap().len(), 0); let mut records = partition.records.unwrap(); - let records = RecordBatchDecoder::decode(&mut records).unwrap(); + let records = + RecordBatchDecoder::decode(&mut records, Some(decompress_record_batch_data)) + .unwrap(); assert_eq!(records.len(), 1); for record in records { assert_eq!( @@ -156,9 +161,23 @@ mod client_tests { assert_eq!(partition.aborted_transactions.as_ref().unwrap().len(), 0); let mut records = partition.records.unwrap(); - let records = RecordBatchDecoder::decode(&mut records).unwrap(); + let records = + RecordBatchDecoder::decode(&mut records, Some(decompress_record_batch_data)) + .unwrap(); assert_eq!(records.len(), 1); } } } + + fn decompress_record_batch_data( + compressed_buffer: &mut bytes::Bytes, + compression: Compression, + ) -> anyhow::Result { + match compression { + Compression::None => Ok(compressed_buffer.to_vec().into()), + _ => { + panic!("Compression not implemented") + } + } + } } diff --git a/tests/all_tests/produce_fetch.rs b/tests/all_tests/produce_fetch.rs index 0362270..259d101 100644 --- a/tests/all_tests/produce_fetch.rs +++ b/tests/all_tests/produce_fetch.rs @@ -39,6 +39,7 @@ fn record_batch_produce_fetch() { version: 2, compression: Compression::None, }, + Some(compress_record_batch_data), ) .unwrap(); @@ -73,6 +74,7 @@ fn message_set_v1_produce_fetch() { version: 1, compression: Compression::None, }, + Some(compress_record_batch_data), ) .unwrap(); @@ -200,7 +202,9 @@ fn fetch_records( ); let mut fetched_records = partition_response.records.clone().unwrap(); - let fetched_records = RecordBatchDecoder::decode(&mut fetched_records).unwrap(); + let fetched_records = + RecordBatchDecoder::decode(&mut fetched_records, Some(decompress_record_batch_data)) + .unwrap(); eprintln!("{expected:#?}"); eprintln!("{fetched_records:#?}"); @@ -224,3 +228,31 @@ fn new_record(offset: i64, v2: bool) -> Record { headers: Default::default(), } } + +fn decompress_record_batch_data( + compressed_buffer: &mut bytes::Bytes, + compression: Compression, +) -> anyhow::Result { + match compression { + Compression::None => Ok(compressed_buffer.to_vec().into()), + _ => { + panic!("Compression not implemented") + } + } +} + +fn compress_record_batch_data( + src: &mut bytes::BytesMut, + dest: &mut BytesMut, + compression: Compression, +) -> anyhow::Result<()> { + match compression { + Compression::None => { + dest.extend_from_slice(src.as_ref()); + Ok(()) + } + _ => { + panic!("Compression not implemented") + } + } +}