diff --git a/starky/src/digest.rs b/starky/src/digest.rs index 088558c8..1209abb1 100644 --- a/starky/src/digest.rs +++ b/starky/src/digest.rs @@ -2,22 +2,31 @@ use crate::field_bls12381::Fr as Fr_bls12381; use crate::field_bls12381::FrRepr as FrRepr_bls12381; use crate::field_bn128::{Fr, FrRepr}; +use crate::helper; use crate::traits::MTNodeType; use ff::*; use fields::field_gl::Fr as FGL; use serde::de::{SeqAccess, Visitor}; use serde::ser::SerializeSeq; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use std::any::TypeId; use std::fmt; use std::fmt::Display; use std::marker::PhantomData; /// the trait F is used to keep track of source data type, so we can implement its deserializer -// TODO: Remove the generic type: F. As it's never used. #[repr(C)] #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct ElementDigest(pub [FGL; N], PhantomData); +impl ElementDigest { + // FIXME: this is a bit tricky that assuming the len is 4, replace it by N here. + pub fn is_dim_1(&self) -> bool { + let e = self.as_elements(); + e[1] == e[2] && e[1] == e[3] && e[1] == FGL::ZERO + } +} + impl MTNodeType for ElementDigest { type BaseField = F; #[inline(always)] @@ -77,13 +86,28 @@ impl Serialize for ElementDigest where S: Serializer, { - let elems = self.0.to_vec(); - - let mut seq = serializer.serialize_seq(Some(elems.len()))?; - for v in elems.iter() { - seq.serialize_element(&v.as_int().to_string())?; + let source = TypeId::of::(); + if source == TypeId::of::() { + let r: Fr = Fr(self.as_scalar::()); + return serializer.serialize_str(&helper::fr_to_biguint(&r).to_string()); + } + if source == TypeId::of::() { + let r: Fr_bls12381 = Fr_bls12381(self.as_scalar::()); + return serializer.serialize_str(&helper::fr_to_biguint(&r).to_string()); + } + if source == TypeId::of::() { + let e = self.as_elements(); + if self.is_dim_1() { + return serializer.serialize_str(&e[0].as_int().to_string()); + } else { + let mut seq = serializer.serialize_seq(Some(4))?; + for v in e.iter() { + seq.serialize_element(&v.as_int().to_string())?; + } + return seq.end(); + } } - seq.end() + panic!("Invalid element to serialize, {:?}", self.0) } } @@ -105,24 +129,34 @@ impl<'de, const N: usize, F: PrimeField + Default> Deserialize<'de> for ElementD where A: SeqAccess<'de>, { - let mut entries = Vec::with_capacity(N); + let mut entries = Vec::new(); while let Some(entry) = seq.next_element::()? { let entry: u64 = entry.parse().unwrap(); - - entries.push(FGL::from_repr(fields::field_gl::FrRepr::from(entry)).unwrap()); + entries.push(FGL::from(entry)); } Ok(ElementDigest::::new(&entries)) } + // it could be one-dim GL, BN128, or BLS12381 fn visit_str(self, s: &str) -> std::result::Result where E: de::Error, { - let entry: u64 = s.parse().unwrap(); - - let data = vec![FGL::from_repr(fields::field_gl::FrRepr::from(entry)).unwrap(); N]; - - Ok(ElementDigest::::new(&data)) + let source = TypeId::of::(); + if source == TypeId::of::() { + // one-dim GL elements + let value = FGL::from_str(s).unwrap(); + Ok(ElementDigest::::new(&[ + value, + FGL::ZERO, + FGL::ZERO, + FGL::ZERO, + ])) + } else { + // BN128 or BLS12381 + let t = F::from_str(s).unwrap(); + Ok(ElementDigest::::from_scalar(&t)) + } } } deserializer.deserialize_any(EntriesVisitor::(Default::default())) @@ -280,11 +314,10 @@ mod tests { #[test] fn test_element_digest_serialize_and_deserialize() { - const N: usize = 3; + const N: usize = 4; let fields = vec![FGL::one(); N]; let data = ElementDigest::::new(&fields); let serialized = serde_json::to_string(&data).unwrap(); - println!("Serialized: {}", serialized); let expect: ElementDigest = serde_json::from_str(&serialized).unwrap(); diff --git a/starky/src/field_bls12381.rs b/starky/src/field_bls12381.rs index e7cde1fe..6b7744a1 100644 --- a/starky/src/field_bls12381.rs +++ b/starky/src/field_bls12381.rs @@ -2,11 +2,51 @@ use core::ops::{Add, Div, Mul, Neg, Sub}; use ff::*; +use crate::helper; +use serde::de::{Error, SeqAccess, Visitor}; +use serde::ser::SerializeSeq; +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt; + #[derive(PrimeField)] #[PrimeFieldModulus = "52435875175126190479447740508185965837690552500527637822603658699938581184513"] #[PrimeFieldGenerator = "7"] pub struct Fr(pub FrRepr); +impl Serialize for Fr { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&helper::fr_to_biguint(self).to_string()) + } +} + +impl<'de> Deserialize<'de> for Fr { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct EntriesVisitor; + + impl<'de> Visitor<'de> for EntriesVisitor { + type Value = Fr; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct Bls12381's Fr") + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + Ok(Self::Value::from_str(v).unwrap()) + } + } + deserializer.deserialize_any(EntriesVisitor) + } +} + #[cfg(test)] mod tests { use crate::field_bls12381::*; diff --git a/starky/src/field_bn128.rs b/starky/src/field_bn128.rs index d3e2509c..76ae7f1a 100644 --- a/starky/src/field_bn128.rs +++ b/starky/src/field_bn128.rs @@ -1,11 +1,52 @@ #![allow(unused_imports, clippy::too_many_arguments)] use ff::*; +use crate::helper; +use ff::*; +use serde::de::{Error, SeqAccess, Visitor}; +use serde::ser::SerializeSeq; +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt; + #[derive(PrimeField)] #[PrimeFieldModulus = "21888242871839275222246405745257275088548364400416034343698204186575808495617"] #[PrimeFieldGenerator = "7"] pub struct Fr(pub FrRepr); +impl Serialize for Fr { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&helper::fr_to_biguint(self).to_string()) + } +} + +impl<'de> Deserialize<'de> for Fr { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct EntriesVisitor; + + impl<'de> Visitor<'de> for EntriesVisitor { + type Value = Fr; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct Bn128's Fr") + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + Ok(Self::Value::from_str(v).unwrap()) + } + } + deserializer.deserialize_any(EntriesVisitor) + } +} + #[cfg(test)] mod tests { use crate::field_bn128::*; diff --git a/starky/src/merklehash.rs b/starky/src/merklehash.rs index 8446d189..bc7f3855 100644 --- a/starky/src/merklehash.rs +++ b/starky/src/merklehash.rs @@ -283,6 +283,10 @@ impl MerkleTree for MerkleTreeGL { vec![node.as_elements().to_vec()[0]] } + fn from_basefield(node: &FGL) -> Self::MTNode { + Self::MTNode::new(&[*node, FGL::ZERO, FGL::ZERO, FGL::ZERO]) + } + #[cfg(not(any( target_feature = "avx512bw", target_feature = "avx512cd", @@ -582,12 +586,8 @@ mod tests { #[test] fn test_merkle_tree_gl_serialize_and_deserialize() { let data = MerkleTreeGL::new(); - let serialized = serde_json::to_string(&data).unwrap(); - println!("Serialized: {}", serialized); - let expect: MerkleTreeGL = serde_json::from_str(&serialized).unwrap(); - assert_eq!(data, expect); } } diff --git a/starky/src/merklehash_bls12381.rs b/starky/src/merklehash_bls12381.rs index 2327d60f..942cf067 100644 --- a/starky/src/merklehash_bls12381.rs +++ b/starky/src/merklehash_bls12381.rs @@ -176,6 +176,10 @@ impl MerkleTree for MerkleTreeBLS12381 { vec![Fr(node.as_scalar::())] } + fn from_basefield(node: &Fr) -> Self::MTNode { + Self::MTNode::from_scalar(node) + } + fn merkelize(&mut self, buff: Vec, width: usize, height: usize) -> Result<()> { let max_workers = get_max_workers(); @@ -373,12 +377,8 @@ mod tests { #[test] fn test_merkle_tree_bls381_serialize_and_deserialize() { let data = MerkleTreeBLS12381::new(); - let serialized = serde_json::to_string(&data).unwrap(); - println!("Serialized: {}", serialized); - let expect: MerkleTreeBLS12381 = serde_json::from_str(&serialized).unwrap(); - assert_eq!(data, expect); } } diff --git a/starky/src/merklehash_bn128.rs b/starky/src/merklehash_bn128.rs index 8dc1dc30..aeca33f2 100644 --- a/starky/src/merklehash_bn128.rs +++ b/starky/src/merklehash_bn128.rs @@ -176,6 +176,10 @@ impl MerkleTree for MerkleTreeBN128 { vec![Fr(node.as_scalar::())] } + fn from_basefield(node: &Fr) -> Self::MTNode { + Self::MTNode::from_scalar(node) + } + fn merkelize(&mut self, buff: Vec, width: usize, height: usize) -> Result<()> { let max_workers = get_max_workers(); @@ -367,12 +371,8 @@ mod tests { #[test] fn test_merkle_tree_bn128_serialize_and_deserialize() { let data = MerkleTreeBN128::new(); - let serialized = serde_json::to_string(&data).unwrap(); - println!("Serialized: {}", serialized); - let expect: MerkleTreeBN128 = serde_json::from_str(&serialized).unwrap(); - assert_eq!(data, expect); } } diff --git a/starky/src/serializer.rs b/starky/src/serializer.rs index 58d0d22c..755bb5cf 100644 --- a/starky/src/serializer.rs +++ b/starky/src/serializer.rs @@ -2,15 +2,11 @@ #![allow(non_snake_case)] use crate::f3g::F3G; use crate::f5g::F5G; -use crate::field_bls12381::Fr as Fr_BLS12381; -use crate::field_bn128::Fr; use crate::fri::FRIProof; use crate::fri::Query; -use crate::helper; use crate::stark_gen::StarkProof; use crate::traits::FieldExtension; use crate::traits::{MTNodeType, MerkleTree}; -use ff::PrimeField; use fields::field_gl::Fr as FGL; use serde::ser::{Serialize, SerializeMap, SerializeSeq, Serializer}; use serde::{ @@ -138,118 +134,6 @@ impl<'de> Deserialize<'de> for F5G { } } -// Is it making sense to wrap? -#[derive(Clone, Debug, PartialEq)] -pub struct NodeWrapper(pub T); - -impl NodeWrapper { - pub fn new(e: T) -> Self { - NodeWrapper(e) - } - pub fn is_dim_1(&self) -> bool { - let e = self.0.as_elements(); - e[1] == e[2] && e[1] == e[3] && e[1] == FGL::ZERO - } -} - -impl Serialize for NodeWrapper { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let source = TypeId::of::(); - if source == TypeId::of::() { - let r: Fr = Fr(self.0.clone().as_scalar::()); - return serializer.serialize_str(&helper::fr_to_biguint(&r).to_string()); - } - if source == TypeId::of::() { - let r: Fr_BLS12381 = Fr_BLS12381(self.0.clone().as_scalar::()); - return serializer.serialize_str(&helper::fr_to_biguint(&r).to_string()); - } - if source == TypeId::of::() { - let e = self.0.as_elements(); - if self.is_dim_1() { - return serializer.serialize_str(&e[0].as_int().to_string()); - } else { - let mut seq = serializer.serialize_seq(Some(4))?; - for v in e.iter() { - seq.serialize_element(&v.as_int().to_string())?; - } - return seq.end(); - } - } - panic!("Invalid element for seralizing, {:?}", self.0) - } -} - -impl<'de, T: MTNodeType> Deserialize<'de> for NodeWrapper { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct EntriesVisitor(PhantomData); - - impl<'de, MT: MTNodeType> Visitor<'de> for EntriesVisitor { - type Value = NodeWrapper; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("struct NodeWrapper") - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: SeqAccess<'de>, - { - let mut entries = Vec::new(); - while let Some(entry) = seq.next_element::()? { - let entry: u64 = entry.parse().unwrap(); - entries.push(FGL::from(entry)); - } - Ok(NodeWrapper(MT::new(&entries))) - } - - // it could be one-dim GL, BN128, or BLS12381 - fn visit_str(self, s: &str) -> Result - where - E: de::Error, - { - let source = TypeId::of::(); - if source == TypeId::of::() { - // one-dim GL elements - let value = FGL::from_str(s).unwrap(); - let one_fgl: NodeWrapper = NodeWrapper::from(value); - Ok(one_fgl) - } else { - // BN128 or BLS12381 - let t = ::BaseField::from_str(s).unwrap(); - Ok(NodeWrapper(MT::from_scalar(&t))) - } - } - } - deserializer.deserialize_any(EntriesVisitor::(Default::default())) - } -} - -impl From for NodeWrapper { - fn from(val: Fr) -> Self { - let e = T::from_scalar(&val); - Self(e) - } -} - -impl From for NodeWrapper { - fn from(val: Fr_BLS12381) -> Self { - let e = T::from_scalar(&val); - Self(e) - } -} - -impl From for NodeWrapper { - fn from(val: FGL) -> Self { - Self(T::new(&[val, FGL::ZERO, FGL::ZERO, FGL::ZERO])) - } -} - impl Serialize for StarkProof { fn serialize(&self, serializer: S) -> Result where @@ -260,22 +144,19 @@ impl Serialize for StarkProof { let mut map = serializer.serialize_map(Some(len))?; if self.rootC.is_some() { - map.serialize_entry("rootC", &NodeWrapper::::new(self.rootC.unwrap()))?; + map.serialize_entry("rootC", &self.rootC.unwrap())?; } - map.serialize_entry("root1", &NodeWrapper::::new(self.root1))?; - map.serialize_entry("root2", &NodeWrapper::::new(self.root2))?; - map.serialize_entry("root3", &NodeWrapper::::new(self.root3))?; - map.serialize_entry("root4", &NodeWrapper::::new(self.root4))?; + map.serialize_entry("root1", &self.root1)?; + map.serialize_entry("root2", &self.root2)?; + map.serialize_entry("root3", &self.root3)?; + map.serialize_entry("root4", &self.root4)?; map.serialize_entry("evals", &self.evals)?; for i in 1..(self.fri_proof.queries.len()) { - map.serialize_entry( - &format!("s{}_root", i), - &NodeWrapper::new(self.fri_proof.queries[i].root), - )?; + map.serialize_entry(&format!("s{}_root", i), &self.fri_proof.queries[i].root)?; let mut vals: Vec> = vec![]; - let mut sibs: Vec>>> = vec![]; + let mut sibs: Vec>> = vec![]; for q in 0..self.fri_proof.queries[0].pol_queries.len() { let qe = &self.fri_proof.queries[i].pol_queries[q]; vals.push(qe[0].0.iter().map(|e| F3G::from(*e)).collect::>()); @@ -285,10 +166,10 @@ impl Serialize for StarkProof { .iter() .map(|e| { e.iter() - .map(|ee| ee.clone().into()) - .collect::>>() + .map(|ee| M::from_basefield(ee)) + .collect::>() }) - .collect::>>>(), + .collect::>>(), ); } map.serialize_entry(&format!("s{}_vals", i), &vals)?; @@ -300,11 +181,11 @@ impl Serialize for StarkProof { let mut s0_vals3: Vec> = vec![]; let mut s0_vals4: Vec> = vec![]; let mut s0_valsC: Vec> = vec![]; - let mut s0_siblings1: Vec>>> = vec![]; - let mut s0_siblings2: Vec>>> = vec![]; - let mut s0_siblings3: Vec>>> = vec![]; - let mut s0_siblings4: Vec>>> = vec![]; - let mut s0_siblingsC: Vec>>> = vec![]; + let mut s0_siblings1: Vec>> = vec![]; + let mut s0_siblings2: Vec>> = vec![]; + let mut s0_siblings3: Vec>> = vec![]; + let mut s0_siblings4: Vec>> = vec![]; + let mut s0_siblingsC: Vec>> = vec![]; for i in 0..self.fri_proof.queries[0].pol_queries.len() { //(leaf, path) represents each query @@ -316,72 +197,63 @@ impl Serialize for StarkProof { .iter() .map(|e| { e.iter() - .map(|ee| ee.clone().into()) - .collect::>>() + .map(|ee| M::from_basefield(ee)) + .collect::>() }) - .collect::>>>(), + .collect::>>(), ); - if !qe[1].0.is_empty() { - s0_vals2.push(qe[1].0.iter().map(|e| F3G::from(*e)).collect::>()); - s0_siblings2.push( - qe[1] - .1 - .iter() - .map(|e| { - e.iter() - .map(|ee| ee.clone().into()) - .collect::>>() - }) - .collect::>>>(), - ); - } + s0_vals2.push(qe[1].0.iter().map(|e| F3G::from(*e)).collect::>()); + s0_siblings2.push( + qe[1] + .1 + .iter() + .map(|e| { + e.iter() + .map(|ee| M::from_basefield(ee)) + .collect::>() + }) + .collect::>>(), + ); - if !qe[2].0.is_empty() { - s0_vals3.push(qe[2].0.iter().map(|e| F3G::from(*e)).collect::>()); - s0_siblings3.push( - qe[2] - .1 - .iter() - .map(|e| { - e.iter() - .map(|ee| ee.clone().into()) - .collect::>>() - }) - .collect::>>>(), - ); - } + s0_vals3.push(qe[2].0.iter().map(|e| F3G::from(*e)).collect::>()); + s0_siblings3.push( + qe[2] + .1 + .iter() + .map(|e| { + e.iter() + .map(|ee| M::from_basefield(ee)) + .collect::>() + }) + .collect::>>(), + ); - let qe = &self.fri_proof.queries[0].pol_queries[i]; - if !qe[3].0.is_empty() { - s0_vals4.push(qe[3].0.iter().map(|e| F3G::from(*e)).collect::>()); - s0_siblings4.push( - qe[3] - .1 - .iter() - .map(|e| { - e.iter() - .map(|ee| ee.clone().into()) - .collect::>>() - }) - .collect::>>>(), - ); - } + s0_vals4.push(qe[3].0.iter().map(|e| F3G::from(*e)).collect::>()); + s0_siblings4.push( + qe[3] + .1 + .iter() + .map(|e| { + e.iter() + .map(|ee| M::from_basefield(ee)) + .collect::>() + }) + .collect::>>(), + ); - if !qe[4].0.is_empty() { - s0_valsC.push(qe[4].0.iter().map(|e| F3G::from(*e)).collect::>()); - s0_siblingsC.push( - qe[4] - .1 - .iter() - .map(|e| { - e.iter() - .map(|ee| ee.clone().into()) - .collect::>>() - }) - .collect::>>>(), - ); - } + s0_valsC.push(qe[4].0.iter().map(|e| F3G::from(*e)).collect::>()); + s0_siblingsC.push( + qe[4] + .1 + .iter() + .map(|e| { + e.iter() + .map(|ee| M::from_basefield(ee)) + .collect::>() + }) + .collect::>>(), + ); } map.serialize_entry("s0_vals1", &s0_vals1)?; @@ -438,27 +310,26 @@ impl<'de, T: MerkleTree + Default> Deserialize<'de> for StarkProof { map.insert(key, value); } let mut sp: StarkProof = Default::default(); - let root: NodeWrapper = + let root: MT::MTNode = serde_json::from_value(map.get("root1").unwrap().clone()).unwrap(); - sp.root1 = root.0; + sp.root1 = root; - let root: NodeWrapper = + let root: MT::MTNode = serde_json::from_value(map.get("root2").unwrap().clone()).unwrap(); - sp.root2 = root.0; + sp.root2 = root; - let root: NodeWrapper = + let root: MT::MTNode = serde_json::from_value(map.get("root3").unwrap().clone()).unwrap(); - sp.root3 = root.0; + sp.root3 = root; - let root: NodeWrapper = + let root: MT::MTNode = serde_json::from_value(map.get("root4").unwrap().clone()).unwrap(); - sp.root4 = root.0; + sp.root4 = root; let root = map.get("rootC"); if root.is_some() { - let root: NodeWrapper = - serde_json::from_value(root.unwrap().clone()).unwrap(); - sp.rootC = Some(root.0); + let root: MT::MTNode = serde_json::from_value(root.unwrap().clone()).unwrap(); + sp.rootC = Some(root); } let prover_addr = map.get("proverAddr"); @@ -516,16 +387,8 @@ impl<'de, T: MerkleTree + Default> Deserialize<'de> for StarkProof { .collect(); let key = map.get(&format!("s0_siblings{}", j)); - let s0_siblings: Vec>>> = + let s0_siblings: Vec>> = serde_json::from_value(key.unwrap().clone()).unwrap(); - let s0_siblings: Vec>> = s0_siblings - .iter() - .map(|e| { - e.iter() - .map(|e2| e2.iter().map(|e3| e3.0).collect()) - .collect() - }) - .collect(); s0_vals_all.push(s0_vals); s0_siblings_all.push(s0_siblings); } @@ -568,9 +431,8 @@ impl<'de, T: MerkleTree + Default> Deserialize<'de> for StarkProof { // handle query 1 to num_query for i in 1..=num_query { let key = map.get(&format!("s{}_root", i)); - let root: NodeWrapper = - serde_json::from_value(key.unwrap().clone()).unwrap(); - fri_proof.queries[i].root = root.0; + let root: MT::MTNode = serde_json::from_value(key.unwrap().clone()).unwrap(); + fri_proof.queries[i].root = root; let key = map.get(&format!("s{}_vals", i)); let val: Vec> = serde_json::from_value(key.unwrap().clone()).unwrap(); @@ -589,17 +451,8 @@ impl<'de, T: MerkleTree + Default> Deserialize<'de> for StarkProof { .collect(); let key = map.get(&format!("s{}_siblings", i)); - let sibs: Vec>>> = + let sibs: Vec>> = serde_json::from_value(key.unwrap().clone()).unwrap(); - let sibs: Vec>> = sibs - .iter() - .map(|e| { - e.iter() - .map(|e2| e2.iter().map(|e3| e3.0).collect()) - .collect() - }) - .collect(); - fri_proof.queries[i].pol_queries = vec![vec![]; num_pol_queries]; for q in 0..num_pol_queries { let node_to_bf = crate::traits::mt_node_to_basefield::(&sibs[q]); @@ -630,7 +483,6 @@ mod tests { use crate::merklehash_bn128::MerkleTreeBN128; use crate::polsarray::PolKind; use crate::polsarray::PolsArray; - use crate::serializer::NodeWrapper; use crate::serializer::StarkProof; use crate::stark_setup::StarkSetup; use crate::traits::FieldExtension; @@ -691,7 +543,7 @@ mod tests { } #[test] - fn test_serialize_node_wrapper() { + fn test_serialize_element_digest() { env_logger::try_init().unwrap_or_default(); let mut rng = rand::thread_rng(); let four_fgl = ElementDigest::<4, FGL>::new(&[ @@ -701,28 +553,25 @@ mod tests { FGL::rand(&mut rng), ]); - let four_fgl = NodeWrapper::>::new(four_fgl); let four_fgl_ser = serde_json::to_string(&four_fgl).unwrap(); - let actual_four_fgl: NodeWrapper> = - serde_json::from_str(&four_fgl_ser).unwrap(); + let actual_four_fgl: ElementDigest<4, FGL> = serde_json::from_str(&four_fgl_ser).unwrap(); assert_eq!(four_fgl.0, actual_four_fgl.0); - let one_fgl: NodeWrapper> = NodeWrapper::from(FGL::rand(&mut rng)); + let one_fgl: ElementDigest<4, FGL> = + ElementDigest::<4, FGL>::new(&[FGL::rand(&mut rng), FGL::ZERO, FGL::ZERO, FGL::ZERO]); let one_fgl_ser = serde_json::to_string(&one_fgl).unwrap(); - let actual_one_fgl: NodeWrapper> = - serde_json::from_str(&one_fgl_ser).unwrap(); + let actual_one_fgl: ElementDigest<4, FGL> = serde_json::from_str(&one_fgl_ser).unwrap(); assert_eq!(one_fgl.0, actual_one_fgl.0); - let one_fr: NodeWrapper> = NodeWrapper::from(Fr::rand(&mut rng)); + let one_fr: ElementDigest<4, Fr> = ElementDigest::<4, Fr>::from_scalar(&Fr::rand(&mut rng)); let one_fr_ser = serde_json::to_string(&one_fr).unwrap(); - let actual_one_fr: NodeWrapper> = - serde_json::from_str(&one_fr_ser).unwrap(); + let actual_one_fr: ElementDigest<4, Fr> = serde_json::from_str(&one_fr_ser).unwrap(); assert_eq!(one_fr.0, actual_one_fr.0); - let one_fr: NodeWrapper> = - NodeWrapper::from(Fr_BLS12381::rand(&mut rng)); + let one_fr: ElementDigest<4, Fr_BLS12381> = + ElementDigest::<4, Fr_BLS12381>::from_scalar(&Fr_BLS12381::rand(&mut rng)); let one_fr_ser = serde_json::to_string(&one_fr).unwrap(); - let actual_one_fr: NodeWrapper> = + let actual_one_fr: ElementDigest<4, Fr_BLS12381> = serde_json::from_str(&one_fr_ser).unwrap(); assert_eq!(one_fr.0, actual_one_fr.0); } diff --git a/starky/src/stark_gen.rs b/starky/src/stark_gen.rs index c655c7dc..506a67d6 100644 --- a/starky/src/stark_gen.rs +++ b/starky/src/stark_gen.rs @@ -1137,10 +1137,12 @@ pub mod tests { "273030697313060285579891744179749754319274977764", ) .unwrap(); + let ser = serde_json::to_string(&starkproof).unwrap(); + let de: StarkProof = serde_json::from_str(&ser).unwrap(); log::trace!("verify the proof..."); let result = stark_verify::( - &starkproof, + &de, &setup.const_root, &setup.starkinfo, &stark_struct, @@ -1176,11 +1178,12 @@ pub mod tests { "273030697313060285579891744179749754319274977764", ) .unwrap(); - + let ser = serde_json::to_string(&starkproof).unwrap(); + let de: StarkProof = serde_json::from_str(&ser).unwrap(); log::trace!("verify the proof..."); let result = stark_verify::( - &starkproof, + &de, &setup.const_root, &setup.starkinfo, &stark_struct, @@ -1213,9 +1216,11 @@ pub mod tests { "273030697313060285579891744179749754319274977764", ) .unwrap(); + let ser = serde_json::to_string(&starkproof).unwrap(); + let de: StarkProof = serde_json::from_str(&ser).unwrap(); log::trace!("verify the proof..."); let result = stark_verify::( - &starkproof, + &de, &setup.const_root, &setup.starkinfo, &stark_struct, @@ -1250,9 +1255,11 @@ pub mod tests { "273030697313060285579891744179749754319274977764", ) .unwrap(); + let ser = serde_json::to_string(&starkproof).unwrap(); + let de: StarkProof = serde_json::from_str(&ser).unwrap(); log::trace!("verify the proof..."); let result = stark_verify::( - &starkproof, + &de, &setup.const_root, &setup.starkinfo, &stark_struct, @@ -1287,6 +1294,8 @@ pub mod tests { "273030697313060285579891744179749754319274977764", ) .unwrap(); + let ser = serde_json::to_string(&starkproof).unwrap(); + let de: StarkProof = serde_json::from_str(&ser).unwrap(); log::trace!("verify the proof..."); let result = stark_verify::( &starkproof, @@ -1297,5 +1306,15 @@ pub mod tests { ) .unwrap(); assert!(result); + + let result = stark_verify::( + &de, + &setup.const_root, + &setup.starkinfo, + &stark_struct, + &setup.program, + ) + .unwrap(); + assert!(result); } } diff --git a/starky/src/traits.rs b/starky/src/traits.rs index 3089787a..027136a0 100644 --- a/starky/src/traits.rs +++ b/starky/src/traits.rs @@ -1,4 +1,3 @@ -use crate::serializer::NodeWrapper; use ::rand::Rand; use anyhow::Result; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; @@ -34,12 +33,12 @@ where + Debug + Serialize + DeserializeOwned; - // TODO: the BaseField is the container of flatten MTNode, merge BaseField and MTNode - type BaseField: Clone + Default + Debug + PartialEq + Into>; type ExtendField: FieldExtension; + type BaseField: Clone + Default + Debug + PartialEq + Serialize + DeserializeOwned; fn new() -> Self; fn to_extend(&self, p_be: &mut Vec); fn to_basefield(node: &Self::MTNode) -> Vec; + fn from_basefield(node: &Self::BaseField) -> Self::MTNode; fn merkelize(&mut self, buff: Vec, width: usize, height: usize) -> Result<()>; fn get_element(&self, idx: usize, sub_idx: usize) -> FGL; fn get_group_proof(&self, idx: usize) -> Result<(Vec, Vec>)>;