Skip to content

Commit

Permalink
feat: groth16 return object in-memory (#268)
Browse files Browse the repository at this point in the history
* chore: add groth16_prove_with_catch methods

* chore: add groth16_prove_with_catch methods

* chore: add groth16_prove_with_cache methods

* chore: modify function name

* chore: change function name

---------

Co-authored-by: eigmax <[email protected]>
  • Loading branch information
ibmp33 and eigmax authored Jul 19, 2024
1 parent c4a24fc commit 31bda35
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 65 deletions.
147 changes: 147 additions & 0 deletions groth16/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,34 @@ pub fn groth16_setup(
Ok(())
}

#[cfg(not(any(feature = "cuda", feature = "opencl")))]
#[allow(clippy::large_enum_variant)]
pub enum SetupResult {
BN128(CircomCircuit<Bn256>, Parameters<Bn256>, VerifyingKey<Bn256>),
BLS12381(CircomCircuit<Bls12>, Parameters<Bls12>, VerifyingKey<Bls12>),
}

#[cfg(not(any(feature = "cuda", feature = "opencl")))]
pub fn groth16_setup_inplace(curve_type: &str, circuit_file: &str) -> Result<SetupResult> {
let mut rng = rand::thread_rng();
let result = match curve_type {
"BN128" => {
let circuit = create_circuit_from_file::<Bn256>(circuit_file, None);
let (pk, vk) = Groth16::circuit_specific_setup(circuit.clone(), &mut rng)?;
SetupResult::BN128(circuit, pk, vk)
}
"BLS12381" => {
let circuit = create_circuit_from_file::<Bls12>(circuit_file, None);
let (pk, vk) = Groth16::circuit_specific_setup(circuit.clone(), &mut rng)?;
SetupResult::BLS12381(circuit, pk, vk)
}
_ => {
bail!(format!("Unknown curve type: {}", curve_type))
}
};
Ok(result)
}

#[cfg(any(feature = "cuda", feature = "opencl"))]
pub fn groth16_setup(
curve_type: &str,
Expand All @@ -88,6 +116,33 @@ pub fn groth16_setup(
Ok(())
}

#[cfg(any(feature = "cuda", feature = "opencl"))]
#[allow(clippy::large_enum_variant)]
pub enum SetupResult {
BLS12381(
CircomCircuit<Scalar>,
Parameters<Bls12>,
VerifyingKey<Bls12>,
),
}

#[cfg(any(feature = "cuda", feature = "opencl"))]
pub fn groth16_setup_inplace(curve_type: &str, circuit_file: &str) -> Result<SetupResult> {
let mut rng = rand::thread_rng();
let result = match curve_type {
"BLS12381" => {
let circuit = create_circuit_from_file::<Scalar>(circuit_file, None);
let (pk, vk): (Parameters<Bls12>, VerifyingKey<Bls12>) =
Groth16::circuit_specific_setup(circuit.clone(), &mut rng)?;
SetupResult::BLS12381(circuit, pk, vk)
}
_ => {
bail!(format!("Unknown curve type: {}", curve_type))
}
};
Ok(result)
}

#[cfg(not(any(feature = "cuda", feature = "opencl")))]
#[allow(clippy::too_many_arguments)]
pub fn groth16_prove(
Expand Down Expand Up @@ -151,6 +206,31 @@ pub fn groth16_prove(
Ok(())
}

#[cfg(not(any(feature = "cuda", feature = "opencl")))]
#[allow(clippy::too_many_arguments)]
pub fn groth16_prove_inplace<E: Engine + crate::json_utils::Parser>(
curve_type: &str,
circuit: CircomCircuit<E>,
wtns_file: &str,
pk: Parameters<E>,
input_file: &str,
public_input_file: &str,
proof_file: &str,
to_hex: bool,
) -> Result<()> {
let mut rng = rand::thread_rng();
let mut wtns = WitnessCalculator::from_file(wtns_file)?;
let inputs = load_input_for_witness(input_file);
let w = wtns.calculate_witness(inputs, false)?;
let circuit1 = create_circuit_add_witness(circuit, w);
let proof = Groth16::prove(&pk, circuit1.clone(), &mut rng)?;
let proof_json = serialize_proof(&proof, curve_type, to_hex)?;
std::fs::write(proof_file, proof_json)?;
let input_json = circuit1.get_public_inputs_json();
std::fs::write(public_input_file, input_json)?;
Ok(())
}

#[cfg(any(feature = "cuda", feature = "opencl"))]
#[allow(clippy::too_many_arguments)]
pub fn groth16_prove(
Expand Down Expand Up @@ -197,6 +277,31 @@ pub fn groth16_prove(
Ok(())
}

#[cfg(any(feature = "cuda", feature = "opencl"))]
#[allow(clippy::too_many_arguments)]
pub fn groth16_prove_inplace(
curve_type: &str,
circuit: CircomCircuit<Scalar>,
wtns_file: &str,
pk: Parameters<Bls12>,
input_file: &str,
public_input_file: &str,
proof_file: &str,
to_hex: bool,
) -> Result<()> {
let mut rng = rand::thread_rng();
let mut wtns = WitnessCalculator::from_file(wtns_file)?;
let inputs = load_input_for_witness(input_file);
let w = wtns.calculate_witness(inputs, false)?;
let circuit1 = create_circuit_add_witness(circuit, w);
let proof = Groth16::prove(&pk, circuit1.clone(), &mut rng)?;
let proof_json = serialize_proof(&proof, curve_type, to_hex)?;
std::fs::write(proof_file, proof_json)?;
let input_json = circuit1.get_public_inputs_json();
std::fs::write(public_input_file, input_json)?;
Ok(())
}

#[cfg(not(any(feature = "cuda", feature = "opencl")))]
pub fn groth16_verify(
curve_type: &str,
Expand Down Expand Up @@ -392,6 +497,27 @@ fn create_circuit_from_file<E: Engine>(
}
}

#[cfg(not(any(feature = "cuda", feature = "opencl")))]
pub fn create_circuit_add_witness<E: Engine>(
mut circuit: CircomCircuit<E>,
witness: Vec<num_bigint::BigInt>,
) -> CircomCircuit<E> {
let witness: Vec<E::Fr> = witness
.iter()
.map(|wi| {
if wi.is_zero() {
E::Fr::zero()
} else {
E::Fr::from_str(&wi.to_string()).unwrap()
}
})
.collect::<Vec<_>>();
circuit.witness = Some(witness);
circuit.wire_mapping = None;
circuit.aux_offset = 0;
circuit
}

#[cfg(any(feature = "cuda", feature = "opencl"))]
fn create_circuit_from_file<E: PrimeField>(
circuit_file: &str,
Expand All @@ -405,6 +531,27 @@ fn create_circuit_from_file<E: PrimeField>(
}
}

#[cfg(any(feature = "cuda", feature = "opencl"))]
pub fn create_circuit_add_witness(
mut circuit: CircomCircuit<Scalar>,
witness: Vec<num_bigint::BigInt>,
) -> CircomCircuit<Scalar> {
let w = witness
.iter()
.map(|wi| {
if wi.is_zero() {
Scalar::ZERO
} else {
Scalar::from_str_vartime(&wi.to_string()).unwrap()
}
})
.collect::<Vec<_>>();
circuit.witness = Some(w);
circuit.wire_mapping = None;
circuit.aux_offset = 0;
circuit
}

#[cfg(not(any(feature = "cuda", feature = "opencl")))]
fn read_pk_from_file<E: Engine>(file_path: &str, checked: bool) -> Result<Parameters<E>> {
let file =
Expand Down
95 changes: 30 additions & 65 deletions groth16/src/groth16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ mod tests {
use num_traits::Zero;

use super::*;
use crate::api::create_circuit_add_witness;
use crate::api::groth16_setup_inplace;
use crate::api::SetupResult;
use crate::bellman_ce::bls12_381::Bls12;
use crate::bellman_ce::bls12_381::Fr as Fr_bls12381;
use crate::bellman_ce::bn256::{Bn256, Fr};
use algebraic::circom_circuit::CircomCircuit;
use algebraic::reader;
Expand Down Expand Up @@ -188,53 +190,33 @@ mod tests {
}

#[test]
fn groth16_proof_bls12381() -> Result<()> {
fn groth16_proof_bls12381_inpace() -> Result<()> {
//1. SRS
let t = std::time::Instant::now();
let circuit: CircomCircuit<Bls12> = CircomCircuit {
r1cs: reader::load_r1cs(CIRCUIT_FILE_BLS12),
witness: None,
wire_mapping: None,
aux_offset: 0,
let setup_result = groth16_setup_inplace("BLS12381", CIRCUIT_FILE_BLS12)?;
let (circuit, pk, vk) = match setup_result {
SetupResult::BLS12381(circuit, pk, vk) => (circuit, pk, vk),
_ => panic!("Expected BLS12381 setup result"),
};
let mut rng = rand::thread_rng();
let params = Groth16::circuit_specific_setup(circuit, &mut rng)?;
let elapsed = t.elapsed().as_secs_f64();
println!("1-groth16-bls12381 setup run time: {} secs", elapsed);

//2. Prove
let t1 = std::time::Instant::now();
// let mut wtns = WitnessCalculator::new(WASM_FILE_BLS12).unwrap();
let mut rng = rand::thread_rng();
let mut wtns = WitnessCalculator::from_file(WASM_FILE_BLS12)?;
let inputs = load_input_for_witness(INPUT_FILE);
let w = wtns.calculate_witness(inputs, false).unwrap();
let w = w
.iter()
.map(|wi| {
if wi.is_zero() {
Fr_bls12381::zero()
} else {
// println!("wi: {}", wi);
Fr_bls12381::from_str(&wi.to_string()).unwrap()
}
})
.collect::<Vec<_>>();
let circuit1: CircomCircuit<Bls12> = CircomCircuit {
r1cs: reader::load_r1cs(CIRCUIT_FILE_BLS12),
witness: Some(w),
wire_mapping: None,
aux_offset: 0,
};
let inputs = circuit1.get_public_inputs().unwrap();
let proof = Groth16::prove(&params.0, circuit1, &mut rng)?;
let circuit1: CircomCircuit<Bls12> = create_circuit_add_witness::<Bls12>(circuit, w);
let proof = Groth16::prove(&pk, circuit1.clone(), &mut rng)?;
let elapsed1 = t1.elapsed().as_secs_f64();
println!("2-groth16-bls12381 prove run time: {} secs", elapsed1);

//3. Verify
let t2 = std::time::Instant::now();
let verified = Groth16::<_, CircomCircuit<Bls12>>::verify_with_processed_vk(
&params.1, &inputs, &proof,
)?;
let inputs = circuit1.get_public_inputs().unwrap();
let verified =
Groth16::<_, CircomCircuit<Bls12>>::verify_with_processed_vk(&vk, &inputs, &proof)?;
let elapsed2 = t2.elapsed().as_secs_f64();
println!("3-groth16-bls12381 verify run time: {} secs", elapsed2);

Expand All @@ -248,11 +230,13 @@ mod tests {
#[cfg(any(feature = "cuda", feature = "opencl"))]
mod tests {
use super::*;
use crate::api::{create_circuit_add_witness, groth16_setup_inplace, SetupResult};
use algebraic::witness::{load_input_for_witness, WitnessCalculator};
use algebraic_gpu::circom_circuit::CircomCircuit;
use algebraic_gpu::reader;
use blstrs::{Bls12, Scalar};
use ff::{Field, PrimeField};
use log::info;
use num_traits::Zero;
use rand_new::rngs::OsRng;
const INPUT_FILE: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../test/multiplier.input.json");
Expand All @@ -266,54 +250,35 @@ mod tests {
);

#[test]
fn groth16_proof() -> Result<()> {
fn groth16_proof_bls12381_inplace() -> Result<()> {
let _ = env_logger::try_init();
//1. SRS
let t = std::time::Instant::now();
let circuit: CircomCircuit<Scalar> = CircomCircuit {
r1cs: reader::load_r1cs(CIRCUIT_FILE_BLS12),
witness: None,
wire_mapping: None,
aux_offset: 0,
let setup_result = groth16_setup_inplace("BLS12381", CIRCUIT_FILE_BLS12)?;
let (circuit, pk, vk) = match setup_result {
SetupResult::BLS12381(circuit, pk, vk) => (circuit, pk, vk),
_ => panic!("Expected BLS12381 setup result"),
};
let params = Groth16::circuit_specific_setup(circuit, &mut OsRng)?;
let elapsed = t.elapsed().as_secs_f64();
println!("1-groth16-bls12381 setup run time: {} secs", elapsed);
info!("1-groth16-bls12381 setup run time: {} secs", elapsed);

//2. Prove
let t1 = std::time::Instant::now();
let mut wtns = WitnessCalculator::from_file(WASM_FILE_BLS12)?;
let inputs = load_input_for_witness(INPUT_FILE);
let w = wtns.calculate_witness(inputs, false).unwrap();
let w = w
.iter()
.map(|wi| {
if wi.is_zero() {
<Bls12 as Engine>::Fr::ZERO
} else {
// println!("wi: {}", wi);
<Bls12 as Engine>::Fr::from_str_vartime(&wi.to_string()).unwrap()
}
})
.collect::<Vec<_>>();
let circuit1: CircomCircuit<Scalar> = CircomCircuit {
r1cs: reader::load_r1cs(CIRCUIT_FILE_BLS12),
witness: Some(w),
wire_mapping: None,
aux_offset: 0,
};
let inputs = circuit1.get_public_inputs().unwrap();
let proof: bellperson::groth16::Proof<Bls12> =
Groth16::prove(&params.0, circuit1, &mut OsRng)?;
let circuit1: CircomCircuit<Scalar> = create_circuit_add_witness(circuit, w);
let proof = Groth16::prove(&pk, circuit1.clone(), &mut OsRng)?;
let elapsed1 = t1.elapsed().as_secs_f64();
println!("2-groth16-bls12381 prove run time: {} secs", elapsed1);
info!("2-groth16-bls12381 prove run time: {} secs", elapsed1);

//3. Verify
let t2 = std::time::Instant::now();
let verified = Groth16::<_, CircomCircuit<Scalar>>::verify_with_processed_vk(
&params.1, &inputs, &proof,
)?;
let inputs = circuit1.get_public_inputs().unwrap();
let verified =
Groth16::<_, CircomCircuit<Scalar>>::verify_with_processed_vk(&vk, &inputs, &proof)?;
let elapsed2 = t2.elapsed().as_secs_f64();
println!("3-groth16-bls12381 verify run time: {} secs", elapsed2);
info!("3-groth16-bls12381 verify run time: {} secs", elapsed2);

assert!(verified);

Expand Down

0 comments on commit 31bda35

Please sign in to comment.