Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Am/feat/zk metadata #1522

Merged
merged 7 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ clippy_c_api: install_rs_check_toolchain

.PHONY: clippy_js_wasm_api # Run clippy lints enabling the boolean, shortint, integer and the js wasm API
clippy_js_wasm_api: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api,high-level-client-js-wasm-api,zk-pok \
-p $(TFHE_SPEC) -- --no-deps -D warnings
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api,high-level-client-js-wasm-api \
-p $(TFHE_SPEC) -- --no-deps -D warnings
Expand Down
7 changes: 5 additions & 2 deletions tfhe-zk-pok/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tfhe-zk-pok"
version = "0.3.0-alpha.0"
version = "0.3.0-alpha.1"
edition = "2021"
keywords = ["zero", "knowledge", "proof", "vector-commitments"]
homepage = "https://zama.ai/"
Expand All @@ -15,7 +15,9 @@ description = "tfhe-zk-pok: An implementation of zero-knowledge proofs of encryp
ark-bls12-381 = { package = "tfhe-ark-bls12-381", version = "0.4.0" }
ark-ec = { package = "tfhe-ark-ec", version = "0.4.2", features = ["parallel"] }
ark-ff = { package = "tfhe-ark-ff", version = "0.4.3", features = ["parallel"] }
ark-poly = { package = "tfhe-ark-poly", version = "0.4.2", features = ["parallel"] }
ark-poly = { package = "tfhe-ark-poly", version = "0.4.2", features = [
"parallel",
] }
ark-serialize = { version = "0.4.2" }
rand = "0.8.5"
rayon = "1.8.0"
Expand All @@ -26,3 +28,4 @@ num-bigint = "0.4.5"

