From 0ee036b6a6f6439f7e5a6d37ada245c47b21fa91 Mon Sep 17 00:00:00 2001 From: Stephen <81497928+eigmax@users.noreply.github.com> Date: Tue, 23 Jan 2024 21:51:26 +0800 Subject: [PATCH] chore: optimize memory usage (#176) * chore: drop memory * chore: opt interpolate_prepare and bit_reverse * chore: opt interpolate_prepare and bit_reverse * chore: fix mem * fix: number may consists of hex and negtives * chore: opt calculate h1h2 * chore: opt calculate h1h2 * chore: opt fft * fix: str to digit * fix: str to digit * chore: recover UT * chore: code polish & fix rayon hashmap * chore: code polish & fix rayon hashmap * chore: drop memory * fix: handle powdr all constant case * feat: seralize setup info * chore: dev * chore: dev * chore: dev * chore: dev * chore: dev * chore: dev * chore: dev * chore: dev * chore: dev * chore: dev * doc: add todo * doc: add todo * doc: add todo * fix: rebase * chore: stash * fix: ut * chore: stash code * chore: stash code --- algebraic/src/reader.rs | 1 - starky/Cargo.toml | 1 + starky/benches/batch_inverse.rs | 3 +- starky/src/digest.rs | 18 ++++ starky/src/f3g.rs | 3 +- starky/src/f5g.rs | 3 +- starky/src/fft.rs | 10 ++- starky/src/fft_p.rs | 43 +++++---- starky/src/fri.rs | 3 - starky/src/interpreter.rs | 11 ++- starky/src/lib.rs | 2 +- starky/src/merklehash.rs | 38 ++++++++ starky/src/merklehash_bls12381.rs | 36 ++++++++ starky/src/merklehash_bn128.rs | 46 ++++++++-- starky/src/polsarray.rs | 17 ++-- starky/src/polutils.rs | 20 +++++ starky/src/prove.rs | 18 ++-- starky/src/stark_gen.rs | 141 ++++++++++++++++++------------ starky/src/stark_setup.rs | 57 +++++++++++- starky/src/stark_verify.rs | 11 +-- starky/src/starkinfo.rs | 16 ++-- starky/src/starkinfo_Z.rs | 2 - starky/src/starkinfo_codegen.rs | 53 ++++++++--- starky/src/starkinfo_cp_ver.rs | 8 +- starky/src/starkinfo_map.rs | 11 +-- starky/src/traits.rs | 35 +++----- starky/src/transcript.rs | 1 - starky/src/types.rs | 15 ++++ zkvm/Cargo.toml | 7 +- zkvm/src/lib.rs | 50 ++++------- zkvm/vm/evm/Cargo.toml | 2 +- zkvm/vm/lr/Cargo.toml | 12 +++ zkvm/vm/lr/rust-toolchain.toml | 4 + zkvm/vm/lr/src/lib.rs | 46 ++++++++++ 34 files changed, 524 insertions(+), 220 deletions(-) create mode 100644 zkvm/vm/lr/Cargo.toml create mode 100644 zkvm/vm/lr/rust-toolchain.toml create mode 100644 zkvm/vm/lr/src/lib.rs diff --git a/algebraic/src/reader.rs b/algebraic/src/reader.rs index 94fcf8ec..87be1c9a 100644 --- a/algebraic/src/reader.rs +++ b/algebraic/src/reader.rs @@ -91,7 +91,6 @@ pub fn load_witness_from_bin_reader(mut reader: R) -> reader.read_exact(&mut wtns_header)?; if wtns_header != [119, 116, 110, 115] { // python -c 'print([ord(c) for c in "wtns"])' => [119, 116, 110, 115] - // bail!("Invalid file header".to_string())); bail!("Invalid file header"); } let version = reader.read_u32::()?; diff --git a/starky/Cargo.toml b/starky/Cargo.toml index e546f788..f35a778a 100644 --- a/starky/Cargo.toml +++ b/starky/Cargo.toml @@ -24,6 +24,7 @@ lazy_static = "1.0" ## threading rayon = { version = "1.5"} num_cpus = "1.0" +hashbrown = { version = "0.14.3", features = ["rayon"] } # error and log thiserror="1.0" diff --git a/starky/benches/batch_inverse.rs b/starky/benches/batch_inverse.rs index 5c93abfd..7ea75a1c 100644 --- a/starky/benches/batch_inverse.rs +++ b/starky/benches/batch_inverse.rs @@ -4,7 +4,8 @@ extern crate criterion; use criterion::{BenchmarkId, Criterion}; use starky::dev::gen_rand_goldfields; use starky::f3g::F3G; -use starky::traits::{batch_inverse, FieldExtension}; +use starky::polutils::batch_inverse; +use starky::traits::FieldExtension; const MIN_K: usize = 6; const MAX_K: usize = 24; diff --git a/starky/src/digest.rs b/starky/src/digest.rs index 9994abd6..09010f70 100644 --- a/starky/src/digest.rs +++ b/starky/src/digest.rs @@ -1,11 +1,14 @@ #![allow(non_snake_case)] +use crate::errors::Result; use crate::field_bls12381::Fr as Fr_bls12381; use crate::field_bls12381::FrRepr as FrRepr_bls12381; use crate::field_bn128::{Fr, FrRepr}; use crate::traits::MTNodeType; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use ff::*; use plonky::field_gl::Fr as FGL; use std::fmt::Display; +use std::io::{Read, Write}; #[repr(C)] #[derive(Debug, Copy, Clone, Eq, PartialEq)] @@ -46,6 +49,21 @@ impl MTNodeType for ElementDigest { } y } + + fn save(&self, writer: &mut W) -> Result<()> { + for i in &self.0 { + writer.write_u64::(i.as_int())?; + } + Ok(()) + } + + fn load(reader: &mut R) -> Result { + let e1 = reader.read_u64::()?; + let e2 = reader.read_u64::()?; + let e3 = reader.read_u64::()?; + let e4 = reader.read_u64::()?; + Ok(Self::new(&[e1.into(), e2.into(), e3.into(), e4.into()])) + } } impl Display for ElementDigest { diff --git a/starky/src/f3g.rs b/starky/src/f3g.rs index 35f03412..46a91123 100644 --- a/starky/src/f3g.rs +++ b/starky/src/f3g.rs @@ -629,7 +629,8 @@ impl Display for F3G { #[cfg(test)] pub mod tests { use crate::f3g::F3G; - use crate::traits::{batch_inverse, FieldExtension}; + use crate::polutils::batch_inverse; + use crate::traits::FieldExtension; use plonky::field_gl::Fr; use plonky::Field; use std::ops::{Add, Mul}; diff --git a/starky/src/f5g.rs b/starky/src/f5g.rs index 04d2a60f..73063a20 100644 --- a/starky/src/f5g.rs +++ b/starky/src/f5g.rs @@ -725,7 +725,8 @@ impl F5G { #[cfg(test)] pub mod tests { use crate::f5g::F5G; - use crate::traits::{batch_inverse, FieldExtension}; + use crate::polutils::batch_inverse; + use crate::traits::FieldExtension; use plonky::field_gl::Fr; use plonky::Field; use std::ops::{Add, Mul}; diff --git a/starky/src/fft.rs b/starky/src/fft.rs index 8ff879f5..15631545 100644 --- a/starky/src/fft.rs +++ b/starky/src/fft.rs @@ -3,6 +3,7 @@ use crate::constant::MG; use crate::helper::log2_any; use crate::traits::FieldExtension; +use rayon::prelude::*; #[allow(clippy::upper_case_acronyms)] #[derive(Default)] @@ -77,9 +78,12 @@ impl FFT { let n = p.len(); let n2inv = F::from(p.len()).inv(); let mut res = vec![F::ZERO; q.len()]; - for i in 0..n { - res[(n - i) % n] = q[i] * n2inv; - } + + res[0] = q[0] * n2inv; + res[1..] + .par_iter_mut() + .enumerate() + .for_each(|(i, out)| *out = q[n - i - 1] * n2inv); res } } diff --git a/starky/src/fft_p.rs b/starky/src/fft_p.rs index 6bb79cdb..e9ee4ee3 100644 --- a/starky/src/fft_p.rs +++ b/starky/src/fft_p.rs @@ -97,11 +97,14 @@ pub fn bit_reverse( let len = n * n_pols; assert_eq!(len, buffdst.len()); - for j in 0..len { - let i = j / n_pols; - let k = j % n_pols; - buffdst[j] = buffsrc[ris[i] * n_pols + k]; - } + buffdst[0..len] + .par_iter_mut() + .enumerate() + .for_each(|(j, out)| { + let i = j / n_pols; + let k = j % n_pols; + *out = buffsrc[ris[i] * n_pols + k]; + }); } pub fn interpolate_bit_reverse( @@ -113,12 +116,15 @@ pub fn interpolate_bit_reverse( let n = 1 << nbits; let ris = BRs(0, n, nbits); // move it outside the loop. obtain it from cache. - for i in 0..n { - let rii = (n - ris[i]) % n; - for k in 0..n_pols { - buffdst[i * n_pols + k] = buffsrc[rii * n_pols + k]; - } - } + buffdst[0..n * n_pols] + .par_chunks_mut(n_pols) + .enumerate() + .for_each(|(i, out)| { + let rii = (n - ris[i]) % n; + for k in 0..n_pols { + out[k] = buffsrc[rii * n_pols + k]; + } + }); } pub fn inv_bit_reverse( @@ -133,12 +139,15 @@ pub fn inv_bit_reverse( let len = n * n_pols; assert_eq!(len, buffdst.len()); - for j in 0..len { - let i = j / n_pols; - let k = j % n_pols; - let rii = (n - ris[i]) % n; - buffdst[j] = buffsrc[rii * n_pols + k] * n_inv; - } + buffdst[0..len] + .par_iter_mut() + .enumerate() + .for_each(|(j, out)| { + let i = j / n_pols; + let k = j % n_pols; + let rii = (n - ris[i]) % n; + *out = buffsrc[rii * n_pols + k] * n_inv; + }); } pub fn interpolate_prepare(buff: &mut Vec, n_pols: usize, nbits: usize) { diff --git a/starky/src/fri.rs b/starky/src/fri.rs index c7d94824..833bab7c 100644 --- a/starky/src/fri.rs +++ b/starky/src/fri.rs @@ -58,7 +58,6 @@ impl FRI { let mut pol = pol.to_owned(); let mut standard_fft = FFT::new(); let mut pol_bits = log2_any(pol.len()); - log::trace!("fri prove {} {}", pol.len(), 1 << pol_bits); assert_eq!(1 << pol_bits, pol.len()); assert_eq!(pol_bits, self.in_nbits); @@ -93,7 +92,6 @@ impl FRI { sinv *= wi; } } - log::trace!("pol2_e 0={}, 1={}", pol2_e[0], pol2_e[1]); if si < self.steps.len() - 1 { let n_groups = 1 << self.steps[si + 1].nBits; let group_size = (1 << stepi.nBits) / n_groups; @@ -185,7 +183,6 @@ impl FRI { let n_queries = self.n_queries; let mut ys = transcript.get_permutations(self.n_queries, self.steps[0].nBits)?; let mut pol_bits = self.in_nbits; - log::trace!("ys: {:?}, pol_bits {}", ys, self.in_nbits); let mut shift = F::from(*SHIFT); let check_query_fn = |si: usize, diff --git a/starky/src/interpreter.rs b/starky/src/interpreter.rs index 738e42ba..ed3a73a0 100644 --- a/starky/src/interpreter.rs +++ b/starky/src/interpreter.rs @@ -4,6 +4,7 @@ use crate::starkinfo::StarkInfo; use crate::starkinfo_codegen::Node; use crate::starkinfo_codegen::Section; use crate::traits::FieldExtension; +use crate::types::parse_pil_number; use std::fmt; #[derive(Clone, Debug)] @@ -471,12 +472,10 @@ fn get_ref( panic!("Invalid dom"); } } - "number" => Expr::new( - Ops::Vari(F::from(r.value.clone().unwrap().parse::().unwrap())), - vec![], - vec![], - vec![], - ), + "number" => { + let n_val = parse_pil_number(r.value.as_ref().unwrap()); + Expr::new(Ops::Vari(F::from(n_val)), vec![], vec![], vec![]) + } "public" => Expr::new( Ops::Refer, vec!["publics".to_string()], diff --git a/starky/src/lib.rs b/starky/src/lib.rs index e0e52558..67d28016 100644 --- a/starky/src/lib.rs +++ b/starky/src/lib.rs @@ -5,7 +5,7 @@ #![cfg_attr(feature = "avx512", feature(stdsimd))] pub mod polsarray; -mod polutils; +pub mod polutils; pub mod stark_verifier_circom; pub mod stark_verifier_circom_bn128; pub mod traits; diff --git a/starky/src/merklehash.rs b/starky/src/merklehash.rs index fac411b1..0bf1dc3e 100644 --- a/starky/src/merklehash.rs +++ b/starky/src/merklehash.rs @@ -28,8 +28,10 @@ use crate::poseidon_opt::Poseidon; use crate::traits::MTNodeType; use crate::traits::MerkleTree; use anyhow::{bail, Result}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use plonky::field_gl::Fr as FGL; use rayon::prelude::*; +use std::io::{Read, Write}; use std::time::Instant; #[derive(Default)] @@ -262,6 +264,42 @@ impl MerkleTree for MerkleTreeGL { } } + // TODO: https://github.com/0xEigenLabs/eigen-zkvm/issues/187 + fn save(&self, writer: &mut W) -> Result<()> { + writer.write_u64::(self.width as u64)?; + writer.write_u64::(self.height as u64)?; + writer.write_u64::(self.elements.len() as u64)?; + for i in &self.elements { + writer.write_u64::(i.as_int())?; + } + writer.write_u64::(self.nodes.len() as u64)?; + for i in &self.nodes { + i.save(writer)?; + } + Ok(()) + } + + fn load(reader: &mut R) -> Result { + let mut mt = Self::new(); + mt.width = reader.read_u64::()? as usize; + mt.height = reader.read_u64::()? as usize; + + let es = reader.read_u64::()? as usize; + mt.elements = vec![FGL::ZERO; es]; + for i in 0..es { + let e = reader.read_u64::()?; + mt.elements[i] = FGL::from(e); + } + + let ns = reader.read_u64::()? as usize; + mt.nodes = vec![ElementDigest::<4>::new(&[FGL::ZERO, FGL::ZERO, FGL::ZERO, FGL::ZERO]); ns]; + for i in 0..ns { + mt.nodes[i] = ElementDigest::<4>::load(reader)?; + } + + Ok(mt) + } + fn element_size(&self) -> usize { self.elements.len() } diff --git a/starky/src/merklehash_bls12381.rs b/starky/src/merklehash_bls12381.rs index 2cc80656..1d150e8d 100644 --- a/starky/src/merklehash_bls12381.rs +++ b/starky/src/merklehash_bls12381.rs @@ -8,9 +8,11 @@ use crate::poseidon_bls12381_opt::Poseidon; use crate::traits::MTNodeType; use crate::traits::MerkleTree; use anyhow::{bail, Result}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use ff::Field; use plonky::field_gl::Fr as FGL; use rayon::prelude::*; +use std::io::{Read, Write}; use std::time::Instant; #[derive(Default)] @@ -156,6 +158,40 @@ impl MerkleTree for MerkleTreeBLS12381 { poseidon: Poseidon::new(), } } + fn save(&self, writer: &mut W) -> Result<()> { + writer.write_u64::(self.width as u64)?; + writer.write_u64::(self.height as u64)?; + writer.write_u64::(self.elements.len() as u64)?; + for i in &self.elements { + writer.write_u64::(i.as_int())?; + } + writer.write_u64::(self.nodes.len() as u64)?; + for i in &self.nodes { + i.save(writer)?; + } + Ok(()) + } + + fn load(reader: &mut R) -> Result { + let mut mt = Self::new(); + mt.width = reader.read_u64::()? as usize; + mt.height = reader.read_u64::()? as usize; + + let es = reader.read_u64::()? as usize; + mt.elements = vec![FGL::ZERO; es]; + for i in 0..es { + let e = reader.read_u64::()?; + mt.elements[i] = FGL::from(e); + } + + let ns = reader.read_u64::()? as usize; + mt.nodes = vec![ElementDigest::<4>::new(&[FGL::ZERO, FGL::ZERO, FGL::ZERO, FGL::ZERO]); ns]; + for i in 0..ns { + mt.nodes[i] = ElementDigest::<4>::load(reader)?; + } + + Ok(mt) + } fn element_size(&self) -> usize { self.elements.len() diff --git a/starky/src/merklehash_bn128.rs b/starky/src/merklehash_bn128.rs index a37fb510..86548d2a 100644 --- a/starky/src/merklehash_bn128.rs +++ b/starky/src/merklehash_bn128.rs @@ -8,10 +8,11 @@ use crate::poseidon_bn128_opt::Poseidon; use crate::traits::MTNodeType; use crate::traits::MerkleTree; use anyhow::{bail, Result}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use ff::Field; use plonky::field_gl::Fr as FGL; use rayon::prelude::*; -use std::time::Instant; +use std::io::{Read, Write}; #[derive(Default)] pub struct MerkleTreeBN128 { @@ -158,6 +159,41 @@ impl MerkleTree for MerkleTreeBN128 { } } + fn save(&self, writer: &mut W) -> Result<()> { + writer.write_u64::(self.width as u64)?; + writer.write_u64::(self.height as u64)?; + writer.write_u64::(self.elements.len() as u64)?; + for i in &self.elements { + writer.write_u64::(i.as_int())?; + } + writer.write_u64::(self.nodes.len() as u64)?; + for i in &self.nodes { + i.save(writer)?; + } + Ok(()) + } + + fn load(reader: &mut R) -> Result { + let mut mt = Self::new(); + mt.width = reader.read_u64::()? as usize; + mt.height = reader.read_u64::()? as usize; + + let es = reader.read_u64::()? as usize; + mt.elements = vec![FGL::ZERO; es]; + for i in 0..es { + let e = reader.read_u64::()?; + mt.elements[i] = FGL::from(e); + } + + let ns = reader.read_u64::()? as usize; + mt.nodes = vec![ElementDigest::<4>::new(&[FGL::ZERO, FGL::ZERO, FGL::ZERO, FGL::ZERO]); ns]; + for i in 0..ns { + mt.nodes[i] = ElementDigest::<4>::load(reader)?; + } + + Ok(mt) + } + fn element_size(&self) -> usize { self.elements.len() } @@ -187,7 +223,6 @@ impl MerkleTree for MerkleTreeBN128 { } // calculate the nodes of the specific height Merkle tree let mut nodes = vec![ElementDigest::<4>::default(); get_n_nodes(height)]; - let now = Instant::now(); if !buff.is_empty() { nodes .par_chunks_mut(n_per_thread_f) @@ -200,7 +235,6 @@ impl MerkleTree for MerkleTreeBN128 { }); }); } - log::trace!("linearhash time cost: {}", now.elapsed().as_secs_f64()); // merklize level self.nodes = nodes; @@ -213,13 +247,7 @@ impl MerkleTree for MerkleTreeBN128 { let mut p_in: usize = 0; let mut p_out: usize = p_in + next_n256 * 16; while n256 > 1 { - let now = Instant::now(); self.merklize_level(p_in, next_n256, p_out)?; - log::trace!( - "merklize_level {} time cost: {}", - next_n256, - now.elapsed().as_secs_f64() - ); n256 = next_n256; next_n256 = (n256 - 1) / 16 + 1; p_in = p_out; diff --git a/starky/src/polsarray.rs b/starky/src/polsarray.rs index 72f41be6..589b2ad4 100644 --- a/starky/src/polsarray.rs +++ b/starky/src/polsarray.rs @@ -3,6 +3,7 @@ use crate::errors::Result; use crate::{traits::FieldExtension, types::PIL}; use plonky::field_gl::Fr as FGL; use profiler_macro::time_profiler; +use rayon::prelude::*; use std::collections::HashMap; use std::fs::File; use std::io::{Read, Write}; @@ -138,7 +139,7 @@ impl PolsArray { pol.id + k } - #[time_profiler("load_pols_array")] + #[time_profiler("load_cm_pols_array")] pub fn load(&mut self, fileName: &str) -> Result<()> { let mut f = File::open(fileName)?; let maxBufferSize = 1024 * 1024 * 32; @@ -223,12 +224,14 @@ impl PolsArray { } pub fn write_buff(&self) -> Vec { - let mut buff: Vec = vec![]; - for i in 0..self.n { - for j in 0..self.nPols { - buff.push(F::from(self.array[j][i])); - } - } + let mut buff: Vec = vec![F::ZERO; self.n * self.nPols]; + buff.par_chunks_mut(self.nPols) + .enumerate() + .for_each(|(i, chunk)| { + for j in 0..self.nPols { + chunk[j] = F::from(self.array[j][i]); + } + }); buff } } diff --git a/starky/src/polutils.rs b/starky/src/polutils.rs index 6e456c7f..681b8d51 100644 --- a/starky/src/polutils.rs +++ b/starky/src/polutils.rs @@ -31,3 +31,23 @@ pub fn extend_pol(p: &[F], extend_bits: usize) -> Vec { res.extend_from_slice(&zeros); standard_fft.fft(&res) } + +pub fn batch_inverse(elems: &[F]) -> Vec { + if elems.is_empty() { + return vec![]; + } + + let mut tmp: Vec = vec![F::ZERO; elems.len()]; + tmp[0] = elems[0]; + for i in 1..elems.len() { + tmp[i] = elems[i] * (tmp[i - 1]); + } + let mut z = tmp[tmp.len() - 1].inv(); + let mut res: Vec = vec![F::ZERO; elems.len()]; + for i in (1..elems.len()).rev() { + res[i] = z * tmp[i - 1]; + z *= elems[i]; + } + res[0] = z; + res +} diff --git a/starky/src/prove.rs b/starky/src/prove.rs index 2ce6d69d..f9efbb1e 100644 --- a/starky/src/prove.rs +++ b/starky/src/prove.rs @@ -43,8 +43,8 @@ pub fn stark_prove( match stark_struct.verificationHashType.as_str() { "BN128" => prove::( &mut pil, - &const_pol, - &cm_pol, + const_pol, + cm_pol, &stark_struct, false, norm_stage, @@ -54,8 +54,8 @@ pub fn stark_prove( ), "BLS12381" => prove::( &mut pil, - &const_pol, - &cm_pol, + const_pol, + cm_pol, &stark_struct, false, norm_stage, @@ -65,8 +65,8 @@ pub fn stark_prove( ), "GL" => prove::( &mut pil, - &const_pol, - &cm_pol, + const_pol, + cm_pol, &stark_struct, agg_stage, norm_stage, @@ -82,8 +82,8 @@ pub fn stark_prove( #[allow(clippy::too_many_arguments)] fn prove>, T: Transcript>( pil: &mut PIL, - const_pol: &PolsArray, - cm_pol: &PolsArray, + const_pol: PolsArray, + cm_pol: PolsArray, stark_struct: &StarkStruct, agg_stage: bool, norm_stage: bool, @@ -91,7 +91,7 @@ fn prove>, T: Transcript>( zkin: &str, prover_addr: &str, ) -> Result<()> { - let mut setup = StarkSetup::::new(const_pol, pil, stark_struct, None)?; + let mut setup = StarkSetup::::new(&const_pol, pil, stark_struct, None)?; let mut starkproof = StarkProof::::stark_gen::( cm_pol, const_pol, diff --git a/starky/src/stark_gen.rs b/starky/src/stark_gen.rs index 903c5691..71c19ad8 100644 --- a/starky/src/stark_gen.rs +++ b/starky/src/stark_gen.rs @@ -10,15 +10,15 @@ use crate::fri::FRI; use crate::helper::pretty_print_array; use crate::interpreter::compile_code; use crate::polsarray::PolsArray; +use crate::polutils::batch_inverse; use crate::starkinfo::{Program, StarkInfo}; use crate::starkinfo_codegen::{Polynom, Segment}; -use crate::traits::{batch_inverse, FieldExtension}; -use crate::traits::{MTNodeType, MerkleTree, Transcript}; +use crate::traits::{FieldExtension, MTNodeType, MerkleTree, Transcript}; use crate::types::{StarkStruct, PIL}; +use hashbrown::HashMap; use plonky::field_gl::Fr as FGL; use profiler_macro::time_profiler; use rayon::prelude::*; -use std::collections::HashMap; pub struct StarkContext { pub nbits: usize, @@ -190,8 +190,8 @@ impl<'a, M: MerkleTree> StarkProof { #[allow(clippy::too_many_arguments, clippy::type_complexity)] #[time_profiler()] pub fn stark_gen( - cm_pols: &PolsArray, - const_pols: &PolsArray, + cm_pols: PolsArray, + const_pols: PolsArray, const_tree: &M, starkinfo: &'a StarkInfo, program: &Program, @@ -200,8 +200,6 @@ impl<'a, M: MerkleTree> StarkProof { prover_addr: &str, ) -> Result> { let mut ctx = StarkContext::::default(); - //log::trace!("starkinfo: {}", starkinfo); - //log::trace!("program: {}", program); let mut fftobj = FFT::new(); ctx.nbits = stark_struct.nBits; @@ -212,7 +210,10 @@ impl<'a, M: MerkleTree> StarkProof { let mut n_cm = starkinfo.n_cm1; + log::trace!("Alloc context memory"); ctx.cm1_n = cm_pols.write_buff(); + drop(cm_pols); + ctx.cm2_n = vec![M::ExtendField::ZERO; (starkinfo.map_sectionsN.cm2_n) * ctx.N]; ctx.cm3_n = vec![M::ExtendField::ZERO; (starkinfo.map_sectionsN.cm3_n) * ctx.N]; ctx.tmpexp_n = vec![M::ExtendField::ZERO; (starkinfo.map_sectionsN.tmpexp_n) * ctx.N]; @@ -228,27 +229,28 @@ impl<'a, M: MerkleTree> StarkProof { ctx.x_n = vec![M::ExtendField::ZERO; ctx.N]; - let mut xx = M::ExtendField::ONE; + let xx = M::ExtendField::ONE; // Using the precomputing value let w_nbits: M::ExtendField = M::ExtendField::from(MG.0[ctx.nbits]); - for i in 0..ctx.N { - ctx.x_n[i] = xx; - xx *= w_nbits; - } + ctx.x_n.par_iter_mut().enumerate().for_each(|(k, xb)| { + *xb = xx * w_nbits.exp(k); + }); let extend_bits = ctx.nbits_ext - ctx.nbits; ctx.x_2ns = vec![M::ExtendField::ZERO; ctx.N << extend_bits]; - let mut xx: M::ExtendField = M::ExtendField::from(*SHIFT); - for i in 0..(ctx.N << extend_bits) { - ctx.x_2ns[i] = xx; - xx *= M::ExtendField::from(MG.0[ctx.nbits_ext]); - } + let shift_ext: M::ExtendField = M::ExtendField::from(*SHIFT); + let w_nbits_ext: M::ExtendField = M::ExtendField::from(MG.0[ctx.nbits_ext]); + ctx.x_2ns.par_iter_mut().enumerate().for_each(|(k, xb)| { + *xb = shift_ext * w_nbits_ext.exp(k); + }); ctx.Zi = build_Zh_Inv::(ctx.nbits, extend_bits, 0); + log::trace!("Convert const pols to array"); ctx.const_n = const_pols.write_buff(); const_tree.to_extend(&mut ctx.const_2ns); + drop(const_pols); ctx.publics = vec![M::ExtendField::ZERO; starkinfo.publics.len()]; for (i, pe) in starkinfo.publics.iter().enumerate() { @@ -276,8 +278,10 @@ impl<'a, M: MerkleTree> StarkProof { transcript.put(&b[..])?; } + //Do pre-allocation + let mut result = vec![M::ExtendField::ZERO; (1 << stark_struct.nBitsExt) * 8]; log::trace!("Merkelizing 1...."); - let tree1 = extend_and_merkelize::(&mut ctx, starkinfo, "cm1_n")?; + let tree1 = extend_and_merkelize::(&mut ctx, starkinfo, "cm1_n", &mut result)?; tree1.to_extend(&mut ctx.cm1_2ns); log::trace!( @@ -306,7 +310,7 @@ impl<'a, M: MerkleTree> StarkProof { } log::trace!("Merkelizing 2...."); - let tree2 = extend_and_merkelize::(&mut ctx, starkinfo, "cm2_n")?; + let tree2 = extend_and_merkelize::(&mut ctx, starkinfo, "cm2_n", &mut result)?; tree2.to_extend(&mut ctx.cm2_2ns); transcript.put(&[tree2.root().as_elements().to_vec()])?; log::trace!( @@ -353,7 +357,7 @@ impl<'a, M: MerkleTree> StarkProof { log::trace!("Merkelizing 3...."); - let tree3 = extend_and_merkelize::(&mut ctx, starkinfo, "cm3_n")?; + let tree3 = extend_and_merkelize::(&mut ctx, starkinfo, "cm3_n", &mut result)?; tree3.to_extend(&mut ctx.cm3_2ns); transcript.put(&[tree3.root().as_elements().to_vec()])?; @@ -368,12 +372,15 @@ impl<'a, M: MerkleTree> StarkProof { calculate_exps_parallel(&mut ctx, starkinfo, &program.step42ns, "2ns", "step4"); + log::trace!("Calculate c polynomial"); let mut qq1 = vec![M::ExtendField::ZERO; ctx.q_2ns.len()]; let mut qq2 = vec![M::ExtendField::ZERO; starkinfo.q_dim * ctx.Next * starkinfo.q_deg]; ifft(&ctx.q_2ns, starkinfo.q_dim, ctx.nbits_ext, &mut qq1); let mut cur_s = M::ExtendField::ONE; - let shift_in = (M::ExtendField::inv(&M::ExtendField::from(*SHIFT))).exp(ctx.N); + let shift_inv = (M::ExtendField::inv(&shift_ext)).exp(ctx.N); + + log::trace!("Calculate qq2"); for p in 0..starkinfo.q_deg { for i in 0..ctx.N { for k in 0..starkinfo.q_dim { @@ -381,9 +388,10 @@ impl<'a, M: MerkleTree> StarkProof { qq1[p * ctx.N * starkinfo.q_dim + i * starkinfo.q_dim + k] * cur_s; } } - cur_s *= shift_in; + cur_s *= shift_inv; } + // powdr may produce constant polynomial only if starkinfo.q_deg > 0 { fft( &qq2, @@ -416,9 +424,8 @@ impl<'a, M: MerkleTree> StarkProof { LEv[0] = M::ExtendField::from(FGL::from(1u64)); LpEv[0] = M::ExtendField::from(FGL::from(1u64)); - let xis = ctx.challenge[7] / M::ExtendField::from(*SHIFT); - let wxis = (ctx.challenge[7] * M::ExtendField::from(MG.0[ctx.nbits])) - / M::ExtendField::from(*SHIFT); + let xis = ctx.challenge[7] / shift_ext; + let wxis = (ctx.challenge[7] * w_nbits) / shift_ext; for i in 1..ctx.N { LEv[i] = LEv[i - 1] * xis; @@ -429,6 +436,7 @@ impl<'a, M: MerkleTree> StarkProof { let LpEv = fftobj.ifft(&LpEv); ctx.evals = vec![M::ExtendField::ZERO; starkinfo.ev_map.len()]; + log::trace!("Evals"); let N = ctx.N; for (i, ev) in starkinfo.ev_map.iter().enumerate() { let p = match ev.type_.as_str() { @@ -445,7 +453,6 @@ impl<'a, M: MerkleTree> StarkProof { } }; let l = if ev.prime { &LpEv } else { &LEv }; - log::trace!("calculate acc: N={}", N); let acc = (0..N) .into_par_iter() .map(|k| { @@ -464,6 +471,7 @@ impl<'a, M: MerkleTree> StarkProof { ctx.evals[i] = acc; } + log::trace!("Add evals to transcript"); for i in 0..ctx.evals.len() { let b = ctx.evals[i] .as_elements() @@ -492,9 +500,9 @@ impl<'a, M: MerkleTree> StarkProof { let mut x_buff = vec![M::ExtendField::ZERO; extend_size]; + let w_ext = M::ExtendField::from(MG.0[ctx.nbits + extend_bits]); x_buff.par_iter_mut().enumerate().for_each(|(k, xb)| { - *xb = M::ExtendField::from(*SHIFT) - * M::ExtendField::from(MG.0[ctx.nbits + extend_bits]).exp(k); + *xb = shift_ext * w_ext.exp(k); }); tmp_den @@ -568,7 +576,7 @@ impl<'a, M: MerkleTree> StarkProof { ) -> T { ctx.tmp = vec![T::ZERO; seg.tmp_used]; let t = compile_code(ctx, starkinfo, &seg.first, "n", true); - log::trace!("calculate_exp_at_point compile_code ctx.first:\n{}", t); + //log::trace!("calculate_exp_at_point compile_code ctx.first:\n{}", t); // just let public codegen run multiple times //log::trace!("{} = {} @ {}", res, ctx.cm1_n[1 + 2 * idx], idx); @@ -625,32 +633,35 @@ fn set_pol( } } -#[time_profiler()] +#[time_profiler("calculate_H1H2")] fn calculate_H1H2(f: Vec, t: Vec) -> (Vec, Vec) { - let mut idx_t: HashMap = HashMap::new(); - let mut s: Vec<(F, usize)> = vec![]; + let mut idx_t: HashMap = HashMap::with_capacity(t.len()); + let mut s: Vec<(F, usize)> = vec![(F::ZERO, 0); t.len() + f.len()]; for (i, e) in t.iter().enumerate() { idx_t.insert(*e, i); - s.push((*e, i)); + s[i] = (*e, i); } - for e in f.iter() { + for (i, e) in f.iter().enumerate() { let idx = idx_t.get(e); if idx.is_none() { panic!("Number not included: {:?}", e); } - s.push((*e, *idx.unwrap())); + s[i + t.len()] = (*e, *idx.unwrap()); } s.sort_by(|a, b| a.1.cmp(&b.1)); let mut h1 = vec![F::ZERO; f.len()]; let mut h2 = vec![F::ZERO; f.len()]; - for i in 0..f.len() { - h1[i] = s[2 * i].0; - h2[i] = s[2 * i + 1].0; - } + h1.par_iter_mut() + .zip(h2.par_iter_mut()) + .enumerate() + .for_each(|(i, (h1_, h2_))| { + *h1_ = s[2 * i].0; + *h2_ = s[2 * i + 1].0; + }); (h1, h2) } @@ -710,18 +721,22 @@ pub fn get_pol( res } -#[time_profiler()] +#[time_profiler("extend_and_merkelize")] pub fn extend_and_merkelize( ctx: &mut StarkContext, starkinfo: &StarkInfo, section_name: &'static str, + result: &mut Vec, ) -> Result { let nBitsExt = ctx.nbits_ext; let nBits = ctx.nbits; let n_pols = starkinfo.map_sectionsN.get(section_name); - let mut result = vec![M::ExtendField::ZERO; (1 << nBitsExt) * n_pols]; + + let curr_size = (1 << nBitsExt) * n_pols; + result.resize(curr_size, M::ExtendField::ZERO); + let p = ctx.get_mut(section_name); - interpolate(p, n_pols, nBits, &mut result, nBitsExt); + interpolate(p, n_pols, nBits, result, nBitsExt); let mut p_be = vec![FGL::ZERO; result.len()]; p_be.par_iter_mut() .zip(result) @@ -733,7 +748,7 @@ pub fn extend_and_merkelize( Ok(tree) } -#[time_profiler()] +#[time_profiler("merkelize")] pub fn merkelize( ctx: &mut StarkContext, starkinfo: &StarkInfo, @@ -756,18 +771,18 @@ pub fn calculate_exps( starkinfo: &StarkInfo, seg: &Segment, dom: &str, - step: &str, + //step: &str, N: usize, ) { ctx.tmp = vec![F::ZERO; seg.tmp_used]; let c_first = compile_code(ctx, starkinfo, &seg.first, dom, false); + /* log::trace!( "calculate_exps compile_code {} ctx.first:\n{}", step, c_first ); - /* let mut N = if dom == "n" { ctx.N } else { ctx.Next }; let _c_i = compile_code(ctx, starkinfo, &seg.i, dom, false); let _c_last = compile_code(ctx, starkinfo, &seg.last, dom, false); @@ -1054,7 +1069,7 @@ pub fn calculate_exps_parallel( *tmp = vec![F::ZERO; so.width * (cur_n + next)]; } } - calculate_exps(tmp_ctx, starkinfo, seg, dom, step, cur_n); + calculate_exps(tmp_ctx, starkinfo, seg, dom, cur_n); }); // write back the output @@ -1107,8 +1122,8 @@ pub mod tests { log::trace!("setup {}", fr_root); let starkproof = StarkProof::::stark_gen::( - &cm_pol, - &const_pol, + cm_pol, + const_pol, &setup.const_tree, &setup.starkinfo, &setup.program, @@ -1143,8 +1158,8 @@ pub mod tests { let mut setup = StarkSetup::::new(&const_pol, &mut pil, &stark_struct, None).unwrap(); let starkproof = StarkProof::::stark_gen::( - &cm_pol, - &const_pol, + cm_pol, + const_pol, &setup.const_tree, &setup.starkinfo, &setup.program, @@ -1178,8 +1193,8 @@ pub mod tests { let mut setup = StarkSetup::::new(&const_pol, &mut pil, &stark_struct, None).unwrap(); let starkproof = StarkProof::::stark_gen::( - &cm_pol, - &const_pol, + cm_pol, + const_pol, &setup.const_tree, &setup.starkinfo, &setup.program, @@ -1208,11 +1223,16 @@ pub mod tests { let mut cm_pol = PolsArray::new(&pil, PolKind::Commit); cm_pol.load("data/connection.cm").unwrap(); let stark_struct = load_json::("data/starkStruct.json").unwrap(); - let mut setup = + let setup_ = StarkSetup::::new(&const_pol, &mut pil, &stark_struct, None).unwrap(); + + let sp = "/tmp/connection.setup"; + setup_.save(sp).unwrap(); + let mut setup = StarkSetup::load(sp).unwrap(); + let starkproof = StarkProof::::stark_gen::( - &cm_pol, - &const_pol, + cm_pol, + const_pol, &setup.const_tree, &setup.starkinfo, &setup.program, @@ -1241,11 +1261,16 @@ pub mod tests { let mut cm_pol = PolsArray::new(&pil, PolKind::Commit); cm_pol.load("data/plookup.cm.gl").unwrap(); let stark_struct = load_json::("data/starkStruct.json.gl").unwrap(); - let mut setup = + let setup_ = StarkSetup::::new(&const_pol, &mut pil, &stark_struct, None).unwrap(); + + let sp = "/tmp/plonkup.setup"; + setup_.save(sp).unwrap(); + let mut setup = StarkSetup::load(sp).unwrap(); + let starkproof = StarkProof::::stark_gen::( - &cm_pol, - &const_pol, + cm_pol, + const_pol, &setup.const_tree, &setup.starkinfo, &setup.program, diff --git a/starky/src/stark_setup.rs b/starky/src/stark_setup.rs index e6c73215..f4b24dcd 100644 --- a/starky/src/stark_setup.rs +++ b/starky/src/stark_setup.rs @@ -1,4 +1,7 @@ #![allow(non_snake_case, dead_code)] +use rayon::prelude::*; +use std::fs; +use std::path; use crate::errors::Result; use crate::fft_p::interpolate; @@ -8,7 +11,6 @@ use crate::traits::{FieldExtension, MerkleTree}; use crate::types::{StarkStruct, PIL}; use plonky::field_gl::Fr as FGL; use profiler_macro::time_profiler; -use rayon::prelude::*; #[derive(Default)] pub struct StarkSetup { @@ -18,6 +20,50 @@ pub struct StarkSetup { pub program: Program, } +impl StarkSetup { + pub fn save(&self, base_dir: &str) -> Result<()> { + if path::Path::new(base_dir).exists() { + fs::remove_dir_all(base_dir)?; + } + std::fs::create_dir_all(base_dir)?; + let base_dir = path::Path::new(base_dir); + let ct = base_dir.join("const_tree"); + let mut writer = fs::File::create(ct)?; + self.const_tree.save(&mut writer)?; + + let si = base_dir.join("starkinfo"); + let si = fs::File::create(si)?; + serde_json::to_writer(si, &self.starkinfo)?; + + let pg = base_dir.join("program"); + let pg = fs::File::create(pg)?; + serde_json::to_writer(pg, &self.program)?; + Ok(()) + } + + pub fn load(base_dir: &str) -> Result { + let base_dir = path::Path::new(base_dir); + let ct = base_dir.join("const_tree"); + let mut reader = fs::File::open(ct)?; + let const_tree = M::load(&mut reader)?; + let const_root = const_tree.root(); + + let si = base_dir.join("starkinfo"); + let si = fs::File::open(si)?; + let starkinfo: StarkInfo = serde_json::from_reader(si)?; + + let pg = base_dir.join("program"); + let pg = fs::File::open(pg)?; + let program: Program = serde_json::from_reader(pg)?; + Ok(StarkSetup { + const_tree, + const_root, + starkinfo, + program, + }) + } +} + /// STARK SETUP /// /// calculate the trace polynomial over extended field, return the new polynomial's coefficient. @@ -36,11 +82,13 @@ impl StarkSetup { let mut p: Vec> = vec![Vec::new(); const_pol.nPols]; for i in 0..const_pol.nPols { - for j in 0..const_pol.n { - p[i].push(const_pol.array[i][j]) - } + p[i] = vec![FGL::ZERO; const_pol.n]; + p[i].par_iter_mut().enumerate().for_each(|(j, out)| { + *out = const_pol.array[i][j]; + }); } + log::trace!("Write const pol buff and interpolate"); let const_buff = const_pol.write_buff(); //extend and merkelize let mut const_pols_array_e = vec![M::ExtendField::ZERO; (1 << nBitsExt) * pil.nConstants]; @@ -62,6 +110,7 @@ impl StarkSetup { }); let mut const_tree = M::new(); + log::trace!("Merkelize const tree"); const_tree.merkelize( const_pols_array_e_be, const_pol.nPols, diff --git a/starky/src/stark_verify.rs b/starky/src/stark_verify.rs index ba5433ea..e3015902 100644 --- a/starky/src/stark_verify.rs +++ b/starky/src/stark_verify.rs @@ -9,6 +9,7 @@ use crate::starkinfo_codegen::{Node, Section}; use crate::traits; use crate::traits::FieldExtension; use crate::traits::{MTNodeType, MerkleTree, Transcript}; +use crate::types::parse_pil_number; use crate::types::StarkStruct; use anyhow::{bail, Result}; use plonky::field_gl::Fr as FGL; @@ -16,7 +17,7 @@ use profiler_macro::time_profiler; use std::collections::HashMap; //FIXME it doesn't make sense to ask for a mutable program -#[time_profiler()] +#[time_profiler("stark_verify")] pub fn stark_verify( proof: &StarkProof, const_root: &M::MTNode, @@ -73,7 +74,7 @@ pub fn stark_verify( (ctx.challenge[7] * M::ExtendField::from(MG.0[ctx.nbits])).exp(ctx.N) - M::ExtendField::ONE; log::trace!("verifier_code {}", program.verifier_code); - let res = execute_code(&mut ctx, &mut program.verifier_code.first); + let res = execute_code(&ctx, &mut program.verifier_code.first); log::trace!("starkinfo: {}", starkinfo); let mut x_acc = M::ExtendField::ONE; @@ -136,7 +137,7 @@ pub fn stark_verify( .as_elements(); let vals = vec![execute_code( - &mut ctx_query, + &ctx_query, &mut program.verifier_query_code.first, )]; @@ -146,7 +147,7 @@ pub fn stark_verify( fri.verify(&mut transcript, &proof.fri_proof, check_query) } -fn execute_code(ctx: &mut StarkContext, code: &mut Vec
) -> F { +fn execute_code(ctx: &StarkContext, code: &mut Vec
) -> F { let mut tmp: HashMap = HashMap::new(); let extract_val = |arr: &Vec, pos: usize, dim: usize| -> F { @@ -170,7 +171,7 @@ fn execute_code(ctx: &mut StarkContext, code: &mut Vec extract_val(&ctx.tree4, r.tree_pos, r.dim), "const" => ctx.consts[r.id].into(), "eval" => ctx.evals[r.id], - "number" => F::from(r.value.clone().unwrap().parse::().unwrap()), + "number" => F::from(parse_pil_number(r.value.as_ref().unwrap())), "public" => ctx.publics[r.id], "challenge" => ctx.challenge[r.id], // TODO: Support F5G diff --git a/starky/src/starkinfo.rs b/starky/src/starkinfo.rs index 54dd7cc6..04abc6ec 100644 --- a/starky/src/starkinfo.rs +++ b/starky/src/starkinfo.rs @@ -6,11 +6,11 @@ use crate::starkinfo_codegen::{ }; use crate::types::{Expression, Public, StarkStruct, PIL}; use anyhow::{bail, Result}; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt; -#[derive(Default, Debug, Serialize)] +#[derive(Default, Debug, Serialize, Deserialize)] pub struct PCCTX { pub f_exp_id: usize, pub t_exp_id: usize, @@ -23,7 +23,7 @@ pub struct PCCTX { pub den_id: usize, } -#[derive(Debug, Default, Serialize)] +#[derive(Debug, Default, Serialize, Deserialize)] pub struct Program { pub publics_code: Vec, pub step2prev: Segment, @@ -46,7 +46,7 @@ impl fmt::Display for Program { } } -#[derive(Debug, Default, Serialize)] +#[derive(Debug, Default, Serialize, Deserialize)] pub struct StarkInfo { pub var_pol_map: Vec, pub n_cm1: usize, @@ -410,7 +410,7 @@ impl StarkInfo { program: &mut Program, ) -> Result<()> { let ppi = pil.plookupIdentities.clone(); - log::trace!("generate_step2: [{:?}]", ppi); + //log::trace!("generate_step2: [{:?}]", ppi); for pi in ppi.iter() { let u = E::challenge("u".to_string()); let def_val = E::challenge("defVal".to_string()); @@ -481,11 +481,11 @@ impl StarkInfo { } program.step2prev = build_code(ctx, pil); - log::trace!("pu_ctx {:?}", self.pu_ctx); - log::trace!("step2prev {}", program.step2prev); + //log::trace!("pu_ctx {:?}", self.pu_ctx); + //log::trace!("step2prev {}", program.step2prev); ctx.calculated.clear(); self.n_cm2 = pil.nCommitments - self.n_cm1; - log::trace!("n_cm2 {}", self.n_cm2); + //log::trace!("n_cm2 {}", self.n_cm2); Ok(()) } } diff --git a/starky/src/starkinfo_Z.rs b/starky/src/starkinfo_Z.rs index 2ef4c8b7..d28f01c4 100644 --- a/starky/src/starkinfo_Z.rs +++ b/starky/src/starkinfo_Z.rs @@ -159,7 +159,6 @@ impl StarkInfo { pil.nQ += 1; num_exp.keep = Some(true); pu_ctx.num_id = pil.expressions.len(); - log::trace!("num_exp: {} {}", i, num_exp); pil.expressions.push(num_exp); // G(\beta, \gamma) @@ -178,7 +177,6 @@ impl StarkInfo { pil.nQ += 1; pu_ctx.den_id = pil.expressions.len(); den_exp.keep = Some(true); - log::trace!("den_exp: {} {}", i, den_exp); pil.expressions.push(den_exp); let num = E::exp(pu_ctx.num_id, None); diff --git a/starky/src/starkinfo_codegen.rs b/starky/src/starkinfo_codegen.rs index a65f6a9c..ec8d9cfb 100644 --- a/starky/src/starkinfo_codegen.rs +++ b/starky/src/starkinfo_codegen.rs @@ -5,7 +5,10 @@ use crate::traits::FieldExtension; use crate::types::Expression; use crate::types::PIL; use anyhow::{bail, Result}; -use serde::Serialize; +use serde::ser::SerializeSeq; +use serde::Deserializer; +use serde::{Deserialize, Serialize}; + use std::collections::HashMap; use std::fmt; @@ -42,7 +45,7 @@ pub struct Code { pub idQ: Option, } -#[derive(Clone, Debug, Default, Serialize)] +#[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct Node { pub type_: String, pub id: usize, @@ -78,14 +81,14 @@ impl Node { } /// Subcode -#[derive(Clone, Debug, Default, Serialize)] +#[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct Section { pub op: String, pub dest: Node, pub src: Vec, } -#[derive(Debug, Default, Serialize)] +#[derive(Debug, Default, Serialize, Deserialize)] pub struct Segment { pub first: Vec
, pub i: Vec
, @@ -115,7 +118,7 @@ impl Segment { } } -#[derive(Debug, Default, Serialize)] +#[derive(Debug, Default, Serialize, Deserialize)] pub struct IndexVec { pub cm1_n: Vec, pub cm1_2ns: Vec, @@ -149,7 +152,7 @@ impl IndexVec { } } -#[derive(Debug, Default, Serialize)] +#[derive(Debug, Default, Serialize, Deserialize)] pub struct Index { pub cm1_n: usize, pub cm1_2ns: usize, @@ -222,7 +225,7 @@ impl Index { } } -#[derive(Debug, Default, Serialize)] +#[derive(Debug, Default, Serialize, Deserialize)] pub struct PolType { pub section: String, pub section_pos: usize, @@ -239,12 +242,39 @@ pub struct Polynom<'a, F: FieldExtension> { pub dim: usize, } -#[derive(Debug, Clone, Default, Serialize)] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct EVIdx { + #[serde(serialize_with = "serialize_map", deserialize_with = "deserialize_map")] pub cm: HashMap<(usize, usize), usize>, + #[serde(serialize_with = "serialize_map", deserialize_with = "deserialize_map")] pub const_: HashMap<(usize, usize), usize>, } +fn serialize_map( + value: &HashMap, + serializer: S, +) -> std::result::Result +where + S: serde::Serializer, +{ + let vec_map = value.iter().collect::>(); + let mut map = serializer.serialize_seq(Some(value.len()))?; + for v in &vec_map { + map.serialize_element(v)?; + } + map.end() +} + +fn deserialize_map<'de, D>( + deserializer: D, +) -> std::result::Result, D::Error> +where + D: Deserializer<'de>, +{ + let vec = Vec::deserialize(deserializer)?; + Ok(HashMap::from_iter(vec)) +} + impl EVIdx { pub fn new() -> Self { EVIdx { @@ -477,7 +507,6 @@ pub fn eval_exp( } "mulc" => { let a = eval_exp(code_ctx, pil, &values[0], prime)?; - log::trace!("mulc: {:?}", exp); let b = Node::new( "number".to_string(), 0, @@ -632,7 +661,7 @@ pub fn expression_error(_pil: &PIL, strerr: String, _e1: usize, _e2: usize) -> R bail!(strerr); } -pub fn build_code(ctx: &mut Context, pil: &mut PIL) -> Segment { +pub fn build_code(ctx: &mut Context, pil: &PIL) -> Segment { let seg = Segment { first: build_linear_code(ctx, pil, "first"), i: build_linear_code(ctx, pil, "i"), @@ -651,7 +680,7 @@ pub fn build_code(ctx: &mut Context, pil: &mut PIL) -> Segment { seg } -pub fn build_linear_code(ctx: &mut Context, pil: &PIL, loop_pos: &str) -> Vec
{ +pub fn build_linear_code(ctx: &Context, pil: &PIL, loop_pos: &str) -> Vec
{ let exp_and_expprimes = match loop_pos { "i" | "last" => get_exp_and_expprimes(ctx, pil), _ => HashMap::::new(), @@ -674,7 +703,7 @@ pub fn build_linear_code(ctx: &mut Context, pil: &PIL, loop_pos: &str) -> Vec HashMap { +fn get_exp_and_expprimes(ctx: &Context, pil: &PIL) -> HashMap { let mut calc_exps = HashMap::::new(); for i in 0..ctx.code.len() { if (pil.expressions[ctx.code[i].exp_id].idQ.is_some()) diff --git a/starky/src/starkinfo_cp_ver.rs b/starky/src/starkinfo_cp_ver.rs index 8fc5198b..79ebd1a6 100644 --- a/starky/src/starkinfo_cp_ver.rs +++ b/starky/src/starkinfo_cp_ver.rs @@ -11,12 +11,12 @@ impl StarkInfo { pil: &mut PIL, program: &mut Program, ) -> Result<()> { - log::trace!("cp ver begin ctx {:?}, c_exp: {}", ctx, self.c_exp); + //log::trace!("cp ver begin ctx {:?}, c_exp: {}", ctx, self.c_exp); pil_code_gen(ctx, pil, self.c_exp, false, "", 0, true)?; - log::trace!("cp ver buildcode ctx begin {:?}", ctx); + //log::trace!("cp ver buildcode ctx begin {:?}", ctx); let mut code = build_code(ctx, pil); - log::trace!("cp ver buildcode {}", code); + //log::trace!("cp ver buildcode {}", code); let mut ctx_f = ContextF { exp_map: HashMap::new(), @@ -25,7 +25,7 @@ impl StarkInfo { tmpexps: &mut HashMap::new(), starkinfo: self, }; - log::trace!("cp ver code.tmp_used begin {}", code.tmp_used); + //log::trace!("cp ver code.tmp_used begin {}", code.tmp_used); let fix_ref = |r: &mut Node, ctx: &mut ContextF, _pil: &mut PIL| { let p = if r.prime { 1 } else { 0 }; diff --git a/starky/src/starkinfo_map.rs b/starky/src/starkinfo_map.rs index b0cd694d..57145b8c 100644 --- a/starky/src/starkinfo_map.rs +++ b/starky/src/starkinfo_map.rs @@ -42,7 +42,7 @@ impl StarkInfo { pil.cm_dims[i] = 1 } - log::trace!("pu: {:?}", self.pu_ctx); + //log::trace!("pu: {:?}", self.pu_ctx); for (i, pu) in self.pu_ctx.iter().enumerate() { let dim = std::cmp::max( Self::get_exp_dim(pil, &pil.expressions[pu.f_exp_id]), @@ -236,7 +236,7 @@ impl StarkInfo { }); self.f_2ns.push(ppf_2ns); - log::trace!("cm_dims: {:?}", pil.cm_dims); + //log::trace!("cm_dims: {:?}", pil.cm_dims); self.map_section()?; let N = 1 << stark_struct.nBits; let Next = 1 << stark_struct.nBitsExt; @@ -345,12 +345,7 @@ impl StarkInfo { } } - fn get_dim( - &mut self, - r: &mut Node, - tmp_dim: &mut HashMap, - dim_x: usize, - ) -> usize { + fn get_dim(&mut self, r: &mut Node, tmp_dim: &HashMap, dim_x: usize) -> usize { #[allow(unused_assignments)] let mut d = 0; match r.type_.as_str() { diff --git a/starky/src/traits.rs b/starky/src/traits.rs index 144df66d..4a062a36 100644 --- a/starky/src/traits.rs +++ b/starky/src/traits.rs @@ -8,16 +8,25 @@ use plonky::Field; use serde::ser::Serialize; use std::fmt::{Debug, Display}; use std::hash::Hash; +use std::io::{Read, Write}; -pub trait MTNodeType { +pub trait MTNodeType +where + Self: Sized, +{ fn as_elements(&self) -> &[FGL]; fn new(value: &[FGL]) -> Self; fn from_scalar(e: &T) -> Self; fn as_scalar(&self) -> T::Repr; + fn save(&self, writer: &mut W) -> Result<()>; + fn load(reader: &mut R) -> Result; } #[allow(clippy::type_complexity)] -pub trait MerkleTree { +pub trait MerkleTree +where + Self: Sized, +{ type MTNode: Copy + std::fmt::Display + Clone + Default + MTNodeType + core::fmt::Debug; type BaseField: Clone + Default @@ -39,6 +48,8 @@ pub trait MerkleTree { fn root(&self) -> Self::MTNode; fn eq_root(&self, r1: &Self::MTNode, r2: &Self::MTNode) -> bool; fn element_size(&self) -> usize; + fn save(&self, writer: &mut W) -> Result<()>; + fn load(reader: &mut R) -> Result; } pub trait Transcript { @@ -104,23 +115,3 @@ pub trait FieldExtension: // fn rand_ // (&self) -> &[u8]; } - -pub fn batch_inverse(elems: &[F]) -> Vec { - if elems.is_empty() { - return vec![]; - } - - let mut tmp: Vec = vec![F::ZERO; elems.len()]; - tmp[0] = elems[0]; - for i in 1..elems.len() { - tmp[i] = elems[i] * (tmp[i - 1]); - } - let mut z = tmp[tmp.len() - 1].inv(); - let mut res: Vec = vec![F::ZERO; elems.len()]; - for i in (1..elems.len()).rev() { - res[i] = z * tmp[i - 1]; - z *= elems[i]; - } - res[0] = z; - res -} diff --git a/starky/src/transcript.rs b/starky/src/transcript.rs index d31e030d..a03be2d6 100644 --- a/starky/src/transcript.rs +++ b/starky/src/transcript.rs @@ -24,7 +24,6 @@ impl TranscriptGL { Ok(()) } fn add_1(&mut self, e: &FGL) -> Result<()> { - log::trace!("add_1: {}", e); self.out = Vec::new(); self.pending.push(*e); if self.pending.len() == 8 { diff --git a/starky/src/types.rs b/starky/src/types.rs index e49ff3cf..16555477 100644 --- a/starky/src/types.rs +++ b/starky/src/types.rs @@ -217,6 +217,21 @@ where Ok(serde_json::from_str(&data)?) } +#[inline(always)] +pub fn parse_pil_number(raw_val: &str) -> u64 { + //let raw_val = r.value.as_ref().unwrap(); + let mut n_val: i128 = match raw_val.starts_with("0x") { + true => i128::from_str_radix(&raw_val[2..], 16).unwrap(), + _ => raw_val.parse::().unwrap(), + }; + // FIXME: Goldilocks modular, try to fetch it from FieldExtension + if n_val < 0 { + n_val += 18446744069414584321; + } + n_val %= 18446744069414584321; + n_val as u64 +} + #[cfg(test)] mod test { use super::*; diff --git a/zkvm/Cargo.toml b/zkvm/Cargo.toml index 18b3a0eb..8edc8e67 100644 --- a/zkvm/Cargo.toml +++ b/zkvm/Cargo.toml @@ -10,13 +10,14 @@ itertools = "0.12.0" # serialization log = "0.4.0" -powdr = { git = "https://github.com/powdr-labs/powdr", branch = "main" } -backend = { git = "https://github.com/powdr-labs/powdr", branch = "main", package = "backend" } -models = { git = "https://github.com/powdr-labs/powdr_revm", branch = "continuations", package = "models" } +powdr = { git = "https://github.com/eigmax/powdr", branch = "main" } +backend = { git = "https://github.com/eigmax/powdr", branch = "main", package = "backend" } +models = { git = "https://github.com/eigmax/powdr_revm", branch = "continuations", package = "models" } hex = "0.4.3" thiserror = "1.0" revm = { git = "https://github.com/powdr-labs/revm", branch = "serde-no-std", default-features = false, features = [ "serde" ] } serde_json = "1.0.108" +anyhow = "1.0.79" [dev-dependencies] env_logger = "0.10" diff --git a/zkvm/src/lib.rs b/zkvm/src/lib.rs index 724f74eb..78d36743 100644 --- a/zkvm/src/lib.rs +++ b/zkvm/src/lib.rs @@ -1,3 +1,4 @@ +use anyhow::Result; use backend::BackendType; use powdr::number::GoldilocksField; use powdr::pipeline::{Pipeline, Stage}; @@ -9,12 +10,12 @@ use powdr::riscv_executor; use std::path::Path; use std::time::Instant; -pub fn zkvm_evm_prove_one(suite_json: String, output_path: &str) -> Result<(), String> { +pub fn zkvm_evm_prove_one(task: &str, suite_json: String, output_path: &str) -> Result<()> { log::debug!("Compiling Rust..."); let force_overwrite = true; let with_bootloader = true; let (asm_file_path, asm_contents) = compile_rust( - "vm/evm", + &format!("vm/{task}"), Path::new(output_path), force_overwrite, &CoProcessors::base().with_poseidon(), @@ -25,6 +26,7 @@ pub fn zkvm_evm_prove_one(suite_json: String, output_path: &str) -> Result<(), S let mk_pipeline = || { Pipeline::::default() + .with_output(output_path.into(), true) .from_asm_string(asm_contents.clone(), Some(asm_file_path.clone())) .with_prover_inputs(vec![]) }; @@ -103,6 +105,7 @@ mod tests { //use revm::primitives::address; + // RUST_MIN_STACK=2073741821 RUST_LOG=debug proxychains nohup cargo test --release test_zkvm_evm_prove -- --nocapture & #[test] #[ignore] fn test_zkvm_evm_prove() { @@ -111,36 +114,17 @@ mod tests { let test_file = "test-vectors/solidityExample.json"; let suite_json = std::fs::read_to_string(test_file).unwrap(); - /* - let map_caller_keys: HashMap<_, _> = [ - ( - b256!("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8"), - address!("a94f5374fce5edbc8e2a8697c15331677e6ebf0b"), - ), - ( - b256!("c85ef7d79691fe79573b1a7064c19c1a9819ebdbd1faaab1a8ec92344438aaf4"), - address!("cd2a3d9f938e13cd947ec05abc7fe734df8dd826"), - ), - ( - b256!("044852b2a670ade5407e78fb2863c51de9fcb96542a07186fe3aeda6bb8a116d"), - address!("82a978b3f5962a5b0957d9ee9eef472ee55b42f1"), - ), - ( - b256!("6a7eeac5f12b409d42028f66b0b2132535ee158cfda439e3bfdd4558e8f4bf6c"), - address!("c9c5a15a403e41498b6f69f6f89dd9f5892d21f7"), - ), - ( - b256!("a95defe70ebea7804f9c3be42d20d24375e2a92b9d9666b832069c5f3cd423dd"), - address!("3fb1cd2cd96c6d5c0b5eb3322d807b34482481d4"), - ), - ( - b256!("fe13266ff57000135fb9aa854bbfe455d8da85b21f626307bf3263a0c2a8e7fe"), - address!("dcc5ba93a1ed7e045690d722f2bf460a51c61415"), - ), - ] - .into(); - */ - - zkvm_evm_prove_one(suite_json, "/tmp/test").unwrap(); + zkvm_evm_prove_one("evm", suite_json, "/tmp/test_evm").unwrap(); + } + + #[test] + #[ignore] + fn test_zkvm_lr_prove() { + env_logger::try_init().unwrap_or_default(); + //let test_file = "test-vectors/blockInfo.json"; + let test_file = "test-vectors/solidityExample.json"; + let suite_json = std::fs::read_to_string(test_file).unwrap(); + + zkvm_evm_prove_one("lr", suite_json, "/tmp/test_lr").unwrap(); } } diff --git a/zkvm/vm/evm/Cargo.toml b/zkvm/vm/evm/Cargo.toml index a58b78b4..1a3f7884 100644 --- a/zkvm/vm/evm/Cargo.toml +++ b/zkvm/vm/evm/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] revm = { git = "https://github.com/powdr-labs/revm", branch = "serde-no-std", default-features = false, features = [ "serde" ] } -powdr_riscv_rt = { git = "https://github.com/powdr-labs/powdr", branch = "main" } +powdr_riscv_rt = { git = "https://github.com/eigmax/powdr", branch = "main" } models = { git = "https://github.com/eigmax/powdr_revm", branch = "continuations", package = "models" } serde = { version = "1.0", default-features = false, features = ["alloc", "derive", "rc"] } serde_json = { version = "1.0", default-features = false, features = ["alloc"] } diff --git a/zkvm/vm/lr/Cargo.toml b/zkvm/vm/lr/Cargo.toml new file mode 100644 index 00000000..54f1da69 --- /dev/null +++ b/zkvm/vm/lr/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "zk-lr" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +powdr_riscv_rt = { git = "https://github.com/eigmax/powdr", branch = "main" } + + +[workspace] diff --git a/zkvm/vm/lr/rust-toolchain.toml b/zkvm/vm/lr/rust-toolchain.toml new file mode 100644 index 00000000..8e6c8652 --- /dev/null +++ b/zkvm/vm/lr/rust-toolchain.toml @@ -0,0 +1,4 @@ +[toolchain] +channel = "nightly-2023-01-03" +targets = ["riscv32imac-unknown-none-elf"] +profile = "minimal" diff --git a/zkvm/vm/lr/src/lib.rs b/zkvm/vm/lr/src/lib.rs new file mode 100644 index 00000000..788ccc56 --- /dev/null +++ b/zkvm/vm/lr/src/lib.rs @@ -0,0 +1,46 @@ +#![no_std] +extern crate alloc; +use alloc::{vec, vec::Vec}; +//use runtime::get_prover_input; + +fn simple_linear_regression(values: &[(f64, f64)]) -> (f64, f64) { + let (x, y): (Vec, Vec) = values.iter().cloned().unzip(); + + let x_mean = mean(&x); + let y_mean = mean(&y); + + let numerator: f64 = values + .iter() + .map(|&(x, y)| (x - x_mean) * (y - y_mean)) + .sum(); + let denominator: f64 = x.iter().map(|&x| (x - x_mean) * (x - x_mean)).sum(); + + let slope = numerator / denominator; + let y_intercept = y_mean - slope * x_mean; + + (y_intercept, slope) +} + +fn mean(data: &[f64]) -> f64 { + let sum: f64 = data.iter().sum(); + sum / data.len() as f64 +} + +#[no_mangle] +pub fn main() { + /* + let size = get_prover_input(0) as u32; + + let mut line = vec![]; + for i in (1..(size+1)).step_by(2) { + let (x, y) = ( + get_prover_input(i) as i32, + get_prover_input(i+1) as i32 + ); + line.push((f64::from(1), f64::from(y))); + } + */ + + let line = vec![(1.0, 1.0), (2.0, 2.0)]; + let (y_intercept, slope) = simple_linear_regression(&line); +}