From 863e0c275bed9d79aec08e1bd5e477ac067dc6e6 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Fri, 16 Feb 2024 11:36:58 +0100 Subject: [PATCH] feat(integer): add [trailing/leading]_[zeros/ones] --- .../server_key/radix_parallel/ilog2.rs | 785 ++++++++++++++++++ .../integer/server_key/radix_parallel/mod.rs | 1 + .../radix_parallel/tests_cases_unsigned.rs | 14 +- .../server_key/radix_parallel/tests_signed.rs | 50 +- .../radix_parallel/tests_unsigned.rs | 52 +- 5 files changed, 892 insertions(+), 10 deletions(-) create mode 100644 tfhe/src/integer/server_key/radix_parallel/ilog2.rs diff --git a/tfhe/src/integer/server_key/radix_parallel/ilog2.rs b/tfhe/src/integer/server_key/radix_parallel/ilog2.rs new file mode 100644 index 0000000000..2abba3ce0b --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/ilog2.rs @@ -0,0 +1,785 @@ +use crate::core_crypto::algorithms::misc::divide_ceil; +use crate::integer::{IntegerRadixCiphertext, RadixCiphertext, ServerKey}; + +use crate::shortint::Ciphertext; + +use rayon::prelude::*; + +/// A 'bit' value +/// +/// Used to improved readability over using a `bool`. +#[derive(Copy, Clone, Eq, PartialEq)] +#[repr(u64)] +enum BitValue { + Zero = 0, + One = 1, +} + +impl BitValue { + fn opposite(self) -> Self { + match self { + Self::One => Self::Zero, + Self::Zero => Self::One, + } + } +} + +/// Direction to count consecutive bits +#[derive(Copy, Clone, Eq, PartialEq)] +enum Direction { + /// Count starting from the LSB + Trailing, + /// Count starting from MSB + Leading, +} + +impl ServerKey { + /// This function takes a ciphertext in radix representation + /// and returns a vec of blocks, where each blocks holds the number of leading_zeros/ones + /// + /// This contains the logic of making a block have 0 leading_ones/zeros if its preceding + /// block was not full of ones/zeros + fn prepare_count_of_consecutive_bits( + &self, + ct: T, + direction: Direction, + bit_value: BitValue, + ) -> Vec + where + T: IntegerRadixCiphertext, + { + assert!( + self.carry_modulus().0 >= self.message_modulus().0, + "A carry modulus as least as big as the message modulus is required" + ); + + let mut blocks = ct.into_blocks(); + + let lut = match direction { + Direction::Trailing => self.key.generate_lookup_table(|x| { + let x = x % self.key.message_modulus.0 as u64; + + let mut count = 0; + for i in 0..self.key.message_modulus.0.ilog2() { + if (x >> i) & 1 == bit_value.opposite() as u64 { + break; + } + count += 1; + } + count + }), + Direction::Leading => self.key.generate_lookup_table(|x| { + let x = x % self.key.message_modulus.0 as u64; + + let mut count = 0; + for i in (0..self.key.message_modulus.0.ilog2()).rev() { + if (x >> i) & 1 == bit_value.opposite() as u64 { + break; + } + count += 1; + } + count + }), + }; + + // Assign to each block its number of leading/trailing zeros/ones + // in the message space + blocks.par_iter_mut().for_each(|block| { + self.key.apply_lookup_table_assign(block, &lut); + }); + + if direction == Direction::Leading { + // Our blocks are from lsb to msb + // `leading` means starting from the msb, so we reverse block + // for the cum sum process done later + blocks.reverse(); + } + + // Use hillis-steele cumulative-sum algorithm + // Here, each block either keeps his value (the number of leading zeros) + // or becomes 0 if the preceding block + // had a bit set to one in it (leading_zeros != num bits in message) + let num_bits_in_message = self.key.message_modulus.0.ilog2() as u64; + let sum_lut = self.key.generate_lookup_table_bivariate( + |block_num_bit_count, more_significant_block_bit_count| { + if more_significant_block_bit_count == num_bits_in_message { + block_num_bit_count + } else { + 0 + } + }, + ); + + let sum_function = + |block_num_bit_count: &mut Ciphertext, + more_significant_block_bit_count: &Ciphertext| { + self.key.unchecked_apply_lookup_table_bivariate_assign( + block_num_bit_count, + more_significant_block_bit_count, + &sum_lut, + ); + }; + self.compute_prefix_sum_hillis_steele(blocks, sum_function) + } + + /// Counts how many consecutive bits there are + /// + /// The returned Ciphertexts has a variable size + /// i.e. It contains just the minimum number of block + /// needed to represent the maximum possible number of bits. + fn count_consecutive_bits( + &self, + ct: &T, + direction: Direction, + bit_value: BitValue, + ) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + if ct.blocks().is_empty() { + return self.create_trivial_zero_radix(0); + } + + let num_bits_in_message = self.key.message_modulus.0.ilog2(); + let original_num_blocks = ct.blocks().len(); + + let num_bits_in_ciphertext = num_bits_in_message + .checked_mul(original_num_blocks as u32) + .expect("Number of bits encrypted exceeds u32::MAX"); + + let leading_count_per_blocks = + self.prepare_count_of_consecutive_bits(ct.clone(), direction, bit_value); + + // `num_bits_in_ciphertext` is the max value we want to represent + // its ilog2 + 1 gives use how many bits we need to be able to represent it. + let counter_num_blocks = divide_ceil( + num_bits_in_ciphertext.ilog2() + 1, + self.message_modulus().0.ilog2(), + ); + + let cts = leading_count_per_blocks + .into_iter() + .map(|block| { + let mut ct: RadixCiphertext = + self.create_trivial_zero_radix(counter_num_blocks as usize); + ct.blocks[0] = block; + ct + }) + .collect::>(); + + self.unchecked_sum_ciphertexts_vec_parallelized(cts) + .expect("internal error, empty ciphertext count") + } + + //============================================================================================== + // Unchecked + //============================================================================================== + + /// See [Self::trailing_zeros] + /// + /// Expects ct to have clean carries + pub fn unchecked_trailing_zeros(&self, ct: &T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + self.count_consecutive_bits(ct, Direction::Trailing, BitValue::Zero) + } + + /// See [Self::trailing_ones] + /// + /// Expects ct to have clean carries + pub fn unchecked_trailing_ones(&self, ct: &T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + self.count_consecutive_bits(ct, Direction::Trailing, BitValue::One) + } + + /// See [Self::leading_zeros] + /// + /// Expects ct to have clean carries + pub fn unchecked_leading_zeros(&self, ct: &T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + self.count_consecutive_bits(ct, Direction::Leading, BitValue::Zero) + } + + /// See [Self::leading_ones] + /// + /// Expects ct to have clean carries + pub fn unchecked_leading_ones(&self, ct: &T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + self.count_consecutive_bits(ct, Direction::Leading, BitValue::One) + } + + //============================================================================================== + // Smart + //============================================================================================== + + /// See [Self::trailing_zeros] + pub fn smart_trailing_zeros(&self, ct: &mut T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + if !ct.block_carries_are_empty() { + self.full_propagate_parallelized(ct); + } + + self.unchecked_trailing_zeros(ct) + } + + /// See [Self::trailing_ones] + pub fn smart_trailing_ones(&self, ct: &mut T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + if !ct.block_carries_are_empty() { + self.full_propagate_parallelized(ct); + } + + self.unchecked_trailing_ones(ct) + } + + /// See [Self::leading_zeros] + pub fn smart_leading_zeros(&self, ct: &mut T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + if !ct.block_carries_are_empty() { + self.full_propagate_parallelized(ct); + } + + self.unchecked_leading_zeros(ct) + } + + /// See [Self::leading_ones] + pub fn smart_leading_ones(&self, ct: &mut T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + if !ct.block_carries_are_empty() { + self.full_propagate_parallelized(ct); + } + + self.unchecked_leading_ones(ct) + } + + //============================================================================================== + // Default + //============================================================================================== + + /// Returns the number of trailing zeros in the binary representation of `ct` + /// + /// The returned Ciphertexts has a variable size + /// i.e. It contains just the minimum number of block + /// needed to represent the maximum possible number of bits. + /// + /// This is a default function, it will internally clone the ciphertext if it has + /// non propagated carries, and it will output a ciphertext without any carries. + /// + /// # Example + /// + /// ``` + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = -4i8; + /// + /// let ct1 = cks.encrypt_signed(msg); + /// + /// let n = sks.trailing_zeros(&ct1); + /// + /// // Decrypt: + /// let n: u32 = cks.decrypt(&n); + /// assert_eq!(n, msg.trailing_zeros()); + /// ``` + pub fn trailing_zeros(&self, ct: &T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + let mut tmp; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp = ct.clone(); + self.full_propagate_parallelized(&mut tmp); + &tmp + }; + self.unchecked_trailing_zeros(ct) + } + + /// Returns the number of trailing ones in the binary representation of `ct` + /// + /// The returned Ciphertexts has a variable size + /// i.e. It contains just the minimum number of block + /// needed to represent the maximum possible number of bits. + /// + /// This is a default function, it will internally clone the ciphertext if it has + /// non propagated carries, and it will output a ciphertext without any carries. + /// + /// # Example + /// + /// ``` + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = -4i8; + /// + /// let ct1 = cks.encrypt_signed(msg); + /// + /// let n = sks.trailing_ones(&ct1); + /// + /// // Decrypt: + /// let n: u32 = cks.decrypt(&n); + /// assert_eq!(n, msg.trailing_ones()); + /// ``` + pub fn trailing_ones(&self, ct: &T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + let mut tmp; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp = ct.clone(); + self.full_propagate_parallelized(&mut tmp); + &tmp + }; + self.unchecked_trailing_ones(ct) + } + + /// Returns the number of leading zeros in the binary representation of `ct` + /// + /// The returned Ciphertexts has a variable size + /// i.e. It contains just the minimum number of block + /// needed to represent the maximum possible number of bits. + /// + /// This is a default function, it will internally clone the ciphertext if it has + /// non propagated carries, and it will output a ciphertext without any carries. + /// + /// # Example + /// + /// ``` + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = -4i8; + /// + /// let ct1 = cks.encrypt_signed(msg); + /// + /// let n = sks.leading_zeros(&ct1); + /// + /// // Decrypt: + /// let n: u32 = cks.decrypt(&n); + /// assert_eq!(n, msg.leading_zeros()); + /// ``` + pub fn leading_zeros(&self, ct: &T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + let mut tmp; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp = ct.clone(); + self.full_propagate_parallelized(&mut tmp); + &tmp + }; + self.unchecked_leading_zeros(ct) + } + + /// Returns the number of leading ones in the binary representation of `ct` + /// + /// The returned Ciphertexts has a variable size + /// i.e. It contains just the minimum number of block + /// needed to represent the maximum possible number of bits. + /// + /// This is a default function, it will internally clone the ciphertext if it has + /// non propagated carries, and it will output a ciphertext without any carries. + /// + /// # Example + /// + /// ``` + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = -4i8; + /// + /// let ct1 = cks.encrypt_signed(msg); + /// + /// let n = sks.leading_ones(&ct1); + /// + /// // Decrypt: + /// let n: u32 = cks.decrypt(&n); + /// assert_eq!(n, msg.leading_ones()); + /// ``` + pub fn leading_ones(&self, ct: &T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + let mut tmp; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp = ct.clone(); + self.full_propagate_parallelized(&mut tmp); + &tmp + }; + self.unchecked_leading_ones(ct) + } +} + +#[cfg(test)] +pub(crate) mod tests_unsigned { + use super::*; + use crate::integer::keycache::KEY_CACHE; + use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{ + random_non_zero_value, FunctionExecutor, NB_CTXT, NB_TESTS_SMALLER, + }; + use crate::integer::{IntegerKeyKind, RadixClientKey}; + use crate::shortint::PBSParameters; + use rand::Rng; + use std::sync::Arc; + + fn default_test_count_consecutive_bits( + direction: Direction, + bit_value: BitValue, + param: P, + mut executor: T, + ) where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, + { + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks.clone()); + + let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); + + let compute_expected_clear = |x: u64| match (direction, bit_value) { + (Direction::Trailing, BitValue::Zero) => { + if x == 0 { + num_bits + } else { + x.trailing_zeros() + } + } + (Direction::Trailing, BitValue::One) => x.trailing_ones(), + (Direction::Leading, BitValue::Zero) => { + if x == 0 { + num_bits + } else { + (x << (u64::BITS - num_bits)).leading_zeros() + } + } + (Direction::Leading, BitValue::One) => (x << (u64::BITS - num_bits)).leading_ones(), + }; + + let method_name = match (direction, bit_value) { + (Direction::Trailing, BitValue::Zero) => "trailing_zeros", + (Direction::Trailing, BitValue::One) => "trailing_ones", + (Direction::Leading, BitValue::Zero) => "leading_zeros", + (Direction::Leading, BitValue::One) => "leading_ones", + }; + + let input_values = [0u64, modulus - 1] + .into_iter() + .chain((0..NB_TESTS_SMALLER).map(|_| rng.gen::() % modulus)) + .collect::>(); + + for clear in input_values { + let ctxt = cks.encrypt(clear); + + let ct_res = executor.execute(&ctxt); + let tmp = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = compute_expected_clear(clear); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for {method_name}, for {clear}.{method_name}() \ + expected {expected_result}, got {decrypted_result}" + ); + + for _ in 0..NB_TESTS_SMALLER { + // Add non-zero scalar to have non-clean ciphertexts + let clear_2 = random_non_zero_value(&mut rng, modulus); + + let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); + + let clear = clear.wrapping_add(clear_2) % modulus; + + let d0: u64 = cks.decrypt(&ctxt); + assert_eq!(d0, clear, "Failed sanity decryption check"); + + let ct_res = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let expected_result = compute_expected_clear(clear); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for {method_name}, for {clear}.{method_name}() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + let input_values = [0u64, modulus - 1] + .into_iter() + .chain((0..NB_TESTS_SMALLER).map(|_| rng.gen::() % modulus)); + + for clear in input_values { + let ctxt = sks.create_trivial_radix(clear, NB_CTXT); + + let ct_res = executor.execute(&ctxt); + let tmp = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = compute_expected_clear(clear); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for {method_name}, for {clear}.{method_name}() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + pub(crate) fn default_trailing_zeros_test(param: P, executor: T) + where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, + { + default_test_count_consecutive_bits(Direction::Trailing, BitValue::Zero, param, executor); + } + + pub(crate) fn default_trailing_ones_test(param: P, executor: T) + where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, + { + default_test_count_consecutive_bits(Direction::Trailing, BitValue::One, param, executor); + } + + pub(crate) fn default_leading_zeros_test(param: P, executor: T) + where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, + { + default_test_count_consecutive_bits(Direction::Leading, BitValue::Zero, param, executor); + } + + pub(crate) fn default_leading_ones_test(param: P, executor: T) + where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, + { + default_test_count_consecutive_bits(Direction::Leading, BitValue::One, param, executor); + } +} + +#[cfg(test)] +pub(crate) mod tests_signed { + use super::*; + use crate::integer::keycache::KEY_CACHE; + use crate::integer::server_key::radix_parallel::tests_signed::{ + random_non_zero_value, signed_add_under_modulus, NB_CTXT, NB_TESTS_SMALLER, + }; + use crate::integer::{IntegerKeyKind, RadixClientKey, SignedRadixCiphertext}; + use crate::shortint::PBSParameters; + use rand::Rng; + + fn default_test_count_consecutive_bits( + direction: Direction, + bit_value: BitValue, + param: P, + sks_method: F, + ) where + P: Into, + F: for<'a> Fn(&'a ServerKey, &'a SignedRadixCiphertext) -> RadixCiphertext, + { + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); + + let compute_expected_clear = |x: i64| match (direction, bit_value) { + (Direction::Trailing, BitValue::Zero) => { + if x == 0 { + num_bits + } else { + x.trailing_zeros() + } + } + (Direction::Trailing, BitValue::One) => x.trailing_ones().min(num_bits), + (Direction::Leading, BitValue::Zero) => { + if x == 0 { + num_bits + } else { + (x << (u64::BITS - num_bits)).leading_zeros() + } + } + (Direction::Leading, BitValue::One) => (x << (u64::BITS - num_bits)).leading_ones(), + }; + + let method_name = match (direction, bit_value) { + (Direction::Trailing, BitValue::Zero) => "trailing_zeros", + (Direction::Trailing, BitValue::One) => "trailing_ones", + (Direction::Leading, BitValue::Zero) => "leading_zeros", + (Direction::Leading, BitValue::One) => "leading_ones", + }; + + let input_values = [-modulus, 0i64, modulus - 1] + .into_iter() + .chain((0..NB_TESTS_SMALLER).map(|_| rng.gen::() % modulus)) + .collect::>(); + + for clear in input_values { + let ctxt = cks.encrypt_signed(clear); + + let ct_res = sks_method(&sks, &ctxt); + let tmp = sks_method(&sks, &ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = compute_expected_clear(clear); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for {method_name}, for {clear}.{method_name}() \ + expected {expected_result}, got {decrypted_result}" + ); + + for _ in 0..NB_TESTS_SMALLER { + // Add non-zero scalar to have non-clean ciphertexts + let clear_2 = random_non_zero_value(&mut rng, modulus); + + let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); + + let clear = signed_add_under_modulus(clear, clear_2, modulus); + + let d0: i64 = cks.decrypt_signed(&ctxt); + assert_eq!(d0, clear, "Failed sanity decryption check"); + + let ct_res = sks_method(&sks, &ctxt); + assert!(ct_res.block_carries_are_empty()); + + let expected_result = compute_expected_clear(clear); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for {method_name}, for {clear}.{method_name}() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + let input_values = [-modulus, 0i64, modulus - 1] + .into_iter() + .chain((0..NB_TESTS_SMALLER).map(|_| rng.gen::() % modulus)); + + for clear in input_values { + let ctxt = sks.create_trivial_radix(clear, NB_CTXT); + + let ct_res = sks_method(&sks, &ctxt); + assert!(ct_res.block_carries_are_empty()); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = compute_expected_clear(clear); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for {method_name}, for {clear}.{method_name}() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + pub(crate) fn default_trailing_zeros_test

