Skip to content

Commit

Permalink
feat: Blake3 G function gadget
Browse files Browse the repository at this point in the history
  • Loading branch information
storojs72 committed Feb 7, 2025
1 parent 8ef58ed commit 48272f9
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 7 deletions.
213 changes: 213 additions & 0 deletions crates/circuits/src/acc_blake3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
use binius_core::oracle::{OracleId, ShiftVariant};
use binius_field::{
as_packed_field::PackScalar, underlier::UnderlierType, BinaryField1b, TowerField,
};
use bytemuck::Pod;

use crate::{arithmetic, arithmetic::Flags, builder::ConstraintSystemBuilder};

type F1 = BinaryField1b;

// Gadget that performs two u32 variables XOR and then rotates the result
fn xor_rotate_right<U, F>(
builder: &mut ConstraintSystemBuilder<U, F>,
name: impl ToString,
log_size: usize,
a: OracleId,
b: OracleId,
rotate_right_offset: u32,
) -> OracleId
where
U: PackScalar<F> + PackScalar<F1> + Pod,
F: TowerField,
{
assert!(rotate_right_offset <= 32);

builder.push_namespace(name);

let xor = builder
.add_linear_combination("xor", log_size, [(a, F::ONE), (b, F::ONE)])
.unwrap();

let rotate = builder
.add_shifted(
"rotate",
xor,
32 - rotate_right_offset as usize,
crate::sha256::LOG_U32_BITS,
ShiftVariant::CircularLeft,
)
.unwrap();

if let Some(witness) = builder.witness() {
let a_value = witness.get::<F1>(a).unwrap().as_slice::<u32>();
let b_value = witness.get::<F1>(b).unwrap().as_slice::<u32>();

let mut xor_witness = witness.new_column::<F1>(xor);
let xor_value = xor_witness.as_mut_slice::<u32>();

for (idx, v) in xor_value.into_iter().enumerate() {
*v = a_value[idx] ^ b_value[idx];
}

let mut rotate_witness = witness.new_column::<F1>(rotate);
let rotate_value = rotate_witness.as_mut_slice::<u32>();
for (idx, v) in rotate_value.into_iter().enumerate() {
*v = xor_value[idx].rotate_right(rotate_right_offset);
}
}

builder.pop_namespace();

rotate
}

pub fn blake3_g<U, F>(
builder: &mut ConstraintSystemBuilder<U, F>,
a_in: OracleId,
b_in: OracleId,
c_in: OracleId,
d_in: OracleId,
mx: OracleId,
my: OracleId,
log_size: usize,
) -> Result<[OracleId; 4], anyhow::Error>
where
U: UnderlierType + Pod + PackScalar<F> + PackScalar<BinaryField1b>,
F: TowerField,
{
builder.push_namespace("blake3_g");

let a1 = arithmetic::u32::add3(builder, "a_in + b_in + mx", a_in, b_in, mx, Flags::Unchecked)?;

let d1 = xor_rotate_right(builder, "(d_in ^ a1).rotate_right(16)", log_size, d_in, a1, 16u32);

let c1 = arithmetic::u32::add(builder, "c_in + d1", c_in, d1, Flags::Unchecked)?;

let b1 = xor_rotate_right(builder, "(b_in ^ c1).rotate_right(12)", log_size, b_in, c1, 12u32);

let a2 =
arithmetic::u32::add3(builder, "a1 + b1 + my_in", a1, b1, my, Flags::Unchecked).unwrap();

let d2 = xor_rotate_right(builder, "(d1 ^ a2).rotate_right(8)", log_size, d1, a2, 8u32);

let c2 = arithmetic::u32::add(builder, "c1 + d2", c1, d2, Flags::Unchecked)?;

let b2 = xor_rotate_right(builder, "(b1 ^ c2).rotate_right(7)", log_size, b1, c2, 7u32);

builder.pop_namespace();

Ok([a2, b2, c2, d2])
}

#[cfg(test)]
mod tests {
use binius_core::constraint_system::validate::validate_witness;
use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b};
use binius_maybe_rayon::prelude::*;

use crate::{
acc_blake3::blake3_g,
builder::ConstraintSystemBuilder,
unconstrained::{unconstrained, variables_u32},
};

type U = OptimalUnderlier;
type F128 = BinaryField128b;
type F1 = BinaryField1b;

// The Blake3 mixing function, G, which mixes either a column or a diagonal.
// https://github.com/BLAKE3-team/BLAKE3/blob/master/reference_impl/reference_impl.rs
fn g(a_in: u32, b_in: u32, c_in: u32, d_in: u32, mx: u32, my: u32) -> (u32, u32, u32, u32) {
let a1 = a_in.wrapping_add(b_in).wrapping_add(mx);
let d1 = (d_in ^ a1).rotate_right(16);
let c1 = c_in.wrapping_add(d1);
let b1 = (b_in ^ c1).rotate_right(12);

let a2 = a1.wrapping_add(b1).wrapping_add(my);
let d2 = (d1 ^ a2).rotate_right(8);
let c2 = c1.wrapping_add(d2);
let b2 = (b1 ^ c2).rotate_right(7);

(a2, b2, c2, d2)
}

