From 082235c159c8d5c9fabe9e8591b8e9ac3ae2b7f0 Mon Sep 17 00:00:00 2001 From: Artem Storozhuk Date: Tue, 25 Feb 2025 15:37:47 +0200 Subject: [PATCH] example: Blake3 permutation using channels API (#10) * Support for Karatsuba "infinity" point in evaluation & interpolation domains * [sumcheck] Small field zerocheck and its HAL support removed * [ring_switch] Optimize RingSwitchEqInd::multilinear_extension * [clippy]: avoid needless pass by value * [math] Fix `fold_right` crash on big multilinears and make it single threaded * [math] Use specialized zero variable folding in the first sumcheck round. * [test]: add test coverage for eq_ind_partial_eval * [test]: add test coverage for inner_product_par * [test]: test coverage for MultilinearQuery update * [tracing] Display proof size in graph * [ci]: Setting Up GitHub Pipelines * [ci]: Setting Up Mirror to GitLab (#8) * Fix typos (#2) [nicetohave] fix typos * [ci]: Improvements (#17) [ci]: Removing continue or error, and depricating Gitlab pipelines * Improve test compilation time (#10) Co-authored-by: Dmytro Gordon This MR addresses tow issues that make cargo test slow: Thin LTO slows down compilation of all the crates a bit. It takes quite a time to compile and link all the examples with test profile which are not actually executed. So I've added an alias to compile and run tests only for fast local usage. * [serialization] Add canonical serialize/deserialize traits + derive macros Introduces the following traits: SerializeCanonical (which replaces most uses of SerializeBytes) DeserializeCanonical (which replaces most uses of DeserializeBytes) Conveniently, this also comes with proc-macros for deriving these traits for an arbitrary struct/enum (unions are not supported). * [security]: Add CODEOWNERS file for GitHub * [scripts] Added benchmarking script This adds the script to benchmark various of our examples, default sampling is set to 5 to reduce total time to benchmark. * [field] Implement PackedField::unzip * [cleanup]: Remove some useless checked_log_2 calls * [field] Add TowerField::min_tower_level(self), and use it to derive ArithExpr tower_level from its constants (#6) In contrast to TowerField::TOWER_LEVEL, TowerField::binary_tower_level(self) returns the smallest tower level that can fit the current value. This can be useful for shrinking field values to the smaller container that fits them, for the purpose of making arithmetic operations (in particular multiplication) cheaper. * [core]: simplify merkle tree `verify_opening` (#14) * [ci] Adjusting nightly benchmark repository (#23) * [ci]: Adjusting nightly benchmark repository * [ci]: Adjusting CODEOWNERS for .github/ subdir * [circuits] Simplify usage of ConstraintSystemBuilder by making it less generic (#22) [circuits] Simplfy ConstraintSystemBuilder to only support BinaryField128b for the top field. * [field] Simplify usage of PackedExtension, RepackedExtension by making each trait imply its bounds (#24) * [macros] Remove unused IterOracles, IterPolys derive proc macros (#25) * [matrix]: simplify scale_row (#31) * [field] Remove unnecessary `WithUnderlier` trait bound (#32) * [field] Optimize SIMD element access for Zen4 architecture as well. (#28) * refactor: Use binary_tower_level for base field detection (#30) * [serialization] impl SerializeCanonical, DeserializeCanonical for ConstraintSystem (#11) * [circuits] Optimize plain_lookup using selector flushing (#29) * [scripts] Remove groestl run from benchmark script (#26) * [arith_expr]: Statically compile exponentiation in ArithCircuitPoly (#15) * [serialization] Introduce SerializationMode (#36) Changes: Adds SerializaitonMode that specifies whether to use native (fast) or canonical (needed for transcript) serialization/deserializtion. You need to use the same mode for serialization and deserialization. SerializeCanonical is renamed to SerializeBytes, and takes an extra argument of type SerializationMode DeserializeCanonical is renamed to DeserializeBytes and takes an extra argument of type SerializationMode SerializeBytes and DeserializeBytes are now required bounds for the Field trait, rather than being generically implemented for TowerField. u16, u32, u64, u128 now serialize to/deserialize from little-endian rather than big-endian byte order, to be consistent with BinaryField*b serialization. The serialization traits are moved back to binius_utils Automatic implementations of SerializeBytes for Box and &(T: SerializeBytes) Automatic implementation of DeserializeBytes for Box * [gkr_int_mul] Fix type bounds (#34) * feat: Blake3 G function gadget (#16) * [circuits] Add test_circuit helper (#27) * Leave only the object-safe version of the `CompositionPoly` trait (#43) * ]field] Byte-sliced fields changes (#21) * Refactor a bit TowerLevels to remove packed field parameter from the TowerLevel to the Data associated type. This also makes generic bounds a bit more clean, since TowerLevel itself doesn't depend on a concrete packed field type. * Add support of byte-sliced fields with arbitrary register size, i.e. 128b, 256b, 512b. * Add shifts and unpack low/high within 128-bit lanes to UnderlierWithBitOps. This allows implementing transposition in an efficient way. * Add the transparent implementation of UnderlierWithBitOps for PackedScaledUnderlier as we need it to re-use PackedScaledField. * feat: Add example of LinearCombination column usage * ci: Add basic Rust CI (#2) * ci: Add basic Rust CI * Fix test flags * example: Linear combination with offset (#4) * example: Add linear-combination-with-offset usage example * chore: Add example for bit masking using LinearCombination * chore: Add byte decomposition constraint * example: Implement bit-shifting/rotating and packing (#5) * example: Add example of Shifted column usage * example: Add example of Packed column usage * chore: Add 'unconstrained gadgets' warning * example: Projected / Repeated columns usage (#6) * example: Add example of Projected column usage * example: Add example of Repeated column usage * example: Add example of ZeroPadded column usage * examples: Transparent columns usage (part 1) (#8) * feat: Add example of Transparent (Constant) column usage * example: Add example of Transparent (Powers) column usage * example: Add example of Transparent (DisjointProduct) column usage * example: Add example of Transparent (EqIndPartialEval) column usage * examples: Transparent columns usage (part 2) (#9) * example: Add example of Transparent (MultilinearExtensionTransparent) column usage * example: Add example of Transparent (SelectRow) column usage * example: Add example of Transparent (ShiftIndPartialEval) column usage * example: Add example of Transparent (StepDown) column usage * example: Add example of Transparent (StepUp) column usage * example: Add example of Transparent (TowerBasis) column usage * chore: Forward port * feat: Blake3 permutation using channels API * chore: Formatting --------- Co-authored-by: Nikita Lesnikov Co-authored-by: Dmitry Gordon Co-authored-by: Thomas Coratger Co-authored-by: Aliaksei Dziadziuk Co-authored-by: Milos Backonja Co-authored-by: Milos Backonja <35807060+milosbackonja@users.noreply.github.com> Co-authored-by: chloefeal <188809157+chloefeal@users.noreply.github.com> Co-authored-by: Dmytro Gordon Co-authored-by: Tobias Bergkvist Co-authored-by: Anex007 Co-authored-by: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Co-authored-by: Nikita Lesnikov Co-authored-by: Joseph Johnston Co-authored-by: Samuel Burnham <45365069+samuelburnham@users.noreply.github.com> --- .cargo/config.toml | 3 + .gitlab-ci.yml | 203 ------- CODEOWNERS | 5 + Cargo.toml | 7 +- crates/circuits/src/arithmetic/u32.rs | 181 +++--- crates/circuits/src/bitwise.rs | 59 +- crates/circuits/src/blake3.rs | 209 +++++++ .../circuits/src/builder/constraint_system.rs | 27 +- crates/circuits/src/builder/mod.rs | 2 + crates/circuits/src/builder/test_utils.rs | 23 + crates/circuits/src/builder/types.rs | 4 + crates/circuits/src/builder/witness.rs | 108 ++-- crates/circuits/src/collatz.rs | 77 +-- crates/circuits/src/groestl.rs | 25 + crates/circuits/src/keccakf.rs | 57 +- crates/circuits/src/lasso/batch.rs | 20 +- .../lasso/big_integer_ops/byte_sliced_add.rs | 39 +- .../byte_sliced_add_carryfree.rs | 39 +- ...yte_sliced_double_conditional_increment.rs | 37 +- .../byte_sliced_modular_mul.rs | 59 +- .../lasso/big_integer_ops/byte_sliced_mul.rs | 54 +- .../big_integer_ops/byte_sliced_test_utils.rs | 428 ++++++-------- .../circuits/src/lasso/big_integer_ops/mod.rs | 53 ++ crates/circuits/src/lasso/lasso.rs | 23 +- .../src/lasso/lookups/u8_arithmetic.rs | 184 ++++-- crates/circuits/src/lasso/sha256.rs | 133 +++-- crates/circuits/src/lasso/u32add.rs | 114 ++-- .../lasso/u8_double_conditional_increment.rs | 30 +- crates/circuits/src/lasso/u8add.rs | 30 +- crates/circuits/src/lasso/u8add_carryfree.rs | 30 +- crates/circuits/src/lasso/u8mul.rs | 36 +- crates/circuits/src/lib.rs | 556 +----------------- crates/circuits/src/pack.rs | 17 +- crates/circuits/src/plain_lookup.rs | 245 ++++---- crates/circuits/src/sha256.rs | 106 +++- crates/circuits/src/transparent.rs | 39 +- crates/circuits/src/u32fib.rs | 37 +- crates/circuits/src/unconstrained.rs | 43 +- crates/circuits/src/vision.rs | 79 +-- crates/core/Cargo.toml | 2 + crates/core/benches/composition_poly.rs | 4 +- crates/core/src/composition/index.rs | 6 +- .../src/composition/product_composition.rs | 4 +- crates/core/src/constraint_system/channel.rs | 30 +- crates/core/src/constraint_system/mod.rs | 24 +- crates/core/src/constraint_system/prove.rs | 43 +- crates/core/src/constraint_system/verify.rs | 25 +- crates/core/src/lib.rs | 4 +- .../src/merkle_tree/binary_merkle_tree.rs | 9 +- crates/core/src/merkle_tree/scheme.rs | 45 +- crates/core/src/oracle/composite.rs | 10 +- crates/core/src/oracle/constraint.rs | 17 +- crates/core/src/oracle/multilinear.rs | 118 +++- crates/core/src/piop/prove.rs | 18 +- crates/core/src/piop/tests.rs | 8 +- crates/core/src/piop/verify.rs | 4 +- crates/core/src/polynomial/arith_circuit.rs | 535 +++++++++++++---- crates/core/src/polynomial/cached.rs | 272 --------- crates/core/src/polynomial/mod.rs | 2 - crates/core/src/polynomial/multivariate.rs | 54 +- crates/core/src/protocols/evalcheck/error.rs | 2 +- crates/core/src/protocols/fri/common.rs | 6 +- crates/core/src/protocols/fri/prove.rs | 6 +- crates/core/src/protocols/fri/tests.rs | 4 +- crates/core/src/protocols/fri/verify.rs | 2 +- .../protocols/gkr_gpa/gpa_sumcheck/prove.rs | 36 +- crates/core/src/protocols/gkr_gpa/prove.rs | 6 +- crates/core/src/protocols/gkr_gpa/tests.rs | 14 +- crates/core/src/protocols/gkr_gpa/verify.rs | 4 +- .../core/src/protocols/gkr_int_mul/error.rs | 8 + .../generator_exponent/compositions.rs | 4 +- .../gkr_int_mul/generator_exponent/prove.rs | 20 +- .../gkr_int_mul/generator_exponent/tests.rs | 2 +- .../gkr_int_mul/generator_exponent/verify.rs | 5 +- .../gkr_int_mul/generator_exponent/witness.rs | 31 +- crates/core/src/protocols/gkr_int_mul/mod.rs | 2 +- crates/core/src/protocols/sumcheck/common.rs | 4 +- crates/core/src/protocols/sumcheck/error.rs | 2 +- .../src/protocols/sumcheck/front_loaded.rs | 4 +- .../sumcheck/prove/concrete_prover.rs | 65 -- .../core/src/protocols/sumcheck/prove/mod.rs | 2 - .../src/protocols/sumcheck/prove/oracles.rs | 2 +- .../protocols/sumcheck/prove/prover_state.rs | 45 +- .../sumcheck/prove/regular_sumcheck.rs | 30 +- .../protocols/sumcheck/prove/univariate.rs | 34 +- .../src/protocols/sumcheck/prove/zerocheck.rs | 188 +++--- crates/core/src/protocols/sumcheck/tests.rs | 15 +- .../core/src/protocols/sumcheck/univariate.rs | 26 +- .../sumcheck/univariate_zerocheck.rs | 4 +- crates/core/src/protocols/sumcheck/verify.rs | 10 +- .../core/src/protocols/sumcheck/zerocheck.rs | 20 +- crates/core/src/protocols/test_utils.rs | 10 +- crates/core/src/reed_solomon/reed_solomon.rs | 19 +- crates/core/src/ring_switch/common.rs | 2 +- crates/core/src/ring_switch/eq_ind.rs | 94 ++- crates/core/src/ring_switch/prove.rs | 12 +- crates/core/src/ring_switch/tests.rs | 6 +- crates/core/src/ring_switch/verify.rs | 14 +- crates/core/src/tensor_algebra.rs | 10 +- crates/core/src/transcript/error.rs | 4 +- crates/core/src/transcript/mod.rs | 60 +- crates/core/src/transparent/constant.rs | 17 +- crates/core/src/transparent/mod.rs | 1 + .../src/transparent/multilinear_extension.rs | 75 ++- crates/core/src/transparent/powers.rs | 19 +- crates/core/src/transparent/select_row.rs | 15 +- crates/core/src/transparent/serialization.rs | 82 +++ crates/core/src/transparent/step_down.rs | 15 +- crates/core/src/transparent/step_up.rs | 15 +- crates/core/src/transparent/tower_basis.rs | 15 +- crates/field/Cargo.toml | 1 - crates/field/benches/packed_extension_mul.rs | 1 - .../benches/packed_field_element_access.rs | 54 +- crates/field/benches/packed_field_init.rs | 54 +- .../benches/packed_field_subfield_ops.rs | 10 +- crates/field/benches/packed_field_utils.rs | 12 + crates/field/src/aes_field.rs | 107 +++- crates/field/src/arch/aarch64/m128.rs | 69 ++- .../src/arch/portable/byte_sliced/invert.rs | 45 +- .../src/arch/portable/byte_sliced/mod.rs | 28 +- .../src/arch/portable/byte_sliced/multiply.rs | 55 +- .../byte_sliced/packed_byte_sliced.rs | 126 ++-- .../src/arch/portable/byte_sliced/square.rs | 35 +- crates/field/src/arch/portable/packed.rs | 8 + .../src/arch/portable/packed_arithmetic.rs | 16 +- .../field/src/arch/portable/packed_scaled.rs | 88 ++- crates/field/src/arch/x86_64/m128.rs | 239 +++++--- crates/field/src/arch/x86_64/m256.rs | 245 ++++++-- crates/field/src/arch/x86_64/m512.rs | 291 +++++++-- crates/field/src/binary_field.rs | 183 +++--- crates/field/src/binary_field_arithmetic.rs | 4 + crates/field/src/byte_iteration.rs | 440 ++++++++++++++ crates/field/src/extension.rs | 16 +- crates/field/src/field.rs | 3 + crates/field/src/lib.rs | 1 + crates/field/src/packed.rs | 39 +- crates/field/src/packed_binary_field.rs | 146 +++++ crates/field/src/packed_extension.rs | 21 +- crates/field/src/packed_extension_ops.rs | 45 +- crates/field/src/polyval.rs | 78 ++- crates/field/src/tower_levels.rs | 262 ++++----- crates/field/src/transpose.rs | 4 +- crates/field/src/underlier/scaled.rs | 193 +++++- crates/field/src/underlier/small_uint.rs | 34 +- crates/field/src/underlier/underlier_impls.rs | 10 + .../src/underlier/underlier_with_bit_ops.rs | 81 +++ crates/field/src/util.rs | 97 ++- crates/hal/src/backend.rs | 76 +-- crates/hal/src/cpu.rs | 51 +- crates/hal/src/sumcheck_evaluator.rs | 10 +- crates/hal/src/sumcheck_round_calculator.rs | 149 ++--- crates/hash/src/groestl/hasher.rs | 3 +- crates/hash/src/serialization.rs | 4 +- crates/hash/src/vision.rs | 1 - crates/macros/Cargo.toml | 1 + crates/macros/src/arith_circuit_poly.rs | 393 +------------ crates/macros/src/composition_poly.rs | 41 +- crates/macros/src/lib.rs | 369 +++++++----- crates/macros/tests/arithmetic_circuit.rs | 2 +- crates/math/Cargo.toml | 1 + crates/math/src/arith_expr.rs | 20 +- crates/math/src/composition_poly.rs | 26 +- crates/math/src/deinterleave.rs | 13 +- crates/math/src/error.rs | 2 + crates/math/src/fold.rs | 374 +++++------- crates/math/src/matrix.rs | 2 - crates/math/src/mle_adapters.rs | 66 +-- crates/math/src/multilinear_extension.rs | 32 +- crates/math/src/multilinear_query.rs | 104 +++- crates/math/src/tensor_prod_eq_ind.rs | 65 +- crates/math/src/univariate.rs | 191 ++++-- crates/ntt/src/additive_ntt.rs | 35 +- crates/ntt/src/dynamic_dispatch.rs | 18 +- crates/ntt/src/single_threaded.rs | 6 +- crates/ntt/src/tests/ntt_tests.rs | 13 +- crates/utils/Cargo.toml | 3 +- crates/utils/src/lib.rs | 3 + crates/utils/src/serialization.rs | 411 ++++++++++++- crates/utils/src/thread_local_mut.rs | 2 +- examples/Cargo.toml | 15 +- examples/acc-constants.rs | 10 +- examples/acc-disjoint-product.rs | 8 +- examples/acc-eq-ind-partial-eval.rs | 6 +- ...inear-combination-with-offset.rs.disabled} | 0 examples/acc-linear-combination.rs | 13 +- .../acc-multilinear-extension-transparent.rs | 6 +- examples/acc-packed.rs | 21 +- examples/acc-permutation-channels.rs | 111 ++++ examples/acc-powers.rs | 13 +- examples/acc-projected.rs | 10 +- examples/acc-repeated.rs | 15 +- examples/acc-select-row.rs | 6 +- examples/acc-shift-ind-partial-eq.rs | 5 +- examples/acc-shifted.rs | 22 +- examples/acc-step-down.rs | 6 +- examples/acc-step-up.rs | 6 +- examples/acc-tower-basis.rs | 5 +- examples/acc-zeropadded.rs | 8 +- examples/b32_mul.rs | 11 +- examples/bitwise_ops.rs | 13 +- examples/collatz.rs | 10 +- ...circuit.rs => groestl_circuit.rs.disabled} | 0 examples/keccakf_circuit.rs | 8 +- examples/modular_mul.rs | 11 +- examples/sha256_circuit.rs | 16 +- examples/sha256_circuit_with_lookup.rs | 18 +- examples/u32_add.rs | 14 +- examples/u32_mul.rs | 20 +- examples/u32add_with_lookup.rs | 14 +- examples/u8mul.rs | 13 +- examples/vision32b_circuit.rs | 9 +- scripts/nightly_benchmarks.py | 265 +++++++++ scripts/run_tests_and_examples.sh | 2 +- 213 files changed, 6922 insertions(+), 5159 deletions(-) delete mode 100644 .gitlab-ci.yml create mode 100644 CODEOWNERS create mode 100644 crates/circuits/src/blake3.rs create mode 100644 crates/circuits/src/builder/test_utils.rs create mode 100644 crates/circuits/src/builder/types.rs delete mode 100644 crates/core/src/polynomial/cached.rs delete mode 100644 crates/core/src/protocols/sumcheck/prove/concrete_prover.rs create mode 100644 crates/core/src/transparent/serialization.rs create mode 100644 crates/field/src/byte_iteration.rs rename examples/{acc-linear-combination-with-offset.rs => acc-linear-combination-with-offset.rs.disabled} (100%) create mode 100644 examples/acc-permutation-channels.rs rename examples/{groestl_circuit.rs => groestl_circuit.rs.disabled} (100%) create mode 100755 scripts/nightly_benchmarks.py diff --git a/.cargo/config.toml b/.cargo/config.toml index 7acd84b0..6287bb56 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +1,5 @@ [build] rustdocflags = ["-Dwarnings", "--html-in-header", "doc/katex-header.html"] + +[alias] +fast_test = "test --tests" diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml deleted file mode 100644 index d2a03abf..00000000 --- a/.gitlab-ci.yml +++ /dev/null @@ -1,203 +0,0 @@ -workflow: - rules: - - if: $CI_PIPELINE_SOURCE == "merge_request_event" - - if: $CI_COMMIT_BRANCH == 'main' - -variables: - CARGO_HOME: "$CI_PROJECT_DIR/toolchains/cargo" - RUSTUP_HOME: "$CI_PROJECT_DIR/toolchains" - GIT_CLEAN_FLAGS: "-ffdx --exclude toolchains" - FF_TIMESTAMPS: true - - -stages: - - lint - - build - - test - - deploy - -# AMD job configuration template -.job_template_amd: - image: rustlang/rust:nightly - variables: - KUBERNETES_NODE_SELECTOR_INSTANCE_TYPE: "ulvt-node-pool=ulvt-c7i-2xlarge" - KUBERNETES_CPU_REQUEST: "6" - KUBERNETES_MEMORY_REQUEST: "14Gi" - GIT_CLONE_PATH: "$CI_BUILDS_DIR/binius_amd" - tags: - - k8s - -# AMD job configuration template with stable Rust -.job_template_amd_stable: - extends: .test_job_template_amd - variables: - RUST_VERSION: "1.83.0" - before_script: - # workaround for https://github.com/rust-lang/rustup/issues/2886 - - rustup set auto-self-update disable - - rustup toolchain install $RUST_VERSION - -# ARM job configuration template -.job_template_arm: - image: rustlang/rust:nightly - variables: - KUBERNETES_NODE_SELECTOR_INSTANCE_TYPE: "ulvt-node-pool=ulvt-c8g-2xlarge" - KUBERNETES_NODE_SELECTOR_ARCH: 'kubernetes.io/arch=arm64' - KUBERNETES_CPU_REQUEST: "6" - KUBERNETES_MEMORY_REQUEST: "14Gi" - GIT_CLONE_PATH: "$CI_BUILDS_DIR/binius_arm" - before_script: - - if [ "$(uname -m)" != "aarch64" ]; then echo "This job is intended to run on ARM architecture only."; exit 1; fi - tags: - - k8s - -# Linting jobs -copyright-check: - extends: .job_template_amd - stage: lint - script: - - ./scripts/check_copyright_notice.sh - -cargofmt: - extends: .job_template_amd - stage: lint - script: - - cargo fmt --check - -clippy: - extends: .job_template_amd - stage: lint - script: - - cargo clippy --all --all-features --tests --benches --examples -- -D warnings - -# Building jobs - -# TODO: use a docker image with `wasm32-unknown-unknown` target preinstalled -build-debug-wasm: - extends: .job_template_amd - stage: build - script: - - rustup target add wasm32-unknown-unknown - - cargo build --package binius_field --target wasm32-unknown-unknown - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -build-debug-amd: - extends: .job_template_amd - stage: build - script: - - cargo build --tests --benches --examples - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -# Build without default features -# This checks if build without `rayon` feature works. -build-debug-amd-no-default-features: - extends: .job_template_amd - stage: build - script: - - cargo build --tests --benches --examples --no-default-features - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -build-debug-amd-stable: - extends: .job_template_amd_stable - stage: build - script: - - cargo +$RUST_VERSION build --tests --benches --examples -p binius_core --features stable_only - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -build-debug-arm: - extends: .job_template_arm - stage: build - script: - - cargo build --tests --benches --examples - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -.test_job_template_amd: - extends: .job_template_amd - dependencies: - - build-debug-amd - -.test_job_template_amd_stable: - extends: .job_template_amd_stable - dependencies: - - build-debug-amd-stable - -.test_job_template_arm: - extends: .job_template_arm - dependencies: - - build-debug-arm - -unit-test-amd-portable: - extends: .test_job_template_amd - script: - - RUSTFLAGS="-C target-cpu=generic" ./scripts/run_tests_and_examples.sh - -unit-test-arm-portable: - extends: .test_job_template_arm - script: - - RUSTFLAGS="-C target-cpu=generic" ./scripts/run_tests_and_examples.sh - -unit-test-single-threaded: - extends: .test_job_template_amd - script: - - RAYON_NUM_THREADS=1 RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh - -unit-test-no-default-features: - extends: .test_job_template_amd - script: - - CARGO_EXTRA_FLAGS="--no-default-features" RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh - -unit-test-amd: - extends: .test_job_template_amd - script: - - RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh - -unit-test-amd-stable: - extends: .test_job_template_amd_stable - script: - - RUSTFLAGS="-C target-cpu=native" CARGO_STABLE=true ./scripts/run_tests_and_examples.sh - -unit-test-arm: - extends: .test_job_template_arm - script: - - RUSTFLAGS="-C target-cpu=native -C target-feature=+aes" ./scripts/run_tests_and_examples.sh - -# Documentation and pages jobs -build-docs: - extends: .job_template_amd - stage: build - script: - - cargo doc --no-deps - artifacts: - paths: - - target/doc - expire_in: 1 week - -pages: - extends: .job_template_amd - stage: deploy - dependencies: - - build-docs - script: - - mv target/doc public - - echo "/ /binius_core 302" > public/_redirects - artifacts: - paths: - - public - only: - refs: - - main # Deploy for every push to the main branch, for now diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 00000000..61d77251 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,5 @@ +# Code owners for the entire repository +* @jimpo-ulvt @onesk + +# Code owners for the .github path +/.github/ @IrreducibleOSS/Infrastructure diff --git a/Cargo.toml b/Cargo.toml index 9cf5f096..2ee58dfc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -100,6 +100,7 @@ generic-array = "0.14.7" getset = "0.1.2" groestl_crypto = { package = "groestl", version = "0.10.1" } hex-literal = "0.4.1" +inventory = "0.3.19" itertools = "0.13.0" lazy_static = "1.5.0" paste = "1.0.15" @@ -112,13 +113,13 @@ seq-macro = "0.3.5" sha2 = "0.10.8" stackalloc = "1.2.1" subtle = "2.5.0" -syn = { version = "2.0.60", features = ["full"] } +syn = { version = "2.0.98", features = ["extra-traits"] } thiserror = "2.0.3" thread_local = "1.1.7" tiny-keccak = { version = "2.0.2", features = ["keccak"] } trait-set = "0.3.0" tracing = "0.1.38" -tracing-profile = "0.9.0" +tracing-profile = "0.10.1" transpose = "0.2.2" [profile.release] @@ -137,4 +138,4 @@ opt-level = 1 debug = true debug-assertions = true overflow-checks = true -lto = false +lto = "off" diff --git a/crates/circuits/src/arithmetic/u32.rs b/crates/circuits/src/arithmetic/u32.rs index 88a0c9a7..0dcdd1ae 100644 --- a/crates/circuits/src/arithmetic/u32.rs +++ b/crates/circuits/src/arithmetic/u32.rs @@ -1,25 +1,17 @@ // Copyright 2024-2025 Irreducible Inc. use binius_core::oracle::{OracleId, ProjectionVariant, ShiftVariant}; -use binius_field::{ - as_packed_field::PackScalar, packed::set_packed_slice, BinaryField1b, BinaryField32b, - ExtensionField, Field, TowerField, -}; +use binius_field::{packed::set_packed_slice, BinaryField1b, BinaryField32b, Field, TowerField}; use binius_macros::arith_expr; use binius_maybe_rayon::prelude::*; -use bytemuck::Pod; use crate::{builder::ConstraintSystemBuilder, transparent}; -pub fn packed( - builder: &mut ConstraintSystemBuilder, +pub fn packed( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, -) -> Result -where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, -{ +) -> Result { let packed = builder.add_packed(name, input, 5)?; if let Some(witness) = builder.witness() { witness.set( @@ -32,17 +24,13 @@ where Ok(packed) } -pub fn mul_const( - builder: &mut ConstraintSystemBuilder, +pub fn mul_const( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, value: u32, flags: super::Flags, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { if value == 0 { let log_rows = builder.log_rows([input])?; return transparent::constant(builder, name, log_rows, BinaryField1b::ZERO); @@ -85,17 +73,13 @@ where Ok(result) } -pub fn add( - builder: &mut ConstraintSystemBuilder, +pub fn add( + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, flags: super::Flags, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([xin, yin])?; let cout = builder.add_committed("cout", log_rows, BinaryField1b::TOWER_LEVEL); @@ -151,17 +135,13 @@ where Ok(zout) } -pub fn sub( - builder: &mut ConstraintSystemBuilder, +pub fn sub( + builder: &mut ConstraintSystemBuilder, name: impl ToString, zin: OracleId, yin: OracleId, flags: super::Flags, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([zin, yin])?; let cout = builder.add_committed("cout", log_rows, BinaryField1b::TOWER_LEVEL); @@ -218,16 +198,12 @@ where Ok(xout) } -pub fn half( - builder: &mut ConstraintSystemBuilder, +pub fn half( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, flags: super::Flags, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { if matches!(flags, super::Flags::Checked) { // Assert that the number is even let lsb = select_bit(builder, "lsb", input, 0)?; @@ -236,23 +212,24 @@ where shr(builder, name, input, 1) } -pub fn shl( - builder: &mut ConstraintSystemBuilder, +pub fn shl( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, offset: usize, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { if offset == 0 { return Ok(input); } let shifted = builder.add_shifted(name, input, offset, 5, ShiftVariant::LogicalLeft)?; if let Some(witness) = builder.witness() { - (witness.new_column(shifted).as_mut_slice::(), witness.get(input)?.as_slice::()) + ( + witness + .new_column::(shifted) + .as_mut_slice::(), + witness.get::(input)?.as_slice::(), + ) .into_par_iter() .for_each(|(shifted, input)| *shifted = *input << offset); } @@ -260,23 +237,24 @@ where Ok(shifted) } -pub fn shr( - builder: &mut ConstraintSystemBuilder, +pub fn shr( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, offset: usize, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { if offset == 0 { return Ok(input); } let shifted = builder.add_shifted(name, input, offset, 5, ShiftVariant::LogicalRight)?; if let Some(witness) = builder.witness() { - (witness.new_column(shifted).as_mut_slice::(), witness.get(input)?.as_slice::()) + ( + witness + .new_column::(shifted) + .as_mut_slice::(), + witness.get::(input)?.as_slice::(), + ) .into_par_iter() .for_each(|(shifted, input)| *shifted = *input >> offset); } @@ -284,16 +262,12 @@ where Ok(shifted) } -pub fn select_bit( - builder: &mut ConstraintSystemBuilder, +pub fn select_bit( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, index: usize, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { let log_rows = builder.log_rows([input])?; anyhow::ensure!(log_rows >= 5, "Polynomial must have n_vars >= 5. Got {log_rows}"); anyhow::ensure!(index < 32, "Only index values between 0 and 32 are allowed. Got {index}"); @@ -304,7 +278,7 @@ where if let Some(witness) = builder.witness() { let mut bits = witness.new_column::(bits); let bits = bits.packed(); - let input = witness.get(input)?.as_slice::(); + let input = witness.get::(input)?.as_slice::(); input.iter().enumerate().for_each(|(i, &val)| { let value = match (val >> index) & 1 { 0 => BinaryField1b::ZERO, @@ -317,16 +291,12 @@ where Ok(bits) } -pub fn constant( - builder: &mut ConstraintSystemBuilder, +pub fn constant( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_count: usize, value: u32, -) -> Result -where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); // This would not need to be committed if we had `builder.add_unpacked(..)` let output = builder.add_committed("output", log_count + 5, BinaryField1b::TOWER_LEVEL); @@ -360,50 +330,47 @@ where #[cfg(test)] mod tests { - use binius_core::constraint_system::validate::validate_witness; - use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b, TowerField}; + use binius_field::{BinaryField1b, TowerField}; - use crate::{arithmetic, builder::ConstraintSystemBuilder, unconstrained::unconstrained}; - - type U = OptimalUnderlier; - type F = BinaryField128b; + use crate::{arithmetic, builder::test_utils::test_circuit, unconstrained::unconstrained}; #[test] fn test_mul_const() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - - let a = builder.add_committed("a", 5, BinaryField1b::TOWER_LEVEL); - if let Some(witness) = builder.witness() { - witness - .new_column::(a) - .as_mut_slice::() - .iter_mut() - .for_each(|v| *v = 0b01000000_00000000_00000000_00000000u32); - } - - let _c = arithmetic::u32::mul_const(&mut builder, "mul3", a, 3, arithmetic::Flags::Checked) - .unwrap(); + test_circuit(|builder| { + let a = builder.add_committed("a", 5, BinaryField1b::TOWER_LEVEL); + if let Some(witness) = builder.witness() { + witness + .new_column::(a) + .as_mut_slice::() + .iter_mut() + .for_each(|v| *v = 0b01000000_00000000_00000000_00000000u32); + } + let _c = arithmetic::u32::mul_const(builder, "mul3", a, 3, arithmetic::Flags::Checked)?; + Ok(vec![]) + }) + .unwrap(); + } - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + #[test] + fn test_add() { + test_circuit(|builder| { + let log_size = 14; + let a = unconstrained::(builder, "a", log_size)?; + let b = unconstrained::(builder, "b", log_size)?; + let _c = arithmetic::u32::add(builder, "u32add", a, b, arithmetic::Flags::Unchecked)?; + Ok(vec![]) + }) + .unwrap(); } #[test] fn test_sub() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - - let a = unconstrained::(&mut builder, "a", 7).unwrap(); - let b = unconstrained::(&mut builder, "a", 7).unwrap(); - let _c = - arithmetic::u32::sub(&mut builder, "c", a, b, arithmetic::Flags::Unchecked).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + test_circuit(|builder| { + let a = unconstrained::(builder, "a", 7).unwrap(); + let b = unconstrained::(builder, "a", 7).unwrap(); + let _c = arithmetic::u32::sub(builder, "c", a, b, arithmetic::Flags::Unchecked)?; + Ok(vec![]) + }) + .unwrap(); } } diff --git a/crates/circuits/src/bitwise.rs b/crates/circuits/src/bitwise.rs index 33e2d3a1..42d7cbd7 100644 --- a/crates/circuits/src/bitwise.rs +++ b/crates/circuits/src/bitwise.rs @@ -1,25 +1,18 @@ // Copyright 2024-2025 Irreducible Inc. use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, BinaryField1b, TowerField, -}; +use binius_field::{BinaryField1b, Field, TowerField}; use binius_macros::arith_expr; use binius_maybe_rayon::prelude::*; -use bytemuck::Pod; use crate::builder::ConstraintSystemBuilder; -pub fn and( - builder: &mut ConstraintSystemBuilder, +pub fn and( + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([xin, yin])?; let zout = builder.add_committed("zout", log_rows, BinaryField1b::TOWER_LEVEL); @@ -45,19 +38,16 @@ where Ok(zout) } -pub fn xor( - builder: &mut ConstraintSystemBuilder, +pub fn xor( + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([xin, yin])?; - let zout = builder.add_linear_combination("zout", log_rows, [(xin, F::ONE), (yin, F::ONE)])?; + let zout = + builder.add_linear_combination("zout", log_rows, [(xin, Field::ONE), (yin, Field::ONE)])?; if let Some(witness) = builder.witness() { ( witness.get::(xin)?.as_slice::(), @@ -75,16 +65,12 @@ where Ok(zout) } -pub fn or( - builder: &mut ConstraintSystemBuilder, +pub fn or( + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([xin, yin])?; let zout = builder.add_committed("zout", log_rows, BinaryField1b::TOWER_LEVEL); @@ -109,3 +95,24 @@ where builder.pop_namespace(); Ok(zout) } + +#[cfg(test)] +mod tests { + use binius_field::BinaryField1b; + + use crate::{builder::test_utils::test_circuit, unconstrained::unconstrained}; + + #[test] + fn test_bitwise() { + test_circuit(|builder| { + let log_size = 6; + let a = unconstrained::(builder, "a", log_size)?; + let b = unconstrained::(builder, "b", log_size)?; + let _and = super::and(builder, "and", a, b)?; + let _xor = super::xor(builder, "xor", a, b)?; + let _or = super::or(builder, "or", a, b)?; + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/circuits/src/blake3.rs b/crates/circuits/src/blake3.rs new file mode 100644 index 00000000..4d532558 --- /dev/null +++ b/crates/circuits/src/blake3.rs @@ -0,0 +1,209 @@ +// Copyright 2024-2025 Irreducible Inc. + +use binius_core::oracle::{OracleId, ShiftVariant}; +use binius_field::{BinaryField1b, Field}; +use binius_utils::checked_arithmetics::checked_log_2; + +use crate::{ + arithmetic, + arithmetic::Flags, + builder::{types::F, ConstraintSystemBuilder}, +}; + +type F1 = BinaryField1b; +const LOG_U32_BITS: usize = checked_log_2(32); + +// Gadget that performs two u32 variables XOR and then rotates the result +fn xor_rotate_right( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + log_size: usize, + a: OracleId, + b: OracleId, + rotate_right_offset: u32, +) -> Result { + assert!(rotate_right_offset <= 32); + + builder.push_namespace(name); + + let xor = builder + .add_linear_combination("xor", log_size, [(a, F::ONE), (b, F::ONE)]) + .unwrap(); + + let rotate = builder.add_shifted( + "rotate", + xor, + 32 - rotate_right_offset as usize, + LOG_U32_BITS, + ShiftVariant::CircularLeft, + )?; + + if let Some(witness) = builder.witness() { + let a_value = witness.get::(a)?.as_slice::(); + let b_value = witness.get::(b)?.as_slice::(); + + let mut xor_witness = witness.new_column::(xor); + let xor_value = xor_witness.as_mut_slice::(); + + for (idx, v) in xor_value.iter_mut().enumerate() { + *v = a_value[idx] ^ b_value[idx]; + } + + let mut rotate_witness = witness.new_column::(rotate); + let rotate_value = rotate_witness.as_mut_slice::(); + for (idx, v) in rotate_value.iter_mut().enumerate() { + *v = xor_value[idx].rotate_right(rotate_right_offset); + } + } + + builder.pop_namespace(); + + Ok(rotate) +} + +#[allow(clippy::too_many_arguments)] +pub fn blake3_g( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + a_in: OracleId, + b_in: OracleId, + c_in: OracleId, + d_in: OracleId, + mx: OracleId, + my: OracleId, + log_size: usize, +) -> Result<[OracleId; 4], anyhow::Error> { + builder.push_namespace(name); + + let ab = arithmetic::u32::add(builder, "a_in + b_in", a_in, b_in, Flags::Unchecked)?; + let a1 = arithmetic::u32::add(builder, "a_in + b_in + mx", ab, mx, Flags::Unchecked)?; + + let d1 = xor_rotate_right(builder, "(d_in ^ a1).rotate_right(16)", log_size, d_in, a1, 16u32)?; + + let c1 = arithmetic::u32::add(builder, "c_in + d1", c_in, d1, Flags::Unchecked)?; + + let b1 = xor_rotate_right(builder, "(b_in ^ c1).rotate_right(12)", log_size, b_in, c1, 12u32)?; + + let a1b1 = arithmetic::u32::add(builder, "a1 + b1", a1, b1, Flags::Unchecked)?; + let a2 = arithmetic::u32::add(builder, "a1 + b1 + my_in", a1b1, my, Flags::Unchecked)?; + + let d2 = xor_rotate_right(builder, "(d1 ^ a2).rotate_right(8)", log_size, d1, a2, 8u32)?; + + let c2 = arithmetic::u32::add(builder, "c1 + d2", c1, d2, Flags::Unchecked)?; + + let b2 = xor_rotate_right(builder, "(b1 ^ c2).rotate_right(7)", log_size, b1, c2, 7u32)?; + + builder.pop_namespace(); + + Ok([a2, b2, c2, d2]) +} + +#[cfg(test)] +mod tests { + use binius_core::constraint_system::validate::validate_witness; + use binius_field::BinaryField1b; + use binius_maybe_rayon::prelude::*; + + use crate::{ + blake3::blake3_g, + builder::ConstraintSystemBuilder, + unconstrained::{fixed_u32, unconstrained}, + }; + + type F1 = BinaryField1b; + + const LOG_SIZE: usize = 5; + + // The Blake3 mixing function, G, which mixes either a column or a diagonal. + // https://github.com/BLAKE3-team/BLAKE3/blob/master/reference_impl/reference_impl.rs + const fn g( + a_in: u32, + b_in: u32, + c_in: u32, + d_in: u32, + mx: u32, + my: u32, + ) -> (u32, u32, u32, u32) { + let a1 = a_in.wrapping_add(b_in).wrapping_add(mx); + let d1 = (d_in ^ a1).rotate_right(16); + let c1 = c_in.wrapping_add(d1); + let b1 = (b_in ^ c1).rotate_right(12); + + let a2 = a1.wrapping_add(b1).wrapping_add(my); + let d2 = (d1 ^ a2).rotate_right(8); + let c2 = c1.wrapping_add(d2); + let b2 = (b1 ^ c2).rotate_right(7); + + (a2, b2, c2, d2) + } + + #[test] + fn test_vector() { + // Let's use some fixed data input to check that our in-circuit computation + // produces same output as out-of-circuit one + let a = 0xaaaaaaaau32; + let b = 0xbbbbbbbbu32; + let c = 0xccccccccu32; + let d = 0xddddddddu32; + let mx = 0xffff00ffu32; + let my = 0xff00ffffu32; + + let (expected_0, expected_1, expected_2, expected_3) = g(a, b, c, d, mx, my); + + let size = 1 << LOG_SIZE; + + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + let a_in = fixed_u32::(&mut builder, "a", LOG_SIZE, vec![a; size]).unwrap(); + let b_in = fixed_u32::(&mut builder, "b", LOG_SIZE, vec![b; size]).unwrap(); + let c_in = fixed_u32::(&mut builder, "c", LOG_SIZE, vec![c; size]).unwrap(); + let d_in = fixed_u32::(&mut builder, "d", LOG_SIZE, vec![d; size]).unwrap(); + let mx_in = fixed_u32::(&mut builder, "mx", LOG_SIZE, vec![mx; size]).unwrap(); + let my_in = fixed_u32::(&mut builder, "my", LOG_SIZE, vec![my; size]).unwrap(); + + let output = + blake3_g(&mut builder, "g", a_in, b_in, c_in, d_in, mx_in, my_in, LOG_SIZE).unwrap(); + + if let Some(witness) = builder.witness() { + ( + witness.get::(output[0]).unwrap().as_slice::(), + witness.get::(output[1]).unwrap().as_slice::(), + witness.get::(output[2]).unwrap().as_slice::(), + witness.get::(output[3]).unwrap().as_slice::(), + ) + .into_par_iter() + .for_each(|(actual_0, actual_1, actual_2, actual_3)| { + assert_eq!(*actual_0, expected_0); + assert_eq!(*actual_1, expected_1); + assert_eq!(*actual_2, expected_2); + assert_eq!(*actual_3, expected_3); + }); + } + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); + } + + #[test] + fn test_random_input() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + let a_in = unconstrained::(&mut builder, "a", LOG_SIZE).unwrap(); + let b_in = unconstrained::(&mut builder, "b", LOG_SIZE).unwrap(); + let c_in = unconstrained::(&mut builder, "c", LOG_SIZE).unwrap(); + let d_in = unconstrained::(&mut builder, "d", LOG_SIZE).unwrap(); + let mx_in = unconstrained::(&mut builder, "mx", LOG_SIZE).unwrap(); + let my_in = unconstrained::(&mut builder, "my", LOG_SIZE).unwrap(); + + blake3_g(&mut builder, "g", a_in, b_in, c_in, d_in, mx_in, my_in, LOG_SIZE).unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); + } +} diff --git a/crates/circuits/src/builder/constraint_system.rs b/crates/circuits/src/builder/constraint_system.rs index ba8a2c86..e4010caf 100644 --- a/crates/circuits/src/builder/constraint_system.rs +++ b/crates/circuits/src/builder/constraint_system.rs @@ -16,35 +16,28 @@ use binius_core::{ transparent::step_down::StepDown, witness::MultilinearExtensionIndex, }; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, BinaryField1b, TowerField, -}; +use binius_field::{as_packed_field::PackScalar, BinaryField1b}; use binius_math::ArithExpr; use binius_utils::bail; -use crate::builder::witness; +use crate::builder::{ + types::{F, U}, + witness, +}; #[derive(Default)] -pub struct ConstraintSystemBuilder<'arena, U, F> -where - U: UnderlierType + PackScalar, - F: TowerField, -{ +pub struct ConstraintSystemBuilder<'arena> { oracles: Rc>>, constraints: ConstraintSetBuilder, non_zero_oracle_ids: Vec, flushes: Vec, step_down_dedup: HashMap<(usize, usize), OracleId>, - witness: Option>, + witness: Option>, next_channel_id: ChannelId, namespace_path: Vec, } -impl<'arena, U, F> ConstraintSystemBuilder<'arena, U, F> -where - U: UnderlierType + PackScalar, - F: TowerField, -{ +impl<'arena> ConstraintSystemBuilder<'arena> { pub fn new() -> Self { Self::default() } @@ -79,7 +72,7 @@ where }) } - pub fn witness(&mut self) -> Option<&mut witness::Builder<'arena, U, F>> { + pub fn witness(&mut self) -> Option<&mut witness::Builder<'arena>> { self.witness.as_mut() } @@ -354,7 +347,7 @@ where /// /// let log_size = 14; /// - /// let mut builder = ConstraintSystemBuilder::::new(); + /// let mut builder = ConstraintSystemBuilder::new(); /// builder.push_namespace("a"); /// let x = builder.add_committed("x", log_size, BinaryField1b::TOWER_LEVEL); /// builder.push_namespace("b"); diff --git a/crates/circuits/src/builder/mod.rs b/crates/circuits/src/builder/mod.rs index c0b59130..7ee629f3 100644 --- a/crates/circuits/src/builder/mod.rs +++ b/crates/circuits/src/builder/mod.rs @@ -1,6 +1,8 @@ // Copyright 2024-2025 Irreducible Inc. pub mod constraint_system; +pub mod test_utils; +pub mod types; pub mod witness; pub use constraint_system::ConstraintSystemBuilder; diff --git a/crates/circuits/src/builder/test_utils.rs b/crates/circuits/src/builder/test_utils.rs new file mode 100644 index 00000000..0aa80ad9 --- /dev/null +++ b/crates/circuits/src/builder/test_utils.rs @@ -0,0 +1,23 @@ +// Copyright 2025 Irreducible Inc. + +use binius_core::constraint_system::{channel::Boundary, validate::validate_witness}; + +use super::{types::F, ConstraintSystemBuilder}; + +pub fn test_circuit( + build_circuit: fn(&mut ConstraintSystemBuilder) -> Result>, anyhow::Error>, +) -> Result<(), anyhow::Error> { + let mut verifier_builder = ConstraintSystemBuilder::new(); + let verifier_boundaries = build_circuit(&mut verifier_builder)?; + let verifier_constraint_system = verifier_builder.build()?; + + let allocator = bumpalo::Bump::new(); + let mut prover_builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let prover_boundaries = build_circuit(&mut prover_builder)?; + let prover_witness = prover_builder.take_witness()?; + let _prover_constraint_system = prover_builder.build()?; + + assert_eq!(verifier_boundaries, prover_boundaries); + validate_witness(&verifier_constraint_system, &verifier_boundaries, &prover_witness)?; + Ok(()) +} diff --git a/crates/circuits/src/builder/types.rs b/crates/circuits/src/builder/types.rs new file mode 100644 index 00000000..ae33e7e4 --- /dev/null +++ b/crates/circuits/src/builder/types.rs @@ -0,0 +1,4 @@ +// Copyright 2025 Irreducible Inc. + +pub type F = binius_field::BinaryField128b; +pub type U = binius_field::arch::OptimalUnderlier; diff --git a/crates/circuits/src/builder/witness.rs b/crates/circuits/src/builder/witness.rs index 2209a36f..77e83868 100644 --- a/crates/circuits/src/builder/witness.rs +++ b/crates/circuits/src/builder/witness.rs @@ -10,35 +10,33 @@ use binius_core::{ use binius_field::{ as_packed_field::{PackScalar, PackedType}, underlier::WithUnderlier, - ExtensionField, Field, PackedField, TowerField, + ExtensionField, PackedField, TowerField, }; use binius_math::MultilinearExtension; use binius_utils::bail; use bytemuck::{must_cast_slice, must_cast_slice_mut, Pod}; -pub struct Builder<'arena, U: PackScalar, FW: TowerField> { +use super::types::{F, U}; + +pub struct Builder<'arena> { bump: &'arena bumpalo::Bump, - oracles: Rc>>, + oracles: Rc>>, #[allow(clippy::type_complexity)] - entries: Rc>>>>, + entries: Rc>>>>, } -struct WitnessBuilderEntry<'arena, U: PackScalar, FW: Field> { - witness: Result>, binius_math::Error>, +struct WitnessBuilderEntry<'arena> { + witness: Result>, binius_math::Error>, tower_level: usize, data: &'arena [U], } -impl<'arena, U, FW> Builder<'arena, U, FW> -where - U: PackScalar, - FW: TowerField, -{ +impl<'arena> Builder<'arena> { pub fn new( allocator: &'arena bumpalo::Bump, - oracles: Rc>>, + oracles: Rc>>, ) -> Self { Self { bump: allocator, @@ -47,10 +45,10 @@ where } } - pub fn new_column(&self, id: OracleId) -> EntryBuilder<'arena, U, FW, FS> + pub fn new_column(&self, id: OracleId) -> EntryBuilder<'arena, FS> where U: PackScalar, - FW: ExtensionField, + F: ExtensionField, { let oracles = self.oracles.borrow(); let log_rows = oracles.n_vars(id); @@ -69,10 +67,10 @@ where &self, id: OracleId, default: FS, - ) -> EntryBuilder<'arena, U, FW, FS> + ) -> EntryBuilder<'arena, FS> where U: PackScalar, - FW: ExtensionField, + F: ExtensionField, { let oracles = self.oracles.borrow(); let log_rows = oracles.n_vars(id); @@ -88,10 +86,11 @@ where } } - pub fn get(&self, id: OracleId) -> Result, Error> + pub fn get(&self, id: OracleId) -> Result, Error> where + FS: TowerField, U: PackScalar, - FW: ExtensionField, + F: ExtensionField, { let entries = self.entries.borrow(); let oracles = self.oracles.borrow(); @@ -122,11 +121,11 @@ where pub fn set( &self, id: OracleId, - entry: WitnessEntry<'arena, U, FS>, + entry: WitnessEntry<'arena, FS>, ) -> Result<(), Error> where U: PackScalar, - FW: ExtensionField, + F: ExtensionField, { let oracles = self.oracles.borrow(); if !oracles.is_valid_oracle_id(id) { @@ -145,7 +144,7 @@ where Ok(()) } - pub fn build(self) -> Result, Error> { + pub fn build(self) -> Result, Error> { let mut result = MultilinearExtensionIndex::new(); let entries = Rc::into_inner(self.entries) .ok_or_else(|| anyhow!("Failed to build. There are still entries refs. Make sure there are no pending column insertions."))? @@ -160,26 +159,37 @@ where } #[derive(Debug, Clone, Copy)] -pub struct WitnessEntry<'arena, U: PackScalar, FS: TowerField> { +pub struct WitnessEntry<'arena, FS: TowerField> +where + U: PackScalar, +{ data: &'arena [U], log_rows: usize, _marker: PhantomData, } -impl<'arena, U: PackScalar, FS: TowerField> WitnessEntry<'arena, U, FS> { +impl<'arena, FS: TowerField> WitnessEntry<'arena, FS> +where + U: PackScalar, +{ #[inline] pub fn packed(&self) -> &'arena [PackedType] { WithUnderlier::from_underliers_ref(self.data) } - pub const fn repacked(&self) -> WitnessEntry<'arena, U, FW> + #[inline] + pub const fn as_slice(&self) -> &'arena [T] { + must_cast_slice(self.data) + } + + pub const fn repacked(&self) -> WitnessEntry<'arena, FE> where - FW: TowerField + ExtensionField, - U: PackScalar, + FE: TowerField + ExtensionField, + U: PackScalar, { WitnessEntry { data: self.data, - log_rows: self.log_rows - >::LOG_DEGREE, + log_rows: self.log_rows - >::LOG_DEGREE, _marker: PhantomData, } } @@ -189,38 +199,36 @@ impl<'arena, U: PackScalar, FS: TowerField> WitnessEntry<'arena, U, FS> { } } -impl<'arena, U: PackScalar + Pod, FS: TowerField> WitnessEntry<'arena, U, FS> { - #[inline] - pub const fn as_slice(&self) -> &'arena [T] { - must_cast_slice(self.data) - } -} - -pub struct EntryBuilder<'arena, U, FW, FS> +pub struct EntryBuilder<'arena, FS> where - U: PackScalar + PackScalar, FS: TowerField, - FW: TowerField + ExtensionField, + U: PackScalar, + F: ExtensionField, { _marker: PhantomData, #[allow(clippy::type_complexity)] - entries: Rc>>>>, + entries: Rc>>>>, id: OracleId, log_rows: usize, data: Option<&'arena mut [U]>, } -impl EntryBuilder<'_, U, FW, FS> +impl EntryBuilder<'_, FS> where - U: PackScalar + PackScalar, FS: TowerField, - FW: TowerField + ExtensionField, + U: PackScalar, + F: ExtensionField, { #[inline] pub fn packed(&mut self) -> &mut [PackedType] { PackedType::::from_underliers_ref_mut(self.underliers()) } + #[inline] + pub fn as_mut_slice(&mut self) -> &mut [T] { + must_cast_slice_mut(self.underliers()) + } + #[inline] fn underliers(&mut self) -> &mut [U] { self.data @@ -229,23 +237,11 @@ where } } -impl EntryBuilder<'_, U, FW, FS> -where - U: PackScalar + PackScalar + Pod, - FS: TowerField, - FW: TowerField + ExtensionField, -{ - #[inline] - pub fn as_mut_slice(&mut self) -> &mut [T] { - must_cast_slice_mut(self.underliers()) - } -} - -impl Drop for EntryBuilder<'_, U, FW, FS> +impl Drop for EntryBuilder<'_, FS> where - U: PackScalar + PackScalar, FS: TowerField, - FW: TowerField + ExtensionField, + U: PackScalar, + F: ExtensionField, { fn drop(&mut self) { let data = Option::take(&mut self.data).expect("data is always Some until this point"); diff --git a/crates/circuits/src/collatz.rs b/crates/circuits/src/collatz.rs index df1f2d05..b76fdf2f 100644 --- a/crates/circuits/src/collatz.rs +++ b/crates/circuits/src/collatz.rs @@ -10,7 +10,14 @@ use binius_field::{ use binius_macros::arith_expr; use bytemuck::Pod; -use crate::{arithmetic, builder::ConstraintSystemBuilder, transparent}; +use crate::{ + arithmetic, + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, + transparent, +}; pub type Advice = (usize, usize); @@ -37,9 +44,9 @@ impl Collatz { (self.evens.len(), self.odds.len()) } - pub fn build( + pub fn build( self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, advice: Advice, ) -> Result>, anyhow::Error> where @@ -58,16 +65,12 @@ impl Collatz { Ok(boundaries) } - fn even( + fn even( &self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, channel: ChannelId, count: usize, - ) -> Result<(), anyhow::Error> - where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, - { + ) -> Result<(), anyhow::Error> { let log_1b_rows = 5 + binius_utils::checked_arithmetics::log2_ceil_usize(count); let even = builder.add_committed("even", log_1b_rows, BinaryField1b::TOWER_LEVEL); if let Some(witness) = builder.witness() { @@ -90,16 +93,12 @@ impl Collatz { Ok(()) } - fn odd( + fn odd( &self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, channel: ChannelId, count: usize, - ) -> Result<(), anyhow::Error> - where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, - { + ) -> Result<(), anyhow::Error> { let log_32b_rows = binius_utils::checked_arithmetics::log2_ceil_usize(count); let log_1b_rows = 5 + log_32b_rows; @@ -136,10 +135,7 @@ impl Collatz { Ok(()) } - fn get_boundaries(&self, channel_id: usize) -> Vec> - where - F: TowerField + From, - { + fn get_boundaries(&self, channel_id: usize) -> Vec> { vec![ Boundary { channel_id, @@ -179,15 +175,11 @@ pub fn collatz_orbit(x0: u32) -> Vec { res } -pub fn ensure_odd( - builder: &mut ConstraintSystemBuilder, +pub fn ensure_odd( + builder: &mut ConstraintSystemBuilder, input: OracleId, count: usize, -) -> Result<(), anyhow::Error> -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result<(), anyhow::Error> { let log_32b_rows = builder.log_rows([input])? - 5; let lsb = arithmetic::u32::select_bit(builder, "lsb", input, 0)?; let selector = transparent::step_down(builder, "count", log_32b_rows, count)?; @@ -201,28 +193,17 @@ where #[cfg(test)] mod tests { - use binius_core::constraint_system::validate::validate_witness; - use binius_field::{arch::OptimalUnderlier, BinaryField128b}; - - use crate::{builder::ConstraintSystemBuilder, collatz::Collatz}; + use crate::{builder::test_utils::test_circuit, collatz::Collatz}; #[test] fn test_collatz() { - let allocator = bumpalo::Bump::new(); - let mut builder = - ConstraintSystemBuilder::::new_with_witness( - &allocator, - ); - - let x0 = 9999999; - - let mut collatz = Collatz::new(x0); - let advice = collatz.init_prover(); - - let boundaries = collatz.build(&mut builder, advice).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + test_circuit(|builder| { + let x0 = 9999999; + let mut collatz = Collatz::new(x0); + let advice = collatz.init_prover(); + let boundaries = collatz.build(builder, advice)?; + Ok(boundaries) + }) + .unwrap(); } } diff --git a/crates/circuits/src/groestl.rs b/crates/circuits/src/groestl.rs index d1ab36d3..598e6de1 100644 --- a/crates/circuits/src/groestl.rs +++ b/crates/circuits/src/groestl.rs @@ -378,3 +378,28 @@ fn s_box(x: AESTowerField8b) -> AESTowerField8b { let idx = u8::from(x) as usize; AESTowerField8b::from(S_BOX[idx]) } + +#[cfg(test)] +mod tests { + use binius_core::constraint_system::validate::validate_witness; + use binius_field::{arch::OptimalUnderlier, AESTowerField16b}; + + use super::groestl_p_permutation; + use crate::builder::ConstraintSystemBuilder; + + #[test] + fn test_groestl() { + let allocator = bumpalo::Bump::new(); + let mut builder = + ConstraintSystemBuilder::::new_with_witness( + &allocator, + ); + let log_size = 9; + let _state_out = groestl_p_permutation(&mut builder, log_size).unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } +} diff --git a/crates/circuits/src/keccakf.rs b/crates/circuits/src/keccakf.rs index a73c5cea..89d11d56 100644 --- a/crates/circuits/src/keccakf.rs +++ b/crates/circuits/src/keccakf.rs @@ -8,14 +8,19 @@ use binius_core::{ transparent::multilinear_extension::MultilinearExtensionTransparent, }; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::{UnderlierType, WithUnderlier}, - BinaryField1b, BinaryField64b, ExtensionField, PackedField, TowerField, + as_packed_field::PackedType, underlier::WithUnderlier, BinaryField1b, BinaryField64b, Field, + PackedField, TowerField, }; use binius_macros::arith_expr; use bytemuck::{pod_collect_to_vec, Pod}; -use crate::{builder::ConstraintSystemBuilder, transparent::step_down}; +use crate::{ + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, + transparent::step_down, +}; #[derive(Default, Clone, Copy)] pub struct KeccakfState(pub [u64; STATE_SIZE]); @@ -25,15 +30,11 @@ pub struct KeccakfOracles { pub output: [OracleId; STATE_SIZE], } -pub fn keccakf( - builder: &mut ConstraintSystemBuilder, - input_witness: Option>, +pub fn keccakf( + builder: &mut ConstraintSystemBuilder, + input_witness: &Option>, log_size: usize, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar + PackScalar, - F: TowerField + ExtensionField, -{ +) -> Result { let internal_log_size = log_size + LOG_BIT_ROWS_PER_PERMUTATION; let round_consts_single: [OracleId; ROUNDS_PER_STATE_ROW] = array::try_from_fn(|round_within_row| { @@ -124,7 +125,7 @@ where builder.add_projected( "output", packed_state_out[xy], - vec![F::ONE; LOG_STATE_ROWS_PER_PERMUTATION], + vec![Field::ONE; LOG_STATE_ROWS_PER_PERMUTATION], ProjectionVariant::FirstVars, ) })?; @@ -135,7 +136,7 @@ where "c", internal_log_size, array::from_fn::<_, 5, _>(|offset| { - (state[round_within_row][x + 5 * offset], F::ONE) + (state[round_within_row][x + 5 * offset], Field::ONE) }), ) }) @@ -159,8 +160,8 @@ where "d", internal_log_size, [ - (c[round_within_row][(x + 4) % 5], F::ONE), - (c_shift[round_within_row][(x + 1) % 5], F::ONE), + (c[round_within_row][(x + 4) % 5], Field::ONE), + (c_shift[round_within_row][(x + 1) % 5], Field::ONE), ], ) }) @@ -174,8 +175,8 @@ where format!("a_theta[{xy}]"), internal_log_size, [ - (state[round_within_row][xy], F::ONE), - (d[round_within_row][x], F::ONE), + (state[round_within_row][xy], Field::ONE), + (d[round_within_row][x], Field::ONE), ], ) }) @@ -504,3 +505,23 @@ const KECCAKF_RC: [u64; ROUNDS_PER_PERMUTATION] = [ 0x0000000080000001, 0x8000000080008008, ]; + +#[cfg(test)] +mod tests { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + use super::{keccakf, KeccakfState}; + use crate::builder::test_utils::test_circuit; + + #[test] + fn test_keccakf() { + test_circuit(|builder| { + let log_size = 5; + let mut rng = StdRng::seed_from_u64(0); + let input_states = vec![KeccakfState(rng.gen())]; + let _state_out = keccakf(builder, &Some(input_states), log_size)?; + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/circuits/src/lasso/batch.rs b/crates/circuits/src/lasso/batch.rs index 8fb6d069..52c98e29 100644 --- a/crates/circuits/src/lasso/batch.rs +++ b/crates/circuits/src/lasso/batch.rs @@ -4,12 +4,15 @@ use anyhow::Ok; use binius_core::oracle::OracleId; use binius_field::{ as_packed_field::{PackScalar, PackedType}, - BinaryField1b, ExtensionField, PackedFieldIndexable, TowerField, + ExtensionField, PackedFieldIndexable, TowerField, }; use itertools::Itertools; use super::lasso::lasso; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; pub struct LookupBatch { lookup_us: Vec>, u_to_t_mappings: Vec>, @@ -48,19 +51,16 @@ impl LookupBatch { self.lookup_col_lens.push(lookup_u_col_len); } - pub fn execute( - mut self, - builder: &mut ConstraintSystemBuilder, - ) -> Result<(), anyhow::Error> + pub fn execute(mut self, builder: &mut ConstraintSystemBuilder) -> Result<(), anyhow::Error> where - U: PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, FC: TowerField, - F: ExtensionField + TowerField, + U: PackScalar, + F: ExtensionField, + PackedType: PackedFieldIndexable, { let channel = builder.add_channel(); - lasso::<_, _, FC>( + lasso::( builder, "batched lasso", &self.lookup_col_lens, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs index 43349e09..302d2384 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs @@ -3,14 +3,7 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b}; use crate::{ builder::ConstraintSystemBuilder, @@ -19,32 +12,16 @@ use crate::{ type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; -pub fn byte_sliced_add>( - builder: &mut ConstraintSystemBuilder, +pub fn byte_sliced_add: Sized>>( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, - x_in: &Level::Data, - y_in: &Level::Data, + x_in: &Level::Data, + y_in: &Level::Data, carry_in: OracleId, log_size: usize, lookup_batch_add: &mut LookupBatch, -) -> Result<(OracleId, Level::Data), anyhow::Error> -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, - Level::Data: Sized, -{ +) -> Result<(OracleId, Level::Data), anyhow::Error> { if Level::WIDTH == 1 { let (carry_out, sum) = u8add(builder, lookup_batch_add, name, x_in[0], y_in[0], carry_in, log_size)?; @@ -58,7 +35,7 @@ where let (lower_half_x, upper_half_x) = Level::split(x_in); let (lower_half_y, upper_half_y) = Level::split(y_in); - let (internal_carry, lower_sum) = byte_sliced_add::<_, _, Level::Base>( + let (internal_carry, lower_sum) = byte_sliced_add::( builder, format!("lower sum {}b", Level::Base::WIDTH), lower_half_x, @@ -68,7 +45,7 @@ where lookup_batch_add, )?; - let (carry_out, upper_sum) = byte_sliced_add::<_, _, Level::Base>( + let (carry_out, upper_sum) = byte_sliced_add::( builder, format!("upper sum {}b", Level::Base::WIDTH), upper_half_x, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs index 881aa1b9..ab7b864c 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs @@ -3,14 +3,7 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b}; use super::byte_sliced_add; use crate::{ @@ -20,34 +13,18 @@ use crate::{ type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_add_carryfree>( - builder: &mut ConstraintSystemBuilder, +pub fn byte_sliced_add_carryfree: Sized>>( + builder: &mut ConstraintSystemBuilder, name: impl ToString, - x_in: &Level::Data, - y_in: &Level::Data, + x_in: &Level::Data, + y_in: &Level::Data, carry_in: OracleId, log_size: usize, lookup_batch_add: &mut LookupBatch, lookup_batch_add_carryfree: &mut LookupBatch, -) -> Result -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, - Level::Data: Sized, -{ +) -> Result, anyhow::Error> { if Level::WIDTH == 1 { let sum = u8add_carryfree( builder, @@ -68,7 +45,7 @@ where let (lower_half_x, upper_half_x) = Level::split(x_in); let (lower_half_y, upper_half_y) = Level::split(y_in); - let (internal_carry, lower_sum) = byte_sliced_add::<_, _, Level::Base>( + let (internal_carry, lower_sum) = byte_sliced_add::( builder, format!("lower sum {}b", Level::Base::WIDTH), lower_half_x, @@ -78,7 +55,7 @@ where lookup_batch_add, )?; - let upper_sum = byte_sliced_add_carryfree::<_, _, Level::Base>( + let upper_sum = byte_sliced_add_carryfree::( builder, format!("upper sum {}b", Level::Base::WIDTH), upper_half_x, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs index b14baa3e..9000c457 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs @@ -3,14 +3,7 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b}; use crate::{ builder::ConstraintSystemBuilder, @@ -19,34 +12,18 @@ use crate::{ type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_double_conditional_increment>( - builder: &mut ConstraintSystemBuilder, +pub fn byte_sliced_double_conditional_increment: Sized>>( + builder: &mut ConstraintSystemBuilder, name: impl ToString, - x_in: &Level::Data, + x_in: &Level::Data, first_carry_in: OracleId, second_carry_in: OracleId, log_size: usize, zero_oracle_carry: usize, lookup_batch_dci: &mut LookupBatch, -) -> Result<(OracleId, Level::Data), anyhow::Error> -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, - Level::Data: Sized, -{ +) -> Result<(OracleId, Level::Data), anyhow::Error> { if Level::WIDTH == 1 { let (carry_out, sum) = u8_double_conditional_increment( builder, @@ -66,7 +43,7 @@ where let (lower_half_x, upper_half_x) = Level::split(x_in); - let (internal_carry, lower_sum) = byte_sliced_double_conditional_increment::<_, _, Level::Base>( + let (internal_carry, lower_sum) = byte_sliced_double_conditional_increment::( builder, format!("lower sum {}b", Level::Base::WIDTH), lower_half_x, @@ -77,7 +54,7 @@ where lookup_batch_dci, )?; - let (carry_out, upper_sum) = byte_sliced_double_conditional_increment::<_, _, Level::Base>( + let (carry_out, upper_sum) = byte_sliced_double_conditional_increment::( builder, format!("upper sum {}b", Level::Base::WIDTH), upper_half_x, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs index ead52ccc..d480de5c 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs @@ -4,59 +4,32 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::{oracle::OracleId, transparent::constant::Constant}; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::{UnderlierType, WithUnderlier}, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, + tower_levels::TowerLevel, underlier::WithUnderlier, BinaryField32b, BinaryField8b, TowerField, }; use binius_macros::arith_expr; -use bytemuck::Pod; use super::{byte_sliced_add_carryfree, byte_sliced_mul}; use crate::{ - builder::ConstraintSystemBuilder, + builder::{types::F, ConstraintSystemBuilder}, lasso::{ batch::LookupBatch, lookups::u8_arithmetic::{add_carryfree_lookup, add_lookup, dci_lookup, mul_lookup}, }, }; -type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_modular_mul< - U, - F, - LevelIn: TowerLevel, - LevelOut: TowerLevel, ->( - builder: &mut ConstraintSystemBuilder, +pub fn byte_sliced_modular_mul>( + builder: &mut ConstraintSystemBuilder, name: impl ToString, - mult_a: &LevelIn::Data, - mult_b: &LevelIn::Data, + mult_a: &LevelIn::Data, + mult_b: &LevelIn::Data, modulus_input: &[u8], log_size: usize, zero_byte_oracle: OracleId, zero_carry_oracle: OracleId, -) -> Result -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, - ::Underlier: From, -{ +) -> Result, anyhow::Error> { builder.push_namespace(name); let lookup_t_mul = mul_lookup(builder, "mul table")?; @@ -88,12 +61,14 @@ where "modulus", Constant::new( log_size, - F::from_underlier(>::into(modulus_input[byte_idx])), + ::from_underlier(::Underlier, + >>::into(modulus_input[byte_idx])), ), )?; } - let ab = byte_sliced_mul::<_, _, LevelIn, LevelOut>( + let ab = byte_sliced_mul::( builder, "ab", mult_a, @@ -166,7 +141,7 @@ where } } - let qm = byte_sliced_mul::<_, _, LevelIn, LevelOut>( + let qm = byte_sliced_mul::( builder, "qm", "ient, @@ -183,7 +158,7 @@ where repeating_zero[byte_idx] = zero_byte_oracle; } - let qm_plus_r = byte_sliced_add_carryfree::<_, _, LevelOut>( + let qm_plus_r = byte_sliced_add_carryfree::( builder, "hi*lo", &qm, @@ -194,12 +169,12 @@ where &mut lookup_batch_add_carryfree, )?; - lookup_batch_mul.execute::<_, _, BinaryField32b>(builder)?; - lookup_batch_add.execute::<_, _, BinaryField32b>(builder)?; - lookup_batch_add_carryfree.execute::<_, _, BinaryField32b>(builder)?; + lookup_batch_mul.execute::(builder)?; + lookup_batch_add.execute::(builder)?; + lookup_batch_add_carryfree.execute::(builder)?; if LevelIn::WIDTH != 1 { - lookup_batch_dci.execute::<_, _, BinaryField32b>(builder)?; + lookup_batch_dci.execute::(builder)?; } let consistency = arith_expr!([x, y] = x - y); diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs index 03742d34..de386d7a 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs @@ -3,14 +3,7 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{tower_levels::TowerLevel, BinaryField8b}; use super::{byte_sliced_add, byte_sliced_double_conditional_increment}; use crate::{ @@ -18,41 +11,20 @@ use crate::{ lasso::{batch::LookupBatch, u8mul::u8mul_bytesliced}, }; -type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_mul< - U, - F, - LevelIn: TowerLevel, - LevelOut: TowerLevel, ->( - builder: &mut ConstraintSystemBuilder, +pub fn byte_sliced_mul>( + builder: &mut ConstraintSystemBuilder, name: impl ToString, - mult_a: &LevelIn::Data, - mult_b: &LevelIn::Data, + mult_a: &LevelIn::Data, + mult_b: &LevelIn::Data, log_size: usize, zero_carry_oracle: OracleId, lookup_batch_mul: &mut LookupBatch, lookup_batch_add: &mut LookupBatch, lookup_batch_dci: &mut LookupBatch, -) -> Result -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result, anyhow::Error> { if LevelIn::WIDTH == 1 { let result_of_u8mul = u8mul_bytesliced( builder, @@ -77,7 +49,7 @@ where let (mult_a_low, mult_a_high) = LevelIn::split(mult_a); let (mult_b_low, mult_b_high) = LevelIn::split(mult_b); - let a_lo_b_lo = byte_sliced_mul::<_, _, LevelIn::Base, LevelOut::Base>( + let a_lo_b_lo = byte_sliced_mul::( builder, format!("lo*lo {}b", LevelIn::Base::WIDTH), mult_a_low, @@ -88,7 +60,7 @@ where lookup_batch_add, lookup_batch_dci, )?; - let a_lo_b_hi = byte_sliced_mul::<_, _, LevelIn::Base, LevelOut::Base>( + let a_lo_b_hi = byte_sliced_mul::( builder, format!("lo*hi {}b", LevelIn::Base::WIDTH), mult_a_low, @@ -99,7 +71,7 @@ where lookup_batch_add, lookup_batch_dci, )?; - let a_hi_b_lo = byte_sliced_mul::<_, _, LevelIn::Base, LevelOut::Base>( + let a_hi_b_lo = byte_sliced_mul::( builder, format!("hi*lo {}b", LevelIn::Base::WIDTH), mult_a_high, @@ -110,7 +82,7 @@ where lookup_batch_add, lookup_batch_dci, )?; - let a_hi_b_hi = byte_sliced_mul::<_, _, LevelIn::Base, LevelOut::Base>( + let a_hi_b_hi = byte_sliced_mul::( builder, format!("hi*hi {}b", LevelIn::Base::WIDTH), mult_a_high, @@ -122,7 +94,7 @@ where lookup_batch_dci, )?; - let (karatsuba_carry_for_high_chunk, karatsuba_term) = byte_sliced_add::<_, _, LevelIn>( + let (karatsuba_carry_for_high_chunk, karatsuba_term) = byte_sliced_add::( builder, format!("karastsuba addition {}b", LevelIn::WIDTH), &a_lo_b_hi, @@ -135,7 +107,7 @@ where let (a_lo_b_lo_lower_half, a_lo_b_lo_upper_half) = LevelIn::split(&a_lo_b_lo); let (a_hi_b_hi_lower_half, a_hi_b_hi_upper_half) = LevelIn::split(&a_hi_b_hi); - let (additional_carry_for_high_chunk, final_middle_chunk) = byte_sliced_add::<_, _, LevelIn>( + let (additional_carry_for_high_chunk, final_middle_chunk) = byte_sliced_add::( builder, format!("post kartsuba middle term addition {}b", LevelIn::WIDTH), &karatsuba_term, @@ -145,7 +117,7 @@ where lookup_batch_add, )?; - let (_, final_high_chunk) = byte_sliced_double_conditional_increment::<_, _, LevelIn::Base>( + let (_, final_high_chunk) = byte_sliced_double_conditional_increment::( builder, format!("high chunk DCI {}b", LevelIn::Base::WIDTH), a_hi_b_hi_upper_half, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs index 1ceaae04..411a5c15 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs @@ -3,19 +3,18 @@ use std::{array, fmt::Debug}; use alloy_primitives::U512; -use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; +use binius_core::oracle::OracleId; use binius_field::{ - arch::OptimalUnderlier, tower_levels::TowerLevel, BinaryField128b, BinaryField1b, - BinaryField32b, BinaryField8b, Field, TowerField, + tower_levels::TowerLevel, BinaryField1b, BinaryField32b, BinaryField8b, Field, TowerField, }; -use rand::{rngs::ThreadRng, thread_rng, Rng}; +use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng}; use super::{ byte_sliced_add, byte_sliced_add_carryfree, byte_sliced_double_conditional_increment, byte_sliced_modular_mul, byte_sliced_mul, }; use crate::{ - builder::ConstraintSystemBuilder, + builder::test_utils::test_circuit, lasso::{ batch::LookupBatch, lookups::u8_arithmetic::{add_carryfree_lookup, add_lookup, dci_lookup, mul_lookup}, @@ -27,297 +26,230 @@ use crate::{ type B8 = BinaryField8b; type B32 = BinaryField32b; -pub fn random_u512(rng: &mut ThreadRng) -> U512 { +pub fn random_u512(rng: &mut impl Rng) -> U512 { let limbs = array::from_fn(|_| rng.gen()); U512::from_limbs(limbs) } pub fn test_bytesliced_add() where - TL: TowerLevel, + TL: TowerLevel, { - type U = OptimalUnderlier; - type F = BinaryField128b; - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - - let x_in = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "x", log_size).unwrap() - }); - let y_in = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "y", log_size).unwrap() - }); - let c_in = unconstrained::<_, _, BinaryField1b>(&mut builder, "cin first", log_size).unwrap(); - - let lookup_t_add = add_lookup(&mut builder, "add table").unwrap(); - - let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); - let _sum_and_cout = byte_sliced_add::<_, _, TL>( - &mut builder, - "lasso_bytesliced_add", - &x_in, - &y_in, - c_in, - log_size, - &mut lookup_batch_add, - ) + test_circuit(|builder| { + let log_size = 14; + let x_in = TL::from_fn(|_| unconstrained::(builder, "x", log_size).unwrap()); + let y_in = TL::from_fn(|_| unconstrained::(builder, "y", log_size).unwrap()); + let c_in = unconstrained::(builder, "cin first", log_size)?; + let lookup_t_add = add_lookup(builder, "add table")?; + let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); + let _sum_and_cout = byte_sliced_add::( + builder, + "lasso_bytesliced_add", + &x_in, + &y_in, + c_in, + log_size, + &mut lookup_batch_add, + )?; + lookup_batch_add.execute::(builder)?; + Ok(vec![]) + }) .unwrap(); - - lookup_batch_add.execute::<_, _, B32>(&mut builder).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } pub fn test_bytesliced_add_carryfree() where - TL: TowerLevel, + TL: TowerLevel, { - type U = OptimalUnderlier; - type F = BinaryField128b; - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - let x_in = array::from_fn(|_| builder.add_committed("x", log_size, BinaryField8b::TOWER_LEVEL)); - let y_in = array::from_fn(|_| builder.add_committed("y", log_size, BinaryField8b::TOWER_LEVEL)); - let c_in = builder.add_committed("c", log_size, BinaryField1b::TOWER_LEVEL); - - if let Some(witness) = builder.witness() { - let mut x_in: [_; WIDTH] = - array::from_fn(|byte_idx| witness.new_column::(x_in[byte_idx])); - let mut y_in: [_; WIDTH] = - array::from_fn(|byte_idx| witness.new_column::(y_in[byte_idx])); - let mut c_in = witness.new_column::(c_in); - - let x_in_bytes_u8: [_; WIDTH] = x_in.each_mut().map(|col| col.as_mut_slice::()); - let y_in_bytes_u8: [_; WIDTH] = y_in.each_mut().map(|col| col.as_mut_slice::()); - let c_in_u8 = c_in.as_mut_slice::(); - - for row_idx in 0..1 << log_size { - let mut rng = thread_rng(); - let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8); - let mut x = random_u512(&mut rng); - x &= input_bitmask; - let mut y = random_u512(&mut rng); - y &= input_bitmask; - - let mut c: bool = rng.gen(); - - while (x + y + U512::from(c)) > input_bitmask { - x = random_u512(&mut rng); + test_circuit(|builder| { + let log_size = 14; + let x_in = + TL::from_fn(|_| builder.add_committed("x", log_size, BinaryField8b::TOWER_LEVEL)); + let y_in = + TL::from_fn(|_| builder.add_committed("y", log_size, BinaryField8b::TOWER_LEVEL)); + let c_in = builder.add_committed("c", log_size, BinaryField1b::TOWER_LEVEL); + + if let Some(witness) = builder.witness() { + let mut x_in: [_; WIDTH] = + array::from_fn(|byte_idx| witness.new_column::(x_in[byte_idx])); + let mut y_in: [_; WIDTH] = + array::from_fn(|byte_idx| witness.new_column::(y_in[byte_idx])); + let mut c_in = witness.new_column::(c_in); + + let x_in_bytes_u8: [_; WIDTH] = x_in.each_mut().map(|col| col.as_mut_slice::()); + let y_in_bytes_u8: [_; WIDTH] = y_in.each_mut().map(|col| col.as_mut_slice::()); + let c_in_u8 = c_in.as_mut_slice::(); + + for row_idx in 0..1 << log_size { + let mut rng = thread_rng(); + let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8); + let mut x = random_u512(&mut rng); x &= input_bitmask; - y = random_u512(&mut rng); + let mut y = random_u512(&mut rng); y &= input_bitmask; - c = rng.gen(); - } - for byte_idx in 0..WIDTH { - x_in_bytes_u8[byte_idx][row_idx] = x.byte(byte_idx); + let mut c: bool = rng.gen(); - y_in_bytes_u8[byte_idx][row_idx] = y.byte(byte_idx); - } + while (x + y + U512::from(c)) > input_bitmask { + x = random_u512(&mut rng); + x &= input_bitmask; + y = random_u512(&mut rng); + y &= input_bitmask; + c = rng.gen(); + } - c_in_u8[row_idx / 8] |= (c as u8) << (row_idx % 8); - } - } + for byte_idx in 0..WIDTH { + x_in_bytes_u8[byte_idx][row_idx] = x.byte(byte_idx); - let lookup_t_add = add_lookup(&mut builder, "add table").unwrap(); - let lookup_t_add_carryfree = add_carryfree_lookup(&mut builder, "add table").unwrap(); + y_in_bytes_u8[byte_idx][row_idx] = y.byte(byte_idx); + } - let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); - let mut lookup_batch_add_carryfree = LookupBatch::new([lookup_t_add_carryfree]); + c_in_u8[row_idx / 8] |= (c as u8) << (row_idx % 8); + } + } - let _sum_and_cout = byte_sliced_add_carryfree::<_, _, TL>( - &mut builder, - "lasso_bytesliced_add_carryfree", - &x_in, - &y_in, - c_in, - log_size, - &mut lookup_batch_add, - &mut lookup_batch_add_carryfree, - ) + let lookup_t_add = add_lookup(builder, "add table")?; + let lookup_t_add_carryfree = add_carryfree_lookup(builder, "add table")?; + + let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); + let mut lookup_batch_add_carryfree = LookupBatch::new([lookup_t_add_carryfree]); + + let _sum_and_cout = byte_sliced_add_carryfree::( + builder, + "lasso_bytesliced_add_carryfree", + &x_in, + &y_in, + c_in, + log_size, + &mut lookup_batch_add, + &mut lookup_batch_add_carryfree, + )?; + + lookup_batch_add.execute::(builder)?; + lookup_batch_add_carryfree.execute::(builder)?; + Ok(vec![]) + }) .unwrap(); - - lookup_batch_add.execute::<_, _, B32>(&mut builder).unwrap(); - lookup_batch_add_carryfree - .execute::<_, _, B32>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } pub fn test_bytesliced_double_conditional_increment() where - TL: TowerLevel, + TL: TowerLevel, { - type U = OptimalUnderlier; - type F = BinaryField128b; - - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - - let x_in = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "x", log_size).unwrap() - }); - - let first_c_in = - unconstrained::<_, _, BinaryField1b>(&mut builder, "cin first", log_size).unwrap(); - - let second_c_in = - unconstrained::<_, _, BinaryField1b>(&mut builder, "cin second", log_size).unwrap(); - - let zero_oracle_carry = - transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); - let lookup_t_dci = dci_lookup(&mut builder, "add table").unwrap(); - - let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); - - let _sum_and_cout = byte_sliced_double_conditional_increment::<_, _, TL>( - &mut builder, - "lasso_bytesliced_DCI", - &x_in, - first_c_in, - second_c_in, - log_size, - zero_oracle_carry, - &mut lookup_batch_dci, - ) + test_circuit(|builder| { + let log_size = 14; + let x_in = TL::from_fn(|_| unconstrained::(builder, "x", log_size).unwrap()); + let first_c_in = unconstrained::(builder, "cin first", log_size)?; + let second_c_in = unconstrained::(builder, "cin second", log_size)?; + let zero_oracle_carry = + transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?; + let lookup_t_dci = dci_lookup(builder, "add table")?; + let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); + let _sum_and_cout = byte_sliced_double_conditional_increment::( + builder, + "lasso_bytesliced_DCI", + &x_in, + first_c_in, + second_c_in, + log_size, + zero_oracle_carry, + &mut lookup_batch_dci, + )?; + lookup_batch_dci.execute::(builder)?; + Ok(vec![]) + }) .unwrap(); - - lookup_batch_dci.execute::<_, _, B32>(&mut builder).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } pub fn test_bytesliced_mul() where - TL: TowerLevel, - TL::Base: TowerLevel, + TL: TowerLevel, { - type U = OptimalUnderlier; - type F = BinaryField128b; - - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - - let mult_a = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "a", log_size).unwrap() - }); - let mult_b = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "b", log_size).unwrap() - }); - - let zero_oracle_carry = - transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); - - let lookup_t_mul = mul_lookup(&mut builder, "mul lookup").unwrap(); - let lookup_t_add = add_lookup(&mut builder, "add lookup").unwrap(); - let lookup_t_dci = dci_lookup(&mut builder, "dci lookup").unwrap(); - - let mut lookup_batch_mul = LookupBatch::new([lookup_t_mul]); - let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); - let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); - - let _sum_and_cout = byte_sliced_mul::<_, _, TL::Base, TL>( - &mut builder, - "lasso_bytesliced_mul", - &mult_a, - &mult_b, - log_size, - zero_oracle_carry, - &mut lookup_batch_mul, - &mut lookup_batch_add, - &mut lookup_batch_dci, - ) + test_circuit(|builder| { + let log_size = 14; + let mult_a = + TL::Base::from_fn(|_| unconstrained::(builder, "a", log_size).unwrap()); + let mult_b = + TL::Base::from_fn(|_| unconstrained::(builder, "b", log_size).unwrap()); + let zero_oracle_carry = + transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?; + let lookup_t_mul = mul_lookup(builder, "mul lookup")?; + let lookup_t_add = add_lookup(builder, "add lookup")?; + let lookup_t_dci = dci_lookup(builder, "dci lookup")?; + let mut lookup_batch_mul = LookupBatch::new([lookup_t_mul]); + let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); + let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); + let _sum_and_cout = byte_sliced_mul::( + builder, + "lasso_bytesliced_mul", + &mult_a, + &mult_b, + log_size, + zero_oracle_carry, + &mut lookup_batch_mul, + &mut lookup_batch_add, + &mut lookup_batch_dci, + )?; + Ok(vec![]) + }) .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } pub fn test_bytesliced_modular_mul() where - TL: TowerLevel, - TL::Base: TowerLevel, - >::Data: Debug, + TL: TowerLevel: Debug>, + TL::Base: TowerLevel = [OracleId; WIDTH]>, { - type U = OptimalUnderlier; - type F = BinaryField128b; - - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - - let mut rng = thread_rng(); - - let mult_a = builder.add_committed_multiple::("a", log_size, B8::TOWER_LEVEL); - let mult_b = builder.add_committed_multiple::("b", log_size, B8::TOWER_LEVEL); - - let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8); + test_circuit(|builder| { + let log_size = 14; + let mut rng = thread_rng(); + let mult_a = builder.add_committed_multiple::("a", log_size, B8::TOWER_LEVEL); + let mult_b = builder.add_committed_multiple::("b", log_size, B8::TOWER_LEVEL); + let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8); + let modulus = + (random_u512(&mut StdRng::from_seed([42; 32])) % input_bitmask) + U512::from(1u8); - let modulus = (random_u512(&mut rng) % input_bitmask) + U512::from(1u8); + if let Some(witness) = builder.witness() { + let mut mult_a: [_; WIDTH] = + array::from_fn(|byte_idx| witness.new_column::(mult_a[byte_idx])); - if let Some(witness) = builder.witness() { - let mut mult_a: [_; WIDTH] = - array::from_fn(|byte_idx| witness.new_column::(mult_a[byte_idx])); + let mult_a_u8 = mult_a.each_mut().map(|col| col.as_mut_slice::()); - let mult_a_u8 = mult_a.each_mut().map(|col| col.as_mut_slice::()); + let mut mult_b: [_; WIDTH] = + array::from_fn(|byte_idx| witness.new_column::(mult_b[byte_idx])); - let mut mult_b: [_; WIDTH] = - array::from_fn(|byte_idx| witness.new_column::(mult_b[byte_idx])); + let mult_b_u8 = mult_b.each_mut().map(|col| col.as_mut_slice::()); - let mult_b_u8 = mult_b.each_mut().map(|col| col.as_mut_slice::()); + for row_idx in 0..1 << log_size { + let mut a = random_u512(&mut rng); + let mut b = random_u512(&mut rng); - for row_idx in 0..1 << log_size { - let mut a = random_u512(&mut rng); - let mut b = random_u512(&mut rng); + a %= modulus; + b %= modulus; - a %= modulus; - b %= modulus; - - for byte_idx in 0..WIDTH { - mult_a_u8[byte_idx][row_idx] = a.byte(byte_idx); - mult_b_u8[byte_idx][row_idx] = b.byte(byte_idx); + for byte_idx in 0..WIDTH { + mult_a_u8[byte_idx][row_idx] = a.byte(byte_idx); + mult_b_u8[byte_idx][row_idx] = b.byte(byte_idx); + } } } - } - - let modulus_input: [_; WIDTH] = array::from_fn(|byte_idx| modulus.byte(byte_idx)); - - let zero_oracle_byte = - transparent::constant(&mut builder, "zero carry", log_size, BinaryField8b::ZERO).unwrap(); - let zero_oracle_carry = - transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); - - let _modded_product = byte_sliced_modular_mul::<_, _, TL::Base, TL>( - &mut builder, - "lasso_bytesliced_mul", - &mult_a, - &mult_b, - &modulus_input, - log_size, - zero_oracle_byte, - zero_oracle_carry, - ) + let modulus_input: [_; WIDTH] = array::from_fn(|byte_idx| modulus.byte(byte_idx)); + let zero_oracle_byte = + transparent::constant(builder, "zero carry", log_size, BinaryField8b::ZERO)?; + let zero_oracle_carry = + transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?; + let _modded_product = byte_sliced_modular_mul::( + builder, + "lasso_bytesliced_mul", + &mult_a, + &mult_b, + &modulus_input, + log_size, + zero_oracle_byte, + zero_oracle_carry, + )?; + Ok(vec![]) + }) .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } diff --git a/crates/circuits/src/lasso/big_integer_ops/mod.rs b/crates/circuits/src/lasso/big_integer_ops/mod.rs index 3046af9b..4b41b524 100644 --- a/crates/circuits/src/lasso/big_integer_ops/mod.rs +++ b/crates/circuits/src/lasso/big_integer_ops/mod.rs @@ -12,3 +12,56 @@ pub use byte_sliced_add_carryfree::byte_sliced_add_carryfree; pub use byte_sliced_double_conditional_increment::byte_sliced_double_conditional_increment; pub use byte_sliced_modular_mul::byte_sliced_modular_mul; pub use byte_sliced_mul::byte_sliced_mul; + +#[cfg(test)] +mod tests { + use binius_field::tower_levels::{ + TowerLevel1, TowerLevel16, TowerLevel2, TowerLevel4, TowerLevel8, + }; + + use super::byte_sliced_test_utils::{ + test_bytesliced_add, test_bytesliced_add_carryfree, + test_bytesliced_double_conditional_increment, test_bytesliced_modular_mul, + test_bytesliced_mul, + }; + + #[test] + fn test_lasso_add_bytesliced() { + test_bytesliced_add::<1, TowerLevel1>(); + test_bytesliced_add::<2, TowerLevel2>(); + test_bytesliced_add::<4, TowerLevel4>(); + test_bytesliced_add::<8, TowerLevel8>(); + } + + #[test] + fn test_lasso_mul_bytesliced() { + test_bytesliced_mul::<1, TowerLevel2>(); + test_bytesliced_mul::<2, TowerLevel4>(); + test_bytesliced_mul::<4, TowerLevel8>(); + test_bytesliced_mul::<8, TowerLevel16>(); + } + + #[test] + fn test_lasso_modular_mul_bytesliced() { + test_bytesliced_modular_mul::<1, TowerLevel2>(); + test_bytesliced_modular_mul::<2, TowerLevel4>(); + test_bytesliced_modular_mul::<4, TowerLevel8>(); + test_bytesliced_modular_mul::<8, TowerLevel16>(); + } + + #[test] + fn test_lasso_bytesliced_double_conditional_increment() { + test_bytesliced_double_conditional_increment::<1, TowerLevel1>(); + test_bytesliced_double_conditional_increment::<2, TowerLevel2>(); + test_bytesliced_double_conditional_increment::<4, TowerLevel4>(); + test_bytesliced_double_conditional_increment::<8, TowerLevel8>(); + } + + #[test] + fn test_lasso_bytesliced_add_carryfree() { + test_bytesliced_add_carryfree::<1, TowerLevel1>(); + test_bytesliced_add_carryfree::<2, TowerLevel2>(); + test_bytesliced_add_carryfree::<4, TowerLevel4>(); + test_bytesliced_add_carryfree::<8, TowerLevel8>(); + } +} diff --git a/crates/circuits/src/lasso/lasso.rs b/crates/circuits/src/lasso/lasso.rs index 1ecde278..d93e2d48 100644 --- a/crates/circuits/src/lasso/lasso.rs +++ b/crates/circuits/src/lasso/lasso.rs @@ -4,15 +4,20 @@ use anyhow::{ensure, Error, Result}; use binius_core::{constraint_system::channel::ChannelId, oracle::OracleId}; use binius_field::{ as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField1b, ExtensionField, PackedFieldIndexable, TowerField, + ExtensionField, Field, PackedFieldIndexable, TowerField, }; use itertools::{izip, Itertools}; -use crate::{builder::ConstraintSystemBuilder, transparent}; +use crate::{ + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, + transparent, +}; -pub fn lasso( - builder: &mut ConstraintSystemBuilder, +pub fn lasso( + builder: &mut ConstraintSystemBuilder, name: impl ToString, n_lookups: &[usize], u_to_t_mappings: &[impl AsRef<[usize]>], @@ -21,10 +26,10 @@ pub fn lasso( channel: ChannelId, ) -> Result<()> where - U: UnderlierType + PackScalar + PackScalar + PackScalar, - F: TowerField + ExtensionField + From, - PackedType: PackedFieldIndexable, FC: TowerField, + U: PackScalar, + F: ExtensionField + From, + PackedType: PackedFieldIndexable, { if n_lookups.len() != lookups_u.len() { Err(anyhow::Error::msg("n_vars and lookups_u must be of the same length"))?; @@ -55,7 +60,7 @@ where } let t_log_rows = builder.log_rows(lookup_t.as_ref().iter().copied())?; - let lookup_o = transparent::constant(builder, "lookup_o", t_log_rows, F::ONE)?; + let lookup_o = transparent::constant(builder, "lookup_o", t_log_rows, Field::ONE)?; let lookup_f = builder.add_committed("lookup_f", t_log_rows, FC::TOWER_LEVEL); let lookups_r = u_log_rows .iter() diff --git a/crates/circuits/src/lasso/lookups/u8_arithmetic.rs b/crates/circuits/src/lasso/lookups/u8_arithmetic.rs index 5fdbe0c7..bb9a4e6c 100644 --- a/crates/circuits/src/lasso/lookups/u8_arithmetic.rs +++ b/crates/circuits/src/lasso/lookups/u8_arithmetic.rs @@ -2,34 +2,19 @@ use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField32b, TowerField}; use crate::builder::ConstraintSystemBuilder; -type B8 = BinaryField8b; -type B16 = BinaryField16b; type B32 = BinaryField32b; const T_LOG_SIZE_MUL: usize = 16; const T_LOG_SIZE_ADD: usize = 17; const T_LOG_SIZE_DCI: usize = 10; -pub fn mul_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn mul_lookup( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_MUL, B32::TOWER_LEVEL); @@ -53,17 +38,10 @@ where Ok(lookup_t) } -pub fn add_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn add_lookup( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_ADD, B32::TOWER_LEVEL); @@ -95,17 +73,10 @@ where Ok(lookup_t) } -pub fn add_carryfree_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn add_carryfree_lookup( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_ADD, B32::TOWER_LEVEL); @@ -139,17 +110,10 @@ where Ok(lookup_t) } -pub fn dci_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn dci_lookup( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_DCI, B32::TOWER_LEVEL); @@ -182,3 +146,125 @@ where builder.pop_namespace(); Ok(lookup_t) } + +#[cfg(test)] +mod tests { + use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b}; + + use crate::{ + builder::test_utils::test_circuit, + lasso::{self, batch::LookupBatch}, + unconstrained::unconstrained, + }; + + #[test] + fn test_lasso_u8add_carryfree_rejects_carry() { + // TODO: Make this test 100% certain to pass instead of 2^14 bits of security from randomness + test_circuit(|builder| { + let log_size = 14; + let x_in = unconstrained::(builder, "x", log_size)?; + let y_in = unconstrained::(builder, "y", log_size)?; + let c_in = unconstrained::(builder, "c", log_size)?; + + let lookup_t = super::add_carryfree_lookup(builder, "add cf table")?; + let mut lookup_batch = LookupBatch::new([lookup_t]); + let _sum_and_cout = lasso::u8add_carryfree( + builder, + &mut lookup_batch, + "lasso_u8add", + x_in, + y_in, + c_in, + log_size, + )?; + lookup_batch.execute::(builder)?; + Ok(vec![]) + }) + .expect_err("Rejected overflowing add"); + } + + #[test] + fn test_lasso_u8mul() { + test_circuit(|builder| { + let log_size = 10; + + let mult_a = unconstrained::(builder, "mult_a", log_size)?; + let mult_b = unconstrained::(builder, "mult_b", log_size)?; + + let mul_lookup_table = super::mul_lookup(builder, "mul table")?; + + let mut lookup_batch = LookupBatch::new([mul_lookup_table]); + + let _product = lasso::u8mul( + builder, + &mut lookup_batch, + "lasso_u8mul", + mult_a, + mult_b, + 1 << log_size, + )?; + + lookup_batch.execute::(builder)?; + Ok(vec![]) + }) + .unwrap(); + } + + #[test] + fn test_lasso_batched_u8mul() { + test_circuit(|builder| { + let log_size = 10; + let mul_lookup_table = super::mul_lookup(builder, "mul table")?; + + let mut lookup_batch = LookupBatch::new([mul_lookup_table]); + + for _ in 0..10 { + let mult_a = unconstrained::(builder, "mult_a", log_size)?; + let mult_b = unconstrained::(builder, "mult_b", log_size)?; + + let _product = lasso::u8mul( + builder, + &mut lookup_batch, + "lasso_u8mul", + mult_a, + mult_b, + 1 << log_size, + )?; + } + + lookup_batch.execute::(builder)?; + Ok(vec![]) + }) + .unwrap(); + } + + #[test] + fn test_lasso_batched_u8mul_rejects() { + test_circuit(|builder| { + let log_size = 10; + + // We try to feed in the add table instead + let mul_lookup_table = super::add_lookup(builder, "mul table")?; + + let mut lookup_batch = LookupBatch::new([mul_lookup_table]); + + // TODO?: Make this test fail 100% of the time, even though its almost impossible with rng + for _ in 0..10 { + let mult_a = unconstrained::(builder, "mult_a", log_size)?; + let mult_b = unconstrained::(builder, "mult_b", log_size)?; + let _product = lasso::u8mul( + builder, + &mut lookup_batch, + "lasso_u8mul", + mult_a, + mult_b, + 1 << log_size, + )?; + } + + lookup_batch.execute::(builder)?; + Ok(vec![]) + }) + .expect_err("Channels should be unbalanced"); + } +} diff --git a/crates/circuits/src/lasso/sha256.rs b/crates/circuits/src/lasso/sha256.rs index c8677acb..1b8f3c3d 100644 --- a/crates/circuits/src/lasso/sha256.rs +++ b/crates/circuits/src/lasso/sha256.rs @@ -1,21 +1,19 @@ // Copyright 2024-2025 Irreducible Inc. -use std::marker::PhantomData; - use anyhow::Result; use binius_core::oracle::OracleId; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField16b, BinaryField1b, BinaryField32b, BinaryField4b, BinaryField8b, ExtensionField, + as_packed_field::PackedType, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField4b, PackedFieldIndexable, TowerField, }; -use bytemuck::Pod; use itertools::izip; use super::{lasso::lasso, u32add::SeveralU32add}; use crate::{ - builder::ConstraintSystemBuilder, + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, pack::pack, sha256::{rotate_and_xor, u32const_repeating, RotateRightType, INIT, ROUND_CONSTS_K}, }; @@ -24,36 +22,19 @@ pub const CH_MAJ_T_LOG_SIZE: usize = 12; type B1 = BinaryField1b; type B4 = BinaryField4b; -type B8 = BinaryField8b; type B16 = BinaryField16b; type B32 = BinaryField32b; -struct SeveralBitwise { +struct SeveralBitwise { n_lookups: Vec, lookup_t: OracleId, lookups_u: Vec<[OracleId; 1]>, u_to_t_mappings: Vec>, f: fn(u32, u32, u32) -> u32, - _phantom: PhantomData<(U, F)>, } -impl SeveralBitwise -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + ExtensionField + ExtensionField + ExtensionField, -{ - pub fn new( - builder: &mut ConstraintSystemBuilder, - f: fn(u32, u32, u32) -> u32, - ) -> Result { +impl SeveralBitwise { + pub fn new(builder: &mut ConstraintSystemBuilder, f: fn(u32, u32, u32) -> u32) -> Result { let lookup_t = builder.add_committed("bitwise lookup_t", CH_MAJ_T_LOG_SIZE, B16::TOWER_LEVEL); @@ -80,13 +61,12 @@ where lookups_u: Vec::new(), u_to_t_mappings: Vec::new(), f, - _phantom: PhantomData, }) } pub fn calculate( &mut self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, params: [OracleId; 3], ) -> Result { @@ -94,9 +74,9 @@ where let log_size = builder.log_rows(params)?; - let xin_packed = pack::(xin, builder, "xin_packed")?; - let yin_packed = pack::(yin, builder, "yin_packed")?; - let zin_packed = pack::(zin, builder, "zin_packed")?; + let xin_packed = pack::(xin, builder, "xin_packed")?; + let yin_packed = pack::(yin, builder, "yin_packed")?; + let zin_packed = pack::(zin, builder, "zin_packed")?; let res = builder.add_committed(name, log_size, B1::TOWER_LEVEL); @@ -160,12 +140,12 @@ where pub fn finalize( self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, ) -> Result<()> { let channel = builder.add_channel(); - lasso::<_, _, B32>( + lasso::( builder, name, &self.n_lookups, @@ -177,29 +157,11 @@ where } } -pub fn sha256( - builder: &mut ConstraintSystemBuilder, +pub fn sha256( + builder: &mut ConstraintSystemBuilder, input: [OracleId; 16], log_size: usize, -) -> Result<[OracleId; 8], anyhow::Error> -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField - + ExtensionField - + ExtensionField - + ExtensionField - + ExtensionField, -{ +) -> Result<[OracleId; 8], anyhow::Error> { let n_vars = log_size; let mut several_u32_add = SeveralU32add::new(builder)?; @@ -309,3 +271,64 @@ where Ok(output) } + +#[cfg(test)] +mod tests { + use binius_core::oracle::OracleId; + use binius_field::{as_packed_field::PackedType, BinaryField1b, BinaryField8b, TowerField}; + use sha2::{compress256, digest::generic_array::GenericArray}; + + use crate::{ + builder::{test_utils::test_circuit, types::U}, + unconstrained::unconstrained, + }; + + #[test] + fn test_sha256_lasso() { + test_circuit(|builder| { + let log_size = PackedType::::LOG_WIDTH + BinaryField8b::TOWER_LEVEL; + let input: [OracleId; 16] = std::array::from_fn(|i| { + unconstrained::(builder, i, log_size).unwrap() + }); + let state_output = super::sha256(builder, input, log_size).unwrap(); + + if let Some(witness) = builder.witness() { + let input_witneses: [_; 16] = std::array::from_fn(|i| { + witness + .get::(input[i]) + .unwrap() + .as_slice::() + }); + + let output_witneses: [_; 8] = std::array::from_fn(|i| { + witness + .get::(state_output[i]) + .unwrap() + .as_slice::() + }); + + let mut generic_array_input = GenericArray::::default(); + + let n_compressions = input_witneses[0].len(); + + for j in 0..n_compressions { + for i in 0..16 { + for z in 0..4 { + generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; + } + } + + let mut output = crate::sha256::INIT; + compress256(&mut output, &[generic_array_input]); + + for i in 0..8 { + assert_eq!(output[i], output_witneses[i][j]); + } + } + } + + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/circuits/src/lasso/u32add.rs b/crates/circuits/src/lasso/u32add.rs index dca11d72..9b260b8c 100644 --- a/crates/circuits/src/lasso/u32add.rs +++ b/crates/circuits/src/lasso/u32add.rs @@ -7,14 +7,19 @@ use binius_core::oracle::{OracleId, ShiftVariant}; use binius_field::{ as_packed_field::{PackScalar, PackedType}, packed::set_packed_slice, - underlier::{UnderlierType, U1}, + underlier::U1, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, PackedFieldIndexable, TowerField, }; -use bytemuck::Pod; use itertools::izip; use super::lasso::lasso; -use crate::{builder::ConstraintSystemBuilder, pack::pack}; +use crate::{ + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, + pack::pack, +}; const ADD_T_LOG_SIZE: usize = 17; @@ -22,32 +27,18 @@ type B1 = BinaryField1b; type B8 = BinaryField8b; type B32 = BinaryField32b; -pub fn u32add( - builder: &mut ConstraintSystemBuilder, +pub fn u32add( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, xin: OracleId, yin: OracleId, ) -> Result where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - B8: ExtensionField + ExtensionField, - F: TowerField - + ExtensionField - + ExtensionField - + ExtensionField - + ExtensionField, FInput: TowerField, FOutput: TowerField, - B32: TowerField, + U: PackScalar + PackScalar, + B8: ExtensionField + ExtensionField, + F: ExtensionField + ExtensionField, { let mut several = SeveralU32add::new(builder)?; let sum = several.u32add::(builder, name.clone(), xin, yin)?; @@ -55,7 +46,7 @@ where Ok(sum) } -pub struct SeveralU32add { +pub struct SeveralU32add { n_lookups: Vec, lookup_t: OracleId, lookups_u: Vec<[OracleId; 1]>, @@ -64,19 +55,8 @@ pub struct SeveralU32add { _phantom: PhantomData<(U, F)>, } -impl SeveralU32add -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + ExtensionField + ExtensionField, -{ - pub fn new(builder: &mut ConstraintSystemBuilder) -> Result { +impl SeveralU32add { + pub fn new(builder: &mut ConstraintSystemBuilder) -> Result { let lookup_t = builder.add_committed("lookup_t", ADD_T_LOG_SIZE, B32::TOWER_LEVEL); if let Some(witness) = builder.witness() { @@ -111,15 +91,15 @@ where pub fn u32add( &mut self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, ) -> Result where - U: PackScalar + PackScalar, FInput: TowerField, FOutput: TowerField, + U: PackScalar + PackScalar, F: ExtensionField + ExtensionField, B8: ExtensionField + ExtensionField, { @@ -143,8 +123,8 @@ where let cin = builder.add_shifted("cin", cout, 1, 2, ShiftVariant::LogicalLeft)?; - let xin_u8 = pack::<_, _, FInput, B8>(xin, builder, "repacked xin")?; - let yin_u8 = pack::<_, _, FInput, B8>(yin, builder, "repacked yin")?; + let xin_u8 = pack::(xin, builder, "repacked xin")?; + let yin_u8 = pack::(yin, builder, "repacked yin")?; let lookup_u = builder.add_linear_combination( "lookup_u", @@ -231,12 +211,12 @@ where pub fn finalize( mut self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, ) -> Result<()> { let channel = builder.add_channel(); self.finalized = true; - lasso::<_, _, B32>( + lasso::( builder, name, &self.n_lookups, @@ -248,8 +228,56 @@ where } } -impl Drop for SeveralU32add { +impl Drop for SeveralU32add { fn drop(&mut self) { assert!(self.finalized) } } + +#[cfg(test)] +mod tests { + use binius_field::{BinaryField1b, BinaryField8b}; + + use super::SeveralU32add; + use crate::{builder::test_utils::test_circuit, unconstrained::unconstrained}; + + #[test] + fn test_several_lasso_u32add() { + test_circuit(|builder| { + let mut several_u32_add = SeveralU32add::new(builder).unwrap(); + for log_size in [11, 12, 13] { + // BinaryField8b is used here because we utilize an 8x8x1→8 table + let add_a_u8 = unconstrained::(builder, "add_a", log_size).unwrap(); + let add_b_u8 = unconstrained::(builder, "add_b", log_size).unwrap(); + let _sum = several_u32_add + .u32add::( + builder, + "lasso_u32add", + add_a_u8, + add_b_u8, + ) + .unwrap(); + } + several_u32_add.finalize(builder, "lasso_u32add").unwrap(); + Ok(vec![]) + }) + .unwrap(); + } + + #[test] + fn test_lasso_u32add() { + test_circuit(|builder| { + let log_size = 14; + let add_a = unconstrained::(builder, "add_a", log_size)?; + let add_b = unconstrained::(builder, "add_b", log_size)?; + let _sum = super::u32add::( + builder, + "lasso_u32add", + add_a, + add_b, + )?; + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/circuits/src/lasso/u8_double_conditional_increment.rs b/crates/circuits/src/lasso/u8_double_conditional_increment.rs index 907ac418..1f10b443 100644 --- a/crates/circuits/src/lasso/u8_double_conditional_increment.rs +++ b/crates/circuits/src/lasso/u8_double_conditional_increment.rs @@ -2,44 +2,24 @@ use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b, TowerField}; use super::batch::LookupBatch; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; type B32 = BinaryField32b; -pub fn u8_double_conditional_increment( - builder: &mut ConstraintSystemBuilder, +pub fn u8_double_conditional_increment( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, x_in: OracleId, first_carry_in: OracleId, second_carry_in: OracleId, log_size: usize, -) -> Result<(OracleId, OracleId), anyhow::Error> -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result<(OracleId, OracleId), anyhow::Error> { builder.push_namespace(name); let sum = builder.add_committed("sum", log_size, B8::TOWER_LEVEL); diff --git a/crates/circuits/src/lasso/u8add.rs b/crates/circuits/src/lasso/u8add.rs index fe42995d..7e3dd0e5 100644 --- a/crates/circuits/src/lasso/u8add.rs +++ b/crates/circuits/src/lasso/u8add.rs @@ -4,44 +4,24 @@ use std::vec; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b, TowerField}; use super::batch::LookupBatch; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; type B32 = BinaryField32b; -pub fn u8add( - builder: &mut ConstraintSystemBuilder, +pub fn u8add( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, x_in: OracleId, y_in: OracleId, carry_in: OracleId, log_size: usize, -) -> Result<(OracleId, OracleId), anyhow::Error> -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result<(OracleId, OracleId), anyhow::Error> { builder.push_namespace(name); let sum = builder.add_committed("sum", log_size, B8::TOWER_LEVEL); diff --git a/crates/circuits/src/lasso/u8add_carryfree.rs b/crates/circuits/src/lasso/u8add_carryfree.rs index 45bebbd8..fd195959 100644 --- a/crates/circuits/src/lasso/u8add_carryfree.rs +++ b/crates/circuits/src/lasso/u8add_carryfree.rs @@ -2,44 +2,24 @@ use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b, TowerField}; use super::batch::LookupBatch; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; type B32 = BinaryField32b; -pub fn u8add_carryfree( - builder: &mut ConstraintSystemBuilder, +pub fn u8add_carryfree( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, x_in: OracleId, y_in: OracleId, carry_in: OracleId, log_size: usize, -) -> Result -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let sum = builder.add_committed("sum", log_size, B8::TOWER_LEVEL); diff --git a/crates/circuits/src/lasso/u8mul.rs b/crates/circuits/src/lasso/u8mul.rs index 589bf7a4..d0c93ecd 100644 --- a/crates/circuits/src/lasso/u8mul.rs +++ b/crates/circuits/src/lasso/u8mul.rs @@ -2,37 +2,24 @@ use anyhow::{ensure, Result}; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField16b, BinaryField32b, BinaryField8b, TowerField}; use itertools::izip; use super::batch::LookupBatch; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; type B8 = BinaryField8b; type B16 = BinaryField16b; type B32 = BinaryField32b; -pub fn u8mul_bytesliced( - builder: &mut ConstraintSystemBuilder, +pub fn u8mul_bytesliced( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, mult_a: OracleId, mult_b: OracleId, n_multiplications: usize, -) -> Result<[OracleId; 2], anyhow::Error> -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result<[OracleId; 2], anyhow::Error> { builder.push_namespace(name); let log_rows = builder.log_rows([mult_a, mult_b])?; let product = builder.add_committed_multiple("product", log_rows, B8::TOWER_LEVEL); @@ -92,21 +79,14 @@ where Ok(product) } -pub fn u8mul( - builder: &mut ConstraintSystemBuilder, +pub fn u8mul( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, mult_a: OracleId, mult_b: OracleId, n_multiplications: usize, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name.clone()); let product_bytesliced = diff --git a/crates/circuits/src/lib.rs b/crates/circuits/src/lib.rs index c76857c6..59ae8195 100644 --- a/crates/circuits/src/lib.rs +++ b/crates/circuits/src/lib.rs @@ -11,9 +11,9 @@ pub mod arithmetic; pub mod bitwise; +pub mod blake3; pub mod builder; pub mod collatz; -pub mod groestl; pub mod keccakf; pub mod lasso; mod pack; @@ -26,504 +26,32 @@ pub mod vision; #[cfg(test)] mod tests { - use std::array; - use binius_core::{ constraint_system::{ self, channel::{Boundary, FlushDirection}, - validate::validate_witness, }, fiat_shamir::HasherChallenger, - oracle::OracleId, tower::CanonicalTowerFamily, }; use binius_field::{ - arch::OptimalUnderlier, - as_packed_field::PackedType, - tower_levels::{TowerLevel1, TowerLevel16, TowerLevel2, TowerLevel4, TowerLevel8}, - underlier::WithUnderlier, - AESTowerField16b, BinaryField128b, BinaryField1b, BinaryField32b, BinaryField64b, - BinaryField8b, Field, TowerField, + as_packed_field::PackedType, underlier::WithUnderlier, BinaryField8b, Field, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; use groestl_crypto::Groestl256; - use rand::{rngs::StdRng, Rng, SeedableRng}; - use sha2::{compress256, digest::generic_array::GenericArray}; - - use crate::{ - arithmetic, bitwise, - builder::ConstraintSystemBuilder, - groestl::groestl_p_permutation, - keccakf::{keccakf, KeccakfState}, - lasso::{ - self, - batch::LookupBatch, - big_integer_ops::byte_sliced_test_utils::{ - test_bytesliced_add, test_bytesliced_add_carryfree, - test_bytesliced_double_conditional_increment, test_bytesliced_modular_mul, - test_bytesliced_mul, - }, - lookups, - u32add::SeveralU32add, - }, - plain_lookup, - sha256::sha256, - u32fib::u32fib, - unconstrained::unconstrained, - vision::vision_permutation, - }; - - type U = OptimalUnderlier; - type F = BinaryField128b; - - #[test] - fn test_lasso_u8add_carryfree_rejects_carry() { - // TODO: Make this test 100% certain to pass instead of 2^14 bits of security from randomness - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - let x_in = unconstrained::<_, _, BinaryField8b>(&mut builder, "x", log_size).unwrap(); - let y_in = unconstrained::<_, _, BinaryField8b>(&mut builder, "y", log_size).unwrap(); - let c_in = unconstrained::<_, _, BinaryField1b>(&mut builder, "c", log_size).unwrap(); - - let lookup_t = - lookups::u8_arithmetic::add_carryfree_lookup(&mut builder, "add cf table").unwrap(); - let mut lookup_batch = LookupBatch::new([lookup_t]); - let _sum_and_cout = lasso::u8add_carryfree( - &mut builder, - &mut lookup_batch, - "lasso_u8add", - x_in, - y_in, - c_in, - log_size, - ) - .unwrap(); - - lookup_batch - .execute::<_, _, BinaryField32b>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness) - .expect_err("Rejected overflowing add"); - } - - #[test] - fn test_lasso_add_bytesliced() { - test_bytesliced_add::<1, TowerLevel1>(); - test_bytesliced_add::<2, TowerLevel2>(); - test_bytesliced_add::<4, TowerLevel4>(); - test_bytesliced_add::<8, TowerLevel8>(); - } - - #[test] - fn test_lasso_mul_bytesliced() { - test_bytesliced_mul::<1, TowerLevel2>(); - test_bytesliced_mul::<2, TowerLevel4>(); - test_bytesliced_mul::<4, TowerLevel8>(); - test_bytesliced_mul::<8, TowerLevel16>(); - } - - #[test] - fn test_lasso_modular_mul_bytesliced() { - test_bytesliced_modular_mul::<1, TowerLevel2>(); - test_bytesliced_modular_mul::<2, TowerLevel4>(); - test_bytesliced_modular_mul::<4, TowerLevel8>(); - test_bytesliced_modular_mul::<8, TowerLevel16>(); - } - - #[test] - fn test_lasso_bytesliced_double_conditional_increment() { - test_bytesliced_double_conditional_increment::<1, TowerLevel1>(); - test_bytesliced_double_conditional_increment::<2, TowerLevel2>(); - test_bytesliced_double_conditional_increment::<4, TowerLevel4>(); - test_bytesliced_double_conditional_increment::<8, TowerLevel8>(); - } - - #[test] - fn test_lasso_bytesliced_add_carryfree() { - test_bytesliced_add_carryfree::<1, TowerLevel1>(); - test_bytesliced_add_carryfree::<2, TowerLevel2>(); - test_bytesliced_add_carryfree::<4, TowerLevel4>(); - test_bytesliced_add_carryfree::<8, TowerLevel8>(); - } - - #[test] - fn test_lasso_u8mul() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 10; - - let mult_a = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_a", log_size).unwrap(); - let mult_b = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_b", log_size).unwrap(); - - let mul_lookup_table = - lookups::u8_arithmetic::mul_lookup(&mut builder, "mul table").unwrap(); - - let mut lookup_batch = LookupBatch::new([mul_lookup_table]); - - let _product = lasso::u8mul( - &mut builder, - &mut lookup_batch, - "lasso_u8mul", - mult_a, - mult_b, - 1 << log_size, - ) - .unwrap(); - - lookup_batch - .execute::<_, _, BinaryField32b>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_lasso_batched_u8mul() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 10; - let mul_lookup_table = - lookups::u8_arithmetic::mul_lookup(&mut builder, "mul table").unwrap(); - - let mut lookup_batch = LookupBatch::new([mul_lookup_table]); - - for _ in 0..10 { - let mult_a = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_a", log_size).unwrap(); - let mult_b = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_b", log_size).unwrap(); - - let _product = lasso::u8mul( - &mut builder, - &mut lookup_batch, - "lasso_u8mul", - mult_a, - mult_b, - 1 << log_size, - ) - .unwrap(); - } - - lookup_batch - .execute::<_, _, BinaryField32b>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_lasso_batched_u8mul_rejects() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 10; - - // We try to feed in the add table instead - let mul_lookup_table = - lookups::u8_arithmetic::add_lookup(&mut builder, "mul table").unwrap(); - - let mut lookup_batch = LookupBatch::new([mul_lookup_table]); - - // TODO?: Make this test fail 100% of the time, even though its almost impossible with rng - for _ in 0..10 { - let mult_a = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_a", log_size).unwrap(); - let mult_b = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_b", log_size).unwrap(); - - let _product = lasso::u8mul( - &mut builder, - &mut lookup_batch, - "lasso_u8mul", - mult_a, - mult_b, - 1 << log_size, - ) - .unwrap(); - } - - lookup_batch - .execute::<_, _, BinaryField32b>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness) - .expect_err("Channels should be unbalanced"); - } - - #[test] - fn test_several_lasso_u32add() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - - let mut several_u32_add = SeveralU32add::new(&mut builder).unwrap(); - - for log_size in [11, 12, 13] { - // BinaryField8b is used here because we utilize an 8x8x1→8 table - let add_a_u8 = - unconstrained::<_, _, BinaryField8b>(&mut builder, "add_a", log_size).unwrap(); - let add_b_u8 = - unconstrained::<_, _, BinaryField8b>(&mut builder, "add_b", log_size).unwrap(); - let _sum = several_u32_add - .u32add::( - &mut builder, - "lasso_u32add", - add_a_u8, - add_b_u8, - ) - .unwrap(); - } - - several_u32_add - .finalize(&mut builder, "lasso_u32add") - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_lasso_u32add() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - - let add_a = unconstrained::<_, _, BinaryField1b>(&mut builder, "add_a", log_size).unwrap(); - let add_b = unconstrained::<_, _, BinaryField1b>(&mut builder, "add_b", log_size).unwrap(); - let _sum = lasso::u32add::<_, _, BinaryField1b, BinaryField1b>( - &mut builder, - "lasso_u32add", - add_a, - add_b, - ) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_u32add() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - let a = unconstrained::<_, _, BinaryField1b>(&mut builder, "a", log_size).unwrap(); - let b = unconstrained::<_, _, BinaryField1b>(&mut builder, "b", log_size).unwrap(); - let _c = arithmetic::u32::add(&mut builder, "u32add", a, b, arithmetic::Flags::Unchecked) - .unwrap(); - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_u32fib() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size_1b = 14; - let _ = u32fib(&mut builder, "u32fib", log_size_1b).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_bitwise() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 6; - let a = unconstrained::<_, _, BinaryField1b>(&mut builder, "a", log_size).unwrap(); - let b = unconstrained::<_, _, BinaryField1b>(&mut builder, "b", log_size).unwrap(); - let _and = bitwise::and(&mut builder, "and", a, b).unwrap(); - let _xor = bitwise::xor(&mut builder, "xor", a, b).unwrap(); - let _or = bitwise::or(&mut builder, "or", a, b).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_keccakf() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 5; - - let mut rng = StdRng::seed_from_u64(0); - let input_states = vec![KeccakfState(rng.gen())]; - let _state_out = keccakf(&mut builder, Some(input_states), log_size); - - let witness = builder.take_witness().unwrap(); - - let constraint_system = builder.build().unwrap(); - - let boundaries = vec![]; - - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_sha256() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = PackedType::::LOG_WIDTH; - let input: [OracleId; 16] = array::from_fn(|i| { - unconstrained::<_, _, BinaryField1b>(&mut builder, i, log_size).unwrap() - }); - let state_output = sha256(&mut builder, input, log_size).unwrap(); - - let witness = builder.witness().unwrap(); - - let input_witneses: [_; 16] = - array::from_fn(|i| witness.get(input[i]).unwrap().as_slice::()); - - let output_witneses: [_; 8] = - array::from_fn(|i| witness.get(state_output[i]).unwrap().as_slice::()); - - let mut generic_array_input = GenericArray::::default(); - - let n_compressions = input_witneses[0].len(); - - for j in 0..n_compressions { - for i in 0..16 { - for z in 0..4 { - generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; - } - } - - let mut output = crate::sha256::INIT; - compress256(&mut output, &[generic_array_input]); - - for i in 0..8 { - assert_eq!(output[i], output_witneses[i][j]); - } - } - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_sha256_lasso() { - let allocator = bumpalo::Bump::new(); - let mut builder = - ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = PackedType::::LOG_WIDTH + BinaryField8b::TOWER_LEVEL; - let input: [OracleId; 16] = array::from_fn(|i| { - unconstrained::<_, _, BinaryField1b>(&mut builder, i, log_size).unwrap() - }); - let state_output = lasso::sha256(&mut builder, input, log_size).unwrap(); - - let witness = builder.witness().unwrap(); - - let input_witneses: [_; 16] = array::from_fn(|i| { - witness - .get::(input[i]) - .unwrap() - .as_slice::() - }); - - let output_witneses: [_; 8] = array::from_fn(|i| { - witness - .get::(state_output[i]) - .unwrap() - .as_slice::() - }); - - let mut generic_array_input = GenericArray::::default(); - - let n_compressions = input_witneses[0].len(); - - for j in 0..n_compressions { - for i in 0..16 { - for z in 0..4 { - generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; - } - } - - let mut output = crate::sha256::INIT; - compress256(&mut output, &[generic_array_input]); - - for i in 0..8 { - assert_eq!(output[i], output_witneses[i][j]); - } - } - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_groestl() { - let allocator = bumpalo::Bump::new(); - let mut builder = - ConstraintSystemBuilder::::new_with_witness( - &allocator, - ); - let log_size = 9; - let _state_out = groestl_p_permutation(&mut builder, log_size).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_vision32b() { - let allocator = bumpalo::Bump::new(); - let mut builder = - ConstraintSystemBuilder::::new_with_witness( - &allocator, - ); - let log_size = 8; - let state_in: [OracleId; 24] = array::from_fn(|i| { - unconstrained::<_, _, BinaryField32b>(&mut builder, format!("p_in[{i}]"), log_size) - .unwrap() - }); - let _state_out = vision_permutation(&mut builder, log_size, state_in).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } + use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }; #[test] fn test_boundaries() { // Proving Collatz Orbits let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = PackedType::::LOG_WIDTH + 2; @@ -633,74 +161,4 @@ mod tests { >(&constraint_system, 1, 10, &boundaries, proof) .unwrap(); } - - #[test] - fn test_plain_u8_mul_lookup() { - const MAX_LOG_MULTIPLICITY: usize = 18; - let log_lookup_count = 19; - - let log_inv_rate = 1; - let security_bits = 20; - - let proof = { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - - let boundary = plain_lookup::test_plain_lookup::test_u8_mul_lookup::< - _, - _, - MAX_LOG_MULTIPLICITY, - >(&mut builder, log_lookup_count) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - // validating witness with `validate_witness` is too slow for large transparents like the `table` - - let domain_factory = DefaultEvaluationDomainFactory::default(); - let backend = make_portable_backend(); - - constraint_system::prove::< - U, - CanonicalTowerFamily, - _, - Groestl256, - Groestl256ByteCompression, - HasherChallenger, - _, - >( - &constraint_system, - log_inv_rate, - security_bits, - &[boundary], - witness, - &domain_factory, - &backend, - ) - .unwrap() - }; - - // verify - { - let mut builder = ConstraintSystemBuilder::::new(); - - let boundary = plain_lookup::test_plain_lookup::test_u8_mul_lookup::< - _, - _, - MAX_LOG_MULTIPLICITY, - >(&mut builder, log_lookup_count) - .unwrap(); - - let constraint_system = builder.build().unwrap(); - - constraint_system::verify::< - U, - CanonicalTowerFamily, - Groestl256, - Groestl256ByteCompression, - HasherChallenger, - >(&constraint_system, log_inv_rate, security_bits, &[boundary], proof) - .unwrap(); - } - } } diff --git a/crates/circuits/src/pack.rs b/crates/circuits/src/pack.rs index 5d3c8f79..290a4472 100644 --- a/crates/circuits/src/pack.rs +++ b/crates/circuits/src/pack.rs @@ -2,22 +2,23 @@ use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, ExtensionField, TowerField, -}; +use binius_field::{as_packed_field::PackScalar, ExtensionField, TowerField}; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; -pub fn pack( +pub fn pack( oracle_id: OracleId, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, ) -> Result where - F: TowerField + ExtensionField + ExtensionField, + F: ExtensionField + ExtensionField, FInput: TowerField, FOutput: TowerField + ExtensionField, - U: UnderlierType + PackScalar + PackScalar + PackScalar, + U: PackScalar + PackScalar, { if FInput::TOWER_LEVEL == FOutput::TOWER_LEVEL { return Ok(oracle_id); diff --git a/crates/circuits/src/plain_lookup.rs b/crates/circuits/src/plain_lookup.rs index 0c41cc6e..afcb77ac 100644 --- a/crates/circuits/src/plain_lookup.rs +++ b/crates/circuits/src/plain_lookup.rs @@ -1,17 +1,16 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_core::{ - constraint_system::channel::{Boundary, FlushDirection}, - oracle::OracleId, -}; +use binius_core::{constraint_system::channel::FlushDirection, oracle::OracleId}; use binius_field::{ as_packed_field::PackScalar, packed::set_packed_slice, BinaryField1b, ExtensionField, Field, TowerField, }; use bytemuck::Pod; -use itertools::izip; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; /// Checks values in `lookup_values` to be in `table`. /// @@ -24,45 +23,36 @@ use crate::builder::ConstraintSystemBuilder; /// # Parameters /// - `builder`: a mutable reference to the `ConstraintSystemBuilder`. /// - `table`: an oracle holding the table of valid lookup values. -/// - `table_count`: only the first `table_count` values of `table` are considered valid lookup values. -/// - `balancer_value`: any valid table value, needed for balancing the channel. /// - `lookup_values`: an oracle holding the values to be looked up. /// - `lookup_values_count`: only the first `lookup_values_count` values in `lookup_values` will be looked up. /// -/// # Constraints -/// - no value in `lookup_values` can be looked only less than `1 << LOG_MAX_MULTIPLICITY` times, limiting completeness not soundness. -/// /// # How this Works /// We create a single channel for this lookup. /// We let the prover push all values in `lookup_values`, that is all values to be looked up, into the channel. /// We also must pull valid table values (i.e. values that appear in `table`) from the channel if the channel is to balance. /// By ensuring that only valid table values get pulled from the channel, and observing the channel to balance, we ensure that only valid table values get pushed (by the prover) into the channel. /// Therefore our construction is sound. -/// In order for the construction to be complete, allowing an honest prover to pass, we must pull each table value from the channel with exactly the same multiplicity (duplicate count) that the prover pushed that table value into the channel. +/// In order for the construction to be complete, allowing an honest prover to pass, we must pull each +/// table value from the channel with exactly the same multiplicity (duplicate count) that the prover pushed that table value into the channel. /// To do so, we allow the prover to commit information on the multiplicity of each table value. /// -/// The prover counts the multiplicity of each table value, and commits columns holding the bit-decomposition of the multiplicities. -/// Using these bit columns we create `component` columns the same height as the table, which select the table value where a multiplicity bit is 1 and select `balancer_value` where the bit is 0. -/// Pulling these component columns out of the channel with appropriate multiplicities, we pull out each table value from the channel with the multiplicity requested by the prover. -/// Due to the `balancer_value` appearing in the component columns, however, we will also pull the table value `balancer_value` from the channel many more times than needed. -/// To rectify this we put `balancer_value` in a boundary value and push this boundary value to the channel with a multiplicity that will balance the channel. -/// This boundary value is returned from the gadget. +/// The prover counts the multiplicity of each table value, and creates a bit column for +/// each of the LOG_MAX_MULTIPLICITY bits in the bit-decomposition of the multiplicities. +/// Then we flush the table values LOG_MAX_MULTIPLICITY times, each time using a different bit column as the 'selector' oracle to select which values in the +/// table actually get pushed into the channel flushed. When flushing the table with the i'th bit column as the selector, we flush with multiplicity 1 << i. /// -pub fn plain_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn plain_lookup( + builder: &mut ConstraintSystemBuilder, table: OracleId, - table_count: usize, - balancer_value: FS, lookup_values: OracleId, lookup_values_count: usize, -) -> Result, anyhow::Error> +) -> Result<(), anyhow::Error> where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, + U: PackScalar + Pod, + F: ExtensionField, FS: TowerField + Pod, { let n_vars = builder.log_rows([table])?; - debug_assert!(table_count <= 1 << n_vars); let channel = builder.add_channel(); @@ -75,58 +65,29 @@ where let values_slice = witness.get::(lookup_values)?.as_slice::(); multiplicities = Some(count_multiplicities( - &table_slice[0..table_count], + &table_slice[0..1 << n_vars], &values_slice[0..lookup_values_count], false, )?); } - let components: [OracleId; LOG_MAX_MULTIPLICITY] = - get_components::<_, _, FS, LOG_MAX_MULTIPLICITY>( - builder, - table, - table_count, - balancer_value, - multiplicities, - )?; - - components - .into_iter() - .enumerate() - .try_for_each(|(i, component)| { - builder.flush_with_multiplicity( - FlushDirection::Pull, - channel, - table_count, - [component], - 1 << i, - ) - })?; - - let balancer_value_multiplicity = - (((1 << LOG_MAX_MULTIPLICITY) - 1) * table_count - lookup_values_count) as u64; - - let boundary = Boundary { - values: vec![balancer_value.into()], - channel_id: channel, - direction: FlushDirection::Push, - multiplicity: balancer_value_multiplicity, - }; + let bits: [OracleId; LOG_MAX_MULTIPLICITY] = get_bits(builder, table, multiplicities)?; + bits.into_iter().enumerate().try_for_each(|(i, bit)| { + builder.flush_custom(FlushDirection::Pull, channel, bit, [table], 1 << i) + })?; - Ok(boundary) + Ok(()) } -// the `i`'th returned component holds values that are the product of the `table` values and the bits had by taking the `i`'th bit across the multiplicities. -fn get_components( - builder: &mut ConstraintSystemBuilder, +// the `i`'th returned bit column holds the `i`'th multiplicity bit. +fn get_bits( + builder: &mut ConstraintSystemBuilder, table: OracleId, - table_count: usize, - balancer_value: FS, multiplicities: Option>, ) -> Result<[OracleId; LOG_MAX_MULTIPLICITY], anyhow::Error> where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, + U: PackScalar, + F: ExtensionField, FS: TowerField + Pod, { let n_vars = builder.log_rows([table])?; @@ -134,13 +95,10 @@ where let bits: [OracleId; LOG_MAX_MULTIPLICITY] = builder .add_committed_multiple::("bits", n_vars, BinaryField1b::TOWER_LEVEL); - let components: [OracleId; LOG_MAX_MULTIPLICITY] = builder - .add_committed_multiple::("components", n_vars, FS::TOWER_LEVEL); - if let Some(witness) = builder.witness() { let multiplicities = multiplicities.ok_or_else(|| anyhow::anyhow!("multiplicities empty for prover"))?; - debug_assert_eq!(table_count, multiplicities.len()); + debug_assert_eq!(1 << n_vars, multiplicities.len()); // check all multiplicities are in range if multiplicities @@ -155,19 +113,13 @@ where // create the columns for the bits let mut bit_cols = bits.map(|bit| witness.new_column::(bit)); let mut packed_bit_cols = bit_cols.each_mut().map(|bit_col| bit_col.packed()); - // create the columns for the components - let mut component_cols = components.map(|component| witness.new_column::(component)); - let mut packed_component_cols = component_cols - .each_mut() - .map(|component_col| component_col.packed()); - - let table_slice = witness.get::(table)?.as_slice::(); - izip!(table_slice, multiplicities).enumerate().for_each( - |(i, (table_val, multiplicity))| { - for j in 0..LOG_MAX_MULTIPLICITY { + multiplicities + .iter() + .enumerate() + .for_each(|(i, multiplicity)| { + (0..LOG_MAX_MULTIPLICITY).for_each(|j| { let bit_set = multiplicity & (1 << j) != 0; - // set the bit value set_packed_slice( packed_bit_cols[j], i, @@ -176,36 +128,11 @@ where false => BinaryField1b::ZERO, }, ); - // set the component value - set_packed_slice( - packed_component_cols[j], - i, - match bit_set { - true => *table_val, - false => balancer_value, - }, - ); - } - }, - ); + }) + }); } - let expression = { - use binius_math::ArithExpr as Expr; - let table = Expr::Var(0); - let bit = Expr::Var(1); - let component = Expr::Var(2); - component - (bit.clone() * table + (Expr::one() - bit) * Expr::Const(balancer_value)) - }; - (0..LOG_MAX_MULTIPLICITY).for_each(|i| { - builder.assert_zero( - format!("lookup_{i}"), - [table, bits[i], components[i]], - expression.convert_field(), - ); - }); - - Ok(components) + Ok(bits) } #[cfg(test)] @@ -242,28 +169,20 @@ pub mod test_plain_lookup { }); } - pub fn test_u8_mul_lookup( - builder: &mut ConstraintSystemBuilder, + pub fn test_u8_mul_lookup( + builder: &mut ConstraintSystemBuilder, log_lookup_count: usize, - ) -> Result, anyhow::Error> - where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, - { + ) -> Result<(), anyhow::Error> { let table_values = generate_u8_mul_table(); let table = transparent::make_transparent( builder, "u8_mul_table", bytemuck::cast_slice::<_, BinaryField32b>(&table_values), )?; - let balancer_value = BinaryField32b::new(table_values[99]); // any table value let lookup_values = builder.add_committed("lookup_values", log_lookup_count, BinaryField32b::TOWER_LEVEL); - // reduce these if only some table values are valid - // or only some lookup_values are to be looked up - let table_count = table_values.len(); let lookup_values_count = 1 << log_lookup_count; if let Some(witness) = builder.witness() { @@ -272,16 +191,14 @@ pub mod test_plain_lookup { generate_random_u8_mul_claims(&mut mut_slice[0..lookup_values_count]); } - let boundary = plain_lookup::( + plain_lookup::( builder, table, - table_count, - balancer_value, lookup_values, lookup_values_count, )?; - Ok(boundary) + Ok(()) } } @@ -365,3 +282,83 @@ mod count_multiplicity_tests { assert_eq!(result, vec![1, 2, 3]); } } + +#[cfg(test)] +mod tests { + use binius_core::{fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; + use binius_hal::make_portable_backend; + use binius_hash::compress::Groestl256ByteCompression; + use binius_math::DefaultEvaluationDomainFactory; + use groestl_crypto::Groestl256; + + use super::test_plain_lookup; + use crate::builder::ConstraintSystemBuilder; + + #[test] + fn test_plain_u8_mul_lookup() { + const MAX_LOG_MULTIPLICITY: usize = 18; + let log_lookup_count = 19; + + let log_inv_rate = 1; + let security_bits = 20; + + let proof = { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + test_plain_lookup::test_u8_mul_lookup::( + &mut builder, + log_lookup_count, + ) + .unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + // validating witness with `validate_witness` is too slow for large transparents like the `table` + + let domain_factory = DefaultEvaluationDomainFactory::default(); + let backend = make_portable_backend(); + + binius_core::constraint_system::prove::< + crate::builder::types::U, + CanonicalTowerFamily, + _, + Groestl256, + Groestl256ByteCompression, + HasherChallenger, + _, + >( + &constraint_system, + log_inv_rate, + security_bits, + &[], + witness, + &domain_factory, + &backend, + ) + .unwrap() + }; + + // verify + { + let mut builder = ConstraintSystemBuilder::new(); + + test_plain_lookup::test_u8_mul_lookup::( + &mut builder, + log_lookup_count, + ) + .unwrap(); + + let constraint_system = builder.build().unwrap(); + + binius_core::constraint_system::verify::< + crate::builder::types::U, + CanonicalTowerFamily, + Groestl256, + Groestl256ByteCompression, + HasherChallenger, + >(&constraint_system, log_inv_rate, security_bits, &[], proof) + .unwrap(); + } + } +} diff --git a/crates/circuits/src/sha256.rs b/crates/circuits/src/sha256.rs index d2e8fe6e..53b4a33b 100644 --- a/crates/circuits/src/sha256.rs +++ b/crates/circuits/src/sha256.rs @@ -5,16 +5,21 @@ use binius_core::{ transparent::multilinear_extension::MultilinearExtensionTransparent, }; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::{UnderlierType, WithUnderlier}, - BinaryField1b, PackedField, TowerField, + as_packed_field::PackedType, underlier::WithUnderlier, BinaryField1b, Field, PackedField, + TowerField, }; use binius_macros::arith_expr; use binius_utils::checked_arithmetics::checked_log_2; use bytemuck::{pod_collect_to_vec, Pod}; use itertools::izip; -use crate::{arithmetic, builder::ConstraintSystemBuilder}; +use crate::{ + arithmetic, + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, +}; const LOG_U32_BITS: usize = checked_log_2(32); @@ -41,15 +46,11 @@ pub enum RotateRightType { Logical, } -pub fn rotate_and_xor( +pub fn rotate_and_xor( log_size: usize, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, r: &[(OracleId, usize, RotateRightType)], -) -> Result -where - F: TowerField, - U: UnderlierType + Pod + PackScalar + PackScalar, -{ +) -> Result { let shifted_oracle_ids = r .iter() .map(|(oracle_id, shift, t)| { @@ -76,7 +77,7 @@ where let result_oracle_id = builder.add_linear_combination( format!("linear combination of {:?}", shifted_oracle_ids), log_size, - shifted_oracle_ids.iter().map(|s| (*s, F::ONE)), + shifted_oracle_ids.iter().map(|s| (*s, Field::ONE)), )?; if let Some(witness) = builder.witness() { @@ -116,16 +117,12 @@ where .collect() } -pub fn u32const_repeating( +pub fn u32const_repeating( log_size: usize, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, x: u32, name: &str, -) -> Result -where - F: TowerField, - U: UnderlierType + Pod + PackScalar + PackScalar, -{ +) -> Result { let brodcasted = vec![x; 1 << (PackedType::::LOG_WIDTH.saturating_sub(LOG_U32_BITS))]; let transparent_id = builder.add_transparent( @@ -152,15 +149,11 @@ where Ok(repeating_id) } -pub fn sha256( - builder: &mut ConstraintSystemBuilder, +pub fn sha256( + builder: &mut ConstraintSystemBuilder, input: [OracleId; 16], log_size: usize, -) -> Result<[OracleId; 8], anyhow::Error> -where - U: UnderlierType + Pod + PackScalar + PackScalar, - F: TowerField, -{ +) -> Result<[OracleId; 8], anyhow::Error> { if log_size < >::LOG_WIDTH { Err(anyhow::Error::msg("log_size too small"))? } @@ -319,3 +312,64 @@ where Ok(output) } + +#[cfg(test)] +mod tests { + use binius_core::oracle::OracleId; + use binius_field::{as_packed_field::PackedType, BinaryField1b}; + use sha2::{compress256, digest::generic_array::GenericArray}; + + use crate::{ + builder::{test_utils::test_circuit, types::U}, + unconstrained::unconstrained, + }; + + #[test] + fn test_sha256() { + test_circuit(|builder| { + let log_size = PackedType::::LOG_WIDTH; + let input: [OracleId; 16] = std::array::from_fn(|i| { + unconstrained::(builder, i, log_size).unwrap() + }); + let state_output = super::sha256(builder, input, log_size).unwrap(); + + if let Some(witness) = builder.witness() { + let input_witneses: [_; 16] = std::array::from_fn(|i| { + witness + .get::(input[i]) + .unwrap() + .as_slice::() + }); + + let output_witneses: [_; 8] = std::array::from_fn(|i| { + witness + .get::(state_output[i]) + .unwrap() + .as_slice::() + }); + + let mut generic_array_input = GenericArray::::default(); + + let n_compressions = input_witneses[0].len(); + + for j in 0..n_compressions { + for i in 0..16 { + for z in 0..4 { + generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; + } + } + + let mut output = crate::sha256::INIT; + compress256(&mut output, &[generic_array_input]); + + for i in 0..8 { + assert_eq!(output[i], output_witneses[i][j]); + } + } + } + + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/circuits/src/transparent.rs b/crates/circuits/src/transparent.rs index 16efed3a..11e36c22 100644 --- a/crates/circuits/src/transparent.rs +++ b/crates/circuits/src/transparent.rs @@ -3,23 +3,20 @@ use binius_core::{oracle::OracleId, transparent}; use binius_field::{ as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, BinaryField1b, ExtensionField, PackedField, TowerField, }; -use bytemuck::Pod; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; -pub fn step_down( - builder: &mut ConstraintSystemBuilder, +pub fn step_down( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, index: usize, -) -> Result -where - U: UnderlierType + PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { let step_down = transparent::step_down::StepDown::new(log_size, index)?; let id = builder.add_transparent(name, step_down.clone())?; if let Some(witness) = builder.witness() { @@ -28,16 +25,12 @@ where Ok(id) } -pub fn step_up( - builder: &mut ConstraintSystemBuilder, +pub fn step_up( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, index: usize, -) -> Result -where - U: UnderlierType + PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { let step_up = transparent::step_up::StepUp::new(log_size, index)?; let id = builder.add_transparent(name, step_up.clone())?; if let Some(witness) = builder.witness() { @@ -46,14 +39,14 @@ where Ok(id) } -pub fn constant( - builder: &mut ConstraintSystemBuilder, +pub fn constant( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, value: FS, ) -> Result where - U: UnderlierType + PackScalar + PackScalar, + U: PackScalar, F: TowerField + ExtensionField, FS: TowerField, { @@ -68,13 +61,13 @@ where Ok(id) } -pub fn make_transparent( - builder: &mut ConstraintSystemBuilder, +pub fn make_transparent( + builder: &mut ConstraintSystemBuilder, name: impl ToString, values: &[FS], ) -> Result where - U: PackScalar + PackScalar, + U: PackScalar, F: TowerField + ExtensionField, FS: TowerField, { diff --git a/crates/circuits/src/u32fib.rs b/crates/circuits/src/u32fib.rs index 1f4197af..59563fb7 100644 --- a/crates/circuits/src/u32fib.rs +++ b/crates/circuits/src/u32fib.rs @@ -1,26 +1,22 @@ // Copyright 2024-2025 Irreducible Inc. use binius_core::oracle::{OracleId, ShiftVariant}; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, BinaryField1b, BinaryField32b, - ExtensionField, TowerField, -}; +use binius_field::{BinaryField1b, BinaryField32b, TowerField}; use binius_macros::arith_expr; use binius_maybe_rayon::prelude::*; -use bytemuck::Pod; use rand::{thread_rng, Rng}; -use crate::{arithmetic, builder::ConstraintSystemBuilder, transparent::step_down}; +use crate::{ + arithmetic, + builder::{types::F, ConstraintSystemBuilder}, + transparent::step_down, +}; -pub fn u32fib( - builder: &mut ConstraintSystemBuilder, +pub fn u32fib( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar + PackScalar, - F: TowerField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let current = builder.add_committed("current", log_size, BinaryField1b::TOWER_LEVEL); let next = builder.add_shifted("next", current, 32, log_size, ShiftVariant::LogicalRight)?; @@ -75,3 +71,18 @@ where builder.pop_namespace(); Ok(current) } + +#[cfg(test)] +mod tests { + use crate::builder::test_utils::test_circuit; + + #[test] + fn test_u32fib() { + test_circuit(|builder| { + let log_size_1b = 14; + let _ = super::u32fib(builder, "u32fib", log_size_1b)?; + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/circuits/src/unconstrained.rs b/crates/circuits/src/unconstrained.rs index a798ec5f..f39f48cd 100644 --- a/crates/circuits/src/unconstrained.rs +++ b/crates/circuits/src/unconstrained.rs @@ -1,21 +1,22 @@ // Copyright 2024-2025 Irreducible Inc. use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, ExtensionField, TowerField, -}; +use binius_field::{as_packed_field::PackScalar, ExtensionField, TowerField}; use binius_maybe_rayon::prelude::*; use bytemuck::Pod; use rand::{thread_rng, Rng}; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; -pub fn unconstrained( - builder: &mut ConstraintSystemBuilder, +pub fn unconstrained( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, ) -> Result where - U: UnderlierType + Pod + PackScalar + PackScalar, + U: PackScalar + Pod, F: TowerField + ExtensionField, FS: TowerField, { @@ -33,3 +34,31 @@ where Ok(rng) } + +// Same as 'unconstrained' but uses some pre-defined values instead of a random ones +pub fn fixed_u32( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + log_size: usize, + value: Vec, +) -> Result +where + U: PackScalar + Pod, + F: TowerField + ExtensionField, + FS: TowerField, +{ + let rng = builder.add_committed(name, log_size, FS::TOWER_LEVEL); + + if let Some(witness) = builder.witness() { + witness + .new_column::(rng) + .as_mut_slice::() + .into_par_iter() + .zip(value.into_par_iter()) + .for_each(|(data, value)| { + *data = value; + }); + } + + Ok(rng) +} diff --git a/crates/circuits/src/vision.rs b/crates/circuits/src/vision.rs index 5e8fe20a..182566db 100644 --- a/crates/circuits/src/vision.rs +++ b/crates/circuits/src/vision.rs @@ -12,36 +12,22 @@ use std::array; use anyhow::Result; use binius_core::{oracle::OracleId, transparent::constant::Constant}; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - linear_transformation::Transformation, - make_aes_to_binary_packed_transformer, - packed::get_packed_slice, - underlier::UnderlierType, - BinaryField1b, BinaryField32b, BinaryField64b, ExtensionField, PackedAESBinaryField8x32b, - PackedBinaryField8x32b, PackedField, TowerField, + linear_transformation::Transformation, make_aes_to_binary_packed_transformer, + packed::get_packed_slice, BinaryField1b, BinaryField32b, ExtensionField, Field, + PackedAESBinaryField8x32b, PackedBinaryField8x32b, PackedField, TowerField, }; use binius_hash::{Vision32MDSTransform, INV_PACKED_TRANS_AES}; use binius_macros::arith_expr; use binius_math::ArithExpr; -use bytemuck::{must_cast_slice, Pod}; +use bytemuck::must_cast_slice; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; -pub fn vision_permutation( - builder: &mut ConstraintSystemBuilder, +pub fn vision_permutation( + builder: &mut ConstraintSystemBuilder, log_size: usize, p_in: [OracleId; STATE_SIZE], -) -> Result<[OracleId; STATE_SIZE]> -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - F: TowerField + ExtensionField + ExtensionField, - PackedType: Pod, -{ +) -> Result<[OracleId; STATE_SIZE]> { // This only acts as a shorthand type B32 = BinaryField32b; @@ -77,7 +63,7 @@ where } let perm_out = (0..N_ROUNDS).try_fold(round_0_input, |state, round_i| { - vision_round::(builder, log_size, round_i, state) + vision_round(builder, log_size, round_i, state) })?; #[cfg(debug_assertions)] @@ -237,22 +223,13 @@ fn inv_constraint_expr() -> Result> { Ok(non_zero_case * zero_case) } -fn vision_round( - builder: &mut ConstraintSystemBuilder, +fn vision_round( + builder: &mut ConstraintSystemBuilder, log_size: usize, round_i: usize, perm_in: [OracleId; STATE_SIZE], ) -> Result<[OracleId; STATE_SIZE]> -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - F: TowerField + ExtensionField + ExtensionField, - PackedType: Pod, -{ +where { builder.push_namespace(format!("round[{round_i}]")); let inv_0 = builder.add_committed_multiple::( "inv_evens", @@ -318,7 +295,10 @@ where .add_linear_combination( format!("round_out_evens_{}", row), log_size, - [(mds_out_0[row], F::ONE), (even_round_consts[row], F::ONE)], + [ + (mds_out_0[row], Field::ONE), + (even_round_consts[row], Field::ONE), + ], ) .unwrap() }); @@ -328,7 +308,10 @@ where .add_linear_combination( format!("round_out_odd_{}", row), log_size, - [(mds_out_1[row], F::ONE), (odd_round_consts[row], F::ONE)], + [ + (mds_out_1[row], Field::ONE), + (odd_round_consts[row], Field::ONE), + ], ) .unwrap() }); @@ -476,3 +459,25 @@ where Ok(perm_out) } + +#[cfg(test)] +mod tests { + use binius_core::oracle::OracleId; + use binius_field::BinaryField32b; + + use super::vision_permutation; + use crate::{builder::test_utils::test_circuit, unconstrained::unconstrained}; + + #[test] + fn test_vision32b() { + test_circuit(|builder| { + let log_size = 8; + let state_in: [OracleId; 24] = std::array::from_fn(|i| { + unconstrained::(builder, format!("p_in[{i}]"), log_size).unwrap() + }); + let _state_out = vision_permutation(builder, log_size, state_in).unwrap(); + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 113cac7a..7f74200a 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -10,6 +10,7 @@ workspace = true [dependencies] assert_matches.workspace = true auto_impl.workspace = true +binius_macros = { path = "../macros" } binius_field = { path = "../field" } binius_hal = { path = "../hal" } binius_hash = { path = "../hash" } @@ -23,6 +24,7 @@ derive_more.workspace = true digest.workspace = true either.workspace = true getset.workspace = true +inventory.workspace = true itertools.workspace = true rand.workspace = true stackalloc.workspace = true diff --git a/crates/core/benches/composition_poly.rs b/crates/core/benches/composition_poly.rs index 845c865b..166dcfe5 100644 --- a/crates/core/benches/composition_poly.rs +++ b/crates/core/benches/composition_poly.rs @@ -8,7 +8,7 @@ use binius_field::{ PackedField, }; use binius_macros::{arith_circuit_poly, composition_poly}; -use binius_math::{ArithExpr as Expr, CompositionPolyOS}; +use binius_math::{ArithExpr as Expr, CompositionPoly}; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; use rand::{thread_rng, RngCore}; @@ -28,7 +28,7 @@ fn generate_input_data(mut rng: impl RngCore) -> Vec> { fn evaluate_arith_circuit_poly( query: &[&[P]], - arith_circuit_poly: &impl CompositionPolyOS

, + arith_circuit_poly: &impl CompositionPoly

, ) { for i in 0..BATCH_SIZE { let result = arith_circuit_poly diff --git a/crates/core/src/composition/index.rs b/crates/core/src/composition/index.rs index 0eff1949..2a918439 100644 --- a/crates/core/src/composition/index.rs +++ b/crates/core/src/composition/index.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use binius_field::{Field, PackedField}; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::bail; use crate::polynomial::Error; @@ -34,7 +34,7 @@ impl IndexComposition { } } -impl, const N: usize> CompositionPolyOS

+impl, const N: usize> CompositionPoly

for IndexComposition { fn n_vars(&self) -> usize { @@ -159,7 +159,7 @@ mod tests { }; assert_eq!( - (&composition as &dyn CompositionPolyOS).expression(), + (&composition as &dyn CompositionPoly).expression(), ArithExpr::Add( Box::new(ArithExpr::Var(1)), Box::new(ArithExpr::Mul( diff --git a/crates/core/src/composition/product_composition.rs b/crates/core/src/composition/product_composition.rs index 245c2865..7fddf331 100644 --- a/crates/core/src/composition/product_composition.rs +++ b/crates/core/src/composition/product_composition.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::PackedField; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::bail; #[derive(Debug, Default, Copy, Clone)] @@ -17,7 +17,7 @@ impl ProductComposition { } } -impl CompositionPolyOS

for ProductComposition { +impl CompositionPoly

for ProductComposition { fn n_vars(&self) -> usize { self.n_vars() } diff --git a/crates/core/src/constraint_system/channel.rs b/crates/core/src/constraint_system/channel.rs index 93a28cb0..01891f18 100644 --- a/crates/core/src/constraint_system/channel.rs +++ b/crates/core/src/constraint_system/channel.rs @@ -52,14 +52,14 @@ use std::collections::HashMap; use binius_field::{as_packed_field::PackScalar, underlier::UnderlierType, TowerField}; -use bytes::BufMut; +use binius_macros::{DeserializeBytes, SerializeBytes}; use super::error::{Error, VerificationError}; -use crate::{oracle::OracleId, transcript::TranscriptWriter, witness::MultilinearExtensionIndex}; +use crate::{oracle::OracleId, witness::MultilinearExtensionIndex}; pub type ChannelId = usize; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct Flush { pub oracles: Vec, pub channel_id: ChannelId, @@ -68,7 +68,7 @@ pub struct Flush { pub multiplicity: u64, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub struct Boundary { pub values: Vec, pub channel_id: ChannelId, @@ -76,7 +76,7 @@ pub struct Boundary { pub multiplicity: u64, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub enum FlushDirection { Push, Pull, @@ -220,26 +220,6 @@ impl Channel { } } -impl Boundary { - pub fn write_to(&self, writer: &mut TranscriptWriter) { - writer.buffer().put_u64(self.values.len() as u64); - writer.write_slice( - &self - .values - .iter() - .copied() - .map(F::Canonical::from) - .collect::>(), - ); - writer.buffer().put_u64(self.channel_id as u64); - writer.buffer().put_u64(self.multiplicity); - writer.buffer().put_u64(match self.direction { - FlushDirection::Pull => 0, - FlushDirection::Push => 1, - }); - } -} - #[cfg(test)] mod tests { use binius_field::BinaryField64b; diff --git a/crates/core/src/constraint_system/mod.rs b/crates/core/src/constraint_system/mod.rs index de178b31..81719eac 100644 --- a/crates/core/src/constraint_system/mod.rs +++ b/crates/core/src/constraint_system/mod.rs @@ -7,7 +7,9 @@ mod prove; pub mod validate; mod verify; -use binius_field::TowerField; +use binius_field::{BinaryField128b, TowerField}; +use binius_macros::SerializeBytes; +use binius_utils::{DeserializeBytes, SerializationError, SerializationMode}; use channel::{ChannelId, Flush}; pub use prove::prove; pub use verify::verify; @@ -21,7 +23,7 @@ use crate::oracle::{ConstraintSet, MultilinearOracleSet, OracleId}; /// /// As a result, a ConstraintSystem allows us to validate all of these /// constraints against a witness, as well as enabling generic prove/verify -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes)] pub struct ConstraintSystem { pub oracles: MultilinearOracleSet, pub table_constraints: Vec>, @@ -30,6 +32,24 @@ pub struct ConstraintSystem { pub max_channel_id: ChannelId, } +impl DeserializeBytes for ConstraintSystem { + fn deserialize( + mut read_buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(Self { + oracles: DeserializeBytes::deserialize(&mut read_buf, mode)?, + table_constraints: DeserializeBytes::deserialize(&mut read_buf, mode)?, + non_zero_oracle_ids: DeserializeBytes::deserialize(&mut read_buf, mode)?, + flushes: DeserializeBytes::deserialize(&mut read_buf, mode)?, + max_channel_id: DeserializeBytes::deserialize(&mut read_buf, mode)?, + }) + } +} + impl ConstraintSystem { pub const fn no_base_constraints(self) -> Self { self diff --git a/crates/core/src/constraint_system/prove.rs b/crates/core/src/constraint_system/prove.rs index d405405e..a04539d5 100644 --- a/crates/core/src/constraint_system/prove.rs +++ b/crates/core/src/constraint_system/prove.rs @@ -12,7 +12,7 @@ use binius_field::{ use binius_hal::ComputationBackend; use binius_hash::PseudoCompressionFunction; use binius_math::{ - ArithExpr, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MLEDirectAdapter, + EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly, }; use binius_maybe_rayon::prelude::*; @@ -54,7 +54,7 @@ use crate::{ }, }, ring_switch, - tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier, TowerFamily}, + tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier}, transcript::ProverTranscript, witness::{MultilinearExtensionIndex, MultilinearWitness}, }; @@ -104,12 +104,7 @@ where let fast_domain_factory = IsomorphicEvaluationDomainFactory::>::default(); let mut transcript = ProverTranscript::::new(); - { - let mut observer = transcript.observe(); - for boundary in boundaries { - boundary.write_to(&mut observer); - } - } + transcript.observe().write_slice(boundaries); let ConstraintSystem { mut oracles, @@ -241,7 +236,7 @@ where let (flush_oracle_ids, flush_selectors, flush_final_layer_claims) = reorder_for_flushing_by_n_vars( &oracles, - flush_oracle_ids, + &flush_oracle_ids, flush_selectors, flush_final_layer_claims, ); @@ -302,7 +297,7 @@ where .map(|multilinear| 7 - multilinear.log_extension_degree()), constraints .iter() - .map(|constraint| arith_expr_base_tower_level::(&constraint.composition)) + .map(|constraint| constraint.composition.binary_tower_level()) ) .max() .unwrap_or(0); @@ -423,7 +418,7 @@ where let system = ring_switch::EvalClaimSystem::new( &oracles, &commit_meta, - oracle_to_commit_index, + &oracle_to_commit_index, &eval_claims, )?; @@ -457,30 +452,6 @@ where }) } -fn arith_expr_base_tower_level(composition: &ArithExpr>) -> usize { - if composition.try_convert_field::().is_ok() { - return 0; - } - - if composition.try_convert_field::().is_ok() { - return 3; - } - - if composition.try_convert_field::().is_ok() { - return 4; - } - - if composition.try_convert_field::().is_ok() { - return 5; - } - - if composition.try_convert_field::().is_ok() { - return 6; - } - - 7 -} - type TypeErasedUnivariateZerocheck<'a, F> = Box + 'a>; type TypeErasedSumcheck<'a, F> = Box + 'a>; type TypeErasedProver<'a, F> = @@ -518,7 +489,7 @@ where P: PackedExtension + PackedExtension + PackedExtension, - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, { let univariate_prover = sumcheck::prove::constraint_set_zerocheck_prover::<_, _, FBase, _, _>( diff --git a/crates/core/src/constraint_system/verify.rs b/crates/core/src/constraint_system/verify.rs index 467efd93..415f909d 100644 --- a/crates/core/src/constraint_system/verify.rs +++ b/crates/core/src/constraint_system/verify.rs @@ -4,7 +4,7 @@ use std::{cmp::Reverse, iter}; use binius_field::{BinaryField, PackedField, TowerField}; use binius_hash::PseudoCompressionFunction; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::{bail, checked_arithmetics::log2_ceil_usize}; use digest::{core_api::BlockSizeUser, Digest, Output}; use itertools::{izip, multiunzip, Itertools}; @@ -75,12 +75,7 @@ where let Proof { transcript } = proof; let mut transcript = VerifierTranscript::::new(transcript); - { - let mut observer = transcript.observe(); - for boundary in boundaries { - boundary.write_to(&mut observer); - } - } + transcript.observe().write_slice(boundaries); let merkle_scheme = BinaryMerkleTreeScheme::<_, Hash, _>::new(Compress::default()); let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?; @@ -155,7 +150,7 @@ where let (flush_oracle_ids, flush_selectors, flush_final_layer_claims) = reorder_for_flushing_by_n_vars( &oracles, - flush_oracle_ids, + &flush_oracle_ids, flush_selectors, flush_final_layer_claims, ); @@ -284,7 +279,7 @@ where let system = ring_switch::EvalClaimSystem::new( &oracles, &commit_meta, - oracle_to_commit_index, + &oracle_to_commit_index, &eval_claims, )?; @@ -315,7 +310,7 @@ pub fn max_n_vars_and_skip_rounds( ) -> (usize, usize) where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, { let max_n_vars = max_n_vars(zerocheck_claims); @@ -339,7 +334,7 @@ where fn max_n_vars(zerocheck_claims: &[ZerocheckClaim]) -> usize where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, { zerocheck_claims .iter() @@ -572,7 +567,7 @@ pub fn get_flush_dedup_sumcheck_metas( #[derive(Debug)] pub struct FlushSumcheckComposition; -impl CompositionPolyOS

for FlushSumcheckComposition { +impl CompositionPoly

for FlushSumcheckComposition { fn n_vars(&self) -> usize { 2 } @@ -644,7 +639,7 @@ pub fn get_post_flush_sumcheck_eval_claims_without_eq( Ok(evalcheck_claims) } -pub struct DedupSumcheckClaims> { +pub struct DedupSumcheckClaims> { sumcheck_claims: Vec>, gkr_eval_points: Vec>, flush_selectors_unique_by_claim: Vec>, @@ -654,7 +649,7 @@ pub struct DedupSumcheckClaims> #[allow(clippy::type_complexity)] pub fn get_flush_dedup_sumcheck_claims( flush_sumcheck_metas: Vec>, -) -> Result>, Error> { +) -> Result>, Error> { let n_claims = flush_sumcheck_metas.len(); let mut sumcheck_claims = Vec::with_capacity(n_claims); let mut gkr_eval_points = Vec::with_capacity(n_claims); @@ -698,7 +693,7 @@ pub fn get_flush_dedup_sumcheck_claims( pub fn reorder_for_flushing_by_n_vars( oracles: &MultilinearOracleSet, - flush_oracle_ids: Vec, + flush_oracle_ids: &[OracleId], flush_selectors: Vec, flush_final_layer_claims: Vec>, ) -> (Vec, Vec, Vec>) { diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index c7ebeab0..48f1f060 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -7,7 +7,7 @@ //! performance, while verifier-side functions are optimized for auditability and security. // This is to silence clippy errors around suspicious usage of XOR -// in our arithmetic. This is safe to do becasue we're operating +// in our arithmetic. This is safe to do because we're operating // over binary fields. #![allow(clippy::suspicious_arithmetic_impl)] #![allow(clippy::suspicious_op_assign_impl)] @@ -28,3 +28,5 @@ pub mod tower; pub mod transcript; pub mod transparent; pub mod witness; + +pub use inventory; diff --git a/crates/core/src/merkle_tree/binary_merkle_tree.rs b/crates/core/src/merkle_tree/binary_merkle_tree.rs index f15d689b..86383830 100644 --- a/crates/core/src/merkle_tree/binary_merkle_tree.rs +++ b/crates/core/src/merkle_tree/binary_merkle_tree.rs @@ -2,10 +2,12 @@ use std::{array, fmt::Debug, mem::MaybeUninit}; -use binius_field::{serialize_canonical, TowerField}; +use binius_field::TowerField; use binius_hash::{HashBuffer, PseudoCompressionFunction}; use binius_maybe_rayon::{prelude::*, slice::ParallelSlice}; -use binius_utils::{bail, checked_arithmetics::log2_strict_usize}; +use binius_utils::{ + bail, checked_arithmetics::log2_strict_usize, SerializationMode, SerializeBytes, +}; use digest::{crypto_common::BlockSizeUser, Digest, FixedOutputReset, Output}; use tracing::instrument; @@ -210,7 +212,8 @@ where { let mut hash_buffer = HashBuffer::new(hasher); for elem in elems { - serialize_canonical(elem, &mut hash_buffer) + let mode = SerializationMode::CanonicalTower; + SerializeBytes::serialize(&elem, &mut hash_buffer, mode) .expect("HashBuffer has infinite capacity"); } } diff --git a/crates/core/src/merkle_tree/scheme.rs b/crates/core/src/merkle_tree/scheme.rs index c29940cf..5055e354 100644 --- a/crates/core/src/merkle_tree/scheme.rs +++ b/crates/core/src/merkle_tree/scheme.rs @@ -2,11 +2,12 @@ use std::{array, fmt::Debug, marker::PhantomData}; -use binius_field::{serialize_canonical, TowerField}; +use binius_field::TowerField; use binius_hash::{HashBuffer, PseudoCompressionFunction}; use binius_utils::{ bail, checked_arithmetics::{log2_ceil_usize, log2_strict_usize}, + SerializationMode, SerializeBytes, }; use bytes::Buf; use digest::{core_api::BlockSizeUser, Digest, Output}; @@ -108,42 +109,36 @@ where fn verify_opening( &self, - index: usize, + mut index: usize, values: &[F], layer_depth: usize, tree_depth: usize, layer_digests: &[Self::Digest], proof: &mut TranscriptReader, ) -> Result<(), Error> { - if 1 << layer_depth != layer_digests.len() { - bail!(VerificationError::IncorrectVectorLength) + if (1 << layer_depth) != layer_digests.len() { + bail!(VerificationError::IncorrectVectorLength); } - if index > (1 << tree_depth) - 1 { + if index >= (1 << tree_depth) { bail!(Error::IndexOutOfRange { - max: (1 << tree_depth) - 1, + max: (1 << tree_depth) - 1 }); } - let leaf_digest = hash_field_elems::<_, H>(values); - let branch = proof.read_vec(tree_depth - layer_depth)?; - - let mut index = index; - let root = branch.into_iter().fold(leaf_digest, |node, branch_node| { - let next_node = if index & 1 == 0 { - self.compression.compress([node, branch_node]) + let mut leaf_digest = hash_field_elems::<_, H>(values); + for branch_node in proof.read_vec(tree_depth - layer_depth)? { + leaf_digest = self.compression.compress(if index & 1 == 0 { + [leaf_digest, branch_node] } else { - self.compression.compress([branch_node, node]) - }; + [branch_node, leaf_digest] + }); index >>= 1; - next_node - }); - - if root == layer_digests[index] { - Ok(()) - } else { - bail!(VerificationError::InvalidProof) } + + (leaf_digest == layer_digests[index]) + .then_some(()) + .ok_or_else(|| VerificationError::InvalidProof.into()) } } @@ -178,8 +173,10 @@ where let mut hasher = H::new(); { let mut buffer = HashBuffer::new(&mut hasher); - for &elem in elems { - serialize_canonical(elem, &mut buffer).expect("HashBuffer has infinite capacity"); + for elem in elems { + let mode = SerializationMode::CanonicalTower; + SerializeBytes::serialize(elem, &mut buffer, mode) + .expect("HashBuffer has infinite capacity"); } } hasher.finalize() diff --git a/crates/core/src/oracle/composite.rs b/crates/core/src/oracle/composite.rs index 7c904dea..d90102a2 100644 --- a/crates/core/src/oracle/composite.rs +++ b/crates/core/src/oracle/composite.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use binius_field::TowerField; -use binius_math::CompositionPolyOS; +use binius_math::CompositionPoly; use binius_utils::bail; use crate::oracle::{Error, MultilinearPolyOracle, OracleId}; @@ -12,11 +12,11 @@ use crate::oracle::{Error, MultilinearPolyOracle, OracleId}; pub struct CompositePolyOracle { n_vars: usize, inner: Vec>, - composition: Arc>, + composition: Arc>, } impl CompositePolyOracle { - pub fn new + 'static>( + pub fn new + 'static>( n_vars: usize, inner: Vec>, composition: C, @@ -67,7 +67,7 @@ impl CompositePolyOracle { self.inner.clone() } - pub fn composition(&self) -> Arc> { + pub fn composition(&self) -> Arc> { self.composition.clone() } } @@ -82,7 +82,7 @@ mod tests { #[derive(Clone, Debug)] struct TestByteComposition; - impl CompositionPolyOS for TestByteComposition { + impl CompositionPoly for TestByteComposition { fn n_vars(&self) -> usize { 3 } diff --git a/crates/core/src/oracle/constraint.rs b/crates/core/src/oracle/constraint.rs index 4a9d0225..df683394 100644 --- a/crates/core/src/oracle/constraint.rs +++ b/crates/core/src/oracle/constraint.rs @@ -4,7 +4,8 @@ use core::iter::IntoIterator; use std::sync::Arc; use binius_field::{Field, TowerField}; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_macros::{DeserializeBytes, SerializeBytes}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::bail; use itertools::Itertools; @@ -12,26 +13,26 @@ use super::{Error, MultilinearOracleSet, MultilinearPolyVariant, OracleId}; /// Composition trait object that can be used to create lists of compositions of differing /// concrete types. -pub type TypeErasedComposition

= Arc>; +pub type TypeErasedComposition

= Arc>; /// Constraint is a type erased composition along with a predicate on its values on the boolean hypercube -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct Constraint { - pub name: Arc, + pub name: String, pub composition: ArithExpr, pub predicate: ConstraintPredicate, } /// Predicate can either be a sum of values of a composition on the hypercube (sumcheck) or equality to zero /// on the hypercube (zerocheck) -#[derive(Clone, Debug)] +#[derive(Clone, Debug, SerializeBytes, DeserializeBytes)] pub enum ConstraintPredicate { Sum(F), Zero, } /// Constraint set is a group of constraints that operate over the same set of oracle-identified multilinears -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct ConstraintSet { pub n_vars: usize, pub oracle_ids: Vec, @@ -41,7 +42,7 @@ pub struct ConstraintSet { // A deferred constraint constructor that instantiates index composition after the superset of oracles is known #[allow(clippy::type_complexity)] struct UngroupedConstraint { - name: Arc, + name: String, oracle_ids: Vec, composition: ArithExpr, predicate: ConstraintPredicate, @@ -82,7 +83,7 @@ impl ConstraintSetBuilder { composition: ArithExpr, ) { self.constraints.push(UngroupedConstraint { - name: name.to_string().into(), + name: name.to_string(), oracle_ids: oracle_ids.into_iter().collect(), composition, predicate: ConstraintPredicate::Zero, diff --git a/crates/core/src/oracle/multilinear.rs b/crates/core/src/oracle/multilinear.rs index bfba6348..46d39168 100644 --- a/crates/core/src/oracle/multilinear.rs +++ b/crates/core/src/oracle/multilinear.rs @@ -2,8 +2,10 @@ use std::{array, fmt::Debug, sync::Arc}; -use binius_field::{Field, TowerField}; -use binius_utils::bail; +use binius_field::{BinaryField128b, Field, TowerField}; +use binius_macros::{DeserializeBytes, SerializeBytes}; +use binius_utils::{bail, DeserializeBytes, SerializationError, SerializationMode, SerializeBytes}; +use bytes::Buf; use getset::{CopyGetters, Getters}; use crate::{ @@ -280,9 +282,20 @@ impl MultilinearOracleSetAddition<'_, F> { /// /// The oracle set also tracks the committed polynomial in batches where each batch is committed /// together with a polynomial commitment scheme. -#[derive(Default, Debug, Clone)] +#[derive(Default, Debug, Clone, SerializeBytes)] pub struct MultilinearOracleSet { - oracles: Vec>>, + oracles: Vec>, +} + +impl DeserializeBytes for MultilinearOracleSet { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + Ok(Self { + oracles: DeserializeBytes::deserialize(read_buf, mode)?, + }) + } } impl MultilinearOracleSet { @@ -323,12 +336,11 @@ impl MultilinearOracleSet { oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle, ) -> OracleId { let id = self.oracles.len(); - - self.oracles.push(Arc::new(oracle(id))); + self.oracles.push(oracle(id)); id } - fn get_from_set(&self, id: OracleId) -> Arc> { + fn get_from_set(&self, id: OracleId) -> MultilinearPolyOracle { self.oracles[id].clone() } @@ -401,7 +413,7 @@ impl MultilinearOracleSet { } pub fn oracle(&self, id: OracleId) -> MultilinearPolyOracle { - (*self.oracles[id]).clone() + self.oracles[id].clone() } pub fn n_vars(&self, id: OracleId) -> usize { @@ -438,7 +450,7 @@ impl MultilinearOracleSet { /// other oracles. This is formalized in [DP23] Section 4. /// /// [DP23]: -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)] pub struct MultilinearPolyOracle { pub id: OracleId, pub name: Option, @@ -447,7 +459,25 @@ pub struct MultilinearPolyOracle { pub variant: MultilinearPolyVariant, } -#[derive(Debug, Clone, PartialEq, Eq)] +impl DeserializeBytes for MultilinearPolyOracle { + fn deserialize( + mut read_buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(Self { + id: DeserializeBytes::deserialize(&mut read_buf, mode)?, + name: DeserializeBytes::deserialize(&mut read_buf, mode)?, + n_vars: DeserializeBytes::deserialize(&mut read_buf, mode)?, + tower_level: DeserializeBytes::deserialize(&mut read_buf, mode)?, + variant: DeserializeBytes::deserialize(&mut read_buf, mode)?, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)] pub enum MultilinearPolyVariant { Committed, Transparent(TransparentPolyOracle), @@ -459,6 +489,36 @@ pub enum MultilinearPolyVariant { ZeroPadded(OracleId), } +impl DeserializeBytes for MultilinearPolyVariant { + fn deserialize( + mut buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(match u8::deserialize(&mut buf, mode)? { + 0 => Self::Committed, + 1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?), + 2 => Self::Repeating { + id: DeserializeBytes::deserialize(&mut buf, mode)?, + log_count: DeserializeBytes::deserialize(buf, mode)?, + }, + 3 => Self::Projected(DeserializeBytes::deserialize(buf, mode)?), + 4 => Self::Shifted(DeserializeBytes::deserialize(buf, mode)?), + 5 => Self::Packed(DeserializeBytes::deserialize(buf, mode)?), + 6 => Self::LinearCombination(DeserializeBytes::deserialize(buf, mode)?), + 7 => Self::ZeroPadded(DeserializeBytes::deserialize(buf, mode)?), + variant_index => { + return Err(SerializationError::UnknownEnumVariant { + name: "MultilinearPolyVariant", + index: variant_index, + }); + } + }) + } +} + /// A transparent multilinear polynomial oracle. /// /// See the [`MultilinearPolyOracle`] documentation for context. @@ -468,6 +528,30 @@ pub struct TransparentPolyOracle { poly: Arc>, } +impl SerializeBytes for TransparentPolyOracle { + fn serialize( + &self, + mut write_buf: impl bytes::BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + self.poly.erased_serialize(&mut write_buf, mode) + } +} + +impl DeserializeBytes for TransparentPolyOracle { + fn deserialize( + read_buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(Self { + poly: Box::>::deserialize(read_buf, mode)?.into(), + }) + } +} + impl TransparentPolyOracle { fn new(poly: Arc>) -> Result { if poly.binary_tower_level() > F::TOWER_LEVEL { @@ -494,13 +578,13 @@ impl PartialEq for TransparentPolyOracle { impl Eq for TransparentPolyOracle {} -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub enum ProjectionVariant { FirstVars, LastVars, } -#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)] +#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)] pub struct Projected { #[get_copy = "pub"] id: OracleId, @@ -530,14 +614,14 @@ impl Projected { } } -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub enum ShiftVariant { CircularLeft, LogicalLeft, LogicalRight, } -#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)] +#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)] pub struct Shifted { #[get_copy = "pub"] id: OracleId, @@ -579,7 +663,7 @@ impl Shifted { } } -#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)] +#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)] pub struct Packed { #[get_copy = "pub"] id: OracleId, @@ -593,7 +677,7 @@ pub struct Packed { log_degree: usize, } -#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)] +#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)] pub struct LinearCombination { #[get_copy = "pub"] n_vars: usize, @@ -606,7 +690,7 @@ impl LinearCombination { fn new( n_vars: usize, offset: F, - inner: impl IntoIterator>, F)>, + inner: impl IntoIterator, F)>, ) -> Result { let inner = inner .into_iter() diff --git a/crates/core/src/piop/prove.rs b/crates/core/src/piop/prove.rs index 9faed21c..2d1331ec 100644 --- a/crates/core/src/piop/prove.rs +++ b/crates/core/src/piop/prove.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{ - packed::set_packed_slice, BinaryField, ExtensionField, Field, PackedExtension, PackedField, + packed::set_packed_slice, BinaryField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField, }; use binius_hal::ComputationBackend; @@ -10,7 +10,7 @@ use binius_math::{ }; use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*}; use binius_ntt::{NTTOptions, ThreadingSettings}; -use binius_utils::{bail, serialization::SerializeBytes, sorting::is_sorted_ascending}; +use binius_utils::{bail, sorting::is_sorted_ascending, SerializeBytes}; use either::Either; use itertools::{chain, Itertools}; @@ -101,7 +101,7 @@ pub fn commit( multilins: &[M], ) -> Result, Error> where - F: BinaryField + ExtensionField, + F: BinaryField, FEncode: BinaryField, P: PackedField + PackedExtension, M: MultilinearPoly

, @@ -133,7 +133,7 @@ where let rs_code = ReedSolomonCode::new( fri_params.rs_code().log_dim(), fri_params.rs_code().log_inv_rate(), - NTTOptions { + &NTTOptions { precompute_twiddles: true, thread_settings: ThreadingSettings::MultithreadedDefault, }, @@ -166,7 +166,7 @@ pub fn prove Result<(), Error> where - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, FDomain: Field, FEncode: BinaryField, P: PackedFieldIndexable @@ -234,7 +234,7 @@ where merkle_prover, sumcheck_provers, codeword, - committed, + &committed, transcript, )?; @@ -247,11 +247,11 @@ fn prove_interleaved_fri_sumcheck>, codeword: &[P], - committed: MTProver::Committed, + committed: &MTProver::Committed, transcript: &mut ProverTranscript, ) -> Result<(), Error> where - F: TowerField + ExtensionField, + F: TowerField, FEncode: BinaryField, P: PackedFieldIndexable + PackedExtension, MTScheme: MerkleTreeScheme, @@ -259,7 +259,7 @@ where Challenger_: Challenger, { let mut fri_prover = - FRIFolder::new(fri_params, merkle_prover, P::unpack_scalars(codeword), &committed)?; + FRIFolder::new(fri_params, merkle_prover, P::unpack_scalars(codeword), committed)?; let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?; diff --git a/crates/core/src/piop/tests.rs b/crates/core/src/piop/tests.rs index 69f8a054..67a52273 100644 --- a/crates/core/src/piop/tests.rs +++ b/crates/core/src/piop/tests.rs @@ -3,15 +3,15 @@ use std::iter::repeat_with; use binius_field::{ - BinaryField, BinaryField16b, BinaryField8b, ExtensionField, Field, PackedBinaryField2x128b, - PackedExtension, PackedField, PackedFieldIndexable, TowerField, + BinaryField, BinaryField16b, BinaryField8b, Field, PackedBinaryField2x128b, PackedExtension, + PackedField, PackedFieldIndexable, TowerField, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::{ DefaultEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly, }; -use binius_utils::serialization::{DeserializeBytes, SerializeBytes}; +use binius_utils::{DeserializeBytes, SerializeBytes}; use groestl_crypto::Groestl256; use rand::{rngs::StdRng, Rng, SeedableRng}; @@ -104,7 +104,7 @@ fn commit_prove_verify( merkle_prover: &impl MerkleTreeProver, log_inv_rate: usize, ) where - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, FDomain: BinaryField, FEncode: BinaryField, P: PackedFieldIndexable diff --git a/crates/core/src/piop/verify.rs b/crates/core/src/piop/verify.rs index 96858f7f..a0c5a5e7 100644 --- a/crates/core/src/piop/verify.rs +++ b/crates/core/src/piop/verify.rs @@ -5,7 +5,7 @@ use std::{borrow::Borrow, cmp::Ordering, iter, ops::Range}; use binius_field::{BinaryField, ExtensionField, Field, TowerField}; use binius_math::evaluate_piecewise_multilinear; use binius_ntt::NTTOptions; -use binius_utils::{bail, serialization::DeserializeBytes}; +use binius_utils::{bail, DeserializeBytes}; use getset::CopyGetters; use tracing::instrument; @@ -137,7 +137,7 @@ where let log_batch_size = fold_arities.first().copied().unwrap_or(0); let log_dim = commit_meta.total_vars - log_batch_size; - let rs_code = ReedSolomonCode::new(log_dim, log_inv_rate, NTTOptions::default())?; + let rs_code = ReedSolomonCode::new(log_dim, log_inv_rate, &NTTOptions::default())?; let n_test_queries = fri::calculate_n_test_queries::(security_bits, &rs_code)?; let fri_params = FRIParams::new(rs_code, log_batch_size, fold_arities, n_test_queries)?; Ok(fri_params) diff --git a/crates/core/src/polynomial/arith_circuit.rs b/crates/core/src/polynomial/arith_circuit.rs index 3d7e1862..2bec2030 100644 --- a/crates/core/src/polynomial/arith_circuit.rs +++ b/crates/core/src/polynomial/arith_circuit.rs @@ -3,14 +3,12 @@ use std::{fmt::Debug, mem::MaybeUninit, sync::Arc}; use binius_field::{ExtensionField, Field, PackedField, TowerField}; -use binius_math::{ArithExpr, CompositionPoly, CompositionPolyOS, Error}; +use binius_math::{ArithExpr, CompositionPoly, Error}; use stackalloc::{ helpers::{slice_assume_init, slice_assume_init_mut}, stackalloc_uninit, }; -use super::MultivariatePoly; - /// Convert the expression to a sequence of arithmetic operations that can be evaluated in sequence. fn circuit_steps_for_expr( expr: &ArithExpr, @@ -50,10 +48,22 @@ fn circuit_steps_for_expr( result.push(CircuitStep::Mul(left, right)); CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1)) } - ArithExpr::Pow(id, exp) => { - let id = to_circuit_inner(id, result); - result.push(CircuitStep::Pow(id, *exp)); - CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1)) + ArithExpr::Pow(base, exp) => { + let mut acc = to_circuit_inner(base, result); + let base_expr = acc; + let highest_bit = exp.ilog2(); + + for i in (0..highest_bit).rev() { + result.push(CircuitStep::Square(acc)); + acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1)); + + if (exp >> i) & 1 != 0 { + result.push(CircuitStep::Mul(acc, base_expr)); + acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1)); + } + } + + acc } } } @@ -63,7 +73,7 @@ fn circuit_steps_for_expr( } /// Input of the circuit calculation step -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] enum CircuitNode { /// Input variable Var(usize), @@ -87,7 +97,7 @@ impl CircuitNode { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] enum CircuitStepArgument { Expr(CircuitNode), Const(F), @@ -101,15 +111,15 @@ enum CircuitStepArgument { enum CircuitStep { Add(CircuitStepArgument, CircuitStepArgument), Mul(CircuitStepArgument, CircuitStepArgument), - Pow(CircuitStepArgument, u64), + Square(CircuitStepArgument), AddMul(usize, CircuitStepArgument, CircuitStepArgument), } /// Describes polynomial evaluations using a directed acyclic graph of expressions. /// -/// This is meant as an alternative to a hard-coded CompositionPolyOS. +/// This is meant as an alternative to a hard-coded CompositionPoly. /// -/// The advantage over a hard coded CompositionPolyOS is that this can be constructed and manipulated dynamically at runtime +/// The advantage over a hard coded CompositionPoly is that this can be constructed and manipulated dynamically at runtime /// and the object representing different polnomials can be stored in a homogeneous collection. #[derive(Debug, Clone)] pub struct ArithCircuitPoly { @@ -119,12 +129,14 @@ pub struct ArithCircuitPoly { retval: CircuitStepArgument, degree: usize, n_vars: usize, + tower_level: usize, } -impl ArithCircuitPoly { +impl ArithCircuitPoly { pub fn new(expr: ArithExpr) -> Self { let degree = expr.degree(); let n_vars = expr.n_vars(); + let tower_level = expr.binary_tower_level(); let (exprs, retval) = circuit_steps_for_expr(&expr); Self { @@ -133,6 +145,7 @@ impl ArithCircuitPoly { retval, degree, n_vars, + tower_level, } } @@ -142,6 +155,7 @@ impl ArithCircuitPoly { /// arithmetic expression. pub fn with_n_vars(n_vars: usize, expr: ArithExpr) -> Result { let degree = expr.degree(); + let tower_level = expr.binary_tower_level(); if n_vars < expr.n_vars() { return Err(Error::IncorrectNumberOfVariables { expected: expr.n_vars(), @@ -156,11 +170,14 @@ impl ArithCircuitPoly { retval, n_vars, degree, + tower_level, }) } } -impl CompositionPoly for ArithCircuitPoly { +impl>> CompositionPoly

+ for ArithCircuitPoly +{ fn degree(&self) -> usize { self.degree } @@ -170,14 +187,14 @@ impl CompositionPoly for ArithCircuitPoly { } fn binary_tower_level(&self) -> usize { - F::TOWER_LEVEL + self.tower_level } - fn expression>(&self) -> ArithExpr { + fn expression(&self) -> ArithExpr { self.expr.convert_field() } - fn evaluate>>(&self, query: &[P]) -> Result { + fn evaluate(&self, query: &[P]) -> Result { if query.len() != self.n_vars { return Err(Error::IncorrectQuerySize { expected: self.n_vars, @@ -226,8 +243,8 @@ impl CompositionPoly for ArithCircuitPoly { after, get_argument_value(*x, before) * get_argument_value(*y, before), ), - CircuitStep::Pow(id, exp) => { - write_result(after, get_argument_value(*id, before).pow(*exp)) + CircuitStep::Square(x) => { + write_result(after, get_argument_value(*x, before).square()) } }; } @@ -241,11 +258,7 @@ impl CompositionPoly for ArithCircuitPoly { }) } - fn batch_evaluate>>( - &self, - batch_query: &[&[P]], - evals: &mut [P], - ) -> Result<(), Error> { + fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> { let row_len = evals.len(); if batch_query.iter().any(|row| row.len() != row_len) { return Err(Error::BatchEvaluateSizeMismatch); @@ -285,29 +298,31 @@ impl CompositionPoly for ArithCircuitPoly { }, ); } - CircuitStep::Pow(id, exp) => match id { - CircuitStepArgument::Expr(id) => { - let id = id.get_sparse_chunk(batch_query, before, row_len); - for j in 0..row_len { - // Safety: `current` and `id` have length equal to `row_len` - unsafe { - current - .get_unchecked_mut(j) - .write(id.get_unchecked(j).pow(*exp)); + CircuitStep::Square(arg) => { + match arg { + CircuitStepArgument::Expr(node) => { + let id_chunk = node.get_sparse_chunk(batch_query, before, row_len); + for j in 0..row_len { + // Safety: `current` and `id_chunk` have length equal to `row_len` + unsafe { + current + .get_unchecked_mut(j) + .write(id_chunk.get_unchecked(j).square()); + } } } - } - CircuitStepArgument::Const(id) => { - let id: P = P::broadcast((*id).into()); - let result = id.pow(*exp); - for j in 0..row_len { - // Safety: `current` has length equal to `row_len` - unsafe { - current.get_unchecked_mut(j).write(result); + CircuitStepArgument::Const(value) => { + let value: P = P::broadcast((*value).into()); + let result = value.square(); + for j in 0..row_len { + // Safety: `current` has length equal to `row_len` + unsafe { + current.get_unchecked_mut(j).write(result); + } } } } - }, + } CircuitStep::AddMul(target, left, right) => { let target = &before[row_len * target..(target + 1) * row_len]; // Safety: by construction of steps and evaluation order we know @@ -349,52 +364,6 @@ impl CompositionPoly for ArithCircuitPoly { } } -impl>> CompositionPolyOS

- for ArithCircuitPoly -{ - fn degree(&self) -> usize { - CompositionPoly::degree(self) - } - - fn n_vars(&self) -> usize { - CompositionPoly::n_vars(self) - } - - fn expression(&self) -> ArithExpr { - self.expr.convert_field() - } - - fn binary_tower_level(&self) -> usize { - CompositionPoly::binary_tower_level(self) - } - - fn evaluate(&self, query: &[P]) -> Result { - CompositionPoly::evaluate(self, query) - } - - fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> { - CompositionPoly::batch_evaluate(self, batch_query, evals) - } -} - -impl MultivariatePoly for ArithCircuitPoly { - fn degree(&self) -> usize { - CompositionPoly::degree(&self) - } - - fn n_vars(&self) -> usize { - CompositionPoly::n_vars(&self) - } - - fn binary_tower_level(&self) -> usize { - CompositionPoly::binary_tower_level(&self) - } - - fn evaluate(&self, query: &[F]) -> Result { - CompositionPoly::evaluate(&self, query).map_err(|e| e.into()) - } -} - /// Apply a binary operation to two arguments and store the result in `current_evals`. /// `op` must be a function that takes two arguments and initialized the result with the third argument. fn apply_binary_op>>( @@ -466,7 +435,7 @@ mod tests { use binius_field::{ BinaryField16b, BinaryField8b, PackedBinaryField8x16b, PackedField, TowerField, }; - use binius_math::CompositionPolyOS; + use binius_math::CompositionPoly; use binius_utils::felts; use super::*; @@ -479,7 +448,7 @@ mod tests { let expr = ArithExpr::Const(F::new(123)); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; + let typed_circuit: &dyn CompositionPoly

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); assert_eq!(typed_circuit.degree(), 0); assert_eq!(typed_circuit.n_vars(), 0); @@ -500,8 +469,8 @@ mod tests { let expr = ArithExpr::Var(0); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 0); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -528,8 +497,8 @@ mod tests { let expr = ArithExpr::Const(F::new(123)) + ArithExpr::Var(0); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 3); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -548,8 +517,8 @@ mod tests { let expr = ArithExpr::Const(F::new(123)) * ArithExpr::Var(0); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 3); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -574,8 +543,8 @@ mod tests { let expr = ArithExpr::Var(0).pow(13); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 0); assert_eq!(typed_circuit.degree(), 13); assert_eq!(typed_circuit.n_vars(), 1); @@ -600,8 +569,8 @@ mod tests { let expr = ArithExpr::Var(0).pow(2) * (ArithExpr::Var(1) + ArithExpr::Const(F::new(123))); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 3); assert_eq!(typed_circuit.degree(), 3); assert_eq!(typed_circuit.n_vars(), 2); @@ -662,7 +631,7 @@ mod tests { let circuit = ArithCircuitPoly::::new(expr); assert_eq!(circuit.steps.len(), 2); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; + let typed_circuit: &dyn CompositionPoly

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 2); @@ -721,8 +690,8 @@ mod tests { let circuit = ArithCircuitPoly::::new(expr); assert_eq!(circuit.steps.len(), 1); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 1); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -746,10 +715,10 @@ mod tests { // ((x0^2)^3)^4 let expr = ArithExpr::Var(0).pow(2).pow(3).pow(4); let circuit = ArithCircuitPoly::::new(expr); - assert_eq!(circuit.steps.len(), 1); + assert_eq!(circuit.steps.len(), 5); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 0); assert_eq!(typed_circuit.degree(), 24); assert_eq!(typed_circuit.n_vars(), 1); @@ -764,4 +733,358 @@ mod tests { P::from_scalars(felts!(BinaryField16b[0, 1, 1, 1, 20, 152, 41, 170])), ); } + + #[test] + fn test_circuit_steps_for_expr_constant() { + type F = BinaryField8b; + + let expr = ArithExpr::Const(F::new(5)); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert!(steps.is_empty(), "No steps should be generated for a constant"); + assert_eq!(retval, CircuitStepArgument::Const(F::new(5))); + } + + #[test] + fn test_circuit_steps_for_expr_variable() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(18); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert!(steps.is_empty(), "No steps should be generated for a variable"); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(18)))); + } + + #[test] + fn test_circuit_steps_for_expr_addition() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(14) + ArithExpr::::Var(56); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 1, "One addition step should be generated"); + assert!(matches!( + steps[0], + CircuitStep::Add( + CircuitStepArgument::Expr(CircuitNode::Var(14)), + CircuitStepArgument::Expr(CircuitNode::Var(56)) + ) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0)))); + } + + #[test] + fn test_circuit_steps_for_expr_multiplication() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(36) * ArithExpr::Var(26); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 1, "One multiplication step should be generated"); + assert!(matches!( + steps[0], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Var(36)), + CircuitStepArgument::Expr(CircuitNode::Var(26)) + ) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_1() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(12).pow(1); + let (steps, retval) = circuit_steps_for_expr(&expr); + + // No steps should be generated for x^1 + assert_eq!(steps.len(), 0, "Pow(1) should not generate any computation steps"); + + // The return value should just be the variable itself + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(12)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_2() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(10).pow(2); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 1, "Pow(2) should generate one squaring step"); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(10))) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_3() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(5).pow(3); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!( + steps.len(), + 2, + "Pow(3) should generate one squaring and one multiplication step" + ); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(5))) + )); + assert!(matches!( + steps[1], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(0)), + CircuitStepArgument::Expr(CircuitNode::Var(5)) + ) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_4() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(7).pow(4); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 2, "Pow(4) should generate two squaring steps"); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7))) + )); + + assert!(matches!( + steps[1], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0))) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_5() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(3).pow(5); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!( + steps.len(), + 3, + "Pow(5) should generate two squaring steps and one multiplication" + ); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(3))) + )); + assert!(matches!( + steps[1], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0))) + )); + assert!(matches!( + steps[2], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(1)), + CircuitStepArgument::Expr(CircuitNode::Var(3)) + ) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_8() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(4).pow(8); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 3, "Pow(8) should generate three squaring steps"); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(4))) + )); + assert!(matches!( + steps[1], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0))) + )); + assert!(matches!( + steps[2], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1))) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_9() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(8).pow(9); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!( + steps.len(), + 4, + "Pow(9) should generate three squaring steps and one multiplication" + ); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(8))) + )); + assert!(matches!( + steps[1], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0))) + )); + assert!(matches!( + steps[2], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1))) + )); + assert!(matches!( + steps[3], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(2)), + CircuitStepArgument::Expr(CircuitNode::Var(8)) + ) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_12() { + type F = BinaryField8b; + let expr = ArithExpr::::Var(6).pow(12); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 4, "Pow(12) should use 4 steps."); + + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(6))) + )); + assert!(matches!( + steps[1], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(0)), + CircuitStepArgument::Expr(CircuitNode::Var(6)) + ) + )); + assert!(matches!( + steps[2], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1))) + )); + assert!(matches!( + steps[3], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2))) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_13() { + type F = BinaryField8b; + let expr = ArithExpr::::Var(7).pow(13); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 5, "Pow(13) should use 5 steps."); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7))) + )); + assert!(matches!( + steps[1], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(0)), + CircuitStepArgument::Expr(CircuitNode::Var(7)) + ) + )); + assert!(matches!( + steps[2], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1))) + )); + assert!(matches!( + steps[3], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2))) + )); + assert!(matches!( + steps[4], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(3)), + CircuitStepArgument::Expr(CircuitNode::Var(7)) + ) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(4)))); + } + + #[test] + fn test_circuit_steps_for_expr_complex() { + type F = BinaryField8b; + + let expr = (ArithExpr::::Var(0) * ArithExpr::Var(1)) + + (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(2) + - ArithExpr::Var(3); + + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 4, "Expression should generate 4 computation steps"); + + assert!( + matches!( + steps[0], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Var(0)), + CircuitStepArgument::Expr(CircuitNode::Var(1)) + ) + ), + "First step should be multiplication x0 * x1" + ); + + assert!( + matches!( + steps[1], + CircuitStep::Add( + CircuitStepArgument::Const(F::ONE), + CircuitStepArgument::Expr(CircuitNode::Var(0)) + ) + ), + "Second step should be (1 - x0)" + ); + + assert!( + matches!( + steps[2], + CircuitStep::AddMul( + 0, + CircuitStepArgument::Expr(CircuitNode::Slot(1)), + CircuitStepArgument::Expr(CircuitNode::Var(2)) + ) + ), + "Third step should be (1 - x0) * x2" + ); + + assert!( + matches!( + steps[3], + CircuitStep::Add( + CircuitStepArgument::Expr(CircuitNode::Slot(0)), + CircuitStepArgument::Expr(CircuitNode::Var(3)) + ) + ), + "Fourth step should be x0 * x1 + (1 - x0) * x2 + x3" + ); + + assert!( + matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))), + "Final result should be stored in Slot(3)" + ); + } } diff --git a/crates/core/src/polynomial/cached.rs b/crates/core/src/polynomial/cached.rs deleted file mode 100644 index 181257e5..00000000 --- a/crates/core/src/polynomial/cached.rs +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2024-2025 Irreducible Inc. - -use std::{ - any::{Any, TypeId}, - collections::HashMap, - fmt::Debug, - marker::PhantomData, -}; - -use binius_field::{ExtensionField, Field, PackedField}; -use binius_math::{ArithExpr, CompositionPoly, CompositionPolyOS, Error}; - -/// Cached composition poly wrapper. -/// -/// It stores the efficient implementations of the composition poly for some known set of packed field types. -/// We are usually able to use this when the inner poly is constructed with a macro for the known field and packed field types. -#[derive(Default, Debug)] -pub struct CachedPoly> { - inner: Inner, - cache: PackedFieldCache, -} - -impl> CachedPoly { - /// Create a new cached polynomial with the given inner polynomial. - pub fn new(inner: Inner) -> Self { - Self { - inner, - cache: Default::default(), - } - } - - /// Register efficient implementations for the `P` packed field type in the cache. - pub fn register>>( - &mut self, - composition: impl CompositionPolyOS

+ 'static, - ) { - self.cache.register(composition); - } -} - -impl> CompositionPoly for CachedPoly { - fn n_vars(&self) -> usize { - self.inner.n_vars() - } - - fn degree(&self) -> usize { - self.inner.degree() - } - - fn binary_tower_level(&self) -> usize { - self.inner.binary_tower_level() - } - - fn expression>(&self) -> ArithExpr { - self.inner.expression() - } - - fn evaluate>>(&self, query: &[P]) -> Result { - if let Some(result) = self.cache.try_evaluate(query) { - result - } else { - self.inner.evaluate(query) - } - } - - fn batch_evaluate>>( - &self, - batch_query: &[&[P]], - evals: &mut [P], - ) -> Result<(), Error> { - if let Some(result) = self.cache.try_batch_evaluate(batch_query, evals) { - result - } else { - self.inner.batch_evaluate(batch_query, evals) - } - } -} - -impl, P: PackedField>> - CompositionPolyOS

for CachedPoly -{ - fn binary_tower_level(&self) -> usize { - CompositionPoly::binary_tower_level(&self) - } - - fn n_vars(&self) -> usize { - CompositionPoly::n_vars(&self) - } - - fn degree(&self) -> usize { - CompositionPoly::degree(&self) - } - - fn expression(&self) -> ArithExpr { - CompositionPoly::expression(&self) - } - - fn evaluate(&self, query: &[P]) -> Result { - CompositionPoly::evaluate(&self, query) - } - - fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> { - CompositionPoly::batch_evaluate(&self, batch_query, evals) - } -} - -#[derive(Default)] -struct PackedFieldCache { - /// Map from the packed field type 'P to the efficient implementation of the composition polynomial - /// with actual type `Box>`. - entries: HashMap>, - _pd: PhantomData, -} - -impl PackedFieldCache { - /// Register efficient implementations for the `P` packed field type in the cache. - fn register>>( - &mut self, - composition: impl CompositionPolyOS

+ 'static, - ) { - let boxed_composition = Box::new(composition) as Box>; - self.entries - .insert(TypeId::of::

(), Box::new(boxed_composition) as Box); - } - - /// Try to evaluate the expression using the efficient implementation for the `P` packed field type. - /// If no implementation is found, return None. - fn try_evaluate>>( - &self, - query: &[P], - ) -> Option> { - if let Some(entry) = self.entries.get(&TypeId::of::

()) { - let entry = entry - .downcast_ref::>>() - .expect("cast must succeed"); - Some(entry.evaluate(query)) - } else { - None - } - } - - /// Try to batch evaluate the expression using the efficient implementation for the `P` packed field type. - /// If no implementation is found, return None. - fn try_batch_evaluate>>( - &self, - batch_query: &[&[P]], - evals: &mut [P], - ) -> Option> { - if let Some(entry) = self.entries.get(&TypeId::of::

()) { - let entry = entry - .downcast_ref::>>() - .expect("cast must succeed"); - Some(entry.batch_evaluate(batch_query, evals)) - } else { - None - } - } -} - -impl Debug for PackedFieldCache { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PackedFieldCache") - .field("cached_implementations", &self.entries.len()) - .finish() - } -} - -#[cfg(test)] -mod tests { - use std::iter::zip; - - use binius_field::{BinaryField8b, ExtensionField, PackedBinaryField16x8b, PackedField}; - use binius_math::{ArithExpr, CompositionPolyOS}; - - use super::*; - use crate::polynomial::{cached::CachedPoly, ArithCircuitPoly}; - - fn ensure_equal_batch_eval_results( - circuit_1: &impl CompositionPolyOS

, - circuit_2: &impl CompositionPolyOS

, - batch_query: &[&[P]], - ) { - for row in 0..batch_query[0].len() { - let query = batch_query.iter().map(|q| q[row]).collect::>(); - - assert_eq!(circuit_1.evaluate(&query).unwrap(), circuit_2.evaluate(&query).unwrap()); - } - - let result_1 = { - let mut uncached_evals = vec![P::zero(); batch_query[0].len()]; - circuit_1 - .batch_evaluate(batch_query, &mut uncached_evals) - .unwrap(); - uncached_evals - }; - - let result_2 = { - let mut cached_evals = vec![P::zero(); batch_query[0].len()]; - circuit_2 - .batch_evaluate(batch_query, &mut cached_evals) - .unwrap(); - cached_evals - }; - - assert_eq!(result_1, result_2); - } - - #[derive(Debug, Copy, Clone)] - struct AddComposition; - - impl>> CompositionPolyOS

- for AddComposition - { - fn binary_tower_level(&self) -> usize { - 0 - } - - fn n_vars(&self) -> usize { - 1 - } - - fn degree(&self) -> usize { - 1 - } - - fn expression(&self) -> ArithExpr { - ArithExpr::Const(BinaryField8b::new(123).into()) + ArithExpr::Var(0) - } - - fn evaluate(&self, query: &[P]) -> Result { - Ok(query[0] + P::broadcast(BinaryField8b::new(123).into())) - } - - fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> { - for (input, output) in zip(batch_query[0], evals) { - *output = *input + P::broadcast(BinaryField8b::new(123).into()); - } - - Ok(()) - } - } - - #[test] - fn test_cached_impl() { - let expr = ArithExpr::Const(BinaryField8b::new(123)) + ArithExpr::Var(0); - let circuit = ArithCircuitPoly::::new(expr); - - let composition = AddComposition; - - let mut cached_circuit = CachedPoly::new(circuit.clone()); - cached_circuit.register::(composition); - - let batch_query = [(0..255).map(BinaryField8b::new).collect::>()]; - let batch_query = batch_query.iter().map(|q| q.as_slice()).collect::>(); - ensure_equal_batch_eval_results(&circuit, &cached_circuit, &batch_query); - } - - #[test] - fn test_uncached_impl() { - let expr = ArithExpr::Const(BinaryField8b::new(123)) + ArithExpr::Var(0); - let circuit = ArithCircuitPoly::::new(expr); - - let composition = AddComposition; - - let mut cached_circuit = CachedPoly::new(circuit.clone()); - cached_circuit.register::(composition); - - let batch_query = [(0..255).map(BinaryField8b::new).collect::>()]; - let batch_query = batch_query.iter().map(|q| q.as_slice()).collect::>(); - ensure_equal_batch_eval_results(&circuit, &cached_circuit, &batch_query); - } -} diff --git a/crates/core/src/polynomial/mod.rs b/crates/core/src/polynomial/mod.rs index 1c21083a..7f8e4d76 100644 --- a/crates/core/src/polynomial/mod.rs +++ b/crates/core/src/polynomial/mod.rs @@ -1,7 +1,6 @@ // Copyright 2024-2025 Irreducible Inc. mod arith_circuit; -mod cached; mod error; mod multivariate; #[allow(dead_code)] @@ -9,6 +8,5 @@ mod multivariate; pub mod test_utils; pub use arith_circuit::*; -pub use cached::*; pub use error::*; pub use multivariate::*; diff --git a/crates/core/src/polynomial/multivariate.rs b/crates/core/src/polynomial/multivariate.rs index a8e6fdcf..e670f7ac 100644 --- a/crates/core/src/polynomial/multivariate.rs +++ b/crates/core/src/polynomial/multivariate.rs @@ -4,9 +4,10 @@ use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sy use binius_field::{Field, PackedField}; use binius_math::{ - ArithExpr, CompositionPolyOS, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef, + ArithExpr, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef, }; -use binius_utils::bail; +use binius_utils::{bail, SerializationError, SerializationMode}; +use bytes::BufMut; use itertools::Itertools; use rand::{rngs::StdRng, SeedableRng}; @@ -14,8 +15,8 @@ use super::error::Error; /// A multivariate polynomial over a binary tower field. /// -/// The definition `MultivariatePoly` is nearly identical to that of [`CompositionPolyOS`], except that -/// `MultivariatePoly` is _object safe_, whereas `CompositionPolyOS` is not. +/// The definition `MultivariatePoly` is nearly identical to that of [`CompositionPoly`], except that +/// `MultivariatePoly` is _object safe_, whereas `CompositionPoly` is not. pub trait MultivariatePoly

: Debug + Send + Sync { /// The number of variables. fn n_vars(&self) -> usize; @@ -28,13 +29,24 @@ pub trait MultivariatePoly

: Debug + Send + Sync { /// Returns the maximum binary tower level of all constants in the arithmetic expression. fn binary_tower_level(&self) -> usize; + + /// Serialize a type erased MultivariatePoly. + /// Since not every MultivariatePoly implements serialization, this defaults to returning an error. + fn erased_serialize( + &self, + write_buf: &mut dyn BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + let _ = (write_buf, mode); + Err(SerializationError::SerializationNotImplemented) + } } /// Identity composition function $g(X) = X$. #[derive(Clone, Debug)] pub struct IdentityCompositionPoly; -impl CompositionPolyOS

for IdentityCompositionPoly { +impl CompositionPoly

for IdentityCompositionPoly { fn n_vars(&self) -> usize { 1 } @@ -59,7 +71,7 @@ impl CompositionPolyOS

for IdentityCompositionPoly { } } -/// An adapter that constructs a [`CompositionPolyOS`] for a field from a [`CompositionPolyOS`] for a +/// An adapter that constructs a [`CompositionPoly`] for a field from a [`CompositionPoly`] for a /// packing of that field. /// /// This is not intended for use in performance-critical code sections. @@ -72,7 +84,7 @@ pub struct CompositionScalarAdapter { impl CompositionScalarAdapter where P: PackedField, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { pub const fn new(composition: Composition) -> Self { Self { @@ -82,11 +94,11 @@ where } } -impl CompositionPolyOS for CompositionScalarAdapter +impl CompositionPoly for CompositionScalarAdapter where F: Field, P: PackedField, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn n_vars(&self) -> usize { self.composition.n_vars() @@ -141,7 +153,7 @@ where impl MultilinearComposite where P: PackedField, - C: CompositionPolyOS

, + C: CompositionPoly

, M: MultilinearPoly

, { pub fn new(n_vars: usize, composition: C, multilinears: Vec) -> Result { @@ -207,12 +219,10 @@ where impl MultilinearComposite where P: PackedField, - C: CompositionPolyOS

+ 'static, + C: CompositionPoly

+ 'static, M: MultilinearPoly

, { - pub fn to_arc_dyn_composition( - self, - ) -> MultilinearComposite>, M> { + pub fn to_arc_dyn_composition(self) -> MultilinearComposite>, M> { MultilinearComposite { n_vars: self.n_vars, composition: Arc::new(self.composition), @@ -267,7 +277,7 @@ where /// for two distinct multivariate polynomials f and g. /// /// NOTE: THIS IS NOT ADVERSARIALLY COLLISION RESISTANT, COLLISIONS CAN BE MANUFACTURED EASILY -pub fn composition_hash>(composition: &C) -> P { +pub fn composition_hash>(composition: &C) -> P { let mut rng = StdRng::from_seed([0; 32]); let random_point = repeat_with(|| P::random(&mut rng)) @@ -281,7 +291,7 @@ pub fn composition_hash>(composition: &C #[cfg(test)] mod tests { - use binius_math::{ArithExpr, CompositionPolyOS}; + use binius_math::{ArithExpr, CompositionPoly}; #[test] fn test_fingerprint_same_32b() { @@ -291,7 +301,7 @@ mod tests { let expr = (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -308,7 +318,7 @@ mod tests { let expr = ArithExpr::Var(0) + ArithExpr::Var(1); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -326,7 +336,7 @@ mod tests { let expr = (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -342,7 +352,7 @@ mod tests { let expr = ArithExpr::Var(0) + ArithExpr::Var(1); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -360,7 +370,7 @@ mod tests { let expr = (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -376,7 +386,7 @@ mod tests { let expr = ArithExpr::Var(0) + ArithExpr::Var(1); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; diff --git a/crates/core/src/protocols/evalcheck/error.rs b/crates/core/src/protocols/evalcheck/error.rs index 54426fe8..d5bf9447 100644 --- a/crates/core/src/protocols/evalcheck/error.rs +++ b/crates/core/src/protocols/evalcheck/error.rs @@ -45,7 +45,7 @@ pub enum VerificationError { impl VerificationError { pub fn incorrect_composite_poly_evaluation( - oracle: CompositePolyOracle, + oracle: &CompositePolyOracle, ) -> Self { let names = oracle .inner_polys() diff --git a/crates/core/src/protocols/fri/common.rs b/crates/core/src/protocols/fri/common.rs index e8e7318f..9d74660b 100644 --- a/crates/core/src/protocols/fri/common.rs +++ b/crates/core/src/protocols/fri/common.rs @@ -343,13 +343,13 @@ mod tests { #[test] fn test_calculate_n_test_queries() { let security_bits = 96; - let rs_code = ReedSolomonCode::new(28, 1, NTTOptions::default()).unwrap(); + let rs_code = ReedSolomonCode::new(28, 1, &NTTOptions::default()).unwrap(); let n_test_queries = calculate_n_test_queries::(security_bits, &rs_code) .unwrap(); assert_eq!(n_test_queries, 232); - let rs_code = ReedSolomonCode::new(28, 2, NTTOptions::default()).unwrap(); + let rs_code = ReedSolomonCode::new(28, 2, &NTTOptions::default()).unwrap(); let n_test_queries = calculate_n_test_queries::(security_bits, &rs_code) .unwrap(); @@ -359,7 +359,7 @@ mod tests { #[test] fn test_calculate_n_test_queries_unsatisfiable() { let security_bits = 128; - let rs_code = ReedSolomonCode::new(28, 1, NTTOptions::default()).unwrap(); + let rs_code = ReedSolomonCode::new(28, 1, &NTTOptions::default()).unwrap(); assert_matches!( calculate_n_test_queries::(security_bits, &rs_code), Err(Error::ParameterError) diff --git a/crates/core/src/protocols/fri/prove.rs b/crates/core/src/protocols/fri/prove.rs index ba07ebb7..2a3b9f49 100644 --- a/crates/core/src/protocols/fri/prove.rs +++ b/crates/core/src/protocols/fri/prove.rs @@ -3,7 +3,7 @@ use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedField, TowerField}; use binius_hal::{make_portable_backend, ComputationBackend}; use binius_maybe_rayon::prelude::*; -use binius_utils::{bail, serialization::SerializeBytes}; +use binius_utils::{bail, SerializeBytes}; use bytemuck::zeroed_vec; use bytes::BufMut; use itertools::izip; @@ -174,7 +174,7 @@ pub fn commit_interleaved( message: &[P], ) -> Result, Error> where - F: BinaryField + ExtensionField, + F: BinaryField, FA: BinaryField, P: PackedField + PackedExtension, PA: PackedField, @@ -209,7 +209,7 @@ pub fn commit_interleaved_with( message_writer: impl FnOnce(&mut [P]), ) -> Result, Error> where - F: BinaryField + ExtensionField, + F: BinaryField, FA: BinaryField, P: PackedField + PackedExtension, PA: PackedField, diff --git a/crates/core/src/protocols/fri/tests.rs b/crates/core/src/protocols/fri/tests.rs index df4ecf22..1798b3c1 100644 --- a/crates/core/src/protocols/fri/tests.rs +++ b/crates/core/src/protocols/fri/tests.rs @@ -46,14 +46,14 @@ fn test_commit_prove_verify_success( let committed_rs_code_packed = ReedSolomonCode::>::new( log_dimension, log_inv_rate, - NTTOptions::default(), + &NTTOptions::default(), ) .unwrap(); let merkle_prover = BinaryMerkleTreeProver::<_, Groestl256, _>::new(Groestl256ByteCompression); let committed_rs_code = - ReedSolomonCode::::new(log_dimension, log_inv_rate, NTTOptions::default()).unwrap(); + ReedSolomonCode::::new(log_dimension, log_inv_rate, &NTTOptions::default()).unwrap(); let n_test_queries = 3; let params = diff --git a/crates/core/src/protocols/fri/verify.rs b/crates/core/src/protocols/fri/verify.rs index 0abc4604..85e54814 100644 --- a/crates/core/src/protocols/fri/verify.rs +++ b/crates/core/src/protocols/fri/verify.rs @@ -4,7 +4,7 @@ use std::iter; use binius_field::{BinaryField, ExtensionField, TowerField}; use binius_hal::{make_portable_backend, ComputationBackend}; -use binius_utils::{bail, serialization::DeserializeBytes}; +use binius_utils::{bail, DeserializeBytes}; use bytes::Buf; use itertools::izip; use tracing::instrument; diff --git a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs index 800ff7f7..ca11ea7f 100644 --- a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs +++ b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs @@ -2,13 +2,9 @@ use std::ops::Range; -use binius_field::{ - util::eq, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, -}; +use binius_field::{util::eq, Field, PackedExtension, PackedField, PackedFieldIndexable}; use binius_hal::{ComputationBackend, SumcheckEvaluator}; -use binius_math::{ - CompositionPolyOS, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly, -}; +use binius_math::{CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly}; use binius_maybe_rayon::prelude::*; use binius_utils::bail; use itertools::izip; @@ -48,12 +44,10 @@ where impl<'a, F, FDomain, P, Composition, M, Backend> GPAProver<'a, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, - P: PackedFieldIndexable - + PackedExtension - + PackedExtension, - Composition: CompositionPolyOS

, + P: PackedFieldIndexable + PackedExtension, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -104,8 +98,8 @@ where let evaluation_points = domains .iter() - .max_by_key(|domain| domain.points().len()) - .map_or_else(|| Vec::new(), |domain| domain.points().to_vec()); + .max_by_key(|domain| domain.size()) + .map_or_else(|| Vec::new(), |domain| domain.finite_points().to_vec()); let state = ProverState::new( multilinears, @@ -193,12 +187,12 @@ where impl SumcheckProver for GPAProver<'_, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedFieldIndexable + PackedExtension + PackedExtension, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -229,7 +223,7 @@ where }) .collect::>(); - let evals = self.state.calculate_later_round_evals(&evaluators)?; + let evals = self.state.calculate_round_evals(&evaluators)?; let coeffs = self.state .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)?; @@ -287,13 +281,13 @@ where gpa_round_challenge: P::Scalar, } -impl SumcheckEvaluator +impl SumcheckEvaluator for GPAEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension + PackedExtension, FDomain: Field, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn eval_point_indices(&self) -> Range { // By definition of grand product GKR circuit, the composition evaluation is a multilinear @@ -344,10 +338,10 @@ where impl SumcheckInterpolator for GPAEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension, FDomain: Field, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { #[instrument( skip_all, diff --git a/crates/core/src/protocols/gkr_gpa/prove.rs b/crates/core/src/protocols/gkr_gpa/prove.rs index c1ab27bc..e48d0c27 100644 --- a/crates/core/src/protocols/gkr_gpa/prove.rs +++ b/crates/core/src/protocols/gkr_gpa/prove.rs @@ -1,8 +1,6 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{ - ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField, -}; +use binius_field::{Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField}; use binius_hal::ComputationBackend; use binius_math::{ extrapolate_line_scalar, EvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, @@ -46,7 +44,6 @@ where + PackedExtension + PackedExtension, FDomain: Field, - P::Scalar: Field + ExtensionField, Challenger_: Challenger, Backend: ComputationBackend, { @@ -266,7 +263,6 @@ where where FDomain: Field, P: PackedExtension, - F: ExtensionField, { // test same layer let Some(first_prover) = provers.first() else { diff --git a/crates/core/src/protocols/gkr_gpa/tests.rs b/crates/core/src/protocols/gkr_gpa/tests.rs index 87fcefa2..26354460 100644 --- a/crates/core/src/protocols/gkr_gpa/tests.rs +++ b/crates/core/src/protocols/gkr_gpa/tests.rs @@ -7,8 +7,8 @@ use binius_field::{ as_packed_field::{PackScalar, PackedType}, packed::set_packed_slice, underlier::{UnderlierType, WithUnderlier}, - BinaryField128b, BinaryField32b, ExtensionField, Field, PackedExtension, PackedField, - PackedFieldIndexable, RepackedExtension, TowerField, + BinaryField128b, BinaryField32b, Field, PackedExtension, PackedField, PackedFieldIndexable, + RepackedExtension, TowerField, }; use binius_math::{IsomorphicEvaluationDomainFactory, MultilinearExtension}; use bytemuck::zeroed_vec; @@ -24,15 +24,11 @@ use crate::{ witness::MultilinearExtensionIndex, }; -fn generate_poly_helper( +fn generate_poly_helper, F: Field>( rng: &mut StdRng, n_vars: usize, n_multilinears: usize, -) -> Vec<(MultilinearExtension

, F)> -where - P: PackedField>, - F: Field, -{ +) -> Vec<(MultilinearExtension

, F)> { repeat_with(|| { let values = repeat_with(|| F::random(&mut *rng)) .take(1 << n_vars) @@ -119,7 +115,7 @@ fn run_prove_verify_batch_test() where U: UnderlierType + PackScalar, P: PackedExtension + RepackedExtension

+ PackedFieldIndexable, - F: TowerField + ExtensionField, + F: TowerField, FS: TowerField, { let rng = StdRng::seed_from_u64(0); diff --git a/crates/core/src/protocols/gkr_gpa/verify.rs b/crates/core/src/protocols/gkr_gpa/verify.rs index 056f4cb7..a9d96366 100644 --- a/crates/core/src/protocols/gkr_gpa/verify.rs +++ b/crates/core/src/protocols/gkr_gpa/verify.rs @@ -54,7 +54,7 @@ where &mut reverse_sorted_evalcheck_claims, ); - layer_claims = reduce_layer_claim_batch(layer_claims, transcript)?; + layer_claims = reduce_layer_claim_batch(&layer_claims, transcript)?; } process_finished_claims( n_claims, @@ -102,7 +102,7 @@ fn process_finished_claims( /// * `proof` - The batch layer proof that reduces the kth layer claims of the product circuits to the (k+1)th /// * `transcript` - The verifier transcript fn reduce_layer_claim_batch( - claims: Vec>, + claims: &[LayerClaim], transcript: &mut VerifierTranscript, ) -> Result>, Error> where diff --git a/crates/core/src/protocols/gkr_int_mul/error.rs b/crates/core/src/protocols/gkr_int_mul/error.rs index caa824f0..6da9cf33 100644 --- a/crates/core/src/protocols/gkr_int_mul/error.rs +++ b/crates/core/src/protocols/gkr_int_mul/error.rs @@ -13,4 +13,12 @@ pub enum Error { SumcheckError(#[from] SumcheckError), #[error("polynomial error: {0}")] Polynomial(#[from] PolynomialError), + #[error("verification failure: {0}")] + Verification(#[from] VerificationError), +} + +#[derive(Debug, thiserror::Error)] +pub enum VerificationError { + #[error("the proof contains an incorrect evaluation of the eq indicator")] + IncorrectEqIndEvaluation, } diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/compositions.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/compositions.rs index a1525158..35cbebf9 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/compositions.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/compositions.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{Field, PackedField}; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::bail; #[derive(Debug)] @@ -12,7 +12,7 @@ where pub generator_power_constant: F, } -impl CompositionPolyOS

for MultiplyOrDont { +impl CompositionPoly

for MultiplyOrDont { fn n_vars(&self) -> usize { 2 } diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs index 070d0637..f6c0ae38 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs @@ -3,7 +3,7 @@ use std::array; use binius_field::{ - BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, + BinaryField1b, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField, }; use binius_hal::ComputationBackend; @@ -40,18 +40,17 @@ pub fn prove< backend: &Backend, ) -> Result, Error> where + F: ExtensionField + ExtensionField + TowerField, FDomain: Field, - PBits: PackedField, - PGenerator: PackedExtension - + PackedFieldIndexable + FGenerator: TowerField + ExtensionField + ExtensionField, + PBits: PackedField, + PGenerator: PackedField + + PackedExtension + PackedExtension, - PGenerator::Scalar: ExtensionField + ExtensionField, - PChallenge: PackedField - + PackedFieldIndexable + PChallenge: PackedFieldIndexable + + PackedExtension + PackedExtension + PackedExtension, - F: ExtensionField + ExtensionField + BinaryField + TowerField, - FGenerator: Field + TowerField, Backend: ComputationBackend, Challenger_: Challenger, { @@ -60,10 +59,11 @@ where let mut eval_point = claim.eval_point.clone(); let mut eval = claim.eval; + for exponent_bit_number in (1..EXPONENT_BIT_WIDTH).rev() { let this_round_exponent_bit = witness.exponent[exponent_bit_number].clone(); let this_round_generator_power_constant = - F::from(FGenerator::MULTIPLICATIVE_GENERATOR.pow([1 << exponent_bit_number])); + F::from(FGenerator::MULTIPLICATIVE_GENERATOR.pow(1 << exponent_bit_number)); let this_round_input_data = witness.single_bit_output_layers_data[exponent_bit_number - 1].clone(); diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs index 97508b59..421b1efd 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs @@ -124,7 +124,7 @@ fn witness_gen_happens_correctly() { for (row_idx, this_row_exponent) in exponent.into_iter().enumerate() { assert_eq!( ::Scalar::MULTIPLICATIVE_GENERATOR - .pow([this_row_exponent as u64]), + .pow(this_row_exponent as u64), get_packed_slice(results, row_idx) ); } diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs index 1024481a..66df9910 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs @@ -3,7 +3,6 @@ use std::array; use binius_field::{ExtensionField, TowerField}; -use binius_utils::bail; use super::{ super::error::Error, common::GeneratorExponentReductionOutput, utils::first_layer_inverse, @@ -13,7 +12,7 @@ use crate::{ polynomial::MultivariatePoly, protocols::{ gkr_gpa::LayerClaim, - gkr_int_mul::generator_exponent::compositions::MultiplyOrDont, + gkr_int_mul::{error::VerificationError, generator_exponent::compositions::MultiplyOrDont}, sumcheck::{self, zerocheck::ExtraProduct, CompositeSumClaim, SumcheckClaim}, }, transcript::VerifierTranscript, @@ -78,7 +77,7 @@ where EqIndPartialEval::new(log_size, sumcheck_query_point.clone())?.evaluate(&eval_point)?; if sumcheck_verification_output.multilinear_evals[0][2] != eq_eval { - bail!(Error::EqEvalDoesntVerify) + return Err(VerificationError::IncorrectEqIndEvaluation.into()); } eval_claims_on_bit_columns[exponent_bit_number] = LayerClaim { diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/witness.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/witness.rs index 885c47f9..569c8f49 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/witness.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/witness.rs @@ -3,8 +3,7 @@ use std::{array, cmp::min, slice}; use binius_field::{ - ext_base_op_par, BinaryField, BinaryField1b, ExtensionField, Field, PackedExtension, - PackedField, PackedFieldIndexable, + ext_base_op_par, BinaryField, BinaryField1b, ExtensionField, PackedExtension, PackedField, }; use binius_maybe_rayon::{ prelude::{IndexedParallelIterator, ParallelIterator}, @@ -28,7 +27,7 @@ pub struct GeneratorExponentWitness< fn copy_witness_into_vec(poly: &MultilinearWitness) -> Vec

where P: PackedField, - PE: PackedField + PackedExtension, + PE: PackedExtension, PE::Scalar: ExtensionField, { let mut input_layer: Vec

= zeroed_vec(1 << poly.n_vars().saturating_sub(P::LOG_WIDTH)); @@ -68,9 +67,8 @@ fn evaluate_single_bit_output_packed( ) -> Vec where PBits: PackedField, - PGenerator: - PackedField + PackedFieldIndexable + PackedExtension, - PGenerator::Scalar: ExtensionField + BinaryField, + PGenerator: PackedExtension, + PGenerator::Scalar: BinaryField, { debug_assert_eq!( PBits::WIDTH * exponent_bit.len(), @@ -98,9 +96,8 @@ fn evaluate_first_layer_output_packed( ) -> Vec where PBits: PackedField, - PGenerator: - PackedField + PackedFieldIndexable + PackedExtension, - PGenerator::Scalar: ExtensionField, + PGenerator: PackedExtension, + PGenerator::Scalar: BinaryField, { let mut result = vec![PGenerator::zero(); exponent_bit.len() * PGenerator::Scalar::DEGREE]; @@ -119,11 +116,10 @@ impl<'a, PBits, PGenerator, PChallenge, const EXPONENT_BIT_WIDTH: usize> GeneratorExponentWitness<'a, PBits, PGenerator, PChallenge, EXPONENT_BIT_WIDTH> where PBits: PackedField, - PGenerator: - PackedField + PackedFieldIndexable + PackedExtension, - PGenerator::Scalar: ExtensionField + BinaryField, - PChallenge: PackedField + PackedExtension, - PChallenge::Scalar: ExtensionField, + PGenerator: PackedExtension, + PGenerator::Scalar: BinaryField, + PChallenge: PackedExtension, + PChallenge::Scalar: BinaryField, { pub fn new( exponent: [MultilinearWitness<'a, PChallenge>; EXPONENT_BIT_WIDTH], @@ -140,12 +136,15 @@ where PGenerator::Scalar::MULTIPLICATIVE_GENERATOR, ); + let mut generator_power_constant = PGenerator::Scalar::MULTIPLICATIVE_GENERATOR.square(); + for layer_idx_from_left in 1..EXPONENT_BIT_WIDTH { single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed( &exponent_data[layer_idx_from_left], - PGenerator::Scalar::MULTIPLICATIVE_GENERATOR.pow([1 << layer_idx_from_left]), + generator_power_constant, &single_bit_output_layers_data[layer_idx_from_left - 1], - ) + ); + generator_power_constant = generator_power_constant.square(); } Ok(Self { diff --git a/crates/core/src/protocols/gkr_int_mul/mod.rs b/crates/core/src/protocols/gkr_int_mul/mod.rs index 1e8861bf..02f0dbf2 100644 --- a/crates/core/src/protocols/gkr_int_mul/mod.rs +++ b/crates/core/src/protocols/gkr_int_mul/mod.rs @@ -1,4 +1,4 @@ // Copyright 2024-2025 Irreducible Inc. mod error; -//pub mod generator_exponent; +pub mod generator_exponent; diff --git a/crates/core/src/protocols/sumcheck/common.rs b/crates/core/src/protocols/sumcheck/common.rs index c0d1e2cc..8a9bb6f2 100644 --- a/crates/core/src/protocols/sumcheck/common.rs +++ b/crates/core/src/protocols/sumcheck/common.rs @@ -6,7 +6,7 @@ use binius_field::{ util::{inner_product_unchecked, powers}, ExtensionField, Field, PackedField, }; -use binius_math::{CompositionPolyOS, MultilinearPoly}; +use binius_math::{CompositionPoly, MultilinearPoly}; use binius_utils::bail; use getset::{CopyGetters, Getters}; use tracing::instrument; @@ -45,7 +45,7 @@ pub struct SumcheckClaim { impl SumcheckClaim where - Composition: CompositionPolyOS, + Composition: CompositionPoly, { /// Constructs a new sumcheck claim. /// diff --git a/crates/core/src/protocols/sumcheck/error.rs b/crates/core/src/protocols/sumcheck/error.rs index a0195e68..848123a2 100644 --- a/crates/core/src/protocols/sumcheck/error.rs +++ b/crates/core/src/protocols/sumcheck/error.rs @@ -45,7 +45,7 @@ pub enum Error { oracle: String, hypercube_index: usize, }, - #[error("constraint set containts multilinears of different heights")] + #[error("constraint set contains multilinears of different heights")] ConstraintSetNumberOfVariablesMismatch, #[error("batching sumchecks and zerochecks is not supported yet")] MixedBatchingNotSupported, diff --git a/crates/core/src/protocols/sumcheck/front_loaded.rs b/crates/core/src/protocols/sumcheck/front_loaded.rs index bdba716a..3b1c4c93 100644 --- a/crates/core/src/protocols/sumcheck/front_loaded.rs +++ b/crates/core/src/protocols/sumcheck/front_loaded.rs @@ -3,7 +3,7 @@ use std::{cmp, cmp::Ordering, collections::VecDeque, iter}; use binius_field::{Field, TowerField}; -use binius_math::{evaluate_univariate, CompositionPolyOS}; +use binius_math::{evaluate_univariate, CompositionPoly}; use binius_utils::sorting::is_sorted_ascending; use bytes::Buf; @@ -60,7 +60,7 @@ pub struct BatchVerifier { impl BatchVerifier where F: TowerField, - C: CompositionPolyOS + Clone, + C: CompositionPoly + Clone, { /// Constructs a new verifier for the front-loaded batched sumcheck. /// diff --git a/crates/core/src/protocols/sumcheck/prove/concrete_prover.rs b/crates/core/src/protocols/sumcheck/prove/concrete_prover.rs deleted file mode 100644 index b67e99c2..00000000 --- a/crates/core/src/protocols/sumcheck/prove/concrete_prover.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2024-2025 Irreducible Inc. - -use binius_field::{ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable}; -use binius_hal::ComputationBackend; -use binius_math::{CompositionPolyOS, MultilinearPoly}; - -use super::{batch_prove::SumcheckProver, RegularSumcheckProver, ZerocheckProver}; -use crate::protocols::sumcheck::{common::RoundCoeffs, error::Error}; - -/// A sum type that is used to put both regular sumchecks and zerochecks into the same `batch_prove` call. -pub enum ConcreteProver<'a, FDomain, PBase, P, CompositionBase, Composition, M, Backend> -where - FDomain: Field, - PBase: PackedField, - P: PackedField, - M: MultilinearPoly

+ Send + Sync, - Backend: ComputationBackend, -{ - Sumcheck(RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>), - Zerocheck(ZerocheckProver<'a, FDomain, PBase, P, CompositionBase, Composition, M, Backend>), -} - -impl SumcheckProver - for ConcreteProver<'_, FDomain, FBase, P, CompositionBase, Composition, M, Backend> -where - F: Field + ExtensionField + ExtensionField, - FDomain: Field, - FBase: ExtensionField, - P: PackedFieldIndexable - + PackedExtension - + PackedExtension - + PackedExtension, - CompositionBase: CompositionPolyOS<

>::PackedSubfield>, - Composition: CompositionPolyOS

, - M: MultilinearPoly

+ Send + Sync, - Backend: ComputationBackend, -{ - fn n_vars(&self) -> usize { - match self { - ConcreteProver::Sumcheck(prover) => prover.n_vars(), - ConcreteProver::Zerocheck(prover) => prover.n_vars(), - } - } - - fn execute(&mut self, batch_coeff: F) -> Result, Error> { - match self { - ConcreteProver::Sumcheck(prover) => prover.execute(batch_coeff), - ConcreteProver::Zerocheck(prover) => prover.execute(batch_coeff), - } - } - - fn fold(&mut self, challenge: F) -> Result<(), Error> { - match self { - ConcreteProver::Sumcheck(prover) => prover.fold(challenge), - ConcreteProver::Zerocheck(prover) => prover.fold(challenge), - } - } - - fn finish(self: Box) -> Result, Error> { - match *self { - ConcreteProver::Sumcheck(prover) => Box::new(prover).finish(), - ConcreteProver::Zerocheck(prover) => Box::new(prover).finish(), - } - } -} diff --git a/crates/core/src/protocols/sumcheck/prove/mod.rs b/crates/core/src/protocols/sumcheck/prove/mod.rs index 94029d71..44c1f587 100644 --- a/crates/core/src/protocols/sumcheck/prove/mod.rs +++ b/crates/core/src/protocols/sumcheck/prove/mod.rs @@ -3,7 +3,6 @@ mod batch_prove; mod batch_prove_univariate_zerocheck; pub(crate) mod common; -mod concrete_prover; pub mod front_loaded; pub mod oracles; pub mod prover_state; @@ -15,7 +14,6 @@ pub use batch_prove::{batch_prove, batch_prove_with_start, SumcheckProver}; pub use batch_prove_univariate_zerocheck::{ batch_prove_zerocheck_univariate_round, UnivariateZerocheckProver, }; -pub use concrete_prover::ConcreteProver; pub use oracles::{ constraint_set_sumcheck_prover, constraint_set_zerocheck_prover, split_constraint_set, }; diff --git a/crates/core/src/protocols/sumcheck/prove/oracles.rs b/crates/core/src/protocols/sumcheck/prove/oracles.rs index 8b34f4e5..c6f4b798 100644 --- a/crates/core/src/protocols/sumcheck/prove/oracles.rs +++ b/crates/core/src/protocols/sumcheck/prove/oracles.rs @@ -54,7 +54,7 @@ where + PackedExtension + PackedExtension + PackedExtension, - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, FBase: TowerField + ExtensionField + TryFrom, FDomain: Field, Backend: ComputationBackend, diff --git a/crates/core/src/protocols/sumcheck/prove/prover_state.rs b/crates/core/src/protocols/sumcheck/prove/prover_state.rs index d05e7252..d88272f1 100644 --- a/crates/core/src/protocols/sumcheck/prove/prover_state.rs +++ b/crates/core/src/protocols/sumcheck/prove/prover_state.rs @@ -5,10 +5,10 @@ use std::{ sync::atomic::{AtomicBool, Ordering}, }; -use binius_field::{util::powers, ExtensionField, Field, PackedExtension, PackedField}; +use binius_field::{util::powers, Field, PackedExtension, PackedField}; use binius_hal::{ComputationBackend, RoundEvals, SumcheckEvaluator, SumcheckMultilinear}; use binius_math::{ - evaluate_univariate, CompositionPolyOS, MLEDirectAdapter, MultilinearPoly, MultilinearQuery, + evaluate_univariate, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQuery, }; use binius_maybe_rayon::prelude::*; use binius_utils::bail; @@ -70,8 +70,8 @@ where impl<'a, FDomain, F, P, M, Backend> ProverState<'a, FDomain, P, M, Backend> where FDomain: Field, - F: Field + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, + F: Field, + P: PackedField + PackedExtension, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -191,8 +191,11 @@ where ref mut large_field_folded_multilinear, } => { // Post-switchover, simply plug in challenge for the zeroth variable. + let single_variable_query = MultilinearQuery::expand(&[challenge]); *large_field_folded_multilinear = MLEDirectAdapter::from( - large_field_folded_multilinear.evaluate_zeroth_variable(challenge)?, + large_field_folded_multilinear + .as_ref() + .evaluate_partial_low(single_variable_query.to_ref())?, ); } }; @@ -242,41 +245,17 @@ where .collect() } - /// Calculate the accumulated evaluations for the first sumcheck round. - #[instrument(skip_all, level = "debug")] - pub fn calculate_first_round_evals( - &self, - evaluators: &[Evaluator], - ) -> Result>, Error> - where - FBase: ExtensionField, - F: ExtensionField, - P: PackedExtension, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, - { - Ok(self.backend.sumcheck_compute_first_round_evals( - self.n_vars, - &self.multilinears, - evaluators, - &self.evaluation_points, - )?) - } - /// Calculate the accumulated evaluations for an arbitrary sumcheck round. - /// - /// See [`Self::calculate_first_round_evals`] for an optimized version of this method that - /// operates over small fields in the first round. #[instrument(skip_all, level = "debug")] - pub fn calculate_later_round_evals( + pub fn calculate_round_evals( &self, evaluators: &[Evaluator], ) -> Result>, Error> where - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, + Evaluator: SumcheckEvaluator + Sync, + Composition: CompositionPoly

, { - Ok(self.backend.sumcheck_compute_later_round_evals( + Ok(self.backend.sumcheck_compute_round_evals( self.n_vars, self.tensor_query.as_ref().map(Into::into), &self.multilinears, diff --git a/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs b/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs index a3695046..2a6ccb3f 100644 --- a/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs @@ -2,11 +2,9 @@ use std::{marker::PhantomData, ops::Range}; -use binius_field::{ExtensionField, Field, PackedExtension, PackedField}; +use binius_field::{Field, PackedExtension, PackedField}; use binius_hal::{ComputationBackend, SumcheckEvaluator}; -use binius_math::{ - CompositionPolyOS, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly, -}; +use binius_math::{CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly}; use binius_maybe_rayon::prelude::*; use binius_utils::bail; use itertools::izip; @@ -31,7 +29,7 @@ where F: Field, P: PackedField, M: MultilinearPoly

+ Send + Sync, - Composition: CompositionPolyOS

+ 'a, + Composition: CompositionPoly

+ 'a, { let n_vars = multilinears .first() @@ -82,10 +80,10 @@ where impl<'a, F, FDomain, P, Composition, M, Backend> RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedField + PackedExtension + PackedExtension, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -142,8 +140,8 @@ where let evaluation_points = domains .iter() - .max_by_key(|domain| domain.points().len()) - .map_or_else(|| Vec::new(), |domain| domain.points().to_vec()); + .max_by_key(|domain| domain.size()) + .map_or_else(|| Vec::new(), |domain| domain.finite_points().to_vec()); let state = ProverState::new( multilinears, @@ -166,10 +164,10 @@ where impl SumcheckProver for RegularSumcheckProver<'_, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedField + PackedExtension + PackedExtension, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -193,7 +191,7 @@ where }) .collect::>(); - let evals = self.state.calculate_later_round_evals(&evaluators)?; + let evals = self.state.calculate_round_evals(&evaluators)?; self.state .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals) } @@ -213,13 +211,13 @@ where _marker: PhantomData

, } -impl SumcheckEvaluator +impl SumcheckEvaluator for RegularSumcheckEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension + PackedExtension, FDomain: Field, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn eval_point_indices(&self) -> Range { // NB: We skip evaluation of $r(X)$ at $X = 0$ as it is derivable from the @@ -256,7 +254,7 @@ where impl SumcheckInterpolator for RegularSumcheckEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension, FDomain: Field, { diff --git a/crates/core/src/protocols/sumcheck/prove/univariate.rs b/crates/core/src/protocols/sumcheck/prove/univariate.rs index 4d77666b..fbcf3efe 100644 --- a/crates/core/src/protocols/sumcheck/prove/univariate.rs +++ b/crates/core/src/protocols/sumcheck/prove/univariate.rs @@ -9,7 +9,7 @@ use binius_field::{ }; use binius_hal::{ComputationBackend, ComputationBackendExt}; use binius_math::{ - CompositionPolyOS, Error as MathError, EvaluationDomainFactory, + CompositionPoly, Error as MathError, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearPoly, }; use binius_maybe_rayon::prelude::*; @@ -95,7 +95,7 @@ pub fn univariatizing_reduction_prover<'a, F, FDomain, P, Backend>( backend: &'a Backend, ) -> Result, Error> where - F: TowerField + ExtensionField, + F: TowerField, FDomain: TowerField, P: PackedFieldIndexable + PackedExtension @@ -132,12 +132,7 @@ where } #[derive(Debug)] -struct ParFoldStates -where - FBase: Field + PackedField, - P: PackedField + PackedExtension, - P::Scalar: ExtensionField, -{ +struct ParFoldStates> { /// Evaluations of a multilinear subcube, embedded into P (see MultilinearPoly::subcube_evals). Scratch space. evals: Vec

, /// `evals` cast to base field and transposed to 2^skip_rounds * 2^log_batch row-major form. Scratch space. @@ -150,12 +145,7 @@ where round_evals: Vec>, } -impl ParFoldStates -where - FBase: Field, - P: PackedField + PackedExtension, - P::Scalar: ExtensionField, -{ +impl> ParFoldStates { fn new( n_multilinears: usize, skip_rounds: usize, @@ -335,11 +325,11 @@ pub fn zerocheck_univariate_evals where FDomain: TowerField, FBase: ExtensionField, - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, P: PackedFieldIndexable + PackedExtension + PackedExtension, - Composition: CompositionPolyOS>, + Composition: CompositionPoly>, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -388,7 +378,7 @@ where // univariatized subcube. // NB: expansion of the first `skip_rounds` variables is applied to the round evals sum let partial_eq_ind_evals = backend.tensor_product_full_query(zerocheck_challenges)?; - let partial_eq_ind_evals_scalars = P::unpack_scalars(&partial_eq_ind_evals[..]); + let partial_eq_ind_evals_scalars = P::unpack_scalars(&partial_eq_ind_evals); // Evaluate each composition on a minimal packed prefix corresponding to the degree let pbase_prefix_lens = composition_degrees @@ -754,7 +744,7 @@ mod tests { }; use binius_hal::make_portable_backend; use binius_math::{ - CompositionPolyOS, DefaultEvaluationDomainFactory, EvaluationDomainFactory, MultilinearPoly, + CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, MultilinearPoly, }; use binius_ntt::SingleThreadedNTT; use rand::{prelude::StdRng, SeedableRng}; @@ -823,7 +813,7 @@ mod tests { .map(|i| interleaved_scalars[(i << log_batch) + batch_idx]) .collect::>(); - for (i, &point) in max_domain.points()[1 << skip_rounds..] + for (i, &point) in max_domain.finite_points()[1 << skip_rounds..] [..extrapolated_scalars_cnt] .iter() .enumerate() @@ -889,11 +879,11 @@ mod tests { let compositions = [ Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap()) - as Arc>>, + as Arc>>, Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap()) - as Arc>>, + as Arc>>, Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap()) - as Arc>>, + as Arc>>, ]; let backend = make_portable_backend(); diff --git a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs index c3325604..bb9867c6 100644 --- a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs @@ -9,7 +9,7 @@ use binius_field::{ }; use binius_hal::{ComputationBackend, SumcheckEvaluator}; use binius_math::{ - CompositionPolyOS, EvaluationDomainFactory, InterpolationDomain, MLEDirectAdapter, + CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MLEDirectAdapter, MultilinearPoly, MultilinearQuery, }; use binius_maybe_rayon::prelude::*; @@ -23,7 +23,7 @@ use tracing::instrument; use crate::{ polynomial::{Error as PolynomialError, MultilinearComposite}, protocols::sumcheck::{ - common::{determine_switchovers, equal_n_vars_check, small_field_embedding_degree_check}, + common::{determine_switchovers, equal_n_vars_check}, prove::{ common::fold_partial_eq_ind, univariate::{ @@ -41,13 +41,13 @@ use crate::{ pub fn validate_witness<'a, F, P, M, Composition>( multilinears: &[M], - zero_claims: impl IntoIterator, Composition)>, + zero_claims: impl IntoIterator, ) -> Result<(), Error> where F: Field, P: PackedField, M: MultilinearPoly

+ Send + Sync, - Composition: CompositionPolyOS

+ 'a, + Composition: CompositionPoly

+ 'a, { let n_vars = multilinears .first() @@ -99,7 +99,7 @@ where #[getset(get = "pub")] multilinears: Vec, switchover_rounds: Vec, - compositions: Vec<(Arc, CompositionBase, Composition)>, + compositions: Vec<(String, CompositionBase, Composition)>, zerocheck_challenges: Vec, domains: Vec>, backend: &'a Backend, @@ -111,21 +111,21 @@ where impl<'a, 'm, F, FDomain, FBase, P, CompositionBase, Composition, M, Backend> UnivariateZerocheck<'a, 'm, FDomain, FBase, P, CompositionBase, Composition, M, Backend> where - F: Field + ExtensionField + ExtensionField, + F: Field, FDomain: Field, FBase: ExtensionField, P: PackedFieldIndexable + PackedExtension + PackedExtension + PackedExtension, - CompositionBase: CompositionPolyOS<

>::PackedSubfield>, - Composition: CompositionPolyOS

, + CompositionBase: CompositionPoly<

>::PackedSubfield>, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync + 'm, Backend: ComputationBackend, { pub fn new( multilinears: Vec, - zero_claims: impl IntoIterator, CompositionBase, Composition)>, + zero_claims: impl IntoIterator, zerocheck_challenges: &[F], evaluation_domain_factory: impl EvaluationDomainFactory, switchover_fn: impl Fn(usize) -> usize, @@ -154,8 +154,6 @@ where validate_witness(&multilinears, &compositions)?; } - small_field_embedding_degree_check::<_, FBase, P, _>(&multilinears)?; - let switchover_rounds = determine_switchovers(&multilinears, switchover_fn); let zerocheck_challenges = zerocheck_challenges.to_vec(); @@ -188,16 +186,7 @@ where pub fn into_regular_zerocheck( self, ) -> Result< - ZerocheckProver< - 'a, - FDomain, - FBase, - P, - CompositionBase, - Composition, - MultilinearWitness<'m, P>, - Backend, - >, + ZerocheckProver<'a, FDomain, P, Composition, MultilinearWitness<'m, P>, Backend>, Error, > { if self.univariate_evals_output.is_some() { @@ -224,26 +213,29 @@ where validate_witness(&multilinears, &compositions)?; } + let compositions = self + .compositions + .into_iter() + .map(|(_, _, composition)| composition) + .collect::>(); + // Evaluate zerocheck partial indicator in variables 1..n_vars let start = self.n_vars.min(1); let partial_eq_ind_evals = self .backend .tensor_product_full_query(&self.zerocheck_challenges[start..])?; - let claimed_sums = vec![F::ZERO; self.compositions.len()]; + let claimed_sums = vec![F::ZERO; compositions.len()]; // This is a regular multilinear zerocheck constructor, split over two creation stages. ZerocheckProver::new( multilinears, - self.switchover_rounds, - self.compositions - .into_iter() - .map(|(_, a, b)| (a, b)) - .collect(), + &self.switchover_rounds, + compositions, partial_eq_ind_evals, self.zerocheck_challenges, claimed_sums, self.domains, - RegularFirstRound::BaseField, + RegularFirstRound::SkipCube, self.backend, ) } @@ -253,15 +245,15 @@ impl<'a, 'm, F, FDomain, FBase, P, CompositionBase, Composition, M, Backend> UnivariateZerocheckProver<'a, F> for UnivariateZerocheck<'a, 'm, FDomain, FBase, P, CompositionBase, Composition, M, Backend> where - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, FDomain: TowerField, FBase: ExtensionField, P: PackedFieldIndexable + PackedExtension + PackedExtension + PackedExtension, - CompositionBase: CompositionPolyOS> + 'static, - Composition: CompositionPolyOS

+ 'static, + CompositionBase: CompositionPoly> + 'static, + Composition: CompositionPoly

+ 'static, M: MultilinearPoly

+ Send + Sync + 'm, Backend: ComputationBackend, { @@ -383,27 +375,30 @@ where .switchover_rounds .into_iter() .map(|switchover_round| switchover_round.saturating_sub(skip_rounds)) - .collect(); + .collect::>(); let zerocheck_challenges = self.zerocheck_challenges.clone(); + let compositions = self + .compositions + .into_iter() + .map(|(_, _, composition)| composition) + .collect(); + // This is also regular multilinear zerocheck constructor, but "jump started" in round // `skip_rounds` while using witness with a projected univariate round. - // NB: first round evaluator has to be overriden due to issues proving + // NB: first round evaluator has to be overridden due to issues proving // `P: RepackedExtension

` relation in the generic context, as well as the need // to use later round evaluator (as this _is_ a "later" round, albeit numbered at zero) - let regular_prover = ZerocheckProver::<_, FBase, _, _, _, _, _>::new( + let regular_prover = ZerocheckProver::new( partial_low_multilinears, - switchover_rounds, - self.compositions - .into_iter() - .map(|(_, a, b)| (a, b)) - .collect(), + &switchover_rounds, + compositions, partial_eq_ind_evals, zerocheck_challenges, claimed_prime_sums, self.domains, - RegularFirstRound::LargeField, + RegularFirstRound::LaterRound, self.backend, )?; @@ -413,8 +408,8 @@ where #[derive(Debug, Clone, Copy)] enum RegularFirstRound { - BaseField, - LargeField, + SkipCube, + LaterRound, } /// A "regular" multilinear zerocheck prover. @@ -432,10 +427,9 @@ enum RegularFirstRound { /// /// [Gruen24]: #[derive(Debug)] -pub struct ZerocheckProver<'a, FDomain, FBase, P, CompositionBase, Composition, M, Backend> +pub struct ZerocheckProver<'a, FDomain, P, Composition, M, Backend> where FDomain: Field, - FBase: PackedField, P: PackedField, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, @@ -445,32 +439,26 @@ where eq_ind_eval: P::Scalar, partial_eq_ind_evals: Backend::Vec

, zerocheck_challenges: Vec, - compositions: Vec<(CompositionBase, Composition)>, + compositions: Vec, domains: Vec>, first_round: RegularFirstRound, - _f_base_marker: PhantomData, } -impl<'a, F, FDomain, FBase, P, CompositionBase, Composition, M, Backend> - ZerocheckProver<'a, FDomain, FBase, P, CompositionBase, Composition, M, Backend> +impl<'a, F, FDomain, P, Composition, M, Backend> + ZerocheckProver<'a, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField + ExtensionField, + F: Field, FDomain: Field, - FBase: ExtensionField, - P: PackedFieldIndexable - + PackedExtension - + PackedExtension - + PackedExtension, - CompositionBase: CompositionPolyOS>, - Composition: CompositionPolyOS

, + P: PackedFieldIndexable + PackedExtension, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { #[allow(clippy::too_many_arguments)] fn new( multilinears: Vec, - switchover_rounds: Vec, - compositions: Vec<(CompositionBase, Composition)>, + switchover_rounds: &[usize], + compositions: Vec, partial_eq_ind_evals: Backend::Vec

, zerocheck_challenges: Vec, claimed_prime_sums: Vec, @@ -480,8 +468,8 @@ where ) -> Result { let evaluation_points = domains .iter() - .max_by_key(|domain| domain.points().len()) - .map_or_else(|| Vec::new(), |domain| domain.points().to_vec()); + .max_by_key(|domain| domain.size()) + .map_or_else(|| Vec::new(), |domain| domain.finite_points().to_vec()); if claimed_prime_sums.len() != compositions.len() { bail!(Error::IncorrectClaimedPrimeSumsLength); @@ -489,7 +477,7 @@ where let state = ProverState::new_with_switchover_rounds( multilinears, - &switchover_rounds, + switchover_rounds, claimed_prime_sums, evaluation_points, backend, @@ -517,7 +505,6 @@ where compositions, domains, first_round, - _f_base_marker: PhantomData, }) } @@ -547,18 +534,13 @@ where } } -impl SumcheckProver - for ZerocheckProver<'_, FDomain, FBase, P, CompositionBase, Composition, M, Backend> +impl SumcheckProver + for ZerocheckProver<'_, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField + ExtensionField, + F: Field, FDomain: Field, - FBase: ExtensionField, - P: PackedFieldIndexable - + PackedExtension - + PackedExtension - + PackedExtension, - CompositionBase: CompositionPolyOS<

>::PackedSubfield>, - Composition: CompositionPolyOS

, + P: PackedFieldIndexable + PackedExtension, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -580,33 +562,29 @@ where #[instrument(skip_all, name = "ZerocheckProver::execute", level = "debug")] fn execute(&mut self, batch_coeff: F) -> Result, Error> { let round = self.round(); - let base_field_first_round = - round == 0 && matches!(self.first_round, RegularFirstRound::BaseField); - let coeffs = if base_field_first_round { + let skip_cube_first_round = + round == 0 && matches!(self.first_round, RegularFirstRound::SkipCube); + let coeffs = if skip_cube_first_round { let evaluators = izip!(&self.compositions, &self.domains) - .map(|((composition_base, composition), interpolation_domain)| { - ZerocheckFirstRoundEvaluator { - composition_base, - composition, - interpolation_domain, - partial_eq_ind_evals: &self.partial_eq_ind_evals, - _f_base_marker: PhantomData::, - } + .map(|(composition, interpolation_domain)| ZerocheckFirstRoundEvaluator { + composition, + interpolation_domain, + partial_eq_ind_evals: &self.partial_eq_ind_evals, }) .collect::>(); - let evals = self.state.calculate_first_round_evals(&evaluators)?; + let evals = self.state.calculate_round_evals(&evaluators)?; self.state .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)? } else { let evaluators = izip!(&self.compositions, &self.domains) - .map(|((_, composition), interpolation_domain)| ZerocheckLaterRoundEvaluator { + .map(|(composition, interpolation_domain)| ZerocheckLaterRoundEvaluator { composition, interpolation_domain, partial_eq_ind_evals: &self.partial_eq_ind_evals, round_zerocheck_challenge: self.zerocheck_challenges[round], }) .collect::>(); - let evals = self.state.calculate_later_round_evals(&evaluators)?; + let evals = self.state.calculate_round_evals(&evaluators)?; self.state .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)? }; @@ -636,28 +614,22 @@ where } } -struct ZerocheckFirstRoundEvaluator<'a, P, FBase, FDomain, CompositionBase, Composition> +struct ZerocheckFirstRoundEvaluator<'a, P, FDomain, Composition> where P: PackedField, - FBase: Field, FDomain: Field, { - composition_base: &'a CompositionBase, composition: &'a Composition, interpolation_domain: &'a InterpolationDomain, partial_eq_ind_evals: &'a [P], - _f_base_marker: PhantomData, } -impl SumcheckEvaluator - for ZerocheckFirstRoundEvaluator<'_, P, FBase, FDomain, CompositionBase, Composition> +impl SumcheckEvaluator + for ZerocheckFirstRoundEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField + ExtensionField, - FBase: Field, - P: PackedField + PackedExtension, + P: PackedField>, FDomain: Field, - CompositionBase: CompositionPolyOS>, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn eval_point_indices(&self) -> Range { // In the first round of zerocheck we can uniquely determine the degree d @@ -670,7 +642,7 @@ where &self, subcube_vars: usize, subcube_index: usize, - batch_query: &[&[PackedSubfield]], + batch_query: &[&[P]], ) -> P { // If the composition is a linear polynomial, then the composite multivariate polynomial // is multilinear. If the prover is honest, then this multilinear is identically zero, @@ -681,7 +653,7 @@ where let row_len = batch_query.first().map_or(0, |row| row.len()); stackalloc_with_default(row_len, |evals| { - self.composition_base + self.composition .batch_evaluate(batch_query, evals) .expect("correct by query construction invariant"); @@ -705,13 +677,12 @@ where } } -impl SumcheckInterpolator - for ZerocheckFirstRoundEvaluator<'_, P, FBase, FDomain, CompositionBase, Composition> +impl SumcheckInterpolator + for ZerocheckFirstRoundEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField + ExtensionField, - FBase: Field, - FDomain: Field, + F: Field + ExtensionField, P: PackedField, + FDomain: Field, { fn round_evals_to_coeffs( &self, @@ -741,13 +712,12 @@ where round_zerocheck_challenge: P::Scalar, } -impl SumcheckEvaluator +impl SumcheckEvaluator for ZerocheckLaterRoundEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, + P: PackedField>, FDomain: Field, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn eval_point_indices(&self) -> Range { // We can uniquely derive the degree d univariate round polynomial r from evaluations at @@ -796,7 +766,7 @@ where impl SumcheckInterpolator for ZerocheckLaterRoundEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension, FDomain: Field, { diff --git a/crates/core/src/protocols/sumcheck/tests.rs b/crates/core/src/protocols/sumcheck/tests.rs index 667f6316..a7720964 100644 --- a/crates/core/src/protocols/sumcheck/tests.rs +++ b/crates/core/src/protocols/sumcheck/tests.rs @@ -16,7 +16,7 @@ use binius_field::{ }; use binius_hal::{make_portable_backend, ComputationBackend, ComputationBackendExt}; use binius_math::{ - ArithExpr, CompositionPolyOS, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, + ArithExpr, CompositionPoly, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly, MultilinearQuery, }; use binius_maybe_rayon::{current_num_threads, prelude::*}; @@ -50,7 +50,7 @@ struct PowerComposition { exponent: usize, } -impl CompositionPolyOS

for PowerComposition { +impl CompositionPoly

for PowerComposition { fn n_vars(&self) -> usize { 1 } @@ -103,7 +103,7 @@ fn compute_composite_sum( where P: PackedField, M: MultilinearPoly

+ Send + Sync, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { let n_vars = multilinears .first() @@ -263,11 +263,11 @@ fn make_test_sumcheck<'a, F, FDomain, P, PExt, Backend>( backend: &'a Backend, ) -> ( Vec>, - SumcheckClaim + Clone + 'static>, + SumcheckClaim + Clone + 'static>, impl SumcheckProver + 'a, ) where - F: Field + ExtensionField + ExtensionField, + F: Field, FDomain: Field, P: PackedField, PExt: PackedField @@ -287,10 +287,9 @@ where .map(MLEEmbeddingAdapter::<_, PExt, _>::from) .collect::>(); - let mut claim_composite_sums = - Vec::>>>::new(); + let mut claim_composite_sums = Vec::>>>::new(); let mut prover_composite_sums = - Vec::>>>::new(); + Vec::>>>::new(); if max_degree >= 1 { let identity_composition = diff --git a/crates/core/src/protocols/sumcheck/univariate.rs b/crates/core/src/protocols/sumcheck/univariate.rs index d433b5d6..df833eef 100644 --- a/crates/core/src/protocols/sumcheck/univariate.rs +++ b/crates/core/src/protocols/sumcheck/univariate.rs @@ -230,7 +230,7 @@ mod tests { }; use binius_hal::ComputationBackend; use binius_math::{ - CompositionPolyOS, DefaultEvaluationDomainFactory, EvaluationDomainFactory, + CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MultilinearPoly, }; use groestl_crypto::Groestl256; @@ -437,31 +437,31 @@ mod tests { let prover_compositions = [ ( "pair".into(), - pair.clone() as Arc>>, - pair.clone() as Arc>>, + pair.clone() as Arc>>, + pair.clone() as Arc>>, ), ( "triple".into(), - triple.clone() as Arc>>, - triple.clone() as Arc>>, + triple.clone() as Arc>>, + triple.clone() as Arc>>, ), ( "quad".into(), - quad.clone() as Arc>>, - quad.clone() as Arc>>, + quad.clone() as Arc>>, + quad.clone() as Arc>>, ), ]; let prover_adapter_compositions = [ - CompositionScalarAdapter::new(pair.clone() as Arc>), - CompositionScalarAdapter::new(triple.clone() as Arc>), - CompositionScalarAdapter::new(quad.clone() as Arc>), + CompositionScalarAdapter::new(pair.clone() as Arc>), + CompositionScalarAdapter::new(triple.clone() as Arc>), + CompositionScalarAdapter::new(quad.clone() as Arc>), ]; let verifier_compositions = [ - pair as Arc>, - triple as Arc>, - quad as Arc>, + pair as Arc>, + triple as Arc>, + quad as Arc>, ]; for skip_rounds in 0..=max_n_vars { diff --git a/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs b/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs index eb14a3bb..75756851 100644 --- a/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{util::inner_product_unchecked, Field, TowerField}; -use binius_math::{CompositionPolyOS, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory}; +use binius_math::{CompositionPoly, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory}; use binius_utils::{bail, sorting::is_sorted_ascending}; use tracing::instrument; @@ -50,7 +50,7 @@ pub fn batch_verify_zerocheck_univariate_round( ) -> Result, Error> where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, Challenger_: Challenger, { // Check that the claims are in descending order by n_vars diff --git a/crates/core/src/protocols/sumcheck/verify.rs b/crates/core/src/protocols/sumcheck/verify.rs index b06de94a..fe778c1a 100644 --- a/crates/core/src/protocols/sumcheck/verify.rs +++ b/crates/core/src/protocols/sumcheck/verify.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{Field, TowerField}; -use binius_math::{evaluate_univariate, CompositionPolyOS}; +use binius_math::{evaluate_univariate, CompositionPoly}; use binius_utils::{bail, sorting::is_sorted_ascending}; use itertools::izip; use tracing::instrument; @@ -34,7 +34,7 @@ pub fn batch_verify( ) -> Result, Error> where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, Challenger_: Challenger, { let start = BatchVerifyStart { @@ -69,7 +69,7 @@ pub fn batch_verify_with_start( ) -> Result, Error> where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, Challenger_: Challenger, { let BatchVerifyStart { @@ -177,7 +177,7 @@ pub fn compute_expected_batch_composite_evaluation_single_claim Result where - Composition: CompositionPolyOS, + Composition: CompositionPoly, { let composite_evals = claim .composite_sums() @@ -193,7 +193,7 @@ fn compute_expected_batch_composite_evaluation_multi_claim], ) -> Result where - Composition: CompositionPolyOS, + Composition: CompositionPoly, { izip!(batch_coeffs, claims, multilinear_evals.iter()) .map(|(batch_coeff, claim, multilinear_evals)| { diff --git a/crates/core/src/protocols/sumcheck/zerocheck.rs b/crates/core/src/protocols/sumcheck/zerocheck.rs index 4bce691c..5d47e28b 100644 --- a/crates/core/src/protocols/sumcheck/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/zerocheck.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use binius_field::{util::eq, Field, PackedField}; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::{bail, sorting::is_sorted_ascending}; use getset::CopyGetters; @@ -22,7 +22,7 @@ pub struct ZerocheckClaim { impl ZerocheckClaim where - Composition: CompositionPolyOS, + Composition: CompositionPoly, { pub fn new( n_vars: usize, @@ -60,7 +60,7 @@ where } /// Requirement: zerocheck challenges have been sampled before this is called -pub fn reduce_to_sumchecks>( +pub fn reduce_to_sumchecks>( claims: &[ZerocheckClaim], ) -> Result>>, Error> { // Check that the claims are in descending order by n_vars @@ -100,7 +100,7 @@ pub fn reduce_to_sumchecks>( /// /// Note that due to univariatization of some rounds the number of challenges may be less than /// the maximum number of variables among claims. -pub fn verify_sumcheck_outputs>( +pub fn verify_sumcheck_outputs>( claims: &[ZerocheckClaim], zerocheck_challenges: &[F], sumcheck_output: BatchSumcheckOutput, @@ -158,10 +158,10 @@ pub struct ExtraProduct { pub inner: Composition, } -impl CompositionPolyOS

for ExtraProduct +impl CompositionPoly

for ExtraProduct where P: PackedField, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn n_vars(&self) -> usize { self.inner.n_vars() + 1 @@ -195,8 +195,8 @@ mod tests { use std::{iter, sync::Arc}; use binius_field::{ - BinaryField128b, BinaryField32b, BinaryField8b, ExtensionField, PackedBinaryField1x128b, - PackedExtension, PackedFieldIndexable, PackedSubfield, RepackedExtension, + BinaryField128b, BinaryField32b, BinaryField8b, PackedBinaryField1x128b, PackedExtension, + PackedFieldIndexable, PackedSubfield, RepackedExtension, }; use binius_hal::{make_portable_backend, ComputationBackend, ComputationBackendExt}; use binius_math::{ @@ -236,10 +236,10 @@ mod tests { Backend, > where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedFieldIndexable + PackedExtension + RepackedExtension

, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync + 'static, Backend: ComputationBackend, { diff --git a/crates/core/src/protocols/test_utils.rs b/crates/core/src/protocols/test_utils.rs index 1ad370ad..bfa4a0b4 100644 --- a/crates/core/src/protocols/test_utils.rs +++ b/crates/core/src/protocols/test_utils.rs @@ -3,7 +3,7 @@ use std::ops::Deref; use binius_field::{ExtensionField, Field, PackedField}; -use binius_math::{ArithExpr, CompositionPolyOS, MLEEmbeddingAdapter, MultilinearExtension}; +use binius_math::{ArithExpr, CompositionPoly, MLEEmbeddingAdapter, MultilinearExtension}; use rand::Rng; use crate::polynomial::Error as PolynomialError; @@ -19,10 +19,10 @@ impl AddOneComposition { } } -impl CompositionPolyOS

for AddOneComposition +impl CompositionPoly

for AddOneComposition where P: PackedField, - Inner: CompositionPolyOS

, + Inner: CompositionPoly

, { fn n_vars(&self) -> usize { self.inner.n_vars() @@ -56,7 +56,7 @@ impl TestProductComposition { } } -impl

CompositionPolyOS

for TestProductComposition +impl

CompositionPoly

for TestProductComposition where P: PackedField, { @@ -122,7 +122,7 @@ where } pub fn transform_poly( - multilin: MultilinearExtension, + multilin: &MultilinearExtension, ) -> Result, PolynomialError> where F: Field, diff --git a/crates/core/src/reed_solomon/reed_solomon.rs b/crates/core/src/reed_solomon/reed_solomon.rs index 771a306f..ade7bace 100644 --- a/crates/core/src/reed_solomon/reed_solomon.rs +++ b/crates/core/src/reed_solomon/reed_solomon.rs @@ -15,7 +15,7 @@ use std::marker::PhantomData; use binius_field::{BinaryField, ExtensionField, PackedField, RepackedExtension}; use binius_maybe_rayon::prelude::*; use binius_ntt::{AdditiveNTT, DynamicDispatchNTT, Error, NTTOptions, ThreadingSettings}; -use binius_utils::{bail, checked_arithmetics::checked_log_2}; +use binius_utils::bail; use getset::CopyGetters; use tracing::instrument; @@ -40,7 +40,7 @@ where pub fn new( log_dimension: usize, log_inv_rate: usize, - ntt_options: NTTOptions, + ntt_options: &NTTOptions, ) -> Result { // Since we split work between log_inv_rate threads, we need to decrease the number of threads per each NTT transformation. let ntt_log_threads = ntt_options @@ -49,11 +49,11 @@ where .saturating_sub(log_inv_rate); let ntt = DynamicDispatchNTT::new( log_dimension + log_inv_rate, - NTTOptions { + &NTTOptions { thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: ntt_log_threads, }, - ..ntt_options + precompute_twiddles: ntt_options.precompute_twiddles, }, )?; @@ -160,16 +160,11 @@ where /// /// * If the `code` buffer does not have capacity for `len() << log_batch_size` field elements. #[instrument(skip_all, level = "debug")] - pub fn encode_ext_batch_inplace( + pub fn encode_ext_batch_inplace>( &self, code: &mut [PE], log_batch_size: usize, - ) -> Result<(), Error> - where - PE: RepackedExtension

, - PE::Scalar: ExtensionField<

::Scalar>, - { - let log_degree = checked_log_2(PE::Scalar::DEGREE); - self.encode_batch_inplace(PE::cast_bases_mut(code), log_batch_size + log_degree) + ) -> Result<(), Error> { + self.encode_batch_inplace(PE::cast_bases_mut(code), log_batch_size + PE::Scalar::LOG_DEGREE) } } diff --git a/crates/core/src/ring_switch/common.rs b/crates/core/src/ring_switch/common.rs index e86876e4..fb76fe80 100644 --- a/crates/core/src/ring_switch/common.rs +++ b/crates/core/src/ring_switch/common.rs @@ -72,7 +72,7 @@ impl<'a, F: TowerField> EvalClaimSystem<'a, F> { pub fn new( oracles: &MultilinearOracleSet, commit_meta: &'a CommitMeta, - oracle_to_commit_index: SparseIndex, + oracle_to_commit_index: &SparseIndex, eval_claims: &'a [EvalcheckMultilinearClaim], ) -> Result { // Sort evaluation claims in ascending order by number of packed variables. This must diff --git a/crates/core/src/ring_switch/eq_ind.rs b/crates/core/src/ring_switch/eq_ind.rs index daffe0ba..3505969f 100644 --- a/crates/core/src/ring_switch/eq_ind.rs +++ b/crates/core/src/ring_switch/eq_ind.rs @@ -1,10 +1,14 @@ // Copyright 2024-2025 Irreducible Inc. -use std::{iter, marker::PhantomData, sync::Arc}; +use std::{any::TypeId, iter, marker::PhantomData, sync::Arc}; use binius_field::{ - util::inner_product_unchecked, ExtensionField, Field, PackedExtension, PackedField, - PackedFieldIndexable, TowerField, + byte_iteration::{ + can_iterate_bytes, create_partial_sums_lookup_tables, iterate_bytes, ByteIteratorCallback, + }, + util::inner_product_unchecked, + BinaryField1b, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, + TowerField, }; use binius_math::{tensor_prod_eq_ind, MultilinearExtension}; use binius_maybe_rayon::prelude::*; @@ -17,6 +21,34 @@ use crate::{ tensor_algebra::TensorAlgebra, }; +/// Information about the row-batching coefficients. +#[derive(Debug)] +pub struct RowBatchCoeffs { + coeffs: Vec, + /// This is a lookup table for the partial sums of the coefficients + /// that is used to efficiently fold with 1-bit coefficients. + partial_sums_lookup_table: Vec, +} + +impl RowBatchCoeffs { + pub fn new(coeffs: Vec) -> Self { + let partial_sums_lookup_table = if coeffs.len() >= 8 { + create_partial_sums_lookup_tables(coeffs.as_slice()) + } else { + Vec::new() + }; + + Self { + coeffs, + partial_sums_lookup_table, + } + } + + pub fn coeffs(&self) -> &[F] { + &self.coeffs + } +} + /// The multilinear function $A$ from [DP24] Section 5. /// /// The function $A$ is $\ell':= \ell - \kappa$-variate and depends on the last $\ell'$ coordinates @@ -27,7 +59,7 @@ use crate::{ pub struct RingSwitchEqInd { /// $z_{\kappa}, \ldots, z_{\ell-1}$ z_vals: Arc<[F]>, - row_batch_coeffs: Arc<[F]>, + row_batch_coeffs: Arc>, mixing_coeff: F, _marker: PhantomData, } @@ -39,10 +71,10 @@ where { pub fn new( z_vals: Arc<[F]>, - row_batch_coeffs: Arc<[F]>, + row_batch_coeffs: Arc>, mixing_coeff: F, ) -> Result { - if row_batch_coeffs.len() < F::DEGREE { + if row_batch_coeffs.coeffs.len() < F::DEGREE { bail!(Error::InvalidArgs( "RingSwitchEqInd::new expects row_batch_coeffs length greater than or equal to \ the extension degree" @@ -67,20 +99,53 @@ where P::unpack_scalars_mut(&mut evals) .par_iter_mut() .for_each(|val| { - let vert = *val; - *val = inner_product_unchecked( - self.row_batch_coeffs.iter().copied(), - ExtensionField::::iter_bases(&vert), - ); + *val = inner_product_subfield(*val, &self.row_batch_coeffs); }); Ok(MultilinearExtension::from_values(evals)?) } } +#[inline(always)] +fn inner_product_subfield(value: F, row_batch_coeffs: &RowBatchCoeffs) -> F +where + FSub: Field, + F: ExtensionField, +{ + if TypeId::of::() == TypeId::of::() && can_iterate_bytes::() { + // Special case when we are folding with 1-bit coefficients. + // Use partial sums lookup table to speed up the computation. + + struct Callback<'a, F> { + partial_sums_lookup: &'a [F], + result: F, + } + + impl ByteIteratorCallback for Callback<'_, F> { + #[inline(always)] + fn call(&mut self, iter: impl Iterator) { + for (byte_index, byte) in iter.enumerate() { + self.result += self.partial_sums_lookup[(byte_index << 8) + byte as usize]; + } + } + } + + let mut callback = Callback { + partial_sums_lookup: &row_batch_coeffs.partial_sums_lookup_table, + result: F::ZERO, + }; + iterate_bytes(std::slice::from_ref(&value), &mut callback); + + callback.result + } else { + // fall back to the general case + inner_product_unchecked(row_batch_coeffs.coeffs.iter().copied(), F::iter_bases(&value)) + } +} + impl MultivariatePoly for RingSwitchEqInd where FSub: TowerField, - F: TowerField + PackedField + ExtensionField + PackedExtension, + F: TowerField + PackedField + PackedExtension, { fn n_vars(&self) -> usize { self.z_vals.len() @@ -108,7 +173,7 @@ where }, ); - let folded_eval = tensor_eval.fold_vertical(&self.row_batch_coeffs); + let folded_eval = tensor_eval.fold_vertical(&self.row_batch_coeffs.coeffs); Ok(folded_eval) } @@ -141,7 +206,8 @@ mod tests { let row_batch_coeffs = repeat_with(|| ::random(&mut rng)) .take(1 << kappa) - .collect::>(); + .collect::>(); + let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(row_batch_coeffs)); let eval_point = repeat_with(|| ::random(&mut rng)) .take(n_vars) diff --git a/crates/core/src/ring_switch/prove.rs b/crates/core/src/ring_switch/prove.rs index ab2d6aee..5a7423aa 100644 --- a/crates/core/src/ring_switch/prove.rs +++ b/crates/core/src/ring_switch/prove.rs @@ -11,6 +11,7 @@ use tracing::instrument; use super::{ common::{EvalClaimPrefixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc}, + eq_ind::RowBatchCoeffs, error::Error, tower_tensor_algebra::TowerTensorAlgebra, }; @@ -75,11 +76,12 @@ where // Sample the row-batching randomness. let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa()); - let row_batch_coeffs = - Arc::from(MultilinearQuery::::expand(&row_batch_challenges).into_expansion()); + let row_batch_coeffs = Arc::new(RowBatchCoeffs::new( + MultilinearQuery::::expand(&row_batch_challenges).into_expansion(), + )); let row_batched_evals = - compute_row_batched_sumcheck_evals(scaled_tensor_elems, &row_batch_coeffs); + compute_row_batched_sumcheck_evals(scaled_tensor_elems, row_batch_coeffs.coeffs()); transcript.message().write_scalar_slice(&row_batched_evals); // Create the reduced PIOP sumcheck witnesses. @@ -217,7 +219,7 @@ where fn make_ring_switch_eq_inds( sumcheck_claim_descs: &[PIOPSumcheckClaimDesc], suffix_descs: &[EvalClaimSuffixDesc], - row_batch_coeffs: Arc<[F]>, + row_batch_coeffs: Arc>, mixing_coeffs: &[F], ) -> Result>, Error> where @@ -238,7 +240,7 @@ where fn make_ring_switch_eq_ind( suffix_desc: &EvalClaimSuffixDesc>, - row_batch_coeffs: Arc<[FExt]>, + row_batch_coeffs: Arc>>, mixing_coeff: FExt, ) -> Result, Error> where diff --git a/crates/core/src/ring_switch/tests.rs b/crates/core/src/ring_switch/tests.rs index 5546e786..2cb0ce77 100644 --- a/crates/core/src/ring_switch/tests.rs +++ b/crates/core/src/ring_switch/tests.rs @@ -14,7 +14,7 @@ use binius_math::{ DefaultEvaluationDomainFactory, MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly, MultilinearQuery, }; -use binius_utils::serialization::{DeserializeBytes, SerializeBytes}; +use binius_utils::{DeserializeBytes, SerializeBytes}; use groestl_crypto::Groestl256; use rand::prelude::*; @@ -208,7 +208,7 @@ fn with_test_instance_from_oracles( // Finish setting up the test case let system = - EvalClaimSystem::new(oracles, &commit_meta, oracle_to_commit_index, &eval_claims).unwrap(); + EvalClaimSystem::new(oracles, &commit_meta, &oracle_to_commit_index, &eval_claims).unwrap(); check_eval_point_consistency(&system); func(rng, system, witnesses) @@ -304,7 +304,7 @@ fn commit_prove_verify_piop( // Finish setting up the test case let system = - EvalClaimSystem::new(oracles, &commit_meta, oracle_to_commit_index, &eval_claims).unwrap(); + EvalClaimSystem::new(oracles, &commit_meta, &oracle_to_commit_index, &eval_claims).unwrap(); check_eval_point_consistency(&system); let mut proof = ProverTranscript::>::new(); diff --git a/crates/core/src/ring_switch/verify.rs b/crates/core/src/ring_switch/verify.rs index 330b0ffc..8a841b2a 100644 --- a/crates/core/src/ring_switch/verify.rs +++ b/crates/core/src/ring_switch/verify.rs @@ -8,6 +8,7 @@ use binius_utils::checked_arithmetics::log2_ceil_usize; use bytes::Buf; use itertools::izip; +use super::eq_ind::RowBatchCoeffs; use crate::{ fiat_shamir::{CanSample, Challenger}, piop::PIOPSumcheckClaim, @@ -50,8 +51,9 @@ where // Sample the row-batching randomness. let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa()); - let row_batch_coeffs = - Arc::from(MultilinearQuery::::expand(&row_batch_challenges).into_expansion()); + let row_batch_coeffs = Arc::new(RowBatchCoeffs::new( + MultilinearQuery::::expand(&row_batch_challenges).into_expansion(), + )); // For each original evaluation claim, receive the row-batched evaluation claim. let row_batched_evals = transcript @@ -66,7 +68,7 @@ where &system.eval_claim_to_prefix_desc_index, ); for (expected, tensor_elem) in iter::zip(mixed_row_batched_evals, tensor_elems) { - if tensor_elem.fold_vertical(&row_batch_coeffs) != expected { + if tensor_elem.fold_vertical(row_batch_coeffs.coeffs()) != expected { return Err(VerificationError::IncorrectRowBatchedSum.into()); } } @@ -75,7 +77,7 @@ where let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, Tower>( &system.sumcheck_claim_descs, &system.suffix_descs, - row_batch_coeffs, + &row_batch_coeffs, &mixing_coeffs, )?; let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals) @@ -173,7 +175,7 @@ fn accumulate_evaluations_by_prefixes( fn make_ring_switch_eq_inds( sumcheck_claim_descs: &[PIOPSumcheckClaimDesc], suffix_descs: &[EvalClaimSuffixDesc], - row_batch_coeffs: Arc<[F]>, + row_batch_coeffs: &Arc>, mixing_coeffs: &[F], ) -> Result>>, Error> where @@ -190,7 +192,7 @@ where fn make_ring_switch_eq_ind( suffix_desc: &EvalClaimSuffixDesc>, - row_batch_coeffs: Arc<[FExt]>, + row_batch_coeffs: Arc>>, mixing_coeff: FExt, ) -> Result>>, Error> where diff --git a/crates/core/src/tensor_algebra.rs b/crates/core/src/tensor_algebra.rs index a5d7ea5a..cb071f25 100644 --- a/crates/core/src/tensor_algebra.rs +++ b/crates/core/src/tensor_algebra.rs @@ -10,7 +10,6 @@ use std::{ use binius_field::{ square_transpose, util::inner_product_unchecked, ExtensionField, Field, PackedExtension, }; -use binius_utils::checked_arithmetics::checked_log_2; /// An element of the tensor algebra defined as the tensor product of `FE` and `FE` as fields. /// @@ -64,7 +63,7 @@ where /// Returns $\kappa$, the base-2 logarithm of the extension degree. pub const fn kappa() -> usize { - checked_log_2(FE::DEGREE) + FE::LOG_DEGREE } /// Returns the byte size of an element. @@ -124,12 +123,7 @@ where } } -impl TensorAlgebra -where - F: Field, - FE: ExtensionField + PackedExtension, - FE::Scalar: ExtensionField, -{ +impl + PackedExtension> TensorAlgebra { /// Multiply by an element from the vertical subring. /// /// Internally, this performs a transpose, vertical scaling, then transpose sequence. If diff --git a/crates/core/src/transcript/error.rs b/crates/core/src/transcript/error.rs index 68a5eefd..ee57c376 100644 --- a/crates/core/src/transcript/error.rs +++ b/crates/core/src/transcript/error.rs @@ -1,7 +1,5 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_utils::serialization::Error as SerializationError; - #[derive(Debug, thiserror::Error)] pub enum Error { #[error("Transcript is not empty, {remaining} bytes")] @@ -9,5 +7,5 @@ pub enum Error { #[error("Not enough bytes in the buffer")] NotEnoughBytes, #[error("Serialization error: {0}")] - Serialization(#[from] SerializationError), + Serialization(#[from] binius_utils::SerializationError), } diff --git a/crates/core/src/transcript/mod.rs b/crates/core/src/transcript/mod.rs index e6081b6c..d014ea9a 100644 --- a/crates/core/src/transcript/mod.rs +++ b/crates/core/src/transcript/mod.rs @@ -16,8 +16,8 @@ mod error; use std::{iter::repeat_with, slice}; -use binius_field::{deserialize_canonical, serialize_canonical, PackedField, TowerField}; -use binius_utils::serialization::{DeserializeBytes, SerializeBytes}; +use binius_field::{PackedField, TowerField}; +use binius_utils::{DeserializeBytes, SerializationMode, SerializeBytes}; use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut}; pub use error::Error; use tracing::warn; @@ -259,12 +259,14 @@ impl TranscriptReader<'_, B> { } pub fn read(&mut self) -> Result { - T::deserialize(self.buffer()).map_err(Into::into) + let mode = SerializationMode::CanonicalTower; + T::deserialize(self.buffer(), mode).map_err(Into::into) } pub fn read_vec(&mut self, n: usize) -> Result, Error> { + let mode = SerializationMode::CanonicalTower; let mut buffer = self.buffer(); - repeat_with(move || T::deserialize(&mut buffer).map_err(Into::into)) + repeat_with(move || T::deserialize(&mut buffer, mode).map_err(Into::into)) .take(n) .collect() } @@ -287,7 +289,8 @@ impl TranscriptReader<'_, B> { pub fn read_scalar_slice_into(&mut self, buf: &mut [F]) -> Result<(), Error> { let mut buffer = self.buffer(); for elem in buf { - *elem = deserialize_canonical(&mut buffer)?; + let mode = SerializationMode::CanonicalTower; + *elem = DeserializeBytes::deserialize(&mut buffer, mode)?; } Ok(()) } @@ -334,20 +337,27 @@ impl TranscriptWriter<'_, B> { } pub fn write(&mut self, value: &T) { - value - .serialize(self.buffer()) - .expect("TODO: propagate error") + self.proof_size_event_wrapper(|buffer| { + value + .serialize(buffer, SerializationMode::CanonicalTower) + .expect("TODO: propagate error"); + }); } pub fn write_slice(&mut self, values: &[T]) { - let mut buffer = self.buffer(); - for value in values { - value.serialize(&mut buffer).expect("TODO: propagate error") - } + self.proof_size_event_wrapper(|buffer| { + for value in values { + value + .serialize(&mut *buffer, SerializationMode::CanonicalTower) + .expect("TODO: propagate error"); + } + }); } pub fn write_bytes(&mut self, data: &[u8]) { - self.buffer().put_slice(data); + self.proof_size_event_wrapper(|buffer| { + buffer.put_slice(data); + }); } pub fn write_scalar(&mut self, f: F) { @@ -355,10 +365,12 @@ impl TranscriptWriter<'_, B> { } pub fn write_scalar_slice(&mut self, elems: &[F]) { - let mut buffer = self.buffer(); - for elem in elems { - serialize_canonical(*elem, &mut buffer).expect("TODO: propagate error"); - } + self.proof_size_event_wrapper(|buffer| { + for elem in elems { + SerializeBytes::serialize(elem, &mut *buffer, SerializationMode::CanonicalTower) + .expect("TODO: propagate error"); + } + }); } pub fn write_packed>(&mut self, packed: P) { @@ -378,6 +390,14 @@ impl TranscriptWriter<'_, B> { self.write_bytes(msg.as_bytes()) } } + + fn proof_size_event_wrapper(&mut self, f: F) { + let buffer = self.buffer(); + let start_bytes = buffer.remaining_mut(); + f(buffer); + let end_bytes = buffer.remaining_mut(); + tracing::event!(name: "proof_size", tracing::Level::INFO, counter=true, incremental=true, value=start_bytes - end_bytes); + } } impl CanSample for VerifierTranscript @@ -386,7 +406,8 @@ where Challenger_: Challenger, { fn sample(&mut self) -> F { - deserialize_canonical(self.combined.challenger.sampler()) + let mode = SerializationMode::CanonicalTower; + DeserializeBytes::deserialize(self.combined.challenger.sampler(), mode) .expect("challenger has infinite buffer") } } @@ -397,7 +418,8 @@ where Challenger_: Challenger, { fn sample(&mut self) -> F { - deserialize_canonical(self.combined.challenger.sampler()) + let mode = SerializationMode::CanonicalTower; + DeserializeBytes::deserialize(self.combined.challenger.sampler(), mode) .expect("challenger has infinite buffer") } } diff --git a/crates/core/src/transparent/constant.rs b/crates/core/src/transparent/constant.rs index 1c010873..860ced32 100644 --- a/crates/core/src/transparent/constant.rs +++ b/crates/core/src/transparent/constant.rs @@ -1,18 +1,26 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{ExtensionField, TowerField}; -use binius_utils::bail; +use binius_field::{BinaryField128b, ExtensionField, TowerField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; /// A constant polynomial. -#[derive(Debug, Copy, Clone)] -pub struct Constant { +#[derive(Debug, Copy, Clone, SerializeBytes, DeserializeBytes)] +pub struct Constant { n_vars: usize, value: F, tower_level: usize, } +inventory::submit! { + >::register_deserializer( + "Constant", + |buf, mode| Ok(Box::new(Constant::::deserialize(&mut *buf, mode)?)) + ) +} + impl Constant { pub fn new(n_vars: usize, value: FS) -> Self where @@ -26,6 +34,7 @@ impl Constant { } } +#[erased_serialize_bytes] impl MultivariatePoly for Constant { fn n_vars(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/mod.rs b/crates/core/src/transparent/mod.rs index e95864fe..a1e79368 100644 --- a/crates/core/src/transparent/mod.rs +++ b/crates/core/src/transparent/mod.rs @@ -6,6 +6,7 @@ pub mod eq_ind; pub mod multilinear_extension; pub mod powers; pub mod select_row; +pub mod serialization; pub mod shift_ind; pub mod step_down; pub mod step_up; diff --git a/crates/core/src/transparent/multilinear_extension.rs b/crates/core/src/transparent/multilinear_extension.rs index 7df54751..ef55e4d7 100644 --- a/crates/core/src/transparent/multilinear_extension.rs +++ b/crates/core/src/transparent/multilinear_extension.rs @@ -2,9 +2,15 @@ use std::{fmt::Debug, ops::Deref}; -use binius_field::{ExtensionField, PackedField, RepackedExtension, TowerField}; +use binius_field::{ + arch::OptimalUnderlier, as_packed_field::PackedType, packed::pack_slice, BinaryField128b, + BinaryField16b, BinaryField1b, BinaryField2b, BinaryField32b, BinaryField4b, BinaryField64b, + BinaryField8b, ExtensionField, PackedField, RepackedExtension, TowerField, +}; use binius_hal::{make_portable_backend, ComputationBackendExt}; +use binius_macros::erased_serialize_bytes; use binius_math::{MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly}; +use binius_utils::{DeserializeBytes, SerializationError, SerializationMode, SerializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -26,6 +32,72 @@ where data: MLEEmbeddingAdapter, } +impl SerializeBytes for MultilinearExtensionTransparent +where + P: PackedField, + PE: RepackedExtension

, + PE::Scalar: TowerField + ExtensionField, + Data: Deref + Debug + Send + Sync, +{ + fn serialize( + &self, + write_buf: impl bytes::BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + let elems = PE::iter_slice( + self.data + .packed_evals() + .expect("Evals should always be available here"), + ) + .collect::>(); + SerializeBytes::serialize(&elems, write_buf, mode) + } +} + +inventory::submit! { + >::register_deserializer( + "MultilinearExtensionTransparent", + |buf, mode| { + type U = OptimalUnderlier; + type F = BinaryField128b; + type P = PackedType; + let hypercube_evals = Vec::::deserialize(&mut *buf, mode)?; + let result: Box> = if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else { + Box::new(MultilinearExtensionTransparent::::from_values(pack_slice(&hypercube_evals)).unwrap()) + }; + Ok(result) + } + ) +} + +fn try_pack_slice(xs: &[F]) -> Option> +where + PS: PackedField, + F: ExtensionField, +{ + Some(pack_slice( + &xs.iter() + .copied() + .map(TryInto::try_into) + .collect::, _>>() + .ok()?, + )) +} + impl MultilinearExtensionTransparent where P: PackedField, @@ -49,6 +121,7 @@ where } } +#[erased_serialize_bytes] impl MultivariatePoly for MultilinearExtensionTransparent where F: TowerField + ExtensionField, diff --git a/crates/core/src/transparent/powers.rs b/crates/core/src/transparent/powers.rs index c0c912d9..01a925a1 100644 --- a/crates/core/src/transparent/powers.rs +++ b/crates/core/src/transparent/powers.rs @@ -2,10 +2,11 @@ use std::iter::successors; -use binius_field::{Field, PackedField, TowerField}; +use binius_field::{BinaryField128b, PackedField, TowerField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; use binius_maybe_rayon::prelude::*; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use bytemuck::zeroed_vec; use itertools::{izip, Itertools}; @@ -13,13 +14,20 @@ use crate::polynomial::{Error, MultivariatePoly}; /// A transparent multilinear polynomial whose evaluation at index $i$ is $g^i$ for /// some field element $g$. -#[derive(Debug)] -pub struct Powers { +#[derive(Debug, SerializeBytes, DeserializeBytes)] +pub struct Powers { n_vars: usize, base: F, } -impl Powers { +inventory::submit! { + >::register_deserializer( + "Powers", + |buf, mode| Ok(Box::new(Powers::::deserialize(&mut *buf, mode)?)) + ) +} + +impl Powers { pub const fn new(n_vars: usize, base: F) -> Self { Self { n_vars, base } } @@ -49,6 +57,7 @@ impl Powers { } } +#[erased_serialize_bytes] impl> MultivariatePoly

for Powers { fn n_vars(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/select_row.rs b/crates/core/src/transparent/select_row.rs index fdcd32c5..9ef3bf0e 100644 --- a/crates/core/src/transparent/select_row.rs +++ b/crates/core/src/transparent/select_row.rs @@ -1,8 +1,9 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{packed::set_packed_slice, BinaryField1b, Field, PackedField}; +use binius_field::{packed::set_packed_slice, BinaryField128b, BinaryField1b, Field, PackedField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -18,12 +19,19 @@ use crate::polynomial::{Error, MultivariatePoly}; /// ``` /// /// This is useful for defining boundary constraints -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct SelectRow { n_vars: usize, index: usize, } +inventory::submit! { + >::register_deserializer( + "SelectRow", + |buf, mode| Ok(Box::new(SelectRow::deserialize(&mut *buf, mode)?)) + ) +} + impl SelectRow { pub fn new(n_vars: usize, index: usize) -> Result { if index >= (1 << n_vars) { @@ -50,6 +58,7 @@ impl SelectRow { } } +#[erased_serialize_bytes] impl MultivariatePoly for SelectRow { fn degree(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/serialization.rs b/crates/core/src/transparent/serialization.rs new file mode 100644 index 00000000..2d5ef14f --- /dev/null +++ b/crates/core/src/transparent/serialization.rs @@ -0,0 +1,82 @@ +// Copyright 2025 Irreducible Inc. + +//! The purpose of this module is to enable serialization/deserialization of generic MultivariatePoly implementations +//! +//! The simplest way to do this would be to create an enum with all the possible structs that implement MultivariatePoly +//! +//! This has a few problems, though: +//! - Third party code is not able to define custom transparent polynomials +//! - The enum would inherit, or be forced to enumerate possible type parameters of every struct variant + +use std::{collections::HashMap, sync::LazyLock}; + +use binius_field::{BinaryField128b, TowerField}; +use binius_utils::{DeserializeBytes, SerializationError, SerializationMode, SerializeBytes}; + +use crate::polynomial::MultivariatePoly; + +impl SerializeBytes for Box> { + fn serialize( + &self, + mut write_buf: impl bytes::BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + self.erased_serialize(&mut write_buf, mode) + } +} + +impl DeserializeBytes for Box> { + fn deserialize( + mut read_buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + let name = String::deserialize(&mut read_buf, mode)?; + match REGISTRY.get(name.as_str()) { + Some(Some(erased_deserialize)) => erased_deserialize(&mut read_buf, mode), + Some(None) => Err(SerializationError::DeserializerNameConflict { name }), + None => Err(SerializationError::DeserializerNotImplented), + } + } +} + +// Using the inventory crate we can collect all deserializers before the main function runs +// This allows third party code to submit their own deserializers as well +inventory::collect!(DeserializerEntry); + +static REGISTRY: LazyLock>>> = + LazyLock::new(|| { + let mut registry = HashMap::new(); + inventory::iter::> + .into_iter() + .for_each(|&DeserializerEntry { name, deserializer }| match registry.entry(name) { + std::collections::hash_map::Entry::Vacant(entry) => { + entry.insert(Some(deserializer)); + } + std::collections::hash_map::Entry::Occupied(mut entry) => { + entry.insert(None); + } + }); + registry + }); + +impl dyn MultivariatePoly { + pub const fn register_deserializer( + name: &'static str, + deserializer: ErasedDeserializeBytes, + ) -> DeserializerEntry { + DeserializerEntry { name, deserializer } + } +} + +pub struct DeserializerEntry { + name: &'static str, + deserializer: ErasedDeserializeBytes, +} + +type ErasedDeserializeBytes = fn( + &mut dyn bytes::Buf, + mode: SerializationMode, +) -> Result>, SerializationError>; diff --git a/crates/core/src/transparent/step_down.rs b/crates/core/src/transparent/step_down.rs index 8c588f73..00dec677 100644 --- a/crates/core/src/transparent/step_down.rs +++ b/crates/core/src/transparent/step_down.rs @@ -1,8 +1,9 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{Field, PackedField}; +use binius_field::{BinaryField128b, Field, PackedField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -20,12 +21,19 @@ use crate::polynomial::{Error, MultivariatePoly}; /// ``` /// /// This is useful for making constraints that are not enforced at the last rows of the trace -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct StepDown { n_vars: usize, index: usize, } +inventory::submit! { + >::register_deserializer( + "StepDown", + |buf, mode| Ok(Box::new(StepDown::deserialize(&mut *buf, mode)?)) + ) +} + impl StepDown { pub fn new(n_vars: usize, index: usize) -> Result { if index > 1 << n_vars { @@ -68,6 +76,7 @@ impl StepDown { } } +#[erased_serialize_bytes] impl MultivariatePoly for StepDown { fn degree(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/step_up.rs b/crates/core/src/transparent/step_up.rs index 3a24b9f5..ad022df9 100644 --- a/crates/core/src/transparent/step_up.rs +++ b/crates/core/src/transparent/step_up.rs @@ -1,8 +1,9 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{Field, PackedField}; +use binius_field::{BinaryField128b, Field, PackedField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -20,12 +21,19 @@ use crate::polynomial::{Error, MultivariatePoly}; /// ``` /// /// This is useful for making constraints that are not enforced at the first rows of the trace -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct StepUp { n_vars: usize, index: usize, } +inventory::submit! { + >::register_deserializer( + "StepUp", + |buf, mode| Ok(Box::new(StepUp::deserialize(&mut *buf, mode)?)) + ) +} + impl StepUp { pub fn new(n_vars: usize, index: usize) -> Result { if index > 1 << n_vars { @@ -64,6 +72,7 @@ impl StepUp { } } +#[erased_serialize_bytes] impl MultivariatePoly for StepUp { fn degree(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/tower_basis.rs b/crates/core/src/transparent/tower_basis.rs index 20992c46..8b8d32bf 100644 --- a/crates/core/src/transparent/tower_basis.rs +++ b/crates/core/src/transparent/tower_basis.rs @@ -2,9 +2,10 @@ use std::marker::PhantomData; -use binius_field::{Field, PackedField, TowerField}; +use binius_field::{BinaryField128b, Field, PackedField, TowerField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -20,13 +21,20 @@ use crate::polynomial::{Error, MultivariatePoly}; /// /// Thus, $\mathcal{T}_{\iota+k}$ has a $\mathcal{T}_{\iota}$-basis of size $2^k$: /// * $1, X_{\iota}, X_{\iota+1}, X_{\iota}X_{\iota+1}, X_{\iota+2}, \ldots, X_{\iota} X_{\iota+1} \ldots X_{\iota+k-1}$ -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, SerializeBytes, DeserializeBytes)] pub struct TowerBasis { k: usize, iota: usize, _marker: PhantomData, } +inventory::submit! { + >::register_deserializer( + "TowerBasis", + |buf, mode| Ok(Box::new(TowerBasis::::deserialize(&mut *buf, mode)?)) + ) +} + impl TowerBasis { pub fn new(k: usize, iota: usize) -> Result { if iota + k > F::TOWER_LEVEL { @@ -62,6 +70,7 @@ impl TowerBasis { } } +#[erased_serialize_bytes] impl MultivariatePoly for TowerBasis where F: TowerField, diff --git a/crates/field/Cargo.toml b/crates/field/Cargo.toml index 36de13de..11903020 100644 --- a/crates/field/Cargo.toml +++ b/crates/field/Cargo.toml @@ -11,7 +11,6 @@ workspace = true binius_maybe_rayon = { path = "../maybe_rayon", default-features = false } binius_utils = { path = "../utils", default-features = false } bytemuck.workspace = true -bytes.workspace = true cfg-if.workspace = true derive_more.workspace = true rand.workspace = true diff --git a/crates/field/benches/packed_extension_mul.rs b/crates/field/benches/packed_extension_mul.rs index a119dc9a..836b8474 100644 --- a/crates/field/benches/packed_extension_mul.rs +++ b/crates/field/benches/packed_extension_mul.rs @@ -16,7 +16,6 @@ fn benchmark_packed_extension_mul( label: &str, ) where F: Field, - BinaryField128b: ExtensionField, PackedBinaryField2x128b: PackedExtension, { let mut rng = thread_rng(); diff --git a/crates/field/benches/packed_field_element_access.rs b/crates/field/benches/packed_field_element_access.rs index d8f4d0ee..834f280a 100644 --- a/crates/field/benches/packed_field_element_access.rs +++ b/crates/field/benches/packed_field_element_access.rs @@ -3,11 +3,15 @@ use std::array; use binius_field::{ - PackedBinaryField128x1b, PackedBinaryField16x32b, PackedBinaryField16x8b, - PackedBinaryField1x128b, PackedBinaryField256x1b, PackedBinaryField2x128b, - PackedBinaryField2x64b, PackedBinaryField32x8b, PackedBinaryField4x128b, - PackedBinaryField4x32b, PackedBinaryField4x64b, PackedBinaryField512x1b, - PackedBinaryField64x8b, PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField, + ByteSlicedAES16x128b, ByteSlicedAES16x16b, ByteSlicedAES16x32b, ByteSlicedAES16x64b, + ByteSlicedAES16x8b, ByteSlicedAES32x128b, ByteSlicedAES32x16b, ByteSlicedAES32x32b, + ByteSlicedAES32x64b, ByteSlicedAES32x8b, ByteSlicedAES64x128b, ByteSlicedAES64x16b, + ByteSlicedAES64x32b, ByteSlicedAES64x64b, ByteSlicedAES64x8b, PackedBinaryField128x1b, + PackedBinaryField16x32b, PackedBinaryField16x8b, PackedBinaryField1x128b, + PackedBinaryField256x1b, PackedBinaryField2x128b, PackedBinaryField2x64b, + PackedBinaryField32x8b, PackedBinaryField4x128b, PackedBinaryField4x32b, + PackedBinaryField4x64b, PackedBinaryField512x1b, PackedBinaryField64x8b, + PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField, }; use criterion::{ criterion_group, criterion_main, measurement::WallTime, BenchmarkGroup, Criterion, Throughput, @@ -86,5 +90,43 @@ fn packed_512(c: &mut Criterion) { benchmark_get_set!(PackedBinaryField4x128b, group); } -criterion_group!(get_set, packed_128, packed_256, packed_512); +fn byte_sliced_128(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_128"); + + benchmark_get_set!(ByteSlicedAES16x8b, group); + benchmark_get_set!(ByteSlicedAES16x16b, group); + benchmark_get_set!(ByteSlicedAES16x32b, group); + benchmark_get_set!(ByteSlicedAES16x64b, group); + benchmark_get_set!(ByteSlicedAES16x128b, group); +} + +fn byte_sliced_256(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_256"); + + benchmark_get_set!(ByteSlicedAES32x8b, group); + benchmark_get_set!(ByteSlicedAES32x16b, group); + benchmark_get_set!(ByteSlicedAES32x32b, group); + benchmark_get_set!(ByteSlicedAES32x64b, group); + benchmark_get_set!(ByteSlicedAES32x128b, group); +} + +fn byte_sliced_512(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_512"); + + benchmark_get_set!(ByteSlicedAES64x8b, group); + benchmark_get_set!(ByteSlicedAES64x16b, group); + benchmark_get_set!(ByteSlicedAES64x32b, group); + benchmark_get_set!(ByteSlicedAES64x64b, group); + benchmark_get_set!(ByteSlicedAES64x128b, group); +} + +criterion_group!( + get_set, + packed_128, + packed_256, + packed_512, + byte_sliced_128, + byte_sliced_256, + byte_sliced_512 +); criterion_main!(get_set); diff --git a/crates/field/benches/packed_field_init.rs b/crates/field/benches/packed_field_init.rs index b2a3feb5..43c644e9 100644 --- a/crates/field/benches/packed_field_init.rs +++ b/crates/field/benches/packed_field_init.rs @@ -3,11 +3,15 @@ use std::array; use binius_field::{ - PackedBinaryField128x1b, PackedBinaryField16x32b, PackedBinaryField16x8b, - PackedBinaryField1x128b, PackedBinaryField256x1b, PackedBinaryField2x128b, - PackedBinaryField2x64b, PackedBinaryField32x8b, PackedBinaryField4x128b, - PackedBinaryField4x32b, PackedBinaryField4x64b, PackedBinaryField512x1b, - PackedBinaryField64x8b, PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField, + ByteSlicedAES16x128b, ByteSlicedAES16x16b, ByteSlicedAES16x32b, ByteSlicedAES16x64b, + ByteSlicedAES16x8b, ByteSlicedAES32x128b, ByteSlicedAES32x16b, ByteSlicedAES32x32b, + ByteSlicedAES32x64b, ByteSlicedAES32x8b, ByteSlicedAES64x128b, ByteSlicedAES64x16b, + ByteSlicedAES64x32b, ByteSlicedAES64x64b, ByteSlicedAES64x8b, PackedBinaryField128x1b, + PackedBinaryField16x32b, PackedBinaryField16x8b, PackedBinaryField1x128b, + PackedBinaryField256x1b, PackedBinaryField2x128b, PackedBinaryField2x64b, + PackedBinaryField32x8b, PackedBinaryField4x128b, PackedBinaryField4x32b, + PackedBinaryField4x64b, PackedBinaryField512x1b, PackedBinaryField64x8b, + PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField, }; use criterion::{ criterion_group, criterion_main, measurement::WallTime, BenchmarkGroup, Criterion, Throughput, @@ -71,5 +75,43 @@ fn packed_512(c: &mut Criterion) { benchmark_from_fn!(PackedBinaryField4x128b, group); } -criterion_group!(initialization, packed_128, packed_256, packed_512); +fn byte_sliced_128(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_128"); + + benchmark_from_fn!(ByteSlicedAES16x8b, group); + benchmark_from_fn!(ByteSlicedAES16x16b, group); + benchmark_from_fn!(ByteSlicedAES16x32b, group); + benchmark_from_fn!(ByteSlicedAES16x64b, group); + benchmark_from_fn!(ByteSlicedAES16x128b, group); +} + +fn byte_sliced_256(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_256"); + + benchmark_from_fn!(ByteSlicedAES32x8b, group); + benchmark_from_fn!(ByteSlicedAES32x16b, group); + benchmark_from_fn!(ByteSlicedAES32x32b, group); + benchmark_from_fn!(ByteSlicedAES32x64b, group); + benchmark_from_fn!(ByteSlicedAES32x128b, group); +} + +fn byte_sliced_512(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_512"); + + benchmark_from_fn!(ByteSlicedAES64x8b, group); + benchmark_from_fn!(ByteSlicedAES64x16b, group); + benchmark_from_fn!(ByteSlicedAES64x32b, group); + benchmark_from_fn!(ByteSlicedAES64x64b, group); + benchmark_from_fn!(ByteSlicedAES64x128b, group); +} + +criterion_group!( + initialization, + packed_128, + packed_256, + packed_512, + byte_sliced_128, + byte_sliced_256, + byte_sliced_512 +); criterion_main!(initialization); diff --git a/crates/field/benches/packed_field_subfield_ops.rs b/crates/field/benches/packed_field_subfield_ops.rs index cfd2dff3..bc2a3653 100644 --- a/crates/field/benches/packed_field_subfield_ops.rs +++ b/crates/field/benches/packed_field_subfield_ops.rs @@ -5,8 +5,8 @@ use std::array; use binius_field::{ packed::mul_by_subfield_scalar, underlier::{UnderlierType, WithUnderlier}, - BinaryField1b, BinaryField32b, BinaryField4b, BinaryField64b, BinaryField8b, ExtensionField, - Field, PackedBinaryField16x8b, PackedBinaryField1x128b, PackedBinaryField2x128b, + BinaryField1b, BinaryField32b, BinaryField4b, BinaryField64b, BinaryField8b, Field, + PackedBinaryField16x8b, PackedBinaryField1x128b, PackedBinaryField2x128b, PackedBinaryField32x8b, PackedBinaryField4x128b, PackedBinaryField4x32b, PackedBinaryField64x8b, PackedBinaryField8x32b, PackedBinaryField8x64b, PackedExtension, }; @@ -17,11 +17,7 @@ use rand::thread_rng; const BATCH_SIZE: usize = 32; -fn bench_mul_subfield(group: &mut BenchmarkGroup<'_, WallTime>) -where - PE: PackedExtension>, - F: Field, -{ +fn bench_mul_subfield, F: Field>(group: &mut BenchmarkGroup<'_, WallTime>) { let mut rng = thread_rng(); let packed: [PE; BATCH_SIZE] = array::from_fn(|_| PE::random(&mut rng)); let scalars: [F; BATCH_SIZE] = array::from_fn(|_| F::random(&mut rng)); diff --git a/crates/field/benches/packed_field_utils.rs b/crates/field/benches/packed_field_utils.rs index 516af5e6..622a49af 100644 --- a/crates/field/benches/packed_field_utils.rs +++ b/crates/field/benches/packed_field_utils.rs @@ -274,11 +274,23 @@ macro_rules! benchmark_packed_operation { PackedBinaryPolyval4x128b // Byte sliced AES fields + ByteSlicedAES16x8b + ByteSlicedAES16x16b + ByteSlicedAES16x32b + ByteSlicedAES16x64b + ByteSlicedAES16x128b + ByteSlicedAES32x8b ByteSlicedAES32x16b ByteSlicedAES32x32b ByteSlicedAES32x64b ByteSlicedAES32x128b + + ByteSlicedAES64x8b + ByteSlicedAES64x16b + ByteSlicedAES64x32b + ByteSlicedAES64x64b + ByteSlicedAES64x128b ]); }; } diff --git a/crates/field/src/aes_field.rs b/crates/field/src/aes_field.rs index c69b665a..74f7e76d 100644 --- a/crates/field/src/aes_field.rs +++ b/crates/field/src/aes_field.rs @@ -2,13 +2,16 @@ use std::{ any::TypeId, - array, fmt::{Debug, Display, Formatter}, iter::{Product, Sum}, marker::PhantomData, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; +use binius_utils::{ + bytes::{Buf, BufMut}, + DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, +}; use bytemuck::{Pod, Zeroable}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; @@ -79,6 +82,13 @@ impl_arithmetic_using_packed!(AESTowerField128b); impl TowerField for AESTowerField8b { type Canonical = BinaryField8b; + fn min_tower_level(self) -> usize { + match self { + Self::ZERO | Self::ONE => 0, + _ => 3, + } + } + fn mul_primitive(self, iota: usize) -> Result { match iota { 0..=1 => Ok(self * ISOMORPHIC_ALPHAS[iota]), @@ -158,8 +168,8 @@ impl Transformation for SubfieldTransformer>, - OEP: PackedExtension>, + IEP: PackedExtension, + OEP: PackedExtension, T: Transformation, PackedSubfield>, { fn transform(&self, input: &IEP) -> OEP { @@ -172,9 +182,7 @@ where pub fn make_aes_to_binary_packed_transformer() -> impl Transformation where IP: PackedExtension, - IP::Scalar: ExtensionField, OP: PackedExtension, - OP::Scalar: ExtensionField, PackedSubfield: PackedTransformationFactory>, { @@ -191,9 +199,7 @@ where pub fn make_binary_to_aes_packed_transformer() -> impl Transformation where IP: PackedExtension, - IP::Scalar: ExtensionField, OP: PackedExtension, - OP::Scalar: ExtensionField, PackedSubfield: PackedTransformationFactory>, { @@ -281,18 +287,61 @@ impl_tower_field_conversion!(AESTowerField32b, BinaryField32b); impl_tower_field_conversion!(AESTowerField64b, BinaryField64b); impl_tower_field_conversion!(AESTowerField128b, BinaryField128b); +macro_rules! serialize_deserialize_non_canonical { + ($field:ident, canonical=$canonical:ident) => { + impl SerializeBytes for $field { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + match mode { + SerializationMode::Native => self.0.serialize(write_buf, mode), + SerializationMode::CanonicalTower => { + $canonical::from(*self).serialize(write_buf, mode) + } + } + } + } + + impl DeserializeBytes for $field { + fn deserialize( + read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + match mode { + SerializationMode::Native => { + Ok(Self(DeserializeBytes::deserialize(read_buf, mode)?)) + } + SerializationMode::CanonicalTower => { + Ok(Self::from($canonical::deserialize(read_buf, mode)?)) + } + } + } + } + }; +} + +serialize_deserialize_non_canonical!(AESTowerField8b, canonical = BinaryField8b); +serialize_deserialize_non_canonical!(AESTowerField16b, canonical = BinaryField16b); +serialize_deserialize_non_canonical!(AESTowerField32b, canonical = BinaryField32b); +serialize_deserialize_non_canonical!(AESTowerField64b, canonical = BinaryField64b); +serialize_deserialize_non_canonical!(AESTowerField128b, canonical = BinaryField128b); + #[cfg(test)] mod tests { - use bytes::BytesMut; + use binius_utils::{bytes::BytesMut, SerializationMode, SerializeBytes}; use proptest::{arbitrary::any, proptest}; use rand::thread_rng; use super::*; use crate::{ - binary_field::tests::is_binary_field_valid_generator, deserialize_canonical, - serialize_canonical, underlier::WithUnderlier, PackedAESBinaryField16x32b, - PackedAESBinaryField4x32b, PackedAESBinaryField8x32b, PackedBinaryField16x32b, - PackedBinaryField4x32b, PackedBinaryField8x32b, + binary_field::tests::is_binary_field_valid_generator, underlier::WithUnderlier, + PackedAESBinaryField16x32b, PackedAESBinaryField4x32b, PackedAESBinaryField8x32b, + PackedBinaryField16x32b, PackedBinaryField4x32b, PackedBinaryField8x32b, }; fn check_square(f: impl Field) { @@ -591,28 +640,24 @@ mod tests { let aes64 = ::random(&mut rng); let aes128 = ::random(&mut rng); - serialize_canonical(aes8, &mut buffer).unwrap(); - serialize_canonical(aes16, &mut buffer).unwrap(); - serialize_canonical(aes32, &mut buffer).unwrap(); - serialize_canonical(aes64, &mut buffer).unwrap(); - serialize_canonical(aes128, &mut buffer).unwrap(); + let mode = SerializationMode::CanonicalTower; + + SerializeBytes::serialize(&aes8, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&aes16, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&aes32, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&aes64, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&aes128, &mut buffer, mode).unwrap(); - serialize_canonical(aes128, &mut buffer).unwrap(); + SerializeBytes::serialize(&aes128, &mut buffer, mode).unwrap(); let mut read_buffer = buffer.freeze(); - assert_eq!(deserialize_canonical::(&mut read_buffer).unwrap(), aes8); - assert_eq!(deserialize_canonical::(&mut read_buffer).unwrap(), aes16); - assert_eq!(deserialize_canonical::(&mut read_buffer).unwrap(), aes32); - assert_eq!(deserialize_canonical::(&mut read_buffer).unwrap(), aes64); - assert_eq!( - deserialize_canonical::(&mut read_buffer).unwrap(), - aes128 - ); - - assert_eq!( - deserialize_canonical::(&mut read_buffer).unwrap(), - aes128.into() - ) + assert_eq!(AESTowerField8b::deserialize(&mut read_buffer, mode).unwrap(), aes8); + assert_eq!(AESTowerField16b::deserialize(&mut read_buffer, mode).unwrap(), aes16); + assert_eq!(AESTowerField32b::deserialize(&mut read_buffer, mode).unwrap(), aes32); + assert_eq!(AESTowerField64b::deserialize(&mut read_buffer, mode).unwrap(), aes64); + assert_eq!(AESTowerField128b::deserialize(&mut read_buffer, mode).unwrap(), aes128); + + assert_eq!(BinaryField128b::deserialize(&mut read_buffer, mode).unwrap(), aes128.into()) } } diff --git a/crates/field/src/arch/aarch64/m128.rs b/crates/field/src/arch/aarch64/m128.rs index a32a82b7..70c49660 100644 --- a/crates/field/src/arch/aarch64/m128.rs +++ b/crates/field/src/arch/aarch64/m128.rs @@ -19,8 +19,8 @@ use crate::{ arch::binary_utils::{as_array_mut, as_array_ref}, arithmetic_traits::Broadcast, underlier::{ - impl_divisible, impl_iteration, NumCast, Random, SmallU, UnderlierType, - UnderlierWithBitOps, WithUnderlier, U1, U2, U4, + impl_divisible, impl_iteration, unpack_lo_128b_fallback, NumCast, Random, SmallU, + UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4, }, BinaryField, }; @@ -337,6 +337,40 @@ impl UnderlierWithBitOps for M128 { _ => panic!("unsupported bit count"), } } + + #[inline(always)] + fn shl_128b_lanes(self, rhs: usize) -> Self { + Self(self.0 << rhs) + } + + #[inline(always)] + fn shr_128b_lanes(self, rhs: usize) -> Self { + Self(self.0 >> rhs) + } + + #[inline(always)] + fn unpack_lo_128b_lanes(self, rhs: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, rhs, log_block_len), + 3 => unsafe { vzip1q_u8(self.into(), rhs.into()).into() }, + 4 => unsafe { vzip1q_u16(self.into(), rhs.into()).into() }, + 5 => unsafe { vzip1q_u32(self.into(), rhs.into()).into() }, + 6 => unsafe { vzip1q_u64(self.into(), rhs.into()).into() }, + _ => panic!("Unsupported block length"), + } + } + + #[inline(always)] + fn unpack_hi_128b_lanes(self, rhs: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, rhs, log_block_len), + 3 => unsafe { vzip2q_u8(self.into(), rhs.into()).into() }, + 4 => unsafe { vzip2q_u16(self.into(), rhs.into()).into() }, + 5 => unsafe { vzip2q_u32(self.into(), rhs.into()).into() }, + 6 => unsafe { vzip2q_u64(self.into(), rhs.into()).into() }, + _ => panic!("Unsupported block length"), + } + } } impl UnderlierWithBitConstants for M128 { @@ -401,6 +435,37 @@ impl UnderlierWithBitConstants for M128 { } } } + + #[inline] + fn transpose(self, other: Self, log_block_len: usize) -> (Self, Self) { + unsafe { + match log_block_len { + 0..=3 => { + let (a, b) = (self.into(), other.into()); + let (mut a, mut b) = (Self::from(vuzp1q_u8(a, b)), Self::from(vuzp2q_u8(a, b))); + + for log_block_len in (log_block_len..3).rev() { + (a, b) = a.interleave(b, log_block_len); + } + + (a, b) + } + 4 => { + let (a, b) = (self.into(), other.into()); + (vuzp1q_u16(a, b).into(), vuzp2q_u16(a, b).into()) + } + 5 => { + let (a, b) = (self.into(), other.into()); + (vuzp1q_u32(a, b).into(), vuzp2q_u32(a, b).into()) + } + 6 => { + let (a, b) = (self.into(), other.into()); + (vuzp1q_u64(a, b).into(), vuzp2q_u64(a, b).into()) + } + _ => panic!("Unsupported block length"), + } + } + } } impl From for PackedPrimitiveType { diff --git a/crates/field/src/arch/portable/byte_sliced/invert.rs b/crates/field/src/arch/portable/byte_sliced/invert.rs index 3544cbbd..8581e566 100644 --- a/crates/field/src/arch/portable/byte_sliced/invert.rs +++ b/crates/field/src/arch/portable/byte_sliced/invert.rs @@ -6,25 +6,24 @@ use super::{ use crate::{ tower_levels::{TowerLevel, TowerLevelWithArithOps}, underlier::WithUnderlier, - AESTowerField8b, PackedAESBinaryField32x8b, PackedField, + AESTowerField8b, PackedField, }; #[inline(always)] -pub fn invert_or_zero>( - field_element: &Level::Data, - destination: &mut Level::Data, +pub fn invert_or_zero, Level: TowerLevel>( + field_element: &Level::Data

, + destination: &mut Level::Data

, ) { - let base_alpha = - PackedAESBinaryField32x8b::from_scalars([AESTowerField8b::from_underlier(0xd3); 32]); + let base_alpha = P::broadcast(AESTowerField8b::from_underlier(0xd3)); - inv_main::(field_element, destination, base_alpha); + inv_main::(field_element, destination, base_alpha); } #[inline(always)] -fn inv_main>( - field_element: &Level::Data, - destination: &mut Level::Data, - base_alpha: PackedAESBinaryField32x8b, +fn inv_main, Level: TowerLevel>( + field_element: &Level::Data

, + destination: &mut Level::Data

, + base_alpha: P, ) { if Level::WIDTH == 1 { destination.as_mut()[0] = field_element.as_ref()[0].invert_or_zero(); @@ -35,36 +34,30 @@ fn inv_main>( let (result0, result1) = Level::split_mut(destination); - let mut intermediate = <>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut intermediate = <::Base as TowerLevel>::default(); // intermediate = subfield_alpha*a1 - mul_alpha::(a1, &mut intermediate, base_alpha); + mul_alpha::(a1, &mut intermediate, base_alpha); // intermediate = a0 + subfield_alpha*a1 Level::Base::add_into(a0, &mut intermediate); - let mut delta = <>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut delta = <::Base as TowerLevel>::default(); // delta = intermediate * a0 - mul_main::(&intermediate, a0, &mut delta, base_alpha); + mul_main::(&intermediate, a0, &mut delta, base_alpha); // delta = intermediate * a0 + a1^2 - square_main::(a1, &mut delta, base_alpha); + square_main::(a1, &mut delta, base_alpha); - let mut delta_inv = <>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut delta_inv = <::Base as TowerLevel>::default(); // delta_inv = 1/delta - inv_main::(&delta, &mut delta_inv, base_alpha); + inv_main::(&delta, &mut delta_inv, base_alpha); // result0 = delta_inv*intermediate - mul_main::(&delta_inv, &intermediate, result0, base_alpha); + mul_main::(&delta_inv, &intermediate, result0, base_alpha); // result1 = delta_inv*intermediate - mul_main::(&delta_inv, a1, result1, base_alpha); + mul_main::(&delta_inv, a1, result1, base_alpha); } diff --git a/crates/field/src/arch/portable/byte_sliced/mod.rs b/crates/field/src/arch/portable/byte_sliced/mod.rs index 0c42539b..0f29290a 100644 --- a/crates/field/src/arch/portable/byte_sliced/mod.rs +++ b/crates/field/src/arch/portable/byte_sliced/mod.rs @@ -15,8 +15,8 @@ pub mod tests { use proptest::prelude::*; use crate::{$scalar_type, underlier::WithUnderlier, packed::PackedField, arch::byte_sliced::$name}; - fn scalar_array_strategy() -> impl Strategy { - any::<[<$scalar_type as WithUnderlier>::Underlier; 32]>().prop_map(|arr| arr.map(<$scalar_type>::from_underlier)) + fn scalar_array_strategy() -> impl Strategy::WIDTH]> { + any::<[<$scalar_type as WithUnderlier>::Underlier; <$name>::WIDTH]>().prop_map(|arr| arr.map(<$scalar_type>::from_underlier)) } proptest! { @@ -27,7 +27,7 @@ pub mod tests { let bytesliced_result = bytesliced_a + bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_result.get(i)); } } @@ -39,7 +39,7 @@ pub mod tests { bytesliced_a += bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_a.get(i)); } } @@ -51,7 +51,7 @@ pub mod tests { let bytesliced_result = bytesliced_a - bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_result.get(i)); } } @@ -63,7 +63,7 @@ pub mod tests { bytesliced_a -= bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_a.get(i)); } } @@ -75,7 +75,7 @@ pub mod tests { let bytesliced_result = bytesliced_a * bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_result.get(i)); } } @@ -87,7 +87,7 @@ pub mod tests { bytesliced_a *= bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_a.get(i)); } } @@ -118,9 +118,21 @@ pub mod tests { }; } + define_byte_sliced_test!(tests_16x128, ByteSlicedAES16x128b, AESTowerField128b); + define_byte_sliced_test!(tests_16x64, ByteSlicedAES16x64b, AESTowerField64b); + define_byte_sliced_test!(tests_16x32, ByteSlicedAES16x32b, AESTowerField32b); + define_byte_sliced_test!(tests_16x16, ByteSlicedAES16x16b, AESTowerField16b); + define_byte_sliced_test!(tests_16x8, ByteSlicedAES16x8b, AESTowerField8b); + define_byte_sliced_test!(tests_32x128, ByteSlicedAES32x128b, AESTowerField128b); define_byte_sliced_test!(tests_32x64, ByteSlicedAES32x64b, AESTowerField64b); define_byte_sliced_test!(tests_32x32, ByteSlicedAES32x32b, AESTowerField32b); define_byte_sliced_test!(tests_32x16, ByteSlicedAES32x16b, AESTowerField16b); define_byte_sliced_test!(tests_32x8, ByteSlicedAES32x8b, AESTowerField8b); + + define_byte_sliced_test!(tests_64x128, ByteSlicedAES64x128b, AESTowerField128b); + define_byte_sliced_test!(tests_64x64, ByteSlicedAES64x64b, AESTowerField64b); + define_byte_sliced_test!(tests_64x32, ByteSlicedAES64x32b, AESTowerField32b); + define_byte_sliced_test!(tests_64x16, ByteSlicedAES64x16b, AESTowerField16b); + define_byte_sliced_test!(tests_64x8, ByteSlicedAES64x8b, AESTowerField8b); } diff --git a/crates/field/src/arch/portable/byte_sliced/multiply.rs b/crates/field/src/arch/portable/byte_sliced/multiply.rs index cfaf7339..c2037703 100644 --- a/crates/field/src/arch/portable/byte_sliced/multiply.rs +++ b/crates/field/src/arch/portable/byte_sliced/multiply.rs @@ -2,25 +2,28 @@ use crate::{ tower_levels::{TowerLevel, TowerLevelWithArithOps}, underlier::WithUnderlier, - AESTowerField8b, PackedAESBinaryField32x8b, PackedField, + AESTowerField8b, PackedField, }; #[inline(always)] -pub fn mul>( - field_element_a: &Level::Data, - field_element_b: &Level::Data, - destination: &mut Level::Data, +pub fn mul, Level: TowerLevel>( + field_element_a: &Level::Data

, + field_element_b: &Level::Data

, + destination: &mut Level::Data

, ) { - let base_alpha = - PackedAESBinaryField32x8b::from_scalars([AESTowerField8b::from_underlier(0xd3); 32]); - mul_main::(field_element_a, field_element_b, destination, base_alpha); + let base_alpha = P::broadcast(AESTowerField8b::from_underlier(0xd3)); + mul_main::(field_element_a, field_element_b, destination, base_alpha); } #[inline(always)] -pub fn mul_alpha>( - field_element: &Level::Data, - destination: &mut Level::Data, - base_alpha: PackedAESBinaryField32x8b, +pub fn mul_alpha< + const WRITING_TO_ZEROS: bool, + P: PackedField, + Level: TowerLevel, +>( + field_element: &Level::Data

, + destination: &mut Level::Data

, + base_alpha: P, ) { if Level::WIDTH == 1 { if WRITING_TO_ZEROS { @@ -49,15 +52,19 @@ pub fn mul_alpha(a1, result1, base_alpha); + mul_alpha::(a1, result1, base_alpha); } #[inline(always)] -pub fn mul_main>( - field_element_a: &Level::Data, - field_element_b: &Level::Data, - destination: &mut Level::Data, - base_alpha: PackedAESBinaryField32x8b, +pub fn mul_main< + const WRITING_TO_ZEROS: bool, + P: PackedField, + Level: TowerLevel, +>( + field_element_a: &Level::Data

, + field_element_b: &Level::Data

, + destination: &mut Level::Data

, + base_alpha: P, ) { if Level::WIDTH == 1 { if WRITING_TO_ZEROS { @@ -78,21 +85,19 @@ pub fn mul_main>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut z2_z0 = <::Base as TowerLevel>::default(); // z2_z0 = z2 - mul_main::(a1, b1, &mut z2_z0, base_alpha); + mul_main::(a1, b1, &mut z2_z0, base_alpha); // result1 = z2 * alpha - mul_alpha::(&z2_z0, result1, base_alpha); + mul_alpha::(&z2_z0, result1, base_alpha); // z2_z0 = z2 + z0 - mul_main::(a0, b0, &mut z2_z0, base_alpha); + mul_main::(a0, b0, &mut z2_z0, base_alpha); // result1 = z1 + z2 * alpha - mul_main::(&xored_halves_a, &xored_halves_b, result1, base_alpha); + mul_main::(&xored_halves_a, &xored_halves_b, result1, base_alpha); // result1 = z2+ z0+ z1 + z2 * alpha Level::Base::add_into(&z2_z0, result1); diff --git a/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs b/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs index 7746795d..40e9648a 100644 --- a/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs +++ b/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs @@ -7,7 +7,7 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; -use bytemuck::Zeroable; +use bytemuck::{Pod, Zeroable}; use super::{invert::invert_or_zero, multiply::mul, square::square}; use crate::{ @@ -15,7 +15,7 @@ use crate::{ tower_levels::*, underlier::{UnderlierWithBitOps, WithUnderlier}, AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField8b, - PackedField, + PackedAESBinaryField16x8b, PackedAESBinaryField64x8b, PackedField, }; /// Represents 32 AES Tower Field elements in byte-sliced form backed by Packed 32x8b AES fields. @@ -24,16 +24,15 @@ use crate::{ /// multiplication circuit on GFNI machines, since multiplication of two 32x8b field elements is /// handled in one instruction. macro_rules! define_byte_sliced { - ($name:ident, $scalar_type:ty, $tower_level: ty) => { - #[derive(Default, Clone, Debug, Copy, PartialEq, Eq, Zeroable)] + ($name:ident, $scalar_type:ty, $packed_storage:ty, $tower_level: ty) => { + #[derive(Default, Clone, Debug, Copy, PartialEq, Eq, Pod, Zeroable)] + #[repr(transparent)] pub struct $name { - pub(super) data: [PackedAESBinaryField32x8b; - <$tower_level as TowerLevel>::WIDTH], + pub(super) data: [$packed_storage; <$tower_level as TowerLevel>::WIDTH], } impl $name { - pub const BYTES: usize = PackedAESBinaryField32x8b::WIDTH - * <$tower_level as TowerLevel>::WIDTH; + pub const BYTES: usize = <$packed_storage>::WIDTH * <$tower_level as TowerLevel>::WIDTH; /// Get the byte at the given index. /// @@ -41,11 +40,8 @@ macro_rules! define_byte_sliced { /// The caller must ensure that `byte_index` is less than `BYTES`. #[allow(clippy::modulo_one)] pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 { - self.data - [byte_index % <$tower_level as TowerLevel>::WIDTH] - .get( - byte_index / <$tower_level as TowerLevel>::WIDTH, - ) + self.data[byte_index % <$tower_level as TowerLevel>::WIDTH] + .get(byte_index / <$tower_level as TowerLevel>::WIDTH) .to_underlier() } } @@ -53,28 +49,26 @@ macro_rules! define_byte_sliced { impl PackedField for $name { type Scalar = $scalar_type; - const LOG_WIDTH: usize = 5; + const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH; + #[inline(always)] unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar { - let mut result_underlier = 0; - for (byte_index, val) in self.data.iter().enumerate() { - // Safety: - // - `byte_index` is less than 16 - // - `i` must be less than 32 due to safety conditions of this method - unsafe { - result_underlier - .set_subvalue(byte_index, val.get_unchecked(i).to_underlier()) - } - } + let result_underlier = + ::Underlier::from_fn(|byte_index| unsafe { + self.data + .get_unchecked(byte_index) + .get_unchecked(i) + .to_underlier() + }); Self::Scalar::from_underlier(result_underlier) } + #[inline(always)] unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) { let underlier = scalar.to_underlier(); - for byte_index in 0..<$tower_level as TowerLevel>::WIDTH - { + for byte_index in 0..<$tower_level as TowerLevel>::WIDTH { self.data[byte_index].set_unchecked( i, AESTowerField8b::from_underlier(underlier.get_subvalue(byte_index)), @@ -86,16 +80,18 @@ macro_rules! define_byte_sliced { Self::from_scalars([Self::Scalar::random(rng); 32]) } + #[inline] fn broadcast(scalar: Self::Scalar) -> Self { Self { data: array::from_fn(|byte_index| { - PackedAESBinaryField32x8b::broadcast(AESTowerField8b::from_underlier( - unsafe { scalar.to_underlier().get_subvalue(byte_index) }, - )) + <$packed_storage>::broadcast(AESTowerField8b::from_underlier(unsafe { + scalar.to_underlier().get_subvalue(byte_index) + })) }), } } + #[inline] fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self { let mut result = Self::default(); @@ -107,30 +103,43 @@ macro_rules! define_byte_sliced { result } + #[inline] fn square(self) -> Self { let mut result = Self::default(); - square::<$tower_level>(&self.data, &mut result.data); + square::<$packed_storage, $tower_level>(&self.data, &mut result.data); result } + #[inline] fn invert_or_zero(self) -> Self { let mut result = Self::default(); - invert_or_zero::<$tower_level>(&self.data, &mut result.data); + invert_or_zero::<$packed_storage, $tower_level>(&self.data, &mut result.data); result } + #[inline] fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) { let mut result1 = Self::default(); let mut result2 = Self::default(); - for byte_num in 0..<$tower_level as TowerLevel>::WIDTH { - let (this_byte_result1, this_byte_result2) = + for byte_num in 0..<$tower_level as TowerLevel>::WIDTH { + (result1.data[byte_num], result2.data[byte_num]) = self.data[byte_num].interleave(other.data[byte_num], log_block_len); + } + + (result1, result2) + } - result1.data[byte_num] = this_byte_result1; - result2.data[byte_num] = this_byte_result2; + #[inline] + fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) { + let mut result1 = Self::default(); + let mut result2 = Self::default(); + + for byte_num in 0..<$tower_level as TowerLevel>::WIDTH { + (result1.data[byte_num], result2.data[byte_num]) = + self.data[byte_num].unzip(other.data[byte_num], log_block_len); } (result1, result2) @@ -203,12 +212,9 @@ macro_rules! define_byte_sliced { type Output = Self; fn mul(self, rhs: Self) -> Self { - let mut result = $name { - data: [PackedAESBinaryField32x8b::default(); - <$tower_level as TowerLevel>::WIDTH], - }; + let mut result = Self::default(); - mul::<$tower_level>(&self.data, &rhs.data, &mut result.data); + mul::<$packed_storage, $tower_level>(&self.data, &rhs.data, &mut result.data); result } @@ -267,8 +273,38 @@ macro_rules! define_byte_sliced { }; } -define_byte_sliced!(ByteSlicedAES32x128b, AESTowerField128b, TowerLevel16); -define_byte_sliced!(ByteSlicedAES32x64b, AESTowerField64b, TowerLevel8); -define_byte_sliced!(ByteSlicedAES32x32b, AESTowerField32b, TowerLevel4); -define_byte_sliced!(ByteSlicedAES32x16b, AESTowerField16b, TowerLevel2); -define_byte_sliced!(ByteSlicedAES32x8b, AESTowerField8b, TowerLevel1); +// 128 bit +define_byte_sliced!( + ByteSlicedAES16x128b, + AESTowerField128b, + PackedAESBinaryField16x8b, + TowerLevel16 +); +define_byte_sliced!(ByteSlicedAES16x64b, AESTowerField64b, PackedAESBinaryField16x8b, TowerLevel8); +define_byte_sliced!(ByteSlicedAES16x32b, AESTowerField32b, PackedAESBinaryField16x8b, TowerLevel4); +define_byte_sliced!(ByteSlicedAES16x16b, AESTowerField16b, PackedAESBinaryField16x8b, TowerLevel2); +define_byte_sliced!(ByteSlicedAES16x8b, AESTowerField8b, PackedAESBinaryField16x8b, TowerLevel1); + +// 256 bit +define_byte_sliced!( + ByteSlicedAES32x128b, + AESTowerField128b, + PackedAESBinaryField32x8b, + TowerLevel16 +); +define_byte_sliced!(ByteSlicedAES32x64b, AESTowerField64b, PackedAESBinaryField32x8b, TowerLevel8); +define_byte_sliced!(ByteSlicedAES32x32b, AESTowerField32b, PackedAESBinaryField32x8b, TowerLevel4); +define_byte_sliced!(ByteSlicedAES32x16b, AESTowerField16b, PackedAESBinaryField32x8b, TowerLevel2); +define_byte_sliced!(ByteSlicedAES32x8b, AESTowerField8b, PackedAESBinaryField32x8b, TowerLevel1); + +// 512 bit +define_byte_sliced!( + ByteSlicedAES64x128b, + AESTowerField128b, + PackedAESBinaryField64x8b, + TowerLevel16 +); +define_byte_sliced!(ByteSlicedAES64x64b, AESTowerField64b, PackedAESBinaryField64x8b, TowerLevel8); +define_byte_sliced!(ByteSlicedAES64x32b, AESTowerField32b, PackedAESBinaryField64x8b, TowerLevel4); +define_byte_sliced!(ByteSlicedAES64x16b, AESTowerField16b, PackedAESBinaryField64x8b, TowerLevel2); +define_byte_sliced!(ByteSlicedAES64x8b, AESTowerField8b, PackedAESBinaryField64x8b, TowerLevel1); diff --git a/crates/field/src/arch/portable/byte_sliced/square.rs b/crates/field/src/arch/portable/byte_sliced/square.rs index bcd9514a..0c0b6ab7 100644 --- a/crates/field/src/arch/portable/byte_sliced/square.rs +++ b/crates/field/src/arch/portable/byte_sliced/square.rs @@ -3,24 +3,27 @@ use super::multiply::mul_alpha; use crate::{ tower_levels::{TowerLevel, TowerLevelWithArithOps}, underlier::WithUnderlier, - AESTowerField8b, PackedAESBinaryField32x8b, PackedField, + AESTowerField8b, PackedField, }; #[inline(always)] -pub fn square>( - field_element: &Level::Data, - destination: &mut Level::Data, +pub fn square, Level: TowerLevel>( + field_element: &Level::Data

, + destination: &mut Level::Data

, ) { - let base_alpha = - PackedAESBinaryField32x8b::from_scalars([AESTowerField8b::from_underlier(0xd3); 32]); - square_main::(field_element, destination, base_alpha); + let base_alpha = P::broadcast(AESTowerField8b::from_underlier(0xd3)); + square_main::(field_element, destination, base_alpha); } #[inline(always)] -pub fn square_main>( - field_element: &Level::Data, - destination: &mut Level::Data, - base_alpha: PackedAESBinaryField32x8b, +pub fn square_main< + const WRITING_TO_ZEROS: bool, + P: PackedField, + Level: TowerLevel, +>( + field_element: &Level::Data

, + destination: &mut Level::Data

, + base_alpha: P, ) { if Level::WIDTH == 1 { if WRITING_TO_ZEROS { @@ -34,15 +37,13 @@ pub fn square_main>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut a1_squared = <::Base as TowerLevel>::default(); - square_main::(a1, &mut a1_squared, base_alpha); + square_main::(a1, &mut a1_squared, base_alpha); - mul_alpha::(&a1_squared, result1, base_alpha); + mul_alpha::(&a1_squared, result1, base_alpha); - square_main::(a0, result0, base_alpha); + square_main::(a0, result0, base_alpha); Level::Base::add_into(&a1_squared, result0); } diff --git a/crates/field/src/arch/portable/packed.rs b/crates/field/src/arch/portable/packed.rs index ad372dd9..66bbb965 100644 --- a/crates/field/src/arch/portable/packed.rs +++ b/crates/field/src/arch/portable/packed.rs @@ -336,6 +336,14 @@ where (c.into(), d.into()) } + #[inline] + fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) { + assert!(log_block_len < Self::LOG_WIDTH); + let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize; + let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len); + (c.into(), d.into()) + } + #[inline] unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self { debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH); diff --git a/crates/field/src/arch/portable/packed_arithmetic.rs b/crates/field/src/arch/portable/packed_arithmetic.rs index c15f65fa..98d0f130 100644 --- a/crates/field/src/arch/portable/packed_arithmetic.rs +++ b/crates/field/src/arch/portable/packed_arithmetic.rs @@ -37,6 +37,18 @@ where (c, d) } + + /// Transpose with the given bit size + fn transpose(mut self, mut other: Self, log_block_len: usize) -> (Self, Self) { + // There are 2^7 = 128 bits in a u128 + assert!(log_block_len < Self::INTERLEAVE_EVEN_MASK.len()); + + for log_block_len in (log_block_len..Self::LOG_BITS).rev() { + (self, other) = self.interleave(other, log_block_len); + } + + (self, other) + } } /// Abstraction for a packed tower field of height greater than 0. @@ -331,7 +343,7 @@ where OP: PackedBinaryField, { pub fn new + Sync>( - transformation: FieldLinearTransformation, + transformation: &FieldLinearTransformation, ) -> Self { Self { bases: transformation @@ -387,7 +399,7 @@ where fn make_packed_transformation + Sync>( transformation: FieldLinearTransformation, ) -> Self::PackedTransformation { - PackedTransformation::new(transformation) + PackedTransformation::new(&transformation) } } diff --git a/crates/field/src/arch/portable/packed_scaled.rs b/crates/field/src/arch/portable/packed_scaled.rs index a4b55a2c..8417939e 100644 --- a/crates/field/src/arch/portable/packed_scaled.rs +++ b/crates/field/src/arch/portable/packed_scaled.rs @@ -216,14 +216,17 @@ where Self(array::from_fn(|_| PT::random(&mut rng))) } + #[inline] fn broadcast(scalar: Self::Scalar) -> Self { Self(array::from_fn(|_| PT::broadcast(scalar))) } + #[inline] fn square(self) -> Self { Self(self.0.map(|v| v.square())) } + #[inline] fn invert_or_zero(self) -> Self { Self(self.0.map(|v| v.invert_or_zero())) } @@ -253,6 +256,40 @@ where (Self(first), Self(second)) } + fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) { + let mut first = [Default::default(); N]; + let mut second = [Default::default(); N]; + + if log_block_len >= PT::LOG_WIDTH { + let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH); + for i in (0..N / 2).step_by(block_in_pts) { + first[i..i + block_in_pts].copy_from_slice(&self.0[2 * i..2 * i + block_in_pts]); + + second[i..i + block_in_pts] + .copy_from_slice(&self.0[2 * i + block_in_pts..2 * (i + block_in_pts)]); + } + + for i in (0..N / 2).step_by(block_in_pts) { + first[i + N / 2..i + N / 2 + block_in_pts] + .copy_from_slice(&other.0[2 * i..2 * i + block_in_pts]); + + second[i + N / 2..i + N / 2 + block_in_pts] + .copy_from_slice(&other.0[2 * i + block_in_pts..2 * (i + block_in_pts)]); + } + } else { + for i in 0..N / 2 { + (first[i], second[i]) = self.0[2 * i].unzip(self.0[2 * i + 1], log_block_len); + } + + for i in 0..N / 2 { + (first[i + N / 2], second[i + N / 2]) = + other.0[2 * i].unzip(other.0[2 * i + 1], log_block_len); + } + } + + (Self(first), Self(second)) + } + #[inline] unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self { let log_n = checked_log_2(N); @@ -360,20 +397,23 @@ macro_rules! packed_scaled_field { impl std::ops::Add<<$inner as $crate::packed::PackedField>::Scalar> for $name { type Output = Self; - fn add(self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { - let mut result = Self::default(); - for i in 0..Self::WIDTH_IN_PT { - result.0[i] = self.0[i] + rhs; + #[inline] + fn add(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v += broadcast; } - result + self } } impl std::ops::AddAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name { + #[inline] fn add_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) { - for i in 0..Self::WIDTH_IN_PT { - self.0[i] += rhs; + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v += broadcast; } } } @@ -381,20 +421,23 @@ macro_rules! packed_scaled_field { impl std::ops::Sub<<$inner as $crate::packed::PackedField>::Scalar> for $name { type Output = Self; - fn sub(self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { - let mut result = Self::default(); - for i in 0..Self::WIDTH_IN_PT { - result.0[i] = self.0[i] - rhs; + #[inline] + fn sub(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v -= broadcast; } - result + self } } impl std::ops::SubAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name { + #[inline] fn sub_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) { - for i in 0..Self::WIDTH_IN_PT { - self.0[i] -= rhs; + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v -= broadcast; } } } @@ -402,20 +445,23 @@ macro_rules! packed_scaled_field { impl std::ops::Mul<<$inner as $crate::packed::PackedField>::Scalar> for $name { type Output = Self; - fn mul(self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { - let mut result = Self::default(); - for i in 0..Self::WIDTH_IN_PT { - result.0[i] = self.0[i] * rhs; + #[inline] + fn mul(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v *= broadcast; } - result + self } } impl std::ops::MulAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name { + #[inline] fn mul_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) { - for i in 0..Self::WIDTH_IN_PT { - self.0[i] *= rhs; + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v *= broadcast; } } } diff --git a/crates/field/src/arch/x86_64/m128.rs b/crates/field/src/arch/x86_64/m128.rs index f75d152d..d1d715f6 100644 --- a/crates/field/src/arch/x86_64/m128.rs +++ b/crates/field/src/arch/x86_64/m128.rs @@ -13,7 +13,7 @@ use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use crate::{ arch::{ - binary_utils::{as_array_mut, make_func_to_i8}, + binary_utils::{as_array_mut, as_array_ref, make_func_to_i8}, portable::{ packed::{impl_pack_scalar, PackedPrimitiveType}, packed_arithmetic::{ @@ -23,8 +23,9 @@ use crate::{ }, arithmetic_traits::Broadcast, underlier::{ - impl_divisible, impl_iteration, spread_fallback, NumCast, Random, SmallU, SpreadToByte, - UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4, + impl_divisible, impl_iteration, spread_fallback, unpack_hi_128b_fallback, + unpack_lo_128b_fallback, NumCast, Random, SmallU, SpreadToByte, UnderlierType, + UnderlierWithBitOps, WithUnderlier, U1, U2, U4, }, BinaryField, }; @@ -181,7 +182,7 @@ impl Not for M128 { } /// `std::cmp::max` isn't const, so we need our own implementation -const fn max_i32(left: i32, right: i32) -> i32 { +pub(crate) const fn max_i32(left: i32, right: i32) -> i32 { if left > right { left } else { @@ -193,22 +194,37 @@ const fn max_i32(left: i32, right: i32) -> i32 { /// We have to use macro because parameter `count` in _mm_slli_epi64/_mm_srli_epi64 should be passed as constant /// and Rust currently doesn't allow passing expressions (`count - 64`) where variable is a generic constant parameter. /// Source: https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688 -macro_rules! bitshift_right { - ($val:expr, $count:literal) => { +macro_rules! bitshift_128b { + ($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => { unsafe { - let carry = _mm_bsrli_si128($val, 8); - if $count >= 64 { - _mm_srli_epi64(carry, max_i32($count - 64, 0)) - } else { - let carry = _mm_slli_epi64(carry, max_i32(64 - $count, 0)); - - let val = _mm_srli_epi64($val, $count); - _mm_or_si128(val, carry) - } + let carry = $byte_shift($val, 8); + seq!(N in 64..128 { + if $shift == N { + return $bit_shift_64( + carry, + crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _, + ).into(); + } + }); + seq!(N in 0..64 { + if $shift == N { + let carry = $bit_shift_64_opposite( + carry, + crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _, + ); + + let val = $bit_shift_64($val, N); + return $or(val, carry).into(); + } + }); + + return Default::default() } }; } +pub(crate) use bitshift_128b; + impl Shr for M128 { type Output = Self; @@ -216,32 +232,10 @@ impl Shr for M128 { fn shr(self, rhs: usize) -> Self::Output { // This implementation is effective when `rhs` is known at compile-time. // In our code this is always the case. - seq!(N in 0..128 { - if rhs == N { - return Self(bitshift_right!(self.0, N)); - } - }); - - Self::default() + bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128) } } -macro_rules! bitshift_left { - ($val:expr, $count:literal) => { - unsafe { - let carry = _mm_bslli_si128($val, 8); - if $count >= 64 { - _mm_slli_epi64(carry, max_i32($count - 64, 0)) - } else { - let carry = _mm_srli_epi64(carry, max_i32(64 - $count, 0)); - - let val = _mm_slli_epi64($val, $count); - _mm_or_si128(val, carry) - } - } - }; -} - impl Shl for M128 { type Output = Self; @@ -249,13 +243,7 @@ impl Shl for M128 { fn shl(self, rhs: usize) -> Self::Output { // This implementation is effective when `rhs` is known at compile-time. // In our code this is always the case. - seq!(N in 0..128 { - if rhs == N { - return Self(bitshift_left!(self.0, N)); - } - }); - - Self::default() + bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128); } } @@ -420,39 +408,41 @@ impl UnderlierWithBitOps for M128 { #[inline(always)] unsafe fn get_subvalue(&self, i: usize) -> T where - T: WithUnderlier, - T::Underlier: NumCast, + T: UnderlierType + NumCast, { - match T::Underlier::BITS { - 1 | 2 | 4 | 8 | 16 | 32 | 64 => { - let elements_in_64 = 64 / T::Underlier::BITS; - let chunk_64 = unsafe { - if i >= elements_in_64 { - _mm_extract_epi64(self.0, 1) - } else { - _mm_extract_epi64(self.0, 0) - } - }; - - let result_64 = if T::Underlier::BITS == 64 { - chunk_64 - } else { - let ones = ((1u128 << T::Underlier::BITS) - 1) as u64; - let val_64 = (chunk_64 as u64) - >> (T::Underlier::BITS - * (if i >= elements_in_64 { - i - elements_in_64 - } else { - i - })) & ones; - - val_64 as i64 - }; - T::from_underlier(T::Underlier::num_cast_from(Self(unsafe { - _mm_set_epi64x(0, result_64) - }))) + match T::BITS { + 1 | 2 | 4 => { + let elements_in_8 = 8 / T::BITS; + let mut value_u8 = as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { + *arr.get_unchecked(i / elements_in_8) + }); + + let shift = (i % elements_in_8) * T::BITS; + value_u8 >>= shift; + + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 8 => { + let value_u8 = + as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 16 => { + let value_u16 = + as_array_ref::<_, u16, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u16))) + } + 32 => { + let value_u32 = + as_array_ref::<_, u32, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u32))) } - 128 => T::from_underlier(T::Underlier::num_cast_from(*self)), + 64 => { + let value_u64 = + as_array_ref::<_, u64, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u64))) + } + 128 => T::from_underlier(T::num_cast_from(*self)), _ => panic!("unsupported bit count"), } } @@ -471,23 +461,23 @@ impl UnderlierWithBitOps for M128 { let val = u8::num_cast_from(Self::from(val)) << shift; let mask = mask << shift; - as_array_mut::<_, u8, 16>(self, |array| { - let element = &mut array[i / elements_in_8]; + as_array_mut::<_, u8, 16>(self, |array| unsafe { + let element = array.get_unchecked_mut(i / elements_in_8); *element &= !mask; *element |= val; }); } - 8 => as_array_mut::<_, u8, 16>(self, |array| { - array[i] = u8::num_cast_from(Self::from(val)); + 8 => as_array_mut::<_, u8, 16>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val)); }), - 16 => as_array_mut::<_, u16, 8>(self, |array| { - array[i] = u16::num_cast_from(Self::from(val)); + 16 => as_array_mut::<_, u16, 8>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val)); }), - 32 => as_array_mut::<_, u32, 4>(self, |array| { - array[i] = u32::num_cast_from(Self::from(val)); + 32 => as_array_mut::<_, u32, 4>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val)); }), - 64 => as_array_mut::<_, u64, 2>(self, |array| { - array[i] = u64::num_cast_from(Self::from(val)); + 64 => as_array_mut::<_, u64, 2>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val)); }), 128 => { *self = Self::from(val); @@ -705,6 +695,40 @@ impl UnderlierWithBitOps for M128 { _ => panic!("unsupported bit length"), } } + + #[inline] + fn shl_128b_lanes(self, shift: usize) -> Self { + self << shift + } + + #[inline] + fn shr_128b_lanes(self, shift: usize) -> Self { + self >> shift + } + + #[inline] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm_unpacklo_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm_unpacklo_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm_unpacklo_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm_unpacklo_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } + + #[inline] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_hi_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm_unpackhi_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm_unpackhi_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm_unpackhi_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm_unpackhi_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } } unsafe impl Zeroable for M128 {} @@ -922,6 +946,10 @@ mod tests { assert_eq!(M128::from(1u128), M128::ONE); } + fn get(value: M128, log_block_len: usize, index: usize) -> M128 { + (value >> (index << log_block_len)) & single_element_mask_bits::(1 << log_block_len) + } + proptest! { #[test] fn test_conversion(a in any::()) { @@ -955,15 +983,36 @@ mod tests { let (c, d) = unsafe {interleave_bits(a.0, b.0, height)}; let (c, d) = (M128::from(c), M128::from(d)); - let block_len = 1usize << height; - let get = |v, i| { - u128::num_cast_from((v >> (i * block_len)) & single_element_mask_bits::(1 << height)) - }; - for i in (0..128/block_len).step_by(2) { - assert_eq!(get(c, i), get(a, i)); - assert_eq!(get(c, i+1), get(b, i)); - assert_eq!(get(d, i), get(a, i+1)); - assert_eq!(get(d, i+1), get(b, i+1)); + for i in (0..128>>height).step_by(2) { + assert_eq!(get(c, height, i), get(a, height, i)); + assert_eq!(get(c, height, i+1), get(b, height, i)); + assert_eq!(get(d, height, i), get(a, height, i+1)); + assert_eq!(get(d, height, i+1), get(b, height, i+1)); + } + } + + #[test] + fn test_unpack_lo(a in any::(), b in any::(), height in 1usize..7) { + let a = M128::from(a); + let b = M128::from(b); + + let result = a.unpack_lo_128b_lanes(b, height); + for i in 0..128>>(height + 1) { + assert_eq!(get(result, height, 2*i), get(a, height, i)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i)); + } + } + + #[test] + fn test_unpack_hi(a in any::(), b in any::(), height in 1usize..7) { + let a = M128::from(a); + let b = M128::from(b); + + let result = a.unpack_hi_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count)); } } } diff --git a/crates/field/src/arch/x86_64/m256.rs b/crates/field/src/arch/x86_64/m256.rs index 3c36827f..a2fc71ab 100644 --- a/crates/field/src/arch/x86_64/m256.rs +++ b/crates/field/src/arch/x86_64/m256.rs @@ -9,6 +9,7 @@ use std::{ use bytemuck::{must_cast, Pod, Zeroable}; use cfg_if::cfg_if; use rand::{Rng, RngCore}; +use seq_macro::seq; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use crate::{ @@ -20,11 +21,13 @@ use crate::{ interleave_mask_even, interleave_mask_odd, UnderlierWithBitConstants, }, }, + x86_64::m128::bitshift_128b, }, arithmetic_traits::Broadcast, underlier::{ get_block_values, get_spread_bytes, impl_divisible, impl_iteration, spread_fallback, - NumCast, Random, SmallU, UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4, + unpack_hi_128b_fallback, unpack_lo_128b_fallback, NumCast, Random, SmallU, UnderlierType, + UnderlierWithBitOps, WithUnderlier, U1, U2, U4, }, BinaryField, }; @@ -323,7 +326,7 @@ impl UnderlierType for M256 { impl UnderlierWithBitOps for M256 { const ZERO: Self = { Self(m256_from_u128s!(0, 0,)) }; - const ONE: Self = { Self(m256_from_u128s!(0, 1,)) }; + const ONE: Self = { Self(m256_from_u128s!(1, 0,)) }; const ONES: Self = { Self(m256_from_u128s!(u128::MAX, u128::MAX,)) }; #[inline] @@ -463,42 +466,41 @@ impl UnderlierWithBitOps for M256 { T: UnderlierType + NumCast, { match T::BITS { - 1 | 2 | 4 | 8 | 16 | 32 => { - let elements_in_64 = 64 / T::BITS; - let chunk_64 = unsafe { - match i / elements_in_64 { - 0 => _mm256_extract_epi64(self.0, 0), - 1 => _mm256_extract_epi64(self.0, 1), - 2 => _mm256_extract_epi64(self.0, 2), - _ => _mm256_extract_epi64(self.0, 3), - } - }; + 1 | 2 | 4 => { + let elements_in_8 = 8 / T::BITS; + let mut value_u8 = as_array_ref::<_, u8, 32, _>(self, |arr| unsafe { + *arr.get_unchecked(i / elements_in_8) + }); - let result_64 = if T::BITS == 64 { - chunk_64 - } else { - let ones = ((1u128 << T::BITS) - 1) as u64; - let val_64 = (chunk_64 as u64) >> (T::BITS * (i % elements_in_64)) & ones; + let shift = (i % elements_in_8) * T::BITS; + value_u8 >>= shift; - val_64 as i64 - }; - T::num_cast_from(Self(unsafe { _mm256_set_epi64x(0, 0, 0, result_64) })) + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 8 => { + let value_u8 = + as_array_ref::<_, u8, 32, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 16 => { + let value_u16 = + as_array_ref::<_, u16, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u16))) + } + 32 => { + let value_u32 = + as_array_ref::<_, u32, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u32))) } - // NOTE: benchmark show that this strategy is optimal for getting 64-bit subvalues from 256-bit register. - // However using similar code for 1..32 bits is slower than the version above. - // Also even getting `chunk_64` in the code above using this code shows worser benchmarks results. 64 => { - T::num_cast_from(as_array_ref::<_, u64, 4, _>(self, |array| Self::from(array[i]))) + let value_u64 = + as_array_ref::<_, u64, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u64))) } 128 => { - let chunk_128 = unsafe { - if i == 0 { - _mm256_extracti128_si256(self.0, 0) - } else { - _mm256_extracti128_si256(self.0, 1) - } - }; - T::num_cast_from(Self(unsafe { _mm256_set_m128i(_mm_setzero_si128(), chunk_128) })) + let value_u128 = + as_array_ref::<_, u128, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u128))) } _ => panic!("unsupported bit count"), } @@ -518,26 +520,26 @@ impl UnderlierWithBitOps for M256 { let val = u8::num_cast_from(Self::from(val)) << shift; let mask = mask << shift; - as_array_mut::<_, u8, 32>(self, |array| { - let element = &mut array[i / elements_in_8]; + as_array_mut::<_, u8, 32>(self, |array| unsafe { + let element = array.get_unchecked_mut(i / elements_in_8); *element &= !mask; *element |= val; }); } - 8 => as_array_mut::<_, u8, 32>(self, |array| { - array[i] = u8::num_cast_from(Self::from(val)); + 8 => as_array_mut::<_, u8, 32>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val)); }), - 16 => as_array_mut::<_, u16, 16>(self, |array| { - array[i] = u16::num_cast_from(Self::from(val)); + 16 => as_array_mut::<_, u16, 16>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val)); }), - 32 => as_array_mut::<_, u32, 8>(self, |array| { - array[i] = u32::num_cast_from(Self::from(val)); + 32 => as_array_mut::<_, u32, 8>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val)); }), - 64 => as_array_mut::<_, u64, 4>(self, |array| { - array[i] = u64::num_cast_from(Self::from(val)); + 64 => as_array_mut::<_, u64, 4>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val)); }), 128 => as_array_mut::<_, u128, 2>(self, |array| { - array[i] = u128::num_cast_from(Self::from(val)); + *array.get_unchecked_mut(i) = u128::num_cast_from(Self::from(val)); }), _ => panic!("unsupported bit count"), } @@ -835,6 +837,58 @@ impl UnderlierWithBitOps for M256 { _ => spread_fallback(self, log_block_len, block_idx), } } + + #[inline] + fn shr_128b_lanes(self, rhs: usize) -> Self { + // This implementation is effective when `rhs` is known at compile-time. + // In our code this is always the case. + bitshift_128b!( + self.0, + rhs, + _mm256_bsrli_epi128, + _mm256_srli_epi64, + _mm256_slli_epi64, + _mm256_or_si256 + ) + } + + #[inline] + fn shl_128b_lanes(self, rhs: usize) -> Self { + // This implementation is effective when `rhs` is known at compile-time. + // In our code this is always the case. + bitshift_128b!( + self.0, + rhs, + _mm256_bslli_epi128, + _mm256_slli_epi64, + _mm256_srli_epi64, + _mm256_or_si256 + ); + } + + #[inline] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm256_unpacklo_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm256_unpacklo_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm256_unpacklo_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm256_unpacklo_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } + + #[inline] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_hi_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm256_unpackhi_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm256_unpackhi_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm256_unpackhi_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm256_unpackhi_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } } unsafe impl Zeroable for M256 {} @@ -910,6 +964,13 @@ impl UnderlierWithBitConstants for M256 { let (a, b) = unsafe { interleave_bits(self.0, other.0, log_block_len) }; (Self(a), Self(b)) } + + fn transpose(mut self, mut other: Self, log_block_len: usize) -> (Self, Self) { + let (a, b) = unsafe { transpose_bits(self.0, other.0, log_block_len) }; + self.0 = a; + other.0 = b; + (self, other) + } } #[inline] @@ -974,6 +1035,57 @@ unsafe fn interleave_bits(a: __m256i, b: __m256i, log_block_len: usize) -> (__m2 } } +#[inline] +unsafe fn transpose_bits(a: __m256i, b: __m256i, log_block_len: usize) -> (__m256i, __m256i) { + match log_block_len { + 0..=3 => { + let shuffle = _mm256_set_epi8( + 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0, 15, 13, 11, 9, 7, 5, 3, 1, + 14, 12, 10, 8, 6, 4, 2, 0, + ); + let (mut a, mut b) = transpose_with_shuffle(a, b, shuffle); + for log_block_len in (log_block_len..3).rev() { + (a, b) = interleave_bits(a, b, log_block_len); + } + + (a, b) + } + 4 => { + let shuffle = _mm256_set_epi8( + 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, + 13, 12, 9, 8, 5, 4, 1, 0, + ); + + transpose_with_shuffle(a, b, shuffle) + } + 5 => { + let shuffle = _mm256_set_epi8( + 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0, 15, 14, 13, 12, 7, 6, 5, 4, + 11, 10, 9, 8, 3, 2, 1, 0, + ); + + transpose_with_shuffle(a, b, shuffle) + } + 6 => { + let (a, b) = (_mm256_unpacklo_epi64(a, b), _mm256_unpackhi_epi64(a, b)); + + (_mm256_permute4x64_epi64(a, 0b11011000), _mm256_permute4x64_epi64(b, 0b11011000)) + } + 7 => (_mm256_permute2x128_si256(a, b, 0x20), _mm256_permute2x128_si256(a, b, 0x31)), + _ => panic!("unsupported block length"), + } +} + +#[inline(always)] +unsafe fn transpose_with_shuffle(a: __m256i, b: __m256i, shuffle: __m256i) -> (__m256i, __m256i) { + let a = _mm256_shuffle_epi8(a, shuffle); + let b = _mm256_shuffle_epi8(b, shuffle); + + let (a, b) = (_mm256_unpacklo_epi64(a, b), _mm256_unpackhi_epi64(a, b)); + + (_mm256_permute4x64_epi64(a, 0b11011000), _mm256_permute4x64_epi64(b, 0b11011000)) +} + #[inline] unsafe fn interleave_bits_imm( a: __m256i, @@ -1075,7 +1187,7 @@ mod tests { fn test_constants() { assert_eq!(M256::default(), M256::ZERO); assert_eq!(M256::from(0u128), M256::ZERO); - assert_eq!(M256::from([0u128, 1u128]), M256::ONE); + assert_eq!(M256::from([1u128, 0u128]), M256::ONE); } #[derive(Default)] @@ -1139,6 +1251,10 @@ mod tests { } } + fn get(value: M256, log_block_len: usize, index: usize) -> M256 { + (value >> (index << log_block_len)) & single_element_mask_bits::(1 << log_block_len) + } + proptest! { #[allow(clippy::tuple_array_conversions)] // false positive #[test] @@ -1175,14 +1291,41 @@ mod tests { let (c, d) = (M256::from(c), M256::from(d)); let block_len = 1usize << height; - let get = |v, i| { - u128::num_cast_from((v >> (i * block_len)) & single_element_mask_bits::(1 << height)) - }; for i in (0..256/block_len).step_by(2) { - assert_eq!(get(c, i), get(a, i)); - assert_eq!(get(c, i+1), get(b, i)); - assert_eq!(get(d, i), get(a, i+1)); - assert_eq!(get(d, i+1), get(b, i+1)); + assert_eq!(get(c, height, i), get(a, height, i)); + assert_eq!(get(c, height, i+1), get(b, height, i)); + assert_eq!(get(d, height, i), get(a, height, i+1)); + assert_eq!(get(d, height, i+1), get(b, height, i+1)); + } + } + + #[test] + fn test_unpack_lo(a in any::<[u128; 2]>(), b in any::<[u128; 2]>(), height in 0usize..7) { + let a = M256::from(a); + let b = M256::from(b); + + let result = a.unpack_lo_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i)); + assert_eq!(get(result, height, 2*(i + half_block_count)), get(a, height, 2 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + half_block_count)+1), get(b, height, 2 * half_block_count + i)); + } + } + + #[test] + fn test_unpack_hi(a in any::<[u128; 2]>(), b in any::<[u128; 2]>(), height in 0usize..7) { + let a = M256::from(a); + let b = M256::from(b); + + let result = a.unpack_hi_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count)); + assert_eq!(get(result, height, 2*(half_block_count + i)), get(a, height, 3*half_block_count + i)); + assert_eq!(get(result, height, 2*(half_block_count + i) +1), get(b, height, 3*half_block_count + i)); } } } diff --git a/crates/field/src/arch/x86_64/m512.rs b/crates/field/src/arch/x86_64/m512.rs index caa821c6..8afc2825 100644 --- a/crates/field/src/arch/x86_64/m512.rs +++ b/crates/field/src/arch/x86_64/m512.rs @@ -8,23 +8,28 @@ use std::{ use bytemuck::{must_cast, Pod, Zeroable}; use rand::{Rng, RngCore}; +use seq_macro::seq; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use crate::{ arch::{ - binary_utils::{as_array_mut, make_func_to_i8}, + binary_utils::{as_array_mut, as_array_ref, make_func_to_i8}, portable::{ packed::{impl_pack_scalar, PackedPrimitiveType}, packed_arithmetic::{ interleave_mask_even, interleave_mask_odd, UnderlierWithBitConstants, }, }, - x86_64::{m128::M128, m256::M256}, + x86_64::{ + m128::{bitshift_128b, M128}, + m256::M256, + }, }, arithmetic_traits::Broadcast, underlier::{ get_block_values, get_spread_bytes, impl_divisible, impl_iteration, spread_fallback, - NumCast, Random, SmallU, UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4, + unpack_hi_128b_fallback, unpack_lo_128b_fallback, NumCast, Random, SmallU, UnderlierType, + UnderlierWithBitOps, WithUnderlier, U1, U2, U4, }, BinaryField, }; @@ -370,7 +375,7 @@ impl UnderlierType for M512 { impl UnderlierWithBitOps for M512 { const ZERO: Self = { Self(m512_from_u128s!(0, 0, 0, 0,)) }; - const ONE: Self = { Self(m512_from_u128s!(0, 0, 0, 1,)) }; + const ONE: Self = { Self(m512_from_u128s!(1, 0, 0, 0,)) }; const ONES: Self = { Self(m512_from_u128s!(u128::MAX, u128::MAX, u128::MAX, u128::MAX,)) }; #[inline(always)] @@ -617,31 +622,41 @@ impl UnderlierWithBitOps for M512 { T: UnderlierType + NumCast, { match T::BITS { - 1 | 2 | 4 | 8 | 16 | 32 | 64 => { - let elements_in_64 = 64 / T::BITS; - let shuffle = unsafe { _mm512_set1_epi64((i / elements_in_64) as i64) }; - let chunk_64 = - u64::num_cast_from(Self(unsafe { _mm512_permutexvar_epi64(shuffle, self.0) })); - - let result_64 = if T::BITS == 64 { - chunk_64 - } else { - let ones = ((1u128 << T::BITS) - 1) as u64; - (chunk_64 >> (T::BITS * (i % elements_in_64))) & ones - }; + 1 | 2 | 4 => { + let elements_in_8 = 8 / T::BITS; + let mut value_u8 = as_array_ref::<_, u8, 64, _>(self, |arr| unsafe { + *arr.get_unchecked(i / elements_in_8) + }); + + let shift = (i % elements_in_8) * T::BITS; + value_u8 >>= shift; - T::num_cast_from(Self::from(result_64)) + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 8 => { + let value_u8 = + as_array_ref::<_, u8, 64, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 16 => { + let value_u16 = + as_array_ref::<_, u16, 32, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u16))) + } + 32 => { + let value_u32 = + as_array_ref::<_, u32, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u32))) + } + 64 => { + let value_u64 = + as_array_ref::<_, u64, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u64))) } 128 => { - let chunk_128 = unsafe { - match i { - 0 => _mm512_extracti32x4_epi32(self.0, 0), - 1 => _mm512_extracti32x4_epi32(self.0, 1), - 2 => _mm512_extracti32x4_epi32(self.0, 2), - _ => _mm512_extracti32x4_epi32(self.0, 3), - } - }; - T::num_cast_from(Self(unsafe { _mm512_castsi128_si512(chunk_128) })) + let value_u128 = + as_array_ref::<_, u128, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u128))) } _ => panic!("unsupported bit count"), } @@ -661,26 +676,26 @@ impl UnderlierWithBitOps for M512 { let val = u8::num_cast_from(Self::from(val)) << shift; let mask = mask << shift; - as_array_mut::<_, u8, 64>(self, |array| { - let element = &mut array[i / elements_in_8]; + as_array_mut::<_, u8, 64>(self, |array| unsafe { + let element = array.get_unchecked_mut(i / elements_in_8); *element &= !mask; *element |= val; }); } - 8 => as_array_mut::<_, u8, 64>(self, |array| { - array[i] = u8::num_cast_from(Self::from(val)); + 8 => as_array_mut::<_, u8, 64>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val)); }), - 16 => as_array_mut::<_, u16, 32>(self, |array| { - array[i] = u16::num_cast_from(Self::from(val)); + 16 => as_array_mut::<_, u16, 32>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val)); }), - 32 => as_array_mut::<_, u32, 16>(self, |array| { - array[i] = u32::num_cast_from(Self::from(val)); + 32 => as_array_mut::<_, u32, 16>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val)); }), - 64 => as_array_mut::<_, u64, 8>(self, |array| { - array[i] = u64::num_cast_from(Self::from(val)); + 64 => as_array_mut::<_, u64, 8>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val)); }), 128 => as_array_mut::<_, u128, 4>(self, |array| { - array[i] = u128::num_cast_from(Self::from(val)); + *array.get_unchecked_mut(i) = u128::num_cast_from(Self::from(val)); }), _ => panic!("unsupported bit count"), } @@ -855,6 +870,58 @@ impl UnderlierWithBitOps for M512 { _ => spread_fallback(self, log_block_len, block_idx), } } + + #[inline] + fn shr_128b_lanes(self, rhs: usize) -> Self { + // This implementation is effective when `rhs` is known at compile-time. + // In our code this is always the case. + bitshift_128b!( + self.0, + rhs, + _mm512_bsrli_epi128, + _mm512_srli_epi64, + _mm512_slli_epi64, + _mm512_or_si512 + ); + } + + #[inline] + fn shl_128b_lanes(self, rhs: usize) -> Self { + // This implementation is effective when `rhs` is known at compile-time. + // In our code this is always the case. + bitshift_128b!( + self.0, + rhs, + _mm512_bslli_epi128, + _mm512_slli_epi64, + _mm512_srli_epi64, + _mm512_or_si512 + ); + } + + #[inline] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm512_unpacklo_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm512_unpacklo_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm512_unpacklo_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm512_unpacklo_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } + + #[inline] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_hi_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm512_unpackhi_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm512_unpackhi_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm512_unpackhi_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm512_unpackhi_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } } unsafe impl Zeroable for M512 {} @@ -932,6 +999,12 @@ impl UnderlierWithBitConstants for M512 { let (a, b) = unsafe { interleave_bits(self.0, other.0, log_block_len) }; (Self(a), Self(b)) } + + #[inline(always)] + fn transpose(self, other: Self, log_bit_len: usize) -> (Self, Self) { + let (a, b) = unsafe { transpose_bits(self.0, other.0, log_bit_len) }; + (Self(a), Self(b)) + } } #[inline] @@ -1103,6 +1176,95 @@ const fn precompute_spread_mask( m512_masks } +#[inline(always)] +unsafe fn transpose_bits(a: __m512i, b: __m512i, log_block_len: usize) -> (__m512i, __m512i) { + match log_block_len { + 0..=3 => { + let shuffle = _mm512_set_epi8( + 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0, 15, 13, 11, 9, 7, 5, 3, 1, + 14, 12, 10, 8, 6, 4, 2, 0, 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0, + 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0, + ); + let (mut a, mut b) = transpose_with_shuffle(a, b, shuffle); + for log_block_len in (log_block_len..3).rev() { + (a, b) = interleave_bits(a, b, log_block_len); + } + + (a, b) + } + 4 => { + let shuffle = _mm512_set_epi8( + 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, + 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, + 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, + ); + transpose_with_shuffle(a, b, shuffle) + } + 5 => { + let shuffle = _mm512_set_epi8( + 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0, 15, 14, 13, 12, 7, 6, 5, 4, + 11, 10, 9, 8, 3, 2, 1, 0, 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0, 15, + 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0, + ); + transpose_with_shuffle(a, b, shuffle) + } + 6 => ( + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1110, 0b1100, 0b1010, 0b1000, 0b0110, 0b0100, 0b0010, 0b0000), + b, + ), + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1111, 0b1101, 0b1011, 0b1001, 0b0111, 0b0101, 0b0011, 0b0001), + b, + ), + ), + 7 => ( + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1101, 0b1100, 0b1001, 0b1000, 0b0101, 0b0100, 0b0001, 0b0000), + b, + ), + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1111, 0b1110, 0b1011, 0b1010, 0b0111, 0b0110, 0b0011, 0b0010), + b, + ), + ), + 8 => ( + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1011, 0b1010, 0b1001, 0b1000, 0b0011, 0b0010, 0b0001, 0b0000), + b, + ), + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1111, 0b1110, 0b1101, 0b1100, 0b0111, 0b0110, 0b0101, 0b0100), + b, + ), + ), + _ => panic!("unsupported block length"), + } +} + +unsafe fn transpose_with_shuffle(a: __m512i, b: __m512i, shuffle: __m512i) -> (__m512i, __m512i) { + let (a, b) = (_mm512_shuffle_epi8(a, shuffle), _mm512_shuffle_epi8(b, shuffle)); + + ( + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1110, 0b1100, 0b1010, 0b1000, 0b0110, 0b0100, 0b0010, 0b0000), + b, + ), + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1111, 0b1101, 0b1011, 0b1001, 0b0111, 0b0101, 0b0011, 0b0001), + b, + ), + ) +} + impl_iteration!(M512, @strategy BitIterationStrategy, U1, @strategy FallbackStrategy, U2, U4, @@ -1128,7 +1290,7 @@ mod tests { fn test_constants() { assert_eq!(M512::default(), M512::ZERO); assert_eq!(M512::from(0u128), M512::ZERO); - assert_eq!(M512::from([0u128, 0u128, 0u128, 1u128]), M512::ONE); + assert_eq!(M512::from([1u128, 0u128, 0u128, 0u128]), M512::ONE); } #[derive(Default)] @@ -1192,6 +1354,10 @@ mod tests { } } + fn get(value: M512, log_block_len: usize, index: usize) -> M512 { + (value >> (index << log_block_len)) & single_element_mask_bits::(1 << log_block_len) + } + proptest! { #[test] fn test_conversion(a in any::<[u128; 4]>()) { @@ -1225,14 +1391,49 @@ mod tests { let (c, d) = (M512::from(c), M512::from(d)); let block_len = 1usize << height; - let get = |v, i| { - u128::num_cast_from((v >> (i * block_len)) & single_element_mask_bits::(1 << height)) - }; for i in (0..512/block_len).step_by(2) { - assert_eq!(get(c, i), get(a, i)); - assert_eq!(get(c, i+1), get(b, i)); - assert_eq!(get(d, i), get(a, i+1)); - assert_eq!(get(d, i+1), get(b, i+1)); + assert_eq!(get(c, height, i), get(a, height, i)); + assert_eq!(get(c, height, i+1), get(b, height, i)); + assert_eq!(get(d, height, i), get(a, height, i+1)); + assert_eq!(get(d, height, i+1), get(b, height, i+1)); + } + } + + #[test] + fn test_unpack_lo(a in any::<[u128; 4]>(), b in any::<[u128; 4]>(), height in 0usize..7) { + let a = M512::from(a); + let b = M512::from(b); + + let result = a.unpack_lo_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i)); + assert_eq!(get(result, height, 2*(i + half_block_count)), get(a, height, 2 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + half_block_count)+1), get(b, height, 2 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + 2*half_block_count)), get(a, height, 4 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + 2*half_block_count)+1), get(b, height, 4 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + 3*half_block_count)), get(a, height, 6 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + 3*half_block_count)+1), get(b, height, 6 * half_block_count + i)); + } + } + + #[test] + fn test_unpack_hi(a in any::<[u128; 4]>(), b in any::<[u128; 4]>(), height in 0usize..7) { + let a = M512::from(a); + let b = M512::from(b); + + let result = a.unpack_hi_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count)); + assert_eq!(get(result, height, 2*(half_block_count + i)), get(a, height, 3*half_block_count + i)); + assert_eq!(get(result, height, 2*(half_block_count + i) +1), get(b, height, 3*half_block_count + i)); + assert_eq!(get(result, height, 2*(2*half_block_count + i)), get(a, height, 5*half_block_count + i)); + assert_eq!(get(result, height, 2*(2*half_block_count + i) +1), get(b, height, 5*half_block_count + i)); + assert_eq!(get(result, height, 2*(3*half_block_count + i)), get(a, height, 7*half_block_count + i)); + assert_eq!(get(result, height, 2*(3*half_block_count + i) +1), get(b, height, 7*half_block_count + i)); } } } diff --git a/crates/field/src/binary_field.rs b/crates/field/src/binary_field.rs index be6f8ea8..31f30f32 100644 --- a/crates/field/src/binary_field.rs +++ b/crates/field/src/binary_field.rs @@ -2,15 +2,16 @@ use std::{ any::TypeId, - array, fmt::{Debug, Display, Formatter}, iter::{Product, Sum}, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; -use binius_utils::serialization::{DeserializeBytes, Error as SerializationError, SerializeBytes}; +use binius_utils::{ + bytes::{Buf, BufMut}, + DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, +}; use bytemuck::{Pod, Zeroable}; -use bytes::{Buf, BufMut}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; @@ -18,7 +19,7 @@ use super::{ binary_field_arithmetic::TowerFieldArithmetic, error::Error, extension::ExtensionField, }; use crate::{ - underlier::{SmallU, U1, U2, U4}, + underlier::{U1, U2, U4}, Field, }; @@ -46,6 +47,13 @@ where /// Currently for every tower field, the canonical field is Fan-Paar's binary field of the same degree. type Canonical: TowerField + SerializeBytes + DeserializeBytes; + /// Returns the smallest valid `TOWER_LEVEL` in the tower that can fit the same value. + /// + /// Since which `TOWER_LEVEL` values are valid depends on the tower, + /// `F::Canonical::from(elem).min_tower_level()` can return a different result + /// from `elem.min_tower_level()`. + fn min_tower_level(self) -> usize; + fn basis(iota: usize, i: usize) -> Result { if iota > Self::TOWER_LEVEL { return Err(Error::ExtensionDegreeTooHigh); @@ -564,7 +572,6 @@ macro_rules! impl_field_extension { } impl ExtensionField<$subfield_name> for $name { - type Iterator = <[$subfield_name; 1 << $log_degree] as IntoIterator>::IntoIter; const LOG_DEGREE: usize = $log_degree; #[inline] @@ -597,15 +604,21 @@ macro_rules! impl_field_extension { } #[inline] - fn iter_bases(&self) -> Self::Iterator { - use $crate::underlier::NumCast; + fn iter_bases(&self) -> impl Iterator { + use $crate::underlier::{WithUnderlier, IterationMethods, IterationStrategy}; + use binius_utils::iter::IterExtensions; + + IterationMethods::<<$subfield_name as WithUnderlier>::Underlier, Self::Underlier>::ref_iter(&self.0) + .map_skippable($subfield_name::from) + } + + #[inline] + fn into_iter_bases(self) -> impl Iterator { + use $crate::underlier::{WithUnderlier, IterationMethods, IterationStrategy}; + use binius_utils::iter::IterExtensions; - let base_elems = array::from_fn(|i| { - <$subfield_name>::new(<$subfield_typ>::num_cast_from( - (self.0 >> (i * $subfield_name::N_BITS)), - )) - }); - base_elems.into_iter() + IterationMethods::<<$subfield_name as WithUnderlier>::Underlier, Self::Underlier>::value_iter(self.0) + .map_skippable($subfield_name::from) } } }; @@ -621,10 +634,16 @@ pub(super) trait MulPrimitive: Sized { #[macro_export] macro_rules! binary_tower { - ($subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty)) => { - binary_tower!($subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ, $name)); + (BinaryField1b($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty $(, $canonical:ident)?) $(< $extfield_name:ident($extfield_typ:ty $(, $canonical_ext:ident)?))+) => { + binary_tower!([BinaryField1b::TOWER_LEVEL]; BinaryField1b($subfield_typ $(, $canonical_subfield)?) < $name($typ $(, $canonical)?) $(< $extfield_name($extfield_typ $(, $canonical_ext)?))+); }; - ($subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty, $canonical:ident)) => { + ($subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty $(, $canonical:ident)?) $(< $extfield_name:ident($extfield_typ:ty $(, $canonical_ext:ident)?))+) => { + binary_tower!([BinaryField1b::TOWER_LEVEL, $subfield_name::TOWER_LEVEL]; $subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ $(, $canonical)?) $(< $extfield_name($extfield_typ $(, $canonical_ext)?))+); + }; + ([$($valid_tower_levels:tt)*]; $subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty)) => { + binary_tower!([$($valid_tower_levels)*]; $subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ, $name)); + }; + ([$($valid_tower_levels:tt)*]; $subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty, $canonical:ident)) => { impl From<$name> for ($subfield_name, $subfield_name) { #[inline] fn from(src: $name) -> ($subfield_name, $subfield_name) { @@ -648,6 +667,16 @@ macro_rules! binary_tower { type Canonical = $canonical; + fn min_tower_level(self) -> usize { + let zero = <$typ as $crate::underlier::UnderlierWithBitOps>::ZERO; + for level in [$($valid_tower_levels)*] { + if self.0 >> (1 << level) == zero { + return level; + } + } + Self::TOWER_LEVEL + } + fn mul_primitive(self, iota: usize) -> Result { ::mul_primitive(self, iota) } @@ -659,14 +688,13 @@ macro_rules! binary_tower { binary_tower!($subfield_name($subfield_typ) < @1 => $name($typ)); }; - ($subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty $(, $canonical:ident)?) $(< $extfield_name:ident($extfield_typ:ty $(, $canonical_ext:ident)?))+) => { - binary_tower!($subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ $(, $canonical)?)); - binary_tower!($name($typ $(, $canonical)?) $(< $extfield_name($extfield_typ $(, $canonical_ext)?))+); + ([$($valid_tower_levels:tt)*]; $subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty $(, $canonical:ident)?) $(< $extfield_name:ident($extfield_typ:ty $(, $canonical_ext:ident)?))+) => { + binary_tower!([$($valid_tower_levels)*]; $subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ $(, $canonical)?)); + binary_tower!([$($valid_tower_levels)*, $name::TOWER_LEVEL]; $name($typ $(, $canonical)?) $(< $extfield_name($extfield_typ $(, $canonical_ext)?))+); binary_tower!($subfield_name($subfield_typ) < @2 => $($extfield_name($extfield_typ))<+); }; ($subfield_name:ident($subfield_typ:ty) < @$log_degree:expr => $name:ident($typ:ty)) => { $crate::binary_field::impl_field_extension!($subfield_name($subfield_typ) < @$log_degree => $name($typ)); - $crate::binary_field::binary_tower_subfield_mul!($subfield_name, $name); }; ($subfield_name:ident($subfield_typ:ty) < @$log_degree:expr => $name:ident($typ:ty) $(< $extfield_name:ident($extfield_typ:ty))+) => { @@ -707,76 +735,36 @@ pub fn is_canonical_tower() -> bool { } macro_rules! serialize_deserialize { - ($bin_type:ty, SmallU<$U:literal>) => { - impl SerializeBytes for $bin_type { - fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> { - if write_buf.remaining_mut() < 1 { - ::binius_utils::bail!(SerializationError::WriteBufferFull); - } - let b = self.0.val(); - write_buf.put_u8(b); - Ok(()) - } - } - - impl DeserializeBytes for $bin_type { - fn deserialize(mut read_buf: impl Buf) -> Result { - if read_buf.remaining() < 1 { - ::binius_utils::bail!(SerializationError::NotEnoughBytes); - } - let b: u8 = read_buf.get_u8(); - Ok(Self(SmallU::<$U>::new(b))) - } - } - }; - ($bin_type:ty, $inner_type:ty) => { + ($bin_type:ty) => { impl SerializeBytes for $bin_type { - fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> { - if write_buf.remaining_mut() < (<$inner_type>::BITS / 8) as usize { - ::binius_utils::bail!(SerializationError::WriteBufferFull); - } - write_buf.put_slice(&self.0.to_le_bytes()); - Ok(()) + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + self.0.serialize(write_buf, mode) } } impl DeserializeBytes for $bin_type { - fn deserialize(mut read_buf: impl Buf) -> Result { - let mut inner = <$inner_type>::default().to_le_bytes(); - if read_buf.remaining() < inner.len() { - ::binius_utils::bail!(SerializationError::NotEnoughBytes); - } - read_buf.copy_to_slice(&mut inner); - Ok(Self(<$inner_type>::from_le_bytes(inner))) + fn deserialize( + read_buf: impl Buf, + mode: SerializationMode, + ) -> Result { + Ok(Self(DeserializeBytes::deserialize(read_buf, mode)?)) } } }; } -serialize_deserialize!(BinaryField1b, SmallU<1>); -serialize_deserialize!(BinaryField2b, SmallU<2>); -serialize_deserialize!(BinaryField4b, SmallU<4>); -serialize_deserialize!(BinaryField8b, u8); -serialize_deserialize!(BinaryField16b, u16); -serialize_deserialize!(BinaryField32b, u32); -serialize_deserialize!(BinaryField64b, u64); -serialize_deserialize!(BinaryField128b, u128); - -/// Serializes a [`TowerField`] element to a byte buffer with a canonical encoding. -pub fn serialize_canonical( - elem: F, - mut writer: W, -) -> Result<(), SerializationError> { - F::Canonical::from(elem).serialize(&mut writer) -} - -/// Deserializes a [`TowerField`] element from a byte buffer with a canonical encoding. -pub fn deserialize_canonical( - mut reader: R, -) -> Result { - let as_canonical = F::Canonical::deserialize(&mut reader)?; - Ok(F::from(as_canonical)) -} +serialize_deserialize!(BinaryField1b); +serialize_deserialize!(BinaryField2b); +serialize_deserialize!(BinaryField4b); +serialize_deserialize!(BinaryField8b); +serialize_deserialize!(BinaryField16b); +serialize_deserialize!(BinaryField32b); +serialize_deserialize!(BinaryField64b); +serialize_deserialize!(BinaryField128b); impl From for Choice { fn from(val: BinaryField1b) -> Self { @@ -867,7 +855,7 @@ impl From for u8 { #[cfg(test)] pub(crate) mod tests { - use bytes::BytesMut; + use binius_utils::{bytes::BytesMut, SerializationMode}; use proptest::prelude::*; use super::{ @@ -1236,6 +1224,7 @@ pub(crate) mod tests { #[test] fn test_serialization() { + let mode = SerializationMode::CanonicalTower; let mut buffer = BytesMut::new(); let b1 = BinaryField1b::from(0x1); let b8 = BinaryField8b::new(0x12); @@ -1246,25 +1235,25 @@ pub(crate) mod tests { let b64 = BinaryField64b::new(0x13579BDF02468ACE); let b128 = BinaryField128b::new(0x147AD0369CF258BE8899AABBCCDDEEFF); - b1.serialize(&mut buffer).unwrap(); - b8.serialize(&mut buffer).unwrap(); - b2.serialize(&mut buffer).unwrap(); - b16.serialize(&mut buffer).unwrap(); - b32.serialize(&mut buffer).unwrap(); - b4.serialize(&mut buffer).unwrap(); - b64.serialize(&mut buffer).unwrap(); - b128.serialize(&mut buffer).unwrap(); + b1.serialize(&mut buffer, mode).unwrap(); + b8.serialize(&mut buffer, mode).unwrap(); + b2.serialize(&mut buffer, mode).unwrap(); + b16.serialize(&mut buffer, mode).unwrap(); + b32.serialize(&mut buffer, mode).unwrap(); + b4.serialize(&mut buffer, mode).unwrap(); + b64.serialize(&mut buffer, mode).unwrap(); + b128.serialize(&mut buffer, mode).unwrap(); let mut read_buffer = buffer.freeze(); - assert_eq!(BinaryField1b::deserialize(&mut read_buffer).unwrap(), b1); - assert_eq!(BinaryField8b::deserialize(&mut read_buffer).unwrap(), b8); - assert_eq!(BinaryField2b::deserialize(&mut read_buffer).unwrap(), b2); - assert_eq!(BinaryField16b::deserialize(&mut read_buffer).unwrap(), b16); - assert_eq!(BinaryField32b::deserialize(&mut read_buffer).unwrap(), b32); - assert_eq!(BinaryField4b::deserialize(&mut read_buffer).unwrap(), b4); - assert_eq!(BinaryField64b::deserialize(&mut read_buffer).unwrap(), b64); - assert_eq!(BinaryField128b::deserialize(&mut read_buffer).unwrap(), b128); + assert_eq!(BinaryField1b::deserialize(&mut read_buffer, mode).unwrap(), b1); + assert_eq!(BinaryField8b::deserialize(&mut read_buffer, mode).unwrap(), b8); + assert_eq!(BinaryField2b::deserialize(&mut read_buffer, mode).unwrap(), b2); + assert_eq!(BinaryField16b::deserialize(&mut read_buffer, mode).unwrap(), b16); + assert_eq!(BinaryField32b::deserialize(&mut read_buffer, mode).unwrap(), b32); + assert_eq!(BinaryField4b::deserialize(&mut read_buffer, mode).unwrap(), b4); + assert_eq!(BinaryField64b::deserialize(&mut read_buffer, mode).unwrap(), b64); + assert_eq!(BinaryField128b::deserialize(&mut read_buffer, mode).unwrap(), b128); } #[test] diff --git a/crates/field/src/binary_field_arithmetic.rs b/crates/field/src/binary_field_arithmetic.rs index 0e913f50..481b0af1 100644 --- a/crates/field/src/binary_field_arithmetic.rs +++ b/crates/field/src/binary_field_arithmetic.rs @@ -61,6 +61,10 @@ pub(crate) use impl_arithmetic_using_packed; impl TowerField for BinaryField1b { type Canonical = Self; + fn min_tower_level(self) -> usize { + 0 + } + #[inline] fn mul_primitive(self, _: usize) -> Result { Err(crate::Error::ExtensionDegreeMismatch) diff --git a/crates/field/src/byte_iteration.rs b/crates/field/src/byte_iteration.rs new file mode 100644 index 00000000..e7acda99 --- /dev/null +++ b/crates/field/src/byte_iteration.rs @@ -0,0 +1,440 @@ +// Copyright 2023-2025 Irreducible Inc. + +use std::any::TypeId; + +use bytemuck::Pod; + +use crate::{ + packed::get_packed_slice, AESTowerField128b, AESTowerField16b, AESTowerField32b, + AESTowerField64b, AESTowerField8b, BinaryField128b, BinaryField128bPolyval, BinaryField16b, + BinaryField32b, BinaryField64b, BinaryField8b, ByteSlicedAES32x128b, ByteSlicedAES32x16b, + ByteSlicedAES32x32b, ByteSlicedAES32x64b, ByteSlicedAES32x8b, Field, + PackedAESBinaryField16x16b, PackedAESBinaryField16x32b, PackedAESBinaryField16x8b, + PackedAESBinaryField1x128b, PackedAESBinaryField1x16b, PackedAESBinaryField1x32b, + PackedAESBinaryField1x64b, PackedAESBinaryField1x8b, PackedAESBinaryField2x128b, + PackedAESBinaryField2x16b, PackedAESBinaryField2x32b, PackedAESBinaryField2x64b, + PackedAESBinaryField2x8b, PackedAESBinaryField32x16b, PackedAESBinaryField32x8b, + PackedAESBinaryField4x128b, PackedAESBinaryField4x16b, PackedAESBinaryField4x32b, + PackedAESBinaryField4x64b, PackedAESBinaryField4x8b, PackedAESBinaryField64x8b, + PackedAESBinaryField8x16b, PackedAESBinaryField8x64b, PackedAESBinaryField8x8b, + PackedBinaryField128x1b, PackedBinaryField128x2b, PackedBinaryField128x4b, + PackedBinaryField16x16b, PackedBinaryField16x1b, PackedBinaryField16x2b, + PackedBinaryField16x32b, PackedBinaryField16x4b, PackedBinaryField16x8b, + PackedBinaryField1x128b, PackedBinaryField1x16b, PackedBinaryField1x32b, + PackedBinaryField1x64b, PackedBinaryField1x8b, PackedBinaryField256x1b, + PackedBinaryField256x2b, PackedBinaryField2x128b, PackedBinaryField2x16b, + PackedBinaryField2x32b, PackedBinaryField2x4b, PackedBinaryField2x64b, PackedBinaryField2x8b, + PackedBinaryField32x16b, PackedBinaryField32x1b, PackedBinaryField32x2b, + PackedBinaryField32x4b, PackedBinaryField32x8b, PackedBinaryField4x128b, + PackedBinaryField4x16b, PackedBinaryField4x2b, PackedBinaryField4x32b, PackedBinaryField4x4b, + PackedBinaryField4x64b, PackedBinaryField4x8b, PackedBinaryField512x1b, PackedBinaryField64x1b, + PackedBinaryField64x2b, PackedBinaryField64x4b, PackedBinaryField64x8b, PackedBinaryField8x16b, + PackedBinaryField8x1b, PackedBinaryField8x2b, PackedBinaryField8x32b, PackedBinaryField8x4b, + PackedBinaryField8x64b, PackedBinaryField8x8b, PackedBinaryPolyval1x128b, + PackedBinaryPolyval2x128b, PackedBinaryPolyval4x128b, PackedField, +}; + +/// A marker trait that the slice of packed values can be iterated as a sequence of bytes. +/// The order of the iteration by BinaryField1b subfield elements and bits within iterated bytes must +/// be the same. +/// +/// # Safety +/// The implementor must ensure that the cast of the slice of packed values to the slice of bytes +/// is safe and preserves the order of the 1-bit elements. +#[allow(unused)] +unsafe trait SequentialBytes: Pod {} + +unsafe impl SequentialBytes for BinaryField8b {} +unsafe impl SequentialBytes for BinaryField16b {} +unsafe impl SequentialBytes for BinaryField32b {} +unsafe impl SequentialBytes for BinaryField64b {} +unsafe impl SequentialBytes for BinaryField128b {} + +unsafe impl SequentialBytes for PackedBinaryField8x1b {} +unsafe impl SequentialBytes for PackedBinaryField16x1b {} +unsafe impl SequentialBytes for PackedBinaryField32x1b {} +unsafe impl SequentialBytes for PackedBinaryField64x1b {} +unsafe impl SequentialBytes for PackedBinaryField128x1b {} +unsafe impl SequentialBytes for PackedBinaryField256x1b {} +unsafe impl SequentialBytes for PackedBinaryField512x1b {} + +unsafe impl SequentialBytes for PackedBinaryField4x2b {} +unsafe impl SequentialBytes for PackedBinaryField8x2b {} +unsafe impl SequentialBytes for PackedBinaryField16x2b {} +unsafe impl SequentialBytes for PackedBinaryField32x2b {} +unsafe impl SequentialBytes for PackedBinaryField64x2b {} +unsafe impl SequentialBytes for PackedBinaryField128x2b {} +unsafe impl SequentialBytes for PackedBinaryField256x2b {} + +unsafe impl SequentialBytes for PackedBinaryField2x4b {} +unsafe impl SequentialBytes for PackedBinaryField4x4b {} +unsafe impl SequentialBytes for PackedBinaryField8x4b {} +unsafe impl SequentialBytes for PackedBinaryField16x4b {} +unsafe impl SequentialBytes for PackedBinaryField32x4b {} +unsafe impl SequentialBytes for PackedBinaryField64x4b {} +unsafe impl SequentialBytes for PackedBinaryField128x4b {} + +unsafe impl SequentialBytes for PackedBinaryField1x8b {} +unsafe impl SequentialBytes for PackedBinaryField2x8b {} +unsafe impl SequentialBytes for PackedBinaryField4x8b {} +unsafe impl SequentialBytes for PackedBinaryField8x8b {} +unsafe impl SequentialBytes for PackedBinaryField16x8b {} +unsafe impl SequentialBytes for PackedBinaryField32x8b {} +unsafe impl SequentialBytes for PackedBinaryField64x8b {} + +unsafe impl SequentialBytes for PackedBinaryField1x16b {} +unsafe impl SequentialBytes for PackedBinaryField2x16b {} +unsafe impl SequentialBytes for PackedBinaryField4x16b {} +unsafe impl SequentialBytes for PackedBinaryField8x16b {} +unsafe impl SequentialBytes for PackedBinaryField16x16b {} +unsafe impl SequentialBytes for PackedBinaryField32x16b {} + +unsafe impl SequentialBytes for PackedBinaryField1x32b {} +unsafe impl SequentialBytes for PackedBinaryField2x32b {} +unsafe impl SequentialBytes for PackedBinaryField4x32b {} +unsafe impl SequentialBytes for PackedBinaryField8x32b {} +unsafe impl SequentialBytes for PackedBinaryField16x32b {} + +unsafe impl SequentialBytes for PackedBinaryField1x64b {} +unsafe impl SequentialBytes for PackedBinaryField2x64b {} +unsafe impl SequentialBytes for PackedBinaryField4x64b {} +unsafe impl SequentialBytes for PackedBinaryField8x64b {} + +unsafe impl SequentialBytes for PackedBinaryField1x128b {} +unsafe impl SequentialBytes for PackedBinaryField2x128b {} +unsafe impl SequentialBytes for PackedBinaryField4x128b {} + +unsafe impl SequentialBytes for AESTowerField8b {} +unsafe impl SequentialBytes for AESTowerField16b {} +unsafe impl SequentialBytes for AESTowerField32b {} +unsafe impl SequentialBytes for AESTowerField64b {} +unsafe impl SequentialBytes for AESTowerField128b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField8x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField16x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField32x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField64x8b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField8x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField16x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField32x16b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x32b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x32b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x32b {} +unsafe impl SequentialBytes for PackedAESBinaryField16x32b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x64b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x64b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x64b {} +unsafe impl SequentialBytes for PackedAESBinaryField8x64b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x128b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x128b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x128b {} + +unsafe impl SequentialBytes for BinaryField128bPolyval {} + +unsafe impl SequentialBytes for PackedBinaryPolyval1x128b {} +unsafe impl SequentialBytes for PackedBinaryPolyval2x128b {} +unsafe impl SequentialBytes for PackedBinaryPolyval4x128b {} + +/// Returns true if T implements `SequentialBytes` trait. +/// Use a hack that exploits that array copying is optimized for the `Copy` types. +/// Unfortunately there is no more proper way to perform this check this in Rust at runtime. +#[inline(always)] +#[allow(clippy::redundant_clone)] // this is intentional in this method +pub fn is_sequential_bytes() -> bool { + struct X(bool, std::marker::PhantomData); + + impl Clone for X { + fn clone(&self) -> Self { + Self(false, std::marker::PhantomData) + } + } + + impl Copy for X {} + + let value = [X::(true, std::marker::PhantomData)]; + let cloned = value.clone(); + + cloned[0].0 +} + +/// Returns if we can iterate over bytes, each representing 8 1-bit values. +pub fn can_iterate_bytes() -> bool { + // Packed fields with sequential byte order + if is_sequential_bytes::

() { + return true; + } + + // Byte-sliced fields + // Note: add more byte sliced types here as soon as they are added + match TypeId::of::

() { + x if x == TypeId::of::() => true, + x if x == TypeId::of::() => true, + x if x == TypeId::of::() => true, + x if x == TypeId::of::() => true, + x if x == TypeId::of::() => true, + _ => false, + } +} + +/// Helper macro to generate the iteration over bytes for byte-sliced types. +macro_rules! iterate_byte_sliced { + ($packed_type:ty, $data:ident, $callback:ident) => { + assert_eq!(TypeId::of::<$packed_type>(), TypeId::of::

()); + + // Safety: the cast is safe because the type is checked by arm statement + let data = unsafe { + std::slice::from_raw_parts($data.as_ptr() as *const $packed_type, $data.len()) + }; + let iter = data.iter().flat_map(|value| { + (0..<$packed_type>::BYTES).map(move |i| unsafe { value.get_byte_unchecked(i) }) + }); + + $callback.call(iter); + }; +} + +/// Callback for byte iteration. +/// We can't return different types from the `iterate_bytes` and Fn traits don't support associated types +/// that's why we use a callback with a generic function. +pub trait ByteIteratorCallback { + fn call(&mut self, iter: impl Iterator); +} + +/// Iterate over bytes of a slice of the packed values. +/// The method panics if the packed field doesn't support byte iteration, so use `can_iterate_bytes` to check it. +#[inline(always)] +pub fn iterate_bytes(data: &[P], callback: &mut impl ByteIteratorCallback) { + if is_sequential_bytes::

() { + // Safety: `P` implements `SequentialBytes` trait, so the following cast is safe + // and preserves the order. + let bytes = unsafe { + std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data)) + }; + callback.call(bytes.iter().copied()); + } else { + // Note: add more byte sliced types here as soon as they are added + match TypeId::of::

() { + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x128b, data, callback); + } + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x64b, data, callback); + } + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x32b, data, callback); + } + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x16b, data, callback); + } + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x8b, data, callback); + } + _ => unreachable!("packed field doesn't support byte iteration"), + } + } +} + +/// Scalars collection abstraction. +/// This trait is used to abstract over different types of collections of field elements. +pub trait ScalarsCollection { + fn len(&self) -> usize; + fn get(&self, i: usize) -> T; + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl ScalarsCollection for &[F] { + #[inline(always)] + fn len(&self) -> usize { + <[F]>::len(self) + } + + #[inline(always)] + fn get(&self, i: usize) -> F { + self[i] + } +} + +pub struct PackedSlice<'a, P: PackedField> { + slice: &'a [P], + len: usize, +} + +impl<'a, P: PackedField> PackedSlice<'a, P> { + #[inline(always)] + pub const fn new(slice: &'a [P], len: usize) -> Self { + Self { slice, len } + } +} + +impl ScalarsCollection for PackedSlice<'_, P> { + #[inline(always)] + fn len(&self) -> usize { + self.len + } + + #[inline(always)] + fn get(&self, i: usize) -> P::Scalar { + get_packed_slice(self.slice, i) + } +} + +/// Create a lookup table for partial sums of 8 consequent elements with coefficients corresponding to bits in a byte. +/// The lookup table has the following structure: +/// [ +/// partial_sum_chunk_0_7_byte_0, partial_sum_chunk_0_7_byte_1, ..., partial_sum_chunk_0_7_byte_255, +/// partial_sum_chunk_8_15_byte_0, partial_sum_chunk_8_15_byte_1, ..., partial_sum_chunk_8_15_byte_255, +/// ... +/// ] +pub fn create_partial_sums_lookup_tables( + values: impl ScalarsCollection

, +) -> Vec

{ + let len = values.len(); + assert!(len % 8 == 0); + + let mut result = Vec::with_capacity(len * 256 / 8); + for chunk_i in 0..len / 8 { + let offset = chunk_i * 8; + for i in 0..256 { + let mut sum = P::zero(); + for j in 0..8 { + if i & (1 << j) != 0 { + sum += values.get(offset + j); + } + } + result.push(sum); + } + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{PackedBinaryField1x1b, PackedBinaryField2x1b, PackedBinaryField4x1b}; + + #[test] + fn test_sequential_bits() { + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + } +} diff --git a/crates/field/src/extension.rs b/crates/field/src/extension.rs index 45ca2e36..84a660aa 100644 --- a/crates/field/src/extension.rs +++ b/crates/field/src/extension.rs @@ -18,9 +18,6 @@ pub trait ExtensionField: + SubAssign + MulAssign { - /// Iterator returned by `iter_bases`. - type Iterator: Iterator; - /// Base-2 logarithm of the extension degree. const LOG_DEGREE: usize; @@ -47,12 +44,13 @@ pub trait ExtensionField: fn from_bases_sparse(base_elems: &[F], log_stride: usize) -> Result; /// Iterator over base field elements. - fn iter_bases(&self) -> Self::Iterator; + fn iter_bases(&self) -> impl Iterator; + + /// Convert into an iterator over base field elements. + fn into_iter_bases(self) -> impl Iterator; } impl ExtensionField for F { - type Iterator = iter::Once; - const LOG_DEGREE: usize = 0; fn basis(i: usize) -> Result { @@ -74,7 +72,11 @@ impl ExtensionField for F { } } - fn iter_bases(&self) -> Self::Iterator { + fn iter_bases(&self) -> impl Iterator { iter::once(*self) } + + fn into_iter_bases(self) -> impl Iterator { + iter::once(self) + } } diff --git a/crates/field/src/field.rs b/crates/field/src/field.rs index 10395e9c..4a546951 100644 --- a/crates/field/src/field.rs +++ b/crates/field/src/field.rs @@ -7,6 +7,7 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; +use binius_utils::{DeserializeBytes, SerializeBytes}; use rand::RngCore; use crate::{ @@ -49,6 +50,8 @@ pub trait Field: + InvertOrZero // `Underlier: PackScalar` is an obvious property but it can't be deduced by the compiler so we are id here. + WithUnderlier> + + SerializeBytes + + DeserializeBytes { /// The zero element of the field, the additive identity. const ZERO: Self; diff --git a/crates/field/src/lib.rs b/crates/field/src/lib.rs index cae1070d..76414c03 100644 --- a/crates/field/src/lib.rs +++ b/crates/field/src/lib.rs @@ -20,6 +20,7 @@ pub mod arithmetic_traits; pub mod as_packed_field; pub mod binary_field; mod binary_field_arithmetic; +pub mod byte_iteration; pub mod error; pub mod extension; pub mod field; diff --git a/crates/field/src/packed.rs b/crates/field/src/packed.rs index 7c0b9ad8..9139b67f 100644 --- a/crates/field/src/packed.rs +++ b/crates/field/src/packed.rs @@ -20,8 +20,7 @@ use super::{ Error, }; use crate::{ - arithmetic_traits::InvertOrZero, underlier::WithUnderlier, BinaryField, ExtensionField, Field, - PackedExtension, + arithmetic_traits::InvertOrZero, underlier::WithUnderlier, BinaryField, Field, PackedExtension, }; /// A packed field represents a vector of underlying field elements. @@ -216,6 +215,20 @@ pub trait PackedField: /// * `log_block_len` must be strictly less than `LOG_WIDTH`. fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self); + /// Unzips interleaved blocks of this packed vector with another packed vector. + /// + /// Consider this example, where `LOG_WIDTH` is 3 and `log_block_len` is 1: + /// A = [a0, a1, b0, b1, a2, a3, b2, b3] + /// B = [a4, a5, b4, b5, a6, a7, b6, b7] + /// + /// The transposed result is + /// A' = [a0, a1, a2, a3, a4, a5, a6, a7] + /// B' = [b0, b1, b2, b3, b4, b5, b6, b7] + /// + /// ## Preconditions + /// * `log_block_len` must be strictly less than `LOG_WIDTH`. + fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self); + /// Spread takes a block of elements within a packed field and repeats them to the full packing /// width. /// @@ -357,11 +370,7 @@ pub const fn len_packed_slice(packed: &[P]) -> usize { } /// Multiply packed field element by a subfield scalar. -pub fn mul_by_subfield_scalar(val: P, multiplier: FS) -> P -where - P: PackedExtension>, - FS: Field, -{ +pub fn mul_by_subfield_scalar, FS: Field>(val: P, multiplier: FS) -> P { use crate::underlier::UnderlierType; // This is a workaround not to make the multiplication slower in certain cases. @@ -376,6 +385,14 @@ where } } +pub fn pack_slice(scalars: &[P::Scalar]) -> Vec

{ + let mut packed_slice = vec![P::default(); scalars.len() / P::WIDTH]; + for (i, scalar) in scalars.iter().enumerate() { + set_packed_slice(&mut packed_slice, i, *scalar); + } + packed_slice +} + impl Broadcast for F { fn broadcast(scalar: F) -> Self { scalar @@ -427,6 +444,10 @@ impl PackedField for F { panic!("cannot interleave when WIDTH = 1"); } + fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) { + panic!("cannot transpose when WIDTH = 1"); + } + fn broadcast(scalar: Self::Scalar) -> Self { scalar } @@ -495,7 +516,7 @@ mod tests { } /// Run the test for all the packed fields defined in this crate. - fn run_for_all_packed_fields(test: impl PackedFieldTest) { + fn run_for_all_packed_fields(test: &impl PackedFieldTest) { // canonical tower test.run::(); @@ -665,6 +686,6 @@ mod tests { #[test] fn test_iteration() { - run_for_all_packed_fields(PackedFieldIterationTest); + run_for_all_packed_fields(&PackedFieldIterationTest); } } diff --git a/crates/field/src/packed_binary_field.rs b/crates/field/src/packed_binary_field.rs index b22f48fe..152cbf6b 100644 --- a/crates/field/src/packed_binary_field.rs +++ b/crates/field/src/packed_binary_field.rs @@ -807,6 +807,63 @@ pub mod test_utils { check_interleave::

(lhs, rhs, log_block_len); } } + + pub fn check_unzip( + lhs: P::Underlier, + rhs: P::Underlier, + log_block_len: usize, + ) { + let lhs = P::from_underlier(lhs); + let rhs = P::from_underlier(rhs); + let block_len = 1 << log_block_len; + let (a, b) = lhs.unzip(rhs, log_block_len); + for i in (0..P::WIDTH / 2).step_by(block_len) { + for j in 0..block_len { + assert_eq!( + a.get(i + j), + lhs.get(2 * i + j), + "i: {}, j: {}, log_block_len: {}, P: {:?}", + i, + j, + log_block_len, + P::zero() + ); + assert_eq!( + b.get(i + j), + lhs.get(2 * i + j + block_len), + "i: {}, j: {}, log_block_len: {}, P: {:?}", + i, + j, + log_block_len, + P::zero() + ); + } + } + + for i in (0..P::WIDTH / 2).step_by(block_len) { + for j in 0..block_len { + assert_eq!( + a.get(i + j + P::WIDTH / 2), + rhs.get(2 * i + j), + "i: {}, j: {}, log_block_len: {}, P: {:?}", + i, + j, + log_block_len, + P::zero() + ); + assert_eq!(b.get(i + j + P::WIDTH / 2), rhs.get(2 * i + j + block_len)); + } + } + } + + pub fn check_transpose_all_heights( + lhs: P::Underlier, + rhs: P::Underlier, + ) { + for log_block_len in 0..P::LOG_WIDTH { + check_unzip::

(lhs, rhs, log_block_len); + } + } } #[cfg(test)] @@ -831,6 +888,7 @@ mod tests { }, arithmetic_traits::MulAlpha, linear_transformation::PackedTransformationFactory, + test_utils::check_transpose_all_heights, underlier::{U2, U4}, Field, PackedField, PackedFieldIndexable, }; @@ -1206,5 +1264,93 @@ mod tests { check_interleave_all_heights::(a_val.into(), b_val.into()); check_interleave_all_heights::(a_val.into(), b_val.into()); } + + #[test] + fn check_transpose_2b(a_val in 0u8..3, b_val in 0u8..3) { + check_transpose_all_heights::(U2::new(a_val), U2::new(b_val)); + check_transpose_all_heights::(U2::new(a_val), U2::new(b_val)); + } + + #[test] + fn check_transpose_4b(a_val in 0u8..16, b_val in 0u8..16) { + check_transpose_all_heights::(U4::new(a_val), U4::new(b_val)); + check_transpose_all_heights::(U4::new(a_val), U4::new(b_val)); + check_transpose_all_heights::(U4::new(a_val), U4::new(b_val)); + } + + #[test] + fn check_transpose_8b(a_val in 0u8.., b_val in 0u8..) { + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + } + + #[test] + fn check_transpose_16b(a_val in 0u16.., b_val in 0u16..) { + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + } + + #[test] + fn check_transpose_32b(a_val in 0u32.., b_val in 0u32..) { + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + } + + #[test] + fn check_transpose_64b(a_val in 0u64.., b_val in 0u64..) { + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + } + + #[test] + #[allow(clippy::useless_conversion)] // this warning depends on the target platform + fn check_transpose_128b(a_val in 0u128.., b_val in 0u128..) { + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + } + + #[test] + fn check_transpose_256b(a_val in any::<[u128; 2]>(), b_val in any::<[u128; 2]>()) { + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + } + + #[test] + fn check_transpose_512b(a_val in any::<[u128; 4]>(), b_val in any::<[u128; 4]>()) { + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + } } } diff --git a/crates/field/src/packed_extension.rs b/crates/field/src/packed_extension.rs index 3a2b11a2..8f2aaf02 100644 --- a/crates/field/src/packed_extension.rs +++ b/crates/field/src/packed_extension.rs @@ -54,12 +54,12 @@ where /// PE: PackedField>, /// F: Field, /// { -/// packed.iter().flat_map(|ext| ext.iter_bases()) +/// packed.iter().flat_map(|ext| ext.into_iter_bases()) /// } /// /// fn cast_then_iter<'a, F, PE>(packed: &'a PE) -> impl Iterator + 'a /// where -/// PE: PackedExtension>, +/// PE: PackedExtension, /// F: Field, /// { /// PE::cast_base_ref(packed).into_iter() @@ -71,10 +71,7 @@ where /// In order for the above relation to be guaranteed, the memory representation of /// `PackedExtensionField` element must be the same as a slice of the underlying `PackedField` /// element. -pub trait PackedExtension: PackedField -where - Self::Scalar: ExtensionField, -{ +pub trait PackedExtension: PackedField> { type PackedSubfield: PackedField; fn cast_bases(packed: &[Self]) -> &[Self::PackedSubfield]; @@ -187,9 +184,7 @@ where /// This trait is a shorthand for the case `PackedExtension` which is a /// quite common case in our codebase. pub trait RepackedExtension: - PackedExtension -where - Self::Scalar: ExtensionField, + PackedField> + PackedExtension { } @@ -202,10 +197,8 @@ where /// This trait adds shortcut methods for the case `PackedExtension` which is a /// quite common case in our codebase. -pub trait PackedExtensionIndexable: PackedExtension -where - Self::Scalar: ExtensionField, - Self::PackedSubfield: PackedFieldIndexable, +pub trait PackedExtensionIndexable: + PackedExtension + PackedField> { fn unpack_base_scalars(packed: &[Self]) -> &[F] { Self::PackedSubfield::unpack_scalars(Self::cast_bases(packed)) @@ -219,7 +212,7 @@ where impl PackedExtensionIndexable for PT where F: Field, - PT: PackedExtension, PackedSubfield: PackedFieldIndexable>, + PT: PackedExtension, { } diff --git a/crates/field/src/packed_extension_ops.rs b/crates/field/src/packed_extension_ops.rs index 6d2ede80..30fad681 100644 --- a/crates/field/src/packed_extension_ops.rs +++ b/crates/field/src/packed_extension_ops.rs @@ -6,21 +6,17 @@ use binius_maybe_rayon::prelude::{ use crate::{Error, ExtensionField, Field, PackedExtension, PackedField}; -pub fn ext_base_mul(lhs: &mut [PE], rhs: &[PE::PackedSubfield]) -> Result<(), Error> -where - PE: PackedExtension, - PE::Scalar: ExtensionField, - F: Field, -{ +pub fn ext_base_mul, F: Field>( + lhs: &mut [PE], + rhs: &[PE::PackedSubfield], +) -> Result<(), Error> { ext_base_op(lhs, rhs, |_, lhs, broadcasted_rhs| PE::cast_ext(lhs.cast_base() * broadcasted_rhs)) } -pub fn ext_base_mul_par(lhs: &mut [PE], rhs: &[PE::PackedSubfield]) -> Result<(), Error> -where - PE: PackedExtension, - PE::Scalar: ExtensionField, - F: Field, -{ +pub fn ext_base_mul_par, F: Field>( + lhs: &mut [PE], + rhs: &[PE::PackedSubfield], +) -> Result<(), Error> { ext_base_op_par(lhs, rhs, |_, lhs, broadcasted_rhs| { PE::cast_ext(lhs.cast_base() * broadcasted_rhs) }) @@ -29,15 +25,10 @@ where /// # Safety /// /// Width of PackedSubfield is >= the width of the field implementing PackedExtension. -pub unsafe fn get_packed_subfields_at_pe_idx( +pub unsafe fn get_packed_subfields_at_pe_idx, F: Field>( packed_subfields: &[PE::PackedSubfield], i: usize, -) -> PE::PackedSubfield -where - PE: PackedExtension, - PE::Scalar: ExtensionField, - F: Field, -{ +) -> PE::PackedSubfield { let bottom_most_scalar_idx = i * PE::WIDTH; let bottom_most_scalar_idx_in_subfield_arr = bottom_most_scalar_idx / PE::PackedSubfield::WIDTH; let bottom_most_scalar_idx_within_packed_subfield = @@ -67,7 +58,6 @@ pub fn ext_base_op( ) -> Result<(), Error> where PE: PackedExtension, - PE::Scalar: ExtensionField, F: Field, Func: Fn(usize, PE, PE::PackedSubfield) -> PE, { @@ -93,7 +83,6 @@ pub fn ext_base_op_par( ) -> Result<(), Error> where PE: PackedExtension, - PE::Scalar: ExtensionField, F: Field, Func: Fn(usize, PE, PE::PackedSubfield) -> PE + std::marker::Sync, { @@ -117,10 +106,10 @@ mod tests { use crate::{ ext_base_mul, ext_base_mul_par, - packed::{get_packed_slice, set_packed_slice}, + packed::{get_packed_slice, pack_slice}, underlier::WithUnderlier, BinaryField128b, BinaryField16b, BinaryField8b, PackedBinaryField16x16b, - PackedBinaryField2x128b, PackedBinaryField32x8b, PackedField, + PackedBinaryField2x128b, PackedBinaryField32x8b, }; fn strategy_8b_scalars() -> impl Strategy { @@ -138,16 +127,6 @@ mod tests { .prop_map(|arr| arr.map(::from_underlier)) } - fn pack_slice(scalar_slice: &[P::Scalar]) -> Vec

{ - let mut packed_slice = vec![P::default(); scalar_slice.len() / P::WIDTH]; - - for (i, scalar) in scalar_slice.iter().enumerate() { - set_packed_slice(&mut packed_slice, i, *scalar); - } - - packed_slice - } - proptest! { #[test] fn test_base_ext_mul_8(base_scalars in strategy_8b_scalars(), ext_scalars in strategy_128b_scalars()){ diff --git a/crates/field/src/polyval.rs b/crates/field/src/polyval.rs index 989e12ee..cd3d944a 100644 --- a/crates/field/src/polyval.rs +++ b/crates/field/src/polyval.rs @@ -4,12 +4,16 @@ use std::{ any::TypeId, - array, fmt::{self, Debug, Display, Formatter}, iter::{Product, Sum}, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; +use binius_utils::{ + bytes::{Buf, BufMut}, + iter::IterExtensions, + DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, +}; use bytemuck::{Pod, TransparentWrapper, Zeroable}; use rand::{Rng, RngCore}; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; @@ -29,7 +33,7 @@ use crate::{ invert_or_zero_using_packed, multiple_using_packed, square_using_packed, }, linear_transformation::{FieldLinearTransformation, Transformation}, - underlier::UnderlierWithBitOps, + underlier::{IterationMethods, IterationStrategy, UnderlierWithBitOps, U1}, Field, }; @@ -414,7 +418,6 @@ impl Mul for BinaryField1b { } impl ExtensionField for BinaryField128bPolyval { - type Iterator = <[BinaryField1b; 128] as IntoIterator>::IntoIter; const LOG_DEGREE: usize = 7; #[inline] @@ -439,9 +442,44 @@ impl ExtensionField for BinaryField128bPolyval { } #[inline] - fn iter_bases(&self) -> Self::Iterator { - let base_elems = array::from_fn(|i| BinaryField1b::from((self.0 >> i) as u8)); - base_elems.into_iter() + fn iter_bases(&self) -> impl Iterator { + IterationMethods::::value_iter(self.0) + .map_skippable(BinaryField1b::from) + } + + #[inline] + fn into_iter_bases(self) -> impl Iterator { + IterationMethods::::value_iter(self.0) + .map_skippable(BinaryField1b::from) + } +} + +impl SerializeBytes for BinaryField128bPolyval { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + match mode { + SerializationMode::Native => self.0.serialize(write_buf, mode), + SerializationMode::CanonicalTower => { + BinaryField128b::from(*self).serialize(write_buf, mode) + } + } + } +} + +impl DeserializeBytes for BinaryField128bPolyval { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + match mode { + SerializationMode::Native => Ok(Self(DeserializeBytes::deserialize(read_buf, mode)?)), + SerializationMode::CanonicalTower => { + Ok(Self::from(BinaryField128b::deserialize(read_buf, mode)?)) + } + } } } @@ -452,6 +490,13 @@ impl BinaryField for BinaryField128bPolyval { impl TowerField for BinaryField128bPolyval { type Canonical = BinaryField128b; + fn min_tower_level(self) -> usize { + match self { + Self::ZERO | Self::ONE => 0, + _ => 7, + } + } + fn mul_primitive(self, _iota: usize) -> Result { // This method could be implemented by multiplying by isomorphic alpha value // But it's not being used as for now @@ -1019,7 +1064,7 @@ pub fn is_polyval_tower() -> bool { #[cfg(test)] mod tests { - use bytes::BytesMut; + use binius_utils::{bytes::BytesMut, SerializationMode, SerializeBytes}; use proptest::prelude::*; use rand::thread_rng; @@ -1030,11 +1075,10 @@ mod tests { packed_polyval_512::PackedBinaryPolyval4x128b, }, binary_field::tests::is_binary_field_valid_generator, - deserialize_canonical, linear_transformation::PackedTransformationFactory, - serialize_canonical, AESTowerField128b, PackedAESBinaryField1x128b, - PackedAESBinaryField2x128b, PackedAESBinaryField4x128b, PackedBinaryField1x128b, - PackedBinaryField2x128b, PackedBinaryField4x128b, PackedField, + AESTowerField128b, PackedAESBinaryField1x128b, PackedAESBinaryField2x128b, + PackedAESBinaryField4x128b, PackedBinaryField1x128b, PackedBinaryField2x128b, + PackedBinaryField4x128b, PackedField, }; #[test] @@ -1177,25 +1221,25 @@ mod tests { #[test] fn test_canonical_serialization() { + let mode = SerializationMode::CanonicalTower; let mut buffer = BytesMut::new(); let mut rng = thread_rng(); let b128_poly1 = ::random(&mut rng); let b128_poly2 = ::random(&mut rng); - serialize_canonical(b128_poly1, &mut buffer).unwrap(); - serialize_canonical(b128_poly2, &mut buffer).unwrap(); + SerializeBytes::serialize(&b128_poly1, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&b128_poly2, &mut buffer, mode).unwrap(); + let mode = SerializationMode::CanonicalTower; let mut read_buffer = buffer.freeze(); assert_eq!( - deserialize_canonical::(&mut read_buffer).unwrap(), + BinaryField128bPolyval::deserialize(&mut read_buffer, mode).unwrap(), b128_poly1 ); assert_eq!( - BinaryField128bPolyval::from( - deserialize_canonical::(&mut read_buffer).unwrap() - ), + BinaryField128bPolyval::deserialize(&mut read_buffer, mode).unwrap(), b128_poly2 ); } diff --git a/crates/field/src/tower_levels.rs b/crates/field/src/tower_levels.rs index bdac60d0..bfddced0 100644 --- a/crates/field/src/tower_levels.rs +++ b/crates/field/src/tower_levels.rs @@ -16,110 +16,104 @@ use std::{ /// These separate implementations are necessary to overcome the limitations of const generics in Rust. /// These implementations eliminate costly bounds checking that would otherwise be imposed by the compiler /// and allow easy inlining of recursive functions. -pub trait TowerLevel -where - T: Default + Copy, -{ +pub trait TowerLevel { // WIDTH is ALWAYS a power of 2 const WIDTH: usize; // The underlying Data should ALWAYS be a fixed-width array of T's - type Data: AsMut<[T]> + type Data: AsMut<[T]> + AsRef<[T]> + Sized + Index + IndexMut; - type Base: TowerLevel; + type Base: TowerLevel; - // Split something of type Self::Data into two equal halves + // Split something of type Self::Datainto two equal halves #[allow(clippy::type_complexity)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data); + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data); - // Split something of type Self::Data into two equal mutable halves + // Split something of type Self::Datainto two equal mutable halves #[allow(clippy::type_complexity)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data); + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data); // Join two equal-length arrays (the reverse of split) #[allow(clippy::type_complexity)] - fn join( - first: &>::Data, - second: &>::Data, - ) -> Self::Data; + fn join( + first: &::Data, + second: &::Data, + ) -> Self::Data; // Fills an array of T's containing WIDTH elements - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data; + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data; // Fills an array of T's containing WIDTH elements with T::default() - fn default() -> Self::Data { + fn default() -> Self::Data { Self::from_fn(|_| T::default()) } } -pub trait TowerLevelWithArithOps: TowerLevel -where - T: Default + Add + AddAssign + Copy, -{ +pub trait TowerLevelWithArithOps: TowerLevel { #[inline(always)] - fn add_into(field_element: &Self::Data, destination: &mut Self::Data) { + fn add_into( + field_element: &Self::Data, + destination: &mut Self::Data, + ) { for i in 0..Self::WIDTH { destination[i] += field_element[i]; } } #[inline(always)] - fn copy_into(field_element: &Self::Data, destination: &mut Self::Data) { + fn copy_into(field_element: &Self::Data, destination: &mut Self::Data) { for i in 0..Self::WIDTH { destination[i] = field_element[i]; } } #[inline(always)] - fn sum(field_element_a: &Self::Data, field_element_b: &Self::Data) -> Self::Data { + fn sum>( + field_element_a: &Self::Data, + field_element_b: &Self::Data, + ) -> Self::Data { Self::from_fn(|i| field_element_a[i] + field_element_b[i]) } } -impl> TowerLevelWithArithOps for U where - T: Default + Add + AddAssign + Copy -{ -} +impl TowerLevelWithArithOps for T {} pub struct TowerLevel64; -impl TowerLevel for TowerLevel64 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel64 { const WIDTH: usize = 64; - type Data = [T; 64]; + type Data = [T; 64]; type Base = TowerLevel32; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..32].try_into().unwrap()), (data[32..64].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(32); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 64]; result[..32].copy_from_slice(left); result[32..].copy_from_slice(right); @@ -127,43 +121,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel32; -impl TowerLevel for TowerLevel32 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel32 { const WIDTH: usize = 32; - type Data = [T; 32]; + type Data = [T; 32]; type Base = TowerLevel16; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..16].try_into().unwrap()), (data[16..32].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(16); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 32]; result[..16].copy_from_slice(left); result[16..].copy_from_slice(right); @@ -171,43 +162,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel16; -impl TowerLevel for TowerLevel16 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel16 { const WIDTH: usize = 16; - type Data = [T; 16]; + type Data = [T; 16]; type Base = TowerLevel8; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..8].try_into().unwrap()), (data[8..16].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(8); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 16]; result[..8].copy_from_slice(left); result[8..].copy_from_slice(right); @@ -215,43 +203,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel8; -impl TowerLevel for TowerLevel8 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel8 { const WIDTH: usize = 8; - type Data = [T; 8]; + type Data = [T; 8]; type Base = TowerLevel4; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..4].try_into().unwrap()), (data[4..8].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(4); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 8]; result[..4].copy_from_slice(left); result[4..].copy_from_slice(right); @@ -259,43 +244,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel4; -impl TowerLevel for TowerLevel4 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel4 { const WIDTH: usize = 4; - type Data = [T; 4]; + type Data = [T; 4]; type Base = TowerLevel2; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..2].try_into().unwrap()), (data[2..4].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(2); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 4]; result[..2].copy_from_slice(left); result[2..].copy_from_slice(right); @@ -303,43 +285,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel2; -impl TowerLevel for TowerLevel2 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel2 { const WIDTH: usize = 2; - type Data = [T; 2]; + type Data = [T; 2]; type Base = TowerLevel1; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..1].try_into().unwrap()), (data[1..2].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(1); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 2]; result[..1].copy_from_slice(left); result[1..].copy_from_slice(right); @@ -347,48 +326,45 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel1; -impl TowerLevel for TowerLevel1 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel1 { const WIDTH: usize = 1; - type Data = [T; 1]; + type Data = [T; 1]; type Base = Self; // Level 1 is the atomic unit of backing data and must not be split. #[inline(always)] - fn split( - _data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + _data: &Self::Data, + ) -> (&::Data, &::Data) { unreachable!() } #[inline(always)] - fn split_mut( - _data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + _data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { unreachable!() } #[inline(always)] - fn join<'a>( - _left: &>::Data, - _right: &>::Data, - ) -> Self::Data { + fn join( + _left: &::Data, + _right: &::Data, + ) -> Self::Data { unreachable!() } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } diff --git a/crates/field/src/transpose.rs b/crates/field/src/transpose.rs index efabbfc6..b3838b57 100644 --- a/crates/field/src/transpose.rs +++ b/crates/field/src/transpose.rs @@ -2,7 +2,7 @@ use binius_utils::checked_arithmetics::log2_strict_usize; -use super::{packed::PackedField, ExtensionField, PackedFieldIndexable, RepackedExtension}; +use super::{packed::PackedField, Field, PackedFieldIndexable, RepackedExtension}; /// Error thrown when a transpose operation fails. #[derive(Clone, thiserror::Error, Debug)] @@ -76,7 +76,7 @@ pub fn square_transpose(log_n: usize, elems: &mut [P]) -> Result pub fn transpose_scalars(src: &[PE], dst: &mut [P]) -> Result<(), Error> where P: PackedField, - FE: ExtensionField, + FE: Field, PE: PackedFieldIndexable + RepackedExtension

, { let len = src.len(); diff --git a/crates/field/src/underlier/scaled.rs b/crates/field/src/underlier/scaled.rs index 4210bc65..40cabc5c 100644 --- a/crates/field/src/underlier/scaled.rs +++ b/crates/field/src/underlier/scaled.rs @@ -1,13 +1,16 @@ // Copyright 2024-2025 Irreducible Inc. -use std::array; +use std::{ + array, + ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr}, +}; use binius_utils::checked_arithmetics::checked_log_2; use bytemuck::{must_cast_mut, must_cast_ref, NoUninit, Pod, Zeroable}; use rand::RngCore; use subtle::{Choice, ConstantTimeEq}; -use super::{Divisible, Random, UnderlierType}; +use super::{Divisible, Random, UnderlierType, UnderlierWithBitOps}; /// A type that represents a pair of elements of the same underlier type. /// We use it as an underlier for the `ScaledPAckedField` type. @@ -104,3 +107,189 @@ where must_cast_mut::(self) } } + +impl + Copy, const N: usize> BitAnd for ScaledUnderlier { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + Self(array::from_fn(|i| self.0[i] & rhs.0[i])) + } +} + +impl BitAndAssign for ScaledUnderlier { + fn bitand_assign(&mut self, rhs: Self) { + for i in 0..N { + self.0[i] &= rhs.0[i]; + } + } +} + +impl + Copy, const N: usize> BitOr for ScaledUnderlier { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + Self(array::from_fn(|i| self.0[i] | rhs.0[i])) + } +} + +impl BitOrAssign for ScaledUnderlier { + fn bitor_assign(&mut self, rhs: Self) { + for i in 0..N { + self.0[i] |= rhs.0[i]; + } + } +} + +impl + Copy, const N: usize> BitXor for ScaledUnderlier { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self::Output { + Self(array::from_fn(|i| self.0[i] ^ rhs.0[i])) + } +} + +impl BitXorAssign for ScaledUnderlier { + fn bitxor_assign(&mut self, rhs: Self) { + for i in 0..N { + self.0[i] ^= rhs.0[i]; + } + } +} + +impl Shr for ScaledUnderlier { + type Output = Self; + + fn shr(self, rhs: usize) -> Self::Output { + let mut result = Self::default(); + + let shift_in_items = rhs / U::BITS; + for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) { + if i + shift_in_items < N { + result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS); + } + if i + shift_in_items + 1 < N && rhs % U::BITS != 0 { + result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS)); + } + } + + result + } +} + +impl Shl for ScaledUnderlier { + type Output = Self; + + fn shl(self, rhs: usize) -> Self::Output { + let mut result = Self::default(); + + let shift_in_items = rhs / U::BITS; + for i in shift_in_items.saturating_sub(1)..N { + if i >= shift_in_items { + result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS); + } + if i > shift_in_items && rhs % U::BITS != 0 { + result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS)); + } + } + + result + } +} + +impl, const N: usize> Not for ScaledUnderlier { + type Output = Self; + + fn not(self) -> Self::Output { + Self(self.0.map(U::not)) + } +} + +impl UnderlierWithBitOps for ScaledUnderlier { + const ZERO: Self = Self([U::ZERO; N]); + const ONE: Self = { + let mut arr = [U::ZERO; N]; + arr[0] = U::ONE; + Self(arr) + }; + const ONES: Self = Self([U::ONES; N]); + + #[inline] + fn fill_with_bit(val: u8) -> Self { + Self(array::from_fn(|_| U::fill_with_bit(val))) + } + + #[inline] + fn shl_128b_lanes(self, rhs: usize) -> Self { + // We assume that the underlier type has at least 128 bits as the current implementation + // is valid for this case only. + // On practice, we don't use scaled underliers with underlier types that have less than 128 bits. + assert!(U::BITS >= 128); + + Self(self.0.map(|x| x.shl_128b_lanes(rhs))) + } + + #[inline] + fn shr_128b_lanes(self, rhs: usize) -> Self { + // We assume that the underlier type has at least 128 bits as the current implementation + // is valid for this case only. + // On practice, we don't use scaled underliers with underlier types that have less than 128 bits. + assert!(U::BITS >= 128); + + Self(self.0.map(|x| x.shr_128b_lanes(rhs))) + } + + #[inline] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + // We assume that the underlier type has at least 128 bits as the current implementation + // is valid for this case only. + // On practice, we don't use scaled underliers with underlier types that have less than 128 bits. + assert!(U::BITS >= 128); + + Self(array::from_fn(|i| self.0[i].unpack_lo_128b_lanes(other.0[i], log_block_len))) + } + + #[inline] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + // We assume that the underlier type has at least 128 bits as the current implementation + // is valid for this case only. + // On practice, we don't use scaled underliers with underlier types that have less than 128 bits. + assert!(U::BITS >= 128); + + Self(array::from_fn(|i| self.0[i].unpack_hi_128b_lanes(other.0[i], log_block_len))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_shr() { + let val = ScaledUnderlier::([0, 1, 2, 3]); + assert_eq!( + val >> 1, + ScaledUnderlier::([0b10000000, 0b00000000, 0b10000001, 0b00000001]) + ); + assert_eq!( + val >> 2, + ScaledUnderlier::([0b01000000, 0b10000000, 0b11000000, 0b00000000]) + ); + assert_eq!( + val >> 8, + ScaledUnderlier::([0b00000001, 0b00000010, 0b00000011, 0b00000000]) + ); + assert_eq!( + val >> 9, + ScaledUnderlier::([0b00000000, 0b10000001, 0b00000001, 0b00000000]) + ); + } + + #[test] + fn test_shl() { + let val = ScaledUnderlier::([0, 1, 2, 3]); + assert_eq!(val << 1, ScaledUnderlier::([0, 2, 4, 6])); + assert_eq!(val << 2, ScaledUnderlier::([0, 4, 8, 12])); + assert_eq!(val << 8, ScaledUnderlier::([0, 0, 1, 2])); + assert_eq!(val << 9, ScaledUnderlier::([0, 0, 2, 4])); + } +} diff --git a/crates/field/src/underlier/small_uint.rs b/crates/field/src/underlier/small_uint.rs index 2e99f211..3413f5ca 100644 --- a/crates/field/src/underlier/small_uint.rs +++ b/crates/field/src/underlier/small_uint.rs @@ -6,7 +6,12 @@ use std::{ ops::{Not, Shl, Shr}, }; -use binius_utils::checked_arithmetics::checked_log_2; +use binius_utils::{ + bytes::{Buf, BufMut}, + checked_arithmetics::checked_log_2, + serialization::DeserializeBytes, + SerializationError, SerializationMode, SerializeBytes, +}; use bytemuck::{NoUninit, Zeroable}; use derive_more::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign}; use rand::{ @@ -154,6 +159,14 @@ impl UnderlierWithBitOps for SmallU { fn fill_with_bit(val: u8) -> Self { Self(u8::fill_with_bit(val)) & Self::ONES } + + fn shl_128b_lanes(self, rhs: usize) -> Self { + self << rhs + } + + fn shr_128b_lanes(self, rhs: usize) -> Self { + self >> rhs + } } impl From> for u8 { @@ -222,3 +235,22 @@ impl From> for SmallU<4> { pub type U1 = SmallU<1>; pub type U2 = SmallU<2>; pub type U4 = SmallU<4>; + +impl SerializeBytes for SmallU { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + self.val().serialize(write_buf, mode) + } +} + +impl DeserializeBytes for SmallU { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + Ok(Self::new(DeserializeBytes::deserialize(read_buf, mode)?)) + } +} diff --git a/crates/field/src/underlier/underlier_impls.rs b/crates/field/src/underlier/underlier_impls.rs index bb52b689..e28e27dc 100644 --- a/crates/field/src/underlier/underlier_impls.rs +++ b/crates/field/src/underlier/underlier_impls.rs @@ -23,6 +23,16 @@ macro_rules! impl_underlier_type { debug_assert!(val == 0 || val == 1); (val as Self).wrapping_neg() } + + #[inline(always)] + fn shl_128b_lanes(self, rhs: usize) -> Self { + self << rhs + } + + #[inline(always)] + fn shr_128b_lanes(self, rhs: usize) -> Self { + self >> rhs + } } }; () => {}; diff --git a/crates/field/src/underlier/underlier_with_bit_ops.rs b/crates/field/src/underlier/underlier_with_bit_ops.rs index e151f501..8cff6d18 100644 --- a/crates/field/src/underlier/underlier_with_bit_ops.rs +++ b/crates/field/src/underlier/underlier_with_bit_ops.rs @@ -118,6 +118,38 @@ pub trait UnderlierWithBitOps: { spread_fallback(self, log_block_len, block_idx) } + + /// Left shift within 128-bit lanes. + /// This can be more efficient than the full `Shl` implementation. + fn shl_128b_lanes(self, shift: usize) -> Self; + + /// Right shift within 128-bit lanes. + /// This can be more efficient than the full `Shr` implementation. + fn shr_128b_lanes(self, shift: usize) -> Self; + + /// Unpacks `1 << log_block_len`-bit values from low parts of `self` and `other` within 128-bit lanes. + /// + /// Example: + /// self: [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7] + /// other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7] + /// log_block_len: 1 + /// + /// result: [a_0, a_0, b_0, b_1, a_2, a_3, b_2, b_3] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + unpack_lo_128b_fallback(self, other, log_block_len) + } + + /// Unpacks `1 << log_block_len`-bit values from high parts of `self` and `other` within 128-bit lanes. + /// + /// Example: + /// self: [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7] + /// other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7] + /// log_block_len: 1 + /// + /// result: [a_4, a_5, b_4, b_5, a_6, a_7, b_6, b_7] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + unpack_hi_128b_fallback(self, other, log_block_len) + } } /// Returns a bit mask for a single `T` element inside underlier type. @@ -171,6 +203,55 @@ where result } +#[inline(always)] +fn single_element_mask_bits_128b_lanes(log_block_len: usize) -> T { + let mut mask = single_element_mask_bits(1 << log_block_len); + for i in 1..T::BITS / 128 { + mask |= mask << (i * 128); + } + + mask +} + +pub(crate) fn unpack_lo_128b_fallback( + lhs: T, + rhs: T, + log_block_len: usize, +) -> T { + assert!(log_block_len <= 6); + + let mask = single_element_mask_bits_128b_lanes::(log_block_len); + + let mut result = T::ZERO; + for i in 0..1 << (6 - log_block_len) { + result |= ((lhs.shr_128b_lanes(i << log_block_len)) & mask) + .shl_128b_lanes(i << (log_block_len + 1)); + result |= ((rhs.shr_128b_lanes(i << log_block_len)) & mask) + .shl_128b_lanes((2 * i + 1) << log_block_len); + } + + result +} + +pub(crate) fn unpack_hi_128b_fallback( + lhs: T, + rhs: T, + log_block_len: usize, +) -> T { + assert!(log_block_len <= 6); + + let mask = single_element_mask_bits_128b_lanes::(log_block_len); + let mut result = T::ZERO; + for i in 0..1 << (6 - log_block_len) { + result |= ((lhs.shr_128b_lanes(64 + (i << log_block_len))) & mask) + .shl_128b_lanes(i << (log_block_len + 1)); + result |= ((rhs.shr_128b_lanes(64 + (i << log_block_len))) & mask) + .shl_128b_lanes((2 * i + 1) << log_block_len); + } + + result +} + pub(crate) fn single_element_mask_bits(bits_count: usize) -> T { if bits_count == T::BITS { !T::ZERO diff --git a/crates/field/src/util.rs b/crates/field/src/util.rs index 3bc6ec29..12f99a25 100644 --- a/crates/field/src/util.rs +++ b/crates/field/src/util.rs @@ -16,7 +16,7 @@ where F: Field, FE: ExtensionField, { - iter::zip(a, b).map(|(a_i, b_i)| a_i * b_i).sum::() + iter::zip(a, b).map(|(a_i, b_i)| a_i * b_i).sum() } /// Calculate inner product for potentially big slices of xs and ys. @@ -38,9 +38,8 @@ where return inner_product_unchecked(PackedField::iter_slice(xs), PackedField::iter_slice(ys)); } - let calc_product_by_ys = |x_offset, ys: &[PY]| { + let calc_product_by_ys = |xs: &[PX], ys: &[PY]| { let mut result = FX::ZERO; - let xs = &xs[x_offset..]; for (j, y) in ys.iter().enumerate() { for (k, y) in y.iter().enumerate() { @@ -56,14 +55,14 @@ where // For different field sizes, the numbers may need to be adjusted. const CHUNK_SIZE: usize = 64; if ys.len() < 16 * CHUNK_SIZE { - calc_product_by_ys(0, ys) + calc_product_by_ys(xs, ys) } else { // According to benchmark results iterating by chunks here is more efficient than using `par_iter` with `min_length` directly. ys.par_chunks(CHUNK_SIZE) .enumerate() .map(|(i, ys)| { let offset = i * checked_int_div(CHUNK_SIZE * PY::WIDTH, PX::WIDTH); - calc_product_by_ys(offset, ys) + calc_product_by_ys(&xs[offset..], ys) }) .sum() } @@ -79,3 +78,91 @@ pub fn eq(x: F, y: F) -> F { pub fn powers(val: F) -> impl Iterator { iter::successors(Some(F::ONE), move |&power| Some(power * val)) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::PackedBinaryField4x32b; + + type P = PackedBinaryField4x32b; + type F =

::Scalar; + + #[test] + fn test_inner_product_par_equal_length() { + // xs and ys have the same number of packed elements + let xs1 = F::new(1); + let xs2 = F::new(2); + let xs = vec![P::set_single(xs1), P::set_single(xs2)]; + let ys1 = F::new(3); + let ys2 = F::new(4); + let ys = vec![P::set_single(ys1), P::set_single(ys2)]; + + let result = inner_product_par::(&xs, &ys); + let expected = xs1 * ys1 + xs2 * ys2; + + assert_eq!(result, expected); + } + + #[test] + fn test_inner_product_par_unequal_length() { + // ys is larger than xs due to packing differences + let xs1 = F::new(1); + let xs = vec![P::set_single(xs1)]; + let ys1 = F::new(2); + let ys2 = F::new(3); + let ys = vec![P::set_single(ys1), P::set_single(ys2)]; + + let result = inner_product_par::(&xs, &ys); + let expected = xs1 * ys1; + + assert_eq!(result, expected); + } + + #[test] + fn test_inner_product_par_large_input_single_threaded() { + // Large input but not enough to trigger parallel execution + let size = 256; + let xs: Vec

= (0..size).map(|i| P::set_single(F::new(i as u32))).collect(); + let ys: Vec

= (0..size) + .map(|i| P::set_single(F::new((i + 1) as u32))) + .collect(); + + let result = inner_product_par::(&xs, &ys); + + let expected = (0..size) + .map(|i| F::new(i as u32) * F::new((i + 1) as u32)) + .sum::(); + + assert_eq!(result, expected); + } + + #[test] + fn test_inner_product_par_large_input_par() { + // Large input to test parallel execution + let size = 2000; + let xs: Vec

= (0..size).map(|i| P::set_single(F::new(i as u32))).collect(); + let ys: Vec

= (0..size) + .map(|i| P::set_single(F::new((i + 1) as u32))) + .collect(); + + let result = inner_product_par::(&xs, &ys); + + let expected = (0..size) + .map(|i| F::new(i as u32) * F::new((i + 1) as u32)) + .sum::(); + + assert_eq!(result, expected); + } + + #[test] + fn test_inner_product_par_empty() { + // Case: Empty input should return 0 + let xs: Vec

= vec![]; + let ys: Vec

= vec![]; + + let result = inner_product_par::(&xs, &ys); + let expected = F::ZERO; + + assert_eq!(result, expected); + } +} diff --git a/crates/hal/src/backend.rs b/crates/hal/src/backend.rs index 05bd2e87..8d9dac1d 100644 --- a/crates/hal/src/backend.rs +++ b/crates/hal/src/backend.rs @@ -5,9 +5,9 @@ use std::{ ops::{Deref, DerefMut}, }; -use binius_field::{ExtensionField, Field, PackedExtension, PackedField}; +use binius_field::{Field, PackedExtension, PackedField}; use binius_math::{ - CompositionPolyOS, MultilinearExtension, MultilinearPoly, MultilinearQuery, MultilinearQueryRef, + CompositionPoly, MultilinearExtension, MultilinearPoly, MultilinearQuery, MultilinearQueryRef, }; use binius_maybe_rayon::iter::FromParallelIterator; use tracing::instrument; @@ -42,28 +42,8 @@ pub trait ComputationBackend: Send + Sync + Debug { query: &[P::Scalar], ) -> Result, Error>; - /// Calculate the accumulated evaluations for the first round of zerocheck. - fn sumcheck_compute_first_round_evals( - &self, - n_vars: usize, - multilinears: &[SumcheckMultilinear], - evaluators: &[Evaluator], - evaluation_points: &[FDomain], - ) -> Result>, Error> - where - FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension - + PackedExtension, - M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

; - /// Calculate the accumulated evaluations for an arbitrary round of zerocheck. - fn sumcheck_compute_later_round_evals( + fn sumcheck_compute_round_evals( &self, n_vars: usize, tensor_query: Option>, @@ -73,13 +53,10 @@ pub trait ComputationBackend: Send + Sync + Debug { ) -> Result>, Error> where FDomain: Field, - F: Field + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension, + P: PackedExtension, M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

; + Evaluator: SumcheckEvaluator + Sync, + Composition: CompositionPoly

; /// Partially evaluate the polynomial with assignment to the high-indexed variables. fn evaluate_partial_high( @@ -108,35 +85,7 @@ where T::tensor_product_full_query(self, query) } - fn sumcheck_compute_first_round_evals( - &self, - n_vars: usize, - multilinears: &[SumcheckMultilinear], - evaluators: &[Evaluator], - evaluation_points: &[FDomain], - ) -> Result>, Error> - where - FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension - + PackedExtension, - M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, - { - T::sumcheck_compute_first_round_evals::<_, FBase, _, _, _, _, _>( - self, - n_vars, - multilinears, - evaluators, - evaluation_points, - ) - } - - fn sumcheck_compute_later_round_evals( + fn sumcheck_compute_round_evals( &self, n_vars: usize, tensor_query: Option>, @@ -146,15 +95,12 @@ where ) -> Result>, Error> where FDomain: Field, - F: Field + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension, + P: PackedExtension, M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, + Evaluator: SumcheckEvaluator + Sync, + Composition: CompositionPoly

, { - T::sumcheck_compute_later_round_evals( + T::sumcheck_compute_round_evals( self, n_vars, tensor_query, diff --git a/crates/hal/src/cpu.rs b/crates/hal/src/cpu.rs index 3d9eac8f..acd17b22 100644 --- a/crates/hal/src/cpu.rs +++ b/crates/hal/src/cpu.rs @@ -2,16 +2,16 @@ use std::fmt::Debug; -use binius_field::{ExtensionField, Field, PackedExtension, PackedField}; +use binius_field::{Field, PackedExtension, PackedField}; use binius_math::{ - eq_ind_partial_eval, CompositionPolyOS, MultilinearExtension, MultilinearPoly, + eq_ind_partial_eval, CompositionPoly, MultilinearExtension, MultilinearPoly, MultilinearQueryRef, }; use tracing::instrument; use crate::{ - sumcheck_round_calculator::{calculate_first_round_evals, calculate_later_round_evals}, - ComputationBackend, Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear, + sumcheck_round_calculator::calculate_round_evals, ComputationBackend, Error, RoundEvals, + SumcheckEvaluator, SumcheckMultilinear, }; /// Implementation of ComputationBackend for the default Backend that uses the CPU for all computations. @@ -37,31 +37,7 @@ impl ComputationBackend for CpuBackend { Ok(eq_ind_partial_eval(query)) } - fn sumcheck_compute_first_round_evals( - &self, - n_vars: usize, - multilinears: &[SumcheckMultilinear], - evaluators: &[Evaluator], - evaluation_points: &[FDomain], - ) -> Result>, Error> - where - FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, - M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, - { - calculate_first_round_evals::<_, FBase, _, _, _, _, _>( - n_vars, - multilinears, - evaluators, - evaluation_points, - ) - } - - fn sumcheck_compute_later_round_evals( + fn sumcheck_compute_round_evals( &self, n_vars: usize, tensor_query: Option>, @@ -71,21 +47,12 @@ impl ComputationBackend for CpuBackend { ) -> Result>, Error> where FDomain: Field, - F: Field + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension, + P: PackedExtension, M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, + Evaluator: SumcheckEvaluator + Sync, + Composition: CompositionPoly

, { - calculate_later_round_evals( - n_vars, - tensor_query, - multilinears, - evaluators, - evaluation_points, - ) + calculate_round_evals(n_vars, tensor_query, multilinears, evaluators, evaluation_points) } #[instrument(skip_all, name = "CpuBackend::evaluate_partial_high")] diff --git a/crates/hal/src/sumcheck_evaluator.rs b/crates/hal/src/sumcheck_evaluator.rs index 982b9c6c..31e0721a 100644 --- a/crates/hal/src/sumcheck_evaluator.rs +++ b/crates/hal/src/sumcheck_evaluator.rs @@ -2,17 +2,13 @@ use std::ops::Range; -use binius_field::{ExtensionField, Field, PackedExtension, PackedField, PackedSubfield}; +use binius_field::{Field, PackedField}; /// Evaluations of a polynomial at a set of evaluation points. #[derive(Debug, Clone)] pub struct RoundEvals(pub Vec); -pub trait SumcheckEvaluator -where - FBase: Field, - P: PackedField> + PackedExtension, -{ +pub trait SumcheckEvaluator { /// The range of eval point indices over which composition evaluation and summation should happen. /// Returned range must equal the result of `n_round_evals()` in length. fn eval_point_indices(&self) -> Range; @@ -27,7 +23,7 @@ where &self, subcube_vars: usize, subcube_index: usize, - batch_query: &[&[PackedSubfield]], + batch_query: &[&[P]], ) -> P; /// Returns the composition evaluated by this object. diff --git a/crates/hal/src/sumcheck_round_calculator.rs b/crates/hal/src/sumcheck_round_calculator.rs index f71d1f84..d14b1594 100644 --- a/crates/hal/src/sumcheck_round_calculator.rs +++ b/crates/hal/src/sumcheck_round_calculator.rs @@ -4,19 +4,16 @@ //! //! This is one of the core computational tasks in the sumcheck proving algorithm. -use std::{iter, marker::PhantomData}; +use std::iter; -use binius_field::{ - recast_packed, ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, - RepackedExtension, -}; +use binius_field::{Field, PackedExtension, PackedField, PackedSubfield}; use binius_math::{ - deinterleave, extrapolate_lines, CompositionPolyOS, MultilinearPoly, MultilinearQuery, + deinterleave, extrapolate_lines, CompositionPoly, MultilinearPoly, MultilinearQuery, MultilinearQueryRef, }; use binius_maybe_rayon::prelude::*; use bytemuck::zeroed_vec; -use itertools::izip; +use itertools::{izip, Itertools}; use stackalloc::stackalloc_with_iter; use crate::{Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear}; @@ -44,39 +41,11 @@ trait SumcheckMultilinearAccess { ) -> Result<(), Error>; } -/// Calculate the accumulated evaluations for the first sumcheck round. -pub(crate) fn calculate_first_round_evals( - n_vars: usize, - multilinears: &[SumcheckMultilinear], - evaluators: &[Evaluator], - evaluation_points: &[FDomain], -) -> Result>, Error> -where - FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, - M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, -{ - let accesses = multilinears - .iter() - .map(FirstRoundAccess::new) - .collect::>(); - calculate_round_evals::<_, FBase, _, _, _, _, _>( - n_vars, - &accesses, - evaluators, - evaluation_points, - ) -} - /// Calculate the accumulated evaluations for an arbitrary sumcheck round. /// /// See [`calculate_first_round_evals`] for an optimized version of this method /// that works over small fields in the first round. -pub(crate) fn calculate_later_round_evals( +pub(crate) fn calculate_round_evals( n_vars: usize, tensor_query: Option>, multilinears: &[SumcheckMultilinear], @@ -85,26 +54,27 @@ pub(crate) fn calculate_later_round_evals Result>, Error> where FDomain: Field, - F: Field + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, + F: Field, + P: PackedField + PackedExtension, M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, + Evaluator: SumcheckEvaluator + Sync, + Composition: CompositionPoly

, { let empty_query = MultilinearQuery::with_capacity(0); - let query = tensor_query.unwrap_or_else(|| empty_query.to_ref()); + let tensor_query = tensor_query.unwrap_or_else(|| empty_query.to_ref()); - let accesses = multilinears + let later_rounds_accesses = multilinears .iter() - .map(|multilinear| LaterRoundAccess { + .map(|multilinear| LargeFieldAccess { multilinear, - tensor_query: query, + tensor_query, }) - .collect::>(); - calculate_round_evals::<_, F, _, _, _, _, _>(n_vars, &accesses, evaluators, evaluation_points) + .collect_vec(); + + calculate_round_evals_with_access(n_vars, &later_rounds_accesses, evaluators, evaluation_points) } -fn calculate_round_evals( +fn calculate_round_evals_with_access( n_vars: usize, multilinears: &[Access], evaluators: &[Evaluator], @@ -112,12 +82,11 @@ fn calculate_round_evals( ) -> Result>, Error> where FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, - Evaluator: SumcheckEvaluator + Sync, - Access: SumcheckMultilinearAccess> + Sync, - Composition: CompositionPolyOS

, + F: Field, + P: PackedField + PackedExtension, + Evaluator: SumcheckEvaluator + Sync, + Access: SumcheckMultilinearAccess

+ Sync, + Composition: CompositionPoly

, { let n_multilinears = multilinears.len(); let n_round_evals = evaluators @@ -182,9 +151,9 @@ where // `binius_math::univariate::extrapolate_line`, except that we do // not repeat the broadcast of the subfield element to a packed // subfield. - *eval_z = recast_packed::(extrapolate_lines( - recast_packed::(eval_0), - recast_packed::(eval_1), + *eval_z = P::cast_ext(extrapolate_lines( + P::cast_base(eval_0), + P::cast_base(eval_1), eval_point_broadcast, )); } @@ -316,57 +285,7 @@ impl ParFoldStates { } #[derive(Debug)] -struct FirstRoundAccess<'a, PBase, P, M> -where - P: PackedField, - M: MultilinearPoly

+ Send + Sync, -{ - multilinear: &'a SumcheckMultilinear, - _marker: PhantomData, -} - -impl<'a, PBase, P, M> FirstRoundAccess<'a, PBase, P, M> -where - P: PackedField, - M: MultilinearPoly

+ Send + Sync, -{ - const fn new(multilinear: &'a SumcheckMultilinear) -> Self { - Self { - multilinear, - _marker: PhantomData, - } - } -} - -impl SumcheckMultilinearAccess for FirstRoundAccess<'_, PBase, P, M> -where - PBase: PackedField, - P: RepackedExtension, - P::Scalar: ExtensionField, - M: MultilinearPoly

+ Send + Sync, -{ - fn subcube_evaluations( - &self, - subcube_vars: usize, - subcube_index: usize, - evals: &mut [PBase], - ) -> Result<(), Error> { - if let SumcheckMultilinear::Transparent { multilinear, .. } = self.multilinear { - let evals =

>::cast_exts_mut(evals); - Ok(multilinear.subcube_evals( - subcube_vars, - subcube_index, - >::LOG_DEGREE, - evals, - )?) - } else { - panic!("precondition: no folded multilinears in the first round"); - } - } -} - -#[derive(Debug)] -struct LaterRoundAccess<'a, P, M> +struct LargeFieldAccess<'a, P, M> where P: PackedField, M: MultilinearPoly

+ Send + Sync, @@ -375,7 +294,7 @@ where tensor_query: MultilinearQueryRef<'a, P>, } -impl SumcheckMultilinearAccess

for LaterRoundAccess<'_, P, M> +impl SumcheckMultilinearAccess

for LargeFieldAccess<'_, P, M> where P: PackedField, M: MultilinearPoly

+ Send + Sync, @@ -388,28 +307,28 @@ where ) -> Result<(), Error> { match self.multilinear { SumcheckMultilinear::Transparent { multilinear, .. } => { - // TODO: Stop using LaterRoundAccess for first round in RegularSumcheckProver and - // GPASumcheckProver, then remove this conditional. if self.tensor_query.n_vars() == 0 { - Ok(multilinear.subcube_evals(subcube_vars, subcube_index, 0, evals)?) + multilinear.subcube_evals(subcube_vars, subcube_index, 0, evals)? } else { - Ok(multilinear.subcube_inner_products( + multilinear.subcube_inner_products( self.tensor_query, subcube_vars, subcube_index, evals, - )?) + )? } } SumcheckMultilinear::Folded { large_field_folded_multilinear, - } => Ok(large_field_folded_multilinear.subcube_evals( + } => large_field_folded_multilinear.subcube_evals( subcube_vars, subcube_index, 0, evals, - )?), + )?, } + + Ok(()) } } diff --git a/crates/hash/src/groestl/hasher.rs b/crates/hash/src/groestl/hasher.rs index 74f9e49c..d6ee087c 100644 --- a/crates/hash/src/groestl/hasher.rs +++ b/crates/hash/src/groestl/hasher.rs @@ -153,7 +153,6 @@ impl Hasher

for Groestl256 where F: BinaryField + From + Into, P: PackedExtension, - P::Scalar: ExtensionField, OptimalUnderlier256b: PackScalar + Divisible, Self: UpdateOverSlice, { @@ -300,7 +299,7 @@ mod tests { } #[test] - fn test_aes_binary_convertion() { + fn test_aes_binary_conversion() { let mut rng = thread_rng(); let input_aes: [PackedAESBinaryField32x8b; 90] = array::from_fn(|_| PackedAESBinaryField32x8b::random(&mut rng)); diff --git a/crates/hash/src/serialization.rs b/crates/hash/src/serialization.rs index ff64c975..f1292c64 100644 --- a/crates/hash/src/serialization.rs +++ b/crates/hash/src/serialization.rs @@ -2,7 +2,7 @@ use std::{borrow::Borrow, cmp::min}; -use binius_utils::serialization::SerializeBytes; +use binius_utils::{SerializationMode, SerializeBytes}; use bytes::{buf::UninitSlice, BufMut}; use digest::{ core_api::{Block, BlockSizeUser}, @@ -75,7 +75,7 @@ where let mut buffer = HashBuffer::new(&mut hasher); for item in items { item.borrow() - .serialize(&mut buffer) + .serialize(&mut buffer, SerializationMode::CanonicalTower) .expect("HashBuffer has infinite capacity"); } } diff --git a/crates/hash/src/vision.rs b/crates/hash/src/vision.rs index 45940192..750db41b 100644 --- a/crates/hash/src/vision.rs +++ b/crates/hash/src/vision.rs @@ -221,7 +221,6 @@ where U: PackScalar + Divisible, F: BinaryField + From + Into, P: PackedExtension, - P::Scalar: ExtensionField, PackedAESBinaryField8x32b: WithUnderlier, { type Digest = PackedType; diff --git a/crates/macros/Cargo.toml b/crates/macros/Cargo.toml index c0293a3c..c3b8defb 100644 --- a/crates/macros/Cargo.toml +++ b/crates/macros/Cargo.toml @@ -16,6 +16,7 @@ proc-macro2.workspace = true binius_core = { path = "../core" } binius_field = { path = "../field" } binius_math = { path = "../math" } +binius_utils = { path = "../utils" } paste.workspace = true rand.workspace = true diff --git a/crates/macros/src/arith_circuit_poly.rs b/crates/macros/src/arith_circuit_poly.rs index 0f8f1f6a..93ba74a7 100644 --- a/crates/macros/src/arith_circuit_poly.rs +++ b/crates/macros/src/arith_circuit_poly.rs @@ -3,55 +3,22 @@ use quote::{quote, ToTokens}; use syn::{bracketed, parse::Parse, parse_quote, spanned::Spanned, Token}; -use crate::composition_poly::CompositionPolyItem; - #[derive(Debug)] pub(crate) struct ArithCircuitPolyItem { poly: syn::Expr, - /// We create a composition poly to cache the efficient evaluation implementations - /// for the known packed field types. - composition_poly: CompositionPolyItem, field_name: syn::Ident, } impl ToTokens for ArithCircuitPolyItem { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { - let Self { - poly, - composition_poly, - field_name, - } = self; - - let mut register_cached_impls = proc_macro2::TokenStream::new(); - let packed_extensions = get_packed_extensions(field_name); - if packed_extensions.is_empty() { - register_cached_impls.extend(quote! { result }); - } else { - register_cached_impls.extend(quote! ( - let mut cached = binius_core::polynomial::CachedPoly::new(composition); - - )); - - for packed_extension in get_packed_extensions(field_name) { - register_cached_impls.extend(quote! { - cached.register::(composition.clone()); - }); - } - - register_cached_impls.extend(quote! { - cached - }); - } + let Self { poly, field_name } = self; tokens.extend(quote! { { use binius_field::Field; use binius_math::ArithExpr as Expr; - let mut result = binius_core::polynomial::ArithCircuitPoly::::new(#poly); - let composition = #composition_poly; - - #register_cached_impls + binius_core::polynomial::ArithCircuitPoly::::new(#poly) } }); } @@ -59,7 +26,6 @@ impl ToTokens for ArithCircuitPolyItem { impl Parse for ArithCircuitPolyItem { fn parse(input: syn::parse::ParseStream) -> syn::Result { - let original_tokens = input.fork(); let vars: Vec = { let content; bracketed!(content in input); @@ -73,14 +39,8 @@ impl Parse for ArithCircuitPolyItem { input.parse::()?; let field_name = input.parse()?; - // Here we assume that the `composition_poly` shares the expression syntax with the `arithmetic_circuit_poly`. - let composition_poly = CompositionPolyItem::parse(&original_tokens)?; - Ok(Self { - poly, - composition_poly, - field_name, - }) + Ok(Self { poly, field_name }) } } @@ -120,350 +80,3 @@ fn flatten_expr(expr: &syn::Expr, vars: &[syn::Ident]) -> Result Vec { - match ident.to_string().as_str() { - "BinaryField1b" => vec![ - parse_quote!(PackedBinaryField1x1b), - parse_quote!(PackedBinaryField2x1b), - parse_quote!(PackedBinaryField4x1b), - parse_quote!(PackedBinaryField8x1b), - parse_quote!(PackedBinaryField16x1b), - parse_quote!(PackedBinaryField32x1b), - parse_quote!(PackedBinaryField64x1b), - parse_quote!(PackedBinaryField128x1b), - parse_quote!(PackedBinaryField256x1b), - parse_quote!(PackedBinaryField512x1b), - parse_quote!(PackedBinaryField1x2b), - parse_quote!(PackedBinaryField2x2b), - parse_quote!(PackedBinaryField4x2b), - parse_quote!(PackedBinaryField8x2b), - parse_quote!(PackedBinaryField16x2b), - parse_quote!(PackedBinaryField32x2b), - parse_quote!(PackedBinaryField64x2b), - parse_quote!(PackedBinaryField128x2b), - parse_quote!(PackedBinaryField256x2b), - parse_quote!(PackedBinaryField1x4b), - parse_quote!(PackedBinaryField2x4b), - parse_quote!(PackedBinaryField4x4b), - parse_quote!(PackedBinaryField8x4b), - parse_quote!(PackedBinaryField16x4b), - parse_quote!(PackedBinaryField32x4b), - parse_quote!(PackedBinaryField64x4b), - parse_quote!(PackedBinaryField128x4b), - parse_quote!(PackedBinaryField1x8b), - parse_quote!(PackedBinaryField2x8b), - parse_quote!(PackedBinaryField4x8b), - parse_quote!(PackedBinaryField8x8b), - parse_quote!(PackedBinaryField16x8b), - parse_quote!(PackedBinaryField32x8b), - parse_quote!(PackedBinaryField64x8b), - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - parse_quote!(PackedAESBinaryField1x8b), - parse_quote!(PackedAESBinaryField2x8b), - parse_quote!(PackedAESBinaryField4x8b), - parse_quote!(PackedAESBinaryField8x8b), - parse_quote!(PackedAESBinaryField16x8b), - parse_quote!(PackedAESBinaryField32x8b), - parse_quote!(PackedAESBinaryField64x8b), - parse_quote!(PackedAESBinaryField1x16b), - parse_quote!(PackedAESBinaryField2x16b), - parse_quote!(PackedAESBinaryField4x16b), - parse_quote!(PackedAESBinaryField8x16b), - parse_quote!(PackedAESBinaryField16x16b), - parse_quote!(PackedAESBinaryField32x16b), - parse_quote!(PackedAESBinaryField1x32b), - parse_quote!(PackedAESBinaryField2x32b), - parse_quote!(PackedAESBinaryField4x32b), - parse_quote!(PackedAESBinaryField8x32b), - parse_quote!(PackedAESBinaryField16x32b), - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - parse_quote!(PackedBinaryPolyval1x128b), - parse_quote!(PackedBinaryPolyval2x128b), - parse_quote!(PackedBinaryPolyval4x128b), - ], - "BinaryField2b" => { - vec![ - parse_quote!(PackedBinaryField1x2b), - parse_quote!(PackedBinaryField2x2b), - parse_quote!(PackedBinaryField4x2b), - parse_quote!(PackedBinaryField8x2b), - parse_quote!(PackedBinaryField16x2b), - parse_quote!(PackedBinaryField32x2b), - parse_quote!(PackedBinaryField64x2b), - parse_quote!(PackedBinaryField128x2b), - parse_quote!(PackedBinaryField256x2b), - parse_quote!(PackedBinaryField1x4b), - parse_quote!(PackedBinaryField2x4b), - parse_quote!(PackedBinaryField4x4b), - parse_quote!(PackedBinaryField8x4b), - parse_quote!(PackedBinaryField16x4b), - parse_quote!(PackedBinaryField32x4b), - parse_quote!(PackedBinaryField64x4b), - parse_quote!(PackedBinaryField128x4b), - parse_quote!(PackedBinaryField1x8b), - parse_quote!(PackedBinaryField2x8b), - parse_quote!(PackedBinaryField4x8b), - parse_quote!(PackedBinaryField8x8b), - parse_quote!(PackedBinaryField16x8b), - parse_quote!(PackedBinaryField32x8b), - parse_quote!(PackedBinaryField64x8b), - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField4b" => { - vec![ - parse_quote!(PackedBinaryField1x4b), - parse_quote!(PackedBinaryField2x4b), - parse_quote!(PackedBinaryField4x4b), - parse_quote!(PackedBinaryField8x4b), - parse_quote!(PackedBinaryField16x4b), - parse_quote!(PackedBinaryField32x4b), - parse_quote!(PackedBinaryField64x4b), - parse_quote!(PackedBinaryField128x4b), - parse_quote!(PackedBinaryField1x8b), - parse_quote!(PackedBinaryField2x8b), - parse_quote!(PackedBinaryField4x8b), - parse_quote!(PackedBinaryField8x8b), - parse_quote!(PackedBinaryField16x8b), - parse_quote!(PackedBinaryField32x8b), - parse_quote!(PackedBinaryField64x8b), - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField8b" => { - vec![ - parse_quote!(PackedBinaryField1x8b), - parse_quote!(PackedBinaryField2x8b), - parse_quote!(PackedBinaryField4x8b), - parse_quote!(PackedBinaryField8x8b), - parse_quote!(PackedBinaryField16x8b), - parse_quote!(PackedBinaryField32x8b), - parse_quote!(PackedBinaryField64x8b), - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField16b" => { - vec![ - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField32b" => { - vec![ - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField64b" => { - vec![ - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - - "BinaryField128b" => { - vec![ - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - - "AESTowerField8b" => { - vec![ - parse_quote!(PackedAESBinaryField1x8b), - parse_quote!(PackedAESBinaryField2x8b), - parse_quote!(PackedAESBinaryField4x8b), - parse_quote!(PackedAESBinaryField8x8b), - parse_quote!(PackedAESBinaryField16x8b), - parse_quote!(PackedAESBinaryField32x8b), - parse_quote!(PackedAESBinaryField64x8b), - parse_quote!(PackedAESBinaryField1x16b), - parse_quote!(PackedAESBinaryField2x16b), - parse_quote!(PackedAESBinaryField4x16b), - parse_quote!(PackedAESBinaryField8x16b), - parse_quote!(PackedAESBinaryField16x16b), - parse_quote!(PackedAESBinaryField32x16b), - parse_quote!(PackedAESBinaryField1x32b), - parse_quote!(PackedAESBinaryField2x32b), - parse_quote!(PackedAESBinaryField4x32b), - parse_quote!(PackedAESBinaryField8x32b), - parse_quote!(PackedAESBinaryField16x32b), - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - parse_quote!(ByteSlicedAES32x128b), - ] - } - "AESTowerField16b" => { - vec![ - parse_quote!(PackedAESBinaryField1x16b), - parse_quote!(PackedAESBinaryField2x16b), - parse_quote!(PackedAESBinaryField4x16b), - parse_quote!(PackedAESBinaryField8x16b), - parse_quote!(PackedAESBinaryField16x16b), - parse_quote!(PackedAESBinaryField32x16b), - parse_quote!(PackedAESBinaryField1x32b), - parse_quote!(PackedAESBinaryField2x32b), - parse_quote!(PackedAESBinaryField4x32b), - parse_quote!(PackedAESBinaryField8x32b), - parse_quote!(PackedAESBinaryField16x32b), - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - ] - } - "AESTowerField32b" => { - vec![ - parse_quote!(PackedAESBinaryField1x32b), - parse_quote!(PackedAESBinaryField2x32b), - parse_quote!(PackedAESBinaryField4x32b), - parse_quote!(PackedAESBinaryField8x32b), - parse_quote!(PackedAESBinaryField16x32b), - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - ] - } - "AESTowerField64b" => { - vec![ - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - ] - } - "AESTowerField128b" => { - vec![ - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - ] - } - - _ => vec![], - } -} diff --git a/crates/macros/src/composition_poly.rs b/crates/macros/src/composition_poly.rs index 36200a76..0472fc32 100644 --- a/crates/macros/src/composition_poly.rs +++ b/crates/macros/src/composition_poly.rs @@ -43,7 +43,10 @@ impl ToTokens for CompositionPolyItem { #[derive(Debug, Clone, Copy)] struct #name; - impl binius_math::CompositionPoly<#scalar_type> for #name { + impl

binius_math::CompositionPoly

for #name + where + P: binius_field::PackedField>, + { fn n_vars(&self) -> usize { #n_vars } @@ -56,18 +59,18 @@ impl ToTokens for CompositionPolyItem { 0 } - fn expression>(&self) -> binius_math::ArithExpr { + fn expression(&self) -> binius_math::ArithExpr { (#expr).convert_field() } - fn evaluate>>(&self, query: &[P]) -> Result { + fn evaluate(&self, query: &[P]) -> Result { if query.len() != #n_vars { return Err(binius_math::Error::IncorrectQuerySize { expected: #n_vars }); } Ok(#eval_single) } - fn batch_evaluate>>( + fn batch_evaluate( &self, batch_query: &[&[P]], evals: &mut [P], @@ -89,36 +92,6 @@ impl ToTokens for CompositionPolyItem { Ok(()) } } - - impl

binius_math::CompositionPolyOS

for #name - where - P: binius_field::PackedField>, - { - fn n_vars(&self) -> usize { - >::n_vars(self) - } - - fn degree(&self) -> usize { - >::degree(self) - } - - fn binary_tower_level(&self) -> usize { - >::binary_tower_level(self) - } - - fn expression(&self) -> binius_math::ArithExpr { - >::expression(self) - } - - fn evaluate(&self, query: &[P]) -> Result { - >::evaluate(self, query) - } - - fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), binius_math::Error> { - >::batch_evaluate(self, batch_query, evals) - } - } - }; if *is_anonymous { diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index c2d8f069..08fa219a 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -5,26 +5,24 @@ mod arith_circuit_poly; mod arith_expr; mod composition_poly; -use std::collections::BTreeSet; - use proc_macro::TokenStream; use quote::{quote, ToTokens}; -use syn::{parse_macro_input, Data, DeriveInput, Fields}; +use syn::{parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, ItemImpl}; use crate::{ arith_circuit_poly::ArithCircuitPolyItem, arith_expr::ArithExprItem, composition_poly::CompositionPolyItem, }; -/// Useful for concisely creating structs that implement CompositionPolyOS. +/// Useful for concisely creating structs that implement CompositionPoly. /// This currently only supports creating composition polynomials of tower level 0. /// /// ``` /// use binius_macros::composition_poly; -/// use binius_math::CompositionPolyOS; +/// use binius_math::CompositionPoly; /// use binius_field::{Field, BinaryField1b as F}; /// -/// // Defines named struct without any fields that implements CompositionPolyOS +/// // Defines named struct without any fields that implements CompositionPoly /// composition_poly!(MyComposition[x, y, z] = x + y * z); /// assert_eq!( /// MyComposition.evaluate(&[F::ONE, F::ONE, F::ONE]).unwrap(), @@ -76,156 +74,261 @@ pub fn arith_circuit_poly(input: TokenStream) -> TokenStream { .into() } -/// Implements `pub fn iter_oracles(&self) -> impl Iterator`. +/// Derives the trait binius_utils::DeserializeBytes for a struct or enum /// -/// Detects and includes fields with type `OracleId`, `[OracleId; N]` +/// See the DeserializeBytes derive macro docs for examples/tests +#[proc_macro_derive(SerializeBytes)] +pub fn derive_serialize_bytes(input: TokenStream) -> TokenStream { + let input: DeriveInput = parse_macro_input!(input); + let span = input.span(); + let name = input.ident; + let mut generics = input.generics.clone(); + generics.type_params_mut().for_each(|type_param| { + type_param + .bounds + .push(parse_quote!(binius_utils::SerializeBytes)) + }); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let body = match input.data { + Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(), + Data::Struct(data) => { + let fields = field_names(data.fields, None); + quote! { + #(binius_utils::SerializeBytes::serialize(&self.#fields, &mut write_buf, mode)?;)* + } + } + Data::Enum(data) => { + let variants = data + .variants + .into_iter() + .enumerate() + .map(|(i, variant)| { + let variant_ident = &variant.ident; + let variant_index = i as u8; + let fields = field_names(variant.fields.clone(), Some("field_")); + let serialize_variant = quote! { + binius_utils::SerializeBytes::serialize(&#variant_index, &mut write_buf, mode)?; + #(binius_utils::SerializeBytes::serialize(#fields, &mut write_buf, mode)?;)* + }; + match variant.fields { + Fields::Named(_) => quote! { + Self::#variant_ident { #(#fields),* } => { + #serialize_variant + } + }, + Fields::Unnamed(_) => quote! { + Self::#variant_ident(#(#fields),*) => { + #serialize_variant + } + }, + Fields::Unit => quote! { + Self::#variant_ident => { + #serialize_variant + } + }, + } + }) + .collect::>(); + + quote! { + match self { + #(#variants)* + } + } + } + }; + quote! { + impl #impl_generics binius_utils::SerializeBytes for #name #ty_generics #where_clause { + fn serialize(&self, mut write_buf: impl binius_utils::bytes::BufMut, mode: binius_utils::SerializationMode) -> Result<(), binius_utils::SerializationError> { + #body + Ok(()) + } + } + }.into() +} + +/// Derives the trait binius_utils::DeserializeBytes for a struct or enum /// /// ``` -/// use binius_macros::IterOracles; -/// type OracleId = usize; -/// type BatchId = usize; -/// -/// #[derive(IterOracles)] -/// struct Oracle { -/// x: OracleId, -/// y: [OracleId; 5], -/// z: [OracleId; 5*2], -/// ignored_field1: usize, -/// ignored_field2: BatchId, -/// ignored_field3: [[OracleId; 5]; 2], +/// use binius_field::BinaryField128b; +/// use binius_utils::{SerializeBytes, DeserializeBytes, SerializationMode}; +/// use binius_macros::{SerializeBytes, DeserializeBytes}; +/// +/// #[derive(Debug, PartialEq, SerializeBytes, DeserializeBytes)] +/// enum MyEnum { +/// A(usize), +/// B { x: u32, y: u32 }, +/// C /// } +/// +/// +/// let mut buf = vec![]; +/// let value = MyEnum::B { x: 42, y: 1337 }; +/// MyEnum::serialize(&value, &mut buf, SerializationMode::Native).unwrap(); +/// assert_eq!( +/// MyEnum::deserialize(buf.as_slice(), SerializationMode::Native).unwrap(), +/// value +/// ); +/// +/// +/// #[derive(Debug, PartialEq, SerializeBytes, DeserializeBytes)] +/// struct MyStruct { +/// data: Vec +/// } +/// +/// let mut buf = vec![]; +/// let value = MyStruct { +/// data: vec![BinaryField128b::new(1234), BinaryField128b::new(5678)] +/// }; +/// MyStruct::serialize(&value, &mut buf, SerializationMode::CanonicalTower).unwrap(); +/// assert_eq!( +/// MyStruct::::deserialize(buf.as_slice(), SerializationMode::CanonicalTower).unwrap(), +/// value +/// ); /// ``` -#[proc_macro_derive(IterOracles)] -pub fn iter_oracle_derive(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - let Data::Struct(data) = &input.data else { - panic!("#[derive(IterOracles)] is only defined for structs with named fields"); - }; - let Fields::Named(fields) = &data.fields else { - panic!("#[derive(IterOracles)] is only defined for structs with named fields"); +#[proc_macro_derive(DeserializeBytes)] +pub fn derive_deserialize_bytes(input: TokenStream) -> TokenStream { + let input: DeriveInput = parse_macro_input!(input); + let span = input.span(); + let name = input.ident; + let mut generics = input.generics.clone(); + generics.type_params_mut().for_each(|type_param| { + type_param + .bounds + .push(parse_quote!(binius_utils::DeserializeBytes)) + }); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let deserialize_value = quote! { + binius_utils::DeserializeBytes::deserialize(&mut read_buf, mode)? }; + let body = match input.data { + Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(), + Data::Struct(data) => { + let fields = field_names(data.fields, None); + quote! { + Ok(Self { + #(#fields: #deserialize_value,)* + }) + } + } + Data::Enum(data) => { + let variants = data + .variants + .into_iter() + .enumerate() + .map(|(i, variant)| { + let variant_ident = &variant.ident; + let variant_index: u8 = i as u8; + match variant.fields { + Fields::Named(fields) => { + let fields = fields + .named + .into_iter() + .map(|field| field.ident) + .map(|field_name| quote!(#field_name: #deserialize_value)) + .collect::>(); - let name = &input.ident; - let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl(); + quote! { + #variant_index => Self::#variant_ident { #(#fields,)* } + } + } + Fields::Unnamed(fields) => { + let fields = fields + .unnamed + .into_iter() + .map(|_| quote!(#deserialize_value)) + .collect::>(); - let oracles = fields - .named - .iter() - .filter_map(|f| { - let name = f.ident.clone(); - match &f.ty { - syn::Type::Path(type_path) if type_path.path.is_ident("OracleId") => { - Some(quote!(std::iter::once(self.#name))) - } - syn::Type::Array(array) => { - if let syn::Type::Path(type_path) = *array.elem.clone() { - type_path - .path - .is_ident("OracleId") - .then(|| quote!(self.#name.into_iter())) - } else { - None + quote! { + #variant_index => Self::#variant_ident(#(#fields,)*) + } + } + Fields::Unit => quote! { + #variant_index => Self::#variant_ident + }, } - } - _ => None, - } - }) - .collect::>(); + }) + .collect::>(); + let name = name.to_string(); + quote! { + let variant_index: u8 = #deserialize_value; + Ok(match variant_index { + #(#variants,)* + _ => { + return Err(binius_utils::SerializationError::UnknownEnumVariant { + name: #name, + index: variant_index + }) + } + }) + } + } + }; quote! { - impl #impl_generics #name #ty_generics #where_clause { - pub fn iter_oracles(&self) -> impl Iterator { - std::iter::empty() - #(.chain(#oracles))* + impl #impl_generics binius_utils::DeserializeBytes for #name #ty_generics #where_clause { + fn deserialize(mut read_buf: impl binius_utils::bytes::Buf, mode: binius_utils::SerializationMode) -> Result + where + Self: Sized + { + #body } } } .into() } -/// Implements `pub fn iter_polys(&self) -> impl Iterator>`. +/// Use on an impl block for MultivariatePoly, to automatically implement erased_serialize_bytes. /// -/// Supports `Vec

`, `[Vec

; N]`. Currently doesn't filter out fields from the struct, so you can't add any other fields. +/// Importantly, this will serialize the concrete instance, prefixed by the identifier of the data type. /// -/// ``` -/// use binius_macros::IterPolys; -/// use binius_field::PackedField; -/// -/// #[derive(IterPolys)] -/// struct Witness { -/// x: Vec

, -/// y: [Vec

; 5], -/// z: [Vec

; 5*2], -/// } -/// ``` -#[proc_macro_derive(IterPolys)] -pub fn iter_witness_derive(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - let Data::Struct(data) = &input.data else { - panic!("#[derive(IterPolys)] is only defined for structs with named fields"); - }; - let Fields::Named(fields) = &data.fields else { - panic!("#[derive(IterPolys)] is only defined for structs with named fields"); +/// This prefix can be used to figure out which concrete data type it should use for deserialization later. +#[proc_macro_attribute] +pub fn erased_serialize_bytes(_attr: TokenStream, item: TokenStream) -> TokenStream { + let mut item_impl: ItemImpl = parse_macro_input!(item); + let syn::Type::Path(p) = &*item_impl.self_ty else { + return syn::Error::new( + item_impl.span(), + "#[erased_serialize_bytes] can only be used on an impl for a concrete type", + ) + .into_compile_error() + .into(); }; - - let name = &input.ident; - let witnesses = fields - .named - .iter() - .map(|f| { - let name = f.ident.clone(); - match &f.ty { - syn::Type::Array(_) => quote!(self.#name.iter()), - _ => quote!(std::iter::once(&self.#name)), - } - }) - .collect::>(); - - let packed_field_vars = generic_vars_with_trait(&input.generics, "PackedField"); - assert_eq!(packed_field_vars.len(), 1, "Only a single packed field is supported for now"); - let p = packed_field_vars.first(); - let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl(); - quote! { - impl #impl_generics #name #ty_generics #where_clause { - pub fn iter_polys(&self) -> impl Iterator> { - std::iter::empty() - #(.chain(#witnesses))* - .map(|values| binius_math::MultilinearExtension::from_values_slice(values.as_slice()).unwrap()) - } + let name = p.path.segments.last().unwrap().ident.to_string(); + item_impl.items.push(syn::ImplItem::Fn(parse_quote! { + fn erased_serialize( + &self, + write_buf: &mut dyn binius_utils::bytes::BufMut, + mode: binius_utils::SerializationMode, + ) -> Result<(), binius_utils::SerializationError> { + binius_utils::SerializeBytes::serialize(&#name, &mut *write_buf, mode)?; + binius_utils::SerializeBytes::serialize(self, &mut *write_buf, mode) } + })); + quote! { + #item_impl } .into() } -/// This will accept the generics definition of a struct (relevant for derive macros), -/// and return all the generic vars that are constrained by a specific trait identifier. -/// ``` -/// use binius_field::{PackedField, Field}; -/// struct Example(A, B, C); -/// ``` -/// In the above example, when matching against the trait_name "PackedField", -/// the identifiers A and B will be returned, but not C -pub(crate) fn generic_vars_with_trait( - vars: &syn::Generics, - trait_name: &str, -) -> BTreeSet { - vars.params - .iter() - .filter_map(|param| match param { - syn::GenericParam::Type(type_param) => { - let is_bounded_by_trait_name = type_param.bounds.iter().any(|bound| match bound { - syn::TypeParamBound::Trait(trait_bound) => { - if let Some(last_segment) = trait_bound.path.segments.last() { - last_segment.ident == trait_name - } else { - false - } - } - _ => false, - }); - is_bounded_by_trait_name.then(|| type_param.ident.clone()) - } - syn::GenericParam::Const(_) | syn::GenericParam::Lifetime(_) => None, - }) - .collect() +fn field_names(fields: Fields, positional_prefix: Option<&str>) -> Vec { + match fields { + Fields::Named(fields) => fields + .named + .into_iter() + .map(|field| field.ident.into_token_stream()) + .collect(), + Fields::Unnamed(fields) => fields + .unnamed + .into_iter() + .enumerate() + .map(|(i, _)| match positional_prefix { + Some(prefix) => { + quote::format_ident!("{}{}", prefix, syn::Index::from(i)).into_token_stream() + } + None => syn::Index::from(i).into_token_stream(), + }) + .collect(), + Fields::Unit => vec![], + } } diff --git a/crates/macros/tests/arithmetic_circuit.rs b/crates/macros/tests/arithmetic_circuit.rs index 895f6003..c840e3b9 100644 --- a/crates/macros/tests/arithmetic_circuit.rs +++ b/crates/macros/tests/arithmetic_circuit.rs @@ -2,7 +2,7 @@ use binius_field::*; use binius_macros::arith_circuit_poly; -use binius_math::CompositionPolyOS; +use binius_math::CompositionPoly; use paste::paste; use rand::{rngs::StdRng, SeedableRng}; diff --git a/crates/math/Cargo.toml b/crates/math/Cargo.toml index d077173a..d37e6528 100644 --- a/crates/math/Cargo.toml +++ b/crates/math/Cargo.toml @@ -9,6 +9,7 @@ workspace = true [dependencies] binius_field = { path = "../field" } +binius_macros = { path = "../macros" } binius_maybe_rayon = { path = "../maybe_rayon", default-features = false } binius_utils = { path = "../utils", default-features = false } auto_impl.workspace = true diff --git a/crates/math/src/arith_expr.rs b/crates/math/src/arith_expr.rs index 6607901f..84cb05cc 100644 --- a/crates/math/src/arith_expr.rs +++ b/crates/math/src/arith_expr.rs @@ -7,7 +7,8 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; -use binius_field::{Field, PackedField}; +use binius_field::{Field, PackedField, TowerField}; +use binius_macros::{DeserializeBytes, SerializeBytes}; use super::error::Error; @@ -16,7 +17,7 @@ use super::error::Error; /// Arithmetic expressions are trees, where the leaves are either constants or variables, and the /// non-leaf nodes are arithmetic operations, such as addition, multiplication, etc. They are /// specific representations of multivariate polynomials. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub enum ArithExpr { Const(F), Var(usize), @@ -136,7 +137,7 @@ impl ArithExpr { &self, ) -> Result, >::Error> { Ok(match self { - Self::Const(val) => ArithExpr::Const((*val).try_into()?), + Self::Const(val) => ArithExpr::Const(FTgt::try_from(*val)?), Self::Var(index) => ArithExpr::Var(*index), Self::Add(left, right) => { let new_left = left.try_convert_field()?; @@ -197,6 +198,19 @@ impl ArithExpr { } } +impl ArithExpr { + pub fn binary_tower_level(&self) -> usize { + match self { + Self::Const(value) => value.min_tower_level(), + Self::Var(_) => 0, + Self::Add(left, right) | Self::Mul(left, right) => { + max(left.binary_tower_level(), right.binary_tower_level()) + } + Self::Pow(base, _) => base.binary_tower_level(), + } + } +} + impl Default for ArithExpr where F: Field, diff --git a/crates/math/src/composition_poly.rs b/crates/math/src/composition_poly.rs index 100c9d31..8fc883f0 100644 --- a/crates/math/src/composition_poly.rs +++ b/crates/math/src/composition_poly.rs @@ -3,16 +3,14 @@ use std::fmt::Debug; use auto_impl::auto_impl; -use binius_field::{ExtensionField, Field, PackedField}; +use binius_field::PackedField; use stackalloc::stackalloc_with_default; use crate::{ArithExpr, Error}; /// A multivariate polynomial that is used as a composition of several multilinear polynomials. -/// -/// This is an object-safe version of the [`CompositionPoly`] trait. #[auto_impl(Arc, &)] -pub trait CompositionPolyOS

: Debug + Send + Sync +pub trait CompositionPoly

: Debug + Send + Sync where P: PackedField, { @@ -65,23 +63,3 @@ where }) } } - -/// A generic version of the `CompositionPolyOS` trait that is not object-safe. -#[auto_impl(&)] -pub trait CompositionPoly: Debug + Send + Sync { - fn n_vars(&self) -> usize; - - fn degree(&self) -> usize; - - fn binary_tower_level(&self) -> usize; - - fn expression>(&self) -> ArithExpr; - - fn evaluate>>(&self, query: &[P]) -> Result; - - fn batch_evaluate>>( - &self, - batch_query: &[&[P]], - evals: &mut [P], - ) -> Result<(), Error>; -} diff --git a/crates/math/src/deinterleave.rs b/crates/math/src/deinterleave.rs index 4f01b437..1a597290 100644 --- a/crates/math/src/deinterleave.rs +++ b/crates/math/src/deinterleave.rs @@ -30,14 +30,11 @@ pub fn deinterleave( } let deinterleaved = (0..1 << (log_scalar_count - P::LOG_WIDTH)).map(|i| { - let mut even = interleaved[2 * i]; - let mut odd = interleaved[2 * i + 1]; - - for log_block_len in (0..P::LOG_WIDTH).rev() { - let (even_interleaved, odd_interleaved) = even.interleave(odd, log_block_len); - even = even_interleaved; - odd = odd_interleaved; - } + let (even, odd) = if P::LOG_WIDTH > 0 { + P::unzip(interleaved[2 * i], interleaved[2 * i + 1], 0) + } else { + (interleaved[2 * i], interleaved[2 * i + 1]) + }; (i, even, odd) }); diff --git a/crates/math/src/error.rs b/crates/math/src/error.rs index adfc1f74..bb53b885 100644 --- a/crates/math/src/error.rs +++ b/crates/math/src/error.rs @@ -10,6 +10,8 @@ pub enum Error { MatrixNotSquare, #[error("the matrix is singular")] MatrixIsSingular, + #[error("domain size needs to be at least one")] + DomainSizeAtLeastOne, #[error("domain size is larger than the field")] DomainSizeTooLarge, #[error("the inputted packed values slice had an unexpected length")] diff --git a/crates/math/src/fold.rs b/crates/math/src/fold.rs index a7002a50..a35271a3 100644 --- a/crates/math/src/fold.rs +++ b/crates/math/src/fold.rs @@ -4,22 +4,18 @@ use core::slice; use std::{any::TypeId, cmp::min, mem::MaybeUninit}; use binius_field::{ - arch::{byte_sliced::ByteSlicedAES32x128b, ArchOptimal, OptimalUnderlier}, - packed::{get_packed_slice, set_packed_slice_unchecked}, + arch::{ArchOptimal, OptimalUnderlier}, + byte_iteration::{ + can_iterate_bytes, create_partial_sums_lookup_tables, is_sequential_bytes, iterate_bytes, + ByteIteratorCallback, PackedSlice, + }, + packed::{get_packed_slice, get_packed_slice_unchecked, set_packed_slice_unchecked}, underlier::{UnderlierWithBitOps, WithUnderlier}, - AESTowerField128b, BinaryField128b, BinaryField128bPolyval, BinaryField1b, ByteSlicedAES32x16b, - ByteSlicedAES32x32b, ByteSlicedAES32x64b, ByteSlicedAES32x8b, ExtensionField, Field, - PackedBinaryField128x1b, PackedBinaryField16x1b, PackedBinaryField256x1b, - PackedBinaryField32x1b, PackedBinaryField512x1b, PackedBinaryField64x1b, PackedBinaryField8x1b, - PackedField, -}; -use binius_maybe_rayon::{ - iter::{IndexedParallelIterator, ParallelIterator}, - slice::ParallelSliceMut, + AESTowerField128b, BinaryField128b, BinaryField128bPolyval, BinaryField1b, ExtensionField, + Field, PackedField, }; use binius_utils::bail; -use bytemuck::{fill_zeroes, Pod}; -use itertools::max; +use bytemuck::fill_zeroes; use lazy_static::lazy_static; use stackalloc::helpers::slice_assume_init_mut; @@ -29,6 +25,9 @@ use crate::Error; /// /// Every consequent `1 << log_query_size` scalar values are dot-producted with the corresponding /// query elements. The result is stored in the `output` slice of packed values. +/// +/// Please note that this method is single threaded. Currently we always have some +/// parallelism above this level, so it's not a problem. pub fn fold_right( evals: &[P], log_evals_size: usize, @@ -49,7 +48,16 @@ where return Ok(()); } - fold_right_fallback(evals, log_evals_size, query, log_query_size, out); + // Use linear interpolation for single variable multilinear queries. + let is_lerp = log_query_size == 1 + && get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE; + + if is_lerp { + let lerp_query = get_packed_slice(query, 1); + fold_right_lerp(evals, log_evals_size, lerp_query, out); + } else { + fold_right_fallback(evals, log_evals_size, query, log_query_size, out); + } Ok(()) } @@ -60,7 +68,7 @@ where /// with the corresponding query element. The results is written to the `output` slice of packed values. /// If the function returns `Ok(())`, then `out` can be safely interpreted as initialized. /// -/// Please note that unlike `fold_right`, this method is single threaded. Currently we always have some +/// Please note that this method is single threaded. Currently we always have some /// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to /// use more efficient optimizations for special cases. If we ever need a parallel version of this /// function, we can implement it separately. @@ -129,115 +137,6 @@ where Ok(()) } -/// A marker trait that the slice of packed values can be iterated as a sequence of bytes. -/// The order of the iteration by BinaryField1b subfield elements and bits within iterated bytes must -/// be the same. -/// -/// # Safety -/// The implementor must ensure that the cast of the slice of packed values to the slice of bytes -/// is safe and preserves the order of the 1-bit elements. -#[allow(unused)] -unsafe trait SequentialBytes: Pod {} - -unsafe impl SequentialBytes for PackedBinaryField8x1b {} -unsafe impl SequentialBytes for PackedBinaryField16x1b {} -unsafe impl SequentialBytes for PackedBinaryField32x1b {} -unsafe impl SequentialBytes for PackedBinaryField64x1b {} -unsafe impl SequentialBytes for PackedBinaryField128x1b {} -unsafe impl SequentialBytes for PackedBinaryField256x1b {} -unsafe impl SequentialBytes for PackedBinaryField512x1b {} - -/// Returns true if T implements `SequentialBytes` trait. -/// Use a hack that exploits that array copying is optimized for the `Copy` types. -/// Unfortunately there is no more proper way to perform this check this in Rust at runtime. -#[allow(clippy::redundant_clone)] -fn is_sequential_bytes() -> bool { - struct X(bool, std::marker::PhantomData); - - impl Clone for X { - fn clone(&self) -> Self { - Self(false, std::marker::PhantomData) - } - } - - impl Copy for X {} - - let value = [X::(true, std::marker::PhantomData)]; - let cloned = value.clone(); - - cloned[0].0 -} - -/// Returns if we can iterate over bytes, each representing 8 1-bit values. -fn can_iterate_bytes() -> bool { - // Packed fields with sequential byte order - if is_sequential_bytes::

() { - return true; - } - - // Byte-sliced fields - // Note: add more byte sliced types here as soon as they are added - match TypeId::of::

() { - x if x == TypeId::of::() => true, - x if x == TypeId::of::() => true, - x if x == TypeId::of::() => true, - x if x == TypeId::of::() => true, - x if x == TypeId::of::() => true, - _ => false, - } -} - -/// Helper macro to generate the iteration over bytes for byte-sliced types. -macro_rules! iterate_byte_sliced { - ($packed_type:ty, $data:ident, $f:ident) => { - assert_eq!(TypeId::of::<$packed_type>(), TypeId::of::

()); - - // Safety: the cast is safe because the type is checked by arm statement - let data = - unsafe { slice::from_raw_parts($data.as_ptr() as *const $packed_type, $data.len()) }; - for value in data.iter() { - for i in 0..<$packed_type>::BYTES { - // Safety: j is less than `ByteSlicedAES32x128b::BYTES` - $f(unsafe { value.get_byte_unchecked(i) }); - } - } - }; -} - -/// Iterate over bytes of a slice of the packed values. -fn iterate_bytes(data: &[P], mut f: impl FnMut(u8)) { - if is_sequential_bytes::

() { - // Safety: `P` implements `SequentialBytes` trait, so the following cast is safe - // and preserves the order. - let bytes = unsafe { - std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data)) - }; - for byte in bytes { - f(*byte); - } - } else { - // Note: add more byte sliced types here as soon as they are added - match TypeId::of::

() { - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x128b, data, f); - } - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x64b, data, f); - } - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x32b, data, f); - } - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x16b, data, f); - } - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x8b, data, f); - } - _ => unreachable!("packed field doesn't support byte iteration"), - } - } -} - /// Optimized version for 1-bit values with query size 0-2 fn fold_right_1bit_evals_small_query( evals: &[P], @@ -251,18 +150,8 @@ where if LOG_QUERY_SIZE >= 3 { return false; } - let chunk_size = 1 - << max(&[ - 10, - (P::LOG_WIDTH + LOG_QUERY_SIZE).saturating_sub(PE::LOG_WIDTH), - PE::LOG_WIDTH, - ]) - .unwrap(); - if out.len() % chunk_size != 0 { - return false; - } - if P::WIDTH << LOG_QUERY_SIZE > chunk_size << PE::LOG_WIDTH { + if P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH { return false; } @@ -279,33 +168,43 @@ where }) .collect::>(); - out.par_chunks_mut(chunk_size) - .enumerate() - .for_each(|(index, chunk)| { - let input_offset = - ((index * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; - let input_end = - (((index + 1) * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; + struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> { + out: &'a mut [PE], + cached_table: &'a [PE::Scalar], + } + impl ByteIteratorCallback + for Callback<'_, PE, LOG_QUERY_SIZE> + { + #[inline(always)] + fn call(&mut self, iterator: impl Iterator) { + let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1; + let values_in_byte = 1 << (3 - LOG_QUERY_SIZE); let mut current_index = 0; - iterate_bytes(&evals[input_offset..input_end], |byte| { - let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1; - let values_in_byte = 1 << (3 - LOG_QUERY_SIZE); + for byte in iterator { for k in 0..values_in_byte { let index = (byte >> (k * (1 << LOG_QUERY_SIZE))) & mask; // Safety: `i` is less than `chunk_size` unsafe { set_packed_slice_unchecked( - chunk, + self.out, current_index + k, - cached_table[index as usize], + self.cached_table[index as usize], ); } } current_index += values_in_byte; - }); - }); + } + } + } + + let mut callback = Callback::<'_, PE, LOG_QUERY_SIZE> { + out, + cached_table: &cached_table, + }; + + iterate_bytes(evals, &mut callback); true } @@ -323,61 +222,52 @@ where if LOG_QUERY_SIZE < 3 { return false; } - let chunk_size = 1 - << max(&[ - 10, - (P::LOG_WIDTH + LOG_QUERY_SIZE).saturating_sub(PE::LOG_WIDTH), - PE::LOG_WIDTH, - ]) - .unwrap(); - if out.len() % chunk_size != 0 { + + if P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH { return false; } - let log_tables_count = LOG_QUERY_SIZE - 3; - let tables_count = 1 << log_tables_count; - let cached_tables = (0..tables_count) - .map(|i| { - (0..256) - .map(|j| { - let mut result = PE::Scalar::ZERO; - for k in 0..8 { - if j >> k & 1 == 1 { - result += get_packed_slice(query, (i << 3) | k); - } - } - result - }) - .collect::>() - }) - .collect::>(); + let cached_tables = + create_partial_sums_lookup_tables(PackedSlice::new(query, 1 << LOG_QUERY_SIZE)); - out.par_chunks_mut(chunk_size) - .enumerate() - .for_each(|(index, chunk)| { - let input_offset = - ((index * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; - let input_end = - (((index + 1) * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; + struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> { + out: &'a mut [PE], + cached_tables: &'a [PE::Scalar], + } - let mut current_value = PE::Scalar::ZERO; - let mut current_table = 0; + impl ByteIteratorCallback + for Callback<'_, PE, LOG_QUERY_SIZE> + { + #[inline(always)] + fn call(&mut self, iterator: impl Iterator) { + let log_tables_count = LOG_QUERY_SIZE - 3; + let tables_count = 1 << log_tables_count; let mut current_index = 0; - iterate_bytes(&evals[input_offset..input_end], |byte| { - current_value += cached_tables[current_table][byte as usize]; + let mut current_table = 0; + let mut current_value = PE::Scalar::ZERO; + for byte in iterator { + current_value += self.cached_tables[(current_table << 8) + byte as usize]; current_table += 1; if current_table == tables_count { // Safety: `i` is less than `chunk_size` unsafe { - set_packed_slice_unchecked(chunk, current_index, current_value); + set_packed_slice_unchecked(self.out, current_index, current_value); } - current_table = 0; current_index += 1; + current_table = 0; current_value = PE::Scalar::ZERO; } - }); - }); + } + } + } + + let mut callback = Callback::<'_, _, LOG_QUERY_SIZE> { + out, + cached_tables: &cached_tables, + }; + + iterate_bytes(evals, &mut callback); true } @@ -419,6 +309,46 @@ where } } +/// Specialized implementation for a single parameter right fold using linear interpolation +/// instead of tensor expansion resulting in a single multiplication instead of two: +/// f(r||w) = r * (f(1||w) - f(0||w)) + f(0||w). +/// +/// The same approach may be generalized to higher variable counts, with diminishing returns. +fn fold_right_lerp( + evals: &[P], + log_evals_size: usize, + lerp_query: PE::Scalar, + out: &mut [PE], +) where + P: PackedField, + PE: PackedField>, +{ + assert_eq!(1 << log_evals_size.saturating_sub(PE::LOG_WIDTH + 1), out.len()); + + out.iter_mut() + .enumerate() + .for_each(|(i, packed_result_eval)| { + for j in 0..min(PE::WIDTH, 1 << (log_evals_size - 1)) { + let index = (i << PE::LOG_WIDTH) | j; + + let (eval0, eval1) = unsafe { + ( + get_packed_slice_unchecked(evals, index << 1), + get_packed_slice_unchecked(evals, (index << 1) | 1), + ) + }; + + let result_eval = + PE::Scalar::from(eval1 - eval0) * lerp_query + PE::Scalar::from(eval0); + + // Safety: `j` < `PE::WIDTH` + unsafe { + packed_result_eval.set_unchecked(j, result_eval); + } + } + }) +} + /// Fallback implementation for fold that can be executed for any field types and sizes. fn fold_right_fallback( evals: &[P], @@ -430,34 +360,26 @@ fn fold_right_fallback( P: PackedField, PE: PackedField>, { - const CHUNK_SIZE: usize = 1 << 10; - let packed_result_evals = out; - packed_result_evals - .par_chunks_mut(CHUNK_SIZE) - .enumerate() - .for_each(|(i, packed_result_evals)| { - for (k, packed_result_eval) in packed_result_evals.iter_mut().enumerate() { - let offset = i * CHUNK_SIZE; - for j in 0..min(PE::WIDTH, 1 << (log_evals_size - log_query_size)) { - let index = ((offset + k) << PE::LOG_WIDTH) | j; - - let offset = index << log_query_size; - - let mut result_eval = PE::Scalar::ZERO; - for (t, query_expansion) in PackedField::iter_slice(query) - .take(1 << log_query_size) - .enumerate() - { - result_eval += query_expansion * get_packed_slice(evals, t + offset); - } + for (k, packed_result_eval) in out.iter_mut().enumerate() { + for j in 0..min(PE::WIDTH, 1 << (log_evals_size - log_query_size)) { + let index = (k << PE::LOG_WIDTH) | j; + + let offset = index << log_query_size; + + let mut result_eval = PE::Scalar::ZERO; + for (t, query_expansion) in PackedField::iter_slice(query) + .take(1 << log_query_size) + .enumerate() + { + result_eval += query_expansion * get_packed_slice(evals, t + offset); + } - // Safety: `j` < `PE::WIDTH` - unsafe { - packed_result_eval.set_unchecked(j, result_eval); - } - } + // Safety: `j` < `PE::WIDTH` + unsafe { + packed_result_eval.set_unchecked(j, result_eval); } - }); + } + } } type ArchOptimaType = ::OptimalThroughputPacked; @@ -685,27 +607,13 @@ mod tests { use std::iter::repeat_with; use binius_field::{ - packed::set_packed_slice, PackedBinaryField16x32b, PackedBinaryField16x8b, - PackedBinaryField4x1b, PackedBinaryField512x1b, + packed::set_packed_slice, PackedBinaryField128x1b, PackedBinaryField16x32b, + PackedBinaryField16x8b, PackedBinaryField512x1b, PackedBinaryField64x8b, }; use rand::{rngs::StdRng, SeedableRng}; use super::*; - #[test] - fn test_sequential_bits() { - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - - assert!(!is_sequential_bytes::()); - assert!(!is_sequential_bytes::()); - } - fn fold_right_reference( evals: &[P], log_evals_size: usize, @@ -782,7 +690,9 @@ mod tests { let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng)) .take(1 << LOG_EVALS_SIZE) .collect::>(); - let query = vec![PackedBinaryField512x1b::random(&mut rng)]; + let query = repeat_with(|| PackedBinaryField64x8b::random(&mut rng)) + .take(8) + .collect::>(); for log_query_size in 0..10 { check_fold_right( diff --git a/crates/math/src/matrix.rs b/crates/math/src/matrix.rs index 674087d9..408b7986 100644 --- a/crates/math/src/matrix.rs +++ b/crates/math/src/matrix.rs @@ -186,8 +186,6 @@ impl Matrix { } fn scale_row(&mut self, i: usize, scalar: F) { - assert!(i < self.m); - for x in self.row_mut(i) { *x *= scalar; } diff --git a/crates/math/src/mle_adapters.rs b/crates/math/src/mle_adapters.rs index 079fc4dd..47c227cd 100644 --- a/crates/math/src/mle_adapters.rs +++ b/crates/math/src/mle_adapters.rs @@ -8,7 +8,6 @@ use binius_field::{ }, ExtensionField, Field, PackedField, RepackedExtension, }; -use binius_maybe_rayon::prelude::*; use binius_utils::bail; use super::{Error, MultilinearExtension, MultilinearPoly, MultilinearQueryRef}; @@ -274,11 +273,9 @@ where P: PackedField, Data: Deref + Send + Sync + Debug + 'a, { - pub fn specialize_arc_dyn(self) -> Arc + Send + Sync + 'a> - where - PE: PackedField + RepackedExtension

, - PE::Scalar: ExtensionField, - { + pub fn specialize_arc_dyn>( + self, + ) -> Arc + Send + Sync + 'a> { self.specialize().upcast_arc_dyn() } } @@ -299,44 +296,6 @@ where pub fn upcast_arc_dyn(self) -> Arc + Send + Sync + 'a> { Arc::new(self) } - - /// Given a ($mu$-variate) multilinear function $f$ and an element $r$, - /// return the multilinear function $f(r, X_1, ..., X_{\mu - 1})$. - pub fn evaluate_zeroth_variable(&self, r: P::Scalar) -> Result, Error> { - let multilin = &self.0; - let mu = multilin.n_vars(); - if mu == 0 { - bail!(Error::ConstantFold); - } - let packed_length = 1 << mu.saturating_sub(P::LOG_WIDTH + 1); - // in general, the formula is: f(r||w) = r * (f(1||w) - f(0||w)) + f(0||w). - let result = (0..packed_length) - .into_par_iter() - .map(|i| { - let eval0_minus_eval1 = P::from_fn(|j| { - let index = (i << P::LOG_WIDTH) | j; - // necessary if `mu_minus_one` < `P::LOG_WIDTH` - if index >= 1 << (mu - 1) { - return P::Scalar::ZERO; - } - let eval0 = get_packed_slice(multilin.evals(), index << 1); - let eval1 = get_packed_slice(multilin.evals(), (index << 1) | 1); - eval0 - eval1 - }); - let eval0 = P::from_fn(|j| { - let index = (i << P::LOG_WIDTH) | j; - // necessary if `mu_minus_one` < `P::LOG_WIDTH` - if index >= 1 << (mu - 1) { - return P::Scalar::ZERO; - } - get_packed_slice(multilin.evals(), index << 1) - }); - eval0_minus_eval1 * r + eval0 - }) - .collect::>(); - - MultilinearExtension::new(mu - 1, result) - } } impl From> for MLEDirectAdapter @@ -699,23 +658,4 @@ mod tests { .unwrap(); assert_eq!(evals_out, poly.packed_evals().unwrap()); } - - #[test] - fn test_evaluate_zeroth_evaluate_partial_low_consistent() { - let mut rng = StdRng::seed_from_u64(0); - let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng)) - .take(1 << 8) - .collect(); - - let me = MultilinearExtension::from_values(values).unwrap(); - let mled = MLEDirectAdapter::from(me); - let r = ::random(&mut rng); - - let eval_1: MultilinearExtension = - mled.evaluate_zeroth_variable(r).unwrap(); - let eval_2 = mled - .evaluate_partial_low(multilinear_query(&[r]).to_ref()) - .unwrap(); - assert_eq!(eval_1, eval_2); - } } diff --git a/crates/math/src/multilinear_extension.rs b/crates/math/src/multilinear_extension.rs index 91c379c8..8403476e 100644 --- a/crates/math/src/multilinear_extension.rs +++ b/crates/math/src/multilinear_extension.rs @@ -245,6 +245,7 @@ where PE::Scalar: ExtensionField, { let query = query.into(); + if self.mu < query.n_vars() { bail!(Error::IncorrectQuerySize { expected: self.mu }); } @@ -275,15 +276,6 @@ where PE: PackedField, PE::Scalar: ExtensionField, { - if self.mu < query.n_vars() { - bail!(Error::IncorrectQuerySize { expected: self.mu }); - } - if out.len() != 1 << ((self.mu - query.n_vars()).saturating_sub(PE::LOG_WIDTH)) { - bail!(Error::IncorrectOutputPolynomialSize { - expected: self.mu - query.n_vars(), - }); - } - // This operation is a matrix-vector product of the matrix of multilinear coefficients with // the vector of tensor product-expanded query coefficients. fold_right(&self.evals, self.mu, query.expansion(), query.n_vars(), out) @@ -559,6 +551,28 @@ mod tests { ); } + #[test] + fn test_evaluate_partial_low_single_and_multiple_var_consistent() { + let mut rng = StdRng::seed_from_u64(0); + let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng)) + .take(1 << 8) + .collect(); + + let mle = MultilinearExtension::from_values(values).unwrap(); + let r1 = ::random(&mut rng); + let r2 = ::random(&mut rng); + + let eval_1: MultilinearExtension = mle + .evaluate_partial_low::(multilinear_query(&[r1]).to_ref()) + .unwrap() + .evaluate_partial_low(multilinear_query(&[r2]).to_ref()) + .unwrap(); + let eval_2 = mle + .evaluate_partial_low(multilinear_query(&[r1, r2]).to_ref()) + .unwrap(); + assert_eq!(eval_1, eval_2); + } + #[test] fn test_new_mle_with_tiny_nvars() { MultilinearExtension::new( diff --git a/crates/math/src/multilinear_query.rs b/crates/math/src/multilinear_query.rs index fc27a689..9593367a 100644 --- a/crates/math/src/multilinear_query.rs +++ b/crates/math/src/multilinear_query.rs @@ -36,7 +36,7 @@ impl<'a, P: PackedField, Data: DerefMut> From<&'a MultilinearQuery for MultilinearQueryRef<'a, P> { fn from(query: &'a MultilinearQuery) -> Self { - MultilinearQueryRef::new(query) + Self::new(query) } } @@ -73,7 +73,7 @@ impl MultilinearQuery> { } pub fn expand(query: &[P::Scalar]) -> Self { - let expanded_query = eq_ind_partial_eval::

(query); + let expanded_query = eq_ind_partial_eval(query); Self { expanded_query, n_vars: query.len(), @@ -148,12 +148,16 @@ impl> MultilinearQuery { #[cfg(test)] mod tests { - use binius_field::{Field, PackedField}; + use binius_field::{Field, PackedBinaryField4x32b, PackedField}; use binius_utils::felts; + use itertools::Itertools; use super::*; use crate::tensor_prod_eq_ind; + type P = PackedBinaryField4x32b; + type F =

::Scalar; + fn tensor_prod(p: &[P::Scalar]) -> Vec

{ let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)]; result[0] = P::set_single(P::Scalar::ONE); @@ -252,4 +256,98 @@ mod tests { felts!(BinaryField16b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2]) ); } + + #[test] + fn test_update_single_var() { + let query = MultilinearQuery::

::with_capacity(2); + let r0 = F::new(2); + let extra_query = [r0]; + + let updated_query = query.update(&extra_query).unwrap(); + + assert_eq!(updated_query.n_vars(), 1); + + let expansion = updated_query.into_expansion(); + let expansion = PackedField::iter_slice(&expansion).collect_vec(); + + assert_eq!(expansion, vec![(F::ONE - r0), r0, F::ZERO, F::ZERO]); + } + + #[test] + fn test_update_two_vars() { + let query = MultilinearQuery::

::with_capacity(3); + let r0 = F::new(2); + let r1 = F::new(3); + let extra_query = [r0, r1]; + + let updated_query = query.update(&extra_query).unwrap(); + assert_eq!(updated_query.n_vars(), 2); + + let expansion = updated_query.expansion(); + let expansion = PackedField::iter_slice(expansion).collect_vec(); + + assert_eq!( + expansion, + vec![ + (F::ONE - r0) * (F::ONE - r1), + r0 * (F::ONE - r1), + (F::ONE - r0) * r1, + r0 * r1, + ] + ); + } + + #[test] + fn test_update_three_vars() { + let query = MultilinearQuery::

::with_capacity(4); + let r0 = F::new(2); + let r1 = F::new(3); + let r2 = F::new(5); + let extra_query = [r0, r1, r2]; + + let updated_query = query.update(&extra_query).unwrap(); + assert_eq!(updated_query.n_vars(), 3); + + let expansion = updated_query.expansion(); + let expansion = PackedField::iter_slice(expansion).collect_vec(); + + assert_eq!( + expansion, + vec![ + (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2), + r0 * (F::ONE - r1) * (F::ONE - r2), + (F::ONE - r0) * r1 * (F::ONE - r2), + r0 * r1 * (F::ONE - r2), + (F::ONE - r0) * (F::ONE - r1) * r2, + r0 * (F::ONE - r1) * r2, + (F::ONE - r0) * r1 * r2, + r0 * r1 * r2, + ] + ); + } + + #[test] + fn test_update_exceeds_capacity() { + let query = MultilinearQuery::

::with_capacity(2); + // More than allowed capacity + let extra_query = [F::new(2), F::new(3), F::new(5)]; + + let result = query.update(&extra_query); + // Expecting an error due to exceeding max_query_vars + assert!(result.is_err()); + } + + #[test] + fn test_update_empty() { + let query = MultilinearQuery::

::with_capacity(2); + // Updating with no new coordinates should be fine + let updated_query = query.update(&[]).unwrap(); + + assert_eq!(updated_query.n_vars(), 0); + + let expansion = updated_query.expansion(); + let expansion = PackedField::iter_slice(expansion).collect_vec(); + + assert_eq!(expansion, vec![F::ONE, F::ZERO, F::ZERO, F::ZERO]); + } } diff --git a/crates/math/src/tensor_prod_eq_ind.rs b/crates/math/src/tensor_prod_eq_ind.rs index 8bcba36f..ed48bb1d 100644 --- a/crates/math/src/tensor_prod_eq_ind.rs +++ b/crates/math/src/tensor_prod_eq_ind.rs @@ -62,7 +62,7 @@ pub fn tensor_prod_eq_ind( xs.par_iter_mut() .zip(ys.par_iter_mut()) .with_min_len(64) - .for_each(|(x, y): (&mut P, &mut P)| { + .for_each(|(x, y)| { // x = x * (1 - packed_r_i) = x - x * packed_r_i // y = x * packed_r_i // Notice that we can reuse the multiplication: (x * packed_r_i) @@ -95,8 +95,7 @@ pub fn eq_ind_partial_eval(point: &[P::Scalar]) -> Vec

{ let len = 1 << n.saturating_sub(P::LOG_WIDTH); let mut buffer = zeroed_vec::

(len); buffer[0].set(0, P::Scalar::ONE); - tensor_prod_eq_ind(0, &mut buffer[..], point) - .expect("buffer is allocated with the correct length"); + tensor_prod_eq_ind(0, &mut buffer, point).expect("buffer is allocated with the correct length"); buffer } @@ -107,10 +106,11 @@ mod tests { use super::*; + type P = PackedBinaryField4x32b; + type F =

::Scalar; + #[test] fn test_tensor_prod_eq_ind() { - type P = PackedBinaryField4x32b; - type F =

::Scalar; let v0 = F::new(1); let v1 = F::new(2); let query = vec![v0, v1]; @@ -128,4 +128,59 @@ mod tests { ] ); } + + #[test] + fn test_eq_ind_partial_eval_empty() { + let result = eq_ind_partial_eval::

(&[]); + let expected = vec![P::set_single(F::ONE)]; + assert_eq!(result, expected); + } + + #[test] + fn test_eq_ind_partial_eval_single_var() { + // Only one query coordinate + let r0 = F::new(2); + let result = eq_ind_partial_eval::

(&[r0]); + let expected = vec![(F::ONE - r0), r0, F::ZERO, F::ZERO]; + let result = PackedField::iter_slice(&result).collect_vec(); + assert_eq!(result, expected); + } + + #[test] + fn test_eq_ind_partial_eval_two_vars() { + // Two query coordinates + let r0 = F::new(2); + let r1 = F::new(3); + let result = eq_ind_partial_eval::

(&[r0, r1]); + let result = PackedField::iter_slice(&result).collect_vec(); + let expected = vec![ + (F::ONE - r0) * (F::ONE - r1), + r0 * (F::ONE - r1), + (F::ONE - r0) * r1, + r0 * r1, + ]; + assert_eq!(result, expected); + } + + #[test] + fn test_eq_ind_partial_eval_three_vars() { + // Case with three query coordinates + let r0 = F::new(2); + let r1 = F::new(3); + let r2 = F::new(5); + let result = eq_ind_partial_eval::

(&[r0, r1, r2]); + let result = PackedField::iter_slice(&result).collect_vec(); + + let expected = vec![ + (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2), + r0 * (F::ONE - r1) * (F::ONE - r2), + (F::ONE - r0) * r1 * (F::ONE - r2), + r0 * r1 * (F::ONE - r2), + (F::ONE - r0) * (F::ONE - r1) * r2, + r0 * (F::ONE - r1) * r2, + (F::ONE - r0) * r1 * r2, + r0 * r1 * r2, + ]; + assert_eq!(result, expected); + } } diff --git a/crates/math/src/univariate.rs b/crates/math/src/univariate.rs index b99d2606..d2339d21 100644 --- a/crates/math/src/univariate.rs +++ b/crates/math/src/univariate.rs @@ -7,6 +7,7 @@ use binius_field::{ PackedField, }; use binius_utils::bail; +use itertools::{izip, Either}; use super::{binary_subspace::BinarySubspace, error::Error}; use crate::Matrix; @@ -17,8 +18,9 @@ use crate::Matrix; /// to reconstruct a degree <= d. This struct supports Barycentric extrapolation. #[derive(Debug, Clone)] pub struct EvaluationDomain { - points: Vec, + finite_points: Vec, weights: Vec, + with_infinity: bool, } /// An extended version of `EvaluationDomain` that supports interpolation to monomial form. Takes @@ -32,9 +34,20 @@ pub struct InterpolationDomain { /// Wraps type information to enable instantiating EvaluationDomains. #[auto_impl(&)] pub trait EvaluationDomainFactory: Clone + Sync { - /// Instantiates an EvaluationDomain from a set of points isomorphic to direct - /// lexicographic successors of zero in Fan-Paar tower - fn create(&self, size: usize) -> Result, Error>; + /// Instantiates an EvaluationDomain from `size` lexicographically first values from the + /// binary subspace. + fn create(&self, size: usize) -> Result, Error> { + self.create_with_infinity(size, false) + } + + /// Instantiates an EvaluationDomain from `size` values in total: lexicographically first values + /// from the binary subspace and potentially Karatsuba "infinity" point (which is the coefficient of + /// the highest power in the interpolated polynomial). + fn create_with_infinity( + &self, + size: usize, + with_infinity: bool, + ) -> Result, Error>; } #[derive(Default, Clone)] @@ -48,8 +61,18 @@ pub struct IsomorphicEvaluationDomainFactory { } impl EvaluationDomainFactory for DefaultEvaluationDomainFactory { - fn create(&self, size: usize) -> Result, Error> { - EvaluationDomain::from_points(make_evaluation_points(&self.subspace, size)?) + fn create_with_infinity( + &self, + size: usize, + with_infinity: bool, + ) -> Result, Error> { + if size == 0 && with_infinity { + bail!(Error::DomainSizeAtLeastOne); + } + EvaluationDomain::from_points( + make_evaluation_points(&self.subspace, size - if with_infinity { 1 } else { 0 })?, + with_infinity, + ) } } @@ -58,9 +81,17 @@ where FSrc: BinaryField, FTgt: Field + From + BinaryField, { - fn create(&self, size: usize) -> Result, Error> { - let points = make_evaluation_points(&self.subspace, size)?; - EvaluationDomain::from_points(points.into_iter().map(Into::into).collect()) + fn create_with_infinity( + &self, + size: usize, + with_infinity: bool, + ) -> Result, Error> { + if size == 0 && with_infinity { + bail!(Error::DomainSizeAtLeastOne); + } + let points = + make_evaluation_points(&self.subspace, size - if with_infinity { 1 } else { 0 })?; + EvaluationDomain::from_points(points.into_iter().map(Into::into).collect(), false) } } @@ -78,7 +109,8 @@ fn make_evaluation_points( impl From> for InterpolationDomain { fn from(evaluation_domain: EvaluationDomain) -> Self { let n = evaluation_domain.size(); - let evaluation_matrix = vandermonde(evaluation_domain.points()); + let evaluation_matrix = + vandermonde(evaluation_domain.finite_points(), evaluation_domain.with_infinity()); let mut interpolation_matrix = Matrix::zeros(n, n); evaluation_matrix .inverse_into(&mut interpolation_matrix) @@ -97,17 +129,25 @@ impl From> for InterpolationDomain { } impl EvaluationDomain { - pub fn from_points(points: Vec) -> Result { - let weights = compute_barycentric_weights(&points)?; - Ok(Self { points, weights }) + pub fn from_points(finite_points: Vec, with_infinity: bool) -> Result { + let weights = compute_barycentric_weights(&finite_points)?; + Ok(Self { + finite_points, + weights, + with_infinity, + }) } pub fn size(&self) -> usize { - self.points.len() + self.finite_points.len() + if self.with_infinity { 1 } else { 0 } } - pub fn points(&self) -> &[F] { - self.points.as_slice() + pub fn finite_points(&self) -> &[F] { + self.finite_points.as_slice() + } + + pub const fn with_infinity(&self) -> bool { + self.with_infinity } /// Compute a vector of Lagrange polynomial evaluations in $O(N)$ at a given point `x`. @@ -116,19 +156,23 @@ impl EvaluationDomain { /// are defined by /// $$L_i(x) = \sum_{j \neq i}\frac{x - \pi_j}{\pi_i - \pi_j}$$ pub fn lagrange_evals>(&self, x: FE) -> Vec { - let num_evals = self.size(); + let num_evals = self.finite_points().len(); let mut result: Vec = vec![FE::ONE; num_evals]; // Multiply the product suffixes for i in (1..num_evals).rev() { - result[i - 1] = result[i] * (x - self.points[i]); + result[i - 1] = result[i] * (x - self.finite_points[i]); } let mut prefix = FE::ONE; // Multiply the product prefixes and weights - for ((r, &point), &weight) in result.iter_mut().zip(&self.points).zip(&self.weights) { + for ((r, &point), &weight) in result + .iter_mut() + .zip(&self.finite_points) + .zip(&self.weights) + { *r *= prefix * weight; prefix *= x - point; } @@ -141,18 +185,26 @@ impl EvaluationDomain { where PE: PackedField>, { - let lagrange_eval_results = self.lagrange_evals(x); - - let n = self.size(); - if values.len() != n { + if values.len() != self.size() { bail!(Error::ExtrapolateNumberOfEvaluations); } - let result = lagrange_eval_results - .into_iter() - .zip(values) - .map(|(evaluation, &value)| value * evaluation) - .sum::(); + let (values_iter, infinity_term) = if self.with_infinity { + let (&value_at_infinity, finite_values) = + values.split_last().expect("values length checked above"); + let highest_degree = finite_values.len() as u64; + let iter = izip!(&self.finite_points, finite_values).map(move |(&point, &value)| { + value - value_at_infinity * PE::Scalar::from(point).pow(highest_degree) + }); + (Either::Left(iter), value_at_infinity * x.pow(highest_degree)) + } else { + (Either::Right(values.iter().copied()), PE::zero()) + }; + + let result = izip!(self.lagrange_evals(x), values_iter) + .map(|(lagrange_at_x, value)| value * lagrange_at_x) + .sum::() + + infinity_term; Ok(result) } @@ -163,20 +215,24 @@ impl InterpolationDomain { self.evaluation_domain.size() } - pub fn points(&self) -> &[F] { - self.evaluation_domain.points() + pub fn finite_points(&self) -> &[F] { + self.evaluation_domain.finite_points() } - pub fn extrapolate(&self, values: &[PE], x: PE::Scalar) -> Result - where - PE: PackedExtension>, - { + pub const fn with_infinity(&self) -> bool { + self.evaluation_domain.with_infinity() + } + + pub fn extrapolate>( + &self, + values: &[PE], + x: PE::Scalar, + ) -> Result { self.evaluation_domain.extrapolate(values, x) } pub fn interpolate>(&self, values: &[FE]) -> Result, Error> { - let n = self.evaluation_domain.size(); - if values.len() != n { + if values.len() != self.evaluation_domain.size() { bail!(Error::ExtrapolateNumberOfEvaluations); } @@ -188,11 +244,7 @@ impl InterpolationDomain { /// Extrapolates lines through a pair of packed fields at a single point from a subfield. #[inline] -pub fn extrapolate_line(x0: P, x1: P, z: FS) -> P -where - P: PackedExtension>, - FS: Field, -{ +pub fn extrapolate_line, FS: Field>(x0: P, x1: P, z: FS) -> P { x0 + mul_by_subfield_scalar(x1 - x0, z) } @@ -236,8 +288,8 @@ fn compute_barycentric_weights(points: &[F]) -> Result, Error> .collect() } -fn vandermonde(xs: &[F]) -> Matrix { - let n = xs.len(); +fn vandermonde(xs: &[F], with_infinity: bool) -> Matrix { + let n = xs.len() + if with_infinity { 1 } else { 0 }; let mut mat = Matrix::zeros(n, n); for (i, x_i) in xs.iter().copied().enumerate() { @@ -249,6 +301,11 @@ fn vandermonde(xs: &[F]) -> Matrix { mat[(i, j)] = acc; } } + + if with_infinity { + mat[(n - 1, n - 1)] = F::ONE; + } + mat } @@ -277,7 +334,7 @@ mod tests { fn test_new_domain() { let domain_factory = DefaultEvaluationDomainFactory::::default(); assert_eq!( - domain_factory.create(3).unwrap().points, + domain_factory.create(3).unwrap().finite_points, &[ BinaryField8b::new(0), BinaryField8b::new(1), @@ -292,7 +349,7 @@ mod tests { let iso_domain_factory = IsomorphicEvaluationDomainFactory::::default(); let domain_1: EvaluationDomain = default_domain_factory.create(10).unwrap(); let domain_2: EvaluationDomain = iso_domain_factory.create(10).unwrap(); - assert_eq!(domain_1.points, domain_2.points); + assert_eq!(domain_1.finite_points, domain_2.finite_points); } #[test] @@ -303,11 +360,11 @@ mod tests { let domain_2: EvaluationDomain = iso_domain_factory.create(10).unwrap(); assert_eq!( domain_1 - .points + .finite_points .into_iter() .map(AESTowerField32b::from) .collect::>(), - domain_2.points + domain_2.finite_points ); } @@ -343,6 +400,7 @@ mod tests { repeat_with(|| ::random(&mut rng)) .take(degree + 1) .collect(), + false, ) .unwrap(); @@ -351,7 +409,7 @@ mod tests { .collect::>(); let values = domain - .points() + .finite_points() .iter() .map(|&x| evaluate_univariate(&coeffs, x)) .collect::>(); @@ -370,6 +428,7 @@ mod tests { repeat_with(|| ::random(&mut rng)) .take(degree + 1) .collect(), + false, ) .unwrap(); @@ -378,10 +437,44 @@ mod tests { .collect::>(); let values = domain - .points() + .finite_points() + .iter() + .map(|&x| evaluate_univariate(&coeffs, x)) + .collect::>(); + + let interpolated = InterpolationDomain::from(domain) + .interpolate(&values) + .unwrap(); + assert_eq!(interpolated, coeffs); + } + + #[test] + fn test_infinity() { + let mut rng = StdRng::seed_from_u64(0); + let degree = 6; + + let domain = EvaluationDomain::from_points( + repeat_with(|| ::random(&mut rng)) + .take(degree) + .collect(), + true, + ) + .unwrap(); + + let coeffs = repeat_with(|| ::random(&mut rng)) + .take(degree + 1) + .collect::>(); + + let mut values = domain + .finite_points() .iter() .map(|&x| evaluate_univariate(&coeffs, x)) .collect::>(); + values.push(coeffs.last().copied().unwrap()); + + let x = ::random(&mut rng); + let expected_y = evaluate_univariate(&coeffs, x); + assert_eq!(domain.extrapolate(&values, x).unwrap(), expected_y); let interpolated = InterpolationDomain::from(domain) .interpolate(&values) diff --git a/crates/ntt/src/additive_ntt.rs b/crates/ntt/src/additive_ntt.rs index a48d7bb9..7b5e0bcc 100644 --- a/crates/ntt/src/additive_ntt.rs +++ b/crates/ntt/src/additive_ntt.rs @@ -1,7 +1,6 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{ExtensionField, PackedField, RepackedExtension}; -use binius_utils::checked_arithmetics::log2_strict_usize; use super::error::Error; @@ -46,29 +45,19 @@ pub trait AdditiveNTT { log_batch_size: usize, ) -> Result<(), Error>; - fn forward_transform_ext(&self, data: &mut [PE], coset: u32) -> Result<(), Error> - where - PE: RepackedExtension

, - PE::Scalar: ExtensionField, - { - if !PE::Scalar::DEGREE.is_power_of_two() { - return Err(Error::PowerOfTwoExtensionDegreeRequired); - } - - let log_batch_size = log2_strict_usize(PE::Scalar::DEGREE); - self.forward_transform(PE::cast_bases_mut(data), coset, log_batch_size) + fn forward_transform_ext>( + &self, + data: &mut [PE], + coset: u32, + ) -> Result<(), Error> { + self.forward_transform(PE::cast_bases_mut(data), coset, PE::Scalar::LOG_DEGREE) } - fn inverse_transform_ext(&self, data: &mut [PE], coset: u32) -> Result<(), Error> - where - PE: RepackedExtension

, - PE::Scalar: ExtensionField, - { - if !PE::Scalar::DEGREE.is_power_of_two() { - return Err(Error::PowerOfTwoExtensionDegreeRequired); - } - - let log_batch_size = log2_strict_usize(PE::Scalar::DEGREE); - self.inverse_transform(PE::cast_bases_mut(data), coset, log_batch_size) + fn inverse_transform_ext>( + &self, + data: &mut [PE], + coset: u32, + ) -> Result<(), Error> { + self.inverse_transform(PE::cast_bases_mut(data), coset, PE::Scalar::LOG_DEGREE) } } diff --git a/crates/ntt/src/dynamic_dispatch.rs b/crates/ntt/src/dynamic_dispatch.rs index f7eb830f..bd1d0f8f 100644 --- a/crates/ntt/src/dynamic_dispatch.rs +++ b/crates/ntt/src/dynamic_dispatch.rs @@ -54,7 +54,7 @@ pub enum DynamicDispatchNTT { impl DynamicDispatchNTT { /// Create a new AdditiveNTT based on the given settings. - pub fn new(log_domain_size: usize, options: NTTOptions) -> Result { + pub fn new(log_domain_size: usize, options: &NTTOptions) -> Result { let log_threads = options.thread_settings.log_threads_count(); let result = match (options.precompute_twiddles, log_threads) { (false, 0) => Self::SingleThreaded(SingleThreadedNTT::new(log_domain_size)?), @@ -144,24 +144,24 @@ mod tests { #[test] fn test_creation() { - fn make_ntt(options: NTTOptions) -> DynamicDispatchNTT { + fn make_ntt(options: &NTTOptions) -> DynamicDispatchNTT { DynamicDispatchNTT::::new(6, options).unwrap() } - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: false, thread_settings: ThreadingSettings::SingleThreaded, }); assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_))); - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: true, thread_settings: ThreadingSettings::SingleThreaded, }); assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_))); let multithreaded = get_log_max_threads() > 0; - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: false, thread_settings: ThreadingSettings::MultithreadedDefault, }); @@ -171,7 +171,7 @@ mod tests { assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_))); } - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: true, thread_settings: ThreadingSettings::MultithreadedDefault, }); @@ -181,19 +181,19 @@ mod tests { assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_))); } - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: false, thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 2 }, }); assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_))); - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: true, thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 }, }); assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_))); - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: false, thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 }, }); diff --git a/crates/ntt/src/single_threaded.rs b/crates/ntt/src/single_threaded.rs index a5723a62..2448894c 100644 --- a/crates/ntt/src/single_threaded.rs +++ b/crates/ntt/src/single_threaded.rs @@ -187,7 +187,7 @@ pub fn forward_transform>( // packed twiddles for all packed butterfly units. let log_block_len = i + log_b; let block_twiddle = calculate_twiddle::

( - s_evals[i].coset(log_domain_size - 1 - cutoff, 0), + &s_evals[i].coset(log_domain_size - 1 - cutoff, 0), log_block_len, ); @@ -263,7 +263,7 @@ pub fn inverse_transform>( // packed twiddles for all packed butterfly units. let log_block_len = i + log_b; let block_twiddle = calculate_twiddle::

( - s_evals[i].coset(log_domain_size - 1 - cutoff, 0), + &s_evals[i].coset(log_domain_size - 1 - cutoff, 0), log_block_len, ); @@ -357,7 +357,7 @@ pub const fn check_batch_transform_inputs_and_params( } #[inline] -fn calculate_twiddle

(s_evals: impl TwiddleAccess, log_block_len: usize) -> P +fn calculate_twiddle

(s_evals: &impl TwiddleAccess, log_block_len: usize) -> P where P: PackedField, { diff --git a/crates/ntt/src/tests/ntt_tests.rs b/crates/ntt/src/tests/ntt_tests.rs index 7ab2d2e3..ccdf8b7d 100644 --- a/crates/ntt/src/tests/ntt_tests.rs +++ b/crates/ntt/src/tests/ntt_tests.rs @@ -10,8 +10,8 @@ use binius_field::{ packed_8::PackedBinaryField1x8b, }, underlier::{NumCast, WithUnderlier}, - AESTowerField8b, BinaryField, BinaryField8b, ExtensionField, PackedBinaryField16x32b, - PackedBinaryField8x32b, PackedField, RepackedExtension, + AESTowerField8b, BinaryField, BinaryField8b, PackedBinaryField16x32b, PackedBinaryField8x32b, + PackedField, RepackedExtension, }; use rand::{rngs::StdRng, SeedableRng}; @@ -144,16 +144,12 @@ fn tests_field_512_bits() { check_roundtrip_all_ntts::(12, 6, 4, 0); } -fn check_packed_extension_roundtrip_with_reference( +fn check_packed_extension_roundtrip_with_reference>( reference_ntt: &impl AdditiveNTT

, ntt: &impl AdditiveNTT

, data: &mut [PE], cosets: Range, -) where - P: PackedField, - PE: RepackedExtension

, - PE::Scalar: ExtensionField, -{ +) { let data_copy = data.to_vec(); let mut data_copy_2 = data.to_vec(); @@ -182,7 +178,6 @@ fn check_packed_extension_roundtrip_all_ntts( ) where P: PackedField, PE: RepackedExtension

+ WithUnderlier>, - PE::Scalar: ExtensionField, { let simple_ntt = SingleThreadedNTT::::new(log_domain_size) .unwrap() diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index 0d04d291..c0e41840 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -8,9 +8,10 @@ authors.workspace = true workspace = true [dependencies] +auto_impl.workspace = true binius_maybe_rayon = { path = "../maybe_rayon", default-features = false } -bytes.workspace = true bytemuck = { workspace = true, features = ["extern_crate_alloc"] } +bytes.workspace = true cfg-if.workspace = true generic-array.workspace = true itertools.workspace = true diff --git a/crates/utils/src/lib.rs b/crates/utils/src/lib.rs index 493fe6e2..3ac565f0 100644 --- a/crates/utils/src/lib.rs +++ b/crates/utils/src/lib.rs @@ -17,3 +17,6 @@ pub mod serialization; pub mod sorting; pub mod sparse_index; pub mod thread_local_mut; + +pub use bytes; +pub use serialization::{DeserializeBytes, SerializationError, SerializationMode, SerializeBytes}; diff --git a/crates/utils/src/serialization.rs b/crates/utils/src/serialization.rs index 434befcc..a6782b36 100644 --- a/crates/utils/src/serialization.rs +++ b/crates/utils/src/serialization.rs @@ -1,50 +1,417 @@ // Copyright 2024-2025 Irreducible Inc. +use auto_impl::auto_impl; use bytes::{Buf, BufMut}; -use generic_array::{ArrayLength, GenericArray}; +use thiserror::Error; + +/// Serialize data according to Mode param +#[auto_impl(Box, &)] +pub trait SerializeBytes { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError>; +} + +/// Deserialize data according to Mode param +pub trait DeserializeBytes { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized; +} + +/// Specifies serialization/deserialization behavior +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SerializationMode { + /// This mode is faster, and serializes to the underlying bytes + Native, + /// Will first convert any tower fields into the Fan-Paar field equivalent + CanonicalTower, +} -#[derive(Clone, thiserror::Error, Debug)] -pub enum Error { +#[derive(Error, Debug, Clone)] +pub enum SerializationError { #[error("Write buffer is full")] WriteBufferFull, #[error("Not enough data in read buffer to deserialize")] NotEnoughBytes, + #[error("Unknown enum variant index {name}::{index}")] + UnknownEnumVariant { name: &'static str, index: u8 }, + #[error("Serialization has not been implemented")] + SerializationNotImplemented, + #[error("Deserializer has not been implemented")] + DeserializerNotImplented, + #[error("Multiple deserializers with the same name {name} has been registered")] + DeserializerNameConflict { name: String }, + #[error("FromUtf8Error: {0}")] + FromUtf8Error(#[from] std::string::FromUtf8Error), } -/// Represents type that can be serialized to a byte buffer. -pub trait SerializeBytes { - fn serialize(&self, write_buf: impl BufMut) -> Result<(), Error>; +// Copyright 2025 Irreducible Inc. + +use generic_array::{ArrayLength, GenericArray}; + +impl DeserializeBytes for Box { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + Ok(Self::new(T::deserialize(read_buf, mode)?)) + } } -/// Represents type that can be deserialized from a byte buffer. -pub trait DeserializeBytes { - fn deserialize(read_buf: impl Buf) -> Result +impl SerializeBytes for usize { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + SerializeBytes::serialize(&(*self as u64), &mut write_buf, mode) + } +} + +impl DeserializeBytes for usize { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result where - Self: Sized; + Self: Sized, + { + let value: u64 = DeserializeBytes::deserialize(&mut read_buf, mode)?; + Ok(value as Self) + } } -impl> SerializeBytes for GenericArray { - fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - if write_buf.remaining_mut() < N::USIZE { - return Err(Error::WriteBufferFull); +impl SerializeBytes for u128 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u128_le(*self); + Ok(()) + } +} + +impl DeserializeBytes for u128 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u128_le()) + } +} + +impl SerializeBytes for u64 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u64_le(*self); + Ok(()) + } +} + +impl DeserializeBytes for u64 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u64_le()) + } +} + +impl SerializeBytes for u32 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u32_le(*self); + Ok(()) + } +} + +impl DeserializeBytes for u32 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u32_le()) + } +} + +impl SerializeBytes for u16 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u16_le(*self); + Ok(()) + } +} + +impl DeserializeBytes for u16 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u16_le()) + } +} + +impl SerializeBytes for u8 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u8(*self); + Ok(()) + } +} + +impl DeserializeBytes for u8 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u8()) + } +} + +impl SerializeBytes for bool { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + u8::serialize(&(*self as u8), write_buf, mode) + } +} + +impl DeserializeBytes for bool { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + Ok(u8::deserialize(read_buf, mode)? != 0) + } +} + +impl SerializeBytes for std::marker::PhantomData { + fn serialize( + &self, + _write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl DeserializeBytes for std::marker::PhantomData { + fn deserialize( + _read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(Self) + } +} + +impl SerializeBytes for &str { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + let bytes = self.as_bytes(); + SerializeBytes::serialize(&bytes.len(), &mut write_buf, mode)?; + assert_enough_space_for(&write_buf, bytes.len())?; + write_buf.put_slice(bytes); + Ok(()) + } +} + +impl SerializeBytes for String { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + SerializeBytes::serialize(&self.as_str(), &mut write_buf, mode) + } +} + +impl DeserializeBytes for String { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + let len = DeserializeBytes::deserialize(&mut read_buf, mode)?; + assert_enough_data_for(&read_buf, len)?; + Ok(Self::from_utf8(read_buf.copy_to_bytes(len).to_vec())?) + } +} + +impl SerializeBytes for Vec { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + SerializeBytes::serialize(&self.len(), &mut write_buf, mode)?; + self.iter() + .try_for_each(|item| SerializeBytes::serialize(item, &mut write_buf, mode)) + } +} + +impl DeserializeBytes for Vec { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + let len: usize = DeserializeBytes::deserialize(&mut read_buf, mode)?; + (0..len) + .map(|_| DeserializeBytes::deserialize(&mut read_buf, mode)) + .collect() + } +} + +impl SerializeBytes for Option { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + match self { + Some(value) => { + SerializeBytes::serialize(&true, &mut write_buf, mode)?; + SerializeBytes::serialize(value, &mut write_buf, mode)?; + } + None => { + SerializeBytes::serialize(&false, write_buf, mode)?; + } } + Ok(()) + } +} + +impl DeserializeBytes for Option { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(match bool::deserialize(&mut read_buf, mode)? { + true => Some(T::deserialize(&mut read_buf, mode)?), + false => None, + }) + } +} + +impl SerializeBytes for (U, V) { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + U::serialize(&self.0, &mut write_buf, mode)?; + V::serialize(&self.1, write_buf, mode) + } +} + +impl DeserializeBytes for (U, V) { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok((U::deserialize(&mut read_buf, mode)?, V::deserialize(read_buf, mode)?)) + } +} + +impl> SerializeBytes for GenericArray { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, N::USIZE)?; write_buf.put_slice(self); Ok(()) } } impl> DeserializeBytes for GenericArray { - fn deserialize(mut read_buf: impl Buf) -> Result { - if read_buf.remaining() < N::USIZE { - return Err(Error::NotEnoughBytes); - } - + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result { + assert_enough_data_for(&read_buf, N::USIZE)?; let mut ret = Self::default(); read_buf.copy_to_slice(&mut ret); Ok(ret) } } +#[inline] +fn assert_enough_space_for(write_buf: &impl BufMut, size: usize) -> Result<(), SerializationError> { + if write_buf.remaining_mut() < size { + return Err(SerializationError::WriteBufferFull); + } + Ok(()) +} + +#[inline] +fn assert_enough_data_for(read_buf: &impl Buf, size: usize) -> Result<(), SerializationError> { + if read_buf.remaining() < size { + return Err(SerializationError::NotEnoughBytes); + } + Ok(()) +} + #[cfg(test)] mod tests { use generic_array::typenum::U32; @@ -60,9 +427,11 @@ mod tests { rng.fill_bytes(&mut data); let mut buf = Vec::new(); - data.serialize(&mut buf).unwrap(); + data.serialize(&mut buf, SerializationMode::Native).unwrap(); - let data_deserialized = GenericArray::::deserialize(&mut buf.as_slice()).unwrap(); + let data_deserialized = + GenericArray::::deserialize(&mut buf.as_slice(), SerializationMode::Native) + .unwrap(); assert_eq!(data_deserialized, data); } } diff --git a/crates/utils/src/thread_local_mut.rs b/crates/utils/src/thread_local_mut.rs index 0f196c5b..9bf80ad3 100644 --- a/crates/utils/src/thread_local_mut.rs +++ b/crates/utils/src/thread_local_mut.rs @@ -6,7 +6,7 @@ use thread_local::ThreadLocal; /// Creates a "scratch space" within each thread with mutable access. /// -/// This is mainly meant to be used as an optimization to avoid unneccesary allocs/frees within rayon code. +/// This is mainly meant to be used as an optimization to avoid unnecessary allocs/frees within rayon code. /// You only pay for allocation of this scratch space once per thread. /// /// Since the space is local to each thread you also don't have to worry about atomicity. diff --git a/examples/Cargo.toml b/examples/Cargo.toml index ece6509b..0de51d97 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -25,10 +25,6 @@ rand.workspace = true tracing-profile.workspace = true tracing.workspace = true -[[example]] -name = "groestl_circuit" -path = "groestl_circuit.rs" - [[example]] name = "keccakf_circuit" path = "keccakf_circuit.rs" @@ -81,9 +77,9 @@ path = "b32_mul.rs" name = "acc-linear-combination" path = "acc-linear-combination.rs" -[[example]] -name = "acc-linear-combination-with-offset" -path = "acc-linear-combination-with-offset.rs" +#[[example]] +#name = "acc-linear-combination-with-offset" +#path = "acc-linear-combination-with-offset.rs" [[example]] name = "acc-shifted" @@ -145,6 +141,10 @@ path = "acc-step-up.rs" name = "acc-tower-basis" path = "acc-tower-basis.rs" +[[example]] +name = "acc-permutation-channels" +path = "acc-permutation-channels.rs" + [lints.clippy] needless_range_loop = "allow" @@ -154,4 +154,3 @@ aes-tower = [] bail_panic = ["binius_utils/bail_panic"] fp-tower = [] rayon = ["binius_utils/rayon"] - diff --git a/examples/acc-constants.rs b/examples/acc-constants.rs index 0947ed10..9536f172 100644 --- a/examples/acc-constants.rs +++ b/examples/acc-constants.rs @@ -3,10 +3,7 @@ use binius_core::{ constraint_system::validate::validate_witness, oracle::OracleId, transparent::constant::Constant, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b, BinaryField32b}; - -type U = OptimalUnderlier; -type F128 = BinaryField128b; +use binius_field::{BinaryField1b, BinaryField32b}; type F32 = BinaryField32b; type F1 = BinaryField1b; @@ -17,7 +14,7 @@ const LOG_SIZE: usize = 4; fn constants_gadget( name: impl ToString, log_size: usize, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, constant_value: u32, ) -> OracleId { builder.push_namespace(name); @@ -45,7 +42,8 @@ fn constants_gadget( // Transparent column. fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); pub const SHA256_INIT: [u32; 8] = [ 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, diff --git a/examples/acc-disjoint-product.rs b/examples/acc-disjoint-product.rs index 05f4b59d..2f6ed581 100644 --- a/examples/acc-disjoint-product.rs +++ b/examples/acc-disjoint-product.rs @@ -3,12 +3,8 @@ use binius_core::{ constraint_system::validate::validate_witness, transparent::{constant::Constant, disjoint_product::DisjointProduct, powers::Powers}, }; -use binius_field::{ - arch::OptimalUnderlier, BinaryField, BinaryField128b, BinaryField8b, PackedField, -}; +use binius_field::{BinaryField, BinaryField8b, PackedField}; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; const LOG_SIZE: usize = 4; @@ -29,7 +25,7 @@ const LOG_SIZE: usize = 4; // of heights (n_vars) of Powers and Constant, so actual data could be repeated multiple times fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let generator = F8::MULTIPLICATIVE_GENERATOR; let powers = Powers::new(LOG_SIZE, generator.into()); diff --git a/examples/acc-eq-ind-partial-eval.rs b/examples/acc-eq-ind-partial-eval.rs index 7fcdc4b5..6d76fb45 100644 --- a/examples/acc-eq-ind-partial-eval.rs +++ b/examples/acc-eq-ind-partial-eval.rs @@ -2,9 +2,8 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{ constraint_system::validate::validate_witness, transparent::eq_ind::EqIndPartialEval, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, PackedField}; +use binius_field::{BinaryField128b, PackedField}; -type U = OptimalUnderlier; type F128 = BinaryField128b; const LOG_SIZE: usize = 3; @@ -21,7 +20,8 @@ const LOG_SIZE: usize = 3; // fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); // A truth table [000, 001, 010, 011 ... 111] where each row is in reversed order let rev_basis = [ diff --git a/examples/acc-linear-combination-with-offset.rs b/examples/acc-linear-combination-with-offset.rs.disabled similarity index 100% rename from examples/acc-linear-combination-with-offset.rs rename to examples/acc-linear-combination-with-offset.rs.disabled diff --git a/examples/acc-linear-combination.rs b/examples/acc-linear-combination.rs index 5cb1eab5..b069b519 100644 --- a/examples/acc-linear-combination.rs +++ b/examples/acc-linear-combination.rs @@ -1,20 +1,17 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; use binius_field::{ - arch::OptimalUnderlier, packed::set_packed_slice, BinaryField128b, BinaryField1b, - BinaryField8b, ExtensionField, TowerField, + packed::set_packed_slice, BinaryField1b, BinaryField8b, ExtensionField, TowerField, }; use binius_macros::arith_expr; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; type F1 = BinaryField1b; // FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production fn bytes_decomposition_gadget( - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, input: OracleId, @@ -146,7 +143,7 @@ fn bytes_decomposition_gadget( } fn elder_4bits_masking_gadget( - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, input: OracleId, @@ -241,12 +238,12 @@ fn elder_4bits_masking_gadget( fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = 1usize; // Define set of bytes that we want to decompose - let p_in = unconstrained::(&mut builder, "p_in".to_string(), log_size).unwrap(); + let p_in = unconstrained::(&mut builder, "p_in".to_string(), log_size).unwrap(); let _ = bytes_decomposition_gadget(&mut builder, "bytes decomposition", log_size, p_in).unwrap(); diff --git a/examples/acc-multilinear-extension-transparent.rs b/examples/acc-multilinear-extension-transparent.rs index f9dd4570..033184ed 100644 --- a/examples/acc-multilinear-extension-transparent.rs +++ b/examples/acc-multilinear-extension-transparent.rs @@ -14,7 +14,7 @@ type F128 = BinaryField128b; type F1 = BinaryField1b; // From a perspective of circuits creation, MultilinearExtensionTransparent can be used naturally for decomposing integers to bits -fn decompose_transparent_u64(builder: &mut ConstraintSystemBuilder, x: u64) { +fn decompose_transparent_u64(builder: &mut ConstraintSystemBuilder, x: u64) { builder.push_namespace("decompose_transparent_u64"); let log_bits = log2_ceil_usize(64); @@ -42,7 +42,7 @@ fn decompose_transparent_u64(builder: &mut ConstraintSystemBuilder, x: builder.pop_namespace(); } -fn decompose_transparent_u32(builder: &mut ConstraintSystemBuilder, x: u32) { +fn decompose_transparent_u32(builder: &mut ConstraintSystemBuilder, x: u32) { builder.push_namespace("decompose_transparent_u32"); let log_bits = log2_ceil_usize(32); @@ -72,7 +72,7 @@ fn decompose_transparent_u32(builder: &mut ConstraintSystemBuilder, x: fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); decompose_transparent_u64(&mut builder, 0xff00ff00ff00ff00); decompose_transparent_u32(&mut builder, 0x00ff00ff); diff --git a/examples/acc-packed.rs b/examples/acc-packed.rs index 6963d878..434a9d11 100644 --- a/examples/acc-packed.rs +++ b/examples/acc-packed.rs @@ -1,12 +1,7 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::constraint_system::validate::validate_witness; -use binius_field::{ - arch::OptimalUnderlier, BinaryField128b, BinaryField16b, BinaryField1b, BinaryField32b, - BinaryField8b, TowerField, -}; +use binius_field::{BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, TowerField}; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F32 = BinaryField32b; type F16 = BinaryField16b; type F8 = BinaryField8b; @@ -14,10 +9,10 @@ type F1 = BinaryField1b; // FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production -fn packing_32_bits_to_u32(builder: &mut ConstraintSystemBuilder) { +fn packing_32_bits_to_u32(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("packing_32_bits_to_u32"); - let bits = unconstrained::(builder, "bits", F32::TOWER_LEVEL).unwrap(); + let bits = unconstrained::(builder, "bits", F32::TOWER_LEVEL).unwrap(); let packed = builder .add_packed("packed", bits, F32::TOWER_LEVEL) .unwrap(); @@ -49,10 +44,10 @@ fn packing_32_bits_to_u32(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn packing_4_bytes_to_u32(builder: &mut ConstraintSystemBuilder) { +fn packing_4_bytes_to_u32(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("packing_4_bytes_to_u32"); - let bytes = unconstrained::(builder, "bytes", F16::TOWER_LEVEL).unwrap(); + let bytes = unconstrained::(builder, "bytes", F16::TOWER_LEVEL).unwrap(); let packed = builder .add_packed("packed", bytes, F16::TOWER_LEVEL) .unwrap(); @@ -76,10 +71,10 @@ fn packing_4_bytes_to_u32(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn packing_8_bits_to_u8(builder: &mut ConstraintSystemBuilder) { +fn packing_8_bits_to_u8(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("packing_8_bits_to_u8"); - let bits = unconstrained::(builder, "bits", F8::TOWER_LEVEL).unwrap(); + let bits = unconstrained::(builder, "bits", F8::TOWER_LEVEL).unwrap(); let packed = builder.add_packed("packed", bits, F8::TOWER_LEVEL).unwrap(); if let Some(witness) = builder.witness() { @@ -94,7 +89,7 @@ fn packing_8_bits_to_u8(builder: &mut ConstraintSystemBuilder) { fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); packing_32_bits_to_u32(&mut builder); packing_4_bytes_to_u32(&mut builder); diff --git a/examples/acc-permutation-channels.rs b/examples/acc-permutation-channels.rs new file mode 100644 index 00000000..a628bf35 --- /dev/null +++ b/examples/acc-permutation-channels.rs @@ -0,0 +1,111 @@ +use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::fixed_u32}; +use binius_core::constraint_system::{ + channel::{Boundary, FlushDirection}, + validate::validate_witness, +}; +use binius_field::{BinaryField128b, BinaryField32b}; +use bumpalo::Bump; + +type F128 = BinaryField128b; +type F32 = BinaryField32b; + +const MSG_PERMUTATION: [usize; 16] = [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8]; + +// Permutation is a classic construction in a traditional cryptography. It has well-defined security properties +// and high performance due to implementation via lookups. One can possible to implement gadget for permutations using +// channels API from Binius. The following examples shows how to enforce Blake3 permutation - verifier pulls pairs of +// input/output of the permutation (encoded as a BinaryField128b elements, to reduce number of flushes), +// while prover is expected to push similar IO to make channel balanced. +fn permute(m: &mut [u32; 16]) { + let mut permuted = [0; 16]; + for i in 0..16 { + permuted[i] = m[MSG_PERMUTATION[i]]; + } + *m = permuted; +} + +fn main() { + let log_size = 4usize; + + let allocator = Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + let m = [ + 0xfffffff0, 0xfffffff1, 0xfffffff2, 0xfffffff3, 0xfffffff4, 0xfffffff5, 0xfffffff6, + 0xfffffff7, 0xfffffff8, 0xfffffff9, 0xfffffffa, 0xfffffffb, 0xfffffffc, 0xfffffffd, + 0xfffffffe, 0xffffffff, + ]; + + let mut m_clone = m; + permute(&mut m_clone); + + let expected = [ + 0xfffffff2, 0xfffffff6, 0xfffffff3, 0xfffffffa, 0xfffffff7, 0xfffffff0, 0xfffffff4, + 0xfffffffd, 0xfffffff1, 0xfffffffb, 0xfffffffc, 0xfffffff5, 0xfffffff9, 0xfffffffe, + 0xffffffff, 0xfffffff8, + ]; + assert_eq!(m_clone, expected); + + let u32_in = fixed_u32::(&mut builder, "in", log_size, m.to_vec()).unwrap(); + let u32_out = fixed_u32::(&mut builder, "out", log_size, expected.to_vec()).unwrap(); + + // we pack 4-u32 (F32) tuples of permutation IO into F128 columns and use them for flushing + let u128_in = builder.add_packed("in_packed", u32_in, 2).unwrap(); + let u128_out = builder.add_packed("out_packed", u32_out, 2).unwrap(); + + // populate memory layout (witness) + if let Some(witness) = builder.witness() { + let in_f32 = witness.get::(u32_in).unwrap(); + let out_f32 = witness.get::(u32_out).unwrap(); + witness.new_column::(u128_in); + witness.new_column::(u128_out); + + witness.set(u128_in, in_f32.repacked::()).unwrap(); + witness.set(u128_out, out_f32.repacked::()).unwrap(); + } + + let channel = builder.add_channel(); + // count defines how many values ( 0 .. count ) from a given columns to send (pushing to a channel) + builder.send(channel, 4, [u128_in, u128_out]).unwrap(); + + let witness = builder.take_witness().unwrap(); + let cs = builder.build().unwrap(); + + // consider our 4-u32 values from a given tupple as 4 limbs of u128 + let f = |limb0: u32, limb1: u32, limb2: u32, limb3: u32| { + let mut x = 0u128; + + x ^= (limb3 as u128) << 96; + x ^= (limb2 as u128) << 64; + x ^= (limb1 as u128) << 32; + x ^= limb0 as u128; + + F128::new(x) + }; + + // Boundaries define actual data (encoded in a set of Flushes) that verifier can push or pull from a given channel + // in order to check if prover is able to balance that channel + let mut offset = 0usize; + let boundaries = (0..4) + .map(|_| { + let boundary = Boundary { + values: vec![ + f(m[offset], m[offset + 1], m[offset + 2], m[offset + 3]), + f( + expected[offset], + expected[offset + 1], + expected[offset + 2], + expected[offset + 3], + ), + ], + channel_id: channel, + direction: FlushDirection::Pull, + multiplicity: 1, + }; + offset += 4; + boundary + }) + .collect::>>(); + + validate_witness(&cs, &boundaries, &witness).unwrap(); +} diff --git a/examples/acc-powers.rs b/examples/acc-powers.rs index 159aa113..c266191e 100644 --- a/examples/acc-powers.rs +++ b/examples/acc-powers.rs @@ -1,12 +1,7 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::constraint_system::validate::validate_witness; -use binius_field::{ - arch::OptimalUnderlier, BinaryField, BinaryField128b, BinaryField16b, BinaryField32b, - PackedField, -}; +use binius_field::{BinaryField, BinaryField16b, BinaryField32b, PackedField}; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F32 = BinaryField32b; type F16 = BinaryField16b; @@ -22,7 +17,7 @@ const LOG_SIZE: usize = 3; // where 'x' is a multiplicative generator - a public value that exists for every BinaryField // -fn powers_gadget_f32(builder: &mut ConstraintSystemBuilder, name: impl ToString) { +fn powers_gadget_f32(builder: &mut ConstraintSystemBuilder, name: impl ToString) { builder.push_namespace(name); let generator = F32::MULTIPLICATIVE_GENERATOR; @@ -43,7 +38,7 @@ fn powers_gadget_f32(builder: &mut ConstraintSystemBuilder, name: impl } // Only Field is being changed -fn powers_gadget_f16(builder: &mut ConstraintSystemBuilder, name: impl ToString) { +fn powers_gadget_f16(builder: &mut ConstraintSystemBuilder, name: impl ToString) { builder.push_namespace(name); let generator = F16::MULTIPLICATIVE_GENERATOR; @@ -65,7 +60,7 @@ fn powers_gadget_f16(builder: &mut ConstraintSystemBuilder, name: impl fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); powers_gadget_f16(&mut builder, "f16"); powers_gadget_f32(&mut builder, "f32"); diff --git a/examples/acc-projected.rs b/examples/acc-projected.rs index 8d4de121..c7ac8150 100644 --- a/examples/acc-projected.rs +++ b/examples/acc-projected.rs @@ -1,8 +1,7 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::{constraint_system::validate::validate_witness, oracle::ProjectionVariant}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::{BinaryField128b, BinaryField8b}; -type U = OptimalUnderlier; type F128 = BinaryField128b; type F8 = BinaryField8b; @@ -20,14 +19,13 @@ struct U8U128ProjectionInfo { // has significant impact on input data processing. // In the following example we have input column with bytes (u8) projected to the output column with u128 values. fn projection( - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, projection_info: U8U128ProjectionInfo, namespace: &str, ) { builder.push_namespace(format!("projection {}", namespace)); - let input = - unconstrained::(builder, "in", projection_info.clone().log_size).unwrap(); + let input = unconstrained::(builder, "in", projection_info.clone().log_size).unwrap(); let projected = builder .add_projected( @@ -112,7 +110,7 @@ impl U8U128ProjectionInfo { fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let projection_data = U8U128ProjectionInfo::new( 4usize, diff --git a/examples/acc-repeated.rs b/examples/acc-repeated.rs index 749bb6fe..ef4bf643 100644 --- a/examples/acc-repeated.rs +++ b/examples/acc-repeated.rs @@ -1,12 +1,9 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::constraint_system::validate::validate_witness; use binius_field::{ - arch::OptimalUnderlier, packed::set_packed_slice, BinaryField128b, BinaryField1b, - BinaryField8b, PackedBinaryField128x1b, + packed::set_packed_slice, BinaryField1b, BinaryField8b, PackedBinaryField128x1b, }; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; type F1 = BinaryField1b; @@ -18,10 +15,10 @@ const LOG_SIZE: usize = 8; // so new column is X times bigger than original one. The following gadget operates over bytes, e.g. // it creates column with some input bytes written and then creates one more 'Repeated' column // where the same bytes are copied multiple times. -fn bytes_repeat_gadget(builder: &mut ConstraintSystemBuilder) { +fn bytes_repeat_gadget(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("bytes_repeat_gadget"); - let bytes = unconstrained::(builder, "input", LOG_SIZE).unwrap(); + let bytes = unconstrained::(builder, "input", LOG_SIZE).unwrap(); let repeat_times_log = 4usize; let repeating = builder @@ -57,10 +54,10 @@ fn bytes_repeat_gadget(builder: &mut ConstraintSystemBuilder) { // repetitions Binius creates column with 8 PackedBinaryField128x1b elements totally. // Proper writing bits requires separate iterating over PackedBinaryField128x1b elements and input bytes // with extracting particular bit values from the input and setting appropriate bit in a given PackedBinaryField128x1b. -fn bits_repeat_gadget(builder: &mut ConstraintSystemBuilder) { +fn bits_repeat_gadget(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("bits_repeat_gadget"); - let bits = unconstrained::(builder, "input", LOG_SIZE).unwrap(); + let bits = unconstrained::(builder, "input", LOG_SIZE).unwrap(); let repeat_times_log = 2usize; // Binius will create column with appropriate height for us @@ -112,7 +109,7 @@ fn bits_repeat_gadget(builder: &mut ConstraintSystemBuilder) { fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); bytes_repeat_gadget(&mut builder); bits_repeat_gadget(&mut builder); diff --git a/examples/acc-select-row.rs b/examples/acc-select-row.rs index 4337540b..7f45bee0 100644 --- a/examples/acc-select-row.rs +++ b/examples/acc-select-row.rs @@ -2,10 +2,8 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{ constraint_system::validate::validate_witness, transparent::select_row::SelectRow, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::BinaryField8b; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; const LOG_SIZE: usize = 8; @@ -13,7 +11,7 @@ const LOG_SIZE: usize = 8; // SelectRow expects exactly one witness value at particular index to be set. fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let index = 58; assert!(index < 1 << LOG_SIZE); diff --git a/examples/acc-shift-ind-partial-eq.rs b/examples/acc-shift-ind-partial-eq.rs index 74fa2bae..385a07b8 100644 --- a/examples/acc-shift-ind-partial-eq.rs +++ b/examples/acc-shift-ind-partial-eq.rs @@ -3,16 +3,15 @@ use binius_core::{ constraint_system::validate::validate_witness, oracle::ShiftVariant, transparent::shift_ind::ShiftIndPartialEval, }; -use binius_field::{arch::OptimalUnderlier, util::eq, BinaryField128b, Field}; +use binius_field::{util::eq, BinaryField128b, Field}; -type U = OptimalUnderlier; type F128 = BinaryField128b; // ShiftIndPartialEval is a more elaborated version of EqIndPartialEval. Same idea with challenges, but a bit more // elaborated evaluation algorithm is used fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let block_size = 3; let shift_offset = 4; diff --git a/examples/acc-shifted.rs b/examples/acc-shifted.rs index 924fb823..a81a6610 100644 --- a/examples/acc-shifted.rs +++ b/examples/acc-shifted.rs @@ -1,21 +1,19 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::{constraint_system::validate::validate_witness, oracle::ShiftVariant}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b}; +use binius_field::BinaryField1b; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F1 = BinaryField1b; // FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production -fn shift_right_gadget_u32(builder: &mut ConstraintSystemBuilder) { +fn shift_right_gadget_u32(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("u32_right_shift"); // defined empirically and it is the same as 'block_bits' defined below let log_size = 5usize; // create column and write arbitrary bytes to it - let input = unconstrained::(builder, "input", log_size).unwrap(); + let input = unconstrained::(builder, "input", log_size).unwrap(); // we want to shift our u32 variable on 1 bit let shift_offset = 1; @@ -43,11 +41,11 @@ fn shift_right_gadget_u32(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn shift_left_gadget_u8(builder: &mut ConstraintSystemBuilder) { +fn shift_left_gadget_u8(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("u8_left_shift"); let log_size = 3usize; - let input = unconstrained::(builder, "input", log_size).unwrap(); + let input = unconstrained::(builder, "input", log_size).unwrap(); let shift_offset = 4; let shift_type = ShiftVariant::LogicalLeft; let block_bits = 3; @@ -67,11 +65,11 @@ fn shift_left_gadget_u8(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn rotate_left_gadget_u16(builder: &mut ConstraintSystemBuilder) { +fn rotate_left_gadget_u16(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("u16_rotate_right"); let log_size = 4usize; - let input = unconstrained::(builder, "input", log_size).unwrap(); + let input = unconstrained::(builder, "input", log_size).unwrap(); let rotation_offset = 5; let rotation_type = ShiftVariant::CircularLeft; let block_bits = 4usize; @@ -92,11 +90,11 @@ fn rotate_left_gadget_u16(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn rotate_right_gadget_u64(builder: &mut ConstraintSystemBuilder) { +fn rotate_right_gadget_u64(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("u64_rotate_right"); let log_size = 6usize; - let input = unconstrained::(builder, "input", log_size).unwrap(); + let input = unconstrained::(builder, "input", log_size).unwrap(); // Right rotation to X bits is achieved using 'ShiftVariant::CircularLeft' with the offset, // computed as size in bits of the variable type - X (e.g. if we want to right-rotate u64 to 8 bits, @@ -124,7 +122,7 @@ fn rotate_right_gadget_u64(builder: &mut ConstraintSystemBuilder) { fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); shift_right_gadget_u32(&mut builder); shift_left_gadget_u8(&mut builder); diff --git a/examples/acc-step-down.rs b/examples/acc-step-down.rs index cbe3cd6e..a2ecf3ba 100644 --- a/examples/acc-step-down.rs +++ b/examples/acc-step-down.rs @@ -2,18 +2,16 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{ constraint_system::validate::validate_witness, transparent::step_down::StepDown, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::BinaryField8b; const LOG_SIZE: usize = 8; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; // StepDown expects all bytes to be set before particular index specified as input fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let index = 10; diff --git a/examples/acc-step-up.rs b/examples/acc-step-up.rs index e659352e..0d65820e 100644 --- a/examples/acc-step-up.rs +++ b/examples/acc-step-up.rs @@ -1,9 +1,7 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{constraint_system::validate::validate_witness, transparent::step_up::StepUp}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::BinaryField8b; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; const LOG_SIZE: usize = 8; @@ -11,7 +9,7 @@ const LOG_SIZE: usize = 8; // StepUp expects all bytes to be unset before particular index specified as input (opposite to StepDown) fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let index = 10; diff --git a/examples/acc-tower-basis.rs b/examples/acc-tower-basis.rs index e99e6790..fcc9a383 100644 --- a/examples/acc-tower-basis.rs +++ b/examples/acc-tower-basis.rs @@ -2,16 +2,15 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{ constraint_system::validate::validate_witness, transparent::tower_basis::TowerBasis, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, Field, TowerField}; +use binius_field::{BinaryField128b, Field, TowerField}; -type U = OptimalUnderlier; type F128 = BinaryField128b; // TowerBasis expects actually basis vectors written to the witness. // The form of basis could vary depending on 'iota' and 'k' parameters fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let k = 3usize; let iota = 4usize; diff --git a/examples/acc-zeropadded.rs b/examples/acc-zeropadded.rs index 2cbb954c..b306bd63 100644 --- a/examples/acc-zeropadded.rs +++ b/examples/acc-zeropadded.rs @@ -1,9 +1,7 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::constraint_system::validate::validate_witness; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::BinaryField8b; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; const LOG_SIZE: usize = 4; @@ -11,9 +9,9 @@ const LOG_SIZE: usize = 4; fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let bytes = unconstrained::(&mut builder, "bytes", LOG_SIZE).unwrap(); + let bytes = unconstrained::(&mut builder, "bytes", LOG_SIZE).unwrap(); // Height of ZeroPadded column can't be smaller than input one. // If n_vars equals to LOG_SIZE, then no padding is required, diff --git a/examples/b32_mul.rs b/examples/b32_mul.rs index b0168f43..5cdbbea5 100644 --- a/examples/b32_mul.rs +++ b/examples/b32_mul.rs @@ -1,9 +1,9 @@ // Copyright 2024-2025 Irreducible Inc. use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField32b, TowerField}; +use binius_field::{BinaryField32b, TowerField}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_macros::arith_expr; @@ -26,7 +26,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -42,18 +41,18 @@ fn main() -> Result<()> { let log_n_muls = log2_ceil_usize(args.n_ops as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField32b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_muls, ) .unwrap(); - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField32b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_muls, diff --git a/examples/bitwise_ops.rs b/examples/bitwise_ops.rs index 3eaec2bb..6b46e29a 100644 --- a/examples/bitwise_ops.rs +++ b/examples/bitwise_ops.rs @@ -3,11 +3,9 @@ use std::{fmt::Display, str::FromStr}; use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{ - arch::OptimalUnderlier, BinaryField128b, BinaryField1b, BinaryField32b, TowerField, -}; +use binius_field::{BinaryField1b, BinaryField32b, TowerField}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_macros::arith_expr; @@ -63,7 +61,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -80,16 +77,16 @@ fn main() -> Result<()> { log2_ceil_usize(args.n_u32_ops as usize) + BinaryField32b::TOWER_LEVEL; let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); // Assuming our 32bit values have been committed as bits - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField1b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_1b_operations, )?; - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField1b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_1b_operations, diff --git a/examples/collatz.rs b/examples/collatz.rs index c30305ca..f51ce7a0 100644 --- a/examples/collatz.rs +++ b/examples/collatz.rs @@ -2,7 +2,7 @@ use anyhow::Result; use binius_circuits::{ - builder::ConstraintSystemBuilder, + builder::{types::U, ConstraintSystemBuilder}, collatz::{Advice, Collatz}, }; use binius_core::{ @@ -10,7 +10,6 @@ use binius_core::{ fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -29,9 +28,6 @@ struct Args { log_inv_rate: u32, } -type U = OptimalUnderlier; -type F = BinaryField128b; - const SECURITY_BITS: usize = 100; fn main() -> Result<()> { @@ -59,7 +55,7 @@ fn prove(x0: u32, log_inv_rate: usize) -> Result<(Advice, Proof), anyhow::Error> let advice = collatz.init_prover(); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let boundaries = collatz.build(&mut builder, advice)?; @@ -95,7 +91,7 @@ fn prove(x0: u32, log_inv_rate: usize) -> Result<(Advice, Proof), anyhow::Error> fn verify(x0: u32, advice: Advice, proof: Proof, log_inv_rate: usize) -> Result<(), anyhow::Error> { let collatz = Collatz::new(x0); - let mut builder = ConstraintSystemBuilder::::new(); + let mut builder = ConstraintSystemBuilder::new(); let boundaries = collatz.build(&mut builder, advice)?; diff --git a/examples/groestl_circuit.rs b/examples/groestl_circuit.rs.disabled similarity index 100% rename from examples/groestl_circuit.rs rename to examples/groestl_circuit.rs.disabled diff --git a/examples/keccakf_circuit.rs b/examples/keccakf_circuit.rs index ed87a920..68f12812 100644 --- a/examples/keccakf_circuit.rs +++ b/examples/keccakf_circuit.rs @@ -5,9 +5,8 @@ use std::vec; use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -28,7 +27,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -45,14 +43,14 @@ fn main() -> Result<()> { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = log_n_permutations; let trace_gen_scope = tracing::info_span!("generating trace").entered(); let input_witness = vec![]; let _state_out = - binius_circuits::keccakf::keccakf(&mut builder, Some(input_witness), log_size)?; + binius_circuits::keccakf::keccakf(&mut builder, &Some(input_witness), log_size)?; drop(trace_gen_scope); let witness = builder diff --git a/examples/modular_mul.rs b/examples/modular_mul.rs index b5dfb9e7..1649fa55 100644 --- a/examples/modular_mul.rs +++ b/examples/modular_mul.rs @@ -5,15 +5,14 @@ use std::array; use alloy_primitives::U512; use anyhow::Result; use binius_circuits::{ - builder::ConstraintSystemBuilder, + builder::{types::U, ConstraintSystemBuilder}, lasso::big_integer_ops::{byte_sliced_modular_mul, byte_sliced_test_utils::random_u512}, transparent, }; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; use binius_field::{ - arch::OptimalUnderlier128b, tower_levels::{TowerLevel4, TowerLevel8}, - BinaryField128b, BinaryField1b, BinaryField8b, Field, TowerField, + BinaryField1b, BinaryField8b, Field, TowerField, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; @@ -36,8 +35,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier128b; - type F = BinaryField128b; type B8 = BinaryField8b; const SECURITY_BITS: usize = 100; const WIDTH: usize = 4; @@ -53,7 +50,7 @@ fn main() -> Result<()> { println!("Verifying {} u32 modular multiplications", args.n_multiplications); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = log2_ceil_usize(args.n_multiplications as usize); let mut rng = thread_rng(); @@ -98,7 +95,7 @@ fn main() -> Result<()> { let zero_oracle_carry = transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); - let _modded_product = byte_sliced_modular_mul::<_, _, TowerLevel4, TowerLevel8>( + let _modded_product = byte_sliced_modular_mul::( &mut builder, "lasso_bytesliced_mul", &mult_a, diff --git a/examples/sha256_circuit.rs b/examples/sha256_circuit.rs index 2c5d162a..3835c79f 100644 --- a/examples/sha256_circuit.rs +++ b/examples/sha256_circuit.rs @@ -5,11 +5,14 @@ use std::array; use anyhow::Result; -use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; +use binius_circuits::{ + builder::{types::U, ConstraintSystemBuilder}, + unconstrained::unconstrained, +}; use binius_core::{ constraint_system, fiat_shamir::HasherChallenger, oracle::OracleId, tower::CanonicalTowerFamily, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b}; +use binius_field::BinaryField1b; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -32,7 +35,6 @@ struct Args { const COMPRESSION_LOG_LEN: usize = 5; fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -48,15 +50,11 @@ fn main() -> Result<()> { let log_n_compressions = log2_ceil_usize(args.n_compressions as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); let input: [OracleId; 16] = array::try_from_fn(|i| { - unconstrained::<_, _, BinaryField1b>( - &mut builder, - i, - log_n_compressions + COMPRESSION_LOG_LEN, - ) + unconstrained::(&mut builder, i, log_n_compressions + COMPRESSION_LOG_LEN) })?; let _state_out = binius_circuits::sha256::sha256( diff --git a/examples/sha256_circuit_with_lookup.rs b/examples/sha256_circuit_with_lookup.rs index 6b7cd791..f7f83435 100644 --- a/examples/sha256_circuit_with_lookup.rs +++ b/examples/sha256_circuit_with_lookup.rs @@ -5,13 +5,14 @@ use std::array; use anyhow::Result; -use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; +use binius_circuits::{ + builder::{types::U, ConstraintSystemBuilder}, + unconstrained::unconstrained, +}; use binius_core::{ constraint_system, fiat_shamir::HasherChallenger, oracle::OracleId, tower::CanonicalTowerFamily, }; -use binius_field::{ - arch::OptimalUnderlier, as_packed_field::PackedType, BinaryField128b, BinaryField1b, -}; +use binius_field::{arch::OptimalUnderlier, as_packed_field::PackedType, BinaryField1b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -38,7 +39,6 @@ struct Args { const COMPRESSION_LOG_LEN: usize = 5; fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -54,15 +54,11 @@ fn main() -> Result<()> { let log_n_compressions = log2_ceil_usize(args.n_compressions as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating witness").entered(); let input: [OracleId; 16] = array::try_from_fn(|i| { - unconstrained::<_, _, BinaryField1b>( - &mut builder, - i, - log_n_compressions + COMPRESSION_LOG_LEN, - ) + unconstrained::(&mut builder, i, log_n_compressions + COMPRESSION_LOG_LEN) })?; let _state_out = binius_circuits::lasso::sha256( diff --git a/examples/u32_add.rs b/examples/u32_add.rs index 2feed9ef..22c446ce 100644 --- a/examples/u32_add.rs +++ b/examples/u32_add.rs @@ -1,9 +1,12 @@ // Copyright 2024-2025 Irreducible Inc. use anyhow::Result; -use binius_circuits::{arithmetic::Flags, builder::ConstraintSystemBuilder}; +use binius_circuits::{ + arithmetic::Flags, + builder::{types::U, ConstraintSystemBuilder}, +}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b}; +use binius_field::BinaryField1b; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -24,7 +27,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -40,15 +42,15 @@ fn main() -> Result<()> { let log_n_additions = log2_ceil_usize(args.n_additions as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField1b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_additions + 5, )?; - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField1b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_additions + 5, diff --git a/examples/u32_mul.rs b/examples/u32_mul.rs index d7ecb1cf..8a8bfafc 100644 --- a/examples/u32_mul.rs +++ b/examples/u32_mul.rs @@ -4,7 +4,7 @@ use std::array; use anyhow::Result; use binius_circuits::{ - builder::ConstraintSystemBuilder, + builder::{types::U, ConstraintSystemBuilder}, lasso::{ batch::LookupBatch, big_integer_ops::byte_sliced_mul, @@ -14,9 +14,8 @@ use binius_circuits::{ }; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; use binius_field::{ - arch::OptimalUnderlier, tower_levels::{TowerLevel4, TowerLevel8}, - BinaryField128b, BinaryField1b, BinaryField32b, BinaryField8b, Field, + BinaryField1b, BinaryField32b, BinaryField8b, Field, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; @@ -38,7 +37,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -54,12 +52,12 @@ fn main() -> Result<()> { let log_n_muls = log2_ceil_usize(args.n_muls as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); // Assuming our input data is already transposed, i.e a length 4 array of B8's let in_a = array::from_fn(|i| { - binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + binius_circuits::unconstrained::unconstrained::( &mut builder, format!("in_a_{}", i), log_n_muls, @@ -67,7 +65,7 @@ fn main() -> Result<()> { .unwrap() }); let in_b = array::from_fn(|i| { - binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + binius_circuits::unconstrained::unconstrained::( &mut builder, format!("in_b_{}", i), log_n_muls, @@ -84,7 +82,7 @@ fn main() -> Result<()> { let mut lookup_batch_mul = LookupBatch::new([lookup_t_mul]); let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); - let _mul_and_cout = byte_sliced_mul::<_, _, TowerLevel4, TowerLevel8>( + let _mul_and_cout = byte_sliced_mul::( &mut builder, "lasso_bytesliced_mul", &in_a, @@ -95,9 +93,9 @@ fn main() -> Result<()> { &mut lookup_batch_add, &mut lookup_batch_dci, )?; - lookup_batch_mul.execute::(&mut builder)?; - lookup_batch_add.execute::(&mut builder)?; - lookup_batch_dci.execute::(&mut builder)?; + lookup_batch_mul.execute::(&mut builder)?; + lookup_batch_add.execute::(&mut builder)?; + lookup_batch_dci.execute::(&mut builder)?; drop(trace_gen_scope); diff --git a/examples/u32add_with_lookup.rs b/examples/u32add_with_lookup.rs index 1eba4ab7..9bc06cfb 100644 --- a/examples/u32add_with_lookup.rs +++ b/examples/u32add_with_lookup.rs @@ -1,11 +1,10 @@ // Copyright 2024-2025 Irreducible Inc. use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; use binius_field::{ - arch::OptimalUnderlier, as_packed_field::PackedType, BinaryField128b, BinaryField1b, - BinaryField8b, + arch::OptimalUnderlier, as_packed_field::PackedType, BinaryField1b, BinaryField8b, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; @@ -29,7 +28,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -45,20 +43,20 @@ fn main() -> Result<()> { let log_n_additions = log2_ceil_usize(args.n_additions as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_additions + 2, )?; - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_additions + 2, )?; - let _product = binius_circuits::lasso::u32add::<_, _, BinaryField8b, BinaryField8b>( + let _product = binius_circuits::lasso::u32add::( &mut builder, "out_c", in_a, diff --git a/examples/u8mul.rs b/examples/u8mul.rs index 90d706dc..cc6e0be9 100644 --- a/examples/u8mul.rs +++ b/examples/u8mul.rs @@ -2,11 +2,11 @@ use anyhow::Result; use binius_circuits::{ - builder::ConstraintSystemBuilder, + builder::{types::U, ConstraintSystemBuilder}, lasso::{batch::LookupBatch, lookups}, }; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField32b, BinaryField8b}; +use binius_field::{BinaryField32b, BinaryField8b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -27,7 +27,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -43,15 +42,15 @@ fn main() -> Result<()> { let log_n_multiplications = log2_ceil_usize(args.n_multiplications as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_multiplications, )?; - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_multiplications, @@ -70,7 +69,7 @@ fn main() -> Result<()> { args.n_multiplications as usize, )?; - lookup_batch.execute::<_, _, BinaryField32b>(&mut builder)?; + lookup_batch.execute::(&mut builder)?; drop(trace_gen_scope); let witness = builder diff --git a/examples/vision32b_circuit.rs b/examples/vision32b_circuit.rs index 61c74d76..b5afa6f5 100644 --- a/examples/vision32b_circuit.rs +++ b/examples/vision32b_circuit.rs @@ -10,11 +10,11 @@ use std::array; use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{ constraint_system, fiat_shamir::HasherChallenger, oracle::OracleId, tower::CanonicalTowerFamily, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField32b, BinaryField8b}; +use binius_field::{BinaryField32b, BinaryField8b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::IsomorphicEvaluationDomainFactory; @@ -35,7 +35,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -51,11 +50,11 @@ fn main() -> Result<()> { let log_n_permutations = log2_ceil_usize(args.n_permutations as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); let state_in: [OracleId; 24] = array::from_fn(|i| { - binius_circuits::unconstrained::unconstrained::<_, _, BinaryField32b>( + binius_circuits::unconstrained::unconstrained::( &mut builder, format!("p_in_{i}"), log_n_permutations, diff --git a/scripts/nightly_benchmarks.py b/scripts/nightly_benchmarks.py new file mode 100755 index 00000000..1c2f9ad0 --- /dev/null +++ b/scripts/nightly_benchmarks.py @@ -0,0 +1,265 @@ +#!/usr/bin/python3 + +import argparse +import csv +import json +import os +import re +import subprocess +from typing import Union + +ENV_VARS = { + "RUSTFLAGS": "-C target-cpu=native", +} + +SAMPLE_SIZE = 5 + +KECCAKF_PERMS = 1 << 13 +VISION32B_PERMS = 1 << 14 +SHA256_PERMS = 1 << 14 +NUM_BINARY_OPS = 1 << 22 +NUM_MULS = 1 << 20 + +HASHER_TO_RUN = { + r"keccakf": { + "type": "hasher", + "display": r"Keccak-f", + "export": "keccakf-report.csv", + "args": ["keccakf_circuit", "--", "--n-permutations"], + "n_ops": KECCAKF_PERMS, + }, + "vision32b": { + "type": "hasher", + "display": r"Vision Mark-32", + "export": "vision32b-report.csv", + "args": ["vision32b_circuit", "--", "--n-permutations"], + "n_ops": VISION32B_PERMS, + }, + "sha256": { + "type": "hasher", + "display": "SHA-256", + "export": "sha256-report.csv", + "args": ["sha256_circuit", "--", "--n-compressions"], + "n_ops": SHA256_PERMS, + }, + "b32_mul": { + "type": "binary_ops", + "display": "BinaryField32b mul", + "export": "b32-mul-report.csv", + "args": ["b32_mul", "--", "--n-ops"], + "n_ops": NUM_MULS, + }, + "u32_add": { + "type": "binary_ops", + "display": "u32 add", + "export": "u32-add-report.csv", + "args": ["u32_add", "--", "--n-additions"], + "n_ops": NUM_BINARY_OPS, + }, + "u32_mul": { + "type": "binary_ops", + "display": "u32 mul", + "export": "u32-mul-report.csv", + "args": ["u32_mul", "--", "--n-muls"], + "n_ops": NUM_MULS, + }, + "xor": { + "type": "binary_ops", + "display": "Xor", + "export": "xor-report.csv", + "args": ["bitwise_ops", "--", "--op", "xor", "--n-u32-ops"], + "n_ops": NUM_BINARY_OPS, + }, + "and": { + "type": "binary_ops", + "display": "And", + "export": "and-report.csv", + "args": ["bitwise_ops", "--", "--op", "and", "--n-u32-ops"], + "n_ops": NUM_BINARY_OPS, + }, + "or": { + "type": "binary_ops", + "display": "Or", + "export": "or-report.csv", + "args": ["bitwise_ops", "--", "--op", "or", "--n-u32-ops"], + "n_ops": NUM_BINARY_OPS, + }, +} + +HASHER_BENCHMARKS = {} +BINARY_OPS_BENCHMARKS = {} + + +def run_benchmark(benchmark_args) -> tuple[bytes, bytes]: + command = ( + ["cargo", "run", "--release", "--example"] + + benchmark_args["args"] + + [f"{benchmark_args['n_ops']}"] + ) + env_vars_to_run = { + **os.environ, + **ENV_VARS, + "PROFILE_CSV_FILE": benchmark_args["export"], + } + process = subprocess.run( + command, env=env_vars_to_run, capture_output=True, check=True + ) + return process.stdout, process.stderr + + +def parse_csv_file(file_name) -> dict: + data = {} + with open(file_name) as file: + reader = csv.reader(file) + for row in reader: + if row[0] == "generating trace": + data.update({"trace_gen_time": int(row[2])}) + elif row[0] == "constraint_system::prove": + data.update({"proving_time": int(row[2])}) + elif row[0] == "constraint_system::verify": + data.update({"verification_time": int(row[2])}) + return data + + +KIB_TO_BYTES = 1024.0 +MIB_TO_BYTES = KIB_TO_BYTES * 1024.0 +GIB_TO_BYTES = MIB_TO_BYTES * 1024.0 +KB_TO_BYTES = 1000.0 +MB_TO_BYTES = KB_TO_BYTES * 1000.0 +GB_TO_BYTES = MB_TO_BYTES * 1000.0 + +SIZE_CONVERSIONS = { + "KiB": KIB_TO_BYTES, + "MiB": MIB_TO_BYTES, + "GiB": GIB_TO_BYTES, + " B": 1, + "KB": KB_TO_BYTES, + "MB": MB_TO_BYTES, + "GB": GB_TO_BYTES, +} + + +def parse_proof_size(proof_size: bytes) -> int: + proof_size = proof_size.decode("utf-8").strip() + for unit, factor in SIZE_CONVERSIONS.items(): + if proof_size.endswith(unit): + byte_len = float(proof_size[: -len(unit)]) * factor + break + else: + raise ValueError(f"Unknown proof size format: {proof_size}") + + # Convert to KiB + return int(byte_len / KIB_TO_BYTES) + + +def nano_to_milli(nano) -> float: + return float(nano) / 1000000.0 + + +def nano_to_seconds(nano) -> float: + return float(nano) / 1000000000.0 + + +def run_and_parse_benchmark(benchmark, benchmark_args) -> tuple[dict, int]: + data = {} + stdout = None + print(f"Running benchmark: {benchmark} with {SAMPLE_SIZE} samples") + for _ in range(SAMPLE_SIZE): + stdout, _stderr = run_benchmark(benchmark_args) + result = parse_csv_file(benchmark_args["export"]) + # Parse the csv file + if len(result.keys()) != 3: + print(f"Failed to parse csv file for benchmark: {benchmark}") + exit(1) + + # Append the results to the data + for key, value in result.items(): + if data.get(key) is None: + data[key] = [] + data[key].append(value) + # Get proof sizes + found = re.search(rb"Proof size: (.*)", stdout) + if found: + return data, parse_proof_size(found.group(1)) + else: + print(f"Failed to get proof size for benchmark: {benchmark}") + exit(1) + + +def run_benchmark_group(benchmarks) -> dict: + benchmark_results = {} + for benchmark, benchmark_args in benchmarks.items(): + try: + data, proof_size = run_and_parse_benchmark(benchmark, benchmark_args) + benchmark_results[benchmark] = {"proof_size_kib": proof_size} + data["n_ops"] = benchmark_args["n_ops"] + data["display"] = benchmark_args["display"] + data["type"] = benchmark_args["type"] + benchmark_results[benchmark].update(data) + + except Exception as e: + print(f"Failed to run benchmark: {benchmark} with error {e} \nExiting...") + exit(1) + return benchmark_results + + +def value_to_bencher(value: Union[list[float], int], throughput: bool = False) -> dict: + if isinstance(value, list): + avg_value = sum(value) / len(value) + max_value = max(value) + min_value = min(value) + else: + avg_value = max_value = min_value = value + + metric_type = "throughput" if throughput else "latency" + return { + metric_type: { + "value": avg_value, + "upper_value": max_value, + "lower_value": min_value, + } + } + + +def dict_to_bencher(data: dict) -> dict: + bencher_data = {} + for benchmark, value in data.items(): + # Name is of the following format: ::::(trace_gen_time | proving_time | verification_time | proof_size_kib | n_ops) + common_name = f"{value['type']}::{value['display']}" + for key in [ + "trace_gen_time", + "proving_time", + "verification_time", + "proof_size_kib", + "n_ops", + ]: + bencher_data[f"{common_name}::{key}"] = value_to_bencher(value[key]) + return bencher_data + + +def main(): + parser = argparse.ArgumentParser( + description="Run nightly benchmarks and export results" + ) + parser.add_argument( + "--export-file", + required=False, + type=str, + help="Export benchmarks results to file (defaults to stdout)", + ) + + args = parser.parse_args() + + benchmarks = run_benchmark_group(HASHER_TO_RUN) + + bencher_data = dict_to_bencher(benchmarks) + if args.export_file is None: + print("Couldn't find export file for hashers writing to stdout instead") + print(json.dumps(bencher_data)) + else: + with open(args.export_file, "w") as file: + json.dump(bencher_data, file) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_tests_and_examples.sh b/scripts/run_tests_and_examples.sh index 85b0ffec..70da6e39 100755 --- a/scripts/run_tests_and_examples.sh +++ b/scripts/run_tests_and_examples.sh @@ -19,7 +19,7 @@ if [ -z "$CARGO_STABLE" ]; then # Execute examples. # Unfortunately there cargo doesn't support executing all examples with a single command. - # Cargo plugins such as "cargo-examples" do suport it but without a possibility to specify "release" profile. + # Cargo plugins such as "cargo-examples" do support it but without a possibility to specify "release" profile. for example in examples/*.rs do cargo run --profile $CARGO_PROFILE --example "$(basename "${example%.rs}")" $CARGO_EXTRA_FLAGS