diff --git a/Cargo.toml b/Cargo.toml index 6c76ba8f..9cf5f096 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,77 @@ version = "0.2.0" edition = "2021" authors = ["Irreducible Team "] +[workspace.lints.clippy] +# These are some of clippy's nursery (i.e., experimental) lints that we like. +# By default, nursery lints are allowed. Some of the lints below have made good +# suggestions which we fixed. The others didn't have any findings, so we can +# assume they don't have that many false positives. Let's enable them to +# prevent future problems. +borrow_as_ptr = "warn" +branches_sharing_code = "warn" +clear_with_drain = "warn" +cloned_instead_of_copied = "warn" +collection_is_never_read = "warn" +dbg_macro = "warn" +derive_partial_eq_without_eq = "warn" +empty_line_after_doc_comments = "warn" +empty_line_after_outer_attr = "warn" +enum_glob_use = "warn" +equatable_if_let = "warn" +explicit_into_iter_loop = "warn" +explicit_iter_loop = "warn" +flat_map_option = "warn" +from_iter_instead_of_collect = "warn" +if_not_else = "warn" +if_then_some_else_none = "warn" +implicit_clone = "warn" +imprecise_flops = "warn" +iter_on_empty_collections = "warn" +iter_on_single_items = "warn" +iter_with_drain = "warn" +iter_without_into_iter = "warn" +large_stack_frames = "warn" +manual_assert = "warn" +manual_clamp = "warn" +manual_is_variant_and = "warn" +manual_string_new = "warn" +match_same_arms = "warn" +missing_const_for_fn = "warn" +mutex_integer = "warn" +naive_bytecount = "warn" +needless_bitwise_bool = "warn" +needless_continue = "warn" +needless_for_each = "warn" +needless_pass_by_ref_mut = "warn" +nonstandard_macro_braces = "warn" +option_as_ref_cloned = "warn" +or_fun_call = "warn" +path_buf_push_overwrite = "warn" +read_zero_byte_vec = "warn" +redundant_clone = "warn" +redundant_else = "warn" +single_char_pattern = "warn" +string_lit_as_bytes = "warn" +string_lit_chars_any = "warn" +suboptimal_flops = "warn" +suspicious_operation_groupings = "warn" +trailing_empty_array = "warn" +trait_duplication_in_bounds = "warn" +transmute_undefined_repr = "warn" +trivial_regex = "warn" +tuple_array_conversions = "warn" +type_repetition_in_bounds = "warn" +uninhabited_references = "warn" +unnecessary_self_imports = "warn" +unnecessary_struct_initialization = "warn" +unnested_or_patterns = "warn" +unused_peekable = "warn" +unused_rounding = "warn" +use_self = "warn" +useless_let_if_seq = "warn" +while_float = "warn" +zero_sized_map_values = "warn" + [workspace.dependencies] anyhow = "1.0.81" assert_matches = "1.5.0" diff --git a/crates/circuits/src/arithmetic/u32.rs b/crates/circuits/src/arithmetic/u32.rs index ea1b9cb3..88a0c9a7 100644 --- a/crates/circuits/src/arithmetic/u32.rs +++ b/crates/circuits/src/arithmetic/u32.rs @@ -151,6 +151,73 @@ where Ok(zout) } +pub fn sub( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + zin: OracleId, + yin: OracleId, + flags: super::Flags, +) -> Result +where + U: PackScalar + PackScalar + Pod, + F: TowerField, +{ + builder.push_namespace(name); + let log_rows = builder.log_rows([zin, yin])?; + let cout = builder.add_committed("cout", log_rows, BinaryField1b::TOWER_LEVEL); + let cin = builder.add_shifted("cin", cout, 1, 5, ShiftVariant::LogicalLeft)?; + let xout = builder.add_committed("xin", log_rows, BinaryField1b::TOWER_LEVEL); + + if let Some(witness) = builder.witness() { + ( + witness.get::(zin)?.as_slice::(), + witness.get::(yin)?.as_slice::(), + witness + .new_column::(xout) + .as_mut_slice::(), + witness + .new_column::(cout) + .as_mut_slice::(), + witness + .new_column::(cin) + .as_mut_slice::(), + ) + .into_par_iter() + .for_each(|(zout, yin, xin, cout, cin)| { + let carry; + (*xin, carry) = (*zout).overflowing_sub(*yin); + *cin = (*xin) ^ (*yin) ^ (*zout); + *cout = ((carry as u32) << 31) | (*cin >> 1); + }); + } + + builder.assert_zero( + "sum", + [xout, yin, cin, zin], + arith_expr!([xout, yin, cin, zin] = xout + yin + cin - zin).convert_field(), + ); + + builder.assert_zero( + "carry", + [xout, yin, cin, cout], + arith_expr!([xout, yin, cin, cout] = (xout + cin) * (yin + cin) + cin - cout) + .convert_field(), + ); + + // Underflow checking + if matches!(flags, super::Flags::Checked) { + let last_cout = select_bit(builder, "last_cout", cout, 31)?; + builder.assert_zero( + "underflow", + [last_cout], + arith_expr!([last_cout] = last_cout).convert_field(), + ); + } + + builder.pop_namespace(); + Ok(xout) +} + pub fn half( builder: &mut ConstraintSystemBuilder, name: impl ToString, @@ -296,7 +363,7 @@ mod tests { use binius_core::constraint_system::validate::validate_witness; use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b, TowerField}; - use crate::{arithmetic, builder::ConstraintSystemBuilder}; + use crate::{arithmetic, builder::ConstraintSystemBuilder, unconstrained::unconstrained}; type U = OptimalUnderlier; type F = BinaryField128b; @@ -323,4 +390,20 @@ mod tests { let boundaries = vec![]; validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } + + #[test] + fn test_sub() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let a = unconstrained::(&mut builder, "a", 7).unwrap(); + let b = unconstrained::(&mut builder, "a", 7).unwrap(); + let _c = + arithmetic::u32::sub(&mut builder, "c", a, b, arithmetic::Flags::Unchecked).unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } } diff --git a/crates/core/benches/prodcheck.rs b/crates/core/benches/prodcheck.rs index d0e56269..bcfd2632 100644 --- a/crates/core/benches/prodcheck.rs +++ b/crates/core/benches/prodcheck.rs @@ -5,65 +5,185 @@ use std::iter::repeat_with; use binius_core::{ fiat_shamir::HasherChallenger, protocols::gkr_gpa::{self, GrandProductClaim, GrandProductWitness}, - transcript::TranscriptWriter, + transcript::ProverTranscript, }; use binius_field::{ - arch::packed_polyval_128::PackedBinaryPolyval1x128b, BinaryField128b, BinaryField128bPolyval, - PackedField, + arch::OptimalUnderlier, + as_packed_field::PackScalar, + linear_transformation::{PackedTransformationFactory, Transformation}, + BinaryField128b, BinaryField128bPolyval, PackedField, PackedFieldIndexable, TowerField, + BINARY_TO_POLYVAL_TRANSFORMATION, }; -use binius_hal::make_portable_backend; +use binius_hal::{make_portable_backend, CpuBackend}; use binius_math::{IsomorphicEvaluationDomainFactory, MultilinearExtension}; +use binius_maybe_rayon::iter::{IntoParallelIterator, ParallelIterator}; use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use groestl_crypto::Groestl256; use rand::{rngs::StdRng, SeedableRng}; // Creates T(x), a multilinear with evaluations over the n-dimensional boolean hypercube -fn create_numerator(n_vars: usize) -> MultilinearExtension

