From 718f46f56032724ef354a19dedb3af77f483ec46 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Fri, 28 Feb 2025 10:31:34 +0100 Subject: [PATCH] fix: BlockDecomposer The BlockDecomposer gave the possibility when the number of bits per block was not a multiple of the number of bits in the original integer to force the extra bits of the last block to a particular value. However, the way this was done could only work when setting these bits to 1, when wanting to set them to 0 it would not work. Good news is that we actually never wanted to set them to 0, but it should still be fixed for completeness, and allow other feature to be added without bugs --- tfhe/src/integer/block_decomposition.rs | 133 ++++++++++++------ .../integer/server_key/radix/scalar_sub.rs | 15 +- 2 files changed, 96 insertions(+), 52 deletions(-) diff --git a/tfhe/src/integer/block_decomposition.rs b/tfhe/src/integer/block_decomposition.rs index de3d88f62b..44e3af3d35 100644 --- a/tfhe/src/integer/block_decomposition.rs +++ b/tfhe/src/integer/block_decomposition.rs @@ -2,7 +2,7 @@ use crate::core_crypto::prelude::{CastFrom, CastInto, Numeric}; use crate::integer::bigint::static_signed::StaticSignedBigInt; use crate::integer::bigint::static_unsigned::StaticUnsignedBigInt; use core::ops::{AddAssign, BitAnd, ShlAssign, ShrAssign}; -use std::ops::{BitOrAssign, Shl, Sub}; +use std::ops::{BitOrAssign, Not, Shl, Shr, Sub}; // These work for signed number as rust uses 2-Complements // And Arithmetic shift for signed number (logical for unsigned) @@ -14,8 +14,10 @@ pub trait Decomposable: + ShrAssign + Eq + CastFrom + + Shr + Shl + BitOrAssign + + Not { } pub trait Recomposable: @@ -86,13 +88,20 @@ impl RecomposableFrom for StaticUnsignedBigInt {} impl DecomposableInto for StaticUnsignedBigInt {} impl DecomposableInto for StaticUnsignedBigInt {} +#[derive(Copy, Clone)] +#[repr(u32)] +pub enum PaddingBitValue { + Zero = 0, + One = 1, +} + #[derive(Clone)] pub struct BlockDecomposer { data: T, bit_mask: T, num_bits_in_mask: u32, num_bits_valid: u32, - padding_bit: T, + padding_bit: Option, limit: Option, } @@ -100,19 +109,27 @@ impl BlockDecomposer where T: Decomposable, { + /// Creates a block decomposer that will stop when the value reaches zero pub fn with_early_stop_at_zero(value: T, bits_per_block: u32) -> Self { - Self::new_(value, bits_per_block, Some(T::ZERO), T::ZERO) + Self::new_(value, bits_per_block, Some(T::ZERO), None) } - pub fn with_padding_bit(value: T, bits_per_block: u32, padding_bit: T) -> Self { - Self::new_(value, bits_per_block, None, padding_bit) + /// Creates a block decomposer that will set the surplus bits to a specific value + /// when bits_per_block is not a multiple of T::BITS + pub fn with_padding_bit(value: T, bits_per_block: u32, padding_bit: PaddingBitValue) -> Self { + Self::new_(value, bits_per_block, None, Some(padding_bit)) } pub fn new(value: T, bits_per_block: u32) -> Self { - Self::new_(value, bits_per_block, None, T::ZERO) + Self::new_(value, bits_per_block, None, None) } - fn new_(value: T, bits_per_block: u32, limit: Option, padding_bit: T) -> Self { + fn new_( + value: T, + bits_per_block: u32, + limit: Option, + padding_bit: Option, + ) -> Self { assert!(bits_per_block <= T::BITS as u32); let num_bits_valid = T::BITS as u32; @@ -129,6 +146,31 @@ where padding_bit, } } + + // We concretize the iterator type to allow usage of callbacks working on iterator for generic + // integer encryption + pub fn iter_as(self) -> std::iter::Map V> + where + V: Numeric, + T: CastInto, + { + assert!(self.num_bits_in_mask <= V::BITS as u32); + self.map(CastInto::cast_into) + } + + pub fn next_as(&mut self) -> Option + where + V: CastFrom, + { + self.next().map(|masked| V::cast_from(masked)) + } + + pub fn checked_next_as(&mut self) -> Option + where + V: TryFrom, + { + self.next().and_then(|masked| V::try_from(masked).ok()) + } } impl Iterator for BlockDecomposer @@ -159,11 +201,18 @@ where if self.num_bits_valid < self.num_bits_in_mask { // This will be the case when self.num_bits_in_mask is not a multiple - // of T::BITS. We replace bits that - // do not come from the actual T but from the padding - // intoduced by the shift, to a specific value. - for i in self.num_bits_valid..self.num_bits_in_mask { - masked |= self.padding_bit << i; + // of T::BITS. + // + // We replace bits that do not come from the actual T but from the padding + // introduced by the shift, to a specific value, if one was provided. + if let Some(padding_bit) = self.padding_bit { + let padding_mask = (self.bit_mask >> self.num_bits_valid) << self.num_bits_valid; + masked = masked & !padding_mask; + + let padding_bit = T::cast_from(padding_bit as u32); + for i in self.num_bits_valid..self.num_bits_in_mask { + masked |= padding_bit << i; + } } } @@ -184,36 +233,6 @@ where } } -impl BlockDecomposer -where - T: Decomposable, -{ - // We concretize the iterator type to allow usage of callbacks working on iterator for generic - // integer encryption - pub fn iter_as(self) -> std::iter::Map V> - where - V: Numeric, - T: CastInto, - { - assert!(self.num_bits_in_mask <= V::BITS as u32); - self.map(CastInto::cast_into) - } - - pub fn next_as(&mut self) -> Option - where - V: CastFrom, - { - self.next().map(|masked| V::cast_from(masked)) - } - - pub fn checked_next_as(&mut self) -> Option - where - V: TryFrom, - { - self.next().and_then(|masked| V::try_from(masked).ok()) - } -} - pub struct BlockRecomposer { data: T, bit_mask: T, @@ -310,6 +329,36 @@ mod tests { assert_eq!(expected_blocks, blocks); } + #[test] + fn test_bit_block_decomposer_3() { + let bits_per_block = 3; + + let value = -1i8; + let blocks = BlockDecomposer::new(value, bits_per_block) + .iter_as::() + .collect::>(); + // We expect the last block padded with 1s as a consequence of arithmetic shift + let expected_blocks = vec![7, 7, 7]; + assert_eq!(expected_blocks, blocks); + + let value = i8::MIN; + let blocks = BlockDecomposer::new(value, bits_per_block) + .iter_as::() + .collect::>(); + // We expect the last block padded with 1s as a consequence of arithmetic shift + let expected_blocks = vec![0, 0, 6]; + assert_eq!(expected_blocks, blocks); + + let value = -1i8; + let blocks = + BlockDecomposer::with_padding_bit(value, bits_per_block, PaddingBitValue::Zero) + .iter_as::() + .collect::>(); + // We expect the last block padded with 0s as we force that + let expected_blocks = vec![7, 7, 3]; + assert_eq!(expected_blocks, blocks); + } + #[test] fn test_bit_block_decomposer_recomposer_carry_handling_in_between() { let value = u16::MAX as u32; diff --git a/tfhe/src/integer/server_key/radix/scalar_sub.rs b/tfhe/src/integer/server_key/radix/scalar_sub.rs index 8174e2b436..c8ce6cb054 100644 --- a/tfhe/src/integer/server_key/radix/scalar_sub.rs +++ b/tfhe/src/integer/server_key/radix/scalar_sub.rs @@ -1,5 +1,5 @@ use crate::core_crypto::prelude::Numeric; -use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; +use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto, PaddingBitValue}; use crate::integer::ciphertext::{IntegerRadixCiphertext, RadixCiphertext}; use crate::integer::server_key::CheckError; use crate::integer::ServerKey; @@ -92,17 +92,12 @@ impl ServerKey { // The only case where these msb could become 0 after the addition // is if scalar == T::ZERO (=> !T::ZERO == T::MAX => T::MAX + 1 == overflow), // but this case has been handled earlier. - let padding_bit = 1u32; // To handle when bits is not a multiple of T::BITS - // All bits of message set to one let pad_block = (1 << bits_in_message as u8) - 1; - let decomposer = BlockDecomposer::with_padding_bit( - neg_scalar, - bits_in_message, - Scalar::cast_from(padding_bit), - ) - .iter_as::() - .chain(std::iter::repeat(pad_block)); + let decomposer = + BlockDecomposer::with_padding_bit(neg_scalar, bits_in_message, PaddingBitValue::One) + .iter_as::() + .chain(std::iter::repeat(pad_block)); Some(decomposer) }