Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add initial implementation of Blake3 G function gadget #7

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 221 additions & 0 deletions crates/circuits/src/acc_blake3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
use binius_core::oracle::{OracleId, ShiftVariant};
use binius_field::{
as_packed_field::PackScalar, underlier::UnderlierType, BinaryField1b, TowerField,
};
use bytemuck::Pod;

use crate::{arithmetic, arithmetic::Flags, builder::ConstraintSystemBuilder};

type F1 = BinaryField1b;

// Gadget that performs two u32 variables XOR and then rotates the result
fn xor_rotate_right<U, F>(
builder: &mut ConstraintSystemBuilder<U, F>,
name: impl ToString,
log_size: usize,
a: OracleId,
b: OracleId,
rotate_right_offset: u32,
) -> OracleId
where
U: PackScalar<F> + PackScalar<F1> + Pod,
F: TowerField,
{
assert!(rotate_right_offset <= 32);

builder.push_namespace(name);

let xor = builder
.add_linear_combination("xor", log_size, [(a, F::ONE), (b, F::ONE)])
.unwrap();

let rotate = builder
.add_shifted(
"rotate",
xor,
32 - rotate_right_offset as usize,
crate::sha256::LOG_U32_BITS,
ShiftVariant::CircularLeft,
)
.unwrap();

if let Some(witness) = builder.witness() {
let a_value = witness.get::<F1>(a).unwrap().as_slice::<u32>();
let b_value = witness.get::<F1>(b).unwrap().as_slice::<u32>();

let mut xor_witness = witness.new_column::<F1>(xor);
let xor_value = xor_witness.as_mut_slice::<u32>();

for (idx, v) in xor_value.iter_mut().enumerate() {
*v = a_value[idx] ^ b_value[idx];
}

let mut rotate_witness = witness.new_column::<F1>(rotate);
let rotate_value = rotate_witness.as_mut_slice::<u32>();
for (idx, v) in rotate_value.iter_mut().enumerate() {
*v = xor_value[idx].rotate_right(rotate_right_offset);
}
}

builder.pop_namespace();

rotate
}

#[allow(clippy::too_many_arguments)]
pub fn blake3_g<U, F>(
builder: &mut ConstraintSystemBuilder<U, F>,
a_in: OracleId,
b_in: OracleId,
c_in: OracleId,
d_in: OracleId,
mx: OracleId,
my: OracleId,
log_size: usize,
) -> Result<[OracleId; 4], anyhow::Error>
where
U: UnderlierType + Pod + PackScalar<F> + PackScalar<BinaryField1b>,
F: TowerField,
{
builder.push_namespace("blake3_g");

let a1 = arithmetic::u32::add3(builder, "a_in + b_in + mx", a_in, b_in, mx, Flags::Unchecked)?;

let d1 = xor_rotate_right(builder, "(d_in ^ a1).rotate_right(16)", log_size, d_in, a1, 16u32);

let c1 = arithmetic::u32::add(builder, "c_in + d1", c_in, d1, Flags::Unchecked)?;

let b1 = xor_rotate_right(builder, "(b_in ^ c1).rotate_right(12)", log_size, b_in, c1, 12u32);

let a2 =
arithmetic::u32::add3(builder, "a1 + b1 + my_in", a1, b1, my, Flags::Unchecked).unwrap();

let d2 = xor_rotate_right(builder, "(d1 ^ a2).rotate_right(8)", log_size, d1, a2, 8u32);

let c2 = arithmetic::u32::add(builder, "c1 + d2", c1, d2, Flags::Unchecked)?;

let b2 = xor_rotate_right(builder, "(b1 ^ c2).rotate_right(7)", log_size, b1, c2, 7u32);

builder.pop_namespace();

Ok([a2, b2, c2, d2])
}

#[cfg(test)]
mod tests {
use binius_core::constraint_system::validate::validate_witness;
use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b};
use binius_maybe_rayon::prelude::*;

use crate::{
acc_blake3::blake3_g,
builder::ConstraintSystemBuilder,
unconstrained::{unconstrained, variables_u32},
};

type U = OptimalUnderlier;
type F128 = BinaryField128b;
type F1 = BinaryField1b;

