diff --git a/noodles-cram/src/codecs/rans_4x8/encode/order_1.rs b/noodles-cram/src/codecs/rans_4x8/encode/order_1.rs index 79915354e..4d771e414 100644 --- a/noodles-cram/src/codecs/rans_4x8/encode/order_1.rs +++ b/noodles-cram/src/codecs/rans_4x8/encode/order_1.rs @@ -24,34 +24,21 @@ pub fn encode(src: &[u8]) -> io::Result> { let mut buf = Vec::new(); let mut states = [LOWER_BOUND; STATE_COUNT]; - // The input data is split into 4 equally sized chunks. - let quarter = src.len() / states.len(); - - let chunks = [ - &src[0..quarter], - &src[quarter..2 * quarter], - &src[2 * quarter..3 * quarter], - &src[3 * quarter..4 * quarter], - ]; - - // The remainder of the input buffer is processed by the last state. - if src.len() > 4 * quarter { - // The last chunk includes the last symbol of the fourth chunk (the subtraction by 1). This - // is safe because chunks are guaranteed to be nonempty. - let remainder = &src[4 * quarter - 1..]; + let [chunk_0, chunk_1, chunk_2, chunk_3, chunk_4] = split_chunks(src); - for syms in remainder.windows(2).rev() { - let (i, j) = (usize::from(syms[0]), usize::from(syms[1])); - states[3] = state_renormalize(states[3], frequencies[i][j], &mut buf)?; - states[3] = state_step(states[3], frequencies[i][j], cumulative_frequencies[i][j]); - } + // ยง 2.2.1 "rANS entropy encoding: Interleaving" (2023-03-15): "Any remainder, when the input + // buffer is not divisible by 4, is processed ... by the 4th rANS state." + for syms in chunk_4.windows(2).rev() { + let (i, j) = (usize::from(syms[0]), usize::from(syms[1])); + states[3] = state_renormalize(states[3], frequencies[i][j], &mut buf)?; + states[3] = state_step(states[3], frequencies[i][j], cumulative_frequencies[i][j]); } let mut windows = [ - chunks[0].windows(2).rev(), - chunks[1].windows(2).rev(), - chunks[2].windows(2).rev(), - chunks[3].windows(2).rev(), + chunk_0.windows(2).rev(), + chunk_1.windows(2).rev(), + chunk_2.windows(2).rev(), + chunk_3.windows(2).rev(), ]; let mut n = 0; @@ -68,7 +55,8 @@ pub fn encode(src: &[u8]) -> io::Result> { n += 1; } - // The last state updates are for the starting contexts, i.e., `(0, chunks[i][0])`. + let chunks = [chunk_0, chunk_1, chunk_2, chunk_3]; + for (state, chunk) in states.iter_mut().rev().zip(chunks.iter().rev()) { let (i, j) = (usize::from(NUL), usize::from(chunk[0])); *state = state_renormalize(*state, frequencies[i][j], &mut buf)?; @@ -181,6 +169,20 @@ fn build_cumulative_frequencies(frequencies: &Frequencies) -> CumulativeFrequenc cumulative_frequencies } +fn split_chunks(buf: &[u8]) -> [&[u8]; 5] { + let chunk_size = buf.len() / STATE_COUNT; + + let (left_chunk, right_chunk) = buf.split_at(2 * chunk_size); + let (chunk_0, chunk_1) = left_chunk.split_at(chunk_size); + let (chunk_2, chunk_3_4) = right_chunk.split_at(chunk_size); + let (chunk_3, _) = chunk_3_4.split_at(chunk_size); + + // The last chunk includes the last byte of chunk 3 for context. + let chunk_4 = &chunk_3_4[chunk_size - 1..]; + + [chunk_0, chunk_1, chunk_2, chunk_3, chunk_4] +} + #[cfg(test)] mod tests { use super::*; @@ -282,4 +284,17 @@ mod tests { assert_eq!(normalize_frequencies(&raw_frequencies), expected); } + + #[test] + fn test_split_chunks() { + assert_eq!( + split_chunks(&[1, 2, 3, 4]), + [&[1][..], &[2][..], &[3][..], &[4][..], &[4][..]] + ); + + assert_eq!( + split_chunks(&[1, 2, 3, 4, 5]), + [&[1][..], &[2][..], &[3][..], &[4][..], &[4, 5][..]] + ); + } }