{ +fn create_numerator(n_vars: usize) -> Vec

{ let mut rng = StdRng::seed_from_u64(0); - let values = repeat_with(|| P::random(&mut rng)) - .take(1 << n_vars) - .collect::>(); + repeat_with(|| P::random(&mut rng)) + .take(1 << (n_vars - P::LOG_WIDTH)) + .collect() +} - MultilinearExtension::from_values(values).unwrap() +fn apply_transformation( + input: &[IP], + transformation: &impl Transformation, +) -> Vec { + input.iter().map(|x| transformation.transform(x)).collect() } -fn bench_polyval(c: &mut Criterion) { - type FDomain = BinaryField128b; +const N_VARS: [usize; 3] = [12, 16, 20]; +const N_CLAIMS: usize = 20; +type FDomain = BinaryField128b; - type P = PackedBinaryPolyval1x128b; - type FW = BinaryField128bPolyval; - let mut group = c.benchmark_group("prodcheck"); +fn bench_gpa_generic(name: &str, c: &mut Criterion, bench_fn: &BenchFn) +where + U: PackScalar, + F: TowerField + From, + BenchFn: Fn( + usize, + &mut ProverTranscript>, + &[U::Packed], + &IsomorphicEvaluationDomainFactory, + &CpuBackend, + ) -> R, +{ + let mut group = c.benchmark_group(name); let domain_factory = IsomorphicEvaluationDomainFactory::::default(); - for n in [12, 16, 20] { + for n_vars in N_VARS { group.throughput(Throughput::Bytes( - ((1 << n) * std::mem::size_of::()) as u64, + ((1 << n_vars) * std::mem::size_of::() * N_CLAIMS) as u64, )); - group.bench_function(format!("n_vars={n}"), |bench| { + group.bench_function(format!("n_vars={n_vars}"), |bench| { // Setup witness - let numerator = create_numerator::

(n); - - let gpa_witness = - GrandProductWitness::

::new(numerator.specialize_arc_dyn()).unwrap(); + let numerator = create_numerator::(n_vars); - let product: FW = FW::from(gpa_witness.grand_product_evaluation()); - let gpa_claim = GrandProductClaim { n_vars: n, product }; let backend = make_portable_backend(); - let mut prover_transcript = TranscriptWriter::>::default(); + let mut prover_transcript = ProverTranscript::>::default(); bench.iter(|| { - gkr_gpa::batch_prove::( - [gpa_witness.clone()], - &[gpa_claim.clone()], - domain_factory.clone(), - &mut prover_transcript, - &backend, - ) + bench_fn(n_vars, &mut prover_transcript, &numerator, &domain_factory, &backend) }); }); } group.finish() } +fn bench_gpa(name: &str, c: &mut Criterion) +where + U: PackScalar, + F: TowerField + From, +{ + bench_gpa_generic::( + name, + c, + &|n_vars, prover_transcript, numerator, domain_factory, backend| { + let (gpa_witnesses, gpa_claims): (Vec<_>, Vec<_>) = (0..N_CLAIMS) + .into_par_iter() + .map(|_| { + let numerator = + MultilinearExtension::::from_values_generic( + numerator, + ) + .unwrap(); + let gpa_witness = GrandProductWitness::<>::Packed>::new( + numerator.specialize_arc_dyn(), + ) + .unwrap(); + + let product = gpa_witness.grand_product_evaluation(); + + (gpa_witness, GrandProductClaim { n_vars, product }) + }) + .collect::>() + .into_iter() + .unzip(); + + gkr_gpa::batch_prove::( + gpa_witnesses, + &gpa_claims, + domain_factory.clone(), + prover_transcript, + backend, + ) + }, + ); +} + +fn bench_gpa_isomorphic(name: &str, c: &mut Criterion) +where + U: PackScalar< + BinaryField128b, + Packed: PackedFieldIndexable + + PackedTransformationFactory<>::Packed>, + >, + U: PackScalar< + BinaryField128bPolyval, + Packed: PackedFieldIndexable + + PackedTransformationFactory<>::Packed>, + >, +{ + let transform_to_polyval = + >::Packed::make_packed_transformation( + BINARY_TO_POLYVAL_TRANSFORMATION, + ); + + bench_gpa_generic::( + name, + c, + &|n_vars, prover_transcript, numerator, domain_factory, backend| { + let (gpa_witnesses, gpa_claims): (Vec<_>, Vec<_>) = (0..N_CLAIMS) + .into_par_iter() + .map(|_| { + let transformed_values = apply_transformation(numerator, &transform_to_polyval); + let numerator = MultilinearExtension::from_values(transformed_values).unwrap(); + + let gpa_witness = GrandProductWitness::< + >::Packed, + >::new(numerator.specialize_arc_dyn()) + .unwrap(); + + let product = gpa_witness.grand_product_evaluation(); + + (gpa_witness, GrandProductClaim { n_vars, product }) + }) + .collect::>() + .into_iter() + .unzip(); + + gkr_gpa::batch_prove::< + BinaryField128bPolyval, + >::Packed, + BinaryField128bPolyval, + _, + _, + >( + gpa_witnesses.clone(), + &gpa_claims, + domain_factory.clone(), + prover_transcript, + &backend, + ) + }, + ); +} + +fn bench_polyval(c: &mut Criterion) { + bench_gpa::("polyval", c); +} + +fn bench_binary(c: &mut Criterion) { + bench_gpa::("binary_128b", c); +} + +fn bench_binary_isomorphic(c: &mut Criterion) { + bench_gpa_isomorphic::("binary_128b_isomorphic", c); +} + criterion_main!(prodcheck); -criterion_group!(prodcheck, bench_polyval); +criterion_group!(prodcheck, bench_polyval, bench_binary, bench_binary_isomorphic); diff --git a/crates/core/src/constraint_system/common.rs b/crates/core/src/constraint_system/common.rs index 558b2659..8e5ddd68 100644 --- a/crates/core/src/constraint_system/common.rs +++ b/crates/core/src/constraint_system/common.rs @@ -1,10 +1,13 @@ // Copyright 2024-2025 Irreducible Inc. -use crate::tower::TowerFamily; +use crate::tower::{ProverTowerFamily, TowerFamily}; /// The cryptographic extension field that the constraint system protocol is defined over. pub type FExt = ::B128; +/// Field with fast multiplication and isomorphism to FExt. +pub type FFastExt = ::FastB128; + /// The evaluation domain used in sumcheck protocols. /// /// This is fixed to be 8-bits, which is large enough to handle all reasonable sumcheck @@ -14,6 +17,6 @@ pub type FDomain = ::B8; /// The Reed–Solomon alphabet used for FRI encoding. /// -/// This is fixed to be 32-bits, which is large enough to handle trace sizes up to 64 GiB +/// This is fixed to be 32-bits, which is large enough to handle trace sizes up to 512 GiB /// of committed data. pub type FEncode = ::B32; diff --git a/crates/core/src/constraint_system/mod.rs b/crates/core/src/constraint_system/mod.rs index 5ccfe80a..be17b6fd 100644 --- a/crates/core/src/constraint_system/mod.rs +++ b/crates/core/src/constraint_system/mod.rs @@ -46,11 +46,10 @@ impl ConstraintSystem { #[derive(Debug, Clone)] pub struct Proof { pub transcript: Vec, - pub advice: Vec, } impl Proof { pub fn get_proof_size(&self) -> usize { - self.transcript.len() + self.advice.len() + self.transcript.len() } } diff --git a/crates/core/src/constraint_system/prove.rs b/crates/core/src/constraint_system/prove.rs index 88dde23b..bb455489 100644 --- a/crates/core/src/constraint_system/prove.rs +++ b/crates/core/src/constraint_system/prove.rs @@ -4,13 +4,16 @@ use std::{cmp::Reverse, env, marker::PhantomData}; use binius_field::{ as_packed_field::{PackScalar, PackedType}, + linear_transformation::{PackedTransformationFactory, Transformation}, + underlier::WithUnderlier, BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, RepackedExtension, TowerField, }; use binius_hal::ComputationBackend; use binius_hash::PseudoCompressionFunction; use binius_math::{ - ArithExpr, EvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly, + ArithExpr, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MLEDirectAdapter, + MultilinearExtension, MultilinearPoly, }; use binius_maybe_rayon::prelude::*; use binius_utils::bail; @@ -29,7 +32,7 @@ use super::{ }; use crate::{ constraint_system::{ - common::{FDomain, FEncode, FExt}, + common::{FDomain, FEncode, FExt, FFastExt}, verify::{get_flush_dedup_sumcheck_metas, FlushSumcheckMeta, StepDownMeta}, }, fiat_shamir::{CanSample, Challenger}, @@ -50,8 +53,8 @@ use crate::{ }, }, ring_switch, - tower::{PackedTop, TowerFamily, TowerUnderlier}, - transcript::{AdviceWriter, CanWrite, Proof as ProofWriter, TranscriptWriter}, + tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier, TowerFamily}, + transcript::ProverTranscript, witness::{MultilinearExtensionIndex, MultilinearWitness}, }; @@ -66,8 +69,8 @@ pub fn prove( backend: &Backend, ) -> Result where - U: TowerUnderlier, - Tower: TowerFamily, + U: ProverTowerUnderlier, + Tower: ProverTowerFamily, Tower::B128: PackedTop, DomainFactory: EvaluationDomainFactory>, Hash: Digest + BlockSizeUser + FixedOutputReset, @@ -81,7 +84,10 @@ where + RepackedExtension> + RepackedExtension> + RepackedExtension> - + RepackedExtension>, + + RepackedExtension> + + PackedTransformationFactory>, + PackedType: + PackedFieldIndexable + PackedTransformationFactory>, PackedType: PackedFieldIndexable + PackedExtension, PackedSubfield: PackedFieldIndexable>, PackedType: PackedFieldIndexable @@ -97,8 +103,9 @@ where "using computation backend: {backend:?}" ); - let mut transcript = TranscriptWriter::::default(); - let mut advice = AdviceWriter::default(); + let fast_domain_factory = IsomorphicEvaluationDomainFactory::>::default(); + + let mut transcript = ProverTranscript::::new(); let ConstraintSystem { mut oracles, @@ -136,12 +143,18 @@ where } = piop::commit(&fri_params, &merkle_prover, &committed_multilins)?; // Observe polynomial commitment - transcript.write(&commitment); + let mut writer = transcript.message(); + writer.write(&commitment); // Grand product arguments // Grand products for non-zero checking - let non_zero_prodcheck_witnesses = - gkr_gpa::construct_grand_product_witnesses(&non_zero_oracle_ids, &witness)?; + let non_zero_fast_witnesses = + make_fast_masked_flush_witnesses(&oracles, &witness, &non_zero_oracle_ids, None)?; + let non_zero_prodcheck_witnesses = non_zero_fast_witnesses + .into_par_iter() + .map(GrandProductWitness::new) + .collect::, _>>()?; + let non_zero_products = gkr_gpa::get_grand_products_from_witnesses(&non_zero_prodcheck_witnesses); if non_zero_products @@ -151,7 +164,7 @@ where bail!(Error::Zeros); } - transcript.write_scalar_slice(&non_zero_products); + writer.write_scalar_slice(&non_zero_products); let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims( &non_zero_oracle_ids, @@ -170,8 +183,12 @@ where make_unmasked_flush_witnesses(&oracles, &mut witness, &flush_oracle_ids)?; // there are no oracle ids associated with these flush_witnesses - let flush_witnesses = - make_masked_flush_witnesses(&oracles, &witness, &flush_oracle_ids, &flush_counts)?; + let flush_witnesses = make_fast_masked_flush_witnesses( + &oracles, + &witness, + &flush_oracle_ids, + Some(&flush_counts), + )?; // This is important to do in parallel. let flush_prodcheck_witnesses = flush_witnesses @@ -180,21 +197,31 @@ where .collect::, _>>()?; let flush_products = gkr_gpa::get_grand_products_from_witnesses(&flush_prodcheck_witnesses); - transcript.write_scalar_slice(&flush_products); + transcript.message().write_scalar_slice(&flush_products); let flush_prodcheck_claims = gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?; // Prove grand products - let GrandProductBatchProveOutput { - mut final_layer_claims, - } = gkr_gpa::batch_prove::<_, _, FDomain, _, _>( - [flush_prodcheck_witnesses, non_zero_prodcheck_witnesses].concat(), - &[flush_prodcheck_claims, non_zero_prodcheck_claims].concat(), - &domain_factory, - &mut transcript, - backend, - )?; + let all_gpa_witnesses = [flush_prodcheck_witnesses, non_zero_prodcheck_witnesses].concat(); + let all_gpa_claims = chain!(flush_prodcheck_claims, non_zero_prodcheck_claims) + .map(|claim| claim.isomorphic()) + .collect::>(); + + let GrandProductBatchProveOutput { final_layer_claims } = + gkr_gpa::batch_prove::, _, FFastExt, _, _>( + all_gpa_witnesses, + &all_gpa_claims, + &fast_domain_factory, + &mut transcript, + backend, + )?; + + // Apply isomorphism to the layer claims + let mut final_layer_claims = final_layer_claims + .into_iter() + .map(|layer_claim| layer_claim.isomorphic()) + .collect::>(); let non_zero_final_layer_claims = final_layer_claims.split_off(flush_oracle_ids.len()); let flush_final_layer_claims = final_layer_claims; @@ -378,7 +405,6 @@ where .chain(zerocheck_eval_claims), switchover_fn, &mut transcript, - &mut advice, &domain_factory, backend, )?; @@ -387,22 +413,18 @@ where let system = ring_switch::EvalClaimSystem::new(&commit_meta, oracle_to_commit_index, &eval_claims)?; - let mut proof_writer = ProofWriter { - transcript: &mut transcript, - advice: &mut advice, - }; let ring_switch::ReducedWitness { transparents: transparent_multilins, sumcheck_claims: piop_sumcheck_claims, - } = ring_switch::prove::<_, _, _, Tower, _, _, _>( + } = ring_switch::prove::<_, _, _, Tower, _, _>( &system, &committed_multilins, - &mut proof_writer, + &mut transcript, backend, )?; // Prove evaluation claims using PIOP compiler - piop::prove::<_, FDomain, _, _, _, _, _, _, _, _, _>( + piop::prove::<_, FDomain, _, _, _, _, _, _, _, _>( &fri_params, &merkle_prover, domain_factory, @@ -412,13 +434,12 @@ where &committed_multilins, &transparent_multilins, &piop_sumcheck_claims, - &mut proof_writer, + &mut transcript, &backend, )?; Ok(Proof { transcript: transcript.finalize(), - advice: advice.finalize(), }) } @@ -518,8 +539,8 @@ fn make_unmasked_flush_witnesses<'a, U, Tower>( flush_oracle_ids: &[OracleId], ) -> Result<(), Error> where - U: TowerUnderlier, - Tower: TowerFamily, + U: ProverTowerUnderlier, + Tower: ProverTowerFamily, { // The function is on the critical path, parallelize. let flush_witnesses: Result>, Error> = flush_oracle_ids @@ -570,40 +591,76 @@ where #[allow(clippy::type_complexity)] #[instrument(skip_all, level = "debug")] -fn make_masked_flush_witnesses<'a, U, Tower>( +fn make_fast_masked_flush_witnesses<'a, U, Tower>( oracles: &MultilinearOracleSet>, witness: &MultilinearExtensionIndex<'a, U, FExt>, flush_oracles: &[OracleId], - flush_counts: &[usize], -) -> Result>>>, Error> + flush_counts: Option<&[usize]>, +) -> Result>>>, Error> where - U: TowerUnderlier, - Tower: TowerFamily, + U: ProverTowerUnderlier, + Tower: ProverTowerFamily, + PackedType: PackedTransformationFactory>, { + let to_fast = Tower::packed_transformation_to_fast(); + // The function is on the critical path, parallelize. flush_oracles .par_iter() - .zip(flush_counts.par_iter()) - .map(|(&flush_oracle_id, &flush_count)| { + .enumerate() + .map(|(i, &flush_oracle_id)| { let n_vars = oracles.n_vars(flush_oracle_id); - let packed_len = 1 << n_vars.saturating_sub(>>::LOG_WIDTH); - let mut result = vec![>>::one(); packed_len]; + let flush_count = flush_counts.map_or(1 << n_vars, |flush_counts| flush_counts[i]); + + debug_assert!(flush_count <= 1 << n_vars); + + let log_width = >>::LOG_WIDTH; + let width = 1 << log_width; + + let packed_len = 1 << n_vars.saturating_sub(log_width); + let mut fast_ext_result = vec![PackedType::>::one(); packed_len]; let poly = witness.get_multilin_poly(flush_oracle_id)?; - let width = >>::WIDTH; - let packed_index = flush_count / width; - for (i, result_val) in result.iter_mut().take(packed_index).enumerate() { - for j in 0..width { - let index = (i << >>::LOG_WIDTH) | j; - result_val.set(j, poly.evaluate_on_hypercube(index)?); - } - } - for j in 0..flush_count % width { - let index = packed_index << >>::LOG_WIDTH | j; - result[packed_index].set(j, poly.evaluate_on_hypercube(index)?); - } - let masked_poly = MultilinearExtension::new(n_vars, result) + const MAX_SUBCUBE_VARS: usize = 8; + let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars); + let subcubes_count = flush_count.div_ceil(1 << subcube_vars); + let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width); + + fast_ext_result[..subcube_packed_size * subcubes_count] + .par_chunks_mut(subcube_packed_size) + .enumerate() + .for_each(|(subcube_index, fast_subcube)| { + let underliers = + PackedType::>::to_underliers_ref_mut(fast_subcube); + + let subcube_evals = + PackedType::>::from_underliers_ref_mut(underliers); + poly.subcube_evals(subcube_vars, subcube_index, 0, subcube_evals) + .expect("witness data populated by make_unmasked_flush_witnesses()"); + + for underlier in underliers.iter_mut() { + let src = PackedType::>::from_underlier(*underlier); + let dest = to_fast.transform(&src); + *underlier = PackedType::>::to_underlier(dest); + } + + let fast_subcube = + PackedType::>::from_underliers_ref_mut(underliers); + if flush_count < (subcube_index + 1) << subcube_vars { + let offset = flush_count - (subcube_index << subcube_vars); + fast_subcube[offset.div_ceil(width)..].fill(PackedField::one()); + + let scalar_offset = offset % width; + if scalar_offset != 0 { + for j in scalar_offset..width { + fast_subcube[offset / width].set(j, FFastExt::::ONE); + } + } + } + }); + + let masked_poly = MultilinearExtension::new(n_vars, fast_ext_result) .expect("data is constructed with the correct length with respect to n_vars"); Ok(MLEDirectAdapter::from(masked_poly).upcast_arc_dyn()) }) @@ -625,8 +682,8 @@ fn get_flush_sumcheck_provers<'a, 'b, U, Tower, FDomain, DomainFactory, Backend> Error, > where - U: TowerUnderlier + PackScalar, - Tower: TowerFamily, + U: ProverTowerUnderlier + PackScalar, + Tower: ProverTowerFamily, Tower::B128: ExtensionField, FDomain: Field, DomainFactory: EvaluationDomainFactory, diff --git a/crates/core/src/constraint_system/verify.rs b/crates/core/src/constraint_system/verify.rs index 0c3a1dc5..6a4005c4 100644 --- a/crates/core/src/constraint_system/verify.rs +++ b/crates/core/src/constraint_system/verify.rs @@ -39,7 +39,7 @@ use crate::{ }, ring_switch, tower::{PackedTop, TowerFamily, TowerUnderlier}, - transcript::{AdviceReader, CanRead, Proof as ProofReader, TranscriptReader}, + transcript::VerifierTranscript, transparent::{eq_ind::EqIndPartialEval, step_down}, }; @@ -72,10 +72,9 @@ where // Stable sort constraint sets in descending order by number of variables. table_constraints.sort_by_key(|constraint_set| Reverse(constraint_set.n_vars)); - let Proof { transcript, advice } = proof; + let Proof { transcript } = proof; - let mut transcript = TranscriptReader::::new(transcript); - let mut advice = AdviceReader::new(advice); + let mut transcript = VerifierTranscript::::new(transcript); let merkle_scheme = BinaryMerkleTreeScheme::<_, Hash, _>::new(Compress::default()); let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?; @@ -87,11 +86,12 @@ where )?; // Read polynomial commitment polynomials - let commitment = transcript.read::>()?; + let mut reader = transcript.message(); + let commitment = reader.read::>()?; // Grand product arguments // Grand products for non-zero checks - let non_zero_products = transcript.read_scalar_slice(non_zero_oracle_ids.len())?; + let non_zero_products = reader.read_scalar_slice(non_zero_oracle_ids.len())?; if non_zero_products .iter() .any(|count| *count == Tower::B128::zero()) @@ -115,7 +115,9 @@ where make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?; let flush_counts = flushes.iter().map(|flush| flush.count).collect::>(); - let flush_products = transcript.read_scalar_slice(flush_oracle_ids.len())?; + let flush_products = transcript + .message() + .read_scalar_slice(flush_oracle_ids.len())?; verify_channels_balance( &flushes, &flush_products, @@ -262,21 +264,16 @@ where .into_iter() .chain(zerocheck_eval_claims), &mut transcript, - &mut advice, )?; // Reduce committed evaluation claims to PIOP sumcheck claims let system = ring_switch::EvalClaimSystem::new(&commit_meta, oracle_to_commit_index, &eval_claims)?; - let mut proof_reader = ProofReader { - transcript: &mut transcript, - advice: &mut advice, - }; let ring_switch::ReducedClaim { transparents, sumcheck_claims: piop_sumcheck_claims, - } = ring_switch::verify::<_, Tower, _, _>(&system, &mut proof_reader)?; + } = ring_switch::verify::<_, Tower, _>(&system, &mut transcript)?; // Prove evaluation claims using PIOP compiler piop::verify( @@ -286,11 +283,10 @@ where &commitment, &transparents, &piop_sumcheck_claims, - &mut proof_reader, + &mut transcript, )?; transcript.finalize()?; - advice.finalize()?; Ok(()) } diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index fae7d800..c7ebeab0 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -15,7 +15,6 @@ pub mod composition; pub mod constraint_system; pub mod fiat_shamir; -pub mod linear_code; pub mod merkle_tree; pub mod oracle; pub mod piop; diff --git a/crates/core/src/linear_code.rs b/crates/core/src/linear_code.rs deleted file mode 100644 index 7ce780e2..00000000 --- a/crates/core/src/linear_code.rs +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2023-2025 Irreducible Inc. - -//! Linear error-correcting code traits. - -use binius_field::{ExtensionField, PackedField, RepackedExtension}; -use binius_utils::checked_arithmetics::checked_log_2; -use tracing::instrument; - -/// An encodable [linear error-correcting code](https://en.wikipedia.org/wiki/Linear_code) intended -/// for use in a Brakedown-style polynomial commitment scheme. -/// -/// This trait represents linear codes with a dimension that is a power of 2, as that property is -/// required for the Brakedown polynomial commitment scheme. -/// -/// Requirements: -/// - `len()` is a multiple of `dim()` -/// - `dim()` is a power of 2 -/// - `dim()` is a multiple of `P::WIDTH` -#[allow(clippy::len_without_is_empty)] -pub trait LinearCode { - type P: PackedField; - type EncodeError: std::error::Error + Send + Sync + 'static; - - /// The block length. - fn len(&self) -> usize { - self.dim() * self.inv_rate() - } - - /// The base-2 log of the dimension. - fn dim_bits(&self) -> usize; - - /// The dimension. - fn dim(&self) -> usize { - 1 << self.dim_bits() - } - - /// The minimum distance between codewords. - fn min_dist(&self) -> usize; - - /// The reciprocal of the rate, ie. `self.len() / self.dim()`. - fn inv_rate(&self) -> usize; - - /// Encode a batch of interleaved messages in-place in a provided buffer. - /// - /// The message symbols are interleaved in the buffer, which improves the cache-efficiency of - /// the encoding procedure. The interleaved codeword is stored in the buffer when the method - /// completes. - /// - /// ## Throws - /// - /// * If the `code` buffer does not have capacity for `len() << log_batch_size` field - /// elements. - fn encode_batch_inplace( - &self, - code: &mut [Self::P], - log_batch_size: usize, - ) -> Result<(), Self::EncodeError>; - - /// Encode a message in-place in a provided buffer. - /// - /// ## Throws - /// - /// * If the `code` buffer does not have capacity for `len()` field elements. - fn encode_inplace(&self, code: &mut [Self::P]) -> Result<(), Self::EncodeError> { - self.encode_batch_inplace(code, 0) - } - - /// Encode a message provided as a vector of packed field elements. - fn encode(&self, mut msg: Vec) -> Result, Self::EncodeError> { - msg.resize(msg.len() * self.inv_rate(), Self::P::default()); - self.encode_inplace(&mut msg)?; - Ok(msg) - } - - /// Encode a batch of interleaved messages of extension field elements in-place in a provided - /// buffer. - /// - /// A linear code can be naturally extended to a code over extension fields by encoding each - /// dimension of the extension as a vector-space separately. - /// - /// ## Preconditions - /// - /// * `PE::Scalar::DEGREE` must be a power of two. - /// - /// ## Throws - /// - /// * If the `code` buffer does not have capacity for `len() << log_batch_size` field elements. - #[instrument(skip_all, level = "debug")] - fn encode_ext_batch_inplace( - &self, - code: &mut [PE], - log_batch_size: usize, - ) -> Result<(), Self::EncodeError> - where - PE: RepackedExtension, - PE::Scalar: ExtensionField<::Scalar>, - { - let log_degree = checked_log_2(PE::Scalar::DEGREE); - self.encode_batch_inplace(PE::cast_bases_mut(code), log_batch_size + log_degree) - } - - /// Encode a message of extension field elements in-place in a provided buffer. - /// - /// See [`Self::encode_ext_batch_inplace`] for more details. - /// - /// ## Throws - /// - /// * If the `code` buffer does not have capacity for `len()` field elements. - fn encode_ext_inplace(&self, code: &mut [PE]) -> Result<(), Self::EncodeError> - where - PE: RepackedExtension, - PE::Scalar: ExtensionField<::Scalar>, - { - self.encode_ext_batch_inplace(code, 0) - } - - /// Encode a message of extension field elements provided as a vector of packed field elements. - /// - /// See [`Self::encode_ext_inplace`] for more details. - fn encode_extension(&self, mut msg: Vec) -> Result, Self::EncodeError> - where - PE: RepackedExtension, - PE::Scalar: ExtensionField<::Scalar>, - { - msg.resize(msg.len() * self.inv_rate(), PE::default()); - self.encode_ext_inplace(&mut msg)?; - Ok(msg) - } -} diff --git a/crates/core/src/merkle_tree/merkle_tree_vcs.rs b/crates/core/src/merkle_tree/merkle_tree_vcs.rs index 9e0d6c9f..67eed306 100644 --- a/crates/core/src/merkle_tree/merkle_tree_vcs.rs +++ b/crates/core/src/merkle_tree/merkle_tree_vcs.rs @@ -1,9 +1,10 @@ // Copyright 2024-2025 Irreducible Inc. use binius_maybe_rayon::iter::IndexedParallelIterator; +use bytes::{Buf, BufMut}; use super::errors::Error; -use crate::transcript::{CanRead, CanWrite}; +use crate::transcript::{TranscriptReader, TranscriptWriter}; /// A Merkle tree commitment. /// @@ -53,14 +54,14 @@ pub trait MerkleTreeScheme: Sync { ) -> Result<(), Error>; /// Verify an opening proof for an entry in a committed vector at the given index. - fn verify_opening( + fn verify_opening( &self, index: usize, values: &[T], layer_depth: usize, tree_depth: usize, layer_digests: &[Self::Digest], - proof: Proof, + proof: &mut TranscriptReader, ) -> Result<(), Error>; } @@ -108,11 +109,11 @@ pub trait MerkleTreeProver: Sync { /// * `committed` - helper data generated during commitment /// * `layer_depth` - depth of the layer to prove inclusion in /// * `index` - the entry index - fn prove_opening( + fn prove_opening( &self, committed: &Self::Committed, layer_depth: usize, index: usize, - proof: Proof, + proof: &mut TranscriptWriter, ) -> Result<(), Error>; } diff --git a/crates/core/src/merkle_tree/prover.rs b/crates/core/src/merkle_tree/prover.rs index 5c06b5c3..27455cc1 100644 --- a/crates/core/src/merkle_tree/prover.rs +++ b/crates/core/src/merkle_tree/prover.rs @@ -3,6 +3,7 @@ use binius_field::TowerField; use binius_hash::PseudoCompressionFunction; use binius_maybe_rayon::iter::IndexedParallelIterator; +use bytes::BufMut; use digest::{core_api::BlockSizeUser, Digest, FixedOutputReset, Output}; use getset::Getters; use tracing::instrument; @@ -13,7 +14,7 @@ use super::{ merkle_tree_vcs::{Commitment, MerkleTreeProver}, scheme::BinaryMerkleTreeScheme, }; -use crate::transcript::CanWrite; +use crate::transcript::TranscriptWriter; #[derive(Debug, Getters)] pub struct BinaryMerkleTreeProver { @@ -66,12 +67,12 @@ where committed.layer(depth) } - fn prove_opening( + fn prove_opening( &self, committed: &Self::Committed, layer_depth: usize, index: usize, - mut proof: Proof, + proof: &mut TranscriptWriter, ) -> Result<(), Error> { let branch = committed.branch(index, layer_depth)?; proof.write_slice(&branch); diff --git a/crates/core/src/merkle_tree/scheme.rs b/crates/core/src/merkle_tree/scheme.rs index 17ab10a1..f5fe52c4 100644 --- a/crates/core/src/merkle_tree/scheme.rs +++ b/crates/core/src/merkle_tree/scheme.rs @@ -8,6 +8,7 @@ use binius_utils::{ bail, checked_arithmetics::{log2_ceil_usize, log2_strict_usize}, }; +use bytes::Buf; use digest::{core_api::BlockSizeUser, Digest, Output}; use getset::Getters; @@ -15,7 +16,7 @@ use super::{ errors::{Error, VerificationError}, merkle_tree_vcs::MerkleTreeScheme, }; -use crate::transcript::CanRead; +use crate::transcript::TranscriptReader; #[derive(Debug, Getters)] pub struct BinaryMerkleTreeScheme { @@ -105,14 +106,14 @@ where Ok(()) } - fn verify_opening( + fn verify_opening( &self, index: usize, values: &[F], layer_depth: usize, tree_depth: usize, layer_digests: &[Self::Digest], - mut proof: Proof, + proof: &mut TranscriptReader, ) -> Result<(), Error> { if 1 << layer_depth != layer_digests.len() { bail!(VerificationError::IncorrectVectorLength) diff --git a/crates/core/src/merkle_tree/tests.rs b/crates/core/src/merkle_tree/tests.rs index 3d7803a6..2d20d98e 100644 --- a/crates/core/src/merkle_tree/tests.rs +++ b/crates/core/src/merkle_tree/tests.rs @@ -9,7 +9,7 @@ use groestl_crypto::Groestl256; use rand::{rngs::StdRng, SeedableRng}; use super::{BinaryMerkleTreeProver, MerkleTreeProver, MerkleTreeScheme}; -use crate::transcript::AdviceWriter; +use crate::{fiat_shamir::HasherChallenger, transcript::ProverTranscript}; #[test] fn test_binary_merkle_vcs_commit_prove_open_correctly() { @@ -25,15 +25,22 @@ fn test_binary_merkle_vcs_commit_prove_open_correctly() { assert_eq!(commitment.root, tree.root()); for (i, value) in data.iter().enumerate() { - let mut proof_writer = AdviceWriter::new(); + let mut proof_writer = ProverTranscript::>::new(); mr_prover - .prove_opening(&tree, 0, i, &mut proof_writer) + .prove_opening(&tree, 0, i, &mut proof_writer.message()) .unwrap(); - let mut proof_reader = proof_writer.into_reader(); + let mut proof_reader = proof_writer.into_verifier(); mr_prover .scheme() - .verify_opening(i, slice::from_ref(value), 0, 4, &[commitment.root], &mut proof_reader) + .verify_opening( + i, + slice::from_ref(value), + 0, + 4, + &[commitment.root], + &mut proof_reader.message(), + ) .unwrap(); } } @@ -57,15 +64,22 @@ fn test_binary_merkle_vcs_commit_layer_prove_open_correctly() { .verify_layer(&commitment.root, layer_depth, layer) .unwrap(); for (i, value) in data.iter().enumerate() { - let mut proof_writer = AdviceWriter::new(); + let mut proof_writer = ProverTranscript::>::new(); mr_prover - .prove_opening(&tree, layer_depth, i, &mut proof_writer) + .prove_opening(&tree, layer_depth, i, &mut proof_writer.message()) .unwrap(); - let mut proof_reader = proof_writer.into_reader(); + let mut proof_reader = proof_writer.into_verifier(); mr_prover .scheme() - .verify_opening(i, slice::from_ref(value), layer_depth, 5, layer, &mut proof_reader) + .verify_opening( + i, + slice::from_ref(value), + layer_depth, + 5, + layer, + &mut proof_reader.message(), + ) .unwrap(); } } diff --git a/crates/core/src/piop/prove.rs b/crates/core/src/piop/prove.rs index 41715353..d9864771 100644 --- a/crates/core/src/piop/prove.rs +++ b/crates/core/src/piop/prove.rs @@ -19,7 +19,7 @@ use super::{ verify::{make_sumcheck_claim_descs, PIOPSumcheckClaim}, }; use crate::{ - fiat_shamir::{CanSample, CanSampleBits}, + fiat_shamir::{CanSample, Challenger}, merkle_tree::{MerkleTreeProver, MerkleTreeScheme}, piop::CommitMeta, protocols::{ @@ -35,7 +35,7 @@ use crate::{ }, }, reed_solomon::reed_solomon::ReedSolomonCode, - transcript::{CanWrite, Proof}, + transcript::ProverTranscript, }; // ## Preconditions @@ -152,19 +152,7 @@ where /// The arguments corresponding to the committed multilinears must be the output of [`commit`]. #[allow(clippy::too_many_arguments)] #[tracing::instrument("piop::prove", skip_all)] -pub fn prove< - F, - FDomain, - FEncode, - P, - M, - DomainFactory, - MTScheme, - MTProver, - Transcript, - Advice, - Backend, ->( +pub fn prove( fri_params: &FRIParams, merkle_prover: &MTProver, domain_factory: DomainFactory, @@ -174,7 +162,7 @@ pub fn prove< committed_multilins: &[M], transparent_multilins: &[M], claims: &[PIOPSumcheckClaim], - proof: &mut Proof, + transcript: &mut ProverTranscript, backend: &Backend, ) -> Result<(), Error> where @@ -186,8 +174,7 @@ where DomainFactory: EvaluationDomainFactory, MTScheme: MerkleTreeScheme, MTProver: MerkleTreeProver, - Transcript: CanSample + CanWrite + CanSampleBits, - Advice: CanWrite, + Challenger_: Challenger, Backend: ComputationBackend, { // Map of n_vars to sumcheck claim descriptions @@ -245,20 +232,20 @@ where sumcheck_provers, codeword, committed, - proof, + transcript, )?; Ok(()) } -fn prove_interleaved_fri_sumcheck( +fn prove_interleaved_fri_sumcheck( n_rounds: usize, fri_params: &FRIParams, merkle_prover: &MTProver, sumcheck_provers: Vec>, codeword: &[P], committed: MTProver::Committed, - proof: &mut Proof, + transcript: &mut ProverTranscript, ) -> Result<(), Error> where F: TowerField + ExtensionField, @@ -266,30 +253,28 @@ where P: PackedFieldIndexable + PackedExtension, MTScheme: MerkleTreeScheme, MTProver: MerkleTreeProver, - Transcript: CanSample + CanWrite + CanSampleBits, - Advice: CanWrite, + Challenger_: Challenger, { let mut fri_prover = FRIFolder::new(fri_params, merkle_prover, P::unpack_scalars(codeword), &committed)?; - let mut sumcheck_batch_prover = - SumcheckBatchProver::new(sumcheck_provers, &mut proof.transcript)?; + let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?; for _ in 0..n_rounds { - sumcheck_batch_prover.send_round_proof(&mut proof.transcript)?; - let challenge = proof.transcript.sample(); + sumcheck_batch_prover.send_round_proof(&mut transcript.message())?; + let challenge = transcript.sample(); sumcheck_batch_prover.receive_challenge(challenge)?; match fri_prover.execute_fold_round(challenge)? { FoldRoundOutput::NoCommitment => {} FoldRoundOutput::Commitment(round_commitment) => { - proof.transcript.write(&round_commitment); + transcript.message().write(&round_commitment); } } } - sumcheck_batch_prover.finish(&mut proof.transcript)?; - fri_prover.finish_proof(&mut proof.advice, &mut proof.transcript)?; + sumcheck_batch_prover.finish(&mut transcript.message())?; + fri_prover.finish_proof(transcript)?; Ok(()) } diff --git a/crates/core/src/piop/tests.rs b/crates/core/src/piop/tests.rs index 89d9f9bd..69f8a054 100644 --- a/crates/core/src/piop/tests.rs +++ b/crates/core/src/piop/tests.rs @@ -27,7 +27,7 @@ use crate::{ merkle_tree::{BinaryMerkleTreeProver, MerkleTreeProver, MerkleTreeScheme}, polynomial::MultivariatePoly, protocols::fri::CommitOutput, - transcript::{AdviceWriter, CanRead, CanWrite, Proof, TranscriptWriter}, + transcript::ProverTranscript, transparent, }; @@ -150,11 +150,8 @@ fn commit_prove_verify( let sumcheck_claims = make_sumcheck_claims(&committed_multilins, transparent_multilins.as_slice()); - let mut proof = Proof { - transcript: TranscriptWriter::>::default(), - advice: AdviceWriter::default(), - }; - proof.transcript.write(&commitment); + let mut proof = ProverTranscript::>::new(); + proof.message().write(&commitment); let domain_factory = DefaultEvaluationDomainFactory::::default(); prove( @@ -189,7 +186,7 @@ fn commit_prove_verify( .map(|poly| poly as &dyn MultivariatePoly) .collect::>(); - let commitment = proof.transcript.read().unwrap(); + let commitment = proof.message().read().unwrap(); verify( commit_meta, merkle_scheme, diff --git a/crates/core/src/piop/verify.rs b/crates/core/src/piop/verify.rs index 818d8146..96858f7f 100644 --- a/crates/core/src/piop/verify.rs +++ b/crates/core/src/piop/verify.rs @@ -12,7 +12,7 @@ use tracing::instrument; use super::error::{Error, VerificationError}; use crate::{ composition::{BivariateProduct, IndexComposition}, - fiat_shamir::{CanSample, CanSampleBits}, + fiat_shamir::{CanSample, Challenger}, merkle_tree::MerkleTreeScheme, piop::util::ResizeableIndex, polynomial::MultivariatePoly, @@ -23,7 +23,7 @@ use crate::{ }, }, reed_solomon::reed_solomon::ReedSolomonCode, - transcript::{CanRead, Proof}, + transcript::VerifierTranscript, }; /// Metadata about a batch of committed multilinear polynomials. @@ -278,20 +278,19 @@ pub fn make_sumcheck_claim_descs( /// described by `commit_meta` and the transparent polynomials in `transparents` /// * `proof` - the proof reader #[instrument("piop::verify", skip_all)] -pub fn verify<'a, F, FEncode, Transcript, Advice, MTScheme>( +pub fn verify<'a, F, FEncode, Challenger_, MTScheme>( commit_meta: &CommitMeta, merkle_scheme: &MTScheme, fri_params: &FRIParams, commitment: &MTScheme::Digest, transparents: &[impl Borrow + 'a>], claims: &[PIOPSumcheckClaim], - proof: &mut Proof, + transcript: &mut VerifierTranscript, ) -> Result<(), Error> where F: TowerField + ExtensionField, FEncode: BinaryField, - Transcript: CanSample + CanRead + CanSampleBits, - Advice: CanRead, + Challenger_: Challenger, MTScheme: MerkleTreeScheme, { // Map of n_vars to sumcheck claim descriptions @@ -329,7 +328,7 @@ where merkle_scheme, &sumcheck_claims, commitment, - proof, + transcript, )?; let mut piecewise_evals = verify_transparent_evals( @@ -401,37 +400,35 @@ struct BatchInterleavedSumcheckFRIOutput { /// * `n_rounds` is greater than or equal to the maximum number of variables of any claim /// * `claims` are sorted in ascending order by number of variables #[instrument(skip_all)] -fn verify_interleaved_fri_sumcheck( +fn verify_interleaved_fri_sumcheck( n_rounds: usize, fri_params: &FRIParams, merkle_scheme: &MTScheme, claims: &[SumcheckClaim>], codeword_commitment: &MTScheme::Digest, - proof: &mut Proof, + proof: &mut VerifierTranscript, ) -> Result, Error> where F: TowerField + ExtensionField, FEncode: BinaryField, - Transcript: CanSample + CanRead + CanSampleBits, - Advice: CanRead, + Challenger_: Challenger, MTScheme: MerkleTreeScheme, { let mut arities_iter = fri_params.fold_arities().iter(); let mut fri_commitments = Vec::with_capacity(fri_params.n_oracles()); let mut next_commit_round = arities_iter.next().copied(); - let mut sumcheck_verifier = SumcheckBatchVerifier::new(claims, &mut proof.transcript)?; + let mut sumcheck_verifier = SumcheckBatchVerifier::new(claims, proof)?; let mut multilinear_evals = Vec::with_capacity(claims.len()); let mut challenges = Vec::with_capacity(n_rounds); for round_no in 0..n_rounds { - while let Some(claim_multilinear_evals) = - sumcheck_verifier.try_finish_claim(&mut proof.transcript)? - { + let mut reader = proof.message(); + while let Some(claim_multilinear_evals) = sumcheck_verifier.try_finish_claim(&mut reader)? { multilinear_evals.push(claim_multilinear_evals); } - sumcheck_verifier.receive_round_proof(&mut proof.transcript)?; + sumcheck_verifier.receive_round_proof(&mut reader)?; - let challenge = proof.transcript.sample(); + let challenge = proof.sample(); challenges.push(challenge); sumcheck_verifier.finish_round(challenge)?; @@ -439,7 +436,7 @@ where let observe_fri_comm = next_commit_round.is_some_and(|round| round == round_no + 1); if observe_fri_comm { let comm = proof - .transcript + .message() .read() .map_err(VerificationError::Transcript)?; fri_commitments.push(comm); @@ -447,9 +444,8 @@ where } } - while let Some(claim_multilinear_evals) = - sumcheck_verifier.try_finish_claim(&mut proof.transcript)? - { + let mut reader = proof.message(); + while let Some(claim_multilinear_evals) = sumcheck_verifier.try_finish_claim(&mut reader)? { multilinear_evals.push(claim_multilinear_evals); } sumcheck_verifier.finish()?; @@ -461,7 +457,7 @@ where &fri_commitments, &challenges, )?; - let fri_final = verifier.verify(&mut proof.advice, &mut proof.transcript)?; + let fri_final = verifier.verify(proof)?; Ok(BatchInterleavedSumcheckFRIOutput { challenges, diff --git a/crates/core/src/polynomial/arith_circuit.rs b/crates/core/src/polynomial/arith_circuit.rs index b406bfa2..ec5c4e57 100644 --- a/crates/core/src/polynomial/arith_circuit.rs +++ b/crates/core/src/polynomial/arith_circuit.rs @@ -6,6 +6,8 @@ use binius_field::{ExtensionField, Field, PackedField, TowerField}; use binius_math::{ArithExpr, CompositionPoly, CompositionPolyOS, Error}; use stackalloc::{helpers::slice_assume_init, stackalloc_uninit}; +use super::MultivariatePoly; + /// Convert the expression to a sequence of arithmetic operations that can be evaluated in sequence. fn circuit_steps_for_expr( expr: &ArithExpr, @@ -319,6 +321,24 @@ impl>> CompositionPolyOS } } +impl MultivariatePoly for ArithCircuitPoly { + fn degree(&self) -> usize { + CompositionPoly::degree(&self) + } + + fn n_vars(&self) -> usize { + CompositionPoly::n_vars(&self) + } + + fn binary_tower_level(&self) -> usize { + CompositionPoly::binary_tower_level(&self) + } + + fn evaluate(&self, query: &[F]) -> Result { + CompositionPoly::evaluate(&self, query).map_err(|e| e.into()) + } +} + /// Apply a binary operation to two arguments and store the result in `current_evals`. /// `op` must be a function that takes two arguments and initialized the result with the third argument. fn apply_binary_op>>( diff --git a/crates/core/src/protocols/evalcheck/evalcheck.rs b/crates/core/src/protocols/evalcheck/evalcheck.rs index 186f06c1..32e36384 100644 --- a/crates/core/src/protocols/evalcheck/evalcheck.rs +++ b/crates/core/src/protocols/evalcheck/evalcheck.rs @@ -7,11 +7,12 @@ use std::{ }; use binius_field::{Field, TowerField}; +use bytes::{Buf, BufMut}; use super::error::Error; use crate::{ oracle::{MultilinearPolyOracle, OracleId}, - transcript::{CanRead, CanWrite}, + transcript::{TranscriptReader, TranscriptWriter}, }; #[derive(Debug, Clone)] @@ -88,8 +89,8 @@ impl EvalcheckNumerics { } /// Serializes the `EvalcheckProof` into the transcript -pub fn serialize_evalcheck_proof( - transcript: &mut Transcript, +pub fn serialize_evalcheck_proof( + transcript: &mut TranscriptWriter, evalcheck: &EvalcheckProof, ) { match evalcheck { @@ -127,8 +128,8 @@ pub fn serialize_evalcheck_proof( } /// Deserializes the `EvalcheckProof` object from the given transcript. -pub fn deserialize_evalcheck_proof( - transcript: &mut Transcript, +pub fn deserialize_evalcheck_proof( + transcript: &mut TranscriptReader, ) -> Result, Error> { let mut ty = 0; transcript.read_bytes(slice::from_mut(&mut ty))?; diff --git a/crates/core/src/protocols/evalcheck/subclaims.rs b/crates/core/src/protocols/evalcheck/subclaims.rs index 541f3367..c093d76a 100644 --- a/crates/core/src/protocols/evalcheck/subclaims.rs +++ b/crates/core/src/protocols/evalcheck/subclaims.rs @@ -25,7 +25,7 @@ use binius_utils::bail; use super::{error::Error, evalcheck::EvalcheckMultilinearClaim}; use crate::{ - fiat_shamir::CanSample, + fiat_shamir::Challenger, oracle::{ ConstraintSet, ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet, OracleId, Packed, ProjectionVariant, Shifted, @@ -36,7 +36,7 @@ use crate::{ prove::oracles::{constraint_sets_sumcheck_provers_metas, SumcheckProversWithMetas}, Error as SumcheckError, }, - transcript::CanWrite, + transcript::ProverTranscript, transparent::{shift_ind::ShiftIndPartialEval, tower_basis::TowerBasis}, witness::{MultilinearExtensionIndex, MultilinearWitness}, }; @@ -398,11 +398,11 @@ impl MemoizedQueries { type SumcheckProofEvalcheckClaims = Vec>; -pub fn prove_bivariate_sumchecks_with_switchover( +pub fn prove_bivariate_sumchecks_with_switchover( oracles: &MultilinearOracleSet, witness: &MultilinearExtensionIndex, constraint_sets: Vec>, - transcript: &mut Transcript, + transcript: &mut ProverTranscript, switchover_fn: impl Fn(usize) -> usize + 'static, domain_factory: impl EvaluationDomainFactory, backend: &Backend, @@ -411,7 +411,7 @@ where U: UnderlierType + PackScalar + PackScalar, F: TowerField + ExtensionField, DomainField: Field, - Transcript: CanSample + CanWrite, + Challenger_: Challenger, Backend: ComputationBackend, { let SumcheckProversWithMetas { provers, metas } = constraint_sets_sumcheck_provers_metas( diff --git a/crates/core/src/protocols/evalcheck/tests.rs b/crates/core/src/protocols/evalcheck/tests.rs index 80f221f2..18179fbd 100644 --- a/crates/core/src/protocols/evalcheck/tests.rs +++ b/crates/core/src/protocols/evalcheck/tests.rs @@ -500,15 +500,16 @@ fn test_evalcheck_serialization() { .chain(transparent[..20].iter()), ); - let mut transcript = crate::transcript::TranscriptWriter::< + let mut transcript = crate::transcript::ProverTranscript::< crate::fiat_shamir::HasherChallenger, - >::default(); + >::new(); - serialize_evalcheck_proof(&mut transcript, &second_linear_combination); - let mut transcript = transcript.into_reader(); + let mut writer = transcript.message(); + serialize_evalcheck_proof(&mut writer, &second_linear_combination); + let mut transcript = transcript.into_verifier(); + let mut reader = transcript.message(); - let out_deserialized = - deserialize_evalcheck_proof::<_, BinaryField128b>(&mut transcript).unwrap(); + let out_deserialized = deserialize_evalcheck_proof::<_, BinaryField128b>(&mut reader).unwrap(); assert_eq!(out_deserialized, second_linear_combination); diff --git a/crates/core/src/protocols/fri/common.rs b/crates/core/src/protocols/fri/common.rs index d4c0a2be..383f3c1a 100644 --- a/crates/core/src/protocols/fri/common.rs +++ b/crates/core/src/protocols/fri/common.rs @@ -9,7 +9,7 @@ use binius_utils::bail; use getset::{CopyGetters, Getters}; use crate::{ - linear_code::LinearCode, merkle_tree::MerkleTreeScheme, protocols::fri::Error, + merkle_tree::MerkleTreeScheme, protocols::fri::Error, reed_solomon::reed_solomon::ReedSolomonCode, }; diff --git a/crates/core/src/protocols/fri/prove.rs b/crates/core/src/protocols/fri/prove.rs index 74dff046..f80f0dec 100644 --- a/crates/core/src/protocols/fri/prove.rs +++ b/crates/core/src/protocols/fri/prove.rs @@ -5,6 +5,7 @@ use binius_hal::{make_portable_backend, ComputationBackend}; use binius_maybe_rayon::prelude::*; use binius_utils::{bail, serialization::SerializeBytes}; use bytemuck::zeroed_vec; +use bytes::BufMut; use itertools::izip; use tracing::instrument; @@ -14,12 +15,11 @@ use super::{ TerminateCodeword, }; use crate::{ - fiat_shamir::CanSampleBits, - linear_code::LinearCode, + fiat_shamir::{CanSampleBits, Challenger}, merkle_tree::{MerkleTreeProver, MerkleTreeScheme}, protocols::fri::common::{fold_chunk, fold_interleaved_chunk}, reed_solomon::reed_solomon::ReedSolomonCode, - transcript::CanWrite, + transcript::{ProverTranscript, TranscriptWriter}, }; #[instrument(skip_all, level = "debug")] @@ -435,16 +435,15 @@ where Ok((terminate_codeword, query_prover)) } - pub fn finish_proof( + pub fn finish_proof( self, - advice: &mut Advice, - transcript: &mut Transcript, + transcript: &mut ProverTranscript, ) -> Result<(), Error> where - Transcript: CanSampleBits, - Advice: CanWrite, + Challenger_: Challenger, { let (terminate_codeword, query_prover) = self.finalize()?; + let mut advice = transcript.decommitment(); advice.write_scalar_slice(&terminate_codeword); let layers = query_prover.vcs_optimal_layers()?; @@ -453,11 +452,10 @@ where } let params = query_prover.params; - let indexes_iter = std::iter::repeat_with(|| transcript.sample_bits(params.index_bits())) - .take(params.n_test_queries()); - for index in indexes_iter { - query_prover.prove_query(index, advice)?; + for _ in 0..params.n_test_queries() { + let index = transcript.sample_bits(params.index_bits()); + query_prover.prove_query(index, transcript.decommitment())?; } Ok(()) @@ -497,9 +495,13 @@ where /// /// * `index` - an index into the original codeword domain #[instrument(skip_all, name = "fri::FRIQueryProver::prove_query", level = "debug")] - pub fn prove_query(&self, mut index: usize, advice: &mut Advice) -> Result<(), Error> + pub fn prove_query( + &self, + mut index: usize, + mut advice: TranscriptWriter, + ) -> Result<(), Error> where - Advice: CanWrite, + B: BufMut, { let mut arities_and_optimal_layers_depths = self .params @@ -524,7 +526,7 @@ where index, first_fold_arity, first_optimal_layer_depth, - advice, + &mut advice, )?; for ((codeword, committed), (arity, optimal_layer_depth)) in @@ -538,7 +540,7 @@ where index, arity, optimal_layer_depth, - advice, + &mut advice, )?; } @@ -561,19 +563,19 @@ where } } -fn prove_coset_opening( +fn prove_coset_opening( merkle_prover: &MTProver, codeword: &[F], committed: &MTProver::Committed, coset_index: usize, log_coset_size: usize, optimal_layer_depth: usize, - advice: &mut Advice, + advice: &mut TranscriptWriter, ) -> Result<(), Error> where F: TowerField, MTProver: MerkleTreeProver, - Advice: CanWrite, + B: BufMut, { let values = &codeword[(coset_index << log_coset_size)..((coset_index + 1) << log_coset_size)]; advice.write_scalar_slice(values); diff --git a/crates/core/src/protocols/fri/tests.rs b/crates/core/src/protocols/fri/tests.rs index ab6b90b9..43ac3281 100644 --- a/crates/core/src/protocols/fri/tests.rs +++ b/crates/core/src/protocols/fri/tests.rs @@ -20,14 +20,13 @@ use rand::prelude::*; use super::to_par_scalar_big_chunks; use crate::{ fiat_shamir::{CanSample, HasherChallenger}, - linear_code::LinearCode, merkle_tree::BinaryMerkleTreeProver, protocols::fri::{ self, to_par_scalar_small_chunks, CommitOutput, FRIFolder, FRIParams, FRIVerifier, FoldRoundOutput, }, reed_solomon::reed_solomon::ReedSolomonCode, - transcript::{AdviceWriter, CanRead, CanWrite, TranscriptWriter}, + transcript::ProverTranscript, }; fn test_commit_prove_verify_success( @@ -85,48 +84,35 @@ fn test_commit_prove_verify_success( ) .unwrap(); - let mut prover_challenger = crate::transcript::Proof { - transcript: TranscriptWriter::>::default(), - advice: AdviceWriter::default(), - }; - prover_challenger.transcript.write(&codeword_commitment); + let mut prover_challenger = ProverTranscript::>::new(); + prover_challenger.message().write(&codeword_commitment); let mut round_commitments = Vec::with_capacity(params.n_oracles()); for _i in 0..params.n_fold_rounds() { - let challenge = prover_challenger.transcript.sample(); + let challenge = prover_challenger.sample(); let fold_round_output = round_prover.execute_fold_round(challenge).unwrap(); match fold_round_output { FoldRoundOutput::NoCommitment => {} FoldRoundOutput::Commitment(round_commitment) => { - prover_challenger.transcript.write(&round_commitment); + prover_challenger.message().write(&round_commitment); round_commitments.push(round_commitment); } } } - round_prover - .finish_proof(&mut prover_challenger.advice, &mut prover_challenger.transcript) - .unwrap(); + round_prover.finish_proof(&mut prover_challenger).unwrap(); // Now run the verifier let mut verifier_challenger = prover_challenger.into_verifier(); - codeword_commitment = verifier_challenger.transcript.read().unwrap(); + codeword_commitment = verifier_challenger.message().read().unwrap(); let mut verifier_challenges = Vec::with_capacity(params.n_fold_rounds()); assert_eq!(round_commitments.len(), n_round_commitments); for (i, commitment) in round_commitments.iter().enumerate() { - verifier_challenges.append( - &mut verifier_challenger - .transcript - .sample_vec(params.fold_arities()[i]), - ); + verifier_challenges.append(&mut verifier_challenger.sample_vec(params.fold_arities()[i])); let mut _commitment = *commitment; - _commitment = verifier_challenger.transcript.read().unwrap(); + _commitment = verifier_challenger.message().read().unwrap(); } - verifier_challenges.append( - &mut verifier_challenger - .transcript - .sample_vec(params.n_final_challenges()), - ); + verifier_challenges.append(&mut verifier_challenger.sample_vec(params.n_final_challenges())); assert_eq!(verifier_challenges.len(), params.n_fold_rounds()); @@ -149,9 +135,7 @@ fn test_commit_prove_verify_success( ) .unwrap(); - let final_fri_value = verifier - .verify(&mut verifier_challenger.advice, &mut verifier_challenger.transcript) - .unwrap(); + let final_fri_value = verifier.verify(&mut verifier_challenger).unwrap(); assert_eq!(computed_eval, final_fri_value); } diff --git a/crates/core/src/protocols/fri/verify.rs b/crates/core/src/protocols/fri/verify.rs index 6dfa9662..5fc9285d 100644 --- a/crates/core/src/protocols/fri/verify.rs +++ b/crates/core/src/protocols/fri/verify.rs @@ -1,19 +1,20 @@ // Copyright 2024-2025 Irreducible Inc. -use std::{iter, iter::repeat_with}; +use std::iter; use binius_field::{BinaryField, ExtensionField, TowerField}; use binius_hal::{make_portable_backend, ComputationBackend}; use binius_utils::{bail, serialization::DeserializeBytes}; +use bytes::Buf; use itertools::izip; use tracing::instrument; use super::{common::vcs_optimal_layers_depths_iter, error::Error, VerificationError}; use crate::{ - fiat_shamir::CanSampleBits, + fiat_shamir::{CanSampleBits, Challenger}, merkle_tree::MerkleTreeScheme, protocols::fri::common::{fold_chunk, fold_interleaved_chunk, FRIParams}, - transcript::CanRead, + transcript::{TranscriptReader, VerifierTranscript}, }; /// A verifier for the FRI query phase. @@ -91,18 +92,17 @@ where self.params.n_oracles() } - pub fn verify( + pub fn verify( &self, - advice: &mut Advice, - transcript: &mut Transcript, + transcript: &mut VerifierTranscript, ) -> Result where - Transcript: CanSampleBits, - Advice: CanRead, + Challenger_: Challenger, { // Verify that the last oracle sent is a codeword. let terminate_codeword_len = 1 << (self.params.n_final_challenges() + self.params.rs_code().log_inv_rate()); + let mut advice = transcript.decommitment(); let terminate_codeword = advice .read_scalar_slice(terminate_codeword_len) .map_err(Error::TranscriptError)?; @@ -123,16 +123,15 @@ where } // Verify the random openings against the decommitted layers. - let indexes_iter = repeat_with(|| transcript.sample_bits(self.params.index_bits())) - .take(self.params.n_test_queries()); let mut scratch_buffer = self.create_scratch_buffer(); - for index in indexes_iter { + for _ in 0..self.params.n_test_queries() { + let index = transcript.sample_bits(self.params.index_bits()); self.verify_query_internal( index, &terminate_codeword, &layers, - advice, + &mut transcript.decommitment(), &mut scratch_buffer, )? } @@ -220,12 +219,12 @@ where /// /// * `index` - an index into the original codeword domain /// * `proof` - a query proof - pub fn verify_query( + pub fn verify_query( &self, index: usize, terminate_codeword: &[F], layers: &[Vec], - advice: &mut Advice, + advice: &mut TranscriptReader, ) -> Result<(), Error> { self.verify_query_internal( index, @@ -237,12 +236,12 @@ where } #[instrument(skip_all, name = "fri::FRIVerifier::verify_query", level = "debug")] - fn verify_query_internal( + fn verify_query_internal( &self, mut index: usize, terminate_codeword: &[F], layers: &[Vec], - advice: &mut Advice, + advice: &mut TranscriptReader, scratch_buffer: &mut [F], ) -> Result<(), Error> { let mut arities_iter = self.params.fold_arities().iter().copied(); @@ -352,19 +351,19 @@ where /// Verifies that the coset opening provided in the proof is consistent with the VCS commitment. #[allow(clippy::too_many_arguments)] -fn verify_coset_opening( +fn verify_coset_opening( vcs: &MTScheme, coset_index: usize, log_coset_size: usize, optimal_layer_depth: usize, tree_depth: usize, layer_digests: &[MTScheme::Digest], - advice: &mut Advice, + advice: &mut TranscriptReader, ) -> Result, Error> where F: TowerField, MTScheme: MerkleTreeScheme, - Advice: CanRead, + B: Buf, { let values = advice.read_scalar_slice::(1 << log_coset_size)?; vcs.verify_opening( diff --git a/crates/core/src/protocols/gkr_gpa/gkr_gpa.rs b/crates/core/src/protocols/gkr_gpa/gkr_gpa.rs index 97bb51fc..1706c698 100644 --- a/crates/core/src/protocols/gkr_gpa/gkr_gpa.rs +++ b/crates/core/src/protocols/gkr_gpa/gkr_gpa.rs @@ -21,6 +21,15 @@ pub struct GrandProductClaim { pub product: F, } +impl GrandProductClaim { + pub fn isomorphic>(self) -> GrandProductClaim { + GrandProductClaim { + n_vars: self.n_vars, + product: self.product.into(), + } + } +} + #[derive(Debug, Clone)] pub struct GrandProductWitness { n_vars: usize, @@ -149,6 +158,18 @@ pub struct LayerClaim { pub eval: F, } +impl LayerClaim { + pub fn isomorphic(self) -> LayerClaim + where + F: Into, + { + LayerClaim { + eval_point: self.eval_point.into_iter().map(Into::into).collect(), + eval: self.eval.into(), + } + } +} + #[derive(Debug, Default)] pub struct GrandProductBatchProveOutput { // Reduced evalcheck claims for all the initial grand product claims diff --git a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/tests.rs b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/tests.rs index ef076870..837e2f84 100644 --- a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/tests.rs +++ b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/tests.rs @@ -25,7 +25,7 @@ use crate::{ sumcheck::{self, zerocheck::ExtraProduct, CompositeSumClaim, SumcheckClaim}, test_utils::AddOneComposition, }, - transcript::TranscriptWriter, + transcript::ProverTranscript, }; fn test_prove_verify_bivariate_product_helper( @@ -60,7 +60,7 @@ fn test_prove_verify_bivariate_product_helper( .evaluate(MultilinearQuery::expand(&gpa_round_challenges).to_ref()) .unwrap(); - let mut transcript = TranscriptWriter::>::default(); + let mut transcript = ProverTranscript::>::new(); let backend = make_portable_backend(); let evaluation_domain_factory = DefaultEvaluationDomainFactory::::default(); @@ -84,7 +84,7 @@ fn test_prove_verify_bivariate_product_helper( let _sumcheck_proof_output = sumcheck::batch_prove(vec![prover], &mut transcript).unwrap(); - let mut verifier_transcript = transcript.into_reader(); + let mut verifier_transcript = transcript.into_verifier(); let verifier_composite_claim = CompositeSumClaim { sum, diff --git a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/verify.rs b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/verify.rs index 41790f62..0e64ceaf 100644 --- a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/verify.rs +++ b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/verify.rs @@ -138,7 +138,7 @@ mod tests { }, sumcheck, }, - transcript::TranscriptWriter, + transcript::ProverTranscript, }; fn generate_poly_helper

( @@ -184,7 +184,7 @@ mod tests { .collect::>(); let prod_multilin = prod_mle.specialize_arc_dyn::>(); - let mut prove_transcript = TranscriptWriter::>::default(); + let mut prove_transcript = ProverTranscript::>::new(); let challenges: Vec = prove_transcript.sample_vec(n_vars); let sum = prod_multilin @@ -215,7 +215,7 @@ mod tests { let sumcheck_claim = reduce_to_sumcheck(&[claim]).unwrap(); let sumcheck_claims = [sumcheck_claim]; - let mut verify_challenger = prove_transcript.into_reader(); + let mut verify_challenger = prove_transcript.into_verifier(); let _: Vec = verify_challenger.sample_vec(n_vars); let batch_output = sumcheck::batch_verify(&sumcheck_claims, &mut verify_challenger).unwrap(); diff --git a/crates/core/src/protocols/gkr_gpa/oracles.rs b/crates/core/src/protocols/gkr_gpa/oracles.rs index b135be1c..6d7d0824 100644 --- a/crates/core/src/protocols/gkr_gpa/oracles.rs +++ b/crates/core/src/protocols/gkr_gpa/oracles.rs @@ -38,8 +38,8 @@ where pub fn get_grand_products_from_witnesses(witnesses: &[GrandProductWitness]) -> Vec where - PW: PackedField, - F: Field + From, + PW: PackedField>, + F: Field, { witnesses .iter() diff --git a/crates/core/src/protocols/gkr_gpa/prove.rs b/crates/core/src/protocols/gkr_gpa/prove.rs index 1ea9e055..6f37681b 100644 --- a/crates/core/src/protocols/gkr_gpa/prove.rs +++ b/crates/core/src/protocols/gkr_gpa/prove.rs @@ -22,9 +22,9 @@ use super::{ }; use crate::{ composition::{BivariateProduct, IndexComposition}, - fiat_shamir::CanSample, + fiat_shamir::{CanSample, Challenger}, protocols::sumcheck::{self, CompositeSumClaim}, - transcript::CanWrite, + transcript::ProverTranscript, }; /// Proves batch reduction turning each GrandProductClaim into an EvalcheckMultilinearClaim @@ -33,11 +33,11 @@ use crate::{ /// * witnesses and claims are of the same length /// * The ith witness corresponds to the ith claim #[instrument(skip_all, name = "gkr_gpa::batch_prove", level = "debug")] -pub fn batch_prove( +pub fn batch_prove( witnesses: impl IntoIterator>, claims: &[GrandProductClaim], evaluation_domain_factory: impl EvaluationDomainFactory, - mut transcript: Transcript, + transcript: &mut ProverTranscript, backend: &Backend, ) -> Result, Error> where @@ -45,7 +45,7 @@ where P: PackedFieldIndexable + PackedExtension, FDomain: Field, P::Scalar: Field + ExtensionField, - Transcript: CanSample + CanWrite, + Challenger_: Challenger, Backend: ComputationBackend, { // Ensure witnesses and claims are of the same length, zip them together @@ -94,7 +94,7 @@ where evaluation_domain_factory.clone(), )?; - sumcheck::batch_prove(vec![gpa_sumcheck_prover], &mut transcript)? + sumcheck::batch_prove(vec![gpa_sumcheck_prover], transcript)? }; // Step 3: Sample a challenge for the next layer diff --git a/crates/core/src/protocols/gkr_gpa/tests.rs b/crates/core/src/protocols/gkr_gpa/tests.rs index ad602f9c..4ea1c347 100644 --- a/crates/core/src/protocols/gkr_gpa/tests.rs +++ b/crates/core/src/protocols/gkr_gpa/tests.rs @@ -20,7 +20,7 @@ use crate::{ fiat_shamir::HasherChallenger, oracle::MultilinearOracleSet, protocols::gkr_gpa::{batch_prove, batch_verify, GrandProductBatchProveOutput}, - transcript::TranscriptWriter, + transcript::ProverTranscript, witness::MultilinearExtensionIndex, }; @@ -172,7 +172,7 @@ where // Prove and Verify let _ = (oracle_set, witness_index, rng); - let mut prover_transcript = TranscriptWriter::>::default(); + let mut prover_transcript = ProverTranscript::>::new(); let GrandProductBatchProveOutput { final_layer_claims: final_layer_claim, } = batch_prove::<_, _, FS, _, _>( @@ -184,7 +184,7 @@ where ) .unwrap(); - let mut verify_transcript = prover_transcript.into_reader(); + let mut verify_transcript = prover_transcript.into_verifier(); let verified_evalcheck_multilinear_claims = batch_verify(claims.clone(), &mut verify_transcript).unwrap(); diff --git a/crates/core/src/protocols/gkr_gpa/verify.rs b/crates/core/src/protocols/gkr_gpa/verify.rs index c06fe030..056f4cb7 100644 --- a/crates/core/src/protocols/gkr_gpa/verify.rs +++ b/crates/core/src/protocols/gkr_gpa/verify.rs @@ -13,17 +13,21 @@ use super::{ gpa_sumcheck::verify::{reduce_to_sumcheck, verify_sumcheck_outputs, GPASumcheckClaim}, Error, GrandProductClaim, }; -use crate::{fiat_shamir::CanSample, protocols::sumcheck, transcript::CanRead}; +use crate::{ + fiat_shamir::{CanSample, Challenger}, + protocols::sumcheck, + transcript::VerifierTranscript, +}; /// Verifies batch reduction turning each GrandProductClaim into an EvalcheckMultilinearClaim #[instrument(skip_all, name = "gkr_gpa::batch_verify", level = "debug")] -pub fn batch_verify( +pub fn batch_verify( claims: impl IntoIterator>, - mut transcript: Transcript, + transcript: &mut VerifierTranscript, ) -> Result>, Error> where F: TowerField, - Transcript: CanSample + CanRead, + Challenger_: Challenger, { let (original_indices, mut sorted_claims) = stable_sort(claims, |claim| claim.n_vars, true); let max_n_vars = sorted_claims.first().map(|claim| claim.n_vars).unwrap_or(0); @@ -50,7 +54,7 @@ where &mut reverse_sorted_evalcheck_claims, ); - layer_claims = reduce_layer_claim_batch(layer_claims, &mut transcript)?; + layer_claims = reduce_layer_claim_batch(layer_claims, transcript)?; } process_finished_claims( n_claims, @@ -97,13 +101,13 @@ fn process_finished_claims( /// * `claims` - The kth layer LayerClaims /// * `proof` - The batch layer proof that reduces the kth layer claims of the product circuits to the (k+1)th /// * `transcript` - The verifier transcript -fn reduce_layer_claim_batch( +fn reduce_layer_claim_batch( claims: Vec>, - mut transcript: Transcript, + transcript: &mut VerifierTranscript, ) -> Result>, Error> where F: TowerField, - Transcript: CanSample + CanRead, + Challenger_: Challenger, { // Validation if claims.is_empty() { @@ -127,7 +131,7 @@ where let sumcheck_claim = reduce_to_sumcheck(&gpa_sumcheck_claims)?; let sumcheck_claims = [sumcheck_claim]; - let batch_sumcheck_output = sumcheck::batch_verify(&sumcheck_claims, &mut transcript)?; + let batch_sumcheck_output = sumcheck::batch_verify(&sumcheck_claims, transcript)?; let batch_sumcheck_output = verify_sumcheck_outputs(&gpa_sumcheck_claims, curr_layer_challenge, batch_sumcheck_output)?; diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs index 6a03c9e0..24176c73 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs @@ -14,12 +14,12 @@ use super::{ utils::first_layer_inverse, witness::GeneratorExponentWitness, }; use crate::{ - fiat_shamir::CanSample, + fiat_shamir::Challenger, protocols::{ gkr_gpa::{gpa_sumcheck::prove::GPAProver, Error, LayerClaim}, sumcheck::{self, CompositeSumClaim}, }, - transcript::CanWrite, + transcript::ProverTranscript, }; pub fn prove< @@ -29,14 +29,14 @@ pub fn prove< PChallenge, PGenerator, FDomain, - Transcript, + Challenger_, Backend, const EXPONENT_BIT_WIDTH: usize, >( witness: &GeneratorExponentWitness<'_, PBits, PGenerator, PChallenge, EXPONENT_BIT_WIDTH>, claim: &LayerClaim, // this is a claim about the evaluation of the result layer at a random point evaluation_domain_factory: impl EvaluationDomainFactory, - mut transcript: Transcript, + transcript: &mut ProverTranscript, backend: &Backend, ) -> Result, Error> where @@ -53,7 +53,7 @@ where F: ExtensionField + ExtensionField + BinaryField + TowerField, FGenerator: Field + TowerField, Backend: ComputationBackend, - Transcript: CanSample + CanWrite, + Challenger_: Challenger, { let mut eval_claims_on_bit_columns: [_; EXPONENT_BIT_WIDTH] = array::from_fn(|_| LayerClaim::::default()); @@ -91,8 +91,7 @@ where backend, )?; - let sumcheck_proof_output = - sumcheck::batch_prove(vec![this_round_prover], &mut transcript)?; + let sumcheck_proof_output = sumcheck::batch_prove(vec![this_round_prover], transcript)?; eval_point = sumcheck_proof_output.challenges.clone(); eval = sumcheck_proof_output.multilinear_evals[0][0]; diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs index 4134a1f6..044e6778 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs @@ -19,12 +19,12 @@ use rand::{thread_rng, Rng}; use super::{common::GeneratorExponentReductionOutput, prove}; use crate::{ - fiat_shamir::{CanSample, HasherChallenger}, + fiat_shamir::{Challenger, HasherChallenger}, protocols::{ gkr_gpa::LayerClaim, gkr_int_mul::generator_exponent::{verify, witness::GeneratorExponentWitness}, }, - transcript::{CanWrite, TranscriptWriter}, + transcript::ProverTranscript, }; type PBits = PackedBinaryField128x1b; @@ -66,16 +66,16 @@ fn generate_witness_and_prove< const LOG_SIZE: usize, const COLUMN_LEN: usize, const EXPONENT_BIT_WIDTH: usize, - Transcript, + Challenger_, >( - transcript: &mut Transcript, + transcript: &mut ProverTranscript, ) -> ( LayerClaim, GeneratorExponentWitness<'static, PBits, PGenerator, PChallenge, 64>, GeneratorExponentReductionOutput, ) where - Transcript: CanSample + CanWrite, + Challenger_: Challenger, { let mut rng = thread_rng(); @@ -135,7 +135,7 @@ fn prove_reduces_to_correct_claims() { const COLUMN_LEN: usize = 1usize << LOG_SIZE; const EXPONENT_BIT_WIDTH: usize = 64usize; - let mut transcript = TranscriptWriter::>::default(); + let mut transcript = ProverTranscript::>::new(); let reduced_claims = generate_witness_and_prove::(&mut transcript); @@ -164,18 +164,20 @@ fn good_proof_verifies() { const COLUMN_LEN: usize = 1usize << LOG_SIZE; const EXPONENT_BIT_WIDTH: usize = 64usize; - let mut transcript = TranscriptWriter::>::default(); + let mut transcript = ProverTranscript::>::new(); let reduced_claims = generate_witness_and_prove::(&mut transcript); let (claim, _, _) = reduced_claims; - let verifier_transcript = transcript.into_reader(); + let mut verifier_transcript = transcript.into_verifier(); let _reduced_claims = verify::verify::( &claim, - verifier_transcript, + &mut verifier_transcript, LOG_SIZE, ) .unwrap(); + + verifier_transcript.finalize().unwrap() } diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs index 3a26a57e..1024481a 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs @@ -9,14 +9,14 @@ use super::{ super::error::Error, common::GeneratorExponentReductionOutput, utils::first_layer_inverse, }; use crate::{ - fiat_shamir::CanSample, + fiat_shamir::Challenger, polynomial::MultivariatePoly, protocols::{ gkr_gpa::LayerClaim, gkr_int_mul::generator_exponent::compositions::MultiplyOrDont, sumcheck::{self, zerocheck::ExtraProduct, CompositeSumClaim, SumcheckClaim}, }, - transcript::CanRead, + transcript::VerifierTranscript, transparent::eq_ind::EqIndPartialEval, }; @@ -35,15 +35,15 @@ use crate::{ /// Input: One evaluation claim on n /// /// Output: EXPONENT_BITS_WIDTH separate claims (at different points) on each of the a_i's -pub fn verify( +pub fn verify( claim: &LayerClaim, - mut transcript: Transcript, + transcript: &mut VerifierTranscript, log_size: usize, ) -> Result, Error> where FGenerator: TowerField, F: TowerField + ExtensionField, - Transcript: CanSample + CanRead, + Challenger_: Challenger, { let mut eval_claims_on_bit_columns: [_; EXPONENT_BIT_WIDTH] = array::from_fn(|_| LayerClaim::::default()); @@ -68,7 +68,7 @@ where )?; let sumcheck_verification_output = - sumcheck::batch_verify(&[this_round_sumcheck_claim], &mut transcript)?; + sumcheck::batch_verify(&[this_round_sumcheck_claim], transcript)?; // Verify claims on transparent polynomials diff --git a/crates/core/src/protocols/greedy_evalcheck/prove.rs b/crates/core/src/protocols/greedy_evalcheck/prove.rs index 339c0c69..a9e640de 100644 --- a/crates/core/src/protocols/greedy_evalcheck/prove.rs +++ b/crates/core/src/protocols/greedy_evalcheck/prove.rs @@ -11,25 +11,24 @@ use tracing::instrument; use super::error::Error; use crate::{ - fiat_shamir::CanSample, + fiat_shamir::Challenger, oracle::MultilinearOracleSet, protocols::evalcheck::{ serialize_evalcheck_proof, subclaims::prove_bivariate_sumchecks_with_switchover, EvalcheckMultilinearClaim, EvalcheckProver, }, - transcript::{write_u64, AdviceWriter, CanWrite}, + transcript::{write_u64, ProverTranscript}, witness::MultilinearExtensionIndex, }; #[allow(clippy::too_many_arguments)] #[instrument(skip_all, name = "greedy_evalcheck::prove")] -pub fn prove( +pub fn prove( oracles: &mut MultilinearOracleSet, witness_index: &mut MultilinearExtensionIndex, claims: impl IntoIterator>, switchover_fn: impl Fn(usize) -> usize + Clone + 'static, - transcript: &mut Transcript, - advice: &mut AdviceWriter, + transcript: &mut ProverTranscript, domain_factory: impl EvaluationDomainFactory, backend: &Backend, ) -> Result>, Error> @@ -38,7 +37,7 @@ where F: TowerField + ExtensionField, PackedType: PackedFieldIndexable, DomainField: TowerField, - Transcript: CanSample + CanWrite, + Challenger_: Challenger, Backend: ComputationBackend, { let mut evalcheck_prover = @@ -48,12 +47,12 @@ where // Prove the initial evalcheck claims let evalcheck_proofs = evalcheck_prover.prove(claims)?; - write_u64(advice, evalcheck_proofs.len() as u64); + write_u64(&mut transcript.decommitment(), evalcheck_proofs.len() as u64); + let mut writer = transcript.message(); for evalcheck_proof in evalcheck_proofs.iter() { - serialize_evalcheck_proof(transcript, evalcheck_proof) + serialize_evalcheck_proof(&mut writer, evalcheck_proof) } - let mut virtual_opening_proofs_len = 0; loop { let new_sumchecks = evalcheck_prover.take_new_sumchecks_constraints().unwrap(); if new_sumchecks.is_empty() { @@ -74,12 +73,11 @@ where let new_evalcheck_proofs = evalcheck_prover.prove(new_evalcheck_claims)?; + let mut writer = transcript.message(); for evalcheck_proof in new_evalcheck_proofs.iter() { - serialize_evalcheck_proof(transcript, evalcheck_proof); + serialize_evalcheck_proof(&mut writer, evalcheck_proof); } - virtual_opening_proofs_len += 1; } - write_u64(advice, virtual_opening_proofs_len); let committed_claims = evalcheck_prover .committed_eval_claims_mut() diff --git a/crates/core/src/protocols/greedy_evalcheck/verify.rs b/crates/core/src/protocols/greedy_evalcheck/verify.rs index f49a4fc9..c3ef6b24 100644 --- a/crates/core/src/protocols/greedy_evalcheck/verify.rs +++ b/crates/core/src/protocols/greedy_evalcheck/verify.rs @@ -5,47 +5,46 @@ use binius_utils::bail; use super::error::Error; use crate::{ - fiat_shamir::CanSample, + fiat_shamir::Challenger, oracle::MultilinearOracleSet, protocols::{ evalcheck::{deserialize_evalcheck_proof, EvalcheckMultilinearClaim, EvalcheckVerifier}, sumcheck::{self, batch_verify, constraint_set_sumcheck_claims, SumcheckClaimsWithMeta}, }, - transcript::{read_u64, AdviceReader, CanRead}, + transcript::{read_u64, VerifierTranscript}, }; -pub fn verify( +pub fn verify( oracles: &mut MultilinearOracleSet, claims: impl IntoIterator>, - transcript: &mut Transcript, - advice: &mut AdviceReader, + transcript: &mut VerifierTranscript, ) -> Result>, Error> where F: TowerField, - Transcript: CanSample + CanRead, + Challenger_: Challenger, { let mut evalcheck_verifier = EvalcheckVerifier::new(oracles); // Verify the initial evalcheck claims let claims = claims.into_iter().collect::>(); - let len_initial_evalcheck_proofs = read_u64(advice)? as usize; + let len_initial_evalcheck_proofs = read_u64(&mut transcript.decommitment())? as usize; let mut initial_evalcheck_proofs = Vec::with_capacity(len_initial_evalcheck_proofs); + let mut reader = transcript.message(); for _ in 0..len_initial_evalcheck_proofs { - let eval_check_proof = deserialize_evalcheck_proof(transcript)?; + let eval_check_proof = deserialize_evalcheck_proof(&mut reader)?; initial_evalcheck_proofs.push(eval_check_proof); } evalcheck_verifier.verify(claims, initial_evalcheck_proofs)?; - let len_virtual_opening_proofs = read_u64(advice)? as usize; - for _ in 0..len_virtual_opening_proofs { + loop { let SumcheckClaimsWithMeta { claims, metas } = constraint_set_sumcheck_claims( evalcheck_verifier.take_new_sumcheck_constraints().unwrap(), )?; if claims.is_empty() { - bail!(Error::ExtraVirtualOpeningProof); + break; } // Reduce the new sumcheck claims for virtual polynomial openings to new evalcheck claims. @@ -55,8 +54,9 @@ where sumcheck::make_eval_claims(evalcheck_verifier.oracles, metas, sumcheck_output)?; let mut evalcheck_proofs = Vec::with_capacity(new_evalcheck_claims.len()); + let mut reader = transcript.message(); for _ in 0..new_evalcheck_claims.len() { - let evalcheck_proof = deserialize_evalcheck_proof(transcript)?; + let evalcheck_proof = deserialize_evalcheck_proof(&mut reader)?; evalcheck_proofs.push(evalcheck_proof) } diff --git a/crates/core/src/protocols/sumcheck/front_loaded.rs b/crates/core/src/protocols/sumcheck/front_loaded.rs index 2cfff144..bdba716a 100644 --- a/crates/core/src/protocols/sumcheck/front_loaded.rs +++ b/crates/core/src/protocols/sumcheck/front_loaded.rs @@ -5,6 +5,7 @@ use std::{cmp, cmp::Ordering, collections::VecDeque, iter}; use binius_field::{Field, TowerField}; use binius_math::{evaluate_univariate, CompositionPolyOS}; use binius_utils::sorting::is_sorted_ascending; +use bytes::Buf; use super::{ common::batch_weighted_value, @@ -12,7 +13,9 @@ use super::{ verify::compute_expected_batch_composite_evaluation_single_claim, RoundCoeffs, RoundProof, }; -use crate::{fiat_shamir::CanSample, protocols::sumcheck::SumcheckClaim, transcript::CanRead}; +use crate::{ + fiat_shamir::CanSample, protocols::sumcheck::SumcheckClaim, transcript::TranscriptReader, +}; #[derive(Debug)] enum CoeffsOrSums { @@ -128,12 +131,12 @@ where } /// Processes the next finished sumcheck claim, if all of its rounds are complete. - pub fn try_finish_claim( + pub fn try_finish_claim( &mut self, - transcript: &mut Transcript, + transcript: &mut TranscriptReader, ) -> Result>, Error> where - Transcript: CanRead, + B: Buf, { let Some(SumcheckClaimWithContext { claim, .. }) = self.claims.front() else { return Ok(None); @@ -177,12 +180,12 @@ where } /// Reads the round message from the proof transcript. - pub fn receive_round_proof( + pub fn receive_round_proof( &mut self, - transcript: &mut Transcript, + transcript: &mut TranscriptReader, ) -> Result<(), Error> where - Transcript: CanRead, + B: Buf, { let degree = match self.claims.front() { Some(SumcheckClaimWithContext { diff --git a/crates/core/src/protocols/sumcheck/prove/batch_prove.rs b/crates/core/src/protocols/sumcheck/prove/batch_prove.rs index e0cb4554..5148cfe8 100644 --- a/crates/core/src/protocols/sumcheck/prove/batch_prove.rs +++ b/crates/core/src/protocols/sumcheck/prove/batch_prove.rs @@ -7,12 +7,12 @@ use binius_utils::{bail, sorting::is_sorted_ascending}; use tracing::instrument; use crate::{ - fiat_shamir::CanSample, + fiat_shamir::{CanSample, Challenger}, protocols::sumcheck::{ common::{BatchSumcheckOutput, RoundCoeffs}, error::Error, }, - transcript::CanWrite, + transcript::ProverTranscript, }; /// A sumcheck prover with a round-by-round execution interface. @@ -90,14 +90,14 @@ impl + ?Sized> SumcheckProver for Box( +pub fn batch_prove( provers: Vec, - transcript: Transcript, + transcript: &mut ProverTranscript, ) -> Result, Error> where F: TowerField, Prover: SumcheckProver, - Transcript: CanSample + CanWrite, + Challenger_: Challenger, { let start = BatchProveStart { batch_coeffs: Vec::new(), @@ -118,15 +118,15 @@ pub struct BatchProveStart { /// Prove a batched sumcheck protocol execution, but after some rounds have been processed. #[instrument(skip_all, name = "sumcheck::batch_prove")] -pub fn batch_prove_with_start( +pub fn batch_prove_with_start( start: BatchProveStart, mut provers: Vec, - mut transcript: Transcript, + transcript: &mut ProverTranscript, ) -> Result, Error> where F: TowerField, Prover: SumcheckProver, - Transcript: CanSample + CanWrite, + Challenger_: Challenger, { let BatchProveStart { mut batch_coeffs, @@ -169,7 +169,7 @@ where break; } - let next_batch_coeff = transcript.sample(); + let next_batch_coeff: F = transcript.sample(); batch_coeffs.push(next_batch_coeff); active_index += 1; } @@ -184,7 +184,9 @@ where } let round_proof = round_coeffs.truncate(); - transcript.write_scalar_slice(round_proof.coeffs()); + transcript + .message() + .write_scalar_slice(round_proof.coeffs()); let challenge = transcript.sample(); challenges.push(challenge); @@ -198,7 +200,7 @@ where while let Some(prover) = provers.get(active_index) { debug_assert_eq!(prover.n_vars(), 0); - let _next_batch_coeff = transcript.sample(); + let _next_batch_coeff: F = transcript.sample(); active_index += 1; } @@ -207,8 +209,9 @@ where .map(|prover| Box::new(prover).finish()) .collect::, _>>()?; + let mut writer = transcript.message(); for multilinear_evals in multilinear_evals.iter() { - transcript.write_scalar_slice(multilinear_evals); + writer.write_scalar_slice(multilinear_evals); } let output = BatchSumcheckOutput { diff --git a/crates/core/src/protocols/sumcheck/prove/batch_prove_univariate_zerocheck.rs b/crates/core/src/protocols/sumcheck/prove/batch_prove_univariate_zerocheck.rs index c588e2f2..2470611f 100644 --- a/crates/core/src/protocols/sumcheck/prove/batch_prove_univariate_zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/batch_prove_univariate_zerocheck.rs @@ -5,13 +5,13 @@ use binius_utils::{bail, sorting::is_sorted_ascending}; use tracing::instrument; use crate::{ - fiat_shamir::CanSample, + fiat_shamir::{CanSample, Challenger}, protocols::sumcheck::{ prove::{batch_prove::BatchProveStart, SumcheckProver}, univariate::LagrangeRoundEvals, Error, }, - transcript::CanWrite, + transcript::ProverTranscript, }; /// A univariate zerocheck prover interface. @@ -101,15 +101,15 @@ pub struct BatchZerocheckUnivariateProveOutput { /// verification. #[allow(clippy::type_complexity)] #[instrument(skip_all, level = "debug")] -pub fn batch_prove_zerocheck_univariate_round<'a, F, Prover, Transcript>( +pub fn batch_prove_zerocheck_univariate_round<'a, F, Prover, Challenger_>( mut provers: Vec, skip_rounds: usize, - mut transcript: Transcript, + transcript: &mut ProverTranscript, ) -> Result + 'a>>, Error> where F: TowerField, Prover: UnivariateZerocheckProver<'a, F>, - Transcript: CanSample + CanWrite, + Challenger_: Challenger, { // Check that the provers are in descending order by n_vars if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars()).rev()) { @@ -149,7 +149,7 @@ where bail!(Error::IncorrectZerosPrefixLen); } - transcript.write_scalar_slice(&round_evals.evals); + transcript.message().write_scalar_slice(&round_evals.evals); let univariate_challenge = transcript.sample(); let mut reduction_provers = Vec::with_capacity(provers.len()); diff --git a/crates/core/src/protocols/sumcheck/prove/front_loaded.rs b/crates/core/src/protocols/sumcheck/prove/front_loaded.rs index c0087c53..1ae5d2bc 100644 --- a/crates/core/src/protocols/sumcheck/prove/front_loaded.rs +++ b/crates/core/src/protocols/sumcheck/prove/front_loaded.rs @@ -4,12 +4,13 @@ use std::{collections::VecDeque, iter}; use binius_field::{Field, TowerField}; use binius_utils::sorting::is_sorted_ascending; +use bytes::BufMut; use super::batch_prove::SumcheckProver; use crate::{ fiat_shamir::CanSample, protocols::sumcheck::{Error, RoundCoeffs}, - transcript::CanWrite, + transcript::TranscriptWriter, }; /// Prover for a front-loaded batch sumcheck protocol execution. @@ -61,9 +62,9 @@ where Ok(Self { provers, round: 0 }) } - fn finish_claim_provers(&mut self, transcript: &mut Transcript) -> Result<(), Error> + fn finish_claim_provers(&mut self, transcript: &mut TranscriptWriter) -> Result<(), Error> where - Transcript: CanWrite, + B: BufMut, { while let Some((prover, _)) = self.provers.front() { if prover.n_vars() != self.round { @@ -77,9 +78,9 @@ where } /// Computes the round message and writes it to the proof transcript. - pub fn send_round_proof(&mut self, transcript: &mut Transcript) -> Result<(), Error> + pub fn send_round_proof(&mut self, transcript: &mut TranscriptWriter) -> Result<(), Error> where - Transcript: CanWrite, + B: BufMut, { self.finish_claim_provers(transcript)?; @@ -104,9 +105,9 @@ where } /// Finishes the remaining instance provers and checks that all rounds are completed. - pub fn finish(mut self, transcript: &mut Transcript) -> Result<(), Error> + pub fn finish(mut self, transcript: &mut TranscriptWriter) -> Result<(), Error> where - Transcript: CanWrite, + B: BufMut, { self.finish_claim_provers(transcript)?; if !self.provers.is_empty() { diff --git a/crates/core/src/protocols/sumcheck/tests.rs b/crates/core/src/protocols/sumcheck/tests.rs index 820f4cad..86739fd4 100644 --- a/crates/core/src/protocols/sumcheck/tests.rs +++ b/crates/core/src/protocols/sumcheck/tests.rs @@ -42,7 +42,7 @@ use crate::{ sumcheck::prove::SumcheckProver, test_utils::{AddOneComposition, TestProductComposition}, }, - transcript::TranscriptWriter, + transcript::ProverTranscript, }; #[derive(Debug, Clone)] @@ -167,12 +167,12 @@ fn test_prove_verify_product_helper( ) .unwrap(); - let mut prover_transcript = TranscriptWriter::>::default(); + let mut prover_transcript = ProverTranscript::>::new(); let prover_reduced_claims = batch_prove(vec![prover], &mut prover_transcript).expect("failed to prove sumcheck"); let prover_sample = CanSample::::sample(&mut prover_transcript); - let mut verifier_transcript = prover_transcript.into_reader(); + let mut verifier_transcript = prover_transcript.into_verifier(); let verifier_reduced_claims = batch_verify(&[claim], &mut verifier_transcript).unwrap(); // Check that challengers are in the same state @@ -373,13 +373,13 @@ fn prove_verify_batch(claim_shapes: &[TestSumcheckClaimShape]) { provers.push(prover); } - let mut prover_transcript = TranscriptWriter::>::default(); + let mut prover_transcript = ProverTranscript::>::new(); let prover_output = batch_prove(provers, &mut prover_transcript).expect("failed to prove sumcheck"); let prover_sample = CanSample::::sample(&mut prover_transcript); - let mut verifier_transcript = prover_transcript.into_reader(); + let mut verifier_transcript = prover_transcript.into_verifier(); let verifier_output = batch_verify(&claims, &mut verifier_transcript).unwrap(); assert_eq!(prover_output, verifier_output); @@ -452,35 +452,37 @@ fn prove_verify_batch_front_loaded(claim_shapes: &[TestSumcheckClaimShape]) { .max() .unwrap_or(0); - let mut transcript = TranscriptWriter::>::default(); + let mut transcript = ProverTranscript::>::new(); let mut batch_prover = FrontLoadedBatchProver::new(provers, &mut transcript).unwrap(); for _ in 0..n_rounds { - batch_prover.send_round_proof(&mut transcript).unwrap(); + batch_prover + .send_round_proof(&mut transcript.message()) + .unwrap(); let challenge = transcript.sample(); batch_prover.receive_challenge(challenge).unwrap(); } - batch_prover.finish(&mut transcript).unwrap(); + batch_prover.finish(&mut transcript.message()).unwrap(); - let mut transcript = transcript.into_reader(); + let mut transcript = transcript.into_verifier(); let mut challenges = Vec::with_capacity(n_rounds); let mut multilinear_evals = Vec::with_capacity(claims.len()); let mut verifier = FrontLoadedBatchVerifier::new(&claims, &mut transcript).unwrap(); for _ in 0..n_rounds { - while let Some(claim_multilinear_evals) = - verifier.try_finish_claim(&mut transcript).unwrap() - { + let mut writer = transcript.message(); + while let Some(claim_multilinear_evals) = verifier.try_finish_claim(&mut writer).unwrap() { multilinear_evals.push(claim_multilinear_evals); } - verifier.receive_round_proof(&mut transcript).unwrap(); + verifier.receive_round_proof(&mut writer).unwrap(); let challenge = transcript.sample(); verifier.finish_round(challenge).unwrap(); challenges.push(challenge); } - while let Some(claim_multilinear_evals) = verifier.try_finish_claim(&mut transcript).unwrap() { + let mut writer = transcript.message(); + while let Some(claim_multilinear_evals) = verifier.try_finish_claim(&mut writer).unwrap() { multilinear_evals.push(claim_multilinear_evals); } verifier.finish().unwrap(); diff --git a/crates/core/src/protocols/sumcheck/univariate.rs b/crates/core/src/protocols/sumcheck/univariate.rs index 963af0db..97c8f85e 100644 --- a/crates/core/src/protocols/sumcheck/univariate.rs +++ b/crates/core/src/protocols/sumcheck/univariate.rs @@ -256,7 +256,7 @@ mod tests { }, test_utils::generate_zero_product_multilinears, }, - transcript::{AdviceWriter, Proof, TranscriptWriter}, + transcript::ProverTranscript, }; #[test] @@ -326,12 +326,8 @@ mod tests { provers.push(prover); } - let mut prove_challenger = Proof { - transcript: TranscriptWriter::>::default(), - advice: AdviceWriter::default(), - }; - let batch_sumcheck_output_prove = - batch_prove(provers, &mut prove_challenger.transcript).unwrap(); + let mut prove_challenger = ProverTranscript::>::new(); + let batch_sumcheck_output_prove = batch_prove(provers, &mut prove_challenger).unwrap(); for ((skip_rounds, multilinears), multilinear_evals) in iter::zip(&all_multilinears, batch_sumcheck_output_prove.multilinear_evals) @@ -358,7 +354,7 @@ mod tests { let mut verify_challenger = prove_challenger.into_verifier(); let batch_sumcheck_output_verify = - batch_verify(claims.as_slice(), &mut verify_challenger.transcript).unwrap(); + batch_verify(claims.as_slice(), &mut verify_challenger).unwrap(); let batch_sumcheck_output_post = verify_sumcheck_outputs( claims.as_slice(), univariate_challenge, @@ -470,13 +466,9 @@ mod tests { ]; for skip_rounds in 0..=max_n_vars { - let mut proof = Proof { - transcript: TranscriptWriter::>::new(), - advice: AdviceWriter::new(), - }; + let mut proof = ProverTranscript::>::new(); - let prover_zerocheck_challenges: Vec = - proof.transcript.sample_vec(max_n_vars - skip_rounds); + let prover_zerocheck_challenges: Vec = proof.sample_vec(max_n_vars - skip_rounds); let mut prover_zerocheck_claims = Vec::new(); let mut univariate_provers = Vec::new(); @@ -530,25 +522,21 @@ mod tests { }) .collect::>(); - let prover_univariate_output = batch_prove_zerocheck_univariate_round( - univariate_provers, - skip_rounds, - &mut proof.transcript, - ) - .unwrap(); + let prover_univariate_output = + batch_prove_zerocheck_univariate_round(univariate_provers, skip_rounds, &mut proof) + .unwrap(); let _ = batch_prove_with_start( prover_univariate_output.batch_prove_start, tail_zerocheck_provers, - &mut proof.transcript, + &mut proof, ) .unwrap(); let mut verifier_proof = proof.into_verifier(); - let verifier_zerocheck_challenges: Vec = verifier_proof - .transcript - .sample_vec(max_n_vars - skip_rounds); + let verifier_zerocheck_challenges: Vec = + verifier_proof.sample_vec(max_n_vars - skip_rounds); assert_eq!( prover_zerocheck_challenges .into_iter() @@ -571,7 +559,7 @@ mod tests { let verifier_univariate_output = batch_verify_zerocheck_univariate_round( &verifier_zerocheck_claims[..univariate_cnt], skip_rounds, - &mut verifier_proof.transcript, + &mut verifier_proof, ) .unwrap(); @@ -579,7 +567,7 @@ mod tests { let _verifier_sumcheck_output = batch_verify_with_start( verifier_univariate_output.batch_verify_start, &verifier_sumcheck_claims, - &mut verifier_proof.transcript, + &mut verifier_proof, ) .unwrap(); diff --git a/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs b/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs index 46d8946c..c81e69f3 100644 --- a/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs @@ -10,7 +10,10 @@ use super::{ verify::BatchVerifyStart, zerocheck::ZerocheckClaim, }; -use crate::{fiat_shamir::CanSample, transcript::CanRead}; +use crate::{ + fiat_shamir::{CanSample, Challenger}, + transcript::VerifierTranscript, +}; #[derive(Debug)] pub struct BatchZerocheckUnivariateOutput { @@ -40,15 +43,15 @@ pub fn extrapolated_scalars_count(composition_degree: usize, skip_rounds: usize) /// of the underlying composites, checks that univariatized round polynomial agrees with them on /// challenge point, and outputs sumcheck claims for `batch_verify` on the remaining variables. #[instrument(skip_all, level = "debug")] -pub fn batch_verify_zerocheck_univariate_round( +pub fn batch_verify_zerocheck_univariate_round( claims: &[ZerocheckClaim], skip_rounds: usize, - mut transcript: Transcript, + transcript: &mut VerifierTranscript, ) -> Result, Error> where F: TowerField, Composition: CompositionPolyOS, - Transcript: CanRead + CanSample, + Challenger_: Challenger, { // Check that the claims are in descending order by n_vars if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) { @@ -79,7 +82,9 @@ where max_degree = max_degree.max(claim.max_individual_degree() + 1); } - let round_evals = transcript.read_scalar_slice(max_domain_size - zeros_prefix_len)?; + let round_evals = transcript + .message() + .read_scalar_slice(max_domain_size - zeros_prefix_len)?; let univariate_challenge = transcript.sample(); let evaluation_domain = EvaluationDomainFactory::::create( diff --git a/crates/core/src/protocols/sumcheck/verify.rs b/crates/core/src/protocols/sumcheck/verify.rs index ebf4d31d..f73be079 100644 --- a/crates/core/src/protocols/sumcheck/verify.rs +++ b/crates/core/src/protocols/sumcheck/verify.rs @@ -11,7 +11,10 @@ use super::{ error::{Error, VerificationError}, RoundCoeffs, }; -use crate::{fiat_shamir::CanSample, transcript::CanRead}; +use crate::{ + fiat_shamir::{CanSample, Challenger}, + transcript::VerifierTranscript, +}; /// Verify a batched sumcheck protocol execution. /// @@ -25,14 +28,14 @@ use crate::{fiat_shamir::CanSample, transcript::CanRead}; /// For each sumcheck claim, we sample one random mixing coefficient. The multiple composite claims /// within each claim over a group of multilinears are mixed using the powers of the mixing /// coefficient. -pub fn batch_verify( +pub fn batch_verify( claims: &[SumcheckClaim], - transcript: &mut Transcript, + transcript: &mut VerifierTranscript, ) -> Result, Error> where F: TowerField, Composition: CompositionPolyOS, - Transcript: CanSample + CanRead, + Challenger_: Challenger, { let start = BatchVerifyStart { batch_coeffs: Vec::new(), @@ -59,15 +62,15 @@ pub struct BatchVerifyStart { /// Verify a batched sumcheck protocol execution, but after some rounds have been processed. #[instrument(skip_all, level = "debug")] -pub fn batch_verify_with_start( +pub fn batch_verify_with_start( start: BatchVerifyStart, claims: &[SumcheckClaim], - transcript: &mut Transcript, + transcript: &mut VerifierTranscript, ) -> Result, Error> where F: TowerField, Composition: CompositionPolyOS, - Transcript: CanSample + CanRead, + Challenger_: Challenger, { let BatchVerifyStart { mut batch_coeffs, @@ -118,7 +121,7 @@ where active_index += 1; } - let coeffs = transcript.read_scalar_slice(max_degree)?; + let coeffs = transcript.message().read_scalar_slice(max_degree)?; let round_proof = RoundProof(RoundCoeffs(coeffs)); let challenge = transcript.sample(); @@ -146,8 +149,9 @@ where } let mut multilinear_evals = Vec::with_capacity(claims.len()); + let mut reader = transcript.message(); for claim in claims.iter() { - let evals = transcript.read_scalar_slice::(claim.n_multilinears())?; + let evals = reader.read_scalar_slice::(claim.n_multilinears())?; multilinear_evals.push(evals); } diff --git a/crates/core/src/protocols/sumcheck/zerocheck.rs b/crates/core/src/protocols/sumcheck/zerocheck.rs index 2938bbd8..8d09dbd1 100644 --- a/crates/core/src/protocols/sumcheck/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/zerocheck.rs @@ -217,7 +217,7 @@ mod tests { }, test_utils::{generate_zero_product_multilinears, TestProductComposition}, }, - transcript::TranscriptWriter, + transcript::ProverTranscript, transparent::eq_ind::EqIndPartialEval, witness::MultilinearWitness, }; @@ -289,7 +289,7 @@ mod tests { let binding = [("test_product".into(), TestProductComposition::new(n_multilinears))]; zerocheck::validate_witness(&multilins, &binding).unwrap(); - let mut prove_transcript_1 = TranscriptWriter::>::default(); + let mut prove_transcript_1 = ProverTranscript::>::new(); let backend = make_portable_backend(); let challenges = prove_transcript_1.sample_vec(n_vars); @@ -321,7 +321,7 @@ mod tests { .into_regular_zerocheck() .unwrap(); - let mut prove_transcript_2 = TranscriptWriter::>::default(); + let mut prove_transcript_2 = ProverTranscript::>::new(); let _: Vec = prove_transcript_2.sample_vec(n_vars); let BatchSumcheckOutput { challenges: sumcheck_challenges_2, @@ -350,7 +350,7 @@ mod tests { let binding = [("test_product".into(), TestProductComposition::new(n_multilinears))]; zerocheck::validate_witness(&multilins, &binding).unwrap(); - let mut prove_transcript = TranscriptWriter::>::default(); + let mut prove_transcript = ProverTranscript::>::new(); let challenges = prove_transcript.sample_vec(n_vars); let domain_factory = IsomorphicEvaluationDomainFactory::::default(); @@ -391,7 +391,7 @@ mod tests { .unwrap(); let prover_sample = CanSample::::sample(&mut prove_transcript); - let mut verify_transcript = prove_transcript.into_reader(); + let mut verify_transcript = prove_transcript.into_verifier(); let _: Vec = verify_transcript.sample_vec(n_vars); let sumcheck_claims = reduce_to_sumchecks(&zerocheck_claims).unwrap(); diff --git a/crates/core/src/reed_solomon/reed_solomon.rs b/crates/core/src/reed_solomon/reed_solomon.rs index 03f7899a..771a306f 100644 --- a/crates/core/src/reed_solomon/reed_solomon.rs +++ b/crates/core/src/reed_solomon/reed_solomon.rs @@ -12,13 +12,12 @@ use std::marker::PhantomData; -use binius_field::{BinaryField, PackedField}; +use binius_field::{BinaryField, ExtensionField, PackedField, RepackedExtension}; use binius_maybe_rayon::prelude::*; use binius_ntt::{AdditiveNTT, DynamicDispatchNTT, Error, NTTOptions, ThreadingSettings}; -use binius_utils::bail; +use binius_utils::{bail, checked_arithmetics::checked_log_2}; use getset::CopyGetters; - -use crate::linear_code::LinearCode; +use tracing::instrument; #[derive(Debug, CopyGetters)] pub struct ReedSolomonCode

@@ -70,53 +69,55 @@ where }) } - pub fn get_ntt(&self) -> &impl AdditiveNTT

{ + pub const fn get_ntt(&self) -> &impl AdditiveNTT

{ &self.ntt } - pub fn log_dim(&self) -> usize { + /// The dimension. + pub const fn dim(&self) -> usize { + 1 << self.dim_bits() + } + + pub const fn log_dim(&self) -> usize { self.log_dimension } - pub fn log_len(&self) -> usize { + pub const fn log_len(&self) -> usize { self.log_dimension + self.log_inv_rate } -} -impl LinearCode for ReedSolomonCode

-where - P: PackedField, - F: BinaryField, -{ - type P = P; - type EncodeError = Error; - - fn len(&self) -> usize { + /// The block length. + #[allow(clippy::len_without_is_empty)] + pub const fn len(&self) -> usize { 1 << (self.log_dimension + self.log_inv_rate) } - fn dim_bits(&self) -> usize { + /// The base-2 log of the dimension. + const fn dim_bits(&self) -> usize { self.log_dimension } - fn min_dist(&self) -> usize { - self.len() - self.dim() + 1 - } - - fn inv_rate(&self) -> usize { + /// The reciprocal of the rate, ie. `self.len() / self.dim()`. + pub const fn inv_rate(&self) -> usize { 1 << self.log_inv_rate } - fn encode_batch_inplace( - &self, - code: &mut [Self::P], - log_batch_size: usize, - ) -> Result<(), Self::EncodeError> { + /// Encode a batch of interleaved messages in-place in a provided buffer. + /// + /// The message symbols are interleaved in the buffer, which improves the cache-efficiency of + /// the encoding procedure. The interleaved codeword is stored in the buffer when the method + /// completes. + /// + /// ## Throws + /// + /// * If the `code` buffer does not have capacity for `len() << log_batch_size` field + /// elements. + fn encode_batch_inplace(&self, code: &mut [P], log_batch_size: usize) -> Result<(), Error> { let _scope = tracing::trace_span!( "Reed–Solomon encode", log_len = self.log_len(), log_batch_size = log_batch_size, - symbol_bits = F::N_BITS, + symbol_bits = P::Scalar::N_BITS, ) .entered(); if (code.len() << log_batch_size) < self.len() { @@ -144,4 +145,31 @@ where .try_for_each(|(i, data)| self.ntt.forward_transform(data, i, log_batch_size)) } } + + /// Encode a batch of interleaved messages of extension field elements in-place in a provided + /// buffer. + /// + /// A linear code can be naturally extended to a code over extension fields by encoding each + /// dimension of the extension as a vector-space separately. + /// + /// ## Preconditions + /// + /// * `PE::Scalar::DEGREE` must be a power of two. + /// + /// ## Throws + /// + /// * If the `code` buffer does not have capacity for `len() << log_batch_size` field elements. + #[instrument(skip_all, level = "debug")] + pub fn encode_ext_batch_inplace( + &self, + code: &mut [PE], + log_batch_size: usize, + ) -> Result<(), Error> + where + PE: RepackedExtension

, + PE::Scalar: ExtensionField<

::Scalar>, + { + let log_degree = checked_log_2(PE::Scalar::DEGREE); + self.encode_batch_inplace(PE::cast_bases_mut(code), log_batch_size + log_degree) + } } diff --git a/crates/core/src/ring_switch/prove.rs b/crates/core/src/ring_switch/prove.rs index 75e07bab..ab2d6aee 100644 --- a/crates/core/src/ring_switch/prove.rs +++ b/crates/core/src/ring_switch/prove.rs @@ -15,11 +15,11 @@ use super::{ tower_tensor_algebra::TowerTensorAlgebra, }; use crate::{ - fiat_shamir::CanSample, + fiat_shamir::{CanSample, Challenger}, piop::PIOPSumcheckClaim, ring_switch::{common::EvalClaimSuffixDesc, eq_ind::RingSwitchEqInd}, tower::{PackedTop, TowerFamily}, - transcript::{CanWrite, Proof}, + transcript::ProverTranscript, witness::MultilinearWitness, }; @@ -32,10 +32,10 @@ pub struct ReducedWitness { } #[tracing::instrument("ring_switch::prove", skip_all)] -pub fn prove( +pub fn prove( system: &EvalClaimSystem, witnesses: &[M], - proof: &mut Proof, + transcript: &mut ProverTranscript, backend: &Backend, ) -> Result, Error> where @@ -44,7 +44,7 @@ where M: MultilinearPoly

+ Sync, Tower: TowerFamily, F: PackedTop, - Transcript: CanWrite + CanSample, + Challenger_: Challenger, Backend: ComputationBackend, { if witnesses.len() != system.commit_meta.total_multilins() { @@ -56,7 +56,7 @@ where // Sample enough randomness to batch tensor elements corresponding to claims that share an // evaluation point prefix. let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len()); - let mixing_challenges = proof.transcript.sample_vec(n_mixing_challenges); + let mixing_challenges = transcript.sample_vec(n_mixing_challenges); let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion(); // For each evaluation point prefix, send one batched partial evaluation. @@ -67,21 +67,20 @@ where &system.prefix_descs, &system.eval_claim_to_prefix_desc_index, )?; + let mut writer = transcript.message(); for (mixed_tensor_elem, prefix_desc) in iter::zip(mixed_tensor_elems, &system.prefix_descs) { debug_assert_eq!(mixed_tensor_elem.vertical_elems().len(), 1 << prefix_desc.kappa()); - proof - .transcript - .write_scalar_slice(mixed_tensor_elem.vertical_elems()); + writer.write_scalar_slice(mixed_tensor_elem.vertical_elems()); } // Sample the row-batching randomness. - let row_batch_challenges = proof.transcript.sample_vec(system.max_claim_kappa()); + let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa()); let row_batch_coeffs = Arc::from(MultilinearQuery::::expand(&row_batch_challenges).into_expansion()); let row_batched_evals = compute_row_batched_sumcheck_evals(scaled_tensor_elems, &row_batch_coeffs); - proof.transcript.write_scalar_slice(&row_batched_evals); + transcript.message().write_scalar_slice(&row_batched_evals); // Create the reduced PIOP sumcheck witnesses. let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, P, Tower>( diff --git a/crates/core/src/ring_switch/tests.rs b/crates/core/src/ring_switch/tests.rs index 1e86498b..34e4cb0a 100644 --- a/crates/core/src/ring_switch/tests.rs +++ b/crates/core/src/ring_switch/tests.rs @@ -31,7 +31,7 @@ use crate::{ protocols::{evalcheck::EvalcheckMultilinearClaim, fri::CommitOutput}, ring_switch::prove::ReducedWitness, tower::{CanonicalTowerFamily, PackedTop, TowerFamily, TowerUnderlier}, - transcript::{AdviceWriter, CanRead, CanWrite, Proof, TranscriptWriter}, + transcript::ProverTranscript, witness::{MultilinearExtensionIndex, MultilinearWitness}, }; @@ -233,22 +233,19 @@ fn test_prove_verify_claim_reduction_with_naive_validation() { let oracles = make_test_oracle_set(); with_test_instance_from_oracles::(rng, &oracles, |_rng, system, witnesses| { - let mut proof = Proof { - transcript: TranscriptWriter::>::default(), - advice: AdviceWriter::default(), - }; + let mut proof = ProverTranscript::>::new(); let backend = make_portable_backend(); let ReducedWitness { transparents: transparent_witnesses, sumcheck_claims: prover_sumcheck_claims, - } = prove::<_, _, _, Tower, _, _, _>(&system, &witnesses, &mut proof, &backend).unwrap(); + } = prove::<_, _, _, Tower, _, _>(&system, &witnesses, &mut proof, &backend).unwrap(); let mut proof = proof.into_verifier(); let ReducedClaim { transparents: _, sumcheck_claims: verifier_sumcheck_claims, - } = verify::<_, Tower, _, _>(&system, &mut proof).unwrap(); + } = verify::<_, Tower, _>(&system, &mut proof).unwrap(); assert_eq!(prover_sumcheck_claims, verifier_sumcheck_claims); @@ -307,17 +304,14 @@ fn commit_prove_verify_piop( let system = EvalClaimSystem::new(&commit_meta, oracle_to_commit_index, &eval_claims).unwrap(); check_eval_point_consistency(&system); - let mut proof = Proof { - transcript: TranscriptWriter::>::default(), - advice: AdviceWriter::default(), - }; - proof.transcript.write(&commitment); + let mut proof = ProverTranscript::>::new(); + proof.message().write(&commitment); let backend = make_portable_backend(); let ReducedWitness { transparents: transparent_multilins, sumcheck_claims, - } = prove::<_, _, _, Tower, _, _, _>(&system, &committed_multilins, &mut proof, &backend).unwrap(); + } = prove::<_, _, _, Tower, _, _>(&system, &committed_multilins, &mut proof, &backend).unwrap(); let domain_factory = DefaultEvaluationDomainFactory::::default(); piop::prove( @@ -336,12 +330,12 @@ fn commit_prove_verify_piop( .unwrap(); let mut proof = proof.into_verifier(); - let commitment = proof.transcript.read().unwrap(); + let commitment = proof.message().read().unwrap(); let ReducedClaim { transparents, sumcheck_claims, - } = verify::<_, Tower, _, _>(&system, &mut proof).unwrap(); + } = verify::<_, Tower, _>(&system, &mut proof).unwrap(); piop::verify( &commit_meta, diff --git a/crates/core/src/ring_switch/verify.rs b/crates/core/src/ring_switch/verify.rs index 2b0ebcdc..fdf7f303 100644 --- a/crates/core/src/ring_switch/verify.rs +++ b/crates/core/src/ring_switch/verify.rs @@ -5,10 +5,11 @@ use std::{iter, sync::Arc}; use binius_field::{Field, TowerField}; use binius_math::{MultilinearExtension, MultilinearQuery}; use binius_utils::checked_arithmetics::log2_ceil_usize; +use bytes::Buf; use itertools::izip; use crate::{ - fiat_shamir::CanSample, + fiat_shamir::{CanSample, Challenger}, piop::PIOPSumcheckClaim, polynomial::MultivariatePoly, ring_switch::{ @@ -16,7 +17,7 @@ use crate::{ EvalClaimSuffixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc, VerificationError, }, tower::{PackedTop, TowerFamily}, - transcript::{CanRead, Proof}, + transcript::{TranscriptReader, VerifierTranscript}, }; type FExt = ::B128; @@ -27,34 +28,35 @@ pub struct ReducedClaim<'a, F: Field> { pub sumcheck_claims: Vec>, } -pub fn verify<'a, F, Tower, Transcript, Advice>( +pub fn verify<'a, F, Tower, Challenger_>( system: &'a EvalClaimSystem, - proof: &mut Proof, + transcript: &mut VerifierTranscript, ) -> Result, Error> where F: TowerField, Tower: TowerFamily, F: PackedTop, - Transcript: CanRead + CanSample, + Challenger_: Challenger, { // Sample enough randomness to batch tensor elements corresponding to claims that share an // evaluation point prefix. let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len()); - let mixing_challenges = proof.transcript.sample_vec(n_mixing_challenges); + let mixing_challenges = transcript.sample_vec(n_mixing_challenges); let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion(); // For each evaluation point prefix, receive one batched tensor algebra element and verify // that it is consistent with the evaluation claims. - let tensor_elems = verify_receive_tensor_elems(system, &mixing_coeffs, &mut proof.transcript)?; + let tensor_elems = + verify_receive_tensor_elems(system, &mixing_coeffs, &mut transcript.message())?; // Sample the row-batching randomness. - let row_batch_challenges = proof.transcript.sample_vec(system.max_claim_kappa()); + let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa()); let row_batch_coeffs = Arc::from(MultilinearQuery::::expand(&row_batch_challenges).into_expansion()); // For each original evaluation claim, receive the row-batched evaluation claim. - let row_batched_evals = proof - .transcript + let row_batched_evals = transcript + .message() .read_scalar_slice(system.sumcheck_claim_descs.len())?; // Check that the row-batched evaluation claims sent by the prover are consistent with the @@ -96,16 +98,16 @@ where }) } -fn verify_receive_tensor_elems( +fn verify_receive_tensor_elems( system: &EvalClaimSystem, mixing_coeffs: &[F], - transcript: &mut Transcript, + transcript: &mut TranscriptReader, ) -> Result>, Error> where F: TowerField, Tower: TowerFamily, F: PackedTop, - Transcript: CanRead, + B: Buf, { let expected_tensor_elem_evals = compute_mixed_evaluations( system diff --git a/crates/core/src/tower.rs b/crates/core/src/tower.rs index a11f390a..656bd0b7 100644 --- a/crates/core/src/tower.rs +++ b/crates/core/src/tower.rs @@ -3,15 +3,18 @@ //! Traits for working with field towers. use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, AESTowerField128b, AESTowerField16b, - AESTowerField32b, AESTowerField64b, AESTowerField8b, BinaryField128b, BinaryField16b, - BinaryField1b, BinaryField32b, BinaryField64b, BinaryField8b, ExtensionField, PackedExtension, - PackedField, TowerField, + as_packed_field::PackScalar, + linear_transformation::{PackedTransformationFactory, Transformation}, + polyval::{AES_TO_POLYVAL_TRANSFORMATION, BINARY_TO_POLYVAL_TRANSFORMATION}, + underlier::UnderlierType, + AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField8b, + BinaryField128b, BinaryField128bPolyval, BinaryField16b, BinaryField1b, BinaryField32b, + BinaryField64b, BinaryField8b, ExtensionField, PackedExtension, PackedField, TowerField, }; use trait_set::trait_set; /// A trait that groups a family of related [`TowerField`]s as associated types. -pub trait TowerFamily { +pub trait TowerFamily: Sized { type B1: TowerField + TryFrom; type B8: TowerField + TryFrom + ExtensionField; type B16: TowerField + TryFrom + ExtensionField + ExtensionField; @@ -34,6 +37,15 @@ pub trait TowerFamily { + ExtensionField; } +pub trait ProverTowerFamily: TowerFamily { + type FastB128: TowerField + From + Into; + + fn packed_transformation_to_fast() -> impl Transformation + where + Top: PackedTop + PackedTransformationFactory, + FastTop: PackedField; +} + /// The canonical Fan-Paar tower family. #[derive(Debug)] pub struct CanonicalTowerFamily; @@ -47,6 +59,18 @@ impl TowerFamily for CanonicalTowerFamily { type B128 = BinaryField128b; } +impl ProverTowerFamily for CanonicalTowerFamily { + type FastB128 = BinaryField128bPolyval; + + fn packed_transformation_to_fast() -> impl Transformation + where + Top: PackedTop + PackedTransformationFactory, + FastTop: PackedField, + { + Top::make_packed_transformation(BINARY_TO_POLYVAL_TRANSFORMATION) + } +} + /// The tower defined by Fan-Paar extensions built on top of the Rijndael field. #[derive(Debug)] pub struct AESTowerFamily; @@ -60,6 +84,18 @@ impl TowerFamily for AESTowerFamily { type B128 = AESTowerField128b; } +impl ProverTowerFamily for AESTowerFamily { + type FastB128 = BinaryField128bPolyval; + + fn packed_transformation_to_fast() -> impl Transformation + where + Top: PackedTop + PackedTransformationFactory, + FastTop: PackedField, + { + Top::make_packed_transformation(AES_TO_POLYVAL_TRANSFORMATION) + } +} + trait_set! { /// An underlier with associated packed types for fields in a tower. pub trait TowerUnderlier = @@ -71,6 +107,9 @@ trait_set! { + PackScalar + PackScalar; + pub trait ProverTowerUnderlier = + TowerUnderlier + PackScalar; + /// A packed field type that is the top packed field in a tower. pub trait PackedTop = PackedField diff --git a/crates/core/src/transcript/mod.rs b/crates/core/src/transcript/mod.rs index 2809cea2..d59af04a 100644 --- a/crates/core/src/transcript/mod.rs +++ b/crates/core/src/transcript/mod.rs @@ -24,51 +24,26 @@ use tracing::warn; use crate::fiat_shamir::{CanSample, CanSampleBits, Challenger}; -/// Writable(Prover) transcript over some Challenger that `CanWrite` and `CanSample` +/// Prover transcript over some Challenger that writes to the internal tape and `CanSample` /// /// A Transcript is an abstraction over Fiat-Shamir so the prover and verifier can send and receive -/// data, everything that gets written to or read from the transcript will be observed -#[derive(Debug, Default)] -pub struct TranscriptWriter { +/// data. +#[derive(Debug)] +pub struct ProverTranscript { combined: FiatShamirBuf, debug_assertions: bool, } -/// Writable(Prover) advice that `CanWrite` -/// -/// Advice holds meta-data to the transcript that need not be Fiat-Shamir'ed -#[derive(Debug, Default)] -pub struct AdviceWriter { - buffer: BytesMut, - debug_assertions: bool, -} - -/// Readable(Verifier) transcript over some Challenger that `CanRead` and `CanSample` +/// Verifier transcript over some Challenger that reads from the internal tape and `CanSample` /// /// You must manually call the destructor with `finalize()` to check anything that's written is /// fully read out #[derive(Debug)] -pub struct TranscriptReader { +pub struct VerifierTranscript { combined: FiatShamirBuf, debug_assertions: bool, } -/// Readable(Verifier) advice that `CanRead` -/// -/// You must manually call the destructor with `finalize()` to check anything that's written is -/// fully read out -#[derive(Debug)] -pub struct AdviceReader { - buffer: Bytes, - debug_assertions: bool, -} - -/// Helper struct combining Transcript and Advice data to create a Proof object -pub struct Proof { - pub transcript: Transcript, - pub advice: Advice, -} - #[derive(Debug, Default)] struct FiatShamirBuf { buffer: Inner, @@ -107,7 +82,7 @@ unsafe impl BufMut for FiatShamirBuf BufMut for FiatShamirBuf TranscriptWriter { +impl ProverTranscript { + /// Creates a new prover transcript. + /// + /// By default debug assertions are set to the feature flag `debug_assertions`. You may also + /// change the debug flag with [`Self::set_debug`]. pub fn new() -> Self { Self { combined: Default::default(), @@ -127,67 +106,85 @@ impl TranscriptWriter { } } - pub fn finalize(self) -> Vec { - self.combined.buffer.to_vec() - } - - pub fn into_reader(self) -> TranscriptReader { - TranscriptReader::new(self.finalize()) - } - - pub fn set_debug(&mut self, debug: bool) { - self.debug_assertions = debug; + pub fn into_verifier(self) -> VerifierTranscript { + VerifierTranscript::new(self.finalize()) } } -impl AdviceWriter { - pub fn new() -> Self { - Self { - buffer: Default::default(), - debug_assertions: cfg!(debug_assertions), - } +impl Default for ProverTranscript { + fn default() -> Self { + Self::new() } +} +impl ProverTranscript { pub fn finalize(self) -> Vec { - self.buffer.to_vec() - } - - pub fn into_reader(self) -> AdviceReader { - AdviceReader::new(self.finalize()) + self.combined.buffer.to_vec() } + /// Sets the debug flag. + /// + /// This flag is used to enable debug assertions in the [`TranscriptReader`] and + /// [`TranscriptWriter`] methods. pub fn set_debug(&mut self, debug: bool) { self.debug_assertions = debug; } -} -impl Proof, AdviceWriter> { - pub fn into_verifier(self) -> Proof, AdviceReader> { - Proof { - transcript: self.transcript.into_reader(), - advice: self.advice.into_reader(), + /// Returns a writeable buffer that only observes the data written, without writing it to the + /// proof tape. + /// + /// This method should be used to observe the input statement. + pub fn observe<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut + 'b> + where + 'a: 'b, + { + TranscriptWriter { + buffer: self.combined.challenger.observer(), + debug_assertions: self.debug_assertions, } } -} -impl Proof, AdviceReader> { - pub fn finalize(self) -> Result<(), Error> { - self.transcript.finalize()?; - self.advice.finalize() + /// Returns a writeable buffer that only writes the data to the proof tape, without observing it. + /// + /// This method should only be used to write openings of commitments that were already written + /// to the transcript as an observed message. For example, in the FRI protocol, the prover sends + /// a Merkle tree root as a commitment, and later sends leaf openings. The leaf openings should + /// be written using [`Self::decommitment`] because they are verified with respect to the + /// previously sent Merkle root. + pub fn decommitment(&mut self) -> TranscriptWriter { + TranscriptWriter { + buffer: &mut self.combined.buffer, + debug_assertions: self.debug_assertions, + } + } + + /// Returns a writeable buffer that observes the data written and writes it to the proof tape. + /// + /// This method should be used by default to write prover messages in an interactive protocol. + pub fn message<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut> + where + 'a: 'b, + { + TranscriptWriter { + buffer: &mut self.combined, + debug_assertions: self.debug_assertions, + } } } -impl TranscriptReader { +impl VerifierTranscript { pub fn new(vec: Vec) -> Self { Self { combined: FiatShamirBuf { - challenger: Challenger::default(), + challenger: Challenger_::default(), buffer: Bytes::from(vec), }, debug_assertions: cfg!(debug_assertions), } } +} +impl VerifierTranscript { pub fn finalize(self) -> Result<(), Error> { if self.combined.buffer.has_remaining() { return Err(Error::TranscriptNotEmpty { @@ -200,10 +197,47 @@ impl TranscriptReader { pub fn set_debug(&mut self, debug: bool) { self.debug_assertions = debug; } + + /// Returns a writable buffer that only observes the data written, without reading it from the + /// proof tape. + /// + /// This method should be used to observe the input statement. + pub fn observe<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut + 'b> + where + 'a: 'b, + { + TranscriptWriter { + buffer: self.combined.challenger.observer(), + debug_assertions: self.debug_assertions, + } + } + + /// Returns a readable buffer that only reads the data from the proof tape, without observing it. + /// + /// This method should only be used to read advice that was previously written to the transcript as an observed message. + pub fn decommitment(&mut self) -> TranscriptReader { + TranscriptReader { + buffer: &mut self.combined.buffer, + debug_assertions: self.debug_assertions, + } + } + + /// Returns a readable buffer that observes the data read. + /// + /// This method should be used by default to read verifier messages in an interactive protocol. + pub fn message<'a, 'b>(&'a mut self) -> TranscriptReader<'b, impl Buf> + where + 'a: 'b, + { + TranscriptReader { + buffer: &mut self.combined, + debug_assertions: self.debug_assertions, + } + } } // Useful warnings to see if we are neglecting to read any advice or transcript entirely -impl Drop for TranscriptReader { +impl Drop for VerifierTranscript { fn drop(&mut self) { if self.combined.buffer.has_remaining() { warn!( @@ -214,54 +248,29 @@ impl Drop for TranscriptReader { } } -impl Drop for AdviceReader { - fn drop(&mut self) { - if self.buffer.has_remaining() { - warn!("Advice reader is not fully read out: {:?} bytes left", self.buffer.remaining()) - } - } +pub struct TranscriptReader<'a, B: Buf> { + buffer: &'a mut B, + debug_assertions: bool, } -impl AdviceReader { - pub fn new(vec: Vec) -> Self { - Self { - buffer: Bytes::from(vec), - debug_assertions: cfg!(debug_assertions), - } - } - - pub fn finalize(self) -> Result<(), Error> { - if self.buffer.has_remaining() { - return Err(Error::TranscriptNotEmpty { - remaining: self.buffer.remaining(), - }); - } - Ok(()) - } - - pub fn set_debug(&mut self, debug: bool) { - self.debug_assertions = debug; +impl TranscriptReader<'_, B> { + pub fn buffer(&mut self) -> &mut B { + self.buffer } -} -/// Trait that is used to read bytes and field elements from transcript/advice -#[auto_impl::auto_impl(&mut)] -pub trait CanRead { - fn buffer(&mut self) -> impl Buf + '_; - - fn read(&mut self) -> Result { + pub fn read(&mut self) -> Result { T::deserialize(self.buffer()).map_err(Into::into) } - fn read_vec(&mut self, n: usize) -> Result, Error> { + pub fn read_vec(&mut self, n: usize) -> Result, Error> { let mut buffer = self.buffer(); repeat_with(move || T::deserialize(&mut buffer).map_err(Into::into)) .take(n) .collect() } - fn read_bytes(&mut self, buf: &mut [u8]) -> Result<(), Error> { - let mut buffer = self.buffer(); + pub fn read_bytes(&mut self, buf: &mut [u8]) -> Result<(), Error> { + let buffer = self.buffer(); if buffer.remaining() < buf.len() { return Err(Error::NotEnoughBytes); } @@ -269,13 +278,13 @@ pub trait CanRead { Ok(()) } - fn read_scalar(&mut self) -> Result { + pub fn read_scalar(&mut self) -> Result { let mut out = F::default(); self.read_scalar_slice_into(slice::from_mut(&mut out))?; Ok(out) } - fn read_scalar_slice_into(&mut self, buf: &mut [F]) -> Result<(), Error> { + pub fn read_scalar_slice_into(&mut self, buf: &mut [F]) -> Result<(), Error> { let mut buffer = self.buffer(); for elem in buf { *elem = deserialize_canonical(&mut buffer)?; @@ -283,17 +292,17 @@ pub trait CanRead { Ok(()) } - fn read_scalar_slice(&mut self, len: usize) -> Result, Error> { + pub fn read_scalar_slice(&mut self, len: usize) -> Result, Error> { let mut elems = vec![F::default(); len]; self.read_scalar_slice_into(&mut elems)?; Ok(elems) } - fn read_packed>(&mut self) -> Result { + pub fn read_packed>(&mut self) -> Result { P::try_from_fn(|_| self.read_scalar()) } - fn read_packed_slice>( + pub fn read_packed_slice>( &mut self, len: usize, ) -> Result, Error> { @@ -304,15 +313,7 @@ pub trait CanRead { Ok(packed) } - fn read_debug(&mut self, msg: &str); -} - -impl CanRead for TranscriptReader { - fn buffer(&mut self) -> impl Buf + '_ { - &mut self.combined - } - - fn read_debug(&mut self, msg: &str) { + pub fn read_debug(&mut self, msg: &str) { if self.debug_assertions { let msg_bytes = msg.as_bytes(); let mut buffer = vec![0; msg_bytes.len()]; @@ -322,94 +323,64 @@ impl CanRead for TranscriptReader { } } -impl CanRead for AdviceReader { - fn buffer(&mut self) -> impl Buf + '_ { - &mut self.buffer - } - - fn read_debug(&mut self, msg: &str) { - if self.debug_assertions { - let msg_bytes = msg.as_bytes(); - let mut buffer = vec![0; msg_bytes.len()]; - assert!(self.read_bytes(&mut buffer).is_ok()); - assert_eq!(msg_bytes, buffer); - } - } +pub struct TranscriptWriter<'a, B: BufMut> { + buffer: &'a mut B, + debug_assertions: bool, } -/// Trait that is used to write bytes and field elements to transcript/advice -#[auto_impl::auto_impl(&mut)] -pub trait CanWrite { - fn buffer(&mut self) -> impl BufMut + '_; +impl TranscriptWriter<'_, B> { + pub fn buffer(&mut self) -> &mut B { + self.buffer + } - fn write(&mut self, value: &T) { + pub fn write(&mut self, value: &T) { value .serialize(self.buffer()) .expect("TODO: propagate error") } - fn write_slice(&mut self, values: &[T]) { + pub fn write_slice(&mut self, values: &[T]) { let mut buffer = self.buffer(); for value in values { value.serialize(&mut buffer).expect("TODO: propagate error") } } - fn write_bytes(&mut self, data: &[u8]) { + pub fn write_bytes(&mut self, data: &[u8]) { self.buffer().put_slice(data); } - fn write_scalar(&mut self, f: F) { + pub fn write_scalar(&mut self, f: F) { self.write_scalar_slice(slice::from_ref(&f)); } - fn write_scalar_slice(&mut self, elems: &[F]) { + pub fn write_scalar_slice(&mut self, elems: &[F]) { let mut buffer = self.buffer(); for elem in elems { serialize_canonical(*elem, &mut buffer).expect("TODO: propagate error"); } } - fn write_packed>(&mut self, packed: P) { + pub fn write_packed>(&mut self, packed: P) { for scalar in packed.iter() { self.write_scalar(scalar); } } - fn write_packed_slice>(&mut self, packed_slice: &[P]) { + pub fn write_packed_slice>(&mut self, packed_slice: &[P]) { for &packed in packed_slice { self.write_packed(packed) } } - fn write_debug(&mut self, msg: &str); -} - -impl CanWrite for TranscriptWriter { - fn buffer(&mut self) -> impl BufMut + '_ { - &mut self.combined - } - - fn write_debug(&mut self, msg: &str) { + pub fn write_debug(&mut self, msg: &str) { if self.debug_assertions { self.write_bytes(msg.as_bytes()) } } } -impl CanWrite for AdviceWriter { - fn buffer(&mut self) -> impl BufMut + '_ { - &mut self.buffer - } - - fn write_debug(&mut self, msg: &str) { - if self.debug_assertions { - self.write_bytes(msg.as_bytes()) - } - } -} - -impl CanSample for TranscriptReader +impl CanSample for VerifierTranscript where F: TowerField, Challenger_: Challenger, @@ -420,7 +391,7 @@ where } } -impl CanSample for TranscriptWriter +impl CanSample for ProverTranscript where F: TowerField, Challenger_: Challenger, @@ -449,7 +420,7 @@ fn sample_bits_reader(mut reader: Reader, bits: usize) -> usize { mask & unmasked } -impl CanSampleBits for TranscriptReader +impl CanSampleBits for VerifierTranscript where Challenger_: Challenger, { @@ -458,7 +429,7 @@ where } } -impl CanSampleBits for TranscriptWriter +impl CanSampleBits for ProverTranscript where Challenger_: Challenger, { @@ -468,13 +439,13 @@ where } /// Helper functions for serializing native types -pub fn read_u64(transcript: &mut Transcript) -> Result { +pub fn read_u64(transcript: &mut TranscriptReader) -> Result { let mut as_bytes = [0; size_of::()]; transcript.read_bytes(&mut as_bytes)?; Ok(u64::from_le_bytes(as_bytes)) } -pub fn write_u64(transcript: &mut Transcript, n: u64) { +pub fn write_u64(transcript: &mut TranscriptWriter, n: u64) { transcript.write_bytes(&n.to_le_bytes()); } @@ -492,28 +463,33 @@ mod tests { #[test] fn test_transcripting() { - let mut prover_transcript = TranscriptWriter::>::new(); + let mut prover_transcript = ProverTranscript::>::new(); + let mut writable = prover_transcript.message(); - prover_transcript.write_scalar(BinaryField8b::new(0x96)); - prover_transcript.write_scalar(BinaryField32b::new(0xDEADBEEF)); - prover_transcript.write_scalar(BinaryField128b::new(0x55669900112233550000CCDDFFEEAABB)); + writable.write_scalar(BinaryField8b::new(0x96)); + writable.write_scalar(BinaryField32b::new(0xDEADBEEF)); + writable.write_scalar(BinaryField128b::new(0x55669900112233550000CCDDFFEEAABB)); let sampled_fanpaar1: BinaryField128b = prover_transcript.sample(); - prover_transcript.write_scalar(AESTowerField8b::new(0x52)); - prover_transcript.write_scalar(AESTowerField32b::new(0x12345678)); - prover_transcript.write_scalar(AESTowerField128b::new(0xDDDDBBBBCCCCAAAA2222999911117777)); + let mut writable = prover_transcript.message(); + + writable.write_scalar(AESTowerField8b::new(0x52)); + writable.write_scalar(AESTowerField32b::new(0x12345678)); + writable.write_scalar(AESTowerField128b::new(0xDDDDBBBBCCCCAAAA2222999911117777)); let sampled_aes1: AESTowerField16b = prover_transcript.sample(); prover_transcript + .message() .write_scalar(BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA)); let sampled_polyval1: BinaryField128bPolyval = prover_transcript.sample(); - let mut verifier_transcript = prover_transcript.into_reader(); + let mut verifier_transcript = prover_transcript.into_verifier(); + let mut readable = verifier_transcript.message(); - let fp_8: BinaryField8b = verifier_transcript.read_scalar().unwrap(); - let fp_32: BinaryField32b = verifier_transcript.read_scalar().unwrap(); - let fp_128: BinaryField128b = verifier_transcript.read_scalar().unwrap(); + let fp_8: BinaryField8b = readable.read_scalar().unwrap(); + let fp_32: BinaryField32b = readable.read_scalar().unwrap(); + let fp_128: BinaryField128b = readable.read_scalar().unwrap(); assert_eq!(fp_8.val(), 0x96); assert_eq!(fp_32.val(), 0xDEADBEEF); @@ -523,9 +499,11 @@ mod tests { assert_eq!(sampled_fanpaar1_res, sampled_fanpaar1); - let aes_8: AESTowerField8b = verifier_transcript.read_scalar().unwrap(); - let aes_32: AESTowerField32b = verifier_transcript.read_scalar().unwrap(); - let aes_128: AESTowerField128b = verifier_transcript.read_scalar().unwrap(); + let mut readable = verifier_transcript.message(); + + let aes_8: AESTowerField8b = readable.read_scalar().unwrap(); + let aes_32: AESTowerField32b = readable.read_scalar().unwrap(); + let aes_128: AESTowerField128b = readable.read_scalar().unwrap(); assert_eq!(aes_8.val(), 0x52); assert_eq!(aes_32.val(), 0x12345678); @@ -535,7 +513,8 @@ mod tests { assert_eq!(sampled_aes_res, sampled_aes1); - let polyval_128: BinaryField128bPolyval = verifier_transcript.read_scalar().unwrap(); + let polyval_128: BinaryField128bPolyval = + verifier_transcript.message().read_scalar().unwrap(); assert_eq!(polyval_128, BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA)); let sampled_polyval_res: BinaryField128bPolyval = verifier_transcript.sample(); @@ -546,7 +525,8 @@ mod tests { #[test] fn test_advicing() { - let mut advice_writer = AdviceWriter::new(); + let mut prover_transcript = ProverTranscript::>::new(); + let mut advice_writer = prover_transcript.decommitment(); advice_writer.write_scalar(BinaryField8b::new(0x96)); advice_writer.write_scalar(BinaryField32b::new(0xDEADBEEF)); @@ -558,7 +538,8 @@ mod tests { advice_writer.write_scalar(BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA)); - let mut advice_reader = advice_writer.into_reader(); + let mut verifier_transcript = prover_transcript.into_verifier(); + let mut advice_reader = verifier_transcript.decommitment(); let fp_8: BinaryField8b = advice_reader.read_scalar().unwrap(); let fp_32: BinaryField32b = advice_reader.read_scalar().unwrap(); @@ -579,12 +560,13 @@ mod tests { let polyval_128: BinaryField128bPolyval = advice_reader.read_scalar().unwrap(); assert_eq!(polyval_128, BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA)); - advice_reader.finalize().unwrap(); + verifier_transcript.finalize().unwrap(); } #[test] - fn test_challenger() { - let mut transcript = TranscriptWriter::>::new(); + fn test_challenger_and_observing() { + let mut taped_transcript = ProverTranscript::>::new(); + let mut untaped_transcript = ProverTranscript::>::new(); let mut challenger = HasherChallenger::::default(); const NUM_SAMPLING: usize = 32; @@ -593,48 +575,63 @@ mod tests { let mut sampled_arrays = [[0u8; 8]; NUM_SAMPLING]; for i in 0..NUM_SAMPLING { - transcript.write_scalar(BinaryField64b::new(u64::from_le_bytes( - random_bytes[i * 8..i * 8 + 8].to_vec().try_into().unwrap(), - ))); + taped_transcript + .message() + .write_scalar(BinaryField64b::new(u64::from_le_bytes( + random_bytes[i * 8..i * 8 + 8].to_vec().try_into().unwrap(), + ))); + untaped_transcript + .observe() + .write_scalar(BinaryField64b::new(u64::from_le_bytes( + random_bytes[i * 8..i * 8 + 8].to_vec().try_into().unwrap(), + ))); challenger .observer() .put_slice(&random_bytes[i * 8..i * 8 + 8]); - let sampled_out_transcript: BinaryField64b = transcript.sample(); + let sampled_out_transcript1: BinaryField64b = taped_transcript.sample(); + let sampled_out_transcript2: BinaryField64b = untaped_transcript.sample(); let mut challenger_out = [0u8; 8]; challenger.sampler().copy_to_slice(&mut challenger_out); - assert_eq!(challenger_out, sampled_out_transcript.val().to_le_bytes()); + assert_eq!(challenger_out, sampled_out_transcript1.val().to_le_bytes()); + assert_eq!(challenger_out, sampled_out_transcript2.val().to_le_bytes()); sampled_arrays[i] = challenger_out; } - let mut transcript = transcript.into_reader(); + let mut taped_transcript = taped_transcript.into_verifier(); + + assert!(untaped_transcript.finalize().is_empty()); for array in sampled_arrays.into_iter() { - let _: BinaryField64b = transcript.read_scalar().unwrap(); - let sampled_out_transcript: BinaryField64b = transcript.sample(); + let _: BinaryField64b = taped_transcript.message().read_scalar().unwrap(); + let sampled_out_transcript: BinaryField64b = taped_transcript.sample(); assert_eq!(array, sampled_out_transcript.val().to_le_bytes()); } - transcript.finalize().unwrap(); + taped_transcript.finalize().unwrap(); } #[test] fn test_transcript_debug() { - let mut transcript = TranscriptWriter::>::new(); + let mut transcript = ProverTranscript::>::new(); - transcript.write_debug("test_transcript_debug"); - transcript.into_reader().read_debug("test_transcript_debug"); + transcript.message().write_debug("test_transcript_debug"); + transcript + .into_verifier() + .message() + .read_debug("test_transcript_debug"); } #[test] #[should_panic] fn test_transcript_debug_fail() { - let mut transcript = TranscriptWriter::>::new(); + let mut transcript = ProverTranscript::>::new(); - transcript.write_debug("test_transcript_debug"); + transcript.message().write_debug("test_transcript_debug"); transcript - .into_reader() + .into_verifier() + .message() .read_debug("test_transcript_debug_should_fail"); } } diff --git a/crates/field/Cargo.toml b/crates/field/Cargo.toml index deea2882..36de13de 100644 --- a/crates/field/Cargo.toml +++ b/crates/field/Cargo.toml @@ -4,6 +4,9 @@ version.workspace = true edition.workspace = true authors.workspace = true +[lints] +workspace = true + [dependencies] binius_maybe_rayon = { path = "../maybe_rayon", default-features = false } binius_utils = { path = "../utils", default-features = false } diff --git a/crates/field/src/aes_field.rs b/crates/field/src/aes_field.rs index f3753301..4b2321f9 100644 --- a/crates/field/src/aes_field.rs +++ b/crates/field/src/aes_field.rs @@ -145,7 +145,7 @@ struct SubfieldTransformer { } impl SubfieldTransformer { - fn new(inner_transform: T) -> Self { + const fn new(inner_transform: T) -> Self { Self { inner_transform, _ip_pd: PhantomData, diff --git a/crates/field/src/arch/aarch64/m128.rs b/crates/field/src/arch/aarch64/m128.rs index cdcdb1ea..a32a82b7 100644 --- a/crates/field/src/arch/aarch64/m128.rs +++ b/crates/field/src/arch/aarch64/m128.rs @@ -41,7 +41,7 @@ impl M128 { #[inline] pub fn shuffle_u8(self, src: [u8; 16]) -> Self { - unsafe { vqtbl1q_u8(self.into(), M128::from_le_bytes(src).into()).into() } + unsafe { vqtbl1q_u8(self.into(), Self::from_le_bytes(src).into()).into() } } } @@ -405,7 +405,7 @@ impl UnderlierWithBitConstants for M128 { impl From for PackedPrimitiveType { fn from(value: u128) -> Self { - PackedPrimitiveType::from(M128::from(value)) + Self::from(M128::from(value)) } } diff --git a/crates/field/src/arch/aarch64/packed_128.rs b/crates/field/src/arch/aarch64/packed_128.rs index 837578a4..80778be4 100644 --- a/crates/field/src/arch/aarch64/packed_128.rs +++ b/crates/field/src/arch/aarch64/packed_128.rs @@ -121,6 +121,6 @@ impl_transformation_with_strategy!(PackedBinaryField1x128b, PairwiseStrategy); impl From for PackedBinaryField16x8b { fn from(value: PackedAESBinaryField16x8b) -> Self { - PackedBinaryField16x8b::from_underlier(packed_aes_16x8b_into_tower(value.to_underlier())) + Self::from_underlier(packed_aes_16x8b_into_tower(value.to_underlier())) } } diff --git a/crates/field/src/arch/aarch64/packed_aes_128.rs b/crates/field/src/arch/aarch64/packed_aes_128.rs index 985740c9..1a7fb69f 100644 --- a/crates/field/src/arch/aarch64/packed_aes_128.rs +++ b/crates/field/src/arch/aarch64/packed_aes_128.rs @@ -107,6 +107,6 @@ impl_transformation_with_strategy!(PackedAESBinaryField1x128b, PairwiseStrategy) impl From for PackedAESBinaryField16x8b { fn from(value: PackedBinaryField16x8b) -> Self { - PackedAESBinaryField16x8b::from_underlier(packed_tower_16x8b_into_aes(value.to_underlier())) + Self::from_underlier(packed_tower_16x8b_into_aes(value.to_underlier())) } } diff --git a/crates/field/src/arch/portable/packed.rs b/crates/field/src/arch/portable/packed.rs index c62a801d..ad372dd9 100644 --- a/crates/field/src/arch/portable/packed.rs +++ b/crates/field/src/arch/portable/packed.rs @@ -52,7 +52,7 @@ impl PackedPrimitiveType { }; #[inline] - pub fn from_underlier(val: U) -> Self { + pub const fn from_underlier(val: U) -> Self { Self(val, PhantomData) } @@ -279,12 +279,11 @@ unsafe impl Zeroable unsafe impl Pod for PackedPrimitiveType {} -impl PackedField for PackedPrimitiveType +impl PackedField for PackedPrimitiveType where Self: Broadcast + Square + InvertOrZero + Mul, - U: UnderlierWithBitConstants + Send + Sync + 'static, - Scalar: WithUnderlier, - U: From, + U: UnderlierWithBitConstants + From + Send + Sync + 'static, + Scalar: BinaryField + WithUnderlier, Scalar::Underlier: NumCast, IterationMethods: IterationStrategy, { diff --git a/crates/field/src/arch/portable/packed_arithmetic.rs b/crates/field/src/arch/portable/packed_arithmetic.rs index 366a3999..c15f65fa 100644 --- a/crates/field/src/arch/portable/packed_arithmetic.rs +++ b/crates/field/src/arch/portable/packed_arithmetic.rs @@ -330,7 +330,7 @@ impl PackedTransformation where OP: PackedBinaryField, { - pub fn new>( + pub fn new + Sync>( transformation: FieldLinearTransformation, ) -> Self { Self { @@ -365,7 +365,7 @@ where let ones = OP::one().to_underlier(); let mut input = input.to_underlier(); - for base in self.bases.iter() { + for base in &self.bases { let base_component = input & ones; // contains ones at positions which correspond to non-zero components let mask = broadcast_lowest_bit(base_component, OF::LOG_DEGREE); @@ -382,9 +382,9 @@ where IP: PackedBinaryField + WithUnderlier, OP: PackedBinaryField + WithUnderlier, { - type PackedTransformation> = PackedTransformation; + type PackedTransformation + Sync> = PackedTransformation; - fn make_packed_transformation>( + fn make_packed_transformation + Sync>( transformation: FieldLinearTransformation, ) -> Self::PackedTransformation { PackedTransformation::new(transformation) diff --git a/crates/field/src/arch/portable/packed_polyval_128.rs b/crates/field/src/arch/portable/packed_polyval_128.rs index 87cc186a..cbde5354 100644 --- a/crates/field/src/arch/portable/packed_polyval_128.rs +++ b/crates/field/src/arch/portable/packed_polyval_128.rs @@ -156,7 +156,7 @@ fn bmul64(x: u64, y: u64) -> u64 { } /// Bit-reverse a `u64` in constant time -fn rev64(mut x: u64) -> u64 { +const fn rev64(mut x: u64) -> u64 { x = ((x & 0x5555_5555_5555_5555) << 1) | ((x >> 1) & 0x5555_5555_5555_5555); x = ((x & 0x3333_3333_3333_3333) << 2) | ((x >> 2) & 0x3333_3333_3333_3333); x = ((x & 0x0f0f_0f0f_0f0f_0f0f) << 4) | ((x >> 4) & 0x0f0f_0f0f_0f0f_0f0f); diff --git a/crates/field/src/arch/portable/packed_scaled.rs b/crates/field/src/arch/portable/packed_scaled.rs index 6cada6d7..a4b55a2c 100644 --- a/crates/field/src/arch/portable/packed_scaled.rs +++ b/crates/field/src/arch/portable/packed_scaled.rs @@ -313,7 +313,7 @@ pub struct ScaledTransformation { } impl ScaledTransformation { - fn new(inner: I) -> Self { + const fn new(inner: I) -> Self { Self { inner } } } @@ -331,15 +331,15 @@ where impl PackedTransformationFactory> for ScaledPackedField where - ScaledPackedField: PackedBinaryField, + Self: PackedBinaryField, ScaledPackedField: PackedBinaryField, OP: PackedBinaryField, IP: PackedTransformationFactory, { - type PackedTransformation> = + type PackedTransformation + Sync> = ScaledTransformation>; - fn make_packed_transformation>( + fn make_packed_transformation + Sync>( transformation: FieldLinearTransformation< as PackedField>::Scalar, Data, @@ -475,8 +475,7 @@ impl PackScalar for ScaledUnderlier where U: PackScalar + UnderlierType + Pod, F: Field, - ScaledPackedField: - PackedField + WithUnderlier>, + ScaledPackedField: PackedField + WithUnderlier, { type Packed = ScaledPackedField; } diff --git a/crates/field/src/arch/portable/pairwise_arithmetic.rs b/crates/field/src/arch/portable/pairwise_arithmetic.rs index aed414a9..96e92c48 100644 --- a/crates/field/src/arch/portable/pairwise_arithmetic.rs +++ b/crates/field/src/arch/portable/pairwise_arithmetic.rs @@ -75,7 +75,7 @@ pub struct PairwiseTransformation { } impl PairwiseTransformation { - pub fn new(inner: I) -> Self { + pub const fn new(inner: I) -> Self { Self { inner } } } @@ -96,10 +96,10 @@ where IP: PackedBinaryField, OP: PackedBinaryField, { - type PackedTransformation> = + type PackedTransformation + Sync> = PairwiseTransformation>; - fn make_packed_transformation>( + fn make_packed_transformation + Sync>( transformation: FieldLinearTransformation, ) -> Self::PackedTransformation { PairwiseTransformation::new(transformation) diff --git a/crates/field/src/arch/portable/pairwise_table_arithmetic.rs b/crates/field/src/arch/portable/pairwise_table_arithmetic.rs index d10ffcbc..b008d5b0 100644 --- a/crates/field/src/arch/portable/pairwise_table_arithmetic.rs +++ b/crates/field/src/arch/portable/pairwise_table_arithmetic.rs @@ -29,7 +29,7 @@ where } } -fn mul_binary_tower_4b(a: u8, b: u8) -> u8 { +const fn mul_binary_tower_4b(a: u8, b: u8) -> u8 { #[rustfmt::skip] const MUL_4B_LOOKUP: [u8; 128] = [ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, diff --git a/crates/field/src/arch/portable/underlier_constants.rs b/crates/field/src/arch/portable/underlier_constants.rs index ef19526f..f834e74d 100644 --- a/crates/field/src/arch/portable/underlier_constants.rs +++ b/crates/field/src/arch/portable/underlier_constants.rs @@ -31,90 +31,90 @@ impl UnderlierWithBitConstants for U4 { impl UnderlierWithBitConstants for u8 { const INTERLEAVE_EVEN_MASK: &'static [Self] = &[ - interleave_mask_even!(u8, 0), - interleave_mask_even!(u8, 1), - interleave_mask_even!(u8, 2), + interleave_mask_even!(Self, 0), + interleave_mask_even!(Self, 1), + interleave_mask_even!(Self, 2), ]; const INTERLEAVE_ODD_MASK: &'static [Self] = &[ - interleave_mask_odd!(u8, 0), - interleave_mask_odd!(u8, 1), - interleave_mask_odd!(u8, 2), + interleave_mask_odd!(Self, 0), + interleave_mask_odd!(Self, 1), + interleave_mask_odd!(Self, 2), ]; } impl UnderlierWithBitConstants for u16 { const INTERLEAVE_EVEN_MASK: &'static [Self] = &[ - interleave_mask_even!(u16, 0), - interleave_mask_even!(u16, 1), - interleave_mask_even!(u16, 2), - interleave_mask_even!(u16, 3), + interleave_mask_even!(Self, 0), + interleave_mask_even!(Self, 1), + interleave_mask_even!(Self, 2), + interleave_mask_even!(Self, 3), ]; const INTERLEAVE_ODD_MASK: &'static [Self] = &[ - interleave_mask_odd!(u16, 0), - interleave_mask_odd!(u16, 1), - interleave_mask_odd!(u16, 2), - interleave_mask_odd!(u16, 3), + interleave_mask_odd!(Self, 0), + interleave_mask_odd!(Self, 1), + interleave_mask_odd!(Self, 2), + interleave_mask_odd!(Self, 3), ]; } impl UnderlierWithBitConstants for u32 { const INTERLEAVE_EVEN_MASK: &'static [Self] = &[ - interleave_mask_even!(u32, 0), - interleave_mask_even!(u32, 1), - interleave_mask_even!(u32, 2), - interleave_mask_even!(u32, 3), - interleave_mask_even!(u32, 4), + interleave_mask_even!(Self, 0), + interleave_mask_even!(Self, 1), + interleave_mask_even!(Self, 2), + interleave_mask_even!(Self, 3), + interleave_mask_even!(Self, 4), ]; const INTERLEAVE_ODD_MASK: &'static [Self] = &[ - interleave_mask_odd!(u32, 0), - interleave_mask_odd!(u32, 1), - interleave_mask_odd!(u32, 2), - interleave_mask_odd!(u32, 3), - interleave_mask_odd!(u32, 4), + interleave_mask_odd!(Self, 0), + interleave_mask_odd!(Self, 1), + interleave_mask_odd!(Self, 2), + interleave_mask_odd!(Self, 3), + interleave_mask_odd!(Self, 4), ]; } impl UnderlierWithBitConstants for u64 { const INTERLEAVE_EVEN_MASK: &'static [Self] = &[ - interleave_mask_even!(u64, 0), - interleave_mask_even!(u64, 1), - interleave_mask_even!(u64, 2), - interleave_mask_even!(u64, 3), - interleave_mask_even!(u64, 4), - interleave_mask_even!(u64, 5), + interleave_mask_even!(Self, 0), + interleave_mask_even!(Self, 1), + interleave_mask_even!(Self, 2), + interleave_mask_even!(Self, 3), + interleave_mask_even!(Self, 4), + interleave_mask_even!(Self, 5), ]; const INTERLEAVE_ODD_MASK: &'static [Self] = &[ - interleave_mask_odd!(u64, 0), - interleave_mask_odd!(u64, 1), - interleave_mask_odd!(u64, 2), - interleave_mask_odd!(u64, 3), - interleave_mask_odd!(u64, 4), - interleave_mask_odd!(u64, 5), + interleave_mask_odd!(Self, 0), + interleave_mask_odd!(Self, 1), + interleave_mask_odd!(Self, 2), + interleave_mask_odd!(Self, 3), + interleave_mask_odd!(Self, 4), + interleave_mask_odd!(Self, 5), ]; } impl UnderlierWithBitConstants for u128 { const INTERLEAVE_EVEN_MASK: &'static [Self] = &[ - interleave_mask_even!(u128, 0), - interleave_mask_even!(u128, 1), - interleave_mask_even!(u128, 2), - interleave_mask_even!(u128, 3), - interleave_mask_even!(u128, 4), - interleave_mask_even!(u128, 5), - interleave_mask_even!(u128, 6), + interleave_mask_even!(Self, 0), + interleave_mask_even!(Self, 1), + interleave_mask_even!(Self, 2), + interleave_mask_even!(Self, 3), + interleave_mask_even!(Self, 4), + interleave_mask_even!(Self, 5), + interleave_mask_even!(Self, 6), ]; const INTERLEAVE_ODD_MASK: &'static [Self] = &[ - interleave_mask_odd!(u128, 0), - interleave_mask_odd!(u128, 1), - interleave_mask_odd!(u128, 2), - interleave_mask_odd!(u128, 3), - interleave_mask_odd!(u128, 4), - interleave_mask_odd!(u128, 5), - interleave_mask_odd!(u128, 6), + interleave_mask_odd!(Self, 0), + interleave_mask_odd!(Self, 1), + interleave_mask_odd!(Self, 2), + interleave_mask_odd!(Self, 3), + interleave_mask_odd!(Self, 4), + interleave_mask_odd!(Self, 5), + interleave_mask_odd!(Self, 6), ]; } diff --git a/crates/field/src/arch/x86_64/gfni/gfni_arithmetics.rs b/crates/field/src/arch/x86_64/gfni/gfni_arithmetics.rs index a0d595d6..3aebce24 100644 --- a/crates/field/src/arch/x86_64/gfni/gfni_arithmetics.rs +++ b/crates/field/src/arch/x86_64/gfni/gfni_arithmetics.rs @@ -137,7 +137,7 @@ pub(super) fn get_8x8_matrix( ) -> i64 where OF: BinaryField>, - Data: Deref, + Data: Deref + Sync, { transpose_8x8(i64::from_le_bytes(array::from_fn(|k| { transformation.bases()[k + 8 * col] @@ -152,7 +152,7 @@ where OP: WithUnderlier + PackedBinaryField>, { - pub fn new>( + pub fn new + Sync>( transformation: FieldLinearTransformation, ) -> Self { debug_assert_eq!(OP::Scalar::N_BITS, 8); @@ -185,9 +185,9 @@ where + WithUnderlier, U: GfniType, { - type PackedTransformation::Scalar]>> = GfniTransformation; + type PackedTransformation::Scalar]> + Sync> = GfniTransformation; - fn make_packed_transformation>( + fn make_packed_transformation + Sync>( transformation: FieldLinearTransformation, ) -> Self::PackedTransformation { GfniTransformation::new(transformation) @@ -221,7 +221,7 @@ where + PackedBinaryField>>, [[OP::Underlier; MATRICES]; BLOCKS]: Default, { - pub fn new>( + pub fn new + Sync>( transformation: FieldLinearTransformation, ) -> Self { debug_assert_eq!(OP::Scalar::N_BITS, BLOCKS * 8); @@ -249,7 +249,7 @@ where } let byte_indices = array::from_fn(|i| { - // all shuffle indices are repated with cycle 8. + // all shuffle indices are repeated with cycle 8. half_u128_lane[i % 8] }); let mask_u128 = u128::from_le_bytes(byte_indices); @@ -324,14 +324,14 @@ macro_rules! impl_transformation_with_gfni_nxn { >, { type PackedTransformation< - Data: std::ops::Deref::Scalar]>, + Data: std::ops::Deref::Scalar]> + Sync, > = $crate::arch::x86_64::gfni::gfni_arithmetics::GfniTransformationNxN< OP, $blocks, { $blocks / 2 }, >; - fn make_packed_transformation>( + fn make_packed_transformation + Sync>( transformation: $crate::linear_transformation::FieldLinearTransformation< OP::Scalar, Data, diff --git a/crates/field/src/arch/x86_64/gfni/m256.rs b/crates/field/src/arch/x86_64/gfni/m256.rs index ee50b236..79b3ff36 100644 --- a/crates/field/src/arch/x86_64/gfni/m256.rs +++ b/crates/field/src/arch/x86_64/gfni/m256.rs @@ -46,7 +46,7 @@ impl GfniTransformation256b { pub fn new(transformation: FieldLinearTransformation) -> Self where OP: PackedField> + WithUnderlier, - Data: Deref, + Data: Deref + Sync, { let bases_8x8 = array::from_fn(|col| { array::from_fn(|row| unsafe { @@ -134,9 +134,9 @@ where OP: PackedField> + WithUnderlier, { - type PackedTransformation::Scalar]>> = GfniTransformation256b; + type PackedTransformation::Scalar]> + Sync> = GfniTransformation256b; - fn make_packed_transformation>( + fn make_packed_transformation + Sync>( transformation: FieldLinearTransformation, ) -> Self::PackedTransformation { GfniTransformation256b::new::(transformation) diff --git a/crates/field/src/arch/x86_64/gfni/m512.rs b/crates/field/src/arch/x86_64/gfni/m512.rs index 6d35d56a..1d6b67f7 100644 --- a/crates/field/src/arch/x86_64/gfni/m512.rs +++ b/crates/field/src/arch/x86_64/gfni/m512.rs @@ -45,7 +45,7 @@ impl GfniTransformation512b { pub fn new(transformation: FieldLinearTransformation) -> Self where OP: PackedField> + WithUnderlier, - Data: Deref, + Data: Deref + Sync, { let bases_8x8 = array::from_fn(|col| { array::from_fn(|row| unsafe { @@ -126,9 +126,9 @@ where OP: PackedField> + WithUnderlier, { - type PackedTransformation::Scalar]>> = GfniTransformation512b; + type PackedTransformation::Scalar]> + Sync> = GfniTransformation512b; - fn make_packed_transformation>( + fn make_packed_transformation + Sync>( transformation: FieldLinearTransformation, ) -> Self::PackedTransformation { GfniTransformation512b::new::(transformation) diff --git a/crates/field/src/arch/x86_64/m128.rs b/crates/field/src/arch/x86_64/m128.rs index 79576b39..f75d152d 100644 --- a/crates/field/src/arch/x86_64/m128.rs +++ b/crates/field/src/arch/x86_64/m128.rs @@ -92,7 +92,7 @@ impl From> for M128 { impl From for u128 { fn from(value: M128) -> Self { let mut result = 0u128; - unsafe { _mm_storeu_si128(&mut result as *mut u128 as *mut __m128i, value.0) }; + unsafe { _mm_storeu_si128(&mut result as *mut Self as *mut __m128i, value.0) }; result } @@ -499,9 +499,8 @@ impl UnderlierWithBitOps for M128 { #[inline(always)] unsafe fn spread(self, log_block_len: usize, block_idx: usize) -> Self where - T: UnderlierWithBitOps, + T: UnderlierWithBitOps + NumCast, Self: From, - T: NumCast, { match T::LOG_BITS { 0 => match log_block_len { @@ -782,24 +781,24 @@ impl UnderlierWithBitConstants for M128 { fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) { unsafe { let (c, d) = interleave_bits( - Into::::into(self).into(), - Into::::into(other).into(), + Into::::into(self).into(), + Into::::into(other).into(), log_block_len, ); - (M128::from(c), M128::from(d)) + (Self::from(c), Self::from(d)) } } } impl From<__m128i> for PackedPrimitiveType { fn from(value: __m128i) -> Self { - PackedPrimitiveType::from(M128::from(value)) + M128::from(value).into() } } impl From for PackedPrimitiveType { fn from(value: u128) -> Self { - PackedPrimitiveType::from(M128::from(value)) + M128::from(value).into() } } diff --git a/crates/field/src/arch/x86_64/simd/simd_arithmetic.rs b/crates/field/src/arch/x86_64/simd/simd_arithmetic.rs index 55d15c13..a131f774 100644 --- a/crates/field/src/arch/x86_64/simd/simd_arithmetic.rs +++ b/crates/field/src/arch/x86_64/simd/simd_arithmetic.rs @@ -301,7 +301,7 @@ impl SimdTransformation where OP: PackedBinaryField + WithUnderlier, { - pub fn new>( + pub fn new + Sync>( transformation: FieldLinearTransformation, ) -> Self { Self { @@ -349,9 +349,9 @@ where OP: PackedBinaryField + WithUnderlier, IP::Underlier: TowerSimdType, { - type PackedTransformation::Scalar]>> = SimdTransformation; + type PackedTransformation::Scalar]> + Sync> = SimdTransformation; - fn make_packed_transformation>( + fn make_packed_transformation + Sync>( transformation: FieldLinearTransformation, ) -> Self::PackedTransformation { SimdTransformation::new(transformation) diff --git a/crates/field/src/arithmetic_traits.rs b/crates/field/src/arithmetic_traits.rs index 68e84d0b..c1d320ef 100644 --- a/crates/field/src/arithmetic_traits.rs +++ b/crates/field/src/arithmetic_traits.rs @@ -147,9 +147,9 @@ pub trait TaggedPackedTransformationFactory: PackedBinaryField where OP: PackedBinaryField, { - type PackedTransformation>: Transformation; + type PackedTransformation + Sync>: Transformation; - fn make_packed_transformation>( + fn make_packed_transformation + Sync>( transformation: FieldLinearTransformation, ) -> Self::PackedTransformation; } @@ -164,13 +164,13 @@ macro_rules! impl_transformation_with_strategy { >, Self: $crate::arithmetic_traits::TaggedPackedTransformationFactory<$strategy, OP>, { - type PackedTransformation> = + type PackedTransformation + Sync> = >::PackedTransformation; - fn make_packed_transformation>( + fn make_packed_transformation + Sync>( transformation: $crate::linear_transformation::FieldLinearTransformation< OP::Scalar, Data, diff --git a/crates/field/src/binary_field.rs b/crates/field/src/binary_field.rs index fac5df97..be6f8ea8 100644 --- a/crates/field/src/binary_field.rs +++ b/crates/field/src/binary_field.rs @@ -94,7 +94,7 @@ macro_rules! binary_field { Self(value) } - pub fn val(self) -> $typ { + pub const fn val(self) -> $typ { self.0 } } @@ -780,7 +780,7 @@ pub fn deserialize_canonical( impl From for Choice { fn from(val: BinaryField1b) -> Self { - Choice::from(val.val().val()) + Self::from(val.val().val()) } } diff --git a/crates/field/src/linear_transformation.rs b/crates/field/src/linear_transformation.rs index b5ea8c1a..f6cd32d3 100644 --- a/crates/field/src/linear_transformation.rs +++ b/crates/field/src/linear_transformation.rs @@ -7,7 +7,7 @@ use rand::RngCore; use crate::{packed::PackedBinaryField, BinaryField, BinaryField1b, ExtensionField}; /// Generic transformation trait that is used both for scalars and packed fields -pub trait Transformation { +pub trait Transformation: Sync { fn transform(&self, data: &Input) -> Output; } @@ -17,7 +17,10 @@ pub trait Transformation { /// parameter because we want to be able both to have const instances that reference static arrays /// and owning vector elements. #[derive(Debug, Clone)] -pub struct FieldLinearTransformation = &'static [OF]> { +pub struct FieldLinearTransformation< + OF: BinaryField, + Data: Deref + Sync = &'static [OF], +> { bases: Data, } @@ -29,7 +32,7 @@ impl FieldLinearTransformation { } } -impl> FieldLinearTransformation { +impl + Sync> FieldLinearTransformation { pub fn new(bases: Data) -> Self { debug_assert_eq!(bases.deref().len(), OF::DEGREE); @@ -41,7 +44,7 @@ impl> FieldLinearTransformation> Transformation +impl + Sync> Transformation for FieldLinearTransformation { fn transform(&self, data: &IF) -> OF { @@ -68,9 +71,9 @@ pub trait PackedTransformationFactory: PackedBinaryField where OP: PackedBinaryField, { - type PackedTransformation>: Transformation; + type PackedTransformation + Sync>: Transformation; - fn make_packed_transformation>( + fn make_packed_transformation + Sync>( transformation: FieldLinearTransformation, ) -> Self::PackedTransformation; } diff --git a/crates/field/src/packed.rs b/crates/field/src/packed.rs index 8255a1ff..8d807121 100644 --- a/crates/field/src/packed.rs +++ b/crates/field/src/packed.rs @@ -340,7 +340,7 @@ pub fn set_packed_slice_checked( }) } -pub fn len_packed_slice(packed: &[P]) -> usize { +pub const fn len_packed_slice(packed: &[P]) -> usize { packed.len() * P::WIDTH } diff --git a/crates/field/src/packed_aes_field.rs b/crates/field/src/packed_aes_field.rs index 3f567b71..92005154 100644 --- a/crates/field/src/packed_aes_field.rs +++ b/crates/field/src/packed_aes_field.rs @@ -443,7 +443,7 @@ mod tests { /// Compile-time test to ensure packed fields implement `PackedTransformationFactory`. #[allow(unused)] - fn test_implement_transformation_factory() { + const fn test_implement_transformation_factory() { // 8 bit packed aes tower implements_transformation_factory::(); implements_transformation_factory::(); diff --git a/crates/field/src/packed_binary_field.rs b/crates/field/src/packed_binary_field.rs index dffaff16..b22f48fe 100644 --- a/crates/field/src/packed_binary_field.rs +++ b/crates/field/src/packed_binary_field.rs @@ -773,7 +773,7 @@ pub mod test_utils { /// Helper function for compile-time checks #[allow(unused)] - pub fn implements_transformation_factory< + pub const fn implements_transformation_factory< P1: PackedField, P2: PackedTransformationFactory, >() { @@ -1020,7 +1020,7 @@ mod tests { /// Compile-time test to ensure packed fields implement `PackedTransformationFactory`. #[allow(unused)] - fn test_implement_transformation_factory() { + const fn test_implement_transformation_factory() { // 1 bit packed binary tower implements_transformation_factory::(); diff --git a/crates/field/src/packed_polyval.rs b/crates/field/src/packed_polyval.rs index 06390041..a0f0cdf8 100644 --- a/crates/field/src/packed_polyval.rs +++ b/crates/field/src/packed_polyval.rs @@ -194,7 +194,7 @@ mod tests { /// Compile-time test to ensure packed fields implement `PackedTransformationFactory`. #[allow(unused)] - fn test_implement_transformation_factory() { + const fn test_implement_transformation_factory() { // 128 bit packed polyval implements_transformation_factory::(); implements_transformation_factory::(); diff --git a/crates/field/src/polyval.rs b/crates/field/src/polyval.rs index 40984ec0..989e12ee 100644 --- a/crates/field/src/polyval.rs +++ b/crates/field/src/polyval.rs @@ -15,6 +15,7 @@ use rand::{Rng, RngCore}; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; use super::{ + aes_field::AESTowerField128b, arithmetic_traits::InvertOrZero, binary_field::{BinaryField, BinaryField128b, BinaryField1b, TowerField}, error::Error, @@ -253,8 +254,8 @@ impl Square for BinaryField128bPolyval { } impl Field for BinaryField128bPolyval { - const ZERO: Self = BinaryField128bPolyval(0); - const ONE: Self = BinaryField128bPolyval(0xc2000000000000000000000000000001); + const ZERO: Self = Self(0); + const ONE: Self = Self(0xc2000000000000000000000000000001); fn random(mut rng: impl RngCore) -> Self { Self(rng.gen()) @@ -445,8 +446,7 @@ impl ExtensionField for BinaryField128bPolyval { } impl BinaryField for BinaryField128bPolyval { - const MULTIPLICATIVE_GENERATOR: BinaryField128bPolyval = - BinaryField128bPolyval(0x72bdf2504ce49c03105433c1c25a4a7); + const MULTIPLICATIVE_GENERATOR: Self = Self(0x72bdf2504ce49c03105433c1c25a4a7); } impl TowerField for BinaryField128bPolyval { @@ -730,11 +730,287 @@ pub const POLYVAL_TO_BINARY_TRANSFORMATION: FieldLinearTransformation for BinaryField128b { - fn from(value: BinaryField128bPolyval) -> BinaryField128b { + fn from(value: BinaryField128bPolyval) -> Self { POLYVAL_TO_BINARY_TRANSFORMATION.transform(&value) } } +pub const AES_TO_POLYVAL_TRANSFORMATION: FieldLinearTransformation = + FieldLinearTransformation::new_const(&[ + BinaryField128bPolyval(0xc2000000000000000000000000000001), + BinaryField128bPolyval(0xe632e878241983acfe888a04c4d9a761), + BinaryField128bPolyval(0xac11ddf4a4b79d5c48ac4c527597b579), + BinaryField128bPolyval(0x6b9e5d3f1b690b05f3313f030e46356c), + BinaryField128bPolyval(0x2b04f6e5ed1de8f556e7d64ddd06e9cb), + BinaryField128bPolyval(0x31001e7abbe11a74c26378b8a5589564), + BinaryField128bPolyval(0xa7698d9fd5f16f53cb2ea07a2e92f955), + BinaryField128bPolyval(0xfc2bf21f1b48c91511a841fb19894992), + BinaryField128bPolyval(0x586704bda927015fedb8ddceb7f825d6), + BinaryField128bPolyval(0x141f1af5b6fc687390fa434e9b3df535), + BinaryField128bPolyval(0xe2fab31ae2a86c482d15591868e50692), + BinaryField128bPolyval(0x5b1ab4f647466009452d4152d4a2d9b7), + BinaryField128bPolyval(0x5e0f7136a0b09b8039655d2dea094bf2), + BinaryField128bPolyval(0x6f2075bc8788f28152a66d96ce4680bb), + BinaryField128bPolyval(0x4140c7bd1f7aedd86b92e5fd101ee1c6), + BinaryField128bPolyval(0xdde2a8ec4d0e54eeb5a4a25f51c6e4fa), + BinaryField128bPolyval(0xc50e3a207f91d7cd6dd1e116d55455fb), + BinaryField128bPolyval(0xfa1ac734a9812f783652ef8b68356a41), + BinaryField128bPolyval(0x36513023ad98424cb71c04fe89e160a7), + BinaryField128bPolyval(0x049537ba21f47f9b04d482dde77f1d35), + BinaryField128bPolyval(0x62da2377f5423631f244f3eb099cf2b7), + BinaryField128bPolyval(0x4a23f578b5ea2846dcc6c290ef1e8aaa), + BinaryField128bPolyval(0x6e95c9eedf7f47b3d594d365d23f0664), + BinaryField128bPolyval(0xd2a1e8b6757668d5c29a321b50d6f02d), + BinaryField128bPolyval(0xf4400f37acedc7d9502abeff6cead84c), + BinaryField128bPolyval(0xf200b20bda2bad094b52961d78c3b76d), + BinaryField128bPolyval(0x6d80e955976082ba5db58a84889d3418), + BinaryField128bPolyval(0x44c1b7aca2c318b69501d626d8e3e1be), + BinaryField128bPolyval(0xfc140c4a4801a6f1ca47bea4142a8e09), + BinaryField128bPolyval(0xe7dc049975b85a68922acd362cba0aae), + BinaryField128bPolyval(0x0d2181f080634f18c69d05ea5068dcc7), + BinaryField128bPolyval(0x66b185642341f6a71c11a443ec30bcfa), + BinaryField128bPolyval(0xa26153cb8c150af8243ecbd46378e59e), + BinaryField128bPolyval(0x8f3b44831e624145fb0b4dffddbd0338), + BinaryField128bPolyval(0x238a42154a23b1278ba6133fa32887d2), + BinaryField128bPolyval(0xaea5e0bd0f23bb3755ca8a198e51a02c), + BinaryField128bPolyval(0x382af0a162eb58f6888bd591d34850ee), + BinaryField128bPolyval(0x7c7ab1035fd703fdaef544d7f152f9ff), + BinaryField128bPolyval(0xd81f70c2928c2a2e45c3ff8900f225b7), + BinaryField128bPolyval(0x05d5f641d32186b75064f07fefaade44), + BinaryField128bPolyval(0x00a3a4d8c163c95ac7ac9a5b424e1c65), + BinaryField128bPolyval(0xdc951260493c96fca603481ab501d438), + BinaryField128bPolyval(0x99400402d352c6a6879277fa8e022149), + BinaryField128bPolyval(0x3bebf7af750eace1e434f9a5925288ee), + BinaryField128bPolyval(0x9f171f736eff43513721ae2942afe01d), + BinaryField128bPolyval(0xe3e184abe50f7387c5fdd01faf6c95d3), + BinaryField128bPolyval(0x1fae32af2dc4238dcce57975be1b2400), + BinaryField128bPolyval(0x282116c0f04b6707698f1ea25790ea10), + BinaryField128bPolyval(0x564032a0d5d3773b7b7ed18bebf1c668), + BinaryField128bPolyval(0xe40d8bdaad8fdd00b3f004e706e35b10), + BinaryField128bPolyval(0x675892c7e2ead5594ef74c71079069d2), + BinaryField128bPolyval(0x25e285580c933861739d2b031eb4b2d3), + BinaryField128bPolyval(0x51e78b32d4dd25445b2d1f30689d6abb), + BinaryField128bPolyval(0x133994dbfd8ce08ae714538d557eb150), + BinaryField128bPolyval(0x278247597906a3b1f990119cde5ffb24), + BinaryField128bPolyval(0xef387e28a39a63ed81710b9bdbc74005), + BinaryField128bPolyval(0x41bbc7902188f4e9880f1d81fa580ffb), + BinaryField128bPolyval(0x415afa934ada855ba61bbef36d27db58), + BinaryField128bPolyval(0xe477f6fac6a1c2057662a149aa0ae061), + BinaryField128bPolyval(0xe98aee0eeb136ee02e2be740d058fe5d), + BinaryField128bPolyval(0x777ead11e8dc98c5ffd710b823fc5093), + BinaryField128bPolyval(0xa4e9927e7da484643651b5532106b1e3), + BinaryField128bPolyval(0x9fdfc576fcf0b33b829f1a052e9355f2), + BinaryField128bPolyval(0x342bbe0ad297a239b415fc050f1a23f7), + BinaryField128bPolyval(0x756209f0893e96877833c194b9c943a0), + BinaryField128bPolyval(0xc763ed07e05784c3ca283e12f9f22368), + BinaryField128bPolyval(0xb04147b7592c8c80508e1ee45b2c4806), + BinaryField128bPolyval(0x4f0557d1988f518b39bd6fc3c2fba372), + BinaryField128bPolyval(0x7259d355775035632bfb7b003178ae0e), + BinaryField128bPolyval(0x7826c6a2e9e37bbd991ef41faa246832), + BinaryField128bPolyval(0x4d12f861b14f57602cd121d6622efcf6), + BinaryField128bPolyval(0x47979b6dc50802e344b543260f9f14e4), + BinaryField128bPolyval(0xf330cd9a74a5e884988c3f36567d26f4), + BinaryField128bPolyval(0xb9d4fca088a982f6a9add501644e56b2), + BinaryField128bPolyval(0xee9e3f0fbab5cc33378947fd04769519), + BinaryField128bPolyval(0x0819f0cb4253a5ab7cb10f583ce13537), + BinaryField128bPolyval(0x0ba10d161c76ba69a4b68d140886097a), + BinaryField128bPolyval(0x7850cec4236f7ea4698d1b15707a8bf8), + BinaryField128bPolyval(0x5a712763179ba0a99e8bbe5e3b73146f), + BinaryField128bPolyval(0x8d825f45c33f7a1f45be2e3938a0fcfc), + BinaryField128bPolyval(0xf8f3c35c671bac841b14381a592e6cdd), + BinaryField128bPolyval(0x96d15acd59e4c4852735c30c972140ee), + BinaryField128bPolyval(0x87d0c6a97af12deec54ecfd097c5a4ff), + BinaryField128bPolyval(0x4152fe90327dcbe147e8771ba0334ba1), + BinaryField128bPolyval(0xb6793169bc400bfc14ed05b58db2b472), + BinaryField128bPolyval(0x071df72b3ce39b686d6f52b53e608c7a), + BinaryField128bPolyval(0xcf97a03df90400718aff5257888970f5), + BinaryField128bPolyval(0xa05e7602d3f74d882329c158a0ad9f37), + BinaryField128bPolyval(0x69d3b9b1d9483174b3c38d8f95ce7a5f), + BinaryField128bPolyval(0x54882e1b0f3f749397cbeff7c70f0c73), + BinaryField128bPolyval(0x1df3271f8f5398ff2937a4f0fd041cfa), + BinaryField128bPolyval(0xc964a4d09783b7483c1845943333022b), + BinaryField128bPolyval(0x64991e811d81abc5cef407bebff096bd), + BinaryField128bPolyval(0x12a6345cacc97a7992ad6647c2833af8), + BinaryField128bPolyval(0xe04112ac095f1c67b9792b0fb82cc4fe), + BinaryField128bPolyval(0x744c1206d550d565d994fe159c5cd699), + BinaryField128bPolyval(0xcd8b92c8a6e0e65e167a2b851f32fd9c), + BinaryField128bPolyval(0x994beaf6226f215c7b4187e0c36e08a6), + BinaryField128bPolyval(0x71d346897647348c7eb7752a3a893424), + BinaryField128bPolyval(0xb0f49cce03da921b5f8999cc18311b43), + BinaryField128bPolyval(0x45b6960f54a16e547a329f79210629d0), + BinaryField128bPolyval(0x4535377d8b9718b1f6991e57fea36922), + BinaryField128bPolyval(0xbf97b9a9bda638f42cc98233021d69fd), + BinaryField128bPolyval(0x02e346c80352ad186c77b88c9f9317d0), + BinaryField128bPolyval(0x2eae07273b8c4fc5cef553a4a46cde5b), + BinaryField128bPolyval(0x86c51811ef12c72fcabfc51b2a2e0c2a), + BinaryField128bPolyval(0x8a828ea9f6b97e5c3b0b4c6faed116e6), + BinaryField128bPolyval(0x0f69898966b112c45f01e70a26142623), + BinaryField128bPolyval(0x109173f0af80af37cf9d6ef791d2feed), + BinaryField128bPolyval(0x984655091ad81e2befa87a6688ddc784), + BinaryField128bPolyval(0x21851b1a985e40395abe3d15fff2d770), + BinaryField128bPolyval(0x045610d75e4ee7b53deb1c4149179a3a), + BinaryField128bPolyval(0x3c37dbe331de4b0dc2f5db315d5e7fda), + BinaryField128bPolyval(0x09ca74968860fc3d723b7966d8574ce1), + BinaryField128bPolyval(0x8892f14b27f8e4d1b01efa51eeaa4ad4), + BinaryField128bPolyval(0xc7339f4d332b0fa99f58a62453d76401), + BinaryField128bPolyval(0xfc81ac07b51d1d165b5525b77bf5f969), + BinaryField128bPolyval(0x7bdcb39270e3891d486160e47bc4015c), + BinaryField128bPolyval(0x7967964f6e9d62b6b50a50ee51c927d2), + BinaryField128bPolyval(0xd11b8526eed516e3dfa7b8e2b17bbf40), + BinaryField128bPolyval(0xe47b83247cb2b162c7d6d15c84cca8ce), + BinaryField128bPolyval(0x5136c420c1a70a4b697e000c637ec876), + BinaryField128bPolyval(0x2114cffeda72b157abb70ae549b39e97), + BinaryField128bPolyval(0x7f72edec22f7d7caac7b78cbca5ce3bb), + BinaryField128bPolyval(0xfb5ac3eb65636373828e242c79ef5046), + BinaryField128bPolyval(0x8819e336afff44542a76ee524a033645), + BinaryField128bPolyval(0x8be0251a2790b20b19f6343efaf425e7), + BinaryField128bPolyval(0x2a49adc1114d5dcf91783fafe0542c8a), + ]); + +impl From for BinaryField128bPolyval { + fn from(value: AESTowerField128b) -> Self { + AES_TO_POLYVAL_TRANSFORMATION.transform(&value) + } +} + +pub const POLYVAL_TO_AES_TRANSFORMARION: FieldLinearTransformation = + FieldLinearTransformation::new_const(&[ + AESTowerField128b(0xaffaa99fa8aa55f93974735e68d0882a), + AESTowerField128b(0x402655567dde6c49c7aea09cc7d73e01), + AESTowerField128b(0x83d724035fe42d2bd4ad27c1ad3be9ae), + AESTowerField128b(0x39940944fe609647237fb001386aff50), + AESTowerField128b(0xce1a381a1790a4d70cbbd7389dd705cd), + AESTowerField128b(0x0ca7f0b9fdade73367e9d9ba5ac3cfbe), + AESTowerField128b(0x70a744fa2401c4fddb871879d718ee08), + AESTowerField128b(0x275ddf74916ecd03aa3c243f74bf3461), + AESTowerField128b(0xcedb8685d1e86e6a74b79cd21c271a02), + AESTowerField128b(0x7a451fd334177a5bdb9e82c9fa373e88), + AESTowerField128b(0xcd4617d8c786c2a3824ae7335ec6b418), + AESTowerField128b(0x32feae47f77fa9c6bb6b917fbb2d96d7), + AESTowerField128b(0xcb2fef14aeabf5b41cfd3ab6774c95b7), + AESTowerField128b(0x3029123a0510641d238bea4746cd5e4b), + AESTowerField128b(0xcfdc169c294eec4bb1eab4bc88a505de), + AESTowerField128b(0x139206e1ace72eed8b9d52fc020e2d9f), + AESTowerField128b(0x891e28b32d8a8320a0a2eb295953bc42), + AESTowerField128b(0xb8704d0efb4be36ea6282c5aaf67fa9e), + AESTowerField128b(0x272f2013fe10244185ad0d672eafa581), + AESTowerField128b(0x28614505a9df5f55ef5bb97c4521eace), + AESTowerField128b(0x6dbdd43dcc19626629ac7f2638e73fff), + AESTowerField128b(0x4d687c2abd4aba97692db8c2a4eb267c), + AESTowerField128b(0x19613bca40bd82828d8255f50d271135), + AESTowerField128b(0xf2772ff3c8d95eb9252d8bd01419641e), + AESTowerField128b(0x6611a71908f4aac4fdb11c08cbedfc8c), + AESTowerField128b(0x10ad82530c31e3a8f757424dca80798a), + AESTowerField128b(0x2fc9972bf59fe5c624714cb8466249ed), + AESTowerField128b(0x9c5c3d8b954a27231e16e4a77fbe4369), + AESTowerField128b(0x2565fdbb105f787cccddf28b530af4ee), + AESTowerField128b(0x473e8ad3e0e8e46e611080cc9350a590), + AESTowerField128b(0x525d2e9d24611e2aa37a1d9a2d42377d), + AESTowerField128b(0x4967e7d5067c2ace89648bf6d95de637), + AESTowerField128b(0x2689f814fa63c8a16ddf2374eeb7c3e7), + AESTowerField128b(0x5e9246c3d2cab88ba27d7cf8ba18c4b4), + AESTowerField128b(0x53616806f5402bc897499fac6e27da63), + AESTowerField128b(0x6c6dbb145a21d1c2d87a93e779f87aba), + AESTowerField128b(0x0f7164762ed37685fd82b05e800cebec), + AESTowerField128b(0x0d01e8acede09d57da4a7325af4b04fc), + AESTowerField128b(0x336b892198ac639af40c74c9252eb99c), + AESTowerField128b(0xf282b8b368b39203e0bbe6246b1b0951), + AESTowerField128b(0x090ca21ee9872dc5a00e669729a69750), + AESTowerField128b(0xa037d253b55003a611faf883fcc8f35e), + AESTowerField128b(0xe4b8a9796561b1ba1f0970a0b26f832a), + AESTowerField128b(0xda203a31e0d6ace125e027a2df265b59), + AESTowerField128b(0xe0ebb1a107ded2a6b0916eff84c18fb4), + AESTowerField128b(0xa9998b7d2cded9a5269f6f8b25147b08), + AESTowerField128b(0x2e8337dd13a279f78ac5d327ec36f632), + AESTowerField128b(0x6264c35e09c7b3bc9b80aa886b194025), + AESTowerField128b(0x9ff92674cc64e8c2e1ee5093298382e8), + AESTowerField128b(0x3e196976f0a90cf1f71847b0fce3a0ac), + AESTowerField128b(0xbdad299364e420e1663e2c09db59634d), + AESTowerField128b(0x9046a0de7ea82a4e8c8a75e001bfdf0e), + AESTowerField128b(0xc6762e8ee83287a13a66789ae533c938), + AESTowerField128b(0x1e17ed399374b0c47e9726bfc71e0c8a), + AESTowerField128b(0xf46ce30110ce034b6a0ba8b8d0af93c6), + AESTowerField128b(0x825f7d3cef67cec50370363e2a6e502c), + AESTowerField128b(0xbdfe9b9ae82bfeff8a58710addb13695), + AESTowerField128b(0x23002be0599a589f6e30a3069cbc71bd), + AESTowerField128b(0x4468951dba52ab1e06efe0dea6d01fa0), + AESTowerField128b(0x28752c7da3ed6a83ca09163f3186b862), + AESTowerField128b(0xe9dd33560ea4a316fdee161ba4946fb1), + AESTowerField128b(0x7e0df8223f37449f266bc8fa70de1ecb), + AESTowerField128b(0x88578550f872e4c8e975a2b66c70cde8), + AESTowerField128b(0xa11ea5aebfe37694ca5ff46e28faf100), + AESTowerField128b(0x3df877fc12016ef181fbf63bf87f5e7c), + AESTowerField128b(0x0a5ca382e7cc37ceb234a5d08d3206ac), + AESTowerField128b(0x6d24b53f98df2626e8e37f977013dbaa), + AESTowerField128b(0x51cc686f72fd7f264962407270cd9394), + AESTowerField128b(0x9749ba0da32ec603a0b342e93049e1fb), + AESTowerField128b(0xed05d63413627e1efd2f3757802a12fa), + AESTowerField128b(0x0e4b2e70136dae8d61528cc479f3aeb0), + AESTowerField128b(0x3e2bda6193a5c936f2c8dc53bf2375b1), + AESTowerField128b(0x9f336d2f107bb812f4d39fb05f19a231), + AESTowerField128b(0x9d21c1be60eba516920f52582c709535), + AESTowerField128b(0x39a51756da0aee24ab5ea3ed62afce31), + AESTowerField128b(0x4404c057a9425458d7fd72eb9e23ac50), + AESTowerField128b(0xe2e5839f2ca60e2f20ad3b15676f583f), + AESTowerField128b(0x8326ccb5e936f3a223d2dada1c00efe0), + AESTowerField128b(0x4293fe13b61b834cf2af7ccea5ac07f4), + AESTowerField128b(0x3c36e03518756760624be5278c4ad469), + AESTowerField128b(0x8ccd1c1dbb224aa30ce78e9062de5884), + AESTowerField128b(0x4c7442a391d3fa91581d07fe2114eea1), + AESTowerField128b(0x7f73613a8ac49cd2c31260dd9835b790), + AESTowerField128b(0x5ea3097b9e7f2734249d6777b4028f95), + AESTowerField128b(0xd4ddcc844b626d7e122c431e3a2e9393), + AESTowerField128b(0xc19d3ffeceed10457bbfd7a9b9064779), + AESTowerField128b(0xb7f15cf7bf9c3b2a87a1e461370be7f5), + AESTowerField128b(0xf20ca8dffdcc5433561e487513103aa8), + AESTowerField128b(0xef6a936a1d9bd32d4502b8c4c5e7ce60), + AESTowerField128b(0x27c89de97a00f322d3118c2b094c06f2), + AESTowerField128b(0x8cd3bbb73240861ad260798b7c232ea2), + AESTowerField128b(0x3e61230942e2c19b9ecb0f80a1b20423), + AESTowerField128b(0x30c59cf7d564c2bc9371d28522419283), + AESTowerField128b(0x17c781e73773c72b8e18d0e6f85fe1ac), + AESTowerField128b(0xd58960c501256b149802cd19ddb19a4b), + AESTowerField128b(0x4cfb558f43862eb923981eacc0719fab), + AESTowerField128b(0xff8c63cbb23240af2d02d0c875335abb), + AESTowerField128b(0x9219023fcc6c5b1154020e9685681b25), + AESTowerField128b(0x57ec8c84b214456eb2b2c42e08cc7529), + AESTowerField128b(0xbda0ccb3b1231f075f78b27dcd578a40), + AESTowerField128b(0xb9396785c80962cfdb0a4117868ab3f6), + AESTowerField128b(0xf3b2bea115d44e307a113ebfdf27ccff), + AESTowerField128b(0x66c226397a967901e4bc6ce235bf4feb), + AESTowerField128b(0x05dd05a7c5e9ac15442189f090a80f9e), + AESTowerField128b(0xed7bea5ee1e167921014c2c7853a679d), + AESTowerField128b(0x37fafed03dfb701af939eafaeeb02074), + AESTowerField128b(0x23491f63f098d208343d5591b7bd626f), + AESTowerField128b(0x3cf04e2f641c505c3e110e87f3e9af91), + AESTowerField128b(0x076e26a8d7181c6de575685a1fe939f9), + AESTowerField128b(0x7b4e4f1b2780c2a5cbf75fe85fc94a58), + AESTowerField128b(0xccd26fddd8f624c8b3f7a2e53a0ae8df), + AESTowerField128b(0xb422ef66e72dbe3798fda5509f63fed2), + AESTowerField128b(0x436f0b1488c5a0680f57919dd4b8fa30), + AESTowerField128b(0xfb3808c6bcacc74fab269021cc58c9df), + AESTowerField128b(0x77bf7b6affefb00594c0b1209a37c97d), + AESTowerField128b(0x36776863a0ce234546d735734b90b7b1), + AESTowerField128b(0x9c0013e65e524467294aa8c70ff414dc), + AESTowerField128b(0xc2c6aeb796ca121f09708acca73499a2), + AESTowerField128b(0xac57847d964c41c97ce4ed9fa3417e90), + AESTowerField128b(0x4c62531aa3c2e5320761c8c64b690e1e), + AESTowerField128b(0xf61ab3912aed1d889336ded4ef4fbae8), + AESTowerField128b(0x5aa06080ab76d88dc5a8a01f48d11ee2), + AESTowerField128b(0x0fc8b68dbc323616ba5a66dcac10f733), + AESTowerField128b(0x7afcf993b86c827a7e290b5e21f0ce48), + AESTowerField128b(0xe7fbd490470b7d4ddd8ef44c2f0ece93), + AESTowerField128b(0xd51ff53d804403832d740176a0cddde2), + AESTowerField128b(0x33c2b575cc0be097362f21d506e9fa17), + AESTowerField128b(0xc6987c6acfd76de3caf3f29426e86cde), + ]); + +impl From for AESTowerField128b { + fn from(value: BinaryField128bPolyval) -> Self { + POLYVAL_TO_AES_TRANSFORMARION.transform(&value) + } +} + #[inline(always)] pub fn is_polyval_tower() -> bool { TypeId::of::() == TypeId::of::() @@ -756,8 +1032,9 @@ mod tests { binary_field::tests::is_binary_field_valid_generator, deserialize_canonical, linear_transformation::PackedTransformationFactory, - serialize_canonical, PackedBinaryField1x128b, PackedBinaryField2x128b, - PackedBinaryField4x128b, PackedField, + serialize_canonical, AESTowerField128b, PackedAESBinaryField1x128b, + PackedAESBinaryField2x128b, PackedAESBinaryField4x128b, PackedBinaryField1x128b, + PackedBinaryField2x128b, PackedBinaryField4x128b, PackedField, }; #[test] @@ -831,24 +1108,37 @@ mod tests { let a_val = BinaryField128bPolyval(a); let converted = BinaryField128b::from(a_val); assert_eq!(a_val, BinaryField128bPolyval::from(converted)); + + let a_val = AESTowerField128b(a); + let converted = BinaryField128bPolyval::from(a_val); + assert_eq!(a_val, AESTowerField128b::from(converted)); } #[test] fn test_conversion_128b(a in any::()) { test_packed_conversion::(a.into(), POLYVAL_TO_BINARY_TRANSFORMATION); test_packed_conversion::(a.into(), BINARY_TO_POLYVAL_TRANSFORMATION); + + test_packed_conversion::(a.into(), POLYVAL_TO_AES_TRANSFORMARION); + test_packed_conversion::(a.into(), AES_TO_POLYVAL_TRANSFORMATION); } #[test] fn test_conversion_256b(a in any::<[u128; 2]>()) { test_packed_conversion::(a.into(), POLYVAL_TO_BINARY_TRANSFORMATION); test_packed_conversion::(a.into(), BINARY_TO_POLYVAL_TRANSFORMATION); + + test_packed_conversion::(a.into(), POLYVAL_TO_AES_TRANSFORMARION); + test_packed_conversion::(a.into(), AES_TO_POLYVAL_TRANSFORMATION); } #[test] fn test_conversion_512b(a in any::<[u128; 4]>()) { test_packed_conversion::(PackedBinaryPolyval4x128b::from_underlier(a.into()), POLYVAL_TO_BINARY_TRANSFORMATION); test_packed_conversion::(PackedBinaryField4x128b::from_underlier(a.into()), BINARY_TO_POLYVAL_TRANSFORMATION); + + test_packed_conversion::(PackedBinaryPolyval4x128b::from_underlier(a.into()), POLYVAL_TO_AES_TRANSFORMARION); + test_packed_conversion::(PackedAESBinaryField4x128b::from_underlier(a.into()), AES_TO_POLYVAL_TRANSFORMATION); } @@ -856,10 +1146,10 @@ mod tests { fn test_invert_or_zero(a_val in any::()) { let a = BinaryField128bPolyval::new(a_val); let a_invert = InvertOrZero::invert_or_zero(a); - if a != BinaryField128bPolyval::ZERO { - assert_eq!(a * a_invert, BinaryField128bPolyval::ONE); - } else { + if a == BinaryField128bPolyval::ZERO { assert_eq!(a_invert, BinaryField128bPolyval::ZERO); + } else { + assert_eq!(a * a_invert, BinaryField128bPolyval::ONE); } } } diff --git a/crates/field/src/tower_levels.rs b/crates/field/src/tower_levels.rs index 66eb2cfc..bdac60d0 100644 --- a/crates/field/src/tower_levels.rs +++ b/crates/field/src/tower_levels.rs @@ -361,7 +361,7 @@ where const WIDTH: usize = 1; type Data = [T; 1]; - type Base = TowerLevel1; + type Base = Self; // Level 1 is the atomic unit of backing data and must not be split. diff --git a/crates/field/src/underlier/scaled.rs b/crates/field/src/underlier/scaled.rs index 4f3f1a11..4210bc65 100644 --- a/crates/field/src/underlier/scaled.rs +++ b/crates/field/src/underlier/scaled.rs @@ -17,13 +17,13 @@ pub struct ScaledUnderlier(pub [U; N]); impl Default for ScaledUnderlier { fn default() -> Self { - ScaledUnderlier(array::from_fn(|_| U::default())) + Self(array::from_fn(|_| U::default())) } } impl Random for ScaledUnderlier { fn random(mut rng: impl RngCore) -> Self { - ScaledUnderlier(array::from_fn(|_| U::random(&mut rng))) + Self(array::from_fn(|_| U::random(&mut rng))) } } @@ -61,7 +61,7 @@ impl UnderlierType for ScaledUnderlier Divisible for ScaledUnderlier where - ScaledUnderlier: UnderlierType, + Self: UnderlierType, U: UnderlierType, { type Array = [U; N]; @@ -84,7 +84,7 @@ where unsafe impl Divisible for ScaledUnderlier, 2> where - ScaledUnderlier, 2>: UnderlierType + NoUninit, + Self: UnderlierType + NoUninit, U: UnderlierType + Pod, { type Array = [U; 4]; diff --git a/crates/field/src/underlier/underlier_with_bit_ops.rs b/crates/field/src/underlier/underlier_with_bit_ops.rs index 9c242f9b..e151f501 100644 --- a/crates/field/src/underlier/underlier_with_bit_ops.rs +++ b/crates/field/src/underlier/underlier_with_bit_ops.rs @@ -113,9 +113,8 @@ pub trait UnderlierWithBitOps: #[inline] unsafe fn spread(self, log_block_len: usize, block_idx: usize) -> Self where - T: UnderlierWithBitOps, + T: UnderlierWithBitOps + NumCast, Self: From, - T: NumCast, { spread_fallback(self, log_block_len, block_idx) } @@ -138,10 +137,8 @@ where /// `block_idx` must be less than `1 << (U::LOG_BITS - log_block_len)`. pub(crate) unsafe fn spread_fallback(value: U, log_block_len: usize, block_idx: usize) -> U where - U: UnderlierWithBitOps, - T: UnderlierWithBitOps, - U: From, - T: NumCast, + U: UnderlierWithBitOps + From, + T: UnderlierWithBitOps + NumCast, { debug_assert!( log_block_len + T::LOG_BITS <= U::LOG_BITS, diff --git a/crates/math/src/multilinear_extension.rs b/crates/math/src/multilinear_extension.rs index ee2c8e75..f663f159 100644 --- a/crates/math/src/multilinear_extension.rs +++ b/crates/math/src/multilinear_extension.rs @@ -10,7 +10,7 @@ use binius_field::{ ExtensionField, Field, PackedField, }; use binius_maybe_rayon::prelude::*; -use binius_utils::{bail, checked_arithmetics::log2_strict_usize}; +use binius_utils::bail; use bytemuck::zeroed_vec; use tracing::instrument; @@ -32,22 +32,20 @@ pub struct MultilinearExtension = Vec< impl MultilinearExtension

{ pub fn zeros(n_vars: usize) -> Result { - assert!(P::WIDTH.is_power_of_two()); - if n_vars < log2_strict_usize(P::WIDTH) { + if n_vars < P::LOG_WIDTH { bail!(Error::ArgumentRangeError { - arg: "n_vars".to_string(), - range: log2_strict_usize(P::WIDTH)..32, + arg: "n_vars".into(), + range: P::LOG_WIDTH..32, }); } - - Ok(MultilinearExtension { + Ok(Self { mu: n_vars, - evals: vec![P::default(); 1 << (n_vars - log2(P::WIDTH))], + evals: vec![P::default(); 1 << (n_vars - P::LOG_WIDTH)], }) } pub fn from_values(v: Vec

) -> Result { - MultilinearExtension::from_values_generic(v) + Self::from_values_generic(v) } } @@ -95,7 +93,7 @@ where Data: Deref, { pub fn from_underliers(v: Data) -> Result { - MultilinearExtension::from_values_generic(PackingDeref::new(v)) + Self::from_values_generic(PackingDeref::new(v)) } } diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index 576994bf..0d04d291 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -4,6 +4,9 @@ version.workspace = true edition.workspace = true authors.workspace = true +[lints] +workspace = true + [dependencies] binius_maybe_rayon = { path = "../maybe_rayon", default-features = false } bytes.workspace = true diff --git a/crates/utils/src/array_2d.rs b/crates/utils/src/array_2d.rs index 22ff6a22..38ec2551 100644 --- a/crates/utils/src/array_2d.rs +++ b/crates/utils/src/array_2d.rs @@ -42,7 +42,7 @@ impl> Array2D { } /// Returns the number of columns in the array. - pub fn cols(&self) -> usize { + pub const fn cols(&self) -> usize { self.cols } diff --git a/crates/utils/src/checked_arithmetics.rs b/crates/utils/src/checked_arithmetics.rs index b46dd19d..53ab304f 100644 --- a/crates/utils/src/checked_arithmetics.rs +++ b/crates/utils/src/checked_arithmetics.rs @@ -48,7 +48,7 @@ mod tests { #[test] #[should_panic] - fn test_checked_int_div_fail() { + const fn test_checked_int_div_fail() { _ = checked_int_div(5, 2); } @@ -62,7 +62,7 @@ mod tests { #[test] #[should_panic] - fn test_checked_log2_fail() { + const fn test_checked_log2_fail() { _ = checked_log_2(6) } } diff --git a/crates/utils/src/graph.rs b/crates/utils/src/graph.rs index bcd9412a..3f24ccfb 100644 --- a/crates/utils/src/graph.rs +++ b/crates/utils/src/graph.rs @@ -38,7 +38,7 @@ pub fn connected_components(data: &[&[usize]]) -> Vec { for ids in data { if ids.len() > 1 { let &base = ids.iter().min().unwrap(); - for &node in ids.iter() { + for &node in *ids { if node != base { uf.union(base, node); } @@ -66,7 +66,7 @@ struct UnionFind { impl UnionFind { fn new(n: usize) -> Self { - UnionFind { + Self { parent: (0..n).collect(), rank: vec![0; n], min_element: (0..n).collect(), diff --git a/crates/utils/src/iter.rs b/crates/utils/src/iter.rs index d57a5aab..3de0c36c 100644 --- a/crates/utils/src/iter.rs +++ b/crates/utils/src/iter.rs @@ -26,7 +26,7 @@ pub struct SkippableMap { } impl SkippableMap { - fn new(iter: I, func: F) -> Self { + const fn new(iter: I, func: F) -> Self { Self { iter, func } } } diff --git a/crates/utils/src/thread_local_mut.rs b/crates/utils/src/thread_local_mut.rs index f4e844fc..ebb91168 100644 --- a/crates/utils/src/thread_local_mut.rs +++ b/crates/utils/src/thread_local_mut.rs @@ -14,7 +14,7 @@ use thread_local::ThreadLocal; pub struct ThreadLocalMut(ThreadLocal>); impl ThreadLocalMut { - pub fn new() -> Self { + pub const fn new() -> Self { Self(ThreadLocal::new()) }