Skip to content

Commit

Permalink
fix: remove node wrapper and add UTs (#234)
Browse files Browse the repository at this point in the history
* fix: remove node wrapper and add UTs

* fix: do not ignore the empty leaf

* fix: donot ignore the empty leaf

* fix: typos
  • Loading branch information
eigmax authored Mar 15, 2024
1 parent 6f8114f commit 59d2152
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 278 deletions.
67 changes: 50 additions & 17 deletions starky/src/digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,31 @@
use crate::field_bls12381::Fr as Fr_bls12381;
use crate::field_bls12381::FrRepr as FrRepr_bls12381;
use crate::field_bn128::{Fr, FrRepr};
use crate::helper;
use crate::traits::MTNodeType;
use ff::*;
use fields::field_gl::Fr as FGL;
use serde::de::{SeqAccess, Visitor};
use serde::ser::SerializeSeq;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use std::any::TypeId;
use std::fmt;
use std::fmt::Display;
use std::marker::PhantomData;

/// the trait F is used to keep track of source data type, so we can implement its deserializer
// TODO: Remove the generic type: F. As it's never used.
#[repr(C)]
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct ElementDigest<const N: usize, F: PrimeField + Default>(pub [FGL; N], PhantomData<F>);

impl<const N: usize, F: PrimeField + Default> ElementDigest<N, F> {
// FIXME: this is a bit tricky that assuming the len is 4, replace it by N here.
pub fn is_dim_1(&self) -> bool {
let e = self.as_elements();
e[1] == e[2] && e[1] == e[3] && e[1] == FGL::ZERO
}
}

impl<const N: usize, F: PrimeField + Default> MTNodeType for ElementDigest<N, F> {
type BaseField = F;
#[inline(always)]
Expand Down Expand Up @@ -77,13 +86,28 @@ impl<const N: usize, F: PrimeField + Default> Serialize for ElementDigest<N, F>
where
S: Serializer,
{
let elems = self.0.to_vec();

let mut seq = serializer.serialize_seq(Some(elems.len()))?;
for v in elems.iter() {
seq.serialize_element(&v.as_int().to_string())?;
let source = TypeId::of::<F>();
if source == TypeId::of::<Fr>() {
let r: Fr = Fr(self.as_scalar::<Fr>());
return serializer.serialize_str(&helper::fr_to_biguint(&r).to_string());
}
if source == TypeId::of::<Fr_bls12381>() {
let r: Fr_bls12381 = Fr_bls12381(self.as_scalar::<Fr_bls12381>());
return serializer.serialize_str(&helper::fr_to_biguint(&r).to_string());
}
if source == TypeId::of::<FGL>() {
let e = self.as_elements();
if self.is_dim_1() {
return serializer.serialize_str(&e[0].as_int().to_string());
} else {
let mut seq = serializer.serialize_seq(Some(4))?;
for v in e.iter() {
seq.serialize_element(&v.as_int().to_string())?;
}
return seq.end();
}
}
seq.end()
panic!("Invalid element to serialize, {:?}", self.0)
}
}

Expand All @@ -105,24 +129,34 @@ impl<'de, const N: usize, F: PrimeField + Default> Deserialize<'de> for ElementD
where
A: SeqAccess<'de>,
{
let mut entries = Vec::with_capacity(N);
let mut entries = Vec::new();
while let Some(entry) = seq.next_element::<String>()? {
let entry: u64 = entry.parse().unwrap();

entries.push(FGL::from_repr(fields::field_gl::FrRepr::from(entry)).unwrap());
entries.push(FGL::from(entry));
}
Ok(ElementDigest::<N, F>::new(&entries))
}

// it could be one-dim GL, BN128, or BLS12381
fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
let entry: u64 = s.parse().unwrap();

let data = vec![FGL::from_repr(fields::field_gl::FrRepr::from(entry)).unwrap(); N];

Ok(ElementDigest::<N, F>::new(&data))
let source = TypeId::of::<F>();
if source == TypeId::of::<FGL>() {
// one-dim GL elements
let value = FGL::from_str(s).unwrap();
Ok(ElementDigest::<N, F>::new(&[
value,
FGL::ZERO,
FGL::ZERO,
FGL::ZERO,
]))
} else {
// BN128 or BLS12381
let t = F::from_str(s).unwrap();
Ok(ElementDigest::<N, F>::from_scalar(&t))
}
}
}
deserializer.deserialize_any(EntriesVisitor::<N, F>(Default::default()))
Expand Down Expand Up @@ -280,11 +314,10 @@ mod tests {

#[test]
fn test_element_digest_serialize_and_deserialize() {
const N: usize = 3;
const N: usize = 4;
let fields = vec![FGL::one(); N];
let data = ElementDigest::<N, FGL>::new(&fields);
let serialized = serde_json::to_string(&data).unwrap();
println!("Serialized: {}", serialized);

let expect: ElementDigest<N, FGL> = serde_json::from_str(&serialized).unwrap();

Expand Down
40 changes: 40 additions & 0 deletions starky/src/field_bls12381.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,51 @@
use core::ops::{Add, Div, Mul, Neg, Sub};
use ff::*;

use crate::helper;
use serde::de::{Error, SeqAccess, Visitor};
use serde::ser::SerializeSeq;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;

#[derive(PrimeField)]
#[PrimeFieldModulus = "52435875175126190479447740508185965837690552500527637822603658699938581184513"]
#[PrimeFieldGenerator = "7"]
pub struct Fr(pub FrRepr);

impl Serialize for Fr {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&helper::fr_to_biguint(self).to_string())
}
}

