From 55a4d1898a46f8b45071e9113cc6205437da7b69 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Fri, 28 Feb 2025 13:40:26 +0100 Subject: [PATCH] fix: upcasting of signed integer when block decomposing Some parts of the code did not use the correct way to decompose a clear integer into blocks which could be encrypted or used in scalar ops. The sign extension was not always properly done, leading for example in the encryption of a negative integer stored on a i8 to a SignedRadixCiphertext with a num_blocks greater than i8 to be incorrect: ``` let ct = cks.encrypt_signed(-1i8, 16) // 2_2 parameters let d: i32 = cks.decrypt_signed(&ct); assert_eq!(d, i32::from(-1i8)); // Fails ``` To fix, a BlockDecomposer::with_block_count function is added and used This function will properly do the sign extension when needed --- tfhe/src/integer/block_decomposition.rs | 61 +++++++++++++- tfhe/src/integer/encryption.rs | 11 +-- tfhe/src/integer/gpu/server_key/radix/mod.rs | 10 ++- tfhe/src/integer/server_key/radix/tests.rs | 52 ++++++++++-- .../server_key/radix_parallel/scalar_add.rs | 16 ++-- .../radix_parallel/scalar_comparison.rs | 82 +++++++++---------- .../server_key/radix_parallel/scalar_sub.rs | 1 + 7 files changed, 161 insertions(+), 72 deletions(-) diff --git a/tfhe/src/integer/block_decomposition.rs b/tfhe/src/integer/block_decomposition.rs index 4eb61c6f54..9610747074 100644 --- a/tfhe/src/integer/block_decomposition.rs +++ b/tfhe/src/integer/block_decomposition.rs @@ -115,6 +115,21 @@ where Self::new_(value, bits_per_block, None, Some(padding_bit)) } + /// Creates a block decomposer that will return `block_count` blocks + /// + /// * If T is signed, extra block will be sign extended + pub fn with_block_count(value: T, bits_per_block: u32, block_count: u32) -> Self { + let mut decomposer = Self::new(value, bits_per_block); + // If the new number of bits is less than the actual number of bits, it means + // data will be truncated + // + // If the new number of bits is greater than the actual number of bits, it means + // the right shift used internally will correctly sign extend for us + let num_bits_valid = block_count * bits_per_block; + decomposer.num_bits_valid = num_bits_valid; + decomposer + } + pub fn new(value: T, bits_per_block: u32) -> Self { Self::new_(value, bits_per_block, None, None) } @@ -238,7 +253,8 @@ where T: Recomposable, { pub fn value(&self) -> T { - if self.bit_pos >= T::BITS as u32 { + let is_signed = (T::ONE << (T::BITS as u32 - 1)) < T::ZERO; + if self.bit_pos >= (T::BITS as u32 - u32::from(is_signed)) { self.data } else { let valid_mask = (T::ONE << self.bit_pos) - T::ONE; @@ -351,6 +367,49 @@ mod tests { assert_eq!(expected_blocks, blocks); } + #[test] + fn test_bit_block_decomposer_with_block_count() { + let bits_per_block = 3; + let expected_blocks = vec![0, 0, 6, 7, 7, 7, 7, 7, 7]; + let value = i8::MIN; + for block_count in 1..expected_blocks.len() as u32 { + let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count) + .iter_as::() + .collect::>(); + assert_eq!(expected_blocks[..block_count as usize], blocks); + } + + let bits_per_block = 3; + let expected_blocks = vec![7, 7, 1, 0, 0, 0, 0, 0, 0]; + let value = i8::MAX; + for block_count in 1..expected_blocks.len() as u32 { + let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count) + .iter_as::() + .collect::>(); + assert_eq!(expected_blocks[..block_count as usize], blocks); + } + + let bits_per_block = 2; + let expected_blocks = vec![0, 0, 0, 2, 3, 3, 3, 3, 3]; + let value = i8::MIN; + for block_count in 1..expected_blocks.len() as u32 { + let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count) + .iter_as::() + .collect::>(); + assert_eq!(expected_blocks[..block_count as usize], blocks); + } + + let bits_per_block = 2; + let expected_blocks = vec![3, 3, 3, 1, 0, 0, 0, 0, 0, 0]; + let value = i8::MAX; + for block_count in 1..expected_blocks.len() as u32 { + let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count) + .iter_as::() + .collect::>(); + assert_eq!(expected_blocks[..block_count as usize], blocks); + } + } + #[test] fn test_bit_block_decomposer_recomposer_carry_handling_in_between() { let value = u16::MAX as u32; diff --git a/tfhe/src/integer/encryption.rs b/tfhe/src/integer/encryption.rs index 1b8831c33e..247ef56480 100644 --- a/tfhe/src/integer/encryption.rs +++ b/tfhe/src/integer/encryption.rs @@ -92,9 +92,7 @@ where // We need to concretize the iterator type to be able to pass callbacks consuming the iterator, // having an opaque return impl Iterator does not allow to take callbacks at this moment, not sure // the Fn(impl Trait) syntax can be made to work nicely with the rest of the language -pub(crate) type ClearRadixBlockIterator = std::iter::Take< - std::iter::Chain, fn(T) -> u64>, std::iter::Repeat>, ->; +pub(crate) type ClearRadixBlockIterator = std::iter::Map, fn(T) -> u64>; pub(crate) fn create_clear_radix_block_iterator( message: T, @@ -105,12 +103,7 @@ where T: DecomposableInto, { let bits_in_block = message_modulus.0.ilog2(); - let decomposer = BlockDecomposer::new(message, bits_in_block); - - decomposer - .iter_as::() - .chain(std::iter::repeat(0u64)) - .take(num_blocks) + BlockDecomposer::with_block_count(message, bits_in_block, num_blocks as u32).iter_as::() } pub(crate) fn encrypt_crt( diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index 67b0c9e224..3f530e00ae 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -187,10 +187,12 @@ impl CudaServerKey { PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_size(), }; - let decomposer = BlockDecomposer::new(scalar, self.message_modulus.0.ilog2()) - .iter_as::() - .chain(std::iter::repeat(0)) - .take(num_blocks); + let decomposer = BlockDecomposer::with_block_count( + scalar, + self.message_modulus.0.ilog2(), + num_blocks as u32, + ) + .iter_as::(); let mut cpu_lwe_list = LweCiphertextList::new( 0, lwe_size, diff --git a/tfhe/src/integer/server_key/radix/tests.rs b/tfhe/src/integer/server_key/radix/tests.rs index 184a67c207..2d33277b3e 100644 --- a/tfhe/src/integer/server_key/radix/tests.rs +++ b/tfhe/src/integer/server_key/radix/tests.rs @@ -48,6 +48,7 @@ create_parameterized_test_classical_params!(integer_encrypt_decrypt_128_bits); create_parameterized_test_classical_params!(integer_encrypt_decrypt_128_bits_specific_values); create_parameterized_test_classical_params!(integer_encrypt_decrypt_256_bits_specific_values); create_parameterized_test_classical_params!(integer_encrypt_decrypt_256_bits); +create_parameterized_test_classical_params!(integer_encrypt_auto_cast); create_parameterized_test_classical_params!(integer_unchecked_add); create_parameterized_test_classical_params!(integer_smart_add); create_parameterized_test!( @@ -157,7 +158,7 @@ fn integer_encrypt_decrypt_128_bits(param: ClassicPBSParameters) { let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); let mut rng = rand::thread_rng(); - let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize; + let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize; for _ in 0..10 { let clear = rng.gen::(); @@ -172,7 +173,7 @@ fn integer_encrypt_decrypt_128_bits(param: ClassicPBSParameters) { fn integer_encrypt_decrypt_128_bits_specific_values(param: ClassicPBSParameters) { let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize; + let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize; { let a = u64::MAX as u128; let ct = cks.encrypt_radix(a, num_block); @@ -220,7 +221,7 @@ fn integer_encrypt_decrypt_128_bits_specific_values(param: ClassicPBSParameters) fn integer_encrypt_decrypt_256_bits_specific_values(param: ClassicPBSParameters) { let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let num_block = (256f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize; + let num_block = 256u32.div_ceil(param.message_modulus.0.ilog2()) as usize; { let a = (u64::MAX as u128) << 64; let b = 0; @@ -245,7 +246,7 @@ fn integer_encrypt_decrypt_256_bits(param: ClassicPBSParameters) { let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); let mut rng = rand::thread_rng(); - let num_block = (256f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize; + let num_block = 256u32.div_ceil(param.message_modulus.0.ilog2()) as usize; for _ in 0..10 { let clear0 = rng.gen::(); @@ -261,11 +262,52 @@ fn integer_encrypt_decrypt_256_bits(param: ClassicPBSParameters) { } } +fn integer_encrypt_auto_cast(param: ClassicPBSParameters) { + // The goal is to test that encrypting a value stored in a type + // for which the bit count does not match the target block count of the encrypted + // radix properly applies upcasting/downcasting + + let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let mut rng = rand::thread_rng(); + + let num_blocks = 32u32.div_ceil(param.message_modulus.0.ilog2()) as usize; + + // Positive signed value + let value = rng.gen_range(0..=i32::MAX); + let ct = cks.encrypt_signed_radix(value, num_blocks * 2); + let d: i64 = cks.decrypt_signed_radix(&ct); + assert_eq!(i64::from(value), d); + + let ct = cks.encrypt_signed_radix(value, num_blocks.div_ceil(2)); + let d: i16 = cks.decrypt_signed_radix(&ct); + assert_eq!(value as i16, d); + + // Negative signed value + let value = rng.gen_range(i8::MIN..0); + let ct = cks.encrypt_signed_radix(value, num_blocks * 2); + let d: i64 = cks.decrypt_signed_radix(&ct); + assert_eq!(i64::from(value), d); + + let ct = cks.encrypt_signed_radix(value, num_blocks.div_ceil(2)); + let d: i16 = cks.decrypt_signed_radix(&ct); + assert_eq!(value as i16, d); + + // Unsigned value + let value = rng.gen::(); + let ct = cks.encrypt_radix(value, num_blocks * 2); + let d: u64 = cks.decrypt_radix(&ct); + assert_eq!(u64::from(value), d); + + let ct = cks.encrypt_radix(value, num_blocks.div_ceil(2)); + let d: u16 = cks.decrypt_radix(&ct); + assert_eq!(value as u16, d); +} + fn integer_smart_add_128_bits(param: ClassicPBSParameters) { let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); let mut rng = rand::thread_rng(); - let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize; + let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize; for _ in 0..100 { let clear_0 = rng.gen::(); diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs index d8dca15382..afad90302d 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs @@ -276,15 +276,13 @@ impl ServerKey { self.full_propagate_parallelized(ct); } - let scalar_blocks = BlockDecomposer::new(scalar, self.message_modulus().0.ilog2()) - .iter_as::() - .chain(std::iter::repeat(if scalar < Scalar::ZERO { - (self.message_modulus().0 - 1) as u8 - } else { - 0 - })) - .take(ct.blocks().len()) - .collect(); + let scalar_blocks = BlockDecomposer::with_block_count( + scalar, + self.message_modulus().0.ilog2(), + ct.blocks().len() as u32, + ) + .iter_as::() + .collect(); const COMPUTE_OVERFLOW: bool = false; const INPUT_CARRY: bool = false; diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs index b43d8d24f5..0275541ff8 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs @@ -757,12 +757,10 @@ impl ServerKey { .map(|chunk_of_two| self.pack_block_chunk(chunk_of_two)) .collect::>(); - let padding_value = (packed_modulus - 1) * u64::from(b < Scalar::ZERO); - let mut b_blocks = BlockDecomposer::new(b, packed_modulus.ilog2()) - .iter_as::() - .chain(std::iter::repeat(padding_value)) - .take(a.len()) - .collect::>(); + let mut b_blocks = + BlockDecomposer::with_block_count(b, packed_modulus.ilog2(), a.len() as u32) + .iter_as::() + .collect::>(); if !num_block_is_even && b < Scalar::ZERO { let last_index = b_blocks.len() - 1; @@ -1058,25 +1056,23 @@ impl ServerKey { Scalar: DecomposableInto, { let is_superior = self.unchecked_scalar_gt_parallelized(lhs, rhs); - let luts = BlockDecomposer::new(rhs, self.message_modulus().0.ilog2()) - .iter_as::() - .chain(std::iter::repeat(if rhs >= Scalar::ZERO { - 0u64 - } else { - self.message_modulus().0 - 1 - })) - .take(lhs.blocks().len()) - .map(|scalar_block| { - self.key - .generate_lookup_table_bivariate(|is_superior, block| { - if is_superior == 1 { - block - } else { - scalar_block - } - }) - }) - .collect::>(); + let luts = BlockDecomposer::with_block_count( + rhs, + self.message_modulus().0.ilog2(), + lhs.blocks().len() as u32, + ) + .iter_as::() + .map(|scalar_block| { + self.key + .generate_lookup_table_bivariate(|is_superior, block| { + if is_superior == 1 { + block + } else { + scalar_block + } + }) + }) + .collect::>(); let new_blocks = lhs .blocks() @@ -1097,25 +1093,23 @@ impl ServerKey { Scalar: DecomposableInto, { let is_inferior = self.unchecked_scalar_lt_parallelized(lhs, rhs); - let luts = BlockDecomposer::new(rhs, self.message_modulus().0.ilog2()) - .iter_as::() - .chain(std::iter::repeat(if rhs >= Scalar::ZERO { - 0u64 - } else { - self.message_modulus().0 - 1 - })) - .take(lhs.blocks().len()) - .map(|scalar_block| { - self.key - .generate_lookup_table_bivariate(|is_inferior, block| { - if is_inferior == 1 { - block - } else { - scalar_block - } - }) - }) - .collect::>(); + let luts = BlockDecomposer::with_block_count( + rhs, + self.message_modulus().0.ilog2(), + lhs.blocks().len() as u32, + ) + .iter_as::() + .map(|scalar_block| { + self.key + .generate_lookup_table_bivariate(|is_inferior, block| { + if is_inferior == 1 { + block + } else { + scalar_block + } + }) + }) + .collect::>(); let new_blocks = lhs .blocks() diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs index 3155f9eaa0..613554610a 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs @@ -641,6 +641,7 @@ impl ServerKey { const INPUT_CARRY: bool = true; let flipped_scalar = !scalar; let decomposed_flipped_scalar = + // We don't use BlockDecomposer::with_block_count as we are doing something special BlockDecomposer::new(flipped_scalar, self.message_modulus().0.ilog2()) .iter_as::() .chain(std::iter::repeat(if scalar < Scalar::ZERO {