[dev-dependencies]
serde_json = "~1.0"
itertools = "0.11.0"
156 changes: 100 additions & 56 deletions tfhe-zk-pok/src/proofs/pke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ pub fn commit<G: Curve>(
pub fn prove<G: Curve>(
public: (&PublicParams<G>, &PublicCommit<G>),
private_commit: &PrivateCommit<G>,
metadata: &[u8],
load: ComputeLoad,
rng: &mut dyn RngCore,
) -> Proof<G> {
Expand Down Expand Up @@ -347,7 +348,10 @@ pub fn prove<G: Curve>(
.collect::<Box<_>>();

let mut y = vec![G::Zp::ZERO; n];
G::Zp::hash(&mut y, &[hash, x_bytes, c_hat.to_bytes().as_ref()]);
G::Zp::hash(
&mut y,
&[hash, metadata, x_bytes, c_hat.to_bytes().as_ref()],
);
let y = OneBased(y);

let scalars = (n + 1 - big_d..n + 1)
Expand All @@ -360,6 +364,7 @@ pub fn prove<G: Curve>(
&mut theta,
&[
hash_lmap,
metadata,
x_bytes,
c_hat.to_bytes().as_ref(),
c_y.to_bytes().as_ref(),
Expand All @@ -379,6 +384,7 @@ pub fn prove<G: Curve>(
&mut t,
&[
hash_t,
metadata,
&(1..n + 1)
.flat_map(|i| y[i].to_bytes().as_ref().to_vec())
.collect::<Box<_>>(),
Expand All @@ -394,6 +400,7 @@ pub fn prove<G: Curve>(
&mut delta,
&[
hash_agg,
metadata,
x_bytes,
c_hat.to_bytes().as_ref(),
c_y.to_bytes().as_ref(),
Expand Down Expand Up @@ -472,6 +479,7 @@ pub fn prove<G: Curve>(
core::array::from_mut(&mut z),
&[
hash_z,
metadata,
x_bytes,
c_hat.to_bytes().as_ref(),
c_y.to_bytes().as_ref(),
Expand Down Expand Up @@ -512,6 +520,7 @@ pub fn prove<G: Curve>(
core::array::from_mut(&mut w),
&[
hash_w,
metadata,
x_bytes,
c_hat.to_bytes().as_ref(),
c_y.to_bytes().as_ref(),
Expand Down Expand Up @@ -698,6 +707,7 @@ fn compute_a_theta<G: Curve>(
pub fn verify<G: Curve>(
proof: &Proof<G>,
public: (&PublicParams<G>, &PublicCommit<G>),
metadata: &[u8],
) -> Result<(), ()> {
let &Proof {
c_hat,
Expand Down Expand Up @@ -760,14 +770,18 @@ pub fn verify<G: Curve>(
.collect::<Box<_>>();

let mut y = vec![G::Zp::ZERO; n];
G::Zp::hash(&mut y, &[hash, x_bytes, c_hat.to_bytes().as_ref()]);
G::Zp::hash(
&mut y,
&[hash, metadata, x_bytes, c_hat.to_bytes().as_ref()],
);
let y = OneBased(y);

let mut theta = vec![G::Zp::ZERO; d + k + 1];
G::Zp::hash(
&mut theta,
&[
hash_lmap,
metadata,
x_bytes,
c_hat.to_bytes().as_ref(),
c_y.to_bytes().as_ref(),
Expand All @@ -792,6 +806,7 @@ pub fn verify<G: Curve>(
&mut t,
&[
hash_t,
metadata,
&(1..n + 1)
.flat_map(|i| y[i].to_bytes().as_ref().to_vec())
.collect::<Box<_>>(),
Expand All @@ -807,6 +822,7 @@ pub fn verify<G: Curve>(
&mut delta,
&[
hash_agg,
metadata,
x_bytes,
c_hat.to_bytes().as_ref(),
c_y.to_bytes().as_ref(),
Expand All @@ -821,6 +837,7 @@ pub fn verify<G: Curve>(
core::array::from_mut(&mut z),
&[
hash_z,
metadata,
x_bytes,
c_hat.to_bytes().as_ref(),
c_y.to_bytes().as_ref(),
Expand Down Expand Up @@ -873,6 +890,7 @@ pub fn verify<G: Curve>(
core::array::from_mut(&mut w),
&[
hash_w,
metadata,
x_bytes,
c_hat.to_bytes().as_ref(),
c_y.to_bytes().as_ref(),
Expand Down Expand Up @@ -1053,6 +1071,15 @@ mod tests {
.wrapping_add((delta * m[i] as u64) as i64);
}

// One of our usecases uses 320 bits of additional metadata
const METADATA_LEN: usize = (320 / u8::BITS) as usize;

let mut metadata = [0u8; METADATA_LEN];
metadata.fill_with(|| rng.gen::<u8>());

let mut fake_metadata = [255u8; METADATA_LEN];
fake_metadata.fill_with(|| rng.gen::<u8>());

let mut m_roundtrip = vec![0i64; k];
for i in 0..k {
let mut dot = 0i128;
Expand Down Expand Up @@ -1093,60 +1120,77 @@ mod tests {
let public_param_that_was_not_compressed =
serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap();

for public_param in [
original_public_param,
public_param_that_was_compressed,
public_param_that_was_not_compressed,
] {
for use_fake_e1 in [false, true] {
for use_fake_e2 in [false, true] {
for use_fake_m in [false, true] {
for use_fake_r in [false, true] {
let (public_commit, private_commit) = commit(
a.clone(),
b.clone(),
c1.clone(),
c2.clone(),
if use_fake_r {
fake_r.clone()
} else {
r.clone()
},
if use_fake_e1 {
fake_e1.clone()
} else {
e1.clone()
},
if use_fake_m {
fake_m.clone()
} else {
m.clone()
},
if use_fake_e2 {
fake_e2.clone()
} else {
e2.clone()
},
&public_param,
rng,
);

for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
let proof = prove(
(&public_param, &public_commit),
&private_commit,
load,
rng,
);

assert_eq!(
verify(&proof, (&public_param, &public_commit)).is_err(),
use_fake_e1 || use_fake_e2 || use_fake_r || use_fake_m
);
}
}
}
}
for (
public_param,
use_fake_e1,
use_fake_e2,
use_fake_m,
use_fake_r,
use_fake_metadata_verify,
) in itertools::iproduct!(
[
original_public_param,
public_param_that_was_compressed,
public_param_that_was_not_compressed,
],
[false, true],
[false, true],
[false, true],
[false, true],
[false, true]
) {
let (public_commit, private_commit) = commit(
a.clone(),
b.clone(),
c1.clone(),
c2.clone(),
if use_fake_r {
fake_r.clone()
} else {
r.clone()
},
if use_fake_e1 {
fake_e1.clone()
} else {
e1.clone()
},
if use_fake_m {
fake_m.clone()
} else {
m.clone()
},
if use_fake_e2 {
fake_e2.clone()
} else {
e2.clone()
},
&public_param,
rng,
);

for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
let proof = prove(
(&public_param, &public_commit),
&private_commit,
&metadata,
load,
rng,
);

let verify_metadata = if use_fake_metadata_verify {
&fake_metadata
} else {
&metadata
};

assert_eq!(
verify(&proof, (&public_param, &public_commit), verify_metadata).is_err(),
use_fake_e1
|| use_fake_e2
|| use_fake_r
|| use_fake_m
|| use_fake_metadata_verify
);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions tfhe/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tfhe"
version = "0.8.0-alpha.4"
version = "0.8.0-alpha.5"
edition = "2021"
readme = "../README.md"
keywords = ["fully", "homomorphic", "encryption", "fhe", "cryptography"]
Expand Down Expand Up @@ -75,7 +75,7 @@ sha3 = { version = "0.10", optional = true }
# While we wait for repeat_n in rust standard library
itertools = "0.11.0"
rand_core = { version = "0.6.4", features = ["std"] }
tfhe-zk-pok = { version = "0.3.0-alpha.0", path = "../tfhe-zk-pok", optional = true }
tfhe-zk-pok = { version = "0.3.0-alpha.1", path = "../tfhe-zk-pok", optional = true }
tfhe-versionable = { version = "0.2.1", path = "../utils/tfhe-versionable" }

# wasm deps
Expand Down
18 changes: 15 additions & 3 deletions tfhe/benches/integer/zk_pke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
mod utilities;

use criterion::{criterion_group, criterion_main, Criterion};
use rand::prelude::*;
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::Path;
Expand Down Expand Up @@ -49,6 +50,11 @@ fn pke_zk_proof(c: &mut Criterion) {
let _casting_key =
KeySwitchingKey::new((&compact_private_key, None), (&cks, &sks), _param_casting);

// We have a use case with 320 bits of metadata
let mut metadata = [0u8; (320 / u8::BITS) as usize];
let mut rng = rand::thread_rng();
metadata.fill_with(|| rng.gen());

for bits in [640usize, 1280, 4096] {
assert_eq!(bits % 64, 0);
// Packing, so we take the message and carry modulus to compute our block count
Expand Down Expand Up @@ -77,7 +83,7 @@ fn pke_zk_proof(c: &mut Criterion) {
b.iter(|| {
let _ct1 = tfhe::integer::ProvenCompactCiphertextList::builder(&pk)
.extend(messages.iter().copied())
.build_with_proof_packed(public_params, compute_load)
.build_with_proof_packed(public_params, &metadata, compute_load)
.unwrap();
})
});
Expand Down Expand Up @@ -129,6 +135,11 @@ fn pke_zk_verify(c: &mut Criterion, results_file: &Path) {
let casting_key =
KeySwitchingKey::new((&compact_private_key, None), (&cks, &sks), param_casting);

// We have a use case with 320 bits of metadata
let mut metadata = [0u8; (320 / u8::BITS) as usize];
let mut rng = rand::thread_rng();
metadata.fill_with(|| rng.gen());

for bits in [640usize, 1280, 4096] {
assert_eq!(bits % 64, 0);
// Packing, so we take the message and carry modulus to compute our block count
Expand Down Expand Up @@ -184,7 +195,7 @@ fn pke_zk_verify(c: &mut Criterion, results_file: &Path) {
println!("Generating proven ciphertext ({zk_load})... ");
let ct1 = tfhe::integer::ProvenCompactCiphertextList::builder(&pk)
.extend(messages.iter().copied())
.build_with_proof_packed(public_params, compute_load)
.build_with_proof_packed(public_params, &metadata, compute_load)
.unwrap();

let proven_ciphertext_list_serialized = bincode::serialize(&ct1).unwrap();
Expand Down Expand Up @@ -231,7 +242,7 @@ fn pke_zk_verify(c: &mut Criterion, results_file: &Path) {

bench_group.bench_function(&bench_id_verify, |b| {
b.iter(|| {
let _ret = ct1.verify(public_params, &pk);
let _ret = ct1.verify(public_params, &pk, &metadata);
});
});

Expand All @@ -241,6 +252,7 @@ fn pke_zk_verify(c: &mut Criterion, results_file: &Path) {
.verify_and_expand(
public_params,
&pk,
&metadata,
IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(&sks),
IntegerCompactCiphertextListCastingMode::CastIfNecessary(
casting_key.as_view(),
Expand Down
Loading
Loading