From 36d75adb593fa9826f41e790dce1b35a31e04a93 Mon Sep 17 00:00:00 2001 From: Paul Cheng Date: Sun, 12 Nov 2023 12:02:12 +0800 Subject: [PATCH] refact: remove redunction code in generate witness --- algebraic/src/witness/circom.rs | 132 +++++++----------- algebraic/src/witness/witness_calculator.rs | 133 ++++++++----------- groth16/src/api.rs | 4 +- groth16/src/groth16.rs | 10 +- plonky/src/api.rs | 11 +- starky/src/compressor12/compressor12_exec.rs | 5 +- 6 files changed, 118 insertions(+), 177 deletions(-) diff --git a/algebraic/src/witness/circom.rs b/algebraic/src/witness/circom.rs index 84a8732a..c63dbecf 100644 --- a/algebraic/src/witness/circom.rs +++ b/algebraic/src/witness/circom.rs @@ -5,111 +5,79 @@ use wasmer::{Function, Instance, Store, Value}; #[derive(Clone, Debug)] pub struct Wasm(Instance); -// pub trait CircomBase { -// fn init(&self, sanity_check: bool) -> Result<()>; -// fn func(&self, name: &str) -> &Function; -// fn get_ptr_witness_buffer(&self) -> Result; -// fn get_ptr_witness(&self, w: u32) -> Result; -// fn get_signal_offset32( -// &self, -// p_sig_offset: u32, -// component: u32, -// hash_msb: u32, -// hash_lsb: u32, -// ) -> Result<()>; -// fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()>; -// fn get_u32(&self, name: &str) -> Result; -// // Only exists natively in Circom2, hardcoded for Circom -// fn get_version(&self) -> Result; -// } -// -// pub trait Circom { -// fn get_field_num_len32(&self) -> Result; -// fn get_raw_prime(&self) -> Result<()>; -// fn read_shared_rw_memory(&self, i: u32) -> Result; -// fn write_shared_rw_memory(&self, i: u32, v: u32) -> Result<()>; -// fn set_input_signal(&self, hmsb: u32, hlsb: u32, pos: u32) -> Result<()>; -// fn get_witness(&self, i: u32) -> Result<()>; -// fn get_witness_size(&self) -> Result; -// } - -// impl Circom for Wasm { impl Wasm { - pub(crate) fn get_field_num_len32(&self) -> Result { - self.get_u32("getFieldNumLen32") + pub(crate) fn get_field_num_len32(&self, store: &mut Store) -> Result { + self.get_u32(store, "getFieldNumLen32") } - pub(crate) fn get_raw_prime(&self) -> Result<()> { - let func = self.func("getRawPrime"); - let mut store = Store::default(); - func.call(&mut store, &[])?; + pub(crate) fn get_raw_prime(&self, store: &mut Store) -> Result<()> { + let func = self.func(store, "getRawPrime"); + func.call(store, &[])?; Ok(()) } - pub(crate) fn read_shared_rw_memory(&self, i: u32) -> Result { - let func = self.func("readSharedRWMemory"); - let mut store = Store::default(); - let result = func.call(&mut store, &[i.into()])?; + pub(crate) fn read_shared_rw_memory(&self, store: &mut Store, i: u32) -> Result { + let func = self.func(store, "readSharedRWMemory"); + let result = func.call(store, &[i.into()])?; Ok(result[0].unwrap_i32() as u32) } - pub(crate) fn write_shared_rw_memory(&self, i: u32, v: u32) -> Result<()> { - let func = self.func("writeSharedRWMemory"); - let mut store = Store::default(); - func.call(&mut store, &[i.into(), v.into()])?; + pub(crate) fn write_shared_rw_memory(&self, store: &mut Store, i: u32, v: u32) -> Result<()> { + let func = self.func(store, "writeSharedRWMemory"); + func.call(store, &[i.into(), v.into()])?; Ok(()) } - pub(crate) fn set_input_signal(&self, hmsb: u32, hlsb: u32, pos: u32) -> Result<()> { - let func = self.func("setInputSignal"); - let mut store = Store::default(); - func.call(&mut store, &[hmsb.into(), hlsb.into(), pos.into()])?; + pub(crate) fn set_input_signal( + &self, + store: &mut Store, + hmsb: u32, + hlsb: u32, + pos: u32, + ) -> Result<()> { + let func = self.func(store, "setInputSignal"); + func.call(store, &[hmsb.into(), hlsb.into(), pos.into()])?; Ok(()) } - pub(crate) fn get_witness(&self, i: u32) -> Result<()> { - let func = self.func("getWitness"); - let mut store = Store::default(); - func.call(&mut store, &[i.into()])?; + pub(crate) fn get_witness(&self, store: &mut Store, i: u32) -> Result<()> { + let func = self.func(store, "getWitness"); + func.call(store, &[i.into()])?; Ok(()) } - pub(crate) fn get_witness_size(&self) -> Result { - self.get_u32("getWitnessSize") + pub(crate) fn get_witness_size(&self, store: &mut Store) -> Result { + self.get_u32(store, "getWitnessSize") } - // } - // - // impl CircomBase for Wasm { - pub(crate) fn init(&self, sanity_check: bool) -> Result<()> { - let func = self.func("init"); - let mut store = Store::default(); - func.call(&mut store, &[Value::I32(sanity_check as i32)])?; + + pub(crate) fn init(&self, store: &mut Store, sanity_check: bool) -> Result<()> { + let func = self.func(store, "init"); + func.call(store, &[Value::I32(sanity_check as i32)])?; Ok(()) } - pub(crate) fn get_ptr_witness_buffer(&self) -> Result { - self.get_u32("getWitnessBuffer") + pub(crate) fn get_ptr_witness_buffer(&self, store: &mut Store) -> Result { + self.get_u32(store, "getWitnessBuffer") } - pub(crate) fn get_ptr_witness(&self, w: u32) -> Result { - let func = self.func("getPWitness"); - let mut store = Store::default(); - let res = func.call(&mut store, &[w.into()])?; + pub(crate) fn get_ptr_witness(&self, store: &mut Store, w: u32) -> Result { + let func = self.func(store, "getPWitness"); + let res = func.call(store, &[w.into()])?; Ok(res[0].unwrap_i32() as u32) } pub(crate) fn get_signal_offset32( &self, + store: &mut Store, p_sig_offset: u32, component: u32, hash_msb: u32, hash_lsb: u32, ) -> Result<()> { - let func = self.func("getSignalOffset32"); - let mut store = Store::default(); + let func = self.func(store, "getSignalOffset32"); func.call( - &mut store, + store, &[ p_sig_offset.into(), component.into(), @@ -123,15 +91,15 @@ impl Wasm { pub(crate) fn set_signal( &self, + store: &mut Store, c_idx: u32, component: u32, signal: u32, p_val: u32, ) -> Result<()> { - let func = self.func("setSignal"); - let mut store = Store::default(); + let func = self.func(store, "setSignal"); func.call( - &mut store, + store, &[c_idx.into(), component.into(), signal.into(), p_val.into()], )?; @@ -139,32 +107,26 @@ impl Wasm { } // Default to version 1 if it isn't explicitly defined - pub(crate) fn get_version(&self) -> Result { + pub(crate) fn get_version(&self, store: &mut Store) -> Result { match self.0.exports.get_function("getVersion") { - Ok(func) => { - let mut store = Store::default(); - Ok(func.call(&mut store, &[])?[0].unwrap_i32() as u32) - } + Ok(func) => Ok(func.call(store, &[])?[0].unwrap_i32() as u32), Err(_) => Ok(1), } } - pub(crate) fn get_u32(&self, name: &str) -> Result { - let func = self.func(name); - let mut store = Store::default(); - let result = func.call(&mut store, &[])?; + pub(crate) fn get_u32(&self, store: &mut Store, name: &str) -> Result { + let func = self.func(store, name); + let result = func.call(store, &[])?; Ok(result[0].unwrap_i32() as u32) } - pub(crate) fn func(&self, name: &str) -> &Function { + pub(crate) fn func(&self, store: &mut Store, name: &str) -> &Function { self.0 .exports .get_function(name) .unwrap_or_else(|_| panic!("function {} not found", name)) } - // } - // - // impl Wasm { + pub fn new(instance: Instance) -> Self { Self(instance) } diff --git a/algebraic/src/witness/witness_calculator.rs b/algebraic/src/witness/witness_calculator.rs index 465f4310..c7369137 100644 --- a/algebraic/src/witness/witness_calculator.rs +++ b/algebraic/src/witness/witness_calculator.rs @@ -53,52 +53,50 @@ fn to_array32(s: &BigInt, size: usize) -> Vec { } impl WitnessCalculator { - pub fn new(path: impl AsRef) -> Result { - Self::from_file(path) - } - - pub fn from_file(path: impl AsRef) -> Result { - let store = Store::default(); + pub fn from_file(path: impl AsRef) -> Result<(Store, Self)> { + let mut store = Store::default(); let module = Module::from_file(&store, path).expect("correct wtns file"); - Self::from_module(module) + let wtns = Self::from_module(&mut store, module).unwrap(); + Ok((store, wtns)) } - pub fn from_module(module: Module) -> Result { - // let store = module.store(); - let mut store = Store::default(); - + pub fn from_module(store: &mut Store, module: Module) -> Result { // Set up the memory - let memory = Memory::new(&mut store, MemoryType::new(2000, None, false)).unwrap(); + let memory = Memory::new(store, MemoryType::new(2000, None, false)).unwrap(); let import_object = imports! { "env" => { "memory" => memory.clone(), }, // Host function callbacks from the WASM "runtime" => { - "error" => runtime::error(&mut store), - "logSetSignal" => runtime::log_signal(&mut store), - "logGetSignal" => runtime::log_signal(&mut store), - "logFinishComponent" => runtime::log_component(&mut store), - "logStartComponent" => runtime::log_component(&mut store), - "log" => runtime::log_component(&mut store), - "exceptionHandler" => runtime::exception_handler(&mut store), - "showSharedRWMemory" => runtime::show_memory(&mut store), - "printErrorMessage" => runtime::print_error_message(&mut store), - "writeBufferMessage" => runtime::write_buffer_message(&mut store), + "error" => runtime::error(store), + "logSetSignal" => runtime::log_signal(store), + "logGetSignal" => runtime::log_signal(store), + "logFinishComponent" => runtime::log_component(store), + "logStartComponent" => runtime::log_component(store), + "log" => runtime::log_component(store), + "exceptionHandler" => runtime::exception_handler(store), + "showSharedRWMemory" => runtime::show_memory(store), + "printErrorMessage" => runtime::print_error_message(store), + "writeBufferMessage" => runtime::write_buffer_message(store), } }; - let instance = Wasm::new(Instance::new(&mut store, &module, &import_object)?); + let instance = Wasm::new(Instance::new(store, &module, &import_object)?); // Circom 2 feature flag with version 2 - fn new_circom(instance: Wasm, memory: Memory) -> Result { - let version = instance.get_version().unwrap_or(1); - - let n32 = instance.get_field_num_len32()?; + fn new_circom( + store: &mut Store, + instance: Wasm, + memory: Memory, + ) -> Result { + let version = instance.get_version(store).unwrap_or(1); + + let n32 = instance.get_field_num_len32(store)?; let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero()); - instance.get_raw_prime()?; + instance.get_raw_prime(store)?; let mut arr = vec![0; n32 as usize]; for i in 0..n32 { - let res = instance.read_shared_rw_memory(i)?; + let res = instance.read_shared_rw_memory(store, i)?; arr[(n32 as usize) - (i as usize) - 1] = res; } let prime = from_array32(arr); @@ -114,20 +112,21 @@ impl WitnessCalculator { }) } - new_circom(instance, memory) + new_circom(store, instance, memory) } pub fn calculate_witness)>>( &mut self, + store: &mut Store, inputs: I, sanity_check: bool, ) -> Result> { - self.instance.init(sanity_check)?; - let wtns_u32 = self.calculate_witness_circom(inputs, sanity_check)?; - let n32 = self.instance.get_field_num_len32()?; + self.instance.init(store, sanity_check)?; + let wtns_u32 = self.calculate_witness_circom(store, inputs, sanity_check)?; + let n32 = self.instance.get_field_num_len32(store)?; let mut wo = Vec::new(); - let witness_size = self.instance.get_witness_size()?; + let witness_size = self.instance.get_witness_size(store)?; for i in 0..witness_size { let mut arr = vec![0u32; n32 as usize]; for j in 0..n32 { @@ -140,22 +139,24 @@ impl WitnessCalculator { pub fn calculate_witness_bin)>>( &mut self, + store: &mut Store, inputs: I, sanity_check: bool, ) -> Result> { - self.instance.init(sanity_check)?; - self.calculate_witness_circom(inputs, sanity_check) + self.instance.init(store, sanity_check)?; + self.calculate_witness_circom(store, inputs, sanity_check) } // Circom 2 feature flag with version 2 fn calculate_witness_circom)>>( &mut self, + store: &mut Store, inputs: I, sanity_check: bool, ) -> Result> { - self.instance.init(sanity_check)?; + self.instance.init(store, sanity_check)?; - let n32 = self.instance.get_field_num_len32()?; + let n32 = self.instance.get_field_num_len32(store)?; // allocate the inputs for (name, values) in inputs.into_iter() { @@ -164,20 +165,23 @@ impl WitnessCalculator { for (i, value) in values.into_iter().enumerate() { let f_arr = to_array32(&value, n32 as usize); for j in 0..n32 { - self.instance - .write_shared_rw_memory(j, f_arr[(n32 as usize) - 1 - (j as usize)])?; + self.instance.write_shared_rw_memory( + store, + j, + f_arr[(n32 as usize) - 1 - (j as usize)], + )?; } - self.instance.set_input_signal(msb, lsb, i as u32)?; + self.instance.set_input_signal(store, msb, lsb, i as u32)?; } } let mut w = Vec::new(); - let witness_size = self.instance.get_witness_size()?; + let witness_size = self.instance.get_witness_size(store)?; for i in 0..witness_size { - self.instance.get_witness(i)?; + self.instance.get_witness(store, i)?; for j in 0..n32 { - w.push(self.instance.read_shared_rw_memory(j)?); + w.push(self.instance.read_shared_rw_memory(store, j)?); } } @@ -187,6 +191,7 @@ impl WitnessCalculator { #[cfg(not(feature = "wasm"))] pub fn save_witness_to_bin_file( &self, + store: &mut Store, filename: &str, w: &Vec, ) -> Result<()> { @@ -197,15 +202,16 @@ impl WitnessCalculator { .expect("unable to open."); let writer = BufWriter::new(writer); - self.save_witness_from_bin_writer::(writer, w) + self.save_witness_from_bin_writer::(store, writer, w) } pub fn save_witness_from_bin_writer( &self, + store: &mut Store, mut writer: W, wtns: &Vec, ) -> Result<()> { - let n32 = self.instance.get_field_num_len32()?; + let n32 = self.instance.get_field_num_len32(store)?; let wtns_header = [119, 116, 110, 115]; writer.write_all(&wtns_header)?; @@ -254,37 +260,6 @@ impl WitnessCalculator { } Ok(()) } - - pub fn calculate_witness_element< - E: ScalarEngine, - I: IntoIterator)>, - >( - &mut self, - inputs: I, - sanity_check: bool, - ) -> Result> { - let witness = self.calculate_witness(inputs, sanity_check)?; - let modulus = BigUint::from_str( - "21888242871839275222246405745257275088548364400416034343698204186575808495617", - )?; - - // convert it to field elements - use num_traits::Signed; - let witness = witness - .into_iter() - .map(|w| { - let w = if w.sign() == num_bigint::Sign::Minus { - // Need to negate the witness element if negative - modulus.clone() - w.abs().to_biguint().unwrap() - } else { - w.to_biguint().unwrap() - }; - E::Fr::from_str(&w.to_string()).unwrap() - }) - .collect::>(); - - Ok(witness) - } } #[allow(dead_code)] @@ -448,7 +423,7 @@ mod tests { // TODO: test complex samples fn run_test(case: TestCase) { - let mut wtns = WitnessCalculator::new(case.circuit_path).unwrap(); + let (mut store, mut wtns) = WitnessCalculator::from_file(case.circuit_path).unwrap(); assert_eq!( wtns.memory.prime.to_str_radix(16), "30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001".to_lowercase() @@ -477,7 +452,7 @@ mod tests { }) .collect::>(); - let res = wtns.calculate_witness(inputs, true).unwrap(); + let res = wtns.calculate_witness(&mut store, inputs, true).unwrap(); for (r, w) in res.iter().zip(case.witness) { assert_eq!(r, &BigInt::from_str(w).unwrap()); } diff --git a/groth16/src/api.rs b/groth16/src/api.rs index 3f962e49..c0e80960 100644 --- a/groth16/src/api.rs +++ b/groth16/src/api.rs @@ -59,9 +59,9 @@ pub fn groth16_prove( ) -> Result<()> { let mut rng = rand::thread_rng(); - let mut wtns = WitnessCalculator::new(wtns_file)?; + let (mut store, mut wtns) = WitnessCalculator::from_file(wtns_file)?; let inputs = load_input_for_witness(input_file); - let w = wtns.calculate_witness(inputs, false)?; + let w = wtns.calculate_witness(&mut store, inputs, false)?; match curve_type { "BN128" => { let pk: Parameters = read_pk_from_file(pk_file, false)?; diff --git a/groth16/src/groth16.rs b/groth16/src/groth16.rs index 9aa10f94..47f629a4 100644 --- a/groth16/src/groth16.rs +++ b/groth16/src/groth16.rs @@ -83,9 +83,10 @@ mod tests { //2. Prove let t1 = std::time::Instant::now(); - let mut wtns = WitnessCalculator::new(WASM_FILE).unwrap(); + // let mut wtns = WitnessCalculator::new(WASM_FILE).unwrap(); + let (mut store, mut wtns) = WitnessCalculator::from_file(WASM_FILE)?; let inputs = load_input_for_witness(INPUT_FILE); - let w = wtns.calculate_witness(inputs, false).unwrap(); + let w = wtns.calculate_witness(&mut store, inputs, false).unwrap(); let w = w .iter() .map(|wi| { @@ -138,9 +139,10 @@ mod tests { //2. Prove let t1 = std::time::Instant::now(); - let mut wtns = WitnessCalculator::new(WASM_FILE_BLS12).unwrap(); + // let mut wtns = WitnessCalculator::new(WASM_FILE_BLS12).unwrap(); + let (mut store, mut wtns) = WitnessCalculator::from_file(WASM_FILE_BLS12)?; let inputs = load_input_for_witness(INPUT_FILE); - let w = wtns.calculate_witness(inputs, false).unwrap(); + let w = wtns.calculate_witness(&mut store, inputs, false).unwrap(); let w = w .iter() .map(|wi| { diff --git a/plonky/src/api.rs b/plonky/src/api.rs index 60b9fe69..a0de55ee 100644 --- a/plonky/src/api.rs +++ b/plonky/src/api.rs @@ -82,16 +82,17 @@ pub fn prove( } pub fn calculate_witness(wasm_file: &str, input_json: &str, output: &str) -> Result<()> { - let mut wtns = WitnessCalculator::new(wasm_file).unwrap(); + let inputs = load_input_for_witness(input_json); + + // let mut wtns = WitnessCalculator::new(wasm_file).unwrap(); + let (mut store, mut wtns) = WitnessCalculator::from_file(wasm_file)?; assert_eq!( wtns.memory.prime.to_str_radix(16), "30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001".to_lowercase() ); - let inputs = load_input_for_witness(input_json); - - let wtns_buf = wtns.calculate_witness_bin(inputs, false)?; - wtns.save_witness_to_bin_file::(output, &wtns_buf) + let wtns_buf = wtns.calculate_witness_bin(&mut store, inputs, false)?; + wtns.save_witness_to_bin_file::(&mut store, output, &wtns_buf) } pub fn export_verification_key( diff --git a/starky/src/compressor12/compressor12_exec.rs b/starky/src/compressor12/compressor12_exec.rs index 0944c1c5..51f71fb9 100644 --- a/starky/src/compressor12/compressor12_exec.rs +++ b/starky/src/compressor12/compressor12_exec.rs @@ -38,9 +38,10 @@ pub fn exec( let mut cm_pols = PolsArray::new(&pil_json, PolKind::Commit); // 3. calculate witness. wasm+input->witness - let mut wtns = WitnessCalculator::new(wasm_file).unwrap(); let inputs = load_input_for_witness(input_file); - let w = wtns.calculate_witness(inputs, false).unwrap(); + // let mut wtns = WitnessCalculator::new(wasm_file).unwrap(); + let (mut store, mut wtns) = WitnessCalculator::from_file(wasm_file)?; + let w = wtns.calculate_witness(&mut store, inputs, false).unwrap(); let mut w = w .iter() .map(|wi| {