Skip to content

Commit

Permalink
example: Linear combination with offset (#4)
Browse files Browse the repository at this point in the history
* example: Add linear-combination-with-offset usage example

* chore: Add example for bit masking using LinearCombination

* chore: Add byte decomposition constraint
  • Loading branch information
storojs72 authored Jan 31, 2025
1 parent f8b8c87 commit 0cf7bb6
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 0 deletions.
5 changes: 5 additions & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ path = "b32_mul.rs"
name = "acc-linear-combination"
path = "acc-linear-combination.rs"


[[example]]
name = "acc-linear-combination-with-offset"
path = "acc-linear-combination-with-offset.rs"

[lints.clippy]
needless_range_loop = "allow"

Expand Down
127 changes: 127 additions & 0 deletions examples/acc-linear-combination-with-offset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained};
use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId};
use binius_field::{
arch::OptimalUnderlier, packed::set_packed_slice, AESTowerField128b, AESTowerField8b,
BinaryField1b, ExtensionField, PackedField, TowerField,
};

type U = OptimalUnderlier;
type F128 = AESTowerField128b;
type F8 = AESTowerField8b;
type F1 = BinaryField1b;

fn aes_s_box(x: F8) -> F8 {
#[rustfmt::skip]
const S_BOX: [u8; 256] = [
0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5,
0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0,
0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc,
0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a,
0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0,
0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b,
0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85,
0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5,
0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17,
0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88,
0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c,
0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9,
0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6,
0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e,
0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94,
0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68,
0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
];
let idx = u8::from(x) as usize;
F8::from(S_BOX[idx])
}

// AES s-box is equivalent to the affine transformation defined as follows:
//
// s[i] = b[i] +
// b[(i+4) mod 8] +
// b[(i+5) mod 8] +
// b[(i+6) mod 8] +
// b[(i+7) mod 8] +
// c[i]
//
// where 'b' is input byte, 's' is output byte, 'c' is constant which is equal to 0x63 (0b01100011) and 'i' is a bit position.
// The '+' operation is defined over Rijndael finite field : GF(2^8) = GF(2) [x] / (x^8 + x^4 + x^3 + x + 1).
//
const C: F8 = F8::new(0x63);
const AES_AFFINE_TRANSFORMATION: [F8; 8] = [
F8::new(0b00011111),
F8::new(0b00111110),
F8::new(0b01111100),
F8::new(0b11111000),
F8::new(0b11110001),
F8::new(0b11100011),
F8::new(0b11000111),
F8::new(0b10001111),
];

