Skip to content

Commit

Permalink
fix: upcasting of signed integer when block decomposing
Browse files Browse the repository at this point in the history
Some parts of the code did not use the correct way to
decompose a clear integer into blocks which could be encrypted
or used in scalar ops.

The sign extension was not always properly done, leading for example
in the encryption of a negative integer stored on a i8 to a
SignedRadixCiphertext with a num_blocks greater than i8 to be incorrect:

```
let ct = cks.encrypt_signed(-1i8, 16) // 2_2 parameters
let d: i32 = cks.decrypt_signed(&ct);
assert_eq!(d, i32::from(-1i8)); // Fails
```

To fix, a BlockDecomposer::with_block_count function is added and used
This function will properly do the sign extension when needed
  • Loading branch information
tmontaigu committed Feb 28, 2025
1 parent 2ffefe7 commit 55a4d18
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 72 deletions.
61 changes: 60 additions & 1 deletion tfhe/src/integer/block_decomposition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,21 @@ where
Self::new_(value, bits_per_block, None, Some(padding_bit))
}

/// Creates a block decomposer that will return `block_count` blocks
///
/// * If T is signed, extra block will be sign extended
pub fn with_block_count(value: T, bits_per_block: u32, block_count: u32) -> Self {
let mut decomposer = Self::new(value, bits_per_block);
// If the new number of bits is less than the actual number of bits, it means
// data will be truncated
//
// If the new number of bits is greater than the actual number of bits, it means
// the right shift used internally will correctly sign extend for us
let num_bits_valid = block_count * bits_per_block;
decomposer.num_bits_valid = num_bits_valid;
decomposer
}