#[test]
fn test_vector() {
// Let's use some fixed data input to check that our in-circuit computation
// produces same output as out-of-circuit one
let a = 0xaaaaaaaau32;
let b = 0xbbbbbbbbu32;
let c = 0xccccccccu32;
let d = 0xddddddddu32;
let mx = 0xffff00ffu32;
let my = 0xff00ffffu32;

let (expected_0, expected_1, expected_2, expected_3) = g(a, b, c, d, mx, my);

let log_size = 8usize;
let size = 1 << log_size;

let allocator = bumpalo::Bump::new();
let mut builder = ConstraintSystemBuilder::<U, F128>::new_with_witness(&allocator);

let a_in =
variables_u32::<U, F128, F1>(&mut builder, "a", log_size, vec![a; size]).unwrap();
let b_in =
variables_u32::<U, F128, F1>(&mut builder, "b", log_size, vec![b; size]).unwrap();
let c_in =
variables_u32::<U, F128, F1>(&mut builder, "c", log_size, vec![c; size]).unwrap();
let d_in =
variables_u32::<U, F128, F1>(&mut builder, "d", log_size, vec![d; size]).unwrap();
let mx_in =
variables_u32::<U, F128, F1>(&mut builder, "mx", log_size, vec![mx; size]).unwrap();
let my_in =
variables_u32::<U, F128, F1>(&mut builder, "my", log_size, vec![my; size]).unwrap();

let output =
blake3_g(&mut builder, a_in, b_in, c_in, d_in, mx_in, my_in, log_size).unwrap();

if let Some(witness) = builder.witness() {
(
witness.get::<F1>(output[0]).unwrap().as_slice::<u32>(),
witness.get::<F1>(output[1]).unwrap().as_slice::<u32>(),
witness.get::<F1>(output[2]).unwrap().as_slice::<u32>(),
witness.get::<F1>(output[3]).unwrap().as_slice::<u32>(),
)
.into_par_iter()
.for_each(|(actual_0, actual_1, actual_2, actual_3)| {
assert_eq!(*actual_0, expected_0);
assert_eq!(*actual_1, expected_1);
assert_eq!(*actual_2, expected_2);
assert_eq!(*actual_3, expected_3);
});
}

let witness = builder.take_witness().unwrap();
let constraints_system = builder.build().unwrap();

validate_witness(&constraints_system, &[], &witness).unwrap();
}

#[test]
fn test_random_input() {
let log_size = 8usize;

let allocator = bumpalo::Bump::new();
let mut builder = ConstraintSystemBuilder::<U, F128>::new_with_witness(&allocator);

let a_in = unconstrained::<U, F128, F1>(&mut builder, "a", log_size).unwrap();
let b_in = unconstrained::<U, F128, F1>(&mut builder, "b", log_size).unwrap();
let c_in = unconstrained::<U, F128, F1>(&mut builder, "c", log_size).unwrap();
let d_in = unconstrained::<U, F128, F1>(&mut builder, "d", log_size).unwrap();
let mx_in = unconstrained::<U, F128, F1>(&mut builder, "mx", log_size).unwrap();
let my_in = unconstrained::<U, F128, F1>(&mut builder, "my", log_size).unwrap();

blake3_g(&mut builder, a_in, b_in, c_in, d_in, mx_in, my_in, log_size).unwrap();

let witness = builder.take_witness().unwrap();
let constraints_system = builder.build().unwrap();

validate_witness(&constraints_system, &[], &witness).unwrap();
}
}
11 changes: 5 additions & 6 deletions crates/circuits/src/arithmetic/u32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ pub fn add3<U, F>(
zin: OracleId,
flags: super::Flags,
) -> Result<OracleId, anyhow::Error>
where
U: PackScalar<F> + PackScalar<BinaryField1b> + Pod,
F: TowerField,
where
U: PackScalar<F> + PackScalar<BinaryField1b> + Pod,
F: TowerField,
{
builder.push_namespace(name);
let log_rows = builder.log_rows([xin, yin, zin])?;
Expand Down Expand Up @@ -205,10 +205,9 @@ pub fn add3<U, F>(
"right",
[xin, yin, zin, right],
arith_expr!(
[x, y, z, right] =
x * (y + z) + y * z * (1 + x * (1 + (y + z + x * y * z))) - right
[x, y, z, right] = x * (y + z) + y * z * (1 + x * (1 + (y + z + x * y * z))) - right
)
.convert_field(),
.convert_field(),
);

builder.pop_namespace();
Expand Down
1 change: 1 addition & 0 deletions crates/circuits/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#![feature(array_try_map, array_try_from_fn)]
#![allow(clippy::module_inception)]

pub mod acc_blake3;
pub mod arithmetic;
pub mod bitwise;
pub mod builder;
Expand Down
2 changes: 1 addition & 1 deletion crates/circuits/src/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use itertools::izip;

use crate::{arithmetic, builder::ConstraintSystemBuilder};

const LOG_U32_BITS: usize = checked_log_2(32);
pub const LOG_U32_BITS: usize = checked_log_2(32);

type B1 = BinaryField1b;

Expand Down
27 changes: 27 additions & 0 deletions crates/circuits/src/unconstrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,30 @@ where

Ok(rng)
}

pub fn variables_u32<U, F, FS>(
builder: &mut ConstraintSystemBuilder<U, F>,
name: impl ToString,
log_size: usize,
value: Vec<u32>,
) -> Result<OracleId, anyhow::Error>
where
U: UnderlierType + Pod + PackScalar<F> + PackScalar<FS>,
F: TowerField + ExtensionField<FS>,
FS: TowerField,
{
let rng = builder.add_committed(name, log_size, FS::TOWER_LEVEL);

if let Some(witness) = builder.witness() {
witness
.new_column::<FS>(rng)
.as_mut_slice::<u32>()
.into_par_iter()
.zip(value.into_par_iter())
.for_each(|(data, value)| {
*data = value;
});
}

Ok(rng)
}

0 comments on commit 48272f9

Please sign in to comment.