Skip to content

Commit

Permalink
Merge pull request #2975 from o1-labs/martin/marc/query-saffron-cleanup
Browse files Browse the repository at this point in the history
[saffron] cleanup query
  • Loading branch information
martyall authored Jan 30, 2025
2 parents e3253d3 + 47fa03d commit abf307b
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 68 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions saffron/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ rmp-serde.workspace = true
serde.workspace = true
serde_with.workspace = true
sha3.workspace = true
thiserror.workspace = true
time = { version = "0.3", features = ["macros"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = [ "ansi", "env-filter", "fmt", "time" ] }
Expand Down
249 changes: 181 additions & 68 deletions saffron/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use std::marker::PhantomData;

use ark_ff::{BigInteger, PrimeField};
use ark_poly::EvaluationDomain;
use o1_utils::FieldHelpers;
use thiserror::Error;
use tracing::instrument;

// For injectivity, you can only use this on inputs of length at most
// 'F::MODULUS_BIT_SIZE / 8', e.g. for Vesta this is 31.
Expand All @@ -12,11 +17,6 @@ pub fn decode_into<Fp: PrimeField>(buffer: &mut [u8], x: Fp) {
buffer.copy_from_slice(&bytes);
}

pub fn get_31_bytes<F: PrimeField>(x: F) -> Vec<u8> {
let bytes = x.into_bigint().to_bytes_be();
bytes[1..32].to_vec()
}

pub fn encode_as_field_elements<F: PrimeField>(bytes: &[u8]) -> Vec<F> {
let n = (F::MODULUS_BIT_SIZE / 8) as usize;
bytes
Expand Down Expand Up @@ -59,89 +59,123 @@ pub struct QueryBytes {
#[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Debug)]
/// We store the data in a vector of vector of field element
/// The inner vector represent polynomials
pub struct FieldElt {
/// the number of the polynomial the data point is attached too
pub poly_nb: usize,
/// the number of the root of unity the data point is attached too
pub eval_nb: usize,
struct FieldElt {
/// the index of the polynomial the data point is attached too
poly_index: usize,
/// the index of the root of unity the data point is attached too
eval_index: usize,
domain_size: usize,
n_polys: usize,
}
/// Represents a query in term of Field element
#[derive(Debug)]
pub struct QueryField {
pub start: FieldElt,
/// how many bytes we need to trim from the first 31bytes chunk
pub struct QueryField<F> {
start: FieldElt,
/// how many bytes we need to trim from the first chunk
/// we get from the first field element we decode
pub leftover_start: usize,
pub end: FieldElt,
/// how many bytes we need to trim from the last 31bytes chunk
leftover_start: usize,
end: FieldElt,
/// how many bytes we need to trim from the last chunk
/// we get from the last field element we decode
pub leftover_end: usize,
leftover_end: usize,
tag: PhantomData<F>,
}

impl QueryField {
pub fn is_valid<F: PrimeField>(&self, nb_poly: usize) -> bool {
self.start.eval_nb < 1 << 16
&& self.end.eval_nb < 1 << 16
&& self.end.poly_nb < nb_poly
&& self.start <= self.end
&& self.leftover_end <= (F::MODULUS_BIT_SIZE as usize) / 8
&& self.leftover_start <= (F::MODULUS_BIT_SIZE as usize) / 8
}

pub fn apply<F: PrimeField>(self, data: Vec<Vec<F>>) -> Vec<u8> {
assert!(self.is_valid::<F>(data.len()), "Invalid query");
let mut answer: Vec<u8> = Vec::new();
let mut field_elt = self.start;
while field_elt <= self.end {
if data[field_elt.poly_nb][field_elt.eval_nb] == F::zero() {
println!()
}
let mut to_append = get_31_bytes(data[field_elt.poly_nb][field_elt.eval_nb]);
answer.append(&mut to_append);
field_elt = field_elt.next().unwrap();
}
let n = answer.len();
// trimming the first and last 31bytes chunk
answer[(self.leftover_start)..(n - self.leftover_end)].to_vec()
impl<F: PrimeField> QueryField<F> {
#[instrument(skip_all, level = "debug")]
pub fn apply(self, data: &[Vec<F>]) -> Vec<u8> {
let n = (F::MODULUS_BIT_SIZE / 8) as usize;
let m = F::size_in_bytes();
let mut buffer = vec![0u8; m];
let mut answer = Vec::new();
self.start
.into_iter()
.take_while(|x| x <= &self.end)
.for_each(|x| {
let value = data[x.poly_index][x.eval_index];
decode_into(&mut buffer, value);
answer.extend_from_slice(&buffer[(m - n)..m]);
});

answer[(self.leftover_start)..(answer.len() - self.leftover_end)].to_vec()
}
}

impl Iterator for FieldElt {
type Item = FieldElt;
fn next(&mut self) -> Option<FieldElt> {
if self.eval_nb < (1 << 16) - 1 {
self.eval_nb += 1;
fn next(&mut self) -> Option<Self::Item> {
let current = *self;

if (self.eval_index + 1) < self.domain_size {
self.eval_index += 1;
} else if (self.poly_index + 1) < self.n_polys {
self.poly_index += 1;
self.eval_index = 0;
} else {
self.poly_nb += 1;
self.eval_nb = 0
};
Some(*self)
return None;
}

Some(current)
}
}

impl Into<QueryField> for QueryBytes {
fn into(self) -> QueryField {
let n = 31 as usize;
let start_field_nb = self.start / n;
let start = FieldElt {
poly_nb: start_field_nb / (1 << 16),
eval_nb: start_field_nb % (1 << 16),
};
let leftover_start = self.start % n;
#[derive(Debug, Error, Clone, PartialEq)]
pub enum QueryError {
#[error("Query out of bounds: poly_index {poly_index} eval_index {eval_index} n_polys {n_polys} domain_size {domain_size}")]
QueryOutOfBounds {
poly_index: usize,
eval_index: usize,
n_polys: usize,
domain_size: usize,
},
}

impl QueryBytes {
pub fn into_query_field<F: PrimeField>(
&self,
domain_size: usize,
n_polys: usize,
) -> Result<QueryField<F>, QueryError> {
let n = (F::MODULUS_BIT_SIZE / 8) as usize;
let start = {
let start_field_nb = self.start / n;
FieldElt {
poly_index: start_field_nb / domain_size,
eval_index: start_field_nb % domain_size,
domain_size,
n_polys,
}
};
let byte_end = self.start + self.len;
let end_field_nb = byte_end / n;
let end = FieldElt {
poly_nb: end_field_nb / (1 << 16),
eval_nb: end_field_nb % (1 << 16),
let end = {
let end_field_nb = byte_end / n;
FieldElt {
poly_index: end_field_nb / domain_size,
eval_index: end_field_nb % domain_size,
domain_size,
n_polys,
}
};

if start.poly_index >= n_polys || end.poly_index >= n_polys {
return Err(QueryError::QueryOutOfBounds {
poly_index: end.poly_index,
eval_index: end.eval_index,
n_polys,
domain_size,
});
};

let leftover_start = self.start % n;
let leftover_end = n - byte_end % n;
QueryField {

Ok(QueryField {
start,
leftover_start,
end,
leftover_end,
}
tag: std::marker::PhantomData,
})
}
}

Expand All @@ -156,6 +190,10 @@ pub mod test_utils {
pub fn len(&self) -> usize {
self.0.len()
}

pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -230,10 +268,10 @@ mod tests {
use ark_poly::Radix2EvaluationDomain;
use ark_std::UniformRand;
use mina_curves::pasta::Fp;
use o1_utils::FieldHelpers;
use once_cell::sync::Lazy;
use proptest::prelude::*;
use test_utils::UserData;
use test_utils::{DataSize, UserData};
use tracing::debug;

fn decode<Fp: PrimeField>(x: Fp) -> Vec<u8> {
let mut buffer = vec![0u8; Fp::size_in_bytes()];
Expand Down Expand Up @@ -299,6 +337,25 @@ mod tests {
}
}

fn padded_field_length(xs: &[u8]) -> usize {
let m = Fp::MODULUS_BIT_SIZE as usize / 8;
let n = xs.len();
let num_field_elems = (n + m - 1) / m;
let num_polys = (num_field_elems + DOMAIN.size() - 1) / DOMAIN.size();
DOMAIN.size() * num_polys
}

proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn test_padded_byte_length(UserData(xs) in UserData::arbitrary()
)
{ let chunked = encode_for_domain(&*DOMAIN, &xs);
let n = chunked.into_iter().flatten().count();
prop_assert_eq!(n, padded_field_length(&xs));
}
}

proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
Expand All @@ -316,10 +373,66 @@ mod tests {
let chunked = encode_for_domain(&*DOMAIN, &xs);
for query in queries {
let expected = &xs[query.start..(query.start+query.len)];
let field_query: QueryField = query.clone().into();
let got_answer = field_query.apply(chunked.clone()); // Note: might need clone depending on your types
let field_query: QueryField<Fp> = query.into_query_field(DOMAIN.size(), chunked.len()).unwrap();
let got_answer = field_query.apply(&chunked);
prop_assert_eq!(expected, got_answer);
}
}
}

proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn test_for_invalid_query_length(
(UserData(xs), mut query) in UserData::arbitrary()
.prop_flat_map(|UserData(xs)| {
let padded_len = {
let m = Fp::MODULUS_BIT_SIZE as usize / 8;
padded_field_length(&xs) * m
};
let query_strategy = (0..xs.len()).prop_map(move |start| {
// this is the last valid end point
let end = padded_len - 1;
QueryBytes { start, len: end - start }
});
(Just(UserData(xs)), query_strategy)
})
) {
debug!("check that first query is valid");
let chunked = encode_for_domain(&*DOMAIN, &xs);
let n_polys = chunked.len();
let query_field = query.into_query_field::<Fp>(DOMAIN.size(), n_polys);
prop_assert!(query_field.is_ok());
debug!("check that extending query length by 1 is invalid");
query.len += 1;
let query_field = query.into_query_field::<Fp>(DOMAIN.size(), n_polys);
prop_assert!(query_field.is_err());

}
}

proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn test_nil_query(
(UserData(xs), query) in UserData::arbitrary_with(DataSize::Small)
.prop_flat_map(|xs| {
let padded_len = {
let m = Fp::MODULUS_BIT_SIZE as usize / 8;
padded_field_length(&xs.0) * m
};
let query_strategy = (0..padded_len).prop_map(move |start| {
QueryBytes { start, len: 0 }
});
(Just(xs), query_strategy)
})
) {
let chunked = encode_for_domain(&*DOMAIN, &xs);
let n_polys = chunked.len();
let field_query: QueryField<Fp> = query.into_query_field(DOMAIN.size(), n_polys).unwrap();
let got_answer = field_query.apply(&chunked);
prop_assert!(got_answer.is_empty());
}

}
}

0 comments on commit abf307b

Please sign in to comment.