pub fn new(value: T, bits_per_block: u32) -> Self {
Self::new_(value, bits_per_block, None, None)
}
Expand Down Expand Up @@ -238,7 +253,8 @@ where
T: Recomposable,
{
pub fn value(&self) -> T {
if self.bit_pos >= T::BITS as u32 {
let is_signed = (T::ONE << (T::BITS as u32 - 1)) < T::ZERO;
if self.bit_pos >= (T::BITS as u32 - u32::from(is_signed)) {
self.data
} else {
let valid_mask = (T::ONE << self.bit_pos) - T::ONE;
Expand Down Expand Up @@ -351,6 +367,49 @@ mod tests {
assert_eq!(expected_blocks, blocks);
}

#[test]
fn test_bit_block_decomposer_with_block_count() {
let bits_per_block = 3;
let expected_blocks = vec![0, 0, 6, 7, 7, 7, 7, 7, 7];
let value = i8::MIN;
for block_count in 1..expected_blocks.len() as u32 {
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
.iter_as::<u64>()
.collect::<Vec<_>>();
assert_eq!(expected_blocks[..block_count as usize], blocks);
}

let bits_per_block = 3;
let expected_blocks = vec![7, 7, 1, 0, 0, 0, 0, 0, 0];
let value = i8::MAX;
for block_count in 1..expected_blocks.len() as u32 {
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
.iter_as::<u64>()
.collect::<Vec<_>>();
assert_eq!(expected_blocks[..block_count as usize], blocks);
}

let bits_per_block = 2;
let expected_blocks = vec![0, 0, 0, 2, 3, 3, 3, 3, 3];
let value = i8::MIN;
for block_count in 1..expected_blocks.len() as u32 {
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
.iter_as::<u64>()
.collect::<Vec<_>>();
assert_eq!(expected_blocks[..block_count as usize], blocks);
}

let bits_per_block = 2;
let expected_blocks = vec![3, 3, 3, 1, 0, 0, 0, 0, 0, 0];
let value = i8::MAX;
for block_count in 1..expected_blocks.len() as u32 {
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
.iter_as::<u64>()
.collect::<Vec<_>>();
assert_eq!(expected_blocks[..block_count as usize], blocks);
}
}

#[test]
fn test_bit_block_decomposer_recomposer_carry_handling_in_between() {
let value = u16::MAX as u32;
Expand Down
11 changes: 2 additions & 9 deletions tfhe/src/integer/encryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ where
// We need to concretize the iterator type to be able to pass callbacks consuming the iterator,
// having an opaque return impl Iterator does not allow to take callbacks at this moment, not sure
// the Fn(impl Trait) syntax can be made to work nicely with the rest of the language
pub(crate) type ClearRadixBlockIterator<T> = std::iter::Take<
std::iter::Chain<std::iter::Map<BlockDecomposer<T>, fn(T) -> u64>, std::iter::Repeat<u64>>,
>;
pub(crate) type ClearRadixBlockIterator<T> = std::iter::Map<BlockDecomposer<T>, fn(T) -> u64>;

pub(crate) fn create_clear_radix_block_iterator<T>(
message: T,
Expand All @@ -105,12 +103,7 @@ where
T: DecomposableInto<u64>,
{
let bits_in_block = message_modulus.0.ilog2();
let decomposer = BlockDecomposer::new(message, bits_in_block);

decomposer
.iter_as::<u64>()
.chain(std::iter::repeat(0u64))
.take(num_blocks)
BlockDecomposer::with_block_count(message, bits_in_block, num_blocks as u32).iter_as::<u64>()
}

pub(crate) fn encrypt_crt<BlockKey, Block, CrtCiphertextType, F>(
Expand Down
10 changes: 6 additions & 4 deletions tfhe/src/integer/gpu/server_key/radix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,12 @@ impl CudaServerKey {
PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_size(),
};

let decomposer = BlockDecomposer::new(scalar, self.message_modulus.0.ilog2())
.iter_as::<u64>()
.chain(std::iter::repeat(0))
.take(num_blocks);
let decomposer = BlockDecomposer::with_block_count(
scalar,
self.message_modulus.0.ilog2(),
num_blocks as u32,
)
.iter_as::<u64>();
let mut cpu_lwe_list = LweCiphertextList::new(
0,
lwe_size,
Expand Down
52 changes: 47 additions & 5 deletions tfhe/src/integer/server_key/radix/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ create_parameterized_test_classical_params!(integer_encrypt_decrypt_128_bits);
create_parameterized_test_classical_params!(integer_encrypt_decrypt_128_bits_specific_values);
create_parameterized_test_classical_params!(integer_encrypt_decrypt_256_bits_specific_values);
create_parameterized_test_classical_params!(integer_encrypt_decrypt_256_bits);
create_parameterized_test_classical_params!(integer_encrypt_auto_cast);
create_parameterized_test_classical_params!(integer_unchecked_add);
create_parameterized_test_classical_params!(integer_smart_add);
create_parameterized_test!(
Expand Down Expand Up @@ -157,7 +158,7 @@ fn integer_encrypt_decrypt_128_bits(param: ClassicPBSParameters) {
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let mut rng = rand::thread_rng();
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
for _ in 0..10 {
let clear = rng.gen::<u128>();

Expand All @@ -172,7 +173,7 @@ fn integer_encrypt_decrypt_128_bits(param: ClassicPBSParameters) {
fn integer_encrypt_decrypt_128_bits_specific_values(param: ClassicPBSParameters) {
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
{
let a = u64::MAX as u128;
let ct = cks.encrypt_radix(a, num_block);
Expand Down Expand Up @@ -220,7 +221,7 @@ fn integer_encrypt_decrypt_128_bits_specific_values(param: ClassicPBSParameters)
fn integer_encrypt_decrypt_256_bits_specific_values(param: ClassicPBSParameters) {
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let num_block = (256f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 256u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
{
let a = (u64::MAX as u128) << 64;
let b = 0;
Expand All @@ -245,7 +246,7 @@ fn integer_encrypt_decrypt_256_bits(param: ClassicPBSParameters) {
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let mut rng = rand::thread_rng();
let num_block = (256f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 256u32.div_ceil(param.message_modulus.0.ilog2()) as usize;

for _ in 0..10 {
let clear0 = rng.gen::<u128>();
Expand All @@ -261,11 +262,52 @@ fn integer_encrypt_decrypt_256_bits(param: ClassicPBSParameters) {
}
}

fn integer_encrypt_auto_cast(param: ClassicPBSParameters) {
// The goal is to test that encrypting a value stored in a type
// for which the bit count does not match the target block count of the encrypted
// radix properly applies upcasting/downcasting

let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let mut rng = rand::thread_rng();

let num_blocks = 32u32.div_ceil(param.message_modulus.0.ilog2()) as usize;

// Positive signed value
let value = rng.gen_range(0..=i32::MAX);
let ct = cks.encrypt_signed_radix(value, num_blocks * 2);
let d: i64 = cks.decrypt_signed_radix(&ct);
assert_eq!(i64::from(value), d);

let ct = cks.encrypt_signed_radix(value, num_blocks.div_ceil(2));
let d: i16 = cks.decrypt_signed_radix(&ct);
assert_eq!(value as i16, d);

// Negative signed value
let value = rng.gen_range(i8::MIN..0);
let ct = cks.encrypt_signed_radix(value, num_blocks * 2);
let d: i64 = cks.decrypt_signed_radix(&ct);
assert_eq!(i64::from(value), d);

let ct = cks.encrypt_signed_radix(value, num_blocks.div_ceil(2));
let d: i16 = cks.decrypt_signed_radix(&ct);
assert_eq!(value as i16, d);

// Unsigned value
let value = rng.gen::<u32>();
let ct = cks.encrypt_radix(value, num_blocks * 2);
let d: u64 = cks.decrypt_radix(&ct);
assert_eq!(u64::from(value), d);

let ct = cks.encrypt_radix(value, num_blocks.div_ceil(2));
let d: u16 = cks.decrypt_radix(&ct);
assert_eq!(value as u16, d);
}

fn integer_smart_add_128_bits(param: ClassicPBSParameters) {
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let mut rng = rand::thread_rng();
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize;

for _ in 0..100 {
let clear_0 = rng.gen::<u128>();
Expand Down
16 changes: 7 additions & 9 deletions tfhe/src/integer/server_key/radix_parallel/scalar_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,13 @@ impl ServerKey {
self.full_propagate_parallelized(ct);
}

let scalar_blocks = BlockDecomposer::new(scalar, self.message_modulus().0.ilog2())
.iter_as::<u8>()
.chain(std::iter::repeat(if scalar < Scalar::ZERO {
(self.message_modulus().0 - 1) as u8
} else {
0
}))
.take(ct.blocks().len())
.collect();
let scalar_blocks = BlockDecomposer::with_block_count(
scalar,
self.message_modulus().0.ilog2(),
ct.blocks().len() as u32,
)
.iter_as::<u8>()
.collect();

const COMPUTE_OVERFLOW: bool = false;
const INPUT_CARRY: bool = false;
Expand Down
82 changes: 38 additions & 44 deletions tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -757,12 +757,10 @@ impl ServerKey {
.map(|chunk_of_two| self.pack_block_chunk(chunk_of_two))
.collect::<Vec<_>>();

let padding_value = (packed_modulus - 1) * u64::from(b < Scalar::ZERO);
let mut b_blocks = BlockDecomposer::new(b, packed_modulus.ilog2())
.iter_as::<u64>()
.chain(std::iter::repeat(padding_value))
.take(a.len())
.collect::<Vec<_>>();
let mut b_blocks =
BlockDecomposer::with_block_count(b, packed_modulus.ilog2(), a.len() as u32)
.iter_as::<u64>()
.collect::<Vec<_>>();

if !num_block_is_even && b < Scalar::ZERO {
let last_index = b_blocks.len() - 1;
Expand Down Expand Up @@ -1058,25 +1056,23 @@ impl ServerKey {
Scalar: DecomposableInto<u64>,
{
let is_superior = self.unchecked_scalar_gt_parallelized(lhs, rhs);
let luts = BlockDecomposer::new(rhs, self.message_modulus().0.ilog2())
.iter_as::<u64>()
.chain(std::iter::repeat(if rhs >= Scalar::ZERO {
0u64
} else {
self.message_modulus().0 - 1
}))
.take(lhs.blocks().len())
.map(|scalar_block| {
self.key
.generate_lookup_table_bivariate(|is_superior, block| {
if is_superior == 1 {
block
} else {
scalar_block
}
})
})
.collect::<Vec<_>>();
let luts = BlockDecomposer::with_block_count(
rhs,
self.message_modulus().0.ilog2(),
lhs.blocks().len() as u32,
)
.iter_as::<u64>()
.map(|scalar_block| {
self.key
.generate_lookup_table_bivariate(|is_superior, block| {
if is_superior == 1 {
block
} else {
scalar_block
}
})
})
.collect::<Vec<_>>();

let new_blocks = lhs
.blocks()
Expand All @@ -1097,25 +1093,23 @@ impl ServerKey {
Scalar: DecomposableInto<u64>,
{
let is_inferior = self.unchecked_scalar_lt_parallelized(lhs, rhs);
let luts = BlockDecomposer::new(rhs, self.message_modulus().0.ilog2())
.iter_as::<u64>()
.chain(std::iter::repeat(if rhs >= Scalar::ZERO {
0u64
} else {
self.message_modulus().0 - 1
}))
.take(lhs.blocks().len())
.map(|scalar_block| {
self.key
.generate_lookup_table_bivariate(|is_inferior, block| {
if is_inferior == 1 {
block
} else {
scalar_block
}
})
})
.collect::<Vec<_>>();
let luts = BlockDecomposer::with_block_count(
rhs,
self.message_modulus().0.ilog2(),
lhs.blocks().len() as u32,
)
.iter_as::<u64>()
.map(|scalar_block| {
self.key
.generate_lookup_table_bivariate(|is_inferior, block| {
if is_inferior == 1 {
block
} else {
scalar_block
}
})
})
.collect::<Vec<_>>();

let new_blocks = lhs
.blocks()
Expand Down
1 change: 1 addition & 0 deletions tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ impl ServerKey {
const INPUT_CARRY: bool = true;
let flipped_scalar = !scalar;
let decomposed_flipped_scalar =
// We don't use BlockDecomposer::with_block_count as we are doing something special
BlockDecomposer::new(flipped_scalar, self.message_modulus().0.ilog2())
.iter_as::<u8>()
.chain(std::iter::repeat(if scalar < Scalar::ZERO {
Expand Down

0 comments on commit 55a4d18

Please sign in to comment.