// The Blake3 mixing function, G, which mixes either a column or a diagonal.
// https://github.com/BLAKE3-team/BLAKE3/blob/master/reference_impl/reference_impl.rs
const fn g(
a_in: u32,
b_in: u32,
c_in: u32,
d_in: u32,
mx: u32,
my: u32,
) -> (u32, u32, u32, u32) {
let a1 = a_in.wrapping_add(b_in).wrapping_add(mx);
let d1 = (d_in ^ a1).rotate_right(16);
let c1 = c_in.wrapping_add(d1);
let b1 = (b_in ^ c1).rotate_right(12);

let a2 = a1.wrapping_add(b1).wrapping_add(my);
let d2 = (d1 ^ a2).rotate_right(8);
let c2 = c1.wrapping_add(d2);
let b2 = (b1 ^ c2).rotate_right(7);

(a2, b2, c2, d2)
}

#[test]
fn test_vector() {
// Let's use some fixed data input to check that our in-circuit computation
// produces same output as out-of-circuit one
let a = 0xaaaaaaaau32;
let b = 0xbbbbbbbbu32;
let c = 0xccccccccu32;
let d = 0xddddddddu32;
let mx = 0xffff00ffu32;
let my = 0xff00ffffu32;

let (expected_0, expected_1, expected_2, expected_3) = g(a, b, c, d, mx, my);

let log_size = 8usize;
let size = 1 << log_size;

let allocator = bumpalo::Bump::new();
let mut builder = ConstraintSystemBuilder::<U, F128>::new_with_witness(&allocator);

let a_in =
variables_u32::<U, F128, F1>(&mut builder, "a", log_size, vec![a; size]).unwrap();
let b_in =
variables_u32::<U, F128, F1>(&mut builder, "b", log_size, vec![b; size]).unwrap();
let c_in =
variables_u32::<U, F128, F1>(&mut builder, "c", log_size, vec![c; size]).unwrap();
let d_in =
variables_u32::<U, F128, F1>(&mut builder, "d", log_size, vec![d; size]).unwrap();
let mx_in =
variables_u32::<U, F128, F1>(&mut builder, "mx", log_size, vec![mx; size]).unwrap();
let my_in =
variables_u32::<U, F128, F1>(&mut builder, "my", log_size, vec![my; size]).unwrap();

let output =
blake3_g(&mut builder, a_in, b_in, c_in, d_in, mx_in, my_in, log_size).unwrap();

if let Some(witness) = builder.witness() {
(
witness.get::<F1>(output[0]).unwrap().as_slice::<u32>(),
witness.get::<F1>(output[1]).unwrap().as_slice::<u32>(),
witness.get::<F1>(output[2]).unwrap().as_slice::<u32>(),
witness.get::<F1>(output[3]).unwrap().as_slice::<u32>(),
)
.into_par_iter()
.for_each(|(actual_0, actual_1, actual_2, actual_3)| {
assert_eq!(*actual_0, expected_0);
assert_eq!(*actual_1, expected_1);
assert_eq!(*actual_2, expected_2);
assert_eq!(*actual_3, expected_3);
});
}

let witness = builder.take_witness().unwrap();
let constraints_system = builder.build().unwrap();

validate_witness(&constraints_system, &[], &witness).unwrap();
}

#[test]
fn test_random_input() {
let log_size = 8usize;

let allocator = bumpalo::Bump::new();
let mut builder = ConstraintSystemBuilder::<U, F128>::new_with_witness(&allocator);

let a_in = unconstrained::<U, F128, F1>(&mut builder, "a", log_size).unwrap();
let b_in = unconstrained::<U, F128, F1>(&mut builder, "b", log_size).unwrap();
let c_in = unconstrained::<U, F128, F1>(&mut builder, "c", log_size).unwrap();
let d_in = unconstrained::<U, F128, F1>(&mut builder, "d", log_size).unwrap();
let mx_in = unconstrained::<U, F128, F1>(&mut builder, "mx", log_size).unwrap();
let my_in = unconstrained::<U, F128, F1>(&mut builder, "my", log_size).unwrap();

blake3_g(&mut builder, a_in, b_in, c_in, d_in, mx_in, my_in, log_size).unwrap();

let witness = builder.take_witness().unwrap();
let constraints_system = builder.build().unwrap();

validate_witness(&constraints_system, &[], &witness).unwrap();
}
}
63 changes: 63 additions & 0 deletions crates/circuits/src/arithmetic/u32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,69 @@ where
Ok(zout)
}

