Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent key for key-switch and extract #134

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 97 additions & 51 deletions fhe_core/src/key_switch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@ use algebra::{FieldDiscreteGaussianSampler, NTTField, NTTPolynomial, Polynomial}
use lattice::{DecompositionSpace, NTTGadgetRLWE, PolynomialSpace, LWE, NTTRLWE, RLWE};
use rand::{CryptoRng, Rng};

use crate::{BlindRotationType, LWEModulusType, NTRUCiphertext, SecretKeyPack};
use crate::{LWEModulusType, NTRUCiphertext, SecretKeyPack};

#[derive(Debug, Clone, Copy)]
enum Operation {
AddAMulS,
SubAMulS,
}
Comment on lines +9 to +13
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the meaning of this enum type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perform b-a*s or b+a*s.


/// The Key Switching Key.
///
Expand Down Expand Up @@ -39,18 +45,18 @@ impl<F: NTTField> KeySwitchingKey<F> {

assert!(extended_lwe_dimension <= ring_dimension);

// negative convertion
// convertion
let convert = |v: &C| {
if *v == C::ZERO {
F::zero()
} else if *v == C::ONE {
F::neg_one()
} else {
F::one()
} else {
F::neg_one()
}
};

// s = [s_0, 0,..., 0, -s_{n-1},..., -s_1]
// s = [s_0, s_1,..., s_{n-1}, 0,..., 0]
let mut s = <Polynomial<F>>::new(
secret_key_pack
.lwe_secret_key()
Expand All @@ -59,20 +65,14 @@ impl<F: NTTField> KeySwitchingKey<F> {
.collect(),
);
s.resize(extended_lwe_dimension, F::zero());
s[0] = -s[0];
s[1..].reverse();

let lwe_sk = s.into_ntt_polynomial();

let len = key_switching_basis.decompose_len();
let basis = F::new(key_switching_basis.basis());
let blind_rotation_type = parameters.blind_rotation_type();

let key = if extended_lwe_dimension == ring_dimension {
let mut ring_sk = match blind_rotation_type {
BlindRotationType::RLWE => -secret_key_pack.ntt_ring_secret_key(),
BlindRotationType::NTRU => secret_key_pack.ntt_ring_secret_key().clone(),
};
let mut ring_sk = secret_key_pack.ntt_ring_secret_key().clone();

let k = (0..len)
.map(|i| {
Expand All @@ -90,32 +90,12 @@ impl<F: NTTField> KeySwitchingKey<F> {
.collect();
vec![NTTGadgetRLWE::new(k, key_switching_basis)]
} else {
let (mut key, mut store): (Vec<Polynomial<F>>, F) = match blind_rotation_type {
BlindRotationType::RLWE => (
secret_key_pack
.ring_secret_key()
.as_slice()
.rchunks_exact(extended_lwe_dimension)
.map(|part| -Polynomial::from_slice(part))
.collect(),
-secret_key_pack.ring_secret_key()[0],
),
BlindRotationType::NTRU => (
secret_key_pack
.ring_secret_key()
.as_slice()
.rchunks_exact(extended_lwe_dimension)
.map(|part| Polynomial::from_slice(part))
.collect(),
secret_key_pack.ring_secret_key()[0],
),
};

for k_i in &mut key {
let temp = -k_i[0];
k_i[0] = store;
store = temp;
}
let key: Vec<Polynomial<F>> = secret_key_pack
.ring_secret_key()
.as_slice()
.chunks_exact(extended_lwe_dimension)
.map(|part| Polynomial::from_slice(part))
.collect();

key.into_iter()
.map(|z| {
Expand Down Expand Up @@ -143,49 +123,115 @@ impl<F: NTTField> KeySwitchingKey<F> {
}

/// Performs key switching operation.
pub fn key_switch_for_rlwe(&self, ciphertext: &RLWE<F>) -> LWE<F> {
pub fn key_switch_for_rlwe(&self, mut ciphertext: RLWE<F>) -> LWE<F> {
let extended_lwe_dimension = self.lwe_dimension.next_power_of_two();

let init = <NTTRLWE<F>>::new(
NTTPolynomial::zero(extended_lwe_dimension),
NTTPolynomial::new(vec![ciphertext.b()[0]; extended_lwe_dimension]),
);

if ciphertext.a_slice().len() != extended_lwe_dimension {
let a = ciphertext.a_mut_slice();
a[0] = -a[0];
a[1..].reverse();
a.chunks_exact_mut(extended_lwe_dimension)
.for_each(|chunk| {
chunk[0] = -chunk[0];
chunk[1..].reverse();
});
}

let iter = ciphertext.a_slice().chunks_exact(extended_lwe_dimension);

self.key_switch_inner(extended_lwe_dimension, init, iter)
self.key_switch_inner(extended_lwe_dimension, init, iter, Operation::SubAMulS)
}

/// Performs key switching operation.
pub fn key_switch_for_ntru(&self, ciphertext: &NTRUCiphertext<F>) -> LWE<F> {
pub fn key_switch_for_ntru(&self, mut ciphertext: NTRUCiphertext<F>) -> LWE<F> {
let extended_lwe_dimension = self.lwe_dimension.next_power_of_two();

// Because the lwe ciphertext extracted from a ntru ciphertext always has `b = 0`.
let init = <NTTRLWE<F>>::zero(extended_lwe_dimension);

if ciphertext.as_slice().len() != extended_lwe_dimension {
let a = ciphertext.as_mut_slice();
a[0] = -a[0];
a[1..].reverse();
a.chunks_exact_mut(extended_lwe_dimension)
.for_each(|chunk| {
chunk[0] = -chunk[0];
chunk[1..].reverse();
});
}

let iter = ciphertext.as_slice().chunks_exact(extended_lwe_dimension);

self.key_switch_inner(extended_lwe_dimension, init, iter)
self.key_switch_inner(extended_lwe_dimension, init, iter, Operation::AddAMulS)
}

/// Performs key switching operation.
pub fn key_switch_for_lwe(&self, mut ciphertext: LWE<F>) -> LWE<F> {
let extended_lwe_dimension = self.lwe_dimension.next_power_of_two();

let init = <NTTRLWE<F>>::new(
NTTPolynomial::zero(extended_lwe_dimension),
NTTPolynomial::new(vec![ciphertext.b(); extended_lwe_dimension]),
);

if ciphertext.a().len() != extended_lwe_dimension {
let a = ciphertext.a_mut();
a.chunks_exact_mut(extended_lwe_dimension)
.for_each(|chunk| {
chunk[1..].reverse();
chunk[1..].iter_mut().for_each(|v| *v = -*v);
});
} else {
let a = ciphertext.a_mut();
a[1..].reverse();
a[1..].iter_mut().for_each(|v| *v = -*v);
}

let iter = ciphertext.a().chunks_exact(extended_lwe_dimension);

self.key_switch_inner(extended_lwe_dimension, init, iter, Operation::SubAMulS)
}

fn key_switch_inner(
&self,
extended_lwe_dimension: usize,
mut init: NTTRLWE<F>,
iter: ChunksExact<F>,
op: Operation,
) -> LWE<F> {
let mut polynomial_space = PolynomialSpace::new(extended_lwe_dimension);
let mut decompose_space = DecompositionSpace::new(extended_lwe_dimension);

self.key.iter().zip(iter).for_each(|(k_i, a_i)| {
polynomial_space.copy_from(a_i);
init.add_assign_gadget_rlwe_mul_polynomial_inplace_fast(
k_i,
&mut polynomial_space,
&mut decompose_space,
);
});
match op {
Operation::AddAMulS => {
self.key.iter().zip(iter).for_each(|(k_i, a_i)| {
polynomial_space.copy_from(a_i);

init.add_assign_gadget_rlwe_mul_polynomial_inplace_fast(
k_i,
&mut polynomial_space,
&mut decompose_space,
);
});
}
Operation::SubAMulS => {
self.key.iter().zip(iter).for_each(|(k_i, a_i)| {
polynomial_space.copy_from(a_i);

init.sub_assign_gadget_rlwe_mul_polynomial_inplace_fast(
k_i,
&mut polynomial_space,
&mut decompose_space,
);
});
}
}

<RLWE<F>>::from(init).extract_partial_lwe_reverse_locally(self.lwe_dimension)
<RLWE<F>>::from(init).extract_partial_lwe_locally(self.lwe_dimension)
}
}
15 changes: 5 additions & 10 deletions fhe_core/src/secret_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,19 @@ impl<C: LWEModulusType, F: NTTField> SecretKeyPack<C, F> {
parameters.ring_secret_key_type() == RingSecretKeyType::Binary
|| parameters.ring_secret_key_type() == RingSecretKeyType::Ternary
);
// negative convertion
// convertion
let convert = |v: &C| {
if *v == C::ZERO {
F::zero()
} else if *v == C::ONE {
F::neg_one()
} else {
F::one()
} else {
F::neg_one()
}
};

// s = [s_0, -s_{n-1},..., -s_1]
let mut s =
<Polynomial<F>>::new(lwe_secret_key.iter().map(convert).collect());
s[0] = -s[0];
s[1..].reverse();

s
// s = [s_0, s_1,..., s_{n-1}]
<Polynomial<F>>::new(lwe_secret_key.iter().map(convert).collect())
}
};
ntt_ring_secret_key = ring_secret_key.clone().into_ntt_polynomial();
Expand Down
17 changes: 15 additions & 2 deletions lattice/src/rlwe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,19 @@ impl<F: NTTField> RLWE<F> {
LWE::<F>::new(a, b[0])
}

/// Extract an LWE sample from RLWE.
#[inline]
pub fn extract_partial_lwe_locally(self, dimension: usize) -> LWE<F> {
let Self { a, b } = self;

let mut a = a.data();
a[1..].reverse();
a[1..].iter_mut().for_each(|v| *v = -*v);

a.truncate(dimension);
LWE::<F>::new(a, b[0])
}

/// Perform `destination = self * (Y^r - 1)` for bootstrapping where `Y = X^(2N/q)`.
pub fn mul_monic_monomial_sub_one_inplace<T: NumCast>(
&self, // N
Expand Down Expand Up @@ -957,7 +970,7 @@ impl<F: NTTField> NTTRLWE<F> {
pub fn sub_assign_gadget_rlwe_mul_polynomial_inplace_fast(
&mut self,
gadget_rlwe: &NTTGadgetRLWE<F>,
polynomial: Polynomial<F>,
polynomial: &mut Polynomial<F>,
decompose_space: &mut DecompositionSpace<F>,
) {
let coeff_count = polynomial.coeff_count();
Expand All @@ -966,7 +979,7 @@ impl<F: NTTField> NTTRLWE<F> {
let decompose_space = decompose_space.get_mut();
let basis = gadget_rlwe.basis();

let mut polynomial = -polynomial;
polynomial.neg_assign();

gadget_rlwe.iter().for_each(|g| {
polynomial.decompose_lsb_bits_inplace(basis, decompose_space.as_mut_slice());
Expand Down
4 changes: 2 additions & 2 deletions zkfhe/src/bfhe/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl<C: LWEModulusType, F: NTTField> EvaluationKey<C, F> {

match parameters.steps_after_blind_rotation() {
StepsAfterBR::KsMs => {
let key_switched = self.key_switching_key.key_switch_for_rlwe(&acc);
let key_switched = self.key_switching_key.key_switch_for_rlwe(acc);

lwe_modulus_switch_inplace(
key_switched,
Expand All @@ -81,7 +81,7 @@ impl<C: LWEModulusType, F: NTTField> EvaluationKey<C, F> {
);
}
StepsAfterBR::Ms => {
let lwe = acc.extract_lwe_reverse_locally();
let lwe = acc.extract_lwe_locally();

lwe_modulus_switch_inplace(
lwe,
Expand Down
2 changes: 1 addition & 1 deletion zkfhe/src/ntru_bfhe/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl<C: LWEModulusType, F: NTTField> EvaluationKey<C, F> {
.step_by(twice_ntru_dimension_div_lwe_modulus)
.for_each(|v| *v += half_delta);

let key_switched = self.key_switching_key.key_switch_for_ntru(&acc);
let key_switched = self.key_switching_key.key_switch_for_ntru(acc);

let round_method = parameters.modulus_switch_round_method();

Expand Down