diff --git a/src/decompress.rs b/src/decompress.rs index f2bb2c4..89fd98b 100644 --- a/src/decompress.rs +++ b/src/decompress.rs @@ -1,4 +1,5 @@ use simd_adler32::Adler32; +use std::num::NonZeroUsize; use crate::{ huffman::{self, build_table}, @@ -96,11 +97,9 @@ pub struct Decompressor { // Number of bytes left for uncompressed block. uncompressed_bytes_left: u16, - buffer: u64, - nbits: u8, + bits: BitBuffer, - queued_rle: Option<(u8, usize)>, - queued_backref: Option<(usize, usize)>, + queued_output: Option, last_block: bool, fixed_table: bool, @@ -119,8 +118,7 @@ impl Decompressor { /// Create a new decompressor. pub fn new() -> Self { Self { - buffer: 0, - nbits: 0, + bits: BitBuffer::new(), compression: CompressedBlock { litlen_table: Box::new([0; 4096]), dist_table: Box::new([0; 512]), @@ -139,8 +137,7 @@ impl Decompressor { code_lengths: [0; 320], }, uncompressed_bytes_left: 0, - queued_rle: None, - queued_backref: None, + queued_output: None, checksum: Adler32::new(), state: State::ZlibHeader, last_block: false, @@ -154,68 +151,41 @@ impl Decompressor { self.ignore_adler32 = true; } - fn fill_buffer(&mut self, input: &mut &[u8]) { - if input.len() >= 8 { - self.buffer |= u64::from_le_bytes(input[..8].try_into().unwrap()) << self.nbits; - *input = &input[(63 - self.nbits as usize) / 8..]; - self.nbits |= 56; - } else { - let nbytes = input.len().min((63 - self.nbits as usize) / 8); - let mut input_data = [0; 8]; - input_data[..nbytes].copy_from_slice(&input[..nbytes]); - self.buffer |= u64::from_le_bytes(input_data) - .checked_shl(self.nbits as u32) - .unwrap_or(0); - self.nbits += nbytes as u8 * 8; - *input = &input[nbytes..]; - } - } - - fn peak_bits(&mut self, nbits: u8) -> u64 { - debug_assert!(nbits <= 56 && nbits <= self.nbits); - self.buffer & ((1u64 << nbits) - 1) - } - fn consume_bits(&mut self, nbits: u8) { - debug_assert!(self.nbits >= nbits); - self.buffer >>= nbits; - self.nbits -= nbits; - } - fn read_block_header(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> { - self.fill_buffer(remaining_input); - if self.nbits < 10 { + self.bits.fill_buffer(remaining_input); + if self.bits.nbits < 10 { return Ok(()); } - let start = self.peak_bits(3); + let start = self.bits.peek_bits(3); self.last_block = start & 1 != 0; match start >> 1 { 0b00 => { - let align_bits = (self.nbits - 3) % 8; + let align_bits = (self.bits.nbits - 3) % 8; let header_bits = 3 + 32 + align_bits; - if self.nbits < header_bits { + if self.bits.nbits < header_bits { return Ok(()); } - let len = (self.peak_bits(align_bits + 19) >> (align_bits + 3)) as u16; - let nlen = (self.peak_bits(header_bits) >> (align_bits + 19)) as u16; + let len = (self.bits.peek_bits(align_bits + 19) >> (align_bits + 3)) as u16; + let nlen = (self.bits.peek_bits(header_bits) >> (align_bits + 19)) as u16; if nlen != !len { return Err(DecompressionError::InvalidUncompressedBlockLength); } self.state = State::UncompressedData; self.uncompressed_bytes_left = len; - self.consume_bits(header_bits); + self.bits.consume_bits(header_bits); Ok(()) } 0b01 => { - self.consume_bits(3); + self.bits.consume_bits(3); // Check for an entirely empty blocks which can happen if there are "partial // flushes" in the deflate stream. With fixed huffman codes, the EOF symbol is // 7-bits of zeros so we peak ahead and see if the next 7-bits are all zero. - if self.peak_bits(7) == 0 { - self.consume_bits(7); + if self.bits.peek_bits(7) == 0 { + self.bits.consume_bits(7); if self.last_block { self.state = State::Checksum; return Ok(()); @@ -226,9 +196,9 @@ impl Decompressor { // here. But without it, a long sequence of empty fixed-blocks might cause a // stack overflow. Instead, we consume all empty blocks in a loop and then // recurse. This is the only recursive call this function, and thus is safe. - while self.nbits >= 10 && self.peak_bits(10) == 0b010 { - self.consume_bits(10); - self.fill_buffer(remaining_input); + while self.bits.nbits >= 10 && self.bits.peek_bits(10) == 0b010 { + self.bits.consume_bits(10); + self.bits.fill_buffer(remaining_input); } return self.read_block_header(remaining_input); } @@ -251,13 +221,13 @@ impl Decompressor { Ok(()) } 0b10 => { - if self.nbits < 17 { + if self.bits.nbits < 17 { return Ok(()); } - self.header.hlit = (self.peak_bits(8) >> 3) as usize + 257; - self.header.hdist = (self.peak_bits(13) >> 8) as usize + 1; - self.header.hclen = (self.peak_bits(17) >> 13) as usize + 4; + self.header.hlit = (self.bits.peek_bits(8) >> 3) as usize + 257; + self.header.hdist = (self.bits.peek_bits(13) >> 8) as usize + 1; + self.header.hclen = (self.bits.peek_bits(17) >> 13) as usize + 4; if self.header.hlit > 286 { return Err(DecompressionError::InvalidHlit); } @@ -265,7 +235,7 @@ impl Decompressor { return Err(DecompressionError::InvalidHdist); } - self.consume_bits(17); + self.bits.consume_bits(17); self.state = State::CodeLengthCodes; self.fixed_table = false; Ok(()) @@ -279,20 +249,20 @@ impl Decompressor { &mut self, remaining_input: &mut &[u8], ) -> Result<(), DecompressionError> { - self.fill_buffer(remaining_input); - if self.nbits as usize + remaining_input.len() * 8 < 3 * self.header.hclen { + self.bits.fill_buffer(remaining_input); + if self.bits.nbits as usize + remaining_input.len() * 8 < 3 * self.header.hclen { return Ok(()); } let mut code_length_lengths = [0; 19]; for i in 0..self.header.hclen { - code_length_lengths[CLCL_ORDER[i]] = self.peak_bits(3) as u8; - self.consume_bits(3); + code_length_lengths[CLCL_ORDER[i]] = self.bits.peek_bits(3) as u8; + self.bits.consume_bits(3); // We need to refill the buffer after reading 3 * 18 = 54 bits since the buffer holds // between 56 and 63 bits total. if i == 17 { - self.fill_buffer(remaining_input); + self.bits.fill_buffer(remaining_input); } } @@ -317,12 +287,12 @@ impl Decompressor { fn read_code_lengths(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> { let total_lengths = self.header.hlit + self.header.hdist; while self.header.num_lengths_read < total_lengths { - self.fill_buffer(remaining_input); - if self.nbits < 7 { + self.bits.fill_buffer(remaining_input); + if self.bits.nbits < 7 { return Ok(()); } - let code = self.peak_bits(7); + let code = self.bits.peek_bits(7); let entry = self.header.table[code as usize]; let length = (entry & 0x7) as u8; let symbol = (entry >> 16) as u8; @@ -332,7 +302,7 @@ impl Decompressor { 0..=15 => { self.header.code_lengths[self.header.num_lengths_read] = symbol; self.header.num_lengths_read += 1; - self.consume_bits(length); + self.bits.consume_bits(length); } 16..=18 => { let (base_repeat, extra_bits) = match symbol { @@ -342,7 +312,7 @@ impl Decompressor { _ => unreachable!(), }; - if self.nbits < length + extra_bits { + if self.bits.nbits < length + extra_bits { return Ok(()); } @@ -361,7 +331,7 @@ impl Decompressor { }; let repeat = - (self.peak_bits(length + extra_bits) >> length) as usize + base_repeat; + (self.bits.peek_bits(length + extra_bits) >> length) as usize + base_repeat; if self.header.num_lengths_read + repeat > total_lengths { return Err(DecompressionError::InvalidCodeLengthRepeat); } @@ -370,7 +340,7 @@ impl Decompressor { self.header.code_lengths[self.header.num_lengths_read + i] = value; } self.header.num_lengths_read += repeat; - self.consume_bits(length + extra_bits); + self.bits.consume_bits(length + extra_bits); } _ => unreachable!(), } @@ -446,12 +416,17 @@ impl Decompressor { Ok(()) } + /// Returns: + /// - Whether this compressed block ended or not + /// - The new value of `output_index` fn read_compressed( &mut self, remaining_input: &mut &[u8], output: &mut [u8], mut output_index: usize, - ) -> Result { + ) -> Result<(CompressedBlockStatus, usize), DecompressionError> { + debug_assert_eq!(self.state, State::CompressedData); + // Fast decoding loop. // // This loop is optimized for speed and is the main decoding loop for the decompressor, @@ -464,24 +439,21 @@ impl Decompressor { // - The litlen_entry for the next loop iteration can be loaded in parallel with refilling // the bit buffer. This is because when the input is non-empty, the bit buffer actually // has 64-bits of valid data (even though nbits will be in 56..=63). - self.fill_buffer(remaining_input); - let mut litlen_entry = self.compression.litlen_table[(self.buffer & 0xfff) as usize]; - while self.state == State::CompressedData - && output_index + 8 <= output.len() - && remaining_input.len() >= 8 - { + self.bits.fill_buffer(remaining_input); + let mut litlen_entry = self.compression.litlen_table[(self.bits.buffer & 0xfff) as usize]; + while output_index + 8 <= output.len() && remaining_input.len() >= 8 { // First check whether the next symbol is a literal. This code does up to 2 additional // table lookups to decode more literals. let mut bits; let mut litlen_code_bits = litlen_entry as u8; if litlen_entry & LITERAL_ENTRY != 0 { let litlen_entry2 = self.compression.litlen_table - [(self.buffer >> litlen_code_bits & 0xfff) as usize]; + [(self.bits.buffer >> litlen_code_bits & 0xfff) as usize]; let litlen_code_bits2 = litlen_entry2 as u8; let litlen_entry3 = self.compression.litlen_table - [(self.buffer >> (litlen_code_bits + litlen_code_bits2) & 0xfff) as usize]; + [(self.bits.buffer >> (litlen_code_bits + litlen_code_bits2) & 0xfff) as usize]; let litlen_code_bits3 = litlen_entry3 as u8; - let litlen_entry4 = self.compression.litlen_table[(self.buffer + let litlen_entry4 = self.compression.litlen_table[(self.bits.buffer >> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3) & 0xfff) as usize]; @@ -504,27 +476,28 @@ impl Decompressor { output_index += advance_output_bytes3; litlen_entry = litlen_entry4; - self.consume_bits(litlen_code_bits + litlen_code_bits2 + litlen_code_bits3); - self.fill_buffer(remaining_input); + self.bits + .consume_bits(litlen_code_bits + litlen_code_bits2 + litlen_code_bits3); + self.bits.fill_buffer(remaining_input); continue; } else { - self.consume_bits(litlen_code_bits + litlen_code_bits2); + self.bits.consume_bits(litlen_code_bits + litlen_code_bits2); litlen_entry = litlen_entry3; litlen_code_bits = litlen_code_bits3; - self.fill_buffer(remaining_input); - bits = self.buffer; + self.bits.fill_buffer(remaining_input); + bits = self.bits.buffer; } } else { - self.consume_bits(litlen_code_bits); - bits = self.buffer; + self.bits.consume_bits(litlen_code_bits); + bits = self.bits.buffer; litlen_entry = litlen_entry2; litlen_code_bits = litlen_code_bits2; - if self.nbits < 48 { - self.fill_buffer(remaining_input); + if self.bits.nbits < 48 { + self.bits.fill_buffer(remaining_input); } } } else { - bits = self.buffer; + bits = self.bits.buffer; } // The next symbol is either a 13+ bit literal, back-reference, or an EOF symbol. @@ -545,21 +518,17 @@ impl Decompressor { match litlen_symbol { 0..=255 => { - self.consume_bits(litlen_code_bits); + self.bits.consume_bits(litlen_code_bits); litlen_entry = - self.compression.litlen_table[(self.buffer & 0xfff) as usize]; - self.fill_buffer(remaining_input); + self.compression.litlen_table[(self.bits.buffer & 0xfff) as usize]; + self.bits.fill_buffer(remaining_input); output[output_index] = litlen_symbol as u8; output_index += 1; continue; } 256 => { - self.consume_bits(litlen_code_bits); - self.state = match self.last_block { - true => State::Checksum, - false => State::BlockHeader, - }; - break; + self.bits.consume_bits(litlen_code_bits); + return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index)); } _ => ( LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32, @@ -570,12 +539,8 @@ impl Decompressor { } else if litlen_code_bits == 0 { return Err(DecompressionError::InvalidLiteralLengthCode); } else { - self.consume_bits(litlen_code_bits); - self.state = match self.last_block { - true => State::Checksum, - false => State::BlockHeader, - }; - break; + self.bits.consume_bits(litlen_code_bits); + return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index)); }; bits >>= litlen_code_bits; @@ -615,19 +580,19 @@ impl Decompressor { return Err(DecompressionError::DistanceTooFarBack); } - self.consume_bits( + self.bits.consume_bits( litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits, ); - self.fill_buffer(remaining_input); - litlen_entry = self.compression.litlen_table[(self.buffer & 0xfff) as usize]; + self.bits.fill_buffer(remaining_input); + litlen_entry = self.compression.litlen_table[(self.bits.buffer & 0xfff) as usize]; let copy_length = length.min(output.len() - output_index); if dist == 1 { let last = output[output_index - 1]; output[output_index..][..copy_length].fill(last); - if copy_length < length { - self.queued_rle = Some((last, length - copy_length)); + if let Ok(length) = NonZeroUsize::try_from(length - copy_length) { + self.queued_output = Some(QueuedOutput::Rle { data: last, length }); output_index = output.len(); break; } @@ -652,8 +617,8 @@ impl Decompressor { ) } - if copy_length < length { - self.queued_backref = Some((dist, length - copy_length)); + if let Ok(length) = NonZeroUsize::try_from(length - copy_length) { + self.queued_output = Some(QueuedOutput::Backref { dist, length }); output_index = output.len(); break; } @@ -665,13 +630,13 @@ impl Decompressor { // // This loop processes the remaining input when we're too close to the end of the input or // output to use the fast loop. - while let State::CompressedData = self.state { - self.fill_buffer(remaining_input); + loop { + self.bits.fill_buffer(remaining_input); if output_index == output.len() { break; } - let mut bits = self.buffer; + let mut bits = self.bits.buffer; let litlen_entry = self.compression.litlen_table[(bits & 0xfff) as usize]; let litlen_code_bits = litlen_entry as u8; @@ -680,26 +645,29 @@ impl Decompressor { // output bytes and we can directly write them to the output buffer. let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize; - if self.nbits < litlen_code_bits { + if self.bits.nbits < litlen_code_bits { break; } else if output_index + 1 < output.len() { output[output_index] = (litlen_entry >> 16) as u8; output[output_index + 1] = (litlen_entry >> 24) as u8; output_index += advance_output_bytes; - self.consume_bits(litlen_code_bits); + self.bits.consume_bits(litlen_code_bits); continue; } else if output_index + advance_output_bytes == output.len() { debug_assert_eq!(advance_output_bytes, 1); output[output_index] = (litlen_entry >> 16) as u8; output_index += 1; - self.consume_bits(litlen_code_bits); + self.bits.consume_bits(litlen_code_bits); break; } else { debug_assert_eq!(advance_output_bytes, 2); output[output_index] = (litlen_entry >> 16) as u8; - self.queued_rle = Some(((litlen_entry >> 24) as u8, 1)); + self.queued_output = Some(QueuedOutput::Rle { + data: (litlen_entry >> 24) as u8, + length: NonZeroUsize::new(1).unwrap(), + }); output_index += 1; - self.consume_bits(litlen_code_bits); + self.bits.consume_bits(litlen_code_bits); break; } } @@ -719,20 +687,16 @@ impl Decompressor { let litlen_symbol = secondary_entry >> 4; let litlen_code_bits = (secondary_entry & 0xf) as u8; - if self.nbits < litlen_code_bits { + if self.bits.nbits < litlen_code_bits { break; } else if litlen_symbol < 256 { - self.consume_bits(litlen_code_bits); + self.bits.consume_bits(litlen_code_bits); output[output_index] = litlen_symbol as u8; output_index += 1; continue; } else if litlen_symbol == 256 { - self.consume_bits(litlen_code_bits); - self.state = match self.last_block { - true => State::Checksum, - false => State::BlockHeader, - }; - break; + self.bits.consume_bits(litlen_code_bits); + return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index)); } ( @@ -743,15 +707,11 @@ impl Decompressor { } else if litlen_code_bits == 0 { return Err(DecompressionError::InvalidLiteralLengthCode); } else { - if self.nbits < litlen_code_bits { + if self.bits.nbits < litlen_code_bits { break; } - self.consume_bits(litlen_code_bits); - self.state = match self.last_block { - true => State::Checksum, - false => State::BlockHeader, - }; - break; + self.bits.consume_bits(litlen_code_bits); + return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index)); }; bits >>= litlen_code_bits; @@ -766,7 +726,7 @@ impl Decompressor { (dist_entry >> 8) as u8 & 0xf, dist_entry as u8, ) - } else if self.nbits > litlen_code_bits + length_extra_bits + 9 { + } else if self.bits.nbits > litlen_code_bits + length_extra_bits + 9 { if dist_entry >> 8 == 0 { return Err(DecompressionError::InvalidDistanceCode); } @@ -794,21 +754,21 @@ impl Decompressor { let total_bits = litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits; - if self.nbits < total_bits { + if self.bits.nbits < total_bits { break; } else if dist > output_index { return Err(DecompressionError::DistanceTooFarBack); } - self.consume_bits(total_bits); + self.bits.consume_bits(total_bits); let copy_length = length.min(output.len() - output_index); if dist == 1 { let last = output[output_index - 1]; output[output_index..][..copy_length].fill(last); - if copy_length < length { - self.queued_rle = Some((last, length - copy_length)); + if let Ok(length) = NonZeroUsize::try_from(length - copy_length) { + self.queued_output = Some(QueuedOutput::Rle { data: last, length }); output_index = output.len(); break; } @@ -833,8 +793,8 @@ impl Decompressor { ) } - if copy_length < length { - self.queued_backref = Some((dist, length - copy_length)); + if let Ok(length) = NonZeroUsize::try_from(length - copy_length) { + self.queued_output = Some(QueuedOutput::Backref { dist, length }); output_index = output.len(); break; } @@ -842,20 +802,16 @@ impl Decompressor { output_index += copy_length; } - if self.state == State::CompressedData - && self.queued_backref.is_none() - && self.queued_rle.is_none() - && self.nbits >= 15 - && self.peak_bits(15) as u16 & self.compression.eof_mask == self.compression.eof_code + if self.queued_output.is_none() + && self.bits.nbits >= 15 + && self.bits.peek_bits(15) as u16 & self.compression.eof_mask + == self.compression.eof_code { - self.consume_bits(self.compression.eof_bits); - self.state = match self.last_block { - true => State::Checksum, - false => State::BlockHeader, - }; + self.bits.consume_bits(self.compression.eof_bits); + return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index)); } - Ok(output_index) + Ok((CompressedBlockStatus::MoreDataPresent, output_index)) } /// Decompresses a chunk of data. @@ -892,24 +848,30 @@ impl Decompressor { let mut remaining_input = input; let mut output_index = output_position; - if let Some((data, len)) = self.queued_rle.take() { - let n = len.min(output.len() - output_index); - output[output_index..][..n].fill(data); - output_index += n; - if n < len { - self.queued_rle = Some((data, len - n)); - return Ok((0, n)); - } - } - if let Some((dist, len)) = self.queued_backref.take() { - let n = len.min(output.len() - output_index); - for i in 0..n { - output[output_index + i] = output[output_index + i - dist]; - } - output_index += n; - if n < len { - self.queued_backref = Some((dist, len - n)); - return Ok((0, n)); + if let Some(queued_output) = self.queued_output.take() { + match queued_output { + QueuedOutput::Rle { data, length } => { + let length: usize = length.into(); + let n = length.min(output.len() - output_index); + output[output_index..][..n].fill(data); + output_index += n; + if let Ok(length) = NonZeroUsize::try_from(length - n) { + self.queued_output = Some(QueuedOutput::Rle { data, length }); + return Ok((0, n)); + } + } + QueuedOutput::Backref { dist, length } => { + let length: usize = length.into(); + let n = length.min(output.len() - output_index); + for i in 0..n { + output[output_index + i] = output[output_index + i - dist]; + } + output_index += n; + if let Ok(length) = NonZeroUsize::try_from(length - n) { + self.queued_output = Some(QueuedOutput::Backref { dist, length }); + return Ok((0, n)); + } + } } } @@ -919,13 +881,13 @@ impl Decompressor { last_state = Some(self.state); match self.state { State::ZlibHeader => { - self.fill_buffer(&mut remaining_input); - if self.nbits < 16 { + self.bits.fill_buffer(&mut remaining_input); + if self.bits.nbits < 16 { break; } - let input0 = self.peak_bits(8); - let input1 = self.peak_bits(16) >> 8 & 0xff; + let input0 = self.bits.peek_bits(8); + let input1 = self.bits.peek_bits(16) >> 8 & 0xff; if input0 & 0x0f != 0x08 || (input0 & 0xf0) > 0x70 || input1 & 0x20 != 0 @@ -934,7 +896,7 @@ impl Decompressor { return Err(DecompressionError::BadZlibHeader); } - self.consume_bits(16); + self.bits.consume_bits(16); self.state = State::BlockHeader; } State::BlockHeader => { @@ -947,24 +909,31 @@ impl Decompressor { self.read_code_lengths(&mut remaining_input)?; } State::CompressedData => { - output_index = - self.read_compressed(&mut remaining_input, output, output_index)? + let (compresed_block_status, new_output_index) = + self.read_compressed(&mut remaining_input, output, output_index)?; + output_index = new_output_index; + if compresed_block_status == CompressedBlockStatus::ReachedEndOfBlock { + self.state = match self.last_block { + true => State::Checksum, + false => State::BlockHeader, + }; + } } State::UncompressedData => { // Drain any bytes from our buffer. - debug_assert_eq!(self.nbits % 8, 0); - while self.nbits > 0 + debug_assert_eq!(self.bits.nbits % 8, 0); + while self.bits.nbits > 0 && self.uncompressed_bytes_left > 0 && output_index < output.len() { - output[output_index] = self.peak_bits(8) as u8; - self.consume_bits(8); + output[output_index] = self.bits.peek_bits(8) as u8; + self.bits.consume_bits(8); output_index += 1; self.uncompressed_bytes_left -= 1; } // Buffer may contain one additional byte. Clear it to avoid confusion. - if self.nbits == 0 { - self.buffer = 0; + if self.bits.nbits == 0 { + self.bits.buffer = 0; } // Copy subsequent bytes directly from the input. @@ -986,22 +955,23 @@ impl Decompressor { } } State::Checksum => { - self.fill_buffer(&mut remaining_input); + self.bits.fill_buffer(&mut remaining_input); - let align_bits = self.nbits % 8; - if self.nbits >= 32 + align_bits { + let align_bits = self.bits.nbits % 8; + if self.bits.nbits >= 32 + align_bits { self.checksum.write(&output[output_position..output_index]); if align_bits != 0 { - self.consume_bits(align_bits); + self.bits.consume_bits(align_bits); } #[cfg(not(fuzzing))] if !self.ignore_adler32 - && (self.peak_bits(32) as u32).swap_bytes() != self.checksum.finish() + && (self.bits.peek_bits(32) as u32).swap_bytes() + != self.checksum.finish() { return Err(DecompressionError::WrongChecksum); } self.state = State::Done; - self.consume_bits(32); + self.bits.consume_bits(32); break; } } @@ -1027,6 +997,61 @@ impl Decompressor { } } +#[derive(Debug)] +struct BitBuffer { + buffer: u64, + nbits: u8, +} + +impl BitBuffer { + fn new() -> Self { + Self { + buffer: 0, + nbits: 0, + } + } + + fn fill_buffer(&mut self, input: &mut &[u8]) { + if input.len() >= 8 { + self.buffer |= u64::from_le_bytes(input[..8].try_into().unwrap()) << self.nbits; + *input = &input[(63 - self.nbits as usize) / 8..]; + self.nbits |= 56; + } else { + let nbytes = input.len().min((63 - self.nbits as usize) / 8); + let mut input_data = [0; 8]; + input_data[..nbytes].copy_from_slice(&input[..nbytes]); + self.buffer |= u64::from_le_bytes(input_data) + .checked_shl(self.nbits as u32) + .unwrap_or(0); + self.nbits += nbytes as u8 * 8; + *input = &input[nbytes..]; + } + } + + fn peek_bits(&mut self, nbits: u8) -> u64 { + debug_assert!(nbits <= 56 && nbits <= self.nbits); + self.buffer & ((1u64 << nbits) - 1) + } + + fn consume_bits(&mut self, nbits: u8) { + debug_assert!(self.nbits >= nbits); + self.buffer >>= nbits; + self.nbits -= nbits; + } +} + +#[derive(Debug)] +enum QueuedOutput { + Rle { data: u8, length: NonZeroUsize }, + Backref { dist: usize, length: NonZeroUsize }, +} + +#[derive(Debug, Eq, PartialEq)] +enum CompressedBlockStatus { + MoreDataPresent, + ReachedEndOfBlock, +} + /// Decompress the given data. pub fn decompress_to_vec(input: &[u8]) -> Result, DecompressionError> { match decompress_to_vec_bounded(input, usize::MAX) {