// Gadget that adds three u32 at once
pub fn add3<U, F>(
builder: &mut ConstraintSystemBuilder<U, F>,
name: impl ToString,
xin: OracleId,
yin: OracleId,
zin: OracleId,
flags: super::Flags,
) -> Result<OracleId, anyhow::Error>
where
U: PackScalar<F> + PackScalar<BinaryField1b> + Pod,
F: TowerField,
{
builder.push_namespace(name);
let log_rows = builder.log_rows([xin, yin, zin])?;
let left = builder.add_linear_combination(
"left",
log_rows,
[(xin, F::ONE), (yin, F::ONE), (zin, F::ONE)],
)?;
let right = builder.add_committed("right", log_rows, BinaryField1b::TOWER_LEVEL);

if let Some(witness) = builder.witness() {
let x_vals = witness.get::<BinaryField1b>(xin)?.as_slice::<u32>();
let y_vals = witness.get::<BinaryField1b>(yin)?.as_slice::<u32>();
let z_vals = witness.get::<BinaryField1b>(zin)?.as_slice::<u32>();

let mut left_values = witness.new_column::<BinaryField1b>(left);
let mut right_values = witness.new_column::<BinaryField1b>(right);

// In order to reduce our task to a simpler two integers addition (that we have gadget for) we use a trick from
// https://stackoverflow.com/questions/26228262/how-does-this-function-sum-3-integers-using-only-bit-wise-operators
(x_vals, y_vals, z_vals, left_values.as_mut_slice(), right_values.as_mut_slice())
.into_par_iter()
.for_each(|(x, y, z, left, right)| {
*left = (*x ^ *y) ^ *z;
*right = (*x) & (*y) | (*x) & (*z) | (*y & *z);
});
}

// right << 1
let right_shifted = shl(builder, "right_shifted", right, 1)?;

builder.assert_zero(
"left",
[xin, yin, zin, left],
arith_expr!([x, y, z, left] = x + y + z - left).convert_field(),
);

// We apply following rule: a OR b = a XOR b XOR (a AND B) to the expression of 'right' column defined above.
builder.assert_zero(
"right",
[xin, yin, zin, right],
arith_expr!(
[x, y, z, right] = x * (y + z) + y * z * (1 + x * (1 + (y + z + x * y * z))) - right
)
.convert_field(),
);

builder.pop_namespace();
add(builder, "add3 -> add2", left, right_shifted, flags)
}

pub fn sub<U, F>(
builder: &mut ConstraintSystemBuilder<U, F>,
name: impl ToString,
Expand Down
1 change: 1 addition & 0 deletions crates/circuits/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#![feature(array_try_map, array_try_from_fn)]
#![allow(clippy::module_inception)]

pub mod acc_blake3;
pub mod arithmetic;
pub mod bitwise;
pub mod builder;
Expand Down
2 changes: 1 addition & 1 deletion crates/circuits/src/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use itertools::izip;

use crate::{arithmetic, builder::ConstraintSystemBuilder};

const LOG_U32_BITS: usize = checked_log_2(32);
pub const LOG_U32_BITS: usize = checked_log_2(32);

type B1 = BinaryField1b;

Expand Down
27 changes: 27 additions & 0 deletions crates/circuits/src/unconstrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,30 @@ where

Ok(rng)
}

pub fn variables_u32<U, F, FS>(
builder: &mut ConstraintSystemBuilder<U, F>,
name: impl ToString,
log_size: usize,
value: Vec<u32>,
) -> Result<OracleId, anyhow::Error>
where
U: UnderlierType + Pod + PackScalar<F> + PackScalar<FS>,
F: TowerField + ExtensionField<FS>,
FS: TowerField,
{
let rng = builder.add_committed(name, log_size, FS::TOWER_LEVEL);

if let Some(witness) = builder.witness() {
witness
.new_column::<FS>(rng)
.as_mut_slice::<u32>()
.into_par_iter()
.zip(value.into_par_iter())
.for_each(|(data, value)| {
*data = value;
});
}

Ok(rng)
}