Skip to content

Commit

Permalink
fix: BlockDecomposer
Browse files Browse the repository at this point in the history
The BlockDecomposer gave the possibility when the number of bits per
block was not a multiple of the number of bits in the original integer
to force the extra bits of the last block to a particular value.

However, the way this was done could only work when setting these bits
to 1, when wanting to set them to 0 it would not work.

Good news is that we actually never wanted to set them to 0,
but it should still be fixed for completeness, and allow other
feature to be added without bugs
  • Loading branch information
tmontaigu committed Mar 4, 2025
1 parent 371e823 commit 718f46f
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 52 deletions.
133 changes: 91 additions & 42 deletions tfhe/src/integer/block_decomposition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::core_crypto::prelude::{CastFrom, CastInto, Numeric};
use crate::integer::bigint::static_signed::StaticSignedBigInt;
use crate::integer::bigint::static_unsigned::StaticUnsignedBigInt;
use core::ops::{AddAssign, BitAnd, ShlAssign, ShrAssign};
use std::ops::{BitOrAssign, Shl, Sub};
use std::ops::{BitOrAssign, Not, Shl, Shr, Sub};

// These work for signed number as rust uses 2-Complements
// And Arithmetic shift for signed number (logical for unsigned)
Expand All @@ -14,8 +14,10 @@ pub trait Decomposable:
+ ShrAssign<u32>
+ Eq
+ CastFrom<u32>
+ Shr<u32, Output = Self>
+ Shl<u32, Output = Self>
+ BitOrAssign<Self>
+ Not<Output = Self>
{
}
pub trait Recomposable:
Expand Down Expand Up @@ -86,33 +88,48 @@ impl<const N: usize> RecomposableFrom<u8> for StaticUnsignedBigInt<N> {}
impl<const N: usize> DecomposableInto<u64> for StaticUnsignedBigInt<N> {}
impl<const N: usize> DecomposableInto<u8> for StaticUnsignedBigInt<N> {}

#[derive(Copy, Clone)]
#[repr(u32)]
pub enum PaddingBitValue {
Zero = 0,
One = 1,
}

#[derive(Clone)]
pub struct BlockDecomposer<T> {
data: T,
bit_mask: T,
num_bits_in_mask: u32,
num_bits_valid: u32,
padding_bit: T,
padding_bit: Option<PaddingBitValue>,
limit: Option<T>,
}

impl<T> BlockDecomposer<T>
where
T: Decomposable,
{
/// Creates a block decomposer that will stop when the value reaches zero
pub fn with_early_stop_at_zero(value: T, bits_per_block: u32) -> Self {
Self::new_(value, bits_per_block, Some(T::ZERO), T::ZERO)
Self::new_(value, bits_per_block, Some(T::ZERO), None)
}

pub fn with_padding_bit(value: T, bits_per_block: u32, padding_bit: T) -> Self {
Self::new_(value, bits_per_block, None, padding_bit)
/// Creates a block decomposer that will set the surplus bits to a specific value
/// when bits_per_block is not a multiple of T::BITS
pub fn with_padding_bit(value: T, bits_per_block: u32, padding_bit: PaddingBitValue) -> Self {
Self::new_(value, bits_per_block, None, Some(padding_bit))
}

pub fn new(value: T, bits_per_block: u32) -> Self {
Self::new_(value, bits_per_block, None, T::ZERO)
Self::new_(value, bits_per_block, None, None)
}

fn new_(value: T, bits_per_block: u32, limit: Option<T>, padding_bit: T) -> Self {
fn new_(
value: T,
bits_per_block: u32,
limit: Option<T>,
padding_bit: Option<PaddingBitValue>,
) -> Self {
assert!(bits_per_block <= T::BITS as u32);
let num_bits_valid = T::BITS as u32;

Expand All @@ -129,6 +146,31 @@ where
padding_bit,
}
}

// We concretize the iterator type to allow usage of callbacks working on iterator for generic
// integer encryption
pub fn iter_as<V>(self) -> std::iter::Map<Self, fn(T) -> V>
where
V: Numeric,
T: CastInto<V>,
{
assert!(self.num_bits_in_mask <= V::BITS as u32);
self.map(CastInto::cast_into)
}

pub fn next_as<V>(&mut self) -> Option<V>
where
V: CastFrom<T>,
{
self.next().map(|masked| V::cast_from(masked))
}

pub fn checked_next_as<V>(&mut self) -> Option<V>
where
V: TryFrom<T>,
{
self.next().and_then(|masked| V::try_from(masked).ok())
}
}

impl<T> Iterator for BlockDecomposer<T>
Expand Down Expand Up @@ -159,11 +201,18 @@ where

if self.num_bits_valid < self.num_bits_in_mask {
// This will be the case when self.num_bits_in_mask is not a multiple
// of T::BITS. We replace bits that
// do not come from the actual T but from the padding
// intoduced by the shift, to a specific value.
for i in self.num_bits_valid..self.num_bits_in_mask {
masked |= self.padding_bit << i;
// of T::BITS.
//
// We replace bits that do not come from the actual T but from the padding
// introduced by the shift, to a specific value, if one was provided.
if let Some(padding_bit) = self.padding_bit {
let padding_mask = (self.bit_mask >> self.num_bits_valid) << self.num_bits_valid;
masked = masked & !padding_mask;

let padding_bit = T::cast_from(padding_bit as u32);
for i in self.num_bits_valid..self.num_bits_in_mask {
masked |= padding_bit << i;
}
}
}

Expand All @@ -184,36 +233,6 @@ where
}
}

