diff --git a/Cargo.lock b/Cargo.lock index 177f323176..bb6af2766c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2675,6 +2675,7 @@ dependencies = [ "serde", "serde_with", "sha3", + "thiserror", "time", "tracing", "tracing-subscriber", diff --git a/saffron/Cargo.toml b/saffron/Cargo.toml index 75a057731f..7d1d1f3c77 100644 --- a/saffron/Cargo.toml +++ b/saffron/Cargo.toml @@ -34,6 +34,7 @@ rmp-serde.workspace = true serde.workspace = true serde_with.workspace = true sha3.workspace = true +thiserror.workspace = true time = { version = "0.3", features = ["macros"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = [ "ansi", "env-filter", "fmt", "time" ] } diff --git a/saffron/src/utils.rs b/saffron/src/utils.rs index 7b5484b1f7..5f3c6d0c0f 100644 --- a/saffron/src/utils.rs +++ b/saffron/src/utils.rs @@ -1,5 +1,10 @@ +use std::marker::PhantomData; + use ark_ff::{BigInteger, PrimeField}; use ark_poly::EvaluationDomain; +use o1_utils::FieldHelpers; +use thiserror::Error; +use tracing::instrument; // For injectivity, you can only use this on inputs of length at most // 'F::MODULUS_BIT_SIZE / 8', e.g. for Vesta this is 31. @@ -12,11 +17,6 @@ pub fn decode_into(buffer: &mut [u8], x: Fp) { buffer.copy_from_slice(&bytes); } -pub fn get_31_bytes(x: F) -> Vec { - let bytes = x.into_bigint().to_bytes_be(); - bytes[1..32].to_vec() -} - pub fn encode_as_field_elements(bytes: &[u8]) -> Vec { let n = (F::MODULUS_BIT_SIZE / 8) as usize; bytes @@ -59,89 +59,123 @@ pub struct QueryBytes { #[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Debug)] /// We store the data in a vector of vector of field element /// The inner vector represent polynomials -pub struct FieldElt { - /// the number of the polynomial the data point is attached too - pub poly_nb: usize, - /// the number of the root of unity the data point is attached too - pub eval_nb: usize, +struct FieldElt { + /// the index of the polynomial the data point is attached too + poly_index: usize, + /// the index of the root of unity the data point is attached too + eval_index: usize, + domain_size: usize, + n_polys: usize, } /// Represents a query in term of Field element #[derive(Debug)] -pub struct QueryField { - pub start: FieldElt, - /// how many bytes we need to trim from the first 31bytes chunk +pub struct QueryField { + start: FieldElt, + /// how many bytes we need to trim from the first chunk /// we get from the first field element we decode - pub leftover_start: usize, - pub end: FieldElt, - /// how many bytes we need to trim from the last 31bytes chunk + leftover_start: usize, + end: FieldElt, + /// how many bytes we need to trim from the last chunk /// we get from the last field element we decode - pub leftover_end: usize, + leftover_end: usize, + tag: PhantomData, } -impl QueryField { - pub fn is_valid(&self, nb_poly: usize) -> bool { - self.start.eval_nb < 1 << 16 - && self.end.eval_nb < 1 << 16 - && self.end.poly_nb < nb_poly - && self.start <= self.end - && self.leftover_end <= (F::MODULUS_BIT_SIZE as usize) / 8 - && self.leftover_start <= (F::MODULUS_BIT_SIZE as usize) / 8 - } - - pub fn apply(self, data: Vec>) -> Vec { - assert!(self.is_valid::(data.len()), "Invalid query"); - let mut answer: Vec = Vec::new(); - let mut field_elt = self.start; - while field_elt <= self.end { - if data[field_elt.poly_nb][field_elt.eval_nb] == F::zero() { - println!() - } - let mut to_append = get_31_bytes(data[field_elt.poly_nb][field_elt.eval_nb]); - answer.append(&mut to_append); - field_elt = field_elt.next().unwrap(); - } - let n = answer.len(); - // trimming the first and last 31bytes chunk - answer[(self.leftover_start)..(n - self.leftover_end)].to_vec() +impl QueryField { + #[instrument(skip_all, level = "debug")] + pub fn apply(self, data: &[Vec]) -> Vec { + let n = (F::MODULUS_BIT_SIZE / 8) as usize; + let m = F::size_in_bytes(); + let mut buffer = vec![0u8; m]; + let mut answer = Vec::new(); + self.start + .into_iter() + .take_while(|x| x <= &self.end) + .for_each(|x| { + let value = data[x.poly_index][x.eval_index]; + decode_into(&mut buffer, value); + answer.extend_from_slice(&buffer[(m - n)..m]); + }); + + answer[(self.leftover_start)..(answer.len() - self.leftover_end)].to_vec() } } impl Iterator for FieldElt { type Item = FieldElt; - fn next(&mut self) -> Option { - if self.eval_nb < (1 << 16) - 1 { - self.eval_nb += 1; + fn next(&mut self) -> Option { + let current = *self; + + if (self.eval_index + 1) < self.domain_size { + self.eval_index += 1; + } else if (self.poly_index + 1) < self.n_polys { + self.poly_index += 1; + self.eval_index = 0; } else { - self.poly_nb += 1; - self.eval_nb = 0 - }; - Some(*self) + return None; + } + + Some(current) } } -impl Into for QueryBytes { - fn into(self) -> QueryField { - let n = 31 as usize; - let start_field_nb = self.start / n; - let start = FieldElt { - poly_nb: start_field_nb / (1 << 16), - eval_nb: start_field_nb % (1 << 16), - }; - let leftover_start = self.start % n; +#[derive(Debug, Error, Clone, PartialEq)] +pub enum QueryError { + #[error("Query out of bounds: poly_index {poly_index} eval_index {eval_index} n_polys {n_polys} domain_size {domain_size}")] + QueryOutOfBounds { + poly_index: usize, + eval_index: usize, + n_polys: usize, + domain_size: usize, + }, +} +impl QueryBytes { + pub fn into_query_field( + &self, + domain_size: usize, + n_polys: usize, + ) -> Result, QueryError> { + let n = (F::MODULUS_BIT_SIZE / 8) as usize; + let start = { + let start_field_nb = self.start / n; + FieldElt { + poly_index: start_field_nb / domain_size, + eval_index: start_field_nb % domain_size, + domain_size, + n_polys, + } + }; let byte_end = self.start + self.len; - let end_field_nb = byte_end / n; - let end = FieldElt { - poly_nb: end_field_nb / (1 << 16), - eval_nb: end_field_nb % (1 << 16), + let end = { + let end_field_nb = byte_end / n; + FieldElt { + poly_index: end_field_nb / domain_size, + eval_index: end_field_nb % domain_size, + domain_size, + n_polys, + } }; + + if start.poly_index >= n_polys || end.poly_index >= n_polys { + return Err(QueryError::QueryOutOfBounds { + poly_index: end.poly_index, + eval_index: end.eval_index, + n_polys, + domain_size, + }); + }; + + let leftover_start = self.start % n; let leftover_end = n - byte_end % n; - QueryField { + + Ok(QueryField { start, leftover_start, end, leftover_end, - } + tag: std::marker::PhantomData, + }) } } @@ -156,6 +190,10 @@ pub mod test_utils { pub fn len(&self) -> usize { self.0.len() } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } } #[derive(Clone, Debug)] @@ -230,10 +268,10 @@ mod tests { use ark_poly::Radix2EvaluationDomain; use ark_std::UniformRand; use mina_curves::pasta::Fp; - use o1_utils::FieldHelpers; use once_cell::sync::Lazy; use proptest::prelude::*; - use test_utils::UserData; + use test_utils::{DataSize, UserData}; + use tracing::debug; fn decode(x: Fp) -> Vec { let mut buffer = vec![0u8; Fp::size_in_bytes()]; @@ -299,6 +337,25 @@ mod tests { } } + fn padded_field_length(xs: &[u8]) -> usize { + let m = Fp::MODULUS_BIT_SIZE as usize / 8; + let n = xs.len(); + let num_field_elems = (n + m - 1) / m; + let num_polys = (num_field_elems + DOMAIN.size() - 1) / DOMAIN.size(); + DOMAIN.size() * num_polys + } + + proptest! { + #![proptest_config(ProptestConfig::with_cases(20))] + #[test] + fn test_padded_byte_length(UserData(xs) in UserData::arbitrary() + ) + { let chunked = encode_for_domain(&*DOMAIN, &xs); + let n = chunked.into_iter().flatten().count(); + prop_assert_eq!(n, padded_field_length(&xs)); + } + } + proptest! { #![proptest_config(ProptestConfig::with_cases(20))] #[test] @@ -316,10 +373,66 @@ mod tests { let chunked = encode_for_domain(&*DOMAIN, &xs); for query in queries { let expected = &xs[query.start..(query.start+query.len)]; - let field_query: QueryField = query.clone().into(); - let got_answer = field_query.apply(chunked.clone()); // Note: might need clone depending on your types + let field_query: QueryField = query.into_query_field(DOMAIN.size(), chunked.len()).unwrap(); + let got_answer = field_query.apply(&chunked); prop_assert_eq!(expected, got_answer); } } } + + proptest! { + #![proptest_config(ProptestConfig::with_cases(20))] + #[test] + fn test_for_invalid_query_length( + (UserData(xs), mut query) in UserData::arbitrary() + .prop_flat_map(|UserData(xs)| { + let padded_len = { + let m = Fp::MODULUS_BIT_SIZE as usize / 8; + padded_field_length(&xs) * m + }; + let query_strategy = (0..xs.len()).prop_map(move |start| { + // this is the last valid end point + let end = padded_len - 1; + QueryBytes { start, len: end - start } + }); + (Just(UserData(xs)), query_strategy) + }) + ) { + debug!("check that first query is valid"); + let chunked = encode_for_domain(&*DOMAIN, &xs); + let n_polys = chunked.len(); + let query_field = query.into_query_field::(DOMAIN.size(), n_polys); + prop_assert!(query_field.is_ok()); + debug!("check that extending query length by 1 is invalid"); + query.len += 1; + let query_field = query.into_query_field::(DOMAIN.size(), n_polys); + prop_assert!(query_field.is_err()); + + } + } + + proptest! { + #![proptest_config(ProptestConfig::with_cases(20))] + #[test] + fn test_nil_query( + (UserData(xs), query) in UserData::arbitrary_with(DataSize::Small) + .prop_flat_map(|xs| { + let padded_len = { + let m = Fp::MODULUS_BIT_SIZE as usize / 8; + padded_field_length(&xs.0) * m + }; + let query_strategy = (0..padded_len).prop_map(move |start| { + QueryBytes { start, len: 0 } + }); + (Just(xs), query_strategy) + }) + ) { + let chunked = encode_for_domain(&*DOMAIN, &xs); + let n_polys = chunked.len(); + let field_query: QueryField = query.into_query_field(DOMAIN.size(), n_polys).unwrap(); + let got_answer = field_query.apply(&chunked); + prop_assert!(got_answer.is_empty()); + } + + } }