(param: P) + where + P: Into, + { + default_test_count_consecutive_bits( + Direction::Trailing, + BitValue::Zero, + param, + ServerKey::trailing_zeros, + ); + } + + pub(crate) fn default_trailing_ones_test

(param: P) + where + P: Into, + { + default_test_count_consecutive_bits( + Direction::Trailing, + BitValue::One, + param, + ServerKey::trailing_ones, + ); + } + + pub(crate) fn default_leading_zeros_test

(param: P) + where + P: Into, + { + default_test_count_consecutive_bits( + Direction::Leading, + BitValue::Zero, + param, + ServerKey::leading_zeros, + ); + } + + pub(crate) fn default_leading_ones_test

(param: P) + where + P: Into, + { + default_test_count_consecutive_bits( + Direction::Leading, + BitValue::One, + param, + ServerKey::leading_ones, + ); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/mod.rs b/tfhe/src/integer/server_key/radix_parallel/mod.rs index 83e2390c67..39ef570510 100644 --- a/tfhe/src/integer/server_key/radix_parallel/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/mod.rs @@ -19,6 +19,7 @@ mod scalar_sub; mod shift; pub(crate) mod sub; +mod ilog2; #[cfg(test)] pub(crate) mod tests_cases_comparisons; #[cfg(test)] diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index b16d83a16e..76bb990f70 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -11,14 +11,14 @@ use rand::Rng; use std::sync::Arc; /// Number of loop iteration within randomized tests -const NB_TESTS: usize = 30; +pub(crate) const NB_TESTS: usize = 30; /// Smaller number of loop iteration within randomized test, /// meant for test where the function tested is more expensive -const NB_TESTS_SMALLER: usize = 10; -const NB_CTXT: usize = 4; +pub(crate) const NB_TESTS_SMALLER: usize = 10; +pub(crate) const NB_CTXT: usize = 4; -fn random_non_zero_value(rng: &mut ThreadRng, modulus: u64) -> u64 { +pub(crate) fn random_non_zero_value(rng: &mut ThreadRng, modulus: u64) -> u64 { rng.gen_range(1..modulus) } @@ -2872,6 +2872,12 @@ where } } +// Re-exports to still have tests case accessible from the same location +pub(crate) use crate::integer::server_key::radix_parallel::ilog2::tests_unsigned::{ + default_leading_ones_test, default_leading_zeros_test, default_trailing_ones_test, + default_trailing_zeros_test, +}; + //============================================================================= // Default Scalar Tests //============================================================================= diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed.rs index 3fe91be0ad..ad57c6e73f 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed.rs @@ -11,10 +11,10 @@ use rand::rngs::ThreadRng; use rand::Rng; /// Number of loop iteration within randomized tests -const NB_TESTS: usize = 30; +pub(crate) const NB_TESTS: usize = 30; -const NB_TESTS_SMALLER: usize = 10; -const NB_CTXT: usize = 4; +pub(crate) const NB_TESTS_SMALLER: usize = 10; +pub(crate) const NB_CTXT: usize = 4; macro_rules! create_parametrized_test{ ($name:ident { $($param:ident),* }) => { @@ -48,7 +48,7 @@ macro_rules! create_parametrized_test{ // Helper functions //================================================================================ -fn signed_add_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 { +pub(crate) fn signed_add_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 { signed_overflowing_add_under_modulus(lhs, rhs, modulus).0 } @@ -341,7 +341,7 @@ fn create_iterator_of_signed_random_pairs( izip!(lhs_values, rhs_values) } -fn random_non_zero_value(rng: &mut ThreadRng, modulus: i64) -> i64 { +pub(crate) fn random_non_zero_value(rng: &mut ThreadRng, modulus: i64) -> i64 { loop { let value = rng.gen::() % modulus; if value != 0 { @@ -1533,6 +1533,10 @@ create_parametrized_test!(integer_signed_default_rotate_right { PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS }); +create_parametrized_test!(integer_signed_default_trailing_zeros); +create_parametrized_test!(integer_signed_default_trailing_ones); +create_parametrized_test!(integer_signed_default_leading_zeros); +create_parametrized_test!(integer_signed_default_leading_ones); fn integer_signed_default_add

(param: P) where @@ -2612,6 +2616,42 @@ where } } +fn integer_signed_default_trailing_zeros