impl<T> BlockDecomposer<T>
where
T: Decomposable,
{
// We concretize the iterator type to allow usage of callbacks working on iterator for generic
// integer encryption
pub fn iter_as<V>(self) -> std::iter::Map<Self, fn(T) -> V>
where
V: Numeric,
T: CastInto<V>,
{
assert!(self.num_bits_in_mask <= V::BITS as u32);
self.map(CastInto::cast_into)
}

pub fn next_as<V>(&mut self) -> Option<V>
where
V: CastFrom<T>,
{
self.next().map(|masked| V::cast_from(masked))
}

pub fn checked_next_as<V>(&mut self) -> Option<V>
where
V: TryFrom<T>,
{
self.next().and_then(|masked| V::try_from(masked).ok())
}
}

pub struct BlockRecomposer<T> {
data: T,
bit_mask: T,
Expand Down Expand Up @@ -310,6 +329,36 @@ mod tests {
assert_eq!(expected_blocks, blocks);
}

#[test]
fn test_bit_block_decomposer_3() {
let bits_per_block = 3;

let value = -1i8;
let blocks = BlockDecomposer::new(value, bits_per_block)
.iter_as::<u64>()
.collect::<Vec<_>>();
// We expect the last block padded with 1s as a consequence of arithmetic shift
let expected_blocks = vec![7, 7, 7];
assert_eq!(expected_blocks, blocks);

let value = i8::MIN;
let blocks = BlockDecomposer::new(value, bits_per_block)
.iter_as::<u64>()
.collect::<Vec<_>>();
// We expect the last block padded with 1s as a consequence of arithmetic shift
let expected_blocks = vec![0, 0, 6];
assert_eq!(expected_blocks, blocks);

let value = -1i8;
let blocks =
BlockDecomposer::with_padding_bit(value, bits_per_block, PaddingBitValue::Zero)
.iter_as::<u64>()
.collect::<Vec<_>>();
// We expect the last block padded with 0s as we force that
let expected_blocks = vec![7, 7, 3];
assert_eq!(expected_blocks, blocks);
}

#[test]
fn test_bit_block_decomposer_recomposer_carry_handling_in_between() {
let value = u16::MAX as u32;
Expand Down
15 changes: 5 additions & 10 deletions tfhe/src/integer/server_key/radix/scalar_sub.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::core_crypto::prelude::Numeric;
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto, PaddingBitValue};
use crate::integer::ciphertext::{IntegerRadixCiphertext, RadixCiphertext};
use crate::integer::server_key::CheckError;
use crate::integer::ServerKey;
Expand Down Expand Up @@ -92,17 +92,12 @@ impl ServerKey {
// The only case where these msb could become 0 after the addition
// is if scalar == T::ZERO (=> !T::ZERO == T::MAX => T::MAX + 1 == overflow),
// but this case has been handled earlier.
let padding_bit = 1u32; // To handle when bits is not a multiple of T::BITS
// All bits of message set to one
let pad_block = (1 << bits_in_message as u8) - 1;

let decomposer = BlockDecomposer::with_padding_bit(
neg_scalar,
bits_in_message,
Scalar::cast_from(padding_bit),
)
.iter_as::<u8>()
.chain(std::iter::repeat(pad_block));
let decomposer =
BlockDecomposer::with_padding_bit(neg_scalar, bits_in_message, PaddingBitValue::One)
.iter_as::<u8>()
.chain(std::iter::repeat(pad_block));
Some(decomposer)
}

Expand Down

0 comments on commit 718f46f

Please sign in to comment.