fn main() {
let allocator = bumpalo::Bump::new();
let mut builder = ConstraintSystemBuilder::<U, F128>::new_with_witness(&allocator);

let log_size = 1usize;
let byte_in = unconstrained::<U, F128, F8>(&mut builder, "input_byte", log_size).unwrap();

let bits: [OracleId; 8] =
builder.add_committed_multiple("decomposition", log_size, F1::TOWER_LEVEL);

let byte_out = builder
.add_linear_combination_with_offset(
"lc",
log_size,
C.into(),
(0..8).map(|i| (bits[i], AES_AFFINE_TRANSFORMATION[i].into())),
)
.unwrap();

if let Some(witness) = builder.witness() {
// get initial values of input bytes
let byte_in_values = witness.get::<F8>(byte_in).unwrap().as_slice::<F8>();

// create column for expected values of the output bytes
let mut byte_out_witness = witness.new_column::<F8>(byte_out);
let byte_out_values = byte_out_witness.as_mut_slice::<F8>();

// For each (inverted!) input byte, write correspondent bits to the decomposition
let mut bits_witness = bits.map(|bit| witness.new_column::<F1>(bit));
let packed_bits = bits_witness.each_mut().map(|bit| bit.packed());

for byte_position in 0..byte_in_values.len() {
// write expected byte value to the output after applying s_box
byte_out_values[byte_position] = aes_s_box(byte_in_values[byte_position]);

// invert input byte and write it to a decomposition bits
let input_inverted = byte_in_values[byte_position].invert_or_zero();

let bases = ExtensionField::<F1>::iter_bases(&input_inverted);

for (bit_position, bit) in bases.clone().enumerate() {
set_packed_slice(packed_bits[bit_position], byte_position, bit);
}
}
}

let witness = builder.take_witness().unwrap();
let cs = builder.build().unwrap();

validate_witness(&cs, &[], &witness).unwrap();
}
152 changes: 152 additions & 0 deletions examples/acc-linear-combination.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ fn bytes_decomposition_gadget(
let output_bits: [OracleId; 8] =
builder.add_committed_multiple("output_bits", log_size, F1::TOWER_LEVEL);

let coefficients: [OracleId; 8] =
builder.add_committed_multiple("coeffs", log_size, F8::TOWER_LEVEL);

let coeff_vals = [
F8::new(0b00000001),
F8::new(0b00000010),
F8::new(0b00000100),
F8::new(0b00001000),
F8::new(0b00010000),
F8::new(0b00100000),
F8::new(0b01000000),
F8::new(0b10000000),
];

// Define `output` variable that will store `input` bytes (we will compare this in our constraint below).
// Since we want to enforce decomposition, we use `LinearCombination` column which naturally fits for this purpose.
// We need to specify our coefficients now and later take care of defining bit columns and setting bit values appropriately
Expand Down Expand Up @@ -63,6 +77,18 @@ fn bytes_decomposition_gadget(
// Get its memory
let output = output.as_mut_slice::<F8>();

// Write coefficients into the witness
let mut coeff_witness =
coefficients.map(|coefficient| witness.new_column::<F8>(coefficient));
let coeff_witness = coeff_witness
.each_mut()
.map(|coeff| coeff.as_mut_slice::<F8>());
for (idx, v) in coeff_witness.into_iter().enumerate() {
for vv in v {
*vv = coeff_vals[idx];
}
}

// For each byte from the `input` we need to just copy it to the `output` and also
// we need to perform actual decomposition and write it in a form of packed bits to the `output_bits`
for z in 0..input.len() {
Expand All @@ -83,6 +109,130 @@ fn bytes_decomposition_gadget(

// We just assert that every byte from `input` equals to correspondent byte from `output`
builder.assert_zero("s_box", [input, output], arith_expr!([i, o] = i - o).convert_field());

// Assert decomposition
builder.assert_zero(
"decomposition",
[
input,
output_bits[0],
output_bits[1],
output_bits[2],
output_bits[3],
output_bits[4],
output_bits[5],
output_bits[6],
output_bits[7],
coefficients[0],
coefficients[1],
coefficients[2],
coefficients[3],
coefficients[4],
coefficients[5],
coefficients[6],
coefficients[7],
],
arith_expr!(
[i, b0, b1, b2, b3, b4, b5, b6, b7, c0, c1, c2, c3, c4, c5, c6, c7] =
b0 * c0 + b1 * c1 + b2 * c2 + b3 * c3 + b4 * c4 + b5 * c5 + b6 * c6 + b7 * c7 - i
)
.convert_field(),
);

builder.pop_namespace();
Ok(output)
}

fn elder_4bits_masking_gadget(
builder: &mut ConstraintSystemBuilder<U, F128>,
name: impl ToString,
log_size: usize,
input: OracleId,
) -> Result<OracleId, anyhow::Error> {
builder.push_namespace(name);
let output_bits: [OracleId; 8] =
builder.add_committed_multiple("output_bits", log_size, F1::TOWER_LEVEL);

// we want to mask 4 elder bits in input byte
let lc_coefficients = [
F8::new(0b00000001),
F8::new(0b00000010),
F8::new(0b00000100),
F8::new(0b00001000),
F8::new(0b00000000),
F8::new(0b00000000),
F8::new(0b00000000),
F8::new(0b00000000),
];

let coefficients: [OracleId; 8] =
builder.add_committed_multiple("coeffs", log_size, F8::TOWER_LEVEL);

let output = builder.add_linear_combination(
"output",
log_size,
(0..8).map(|b| (output_bits[b], lc_coefficients[b].into())),
)?;

if let Some(witness) = builder.witness() {
// Write coefficients into the witness
let mut coeff_witness =
coefficients.map(|coefficient| witness.new_column::<F8>(coefficient));
let coeff_witness = coeff_witness
.each_mut()
.map(|coeff| coeff.as_mut_slice::<F8>());
for (idx, v) in coeff_witness.into_iter().enumerate() {
for vv in v {
*vv = lc_coefficients[idx];
}
}

let input = witness.get::<F8>(input)?.as_slice::<F8>();
let mut output_bits_witness: [_; 8] = output_bits.map(|id| witness.new_column::<F1>(id));
let output_bits = output_bits_witness.each_mut().map(|bit| bit.packed());
let mut output = witness.new_column::<F8>(output);
let output = output.as_mut_slice::<F8>();
for z in 0..input.len() {
// apply mask to the input byte
let byte_out_val = u8::from(input[z]) & 0x0F;
output[z] = F8::from(byte_out_val);

let input_bits_bases = ExtensionField::<F1>::iter_bases(&input[z]);
for (b, bit) in input_bits_bases.enumerate() {
set_packed_slice(output_bits[b], z, bit);
}
}
}

// Assert decomposition
builder.assert_zero(
"decomposition",
[
output,
output_bits[0],
output_bits[1],
output_bits[2],
output_bits[3],
output_bits[4],
output_bits[5],
output_bits[6],
output_bits[7],
coefficients[0],
coefficients[1],
coefficients[2],
coefficients[3],
coefficients[4],
coefficients[5],
coefficients[6],
coefficients[7],
],
arith_expr!(
[o, b0, b1, b2, b3, b4, b5, b6, b7, c0, c1, c2, c3, c4, c5, c6, c7] =
b0 * c0 + b1 * c1 + b2 * c2 + b3 * c3 + b4 * c4 + b5 * c5 + b6 * c6 + b7 * c7 - o
)
.convert_field(),
);

builder.pop_namespace();
Ok(output)
}
Expand All @@ -99,6 +249,8 @@ fn main() {
let _ =
bytes_decomposition_gadget(&mut builder, "bytes decomposition", log_size, p_in).unwrap();

let _ = elder_4bits_masking_gadget(&mut builder, "masking", log_size, p_in).unwrap();

let witness = builder.take_witness().unwrap();
let cs = builder.build().unwrap();

Expand Down

0 comments on commit 0cf7bb6

Please sign in to comment.