Skip to content

Commit

Permalink
fix(integer): fix cast in scalar_shift/rotate
Browse files Browse the repository at this point in the history
In scalar_shift/rotate, we get the number of bits to shift/rotate
as a generic type, the can be casted to u64.

We compute the total number of bits the ciphertext has, cast that number
to the same type as the scalar, and do "shift % num_bits".

However, if the number of bits computed exceeds the max value the scalar
type can hold, we could end up doing a remainder with 0.

e.g 256bits ciphertext and scalar type u8 => 256u64 casted to u8 results
in 0.

Fix that by casting the scalar value to u64.
  • Loading branch information
tmontaigu committed Jan 23, 2024
1 parent 3e2833a commit 6060882
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 39 deletions.
28 changes: 28 additions & 0 deletions tfhe/src/high_level_api/integers/unsigned/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -809,3 +809,31 @@ fn test_if_then_else() {
if clear_a <= clear_b { clear_b } else { clear_a }
);
}

#[test]
fn test_scalar_shift_when_clear_type_is_small() {
// This is a regression tests
// The goal is to make sure that doing a scalar shift / rotate
// with a clear type that does not have enough bits to represent
// the number of bits of the fhe type correctly works.

let config = ConfigBuilder::default().build();
let (client_key, server_key) = generate_keys(config);
set_server_key(server_key);

let mut a = FheUint256::encrypt(U256::ONE, &client_key);
// The fhe type has 256 bits, the clear type is u8,
// a u8 cannot represent the value '256'.
// This used to result in the shift/rotate panicking
let clear = 1u8;

let _ = &a << clear;
let _ = &a >> clear;
let _ = (&a).rotate_left(clear);
let _ = (&a).rotate_right(clear);

a <<= clear;
a >>= clear;
a.rotate_left_assign(clear);
a.rotate_right_assign(clear);
}
22 changes: 2 additions & 20 deletions tfhe/src/integer/server_key/radix_parallel/scalar_rotate.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::ops::Rem;

