Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tm/fix signed decomposer #2133

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion tfhe/src/integer/block_decomposition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,21 @@ where
Self::new_(value, bits_per_block, None, Some(padding_bit))
}

/// Creates a block decomposer that will return `block_count` blocks
///
/// * If T is signed, extra block will be sign extended
pub fn with_block_count(value: T, bits_per_block: u32, block_count: u32) -> Self {
let mut decomposer = Self::new(value, bits_per_block);
// If the new number of bits is less than the actual number of bits, it means
// data will be truncated
//
// If the new number of bits is greater than the actual number of bits, it means
// the right shift used internally will correctly sign extend for us
let num_bits_valid = block_count * bits_per_block;
decomposer.num_bits_valid = num_bits_valid;
decomposer
}

pub fn new(value: T, bits_per_block: u32) -> Self {
Self::new_(value, bits_per_block, None, None)
}
Expand Down Expand Up @@ -245,7 +260,8 @@ where
T: Recomposable,
{
pub fn value(&self) -> T {
if self.bit_pos >= T::BITS as u32 {
let is_signed = (T::ONE << (T::BITS as u32 - 1)) < T::ZERO;
if self.bit_pos >= (T::BITS as u32 - u32::from(is_signed)) {
self.data
} else {
let valid_mask = (T::ONE << self.bit_pos) - T::ONE;
Expand Down Expand Up @@ -359,6 +375,49 @@ mod tests {
assert_eq!(expected_blocks, blocks);
}

#[test]
fn test_bit_block_decomposer_with_block_count() {
let bits_per_block = 3;
let expected_blocks = [0, 0, 6, 7, 7, 7, 7, 7, 7];
let value = i8::MIN;
for block_count in 1..expected_blocks.len() as u32 {
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
.iter_as::<u64>()
.collect::<Vec<_>>();
assert_eq!(expected_blocks[..block_count as usize], blocks);
}

let bits_per_block = 3;
let expected_blocks = [7, 7, 1, 0, 0, 0, 0, 0, 0];
let value = i8::MAX;
for block_count in 1..expected_blocks.len() as u32 {
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
.iter_as::<u64>()
.collect::<Vec<_>>();
assert_eq!(expected_blocks[..block_count as usize], blocks);
}

let bits_per_block = 2;
let expected_blocks = [0, 0, 0, 2, 3, 3, 3, 3, 3];
let value = i8::MIN;
for block_count in 1..expected_blocks.len() as u32 {
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
.iter_as::<u64>()
.collect::<Vec<_>>();
assert_eq!(expected_blocks[..block_count as usize], blocks);
}

let bits_per_block = 2;
let expected_blocks = [3, 3, 3, 1, 0, 0, 0, 0, 0, 0];
let value = i8::MAX;
for block_count in 1..expected_blocks.len() as u32 {
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
.iter_as::<u64>()
.collect::<Vec<_>>();
assert_eq!(expected_blocks[..block_count as usize], blocks);
}
}

#[test]
fn test_bit_block_decomposer_recomposer_carry_handling_in_between() {
let value = u16::MAX as u32;
Expand Down
11 changes: 2 additions & 9 deletions tfhe/src/integer/encryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ where
// We need to concretize the iterator type to be able to pass callbacks consuming the iterator,
// having an opaque return impl Iterator does not allow to take callbacks at this moment, not sure
// the Fn(impl Trait) syntax can be made to work nicely with the rest of the language
pub(crate) type ClearRadixBlockIterator<T> = std::iter::Take<
std::iter::Chain<std::iter::Map<BlockDecomposer<T>, fn(T) -> u64>, std::iter::Repeat<u64>>,
>;
pub(crate) type ClearRadixBlockIterator<T> = std::iter::Map<BlockDecomposer<T>, fn(T) -> u64>;

pub(crate) fn create_clear_radix_block_iterator<T>(
message: T,
Expand All @@ -105,12 +103,7 @@ where
T: DecomposableInto<u64>,
{
let bits_in_block = message_modulus.0.ilog2();
let decomposer = BlockDecomposer::new(message, bits_in_block);

decomposer
.iter_as::<u64>()
.chain(std::iter::repeat(0u64))
.take(num_blocks)
BlockDecomposer::with_block_count(message, bits_in_block, num_blocks as u32).iter_as::<u64>()
}

pub(crate) fn encrypt_crt<BlockKey, Block, CrtCiphertextType, F>(
Expand Down
10 changes: 6 additions & 4 deletions tfhe/src/integer/gpu/server_key/radix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,12 @@ impl CudaServerKey {
PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_size(),
};

let decomposer = BlockDecomposer::new(scalar, self.message_modulus.0.ilog2())
.iter_as::<u64>()
.chain(std::iter::repeat(0))
.take(num_blocks);
let decomposer = BlockDecomposer::with_block_count(
scalar,
self.message_modulus.0.ilog2(),
num_blocks as u32,
)
.iter_as::<u64>();
let mut cpu_lwe_list = LweCiphertextList::new(
0,
lwe_size,
Expand Down
52 changes: 47 additions & 5 deletions tfhe/src/integer/server_key/radix/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ create_parameterized_test_classical_params!(integer_encrypt_decrypt_128_bits);
create_parameterized_test_classical_params!(integer_encrypt_decrypt_128_bits_specific_values);
create_parameterized_test_classical_params!(integer_encrypt_decrypt_256_bits_specific_values);
create_parameterized_test_classical_params!(integer_encrypt_decrypt_256_bits);
create_parameterized_test_classical_params!(integer_encrypt_auto_cast);
create_parameterized_test_classical_params!(integer_unchecked_add);
create_parameterized_test_classical_params!(integer_smart_add);
create_parameterized_test!(
Expand Down Expand Up @@ -157,7 +158,7 @@ fn integer_encrypt_decrypt_128_bits(param: ClassicPBSParameters) {
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let mut rng = rand::thread_rng();
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
for _ in 0..10 {
let clear = rng.gen::<u128>();

Expand All @@ -172,7 +173,7 @@ fn integer_encrypt_decrypt_128_bits(param: ClassicPBSParameters) {
fn integer_encrypt_decrypt_128_bits_specific_values(param: ClassicPBSParameters) {
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
{
let a = u64::MAX as u128;
let ct = cks.encrypt_radix(a, num_block);
Expand Down Expand Up @@ -220,7 +221,7 @@ fn integer_encrypt_decrypt_128_bits_specific_values(param: ClassicPBSParameters)
fn integer_encrypt_decrypt_256_bits_specific_values(param: ClassicPBSParameters) {
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let num_block = (256f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 256u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
{
let a = (u64::MAX as u128) << 64;
let b = 0;
Expand All @@ -245,7 +246,7 @@ fn integer_encrypt_decrypt_256_bits(param: ClassicPBSParameters) {
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let mut rng = rand::thread_rng();
let num_block = (256f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 256u32.div_ceil(param.message_modulus.0.ilog2()) as usize;

for _ in 0..10 {
let clear0 = rng.gen::<u128>();
Expand All @@ -261,11 +262,52 @@ fn integer_encrypt_decrypt_256_bits(param: ClassicPBSParameters) {
}
}

fn integer_encrypt_auto_cast(param: ClassicPBSParameters) {
// The goal is to test that encrypting a value stored in a type
// for which the bit count does not match the target block count of the encrypted
// radix properly applies upcasting/downcasting

let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let mut rng = rand::thread_rng();

let num_blocks = 32u32.div_ceil(param.message_modulus.0.ilog2()) as usize;

// Positive signed value
let value = rng.gen_range(0..=i32::MAX);
let ct = cks.encrypt_signed_radix(value, num_blocks * 2);
let d: i64 = cks.decrypt_signed_radix(&ct);
assert_eq!(i64::from(value), d);

let ct = cks.encrypt_signed_radix(value, num_blocks.div_ceil(2));
let d: i16 = cks.decrypt_signed_radix(&ct);
assert_eq!(value as i16, d);

// Negative signed value
let value = rng.gen_range(i8::MIN..0);
let ct = cks.encrypt_signed_radix(value, num_blocks * 2);
let d: i64 = cks.decrypt_signed_radix(&ct);
assert_eq!(i64::from(value), d);

let ct = cks.encrypt_signed_radix(value, num_blocks.div_ceil(2));
let d: i16 = cks.decrypt_signed_radix(&ct);
assert_eq!(value as i16, d);

// Unsigned value
let value = rng.gen::<u32>();
let ct = cks.encrypt_radix(value, num_blocks * 2);
let d: u64 = cks.decrypt_radix(&ct);
assert_eq!(u64::from(value), d);

let ct = cks.encrypt_radix(value, num_blocks.div_ceil(2));
let d: u16 = cks.decrypt_radix(&ct);
assert_eq!(value as u16, d);
}

fn integer_smart_add_128_bits(param: ClassicPBSParameters) {
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let mut rng = rand::thread_rng();
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize;

for _ in 0..100 {
let clear_0 = rng.gen::<u128>();
Expand Down
16 changes: 7 additions & 9 deletions tfhe/src/integer/server_key/radix_parallel/scalar_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,13 @@ impl ServerKey {
self.full_propagate_parallelized(ct);
}

let scalar_blocks = BlockDecomposer::new(scalar, self.message_modulus().0.ilog2())
.iter_as::<u8>()
.chain(std::iter::repeat(if scalar < Scalar::ZERO {
(self.message_modulus().0 - 1) as u8
} else {
0
}))
.take(ct.blocks().len())
.collect();
let scalar_blocks = BlockDecomposer::with_block_count(
scalar,
self.message_modulus().0.ilog2(),
ct.blocks().len() as u32,
)
.iter_as::<u8>()
.collect();

const COMPUTE_OVERFLOW: bool = false;
const INPUT_CARRY: bool = false;
Expand Down
82 changes: 38 additions & 44 deletions tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -757,12 +757,10 @@ impl ServerKey {
.map(|chunk_of_two| self.pack_block_chunk(chunk_of_two))
.collect::<Vec<_>>();

let padding_value = (packed_modulus - 1) * u64::from(b < Scalar::ZERO);
let mut b_blocks = BlockDecomposer::new(b, packed_modulus.ilog2())
.iter_as::<u64>()
.chain(std::iter::repeat(padding_value))
.take(a.len())
.collect::<Vec<_>>();
let mut b_blocks =
BlockDecomposer::with_block_count(b, packed_modulus.ilog2(), a.len() as u32)
.iter_as::<u64>()
.collect::<Vec<_>>();

if !num_block_is_even && b < Scalar::ZERO {
let last_index = b_blocks.len() - 1;
Expand Down Expand Up @@ -1058,25 +1056,23 @@ impl ServerKey {
Scalar: DecomposableInto<u64>,
{
let is_superior = self.unchecked_scalar_gt_parallelized(lhs, rhs);
let luts = BlockDecomposer::new(rhs, self.message_modulus().0.ilog2())
.iter_as::<u64>()
.chain(std::iter::repeat(if rhs >= Scalar::ZERO {
0u64
} else {
self.message_modulus().0 - 1
}))
.take(lhs.blocks().len())
.map(|scalar_block| {
self.key
.generate_lookup_table_bivariate(|is_superior, block| {
if is_superior == 1 {
block
} else {
scalar_block
}
})
})
.collect::<Vec<_>>();
let luts = BlockDecomposer::with_block_count(
rhs,
self.message_modulus().0.ilog2(),
lhs.blocks().len() as u32,
)
.iter_as::<u64>()
.map(|scalar_block| {
self.key
.generate_lookup_table_bivariate(|is_superior, block| {
if is_superior == 1 {
block
} else {
scalar_block
}
})
})
.collect::<Vec<_>>();

let new_blocks = lhs
.blocks()
Expand All @@ -1097,25 +1093,23 @@ impl ServerKey {
Scalar: DecomposableInto<u64>,
{
let is_inferior = self.unchecked_scalar_lt_parallelized(lhs, rhs);
let luts = BlockDecomposer::new(rhs, self.message_modulus().0.ilog2())
.iter_as::<u64>()
.chain(std::iter::repeat(if rhs >= Scalar::ZERO {
0u64
} else {
self.message_modulus().0 - 1
}))
.take(lhs.blocks().len())
.map(|scalar_block| {
self.key
.generate_lookup_table_bivariate(|is_inferior, block| {
if is_inferior == 1 {
block
} else {
scalar_block
}
})
})
.collect::<Vec<_>>();
let luts = BlockDecomposer::with_block_count(
rhs,
self.message_modulus().0.ilog2(),
lhs.blocks().len() as u32,
)
.iter_as::<u64>()
.map(|scalar_block| {
self.key
.generate_lookup_table_bivariate(|is_inferior, block| {
if is_inferior == 1 {
block
} else {
scalar_block
}
})
})
.collect::<Vec<_>>();

let new_blocks = lhs
.blocks()
Expand Down
1 change: 1 addition & 0 deletions tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ impl ServerKey {
const INPUT_CARRY: bool = true;
let flipped_scalar = !scalar;
let decomposed_flipped_scalar =
// We don't use BlockDecomposer::with_block_count as we are doing something special
BlockDecomposer::new(flipped_scalar, self.message_modulus().0.ilog2())
.iter_as::<u8>()
.chain(std::iter::repeat(if scalar < Scalar::ZERO {
Expand Down
Loading