diff --git a/Cargo.lock b/Cargo.lock index 177f323176..10da4557ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2659,7 +2659,6 @@ dependencies = [ "ark-ff", "ark-poly", "ark-serialize", - "ark-std", "clap 4.4.18", "ctor", "hex", diff --git a/saffron/Cargo.toml b/saffron/Cargo.toml index 75a057731f..e88ac43a82 100644 --- a/saffron/Cargo.toml +++ b/saffron/Cargo.toml @@ -39,7 +39,6 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = [ "ansi", "env-filter", "fmt", "time" ] } [dev-dependencies] -ark-std.workspace = true ctor = "0.2" proptest.workspace = true once_cell.workspace = true diff --git a/saffron/src/blob.rs b/saffron/src/blob.rs index 2f86c36471..87b7252532 100644 --- a/saffron/src/blob.rs +++ b/saffron/src/blob.rs @@ -82,7 +82,7 @@ impl FieldBlob { for p in blob.data { let evals = p.evaluate_over_domain(domain).evals; for x in evals { - decode_into(&mut buffer, x); + decode_into(&mut buffer, &x); bytes.extend_from_slice(&buffer[(m - n)..m]); } } diff --git a/saffron/src/utils.rs b/saffron/src/utils.rs index 7b5484b1f7..2a34c9d73d 100644 --- a/saffron/src/utils.rs +++ b/saffron/src/utils.rs @@ -1,5 +1,7 @@ use ark_ff::{BigInteger, PrimeField}; use ark_poly::EvaluationDomain; +use o1_utils::FieldHelpers; +use tracing::instrument; // For injectivity, you can only use this on inputs of length at most // 'F::MODULUS_BIT_SIZE / 8', e.g. for Vesta this is 31. @@ -7,14 +9,17 @@ fn encode(bytes: &[u8]) -> Fp { Fp::from_be_bytes_mod_order(bytes) } -pub fn decode_into(buffer: &mut [u8], x: Fp) { +pub fn decode_into(buffer: &mut [u8], x: &Fp) { let bytes = x.into_bigint().to_bytes_be(); buffer.copy_from_slice(&bytes); } -pub fn get_31_bytes(x: F) -> Vec { - let bytes = x.into_bigint().to_bytes_be(); - bytes[1..32].to_vec() +fn decode(x: &F) -> Vec { + let n = (F::MODULUS_BIT_SIZE / 8) as usize; + let m = F::size_in_bytes(); + let mut buffer = vec![0u8; m]; + decode_into(&mut buffer, x); + buffer[(m - n)..m].to_vec() } pub fn encode_as_field_elements(bytes: &[u8]) -> Vec { @@ -61,47 +66,43 @@ pub struct QueryBytes { /// The inner vector represent polynomials pub struct FieldElt { /// the number of the polynomial the data point is attached too - pub poly_nb: usize, + poly_index: usize, /// the number of the root of unity the data point is attached too - pub eval_nb: usize, + eval_index: usize, + domain_size: usize, + n_polys: usize, } /// Represents a query in term of Field element #[derive(Debug)] -pub struct QueryField { - pub start: FieldElt, +pub struct QueryField { + start: FieldElt, /// how many bytes we need to trim from the first 31bytes chunk /// we get from the first field element we decode - pub leftover_start: usize, - pub end: FieldElt, + leftover_start: usize, + end: FieldElt, /// how many bytes we need to trim from the last 31bytes chunk /// we get from the last field element we decode - pub leftover_end: usize, + leftover_end: usize, + tag: std::marker::PhantomData, } -impl QueryField { - pub fn is_valid(&self, nb_poly: usize) -> bool { - self.start.eval_nb < 1 << 16 - && self.end.eval_nb < 1 << 16 - && self.end.poly_nb < nb_poly - && self.start <= self.end - && self.leftover_end <= (F::MODULUS_BIT_SIZE as usize) / 8 - && self.leftover_start <= (F::MODULUS_BIT_SIZE as usize) / 8 - } - - pub fn apply(self, data: Vec>) -> Vec { - assert!(self.is_valid::(data.len()), "Invalid query"); +impl QueryField { + #[instrument(skip_all, level = "debug")] + pub fn apply(self, data: &[Vec]) -> Vec { let mut answer: Vec = Vec::new(); - let mut field_elt = self.start; - while field_elt <= self.end { - if data[field_elt.poly_nb][field_elt.eval_nb] == F::zero() { - println!() + let mut current = self.start; + + while current <= self.end { + let value = data[current.poly_index][current.eval_index]; + answer.extend(decode(&value)); + if let Some(next) = current.next() { + current = next; + } else { + break; } - let mut to_append = get_31_bytes(data[field_elt.poly_nb][field_elt.eval_nb]); - answer.append(&mut to_append); - field_elt = field_elt.next().unwrap(); } + let n = answer.len(); - // trimming the first and last 31bytes chunk answer[(self.leftover_start)..(n - self.leftover_end)].to_vec() } } @@ -109,38 +110,53 @@ impl QueryField { impl Iterator for FieldElt { type Item = FieldElt; fn next(&mut self) -> Option { - if self.eval_nb < (1 << 16) - 1 { - self.eval_nb += 1; - } else { - self.poly_nb += 1; - self.eval_nb = 0 - }; - Some(*self) + if (self.eval_index + 1) < self.domain_size { + self.eval_index += 1; + return Some(*self); + } else if (self.poly_index + 1) < self.n_polys { + self.poly_index += 1; + self.eval_index = 0; + return Some(*self); + } + None } } -impl Into for QueryBytes { - fn into(self) -> QueryField { - let n = 31 as usize; - let start_field_nb = self.start / n; - let start = FieldElt { - poly_nb: start_field_nb / (1 << 16), - eval_nb: start_field_nb % (1 << 16), +impl QueryBytes { + pub fn into_query_field( + &self, + domain_size: usize, + n_polys: usize, + ) -> QueryField { + let n = (F::MODULUS_BIT_SIZE / 8) as usize; + let start = { + let start_field_nb = self.start / n; + FieldElt { + poly_index: start_field_nb / domain_size, + eval_index: start_field_nb % domain_size, + domain_size, + n_polys, + } }; - let leftover_start = self.start % n; - let byte_end = self.start + self.len; - let end_field_nb = byte_end / n; - let end = FieldElt { - poly_nb: end_field_nb / (1 << 16), - eval_nb: end_field_nb % (1 << 16), + let end = { + let end_field_nb = byte_end / n; + FieldElt { + poly_index: end_field_nb / domain_size, + eval_index: end_field_nb % domain_size, + domain_size, + n_polys, + } }; + let leftover_start = self.start % n; let leftover_end = n - byte_end % n; + QueryField { start, leftover_start, end, leftover_end, + tag: std::marker::PhantomData, } } } @@ -156,6 +172,10 @@ pub mod test_utils { pub fn len(&self) -> usize { self.0.len() } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } } #[derive(Clone, Debug)] @@ -228,48 +248,31 @@ pub mod test_utils { mod tests { use super::*; use ark_poly::Radix2EvaluationDomain; - use ark_std::UniformRand; use mina_curves::pasta::Fp; - use o1_utils::FieldHelpers; use once_cell::sync::Lazy; use proptest::prelude::*; use test_utils::UserData; - fn decode(x: Fp) -> Vec { - let mut buffer = vec![0u8; Fp::size_in_bytes()]; - decode_into(&mut buffer, x); - buffer - } - fn decode_from_field_elements(xs: Vec) -> Vec { - let n = (F::MODULUS_BIT_SIZE / 8) as usize; - let m = F::size_in_bytes(); - let mut buffer = vec![0u8; F::size_in_bytes()]; - xs.iter() - .flat_map(|x| { - decode_into(&mut buffer, *x); - buffer[(m - n)..m].to_vec() - }) - .collect() + xs.iter().flat_map(decode).collect() } - // Check that [u8] -> Fp -> [u8] is the identity function. + // Check that [u8] -> Fp -> [u8] is the identity function for any length 31 bytestring. proptest! { #[test] fn test_round_trip_from_bytes(xs in any::<[u8;31]>()) { let n : Fp = encode(&xs); - let ys : [u8; 31] = decode(n).as_slice()[1..32].try_into().unwrap(); + let ys : [u8; 31] = decode(&n).try_into().unwrap(); prop_assert_eq!(xs, ys); } } - // Check that Fp -> [u8] -> Fp is the identity function. + // Check that Fp -> [u8] -> Fp is the identity function when restricted to the range of encode proptest! { #[test] - fn test_round_trip_from_fp( - x in prop::strategy::Just(Fp::rand(&mut ark_std::rand::thread_rng())) - ) { - let bytes = decode(x); + fn test_round_trip_from_fp(x in any::<[u8;31]>().prop_map(|xs| encode::(&xs))) { + + let bytes = decode(&x); let y = encode(&bytes); prop_assert_eq!(x,y); } @@ -314,10 +317,11 @@ mod tests { }) ) { let chunked = encode_for_domain(&*DOMAIN, &xs); + let n_polys = chunked.len(); for query in queries { let expected = &xs[query.start..(query.start+query.len)]; - let field_query: QueryField = query.clone().into(); - let got_answer = field_query.apply(chunked.clone()); // Note: might need clone depending on your types + let field_query: QueryField = query.into_query_field(DOMAIN.size(), n_polys); + let got_answer = field_query.apply(&chunked); prop_assert_eq!(expected, got_answer); } }