use crate::core_crypto::prelude::CastFrom;
use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::ServerKey;
Expand Down Expand Up @@ -43,7 +41,6 @@ impl ServerKey {
pub fn smart_scalar_rotate_right_parallelized<T, Scalar>(&self, ct: &mut T, n: Scalar) -> T
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
if !ct.block_carries_are_empty() {
Expand Down Expand Up @@ -85,7 +82,6 @@ impl ServerKey {
pub fn smart_scalar_rotate_right_assign_parallelized<T, Scalar>(&self, ct: &mut T, n: Scalar)
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
if !ct.block_carries_are_empty() {
Expand Down Expand Up @@ -135,7 +131,6 @@ impl ServerKey {
pub fn scalar_rotate_right_parallelized<T, Scalar>(&self, ct_right: &T, n: Scalar) -> T
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
let mut result = ct_right.clone();
Expand Down Expand Up @@ -174,7 +169,6 @@ impl ServerKey {
pub fn scalar_rotate_right_assign_parallelized<T, Scalar>(&self, ct: &mut T, n: Scalar)
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
if !ct.block_carries_are_empty() {
Expand Down Expand Up @@ -224,7 +218,6 @@ impl ServerKey {
pub fn unchecked_scalar_rotate_right_parallelized<T, Scalar>(&self, ct: &T, n: Scalar) -> T
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
let mut result = ct.clone();
Expand Down Expand Up @@ -275,7 +268,6 @@ impl ServerKey {
n: Scalar,
) where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
// The general idea, is that we know by how much we want to
Expand All @@ -290,9 +282,7 @@ impl ServerKey {
let num_bits_in_message = self.key.message_modulus.0.ilog2() as u64;
let total_num_bits = num_bits_in_message * ct.blocks().len() as u64;

let n = n % Scalar::cast_from(total_num_bits);
let n = u64::cast_from(n);

let n = u64::cast_from(n) % total_num_bits;
if n == 0 {
return;
}
Expand Down Expand Up @@ -382,7 +372,6 @@ impl ServerKey {
pub fn smart_scalar_rotate_left_parallelized<T, Scalar>(&self, ct: &mut T, n: Scalar) -> T
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
if !ct.block_carries_are_empty() {
Expand Down Expand Up @@ -424,7 +413,6 @@ impl ServerKey {
pub fn smart_scalar_rotate_left_assign_parallelized<T, Scalar>(&self, ct: &mut T, n: Scalar)
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
if !ct.block_carries_are_empty() {
Expand Down Expand Up @@ -474,7 +462,6 @@ impl ServerKey {
pub fn scalar_rotate_left_parallelized<T, Scalar>(&self, ct_left: &T, n: Scalar) -> T
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
let mut result = ct_left.clone();
Expand Down Expand Up @@ -513,7 +500,6 @@ impl ServerKey {
pub fn scalar_rotate_left_assign_parallelized<T, Scalar>(&self, ct: &mut T, n: Scalar)
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
if !ct.block_carries_are_empty() {
Expand Down Expand Up @@ -563,7 +549,6 @@ impl ServerKey {
pub fn unchecked_scalar_rotate_left_parallelized<T, Scalar>(&self, ct: &T, n: Scalar) -> T
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
let mut result = ct.clone();
Expand Down Expand Up @@ -611,7 +596,6 @@ impl ServerKey {
pub fn unchecked_scalar_rotate_left_assign_parallelized<T, Scalar>(&self, ct: &mut T, n: Scalar)
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
// The general idea, is that we know by how much we want to
Expand All @@ -626,9 +610,7 @@ impl ServerKey {
let num_bits_in_message = self.key.message_modulus.0.ilog2() as u64;
let total_num_bits = num_bits_in_message * ct.blocks().len() as u64;

let n = u64::cast_from(n);
let n = n % total_num_bits;

let n = u64::cast_from(n) % total_num_bits;
if n == 0 {
return;
}
Expand Down
22 changes: 3 additions & 19 deletions tfhe/src/integer/server_key/radix_parallel/scalar_shift.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::ops::Rem;

use crate::core_crypto::commons::utils::izip;
use crate::core_crypto::prelude::CastFrom;
use crate::integer::ciphertext::IntegerRadixCiphertext;
Expand Down Expand Up @@ -50,7 +48,6 @@ impl ServerKey {
pub fn unchecked_scalar_right_shift_parallelized<T, Scalar>(&self, ct: &T, shift: Scalar) -> T
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
let mut result = ct.clone();
Expand Down Expand Up @@ -97,7 +94,6 @@ impl ServerKey {
shift: Scalar,
) where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
if T::IS_SIGNED {
Expand All @@ -116,7 +112,6 @@ impl ServerKey {
) -> T
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
let mut result = ct.clone();
Expand All @@ -130,7 +125,6 @@ impl ServerKey {
shift: Scalar,
) where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
// The general idea, is that we know by how much we want to shift
Expand All @@ -147,8 +141,7 @@ impl ServerKey {
let num_bits_in_block = self.key.message_modulus.0.ilog2() as u64;
let total_num_bits = num_bits_in_block * ct.blocks().len() as u64;

let shift = shift % Scalar::cast_from(total_num_bits);
let shift = u64::cast_from(shift);
let shift = u64::cast_from(shift) % total_num_bits;
if shift == 0 {
return;
}
Expand Down Expand Up @@ -249,7 +242,6 @@ impl ServerKey {
shift: Scalar,
) where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
// The general idea, is that we know by how much we want to shift
Expand All @@ -266,8 +258,7 @@ impl ServerKey {
let num_bits_in_block = self.key.message_modulus.0.ilog2() as u64;
let total_num_bits = num_bits_in_block * ct.blocks().len() as u64;

let shift = shift % Scalar::cast_from(total_num_bits);
let shift = u64::cast_from(shift);
let shift = u64::cast_from(shift) % total_num_bits;
if shift == 0 {
return;
}
Expand Down Expand Up @@ -405,7 +396,6 @@ impl ServerKey {
pub fn scalar_right_shift_parallelized<T, Scalar>(&self, ct: &T, shift: Scalar) -> T
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
let mut result = ct.clone();
Expand Down Expand Up @@ -451,7 +441,6 @@ impl ServerKey {
pub fn scalar_right_shift_assign_parallelized<T, Scalar>(&self, ct: &mut T, shift: Scalar)
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
if !ct.block_carries_are_empty() {
Expand Down Expand Up @@ -507,7 +496,6 @@ impl ServerKey {
) -> T
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
let mut result = ct_left.clone();
Expand Down Expand Up @@ -556,7 +544,6 @@ impl ServerKey {
shift: Scalar,
) where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
// The general idea, is that we know by how much we want to shift
Expand All @@ -573,8 +560,7 @@ impl ServerKey {
let num_bits_in_block = self.key.message_modulus.0.ilog2() as u64;
let total_num_bits = num_bits_in_block * ct.blocks().len() as u64;

let shift = shift % Scalar::cast_from(total_num_bits);
let shift = u64::cast_from(shift);
let shift = u64::cast_from(shift) % total_num_bits;
if shift == 0 {
return;
}
Expand Down Expand Up @@ -686,7 +672,6 @@ impl ServerKey {
pub fn scalar_left_shift_parallelized<T, Scalar>(&self, ct_left: &T, shift: Scalar) -> T
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
let mut result = ct_left.clone();
Expand Down Expand Up @@ -732,7 +717,6 @@ impl ServerKey {
pub fn scalar_left_shift_assign_parallelized<T, Scalar>(&self, ct: &mut T, shift: Scalar)
where
T: IntegerRadixCiphertext,
Scalar: Rem<Scalar, Output = Scalar> + CastFrom<u64>,
u64: CastFrom<Scalar>,
{
if !ct.block_carries_are_empty() {
Expand Down

0 comments on commit 6060882

Please sign in to comment.