(param: P) +where + P: Into, +{ + crate::integer::server_key::radix_parallel::ilog2::tests_signed::default_trailing_zeros_test( + param, + ); +} + +fn integer_signed_default_trailing_ones

(param: P) +where + P: Into, +{ + crate::integer::server_key::radix_parallel::ilog2::tests_signed::default_trailing_ones_test( + param, + ); +} + +fn integer_signed_default_leading_zeros

(param: P) +where + P: Into, +{ + crate::integer::server_key::radix_parallel::ilog2::tests_signed::default_leading_zeros_test( + param, + ); +} + +fn integer_signed_default_leading_ones

(param: P) +where + P: Into, +{ + crate::integer::server_key::radix_parallel::ilog2::tests_signed::default_leading_ones_test( + param, + ); +} + //================================================================================ // Unchecked Scalar Tests //================================================================================ diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs index 022586e3af..93531095cb 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs @@ -1,5 +1,5 @@ use crate::integer::keycache::KEY_CACHE; -use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey}; +use crate::integer::{BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey}; use crate::shortint::parameters::*; use paste::paste; use rand::Rng; @@ -230,6 +230,10 @@ create_parametrized_test!(integer_default_overflowing_scalar_add); create_parametrized_test!(integer_smart_if_then_else); create_parametrized_test!(integer_default_if_then_else); create_parametrized_test!(integer_trim_radix_msb_blocks_handles_dirty_inputs); +create_parametrized_test!(integer_default_trailing_zeros); +create_parametrized_test!(integer_default_trailing_ones); +create_parametrized_test!(integer_default_leading_zeros); +create_parametrized_test!(integer_default_leading_ones); create_parametrized_test!(integer_unchecked_add); create_parametrized_test!(integer_unchecked_mul); @@ -272,6 +276,21 @@ impl CpuFunctionExecutor { /// /// impl TestExecutor<(I,), O> for CpuTestExecutor /// would be possible tho. +impl<'a, F> FunctionExecutor<&'a RadixCiphertext, (RadixCiphertext, BooleanBlock)> + for CpuFunctionExecutor +where + F: Fn(&ServerKey, &RadixCiphertext) -> (RadixCiphertext, BooleanBlock), +{ + fn setup(&mut self, _cks: &RadixClientKey, sks: Arc) { + self.sks = Some(sks); + } + + fn execute(&mut self, input: &'a RadixCiphertext) -> (RadixCiphertext, BooleanBlock) { + let sks = self.sks.as_ref().expect("setup was not properly called"); + (self.func)(sks, input) + } +} + impl<'a, F> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext> for CpuFunctionExecutor where F: Fn(&ServerKey, &RadixCiphertext) -> RadixCiphertext, @@ -835,6 +854,37 @@ where default_if_then_else_test(param, executor); } +fn integer_default_trailing_zeros

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::trailing_zeros); + default_trailing_zeros_test(param, executor); +} +fn integer_default_trailing_ones

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::trailing_ones); + default_trailing_ones_test(param, executor); +} + +fn integer_default_leading_zeros

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::leading_zeros); + default_leading_zeros_test(param, executor); +} + +fn integer_default_leading_ones

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::leading_ones); + default_leading_ones_test(param, executor); +} + //============================================================================= // Default Scalar Tests //=============================================================================