diff --git a/tfhe/benches/integer/bench.rs b/tfhe/benches/integer/bench.rs index c1028aeefe..586e76bf55 100644 --- a/tfhe/benches/integer/bench.rs +++ b/tfhe/benches/integer/bench.rs @@ -1068,6 +1068,12 @@ define_server_key_bench_unary_fn!(method_name: smart_abs_parallelized, display_n define_server_key_bench_unary_default_fn!(method_name: neg_parallelized, display_name: negation); define_server_key_bench_unary_default_fn!(method_name: abs_parallelized, display_name: abs); +define_server_key_bench_unary_default_fn!(method_name: leading_zeros, display_name: leading_zeros); +define_server_key_bench_unary_default_fn!(method_name: leading_ones, display_name: leading_ones); +define_server_key_bench_unary_default_fn!(method_name: trailing_zeros, display_name: trailing_zeros); +define_server_key_bench_unary_default_fn!(method_name: trailing_ones, display_name: trailing_ones); +define_server_key_bench_unary_default_fn!(method_name: ilog2, display_name: ilog2); +define_server_key_bench_unary_default_fn!(method_name: checked_ilog2, display_name: checked_ilog2); define_server_key_bench_unary_default_fn!(method_name: unchecked_abs_parallelized, display_name: abs); @@ -2069,6 +2075,12 @@ criterion_group!( rotate_left_parallelized, rotate_right_parallelized, ciphertexts_sum_parallelized, + leading_zeros, + leading_ones, + trailing_zeros, + trailing_ones, + ilog2, + checked_ilog2, ); criterion_group!( diff --git a/tfhe/benches/integer/signed_bench.rs b/tfhe/benches/integer/signed_bench.rs index e5abf4fbcc..b0c42518eb 100644 --- a/tfhe/benches/integer/signed_bench.rs +++ b/tfhe/benches/integer/signed_bench.rs @@ -337,6 +337,12 @@ define_server_key_bench_unary_signed_clean_input_fn!( method_name: abs_parallelized, display_name: abs ); +define_server_key_bench_unary_signed_clean_input_fn!(method_name: leading_zeros, display_name: leading_zeros); +define_server_key_bench_unary_signed_clean_input_fn!(method_name: leading_ones, display_name: leading_ones); +define_server_key_bench_unary_signed_clean_input_fn!(method_name: trailing_zeros, display_name: trailing_zeros); +define_server_key_bench_unary_signed_clean_input_fn!(method_name: trailing_ones, display_name: trailing_ones); +define_server_key_bench_unary_signed_clean_input_fn!(method_name: ilog2, display_name: ilog2); +define_server_key_bench_unary_signed_clean_input_fn!(method_name: checked_ilog2, display_name: checked_ilog2); define_server_key_bench_binary_signed_clean_inputs_fn!( method_name: add_parallelized, @@ -492,6 +498,12 @@ criterion_group!( right_shift_parallelized, rotate_left_parallelized, rotate_right_parallelized, + leading_zeros, + leading_ones, + trailing_zeros, + trailing_ones, + ilog2, + checked_ilog2, ); criterion_group!( diff --git a/tfhe/src/integer/server_key/radix_parallel/add.rs b/tfhe/src/integer/server_key/radix_parallel/add.rs index dc746ca72d..f344b8d898 100644 --- a/tfhe/src/integer/server_key/radix_parallel/add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/add.rs @@ -573,9 +573,10 @@ impl ServerKey { (carries_out, last_block_out_carry) } + /// Computes a prefix sum/scan in parallel using Hillis & Steel algorithm pub(crate) fn compute_prefix_sum_hillis_steele( &self, - mut generates_or_propagates: Vec, + mut blocks: Vec, sum_function: F, ) -> Vec where @@ -583,27 +584,27 @@ impl ServerKey { { debug_assert!(self.key.message_modulus.0 * self.key.carry_modulus.0 >= (1 << 4)); - let num_blocks = generates_or_propagates.len(); - let num_steps = generates_or_propagates.len().ceil_ilog2() as usize; + let num_blocks = blocks.len(); + let num_steps = blocks.len().ceil_ilog2() as usize; let mut space = 1; - let mut step_output = generates_or_propagates.clone(); + let mut step_output = blocks.clone(); for _ in 0..num_steps { step_output[space..num_blocks] .par_iter_mut() .enumerate() .for_each(|(i, block)| { - let prev_block_carry = &generates_or_propagates[i]; + let prev_block_carry = &blocks[i]; sum_function(block, prev_block_carry); }); for i in space..num_blocks { - generates_or_propagates[i].clone_from(&step_output[i]); + blocks[i].clone_from(&step_output[i]); } space *= 2; } - generates_or_propagates + blocks } /// This add_assign two numbers @@ -773,11 +774,8 @@ impl ServerKey { /// Computes the sum of the ciphertexts in parallel. /// - /// - Returns None if ciphertexts is empty - /// - /// - Expects all ciphertexts to have empty carries - /// - Expects all ciphertexts to have the same size - pub fn unchecked_sum_ciphertexts_vec_parallelized( + /// Returns a result that has non propagated carries + pub(crate) fn unchecked_partial_sum_ciphertexts_vec_parallelized( &self, mut ciphertexts: Vec, ) -> Option @@ -893,9 +891,26 @@ impl ServerKey { self.unchecked_add_assign(result, term); } + Some(result.clone()) + } + + /// Computes the sum of the ciphertexts in parallel. + /// + /// - Returns None if ciphertexts is empty + /// + /// - Expects all ciphertexts to have empty carries + /// - Expects all ciphertexts to have the same size + pub fn unchecked_sum_ciphertexts_vec_parallelized(&self, ciphertexts: Vec) -> Option + where + T: IntegerRadixCiphertext, + { + let non_propagated_result = + self.unchecked_partial_sum_ciphertexts_vec_parallelized(ciphertexts)?; + let num_blocks = non_propagated_result.blocks().len(); + let (message_blocks, carry_blocks) = rayon::join( || { - result + non_propagated_result .blocks() .par_iter() .map(|block| self.key.message_extract(block)) @@ -903,7 +918,7 @@ impl ServerKey { }, || { let mut carry_blocks = Vec::with_capacity(num_blocks); - result.blocks()[..num_blocks - 1] // last carry is not interesting + non_propagated_result.blocks()[..num_blocks - 1] // last carry is not interesting .par_iter() .map(|block| self.key.carry_extract(block)) .collect_into_vec(&mut carry_blocks); @@ -1313,6 +1328,7 @@ mod tests { ); } } + #[test] fn test_hillis_steele_choice_4_threads() { const NUM_THREADS: usize = 4; diff --git a/tfhe/src/integer/server_key/radix_parallel/ilog2.rs b/tfhe/src/integer/server_key/radix_parallel/ilog2.rs index 2abba3ce0b..f390496d90 100644 --- a/tfhe/src/integer/server_key/radix_parallel/ilog2.rs +++ b/tfhe/src/integer/server_key/radix_parallel/ilog2.rs @@ -1,8 +1,9 @@ use crate::core_crypto::algorithms::misc::divide_ceil; -use crate::integer::{IntegerRadixCiphertext, RadixCiphertext, ServerKey}; - +use crate::integer::{ + BooleanBlock, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, ServerKey, + SignedRadixCiphertext, +}; use crate::shortint::Ciphertext; - use rayon::prelude::*; /// A 'bit' value @@ -215,6 +216,137 @@ impl ServerKey { self.count_consecutive_bits(ct, Direction::Leading, BitValue::One) } + /// Returns the base 2 logarithm of the number, rounded down. + /// + /// See [Self::ilog2] for an example + /// + /// Expects ct to have clean carries + pub fn unchecked_ilog2(&self, ct: &T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + if ct.blocks().is_empty() { + return self.create_trivial_zero_radix(ct.blocks().len()); + } + + let num_bits_in_message = self.key.message_modulus.0.ilog2(); + let original_num_blocks = ct.blocks().len(); + + let num_bits_in_ciphertext = num_bits_in_message + .checked_mul(original_num_blocks as u32) + .expect("Number of bits encrypted exceeds u32::MAX"); + + // `num_bits_in_ciphertext-1` is the max value we want to represent + // its ilog2 + 1 gives use how many bits we need to be able to represent it. + // We add `1` to this number as we are going to use signed numbers later + // + // The ilog2 of a number that is on n bits, is in range 1..=n-1 + let counter_num_blocks = divide_ceil( + (num_bits_in_ciphertext - 1).ilog2() + 1 + 1, + self.message_modulus().0.ilog2(), + ) as usize; + + // x.ilog2() = (x.num_bit() - 1) - x.leading_zeros() + // - (x.num_bit() - 1) is trivially known + // - we can get leading zeros via a sum + // + // However, the sum include a full propagation, thus the subtraction + // will add another full propagation which is costly. + // + // However, we can do better: + // let N = (x.num_bit() - 1) + // let L0 = x.leading_zeros() + // ``` + // x.ilog2() = N - L0 + // x.ilog2() = -(-(N - L0)) + // x.ilog2() = -(-N + L0) + // ``` + // Since N is a clear number, getting -N is free, + // meaning -N + L0 where L0 is actually `sum(L0[b0], .., L0[num_blocks-1])` + // can be done with `sum(-N, L0[b0], .., L0[num_blocks-1]), by switching to signed + // numbers. + // + // Also, to do -(-N + L0) aka -sum(-N, L0[b0], .., L0[num_blocks-1]) + // we can make the sum not return a fully propagated result, + // and extract message/carry blocks while negating them at the same time + // using the fact that in twos complement -X = bitnot(X) + 1 + // so given a non propagated `C`, we can compute the fully propagated `PC` + // PC = bitnot(message(C)) + bitnot(blockshift(carry(C), 1)) + 2 + + let leading_zeros_per_blocks = + self.prepare_count_of_consecutive_bits(ct.clone(), Direction::Leading, BitValue::Zero); + + let mut cts = leading_zeros_per_blocks + .into_iter() + .map(|block| { + let mut ct: SignedRadixCiphertext = + self.create_trivial_zero_radix(counter_num_blocks); + ct.blocks[0] = block; + ct + }) + .collect::>(); + cts.push( + self.create_trivial_radix(-(num_bits_in_ciphertext as i32 - 1i32), counter_num_blocks), + ); + + let result = self + .unchecked_partial_sum_ciphertexts_vec_parallelized(cts) + .expect("internal error, empty ciphertext count"); + + // This is the part where we extract message and carry blocks + // while inverting their bits + let (message_blocks, carry_blocks) = rayon::join( + || { + let lut = self.key.generate_lookup_table(|x| { + // extract message + let x = x % self.key.message_modulus.0 as u64; + // bitnot the message + (!x) % self.key.message_modulus.0 as u64 + }); + result + .blocks() + .par_iter() + .map(|block| self.key.apply_lookup_table(block, &lut)) + .collect::>() + }, + || { + let lut = self.key.generate_lookup_table(|x| { + // extract carry + let x = x / self.key.message_modulus.0 as u64; + // bitnot the carry + (!x) % self.key.message_modulus.0 as u64 + }); + let mut carry_blocks = Vec::with_capacity(counter_num_blocks as usize); + result.blocks()[..counter_num_blocks - 1] // last carry is not interesting + .par_iter() + .map(|block| self.key.apply_lookup_table(block, &lut)) + .collect_into_vec(&mut carry_blocks); + // Normally this would be 0, but we want the bitnot of 0, which is msg_mod-1 + carry_blocks.insert( + 0, + self.key + .create_trivial((self.message_modulus().0 - 1) as u64), + ); + carry_blocks + }, + ); + + let message = SignedRadixCiphertext::from(message_blocks); + let carry = SignedRadixCiphertext::from(carry_blocks); + let result = self + .sum_ciphertexts_parallelized( + [ + message, + carry, + self.create_trivial_radix(2u32, counter_num_blocks), + ] + .iter(), + ) + .unwrap(); + + self.cast_to_unsigned(result, counter_num_blocks) + } + //============================================================================================== // Smart //============================================================================================== @@ -267,6 +399,37 @@ impl ServerKey { self.unchecked_leading_ones(ct) } + /// Returns the base 2 logarithm of the number, rounded down. + /// + /// See [Self::ilog2] for an example + pub fn smart_ilog2(&self, ct: &mut T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + if !ct.block_carries_are_empty() { + self.full_propagate_parallelized(ct); + } + + self.unchecked_ilog2(ct) + } + + /// Returns the base 2 logarithm of the number, rounded down. + /// + /// See [Self::checked_ilog2] for an example + /// + /// Also returns a BooleanBlock, encrypting true (1) if the result is + /// valid (input is > 0), otherwise 0. + pub fn smart_checked_ilog2(&self, ct: &mut T) -> (RadixCiphertext, BooleanBlock) + where + T: IntegerRadixCiphertext, + { + if !ct.block_carries_are_empty() { + self.full_propagate_parallelized(ct); + } + + rayon::join(|| self.ilog2(ct), || self.scalar_gt_parallelized(ct, 0)) + } + //============================================================================================== // Default //============================================================================================== @@ -446,6 +609,87 @@ impl ServerKey { }; self.unchecked_leading_ones(ct) } + + /// Returns the base 2 logarithm of the number, rounded down. + /// + /// # Example + /// + /// ``` + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 5i8; + /// + /// let ct1 = cks.encrypt_signed(msg); + /// + /// let n = sks.ilog2(&ct1); + /// + /// // Decrypt: + /// let n: u32 = cks.decrypt(&n); + /// assert_eq!(n, msg.ilog2()); + /// ``` + pub fn ilog2(&self, ct: &T) -> RadixCiphertext + where + T: IntegerRadixCiphertext, + { + let mut tmp; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp = ct.clone(); + self.full_propagate_parallelized(&mut tmp); + &tmp + }; + + self.unchecked_ilog2(ct) + } + + /// Returns the base 2 logarithm of the number, rounded down. + /// + /// Also returns a BooleanBlock, encrypting true (1) if the result is + /// valid (input is > 0), otherwise 0. + /// + /// # Example + /// + /// ``` + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 5i8; + /// + /// let ct1 = cks.encrypt_signed(msg); + /// + /// let (n, is_oks) = sks.checked_ilog2(&ct1); + /// + /// // Decrypt: + /// let n: u32 = cks.decrypt(&n); + /// assert_eq!(n, msg.ilog2()); + /// let is_ok = cks.decrypt_bool(&is_oks); + /// assert!(is_ok); + /// ``` + pub fn checked_ilog2(&self, ct: &T) -> (RadixCiphertext, BooleanBlock) + where + T: IntegerRadixCiphertext, + { + let mut tmp; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp = ct.clone(); + self.full_propagate_parallelized(&mut tmp); + &tmp + }; + + rayon::join(|| self.ilog2(ct), || self.scalar_gt_parallelized(ct, 0)) + } } #[cfg(test)] @@ -609,6 +853,263 @@ pub(crate) mod tests_unsigned { { default_test_count_consecutive_bits(Direction::Leading, BitValue::One, param, executor); } + + pub(crate) fn default_ilog2_test(param: P, mut executor: T) + where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, + { + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks.clone()); + + let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); + + // Test with invalid input + { + let ctxt = cks.encrypt(0u64); + + let ct_res = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let counter_num_blocks = divide_ceil( + (num_bits - 1).ilog2() + 1 + 1, + cks.parameters().message_modulus().0.ilog2(), + ) as usize; + let expected_result = (1u32 + << (counter_num_blocks as u32 * cks.parameters().message_modulus().0.ilog2())) + - 1; + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for 0.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + } + + let input_values = (0..num_bits) + .map(|i| 1 << i) + .chain( + (0..NB_TESTS_SMALLER.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt = cks.encrypt(clear); + + let ct_res = executor.execute(&ctxt); + let tmp = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + + for _ in 0..NB_TESTS_SMALLER { + // Add non-zero scalar to have non-clean ciphertexts + // But here, we have to make sure clear is still > 0 + // as we are only testing valid ilog2 inputs + let (clear, clear_2) = loop { + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear = clear_2.wrapping_add(clear) % modulus; + if clear != 0 { + break (clear, clear_2); + } + }; + + let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); + + let d0: u64 = cks.decrypt(&ctxt); + assert_eq!(d0, clear, "Failed sanity decryption check"); + + let ct_res = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let expected_result = clear.ilog2(); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + let input_values = (0..num_bits) + .map(|i| 1 << i) + .chain( + (0..NB_TESTS_SMALLER.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt = sks.create_trivial_radix(clear, NB_CTXT); + + let ct_res = executor.execute(&ctxt); + let tmp = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + pub(crate) fn default_checked_ilog2_test(param: P, mut executor: T) + where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, (RadixCiphertext, BooleanBlock)>, + { + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks.clone()); + + let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); + + // Test with invalid input + { + let ctxt = cks.encrypt(0u64); + + let (ct_res, is_ok) = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(is_ok.as_ref().degree.get(), 1); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let counter_num_blocks = divide_ceil( + (num_bits - 1).ilog2() + 1 + 1, + cks.parameters().message_modulus().0.ilog2(), + ) as usize; + let expected_result = (1u32 + << (counter_num_blocks as u32 * cks.parameters().message_modulus().0.ilog2())) + - 1; + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for 0.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(!is_ok); + } + + let input_values = (0..num_bits) + .map(|i| 1 << i) + .chain( + (0..NB_TESTS_SMALLER.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt = cks.encrypt(clear); + + let (ct_res, is_ok) = executor.execute(&ctxt); + let (tmp, tmp_is_ok) = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + assert_eq!(is_ok, tmp_is_ok); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(is_ok); + + for _ in 0..NB_TESTS_SMALLER { + // Add non-zero scalar to have non-clean ciphertexts + // But here, we have to make sure clear is still > 0 + // as we are only testing valid ilog2 inputs + let (clear, clear_2) = loop { + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear = clear_2.wrapping_add(clear) % modulus; + if clear != 0 { + break (clear, clear_2); + } + }; + + let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); + + let d0: u64 = cks.decrypt(&ctxt); + assert_eq!(d0, clear, "Failed sanity decryption check"); + + let (ct_res, is_ok) = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(is_ok.as_ref().degree.get(), 1); + + let expected_result = clear.ilog2(); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(is_ok); + } + } + + let input_values = (0..num_bits) + .map(|i| 1 << i) + .chain( + (0..NB_TESTS_SMALLER.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt = sks.create_trivial_radix(clear, NB_CTXT); + + let (ct_res, is_ok) = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(is_ok); + } + } } #[cfg(test)] @@ -782,4 +1283,266 @@ pub(crate) mod tests_signed { ServerKey::leading_ones, ); } + + pub(crate) fn default_ilog2_test

(param: P) + where + P: Into, + { + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); + + // Test with invalid input + { + for clear in [0i64, rng.gen_range(-modulus..=-1i64)] { + let ctxt = cks.encrypt_signed(clear); + + let ct_res = sks.ilog2(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = if clear < 0 { + num_bits - 1 + } else { + let counter_num_blocks = divide_ceil( + (num_bits - 1).ilog2() + 1 + 1, + cks.parameters().message_modulus().0.ilog2(), + ) as usize; + (1u32 + << (counter_num_blocks as u32 + * cks.parameters().message_modulus().0.ilog2())) + - 1 + }; + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + let input_values = (0..num_bits - 1) + .map(|i| 1 << i) + .chain( + (0..NB_TESTS_SMALLER.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt = cks.encrypt_signed(clear); + + let ct_res = sks.ilog2(&ctxt); + let tmp = sks.ilog2(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + + for _ in 0..NB_TESTS_SMALLER { + // Add non-zero scalar to have non-clean ciphertexts + // But here, we have to make sure clear is still > 0 + // as we are only testing valid ilog2 inputs + let (clear, clear_2) = loop { + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear = signed_add_under_modulus(clear, clear_2, modulus); + if clear > 0 { + break (clear, clear_2); + } + }; + + let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); + + let d0: i64 = cks.decrypt_signed(&ctxt); + assert_eq!(d0, clear, "Failed sanity decryption check"); + + let ct_res = sks.ilog2(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let expected_result = clear.ilog2(); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + let input_values = (0..num_bits - 1) + .map(|i| 1 << i) + .chain( + (0..NB_TESTS_SMALLER.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt: SignedRadixCiphertext = sks.create_trivial_radix(clear, NB_CTXT); + + let ct_res = sks.ilog2(&ctxt); + let tmp = sks.ilog2(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + pub(crate) fn default_checked_ilog2_test

(param: P) + where + P: Into, + { + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); + + // Test with invalid input + { + for clear in [0i64, rng.gen_range(-modulus..=-1i64)] { + let ctxt = cks.encrypt_signed(clear); + + let (ct_res, is_ok) = sks.checked_ilog2(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = if clear < 0 { + num_bits - 1 + } else { + let counter_num_blocks = divide_ceil( + (num_bits - 1).ilog2() + 1 + 1, + cks.parameters().message_modulus().0.ilog2(), + ) as usize; + (1u32 + << (counter_num_blocks as u32 + * cks.parameters().message_modulus().0.ilog2())) + - 1 + }; + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(!is_ok); + } + } + + let input_values = (0..num_bits - 1) + .map(|i| 1 << i) + .chain( + (0..NB_TESTS_SMALLER.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt = cks.encrypt_signed(clear); + + let (ct_res, is_ok) = sks.checked_ilog2(&ctxt); + let (tmp, tmp_is_ok) = sks.checked_ilog2(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + assert_eq!(is_ok, tmp_is_ok); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(is_ok); + + for _ in 0..NB_TESTS_SMALLER { + // Add non-zero scalar to have non-clean ciphertexts + // But here, we have to make sure clear is still > 0 + // as we are only testing valid ilog2 inputs + let (clear, clear_2) = loop { + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear = signed_add_under_modulus(clear, clear_2, modulus); + if clear > 0 { + break (clear, clear_2); + } + }; + + let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); + + let d0: i64 = cks.decrypt_signed(&ctxt); + assert_eq!(d0, clear, "Failed sanity decryption check"); + + let (ct_res, is_ok) = sks.checked_ilog2(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(is_ok.as_ref().degree.get(), 1); + + let expected_result = clear.ilog2(); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(is_ok); + } + } + + let input_values = (0..num_bits - 1) + .map(|i| 1 << i) + .chain( + (0..NB_TESTS_SMALLER.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt: SignedRadixCiphertext = sks.create_trivial_radix(clear, NB_CTXT); + + let (ct_res, is_ok) = sks.checked_ilog2(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(is_ok); + } + } } diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index 76bb990f70..9997324e77 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -2874,8 +2874,8 @@ where // Re-exports to still have tests case accessible from the same location pub(crate) use crate::integer::server_key::radix_parallel::ilog2::tests_unsigned::{ - default_leading_ones_test, default_leading_zeros_test, default_trailing_ones_test, - default_trailing_zeros_test, + default_checked_ilog2_test, default_ilog2_test, default_leading_ones_test, + default_leading_zeros_test, default_trailing_ones_test, default_trailing_zeros_test, }; //============================================================================= diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed.rs index ad57c6e73f..28dd2a7f8c 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed.rs @@ -1537,6 +1537,17 @@ create_parametrized_test!(integer_signed_default_trailing_zeros); create_parametrized_test!(integer_signed_default_trailing_ones); create_parametrized_test!(integer_signed_default_leading_zeros); create_parametrized_test!(integer_signed_default_leading_ones); +create_parametrized_test!(integer_signed_default_ilog2); +create_parametrized_test!(integer_signed_default_checked_ilog2 { + // uses comparison so 1_1 parameters are not supported + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + PARAM_MESSAGE_3_CARRY_3_KS_PBS, + PARAM_MESSAGE_4_CARRY_4_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS +}); fn integer_signed_default_add

(param: P) where @@ -2652,6 +2663,22 @@ where ); } +fn integer_signed_default_ilog2

(param: P) +where + P: Into, +{ + crate::integer::server_key::radix_parallel::ilog2::tests_signed::default_ilog2_test(param); +} + +fn integer_signed_default_checked_ilog2

(param: P) +where + P: Into, +{ + crate::integer::server_key::radix_parallel::ilog2::tests_signed::default_checked_ilog2_test( + param, + ); +} + //================================================================================ // Unchecked Scalar Tests //================================================================================ diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs index 93531095cb..ec852e3fbe 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs @@ -234,6 +234,17 @@ create_parametrized_test!(integer_default_trailing_zeros); create_parametrized_test!(integer_default_trailing_ones); create_parametrized_test!(integer_default_leading_zeros); create_parametrized_test!(integer_default_leading_ones); +create_parametrized_test!(integer_default_ilog2); +create_parametrized_test!(integer_default_checked_ilog2 { + // This uses comparisons, so require more than 1 bit + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + PARAM_MESSAGE_3_CARRY_3_KS_PBS, + PARAM_MESSAGE_4_CARRY_4_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS +}); create_parametrized_test!(integer_unchecked_add); create_parametrized_test!(integer_unchecked_mul); @@ -885,6 +896,22 @@ where default_leading_ones_test(param, executor); } +fn integer_default_ilog2

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::ilog2); + default_ilog2_test(param, executor); +} + +fn integer_default_checked_ilog2

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::checked_ilog2); + default_checked_ilog2_test(param, executor); +} + //============================================================================= // Default Scalar Tests //=============================================================================