impl<'de> Deserialize<'de> for Fr {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct EntriesVisitor;

impl<'de> Visitor<'de> for EntriesVisitor {
type Value = Fr;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct Bls12381's Fr")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
Ok(Self::Value::from_str(v).unwrap())
}
}
deserializer.deserialize_any(EntriesVisitor)
}
}

#[cfg(test)]
mod tests {
use crate::field_bls12381::*;
Expand Down
41 changes: 41 additions & 0 deletions starky/src/field_bn128.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,52 @@
#![allow(unused_imports, clippy::too_many_arguments)]
use ff::*;

use crate::helper;
use ff::*;
use serde::de::{Error, SeqAccess, Visitor};
use serde::ser::SerializeSeq;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;

#[derive(PrimeField)]
#[PrimeFieldModulus = "21888242871839275222246405745257275088548364400416034343698204186575808495617"]
#[PrimeFieldGenerator = "7"]
pub struct Fr(pub FrRepr);

impl Serialize for Fr {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&helper::fr_to_biguint(self).to_string())
}
}

impl<'de> Deserialize<'de> for Fr {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct EntriesVisitor;

impl<'de> Visitor<'de> for EntriesVisitor {
type Value = Fr;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct Bn128's Fr")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
Ok(Self::Value::from_str(v).unwrap())
}
}
deserializer.deserialize_any(EntriesVisitor)
}
}

#[cfg(test)]
mod tests {
use crate::field_bn128::*;
Expand Down
8 changes: 4 additions & 4 deletions starky/src/merklehash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,10 @@ impl MerkleTree for MerkleTreeGL {
vec![node.as_elements().to_vec()[0]]
}

fn from_basefield(node: &FGL) -> Self::MTNode {
Self::MTNode::new(&[*node, FGL::ZERO, FGL::ZERO, FGL::ZERO])
}

#[cfg(not(any(
target_feature = "avx512bw",
target_feature = "avx512cd",
Expand Down Expand Up @@ -582,12 +586,8 @@ mod tests {
#[test]
fn test_merkle_tree_gl_serialize_and_deserialize() {
let data = MerkleTreeGL::new();

let serialized = serde_json::to_string(&data).unwrap();
println!("Serialized: {}", serialized);

let expect: MerkleTreeGL = serde_json::from_str(&serialized).unwrap();

assert_eq!(data, expect);
}
}
8 changes: 4 additions & 4 deletions starky/src/merklehash_bls12381.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ impl MerkleTree for MerkleTreeBLS12381 {
vec![Fr(node.as_scalar::<Fr>())]
}

fn from_basefield(node: &Fr) -> Self::MTNode {
Self::MTNode::from_scalar(node)
}

fn merkelize(&mut self, buff: Vec<FGL>, width: usize, height: usize) -> Result<()> {
let max_workers = get_max_workers();

Expand Down Expand Up @@ -373,12 +377,8 @@ mod tests {
#[test]
fn test_merkle_tree_bls381_serialize_and_deserialize() {
let data = MerkleTreeBLS12381::new();

let serialized = serde_json::to_string(&data).unwrap();
println!("Serialized: {}", serialized);

let expect: MerkleTreeBLS12381 = serde_json::from_str(&serialized).unwrap();

assert_eq!(data, expect);
}
}
8 changes: 4 additions & 4 deletions starky/src/merklehash_bn128.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ impl MerkleTree for MerkleTreeBN128 {
vec![Fr(node.as_scalar::<Fr>())]
}

fn from_basefield(node: &Fr) -> Self::MTNode {
Self::MTNode::from_scalar(node)
}

fn merkelize(&mut self, buff: Vec<FGL>, width: usize, height: usize) -> Result<()> {
let max_workers = get_max_workers();

Expand Down Expand Up @@ -367,12 +371,8 @@ mod tests {
#[test]
fn test_merkle_tree_bn128_serialize_and_deserialize() {
let data = MerkleTreeBN128::new();

let serialized = serde_json::to_string(&data).unwrap();
println!("Serialized: {}", serialized);

let expect: MerkleTreeBN128 = serde_json::from_str(&serialized).unwrap();

assert_eq!(data, expect);
}
}
Loading

0 comments on commit 59d2152

Please sign in to comment.