Skip to content

Commit

Permalink
]field] Byte-sliced fields changes (IrreducibleOSS#21)
Browse files Browse the repository at this point in the history
* Refactor a bit TowerLevels to remove packed field parameter from the TowerLevel to the Data associated type. This also makes generic bounds a bit more clean, since TowerLevel itself doesn't depend on a concrete packed field type.
* Add support of byte-sliced fields with arbitrary register size, i.e. 128b, 256b, 512b.
* Add shifts and unpack low/high within 128-bit lanes to UnderlierWithBitOps. This allows implementing transposition in an efficient way.
* Add the transparent implementation of UnderlierWithBitOps for PackedScaledUnderlier as we need it to re-use PackedScaledField.
  • Loading branch information
GraDKh authored Feb 24, 2025
1 parent 23b3eba commit e9991ce
Show file tree
Hide file tree
Showing 24 changed files with 1,072 additions and 474 deletions.
8 changes: 4 additions & 4 deletions crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ use crate::{
type B1 = BinaryField1b;
type B8 = BinaryField8b;

pub fn byte_sliced_add<Level: TowerLevel<OracleId, Data: Sized>>(
pub fn byte_sliced_add<Level: TowerLevel<Data<OracleId>: Sized>>(
builder: &mut ConstraintSystemBuilder,
name: impl ToString + Clone,
x_in: &Level::Data,
y_in: &Level::Data,
x_in: &Level::Data<OracleId>,
y_in: &Level::Data<OracleId>,
carry_in: OracleId,
log_size: usize,
lookup_batch_add: &mut LookupBatch,
) -> Result<(OracleId, Level::Data), anyhow::Error> {
) -> Result<(OracleId, Level::Data<OracleId>), anyhow::Error> {
if Level::WIDTH == 1 {
let (carry_out, sum) =
u8add(builder, lookup_batch_add, name, x_in[0], y_in[0], carry_in, log_size)?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ type B1 = BinaryField1b;
type B8 = BinaryField8b;

#[allow(clippy::too_many_arguments)]
pub fn byte_sliced_add_carryfree<Level: TowerLevel<OracleId, Data: Sized>>(
pub fn byte_sliced_add_carryfree<Level: TowerLevel<Data<OracleId>: Sized>>(
builder: &mut ConstraintSystemBuilder,
name: impl ToString,
x_in: &Level::Data,
y_in: &Level::Data,
x_in: &Level::Data<OracleId>,
y_in: &Level::Data<OracleId>,
carry_in: OracleId,
log_size: usize,
lookup_batch_add: &mut LookupBatch,
lookup_batch_add_carryfree: &mut LookupBatch,
) -> Result<Level::Data, anyhow::Error> {
) -> Result<Level::Data<OracleId>, anyhow::Error> {
if Level::WIDTH == 1 {
let sum = u8add_carryfree(
builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ type B1 = BinaryField1b;
type B8 = BinaryField8b;

#[allow(clippy::too_many_arguments)]
pub fn byte_sliced_double_conditional_increment<Level: TowerLevel<OracleId, Data: Sized>>(
pub fn byte_sliced_double_conditional_increment<Level: TowerLevel<Data<OracleId>: Sized>>(
builder: &mut ConstraintSystemBuilder,
name: impl ToString,
x_in: &Level::Data,
x_in: &Level::Data<OracleId>,
first_carry_in: OracleId,
second_carry_in: OracleId,
log_size: usize,
zero_oracle_carry: usize,
lookup_batch_dci: &mut LookupBatch,
) -> Result<(OracleId, Level::Data), anyhow::Error> {
) -> Result<(OracleId, Level::Data<OracleId>), anyhow::Error> {
if Level::WIDTH == 1 {
let (carry_out, sum) = u8_double_conditional_increment(
builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,16 @@ use crate::{
type B8 = BinaryField8b;

#[allow(clippy::too_many_arguments)]
pub fn byte_sliced_modular_mul<
LevelIn: TowerLevel<OracleId>,
LevelOut: TowerLevel<OracleId, Base = LevelIn>,
>(
pub fn byte_sliced_modular_mul<LevelIn: TowerLevel, LevelOut: TowerLevel<Base = LevelIn>>(
builder: &mut ConstraintSystemBuilder,
name: impl ToString,
mult_a: &LevelIn::Data,
mult_b: &LevelIn::Data,
mult_a: &LevelIn::Data<OracleId>,
mult_b: &LevelIn::Data<OracleId>,
modulus_input: &[u8],
log_size: usize,
zero_byte_oracle: OracleId,
zero_carry_oracle: OracleId,
) -> Result<LevelIn::Data, anyhow::Error> {
) -> Result<LevelIn::Data<OracleId>, anyhow::Error> {
builder.push_namespace(name);

let lookup_t_mul = mul_lookup(builder, "mul table")?;
Expand Down
11 changes: 4 additions & 7 deletions crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,17 @@ use crate::{
type B8 = BinaryField8b;

#[allow(clippy::too_many_arguments)]
pub fn byte_sliced_mul<
LevelIn: TowerLevel<OracleId>,
LevelOut: TowerLevel<OracleId, Base = LevelIn>,
>(
pub fn byte_sliced_mul<LevelIn: TowerLevel, LevelOut: TowerLevel<Base = LevelIn>>(
builder: &mut ConstraintSystemBuilder,
name: impl ToString,
mult_a: &LevelIn::Data,
mult_b: &LevelIn::Data,
mult_a: &LevelIn::Data<OracleId>,
mult_b: &LevelIn::Data<OracleId>,
log_size: usize,
zero_carry_oracle: OracleId,
lookup_batch_mul: &mut LookupBatch,
lookup_batch_add: &mut LookupBatch,
lookup_batch_dci: &mut LookupBatch,
) -> Result<LevelOut::Data, anyhow::Error> {
) -> Result<LevelOut::Data<OracleId>, anyhow::Error> {
if LevelIn::WIDTH == 1 {
let result_of_u8mul = u8mul_bytesliced(
builder,
Expand Down
31 changes: 13 additions & 18 deletions crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,12 @@ pub fn random_u512(rng: &mut impl Rng) -> U512 {

pub fn test_bytesliced_add<const WIDTH: usize, TL>()
where
TL: TowerLevel<OracleId, Data = [OracleId; WIDTH]>,
TL: TowerLevel,
{
test_circuit(|builder| {
let log_size = 14;
let x_in =
array::from_fn(|_| unconstrained::<BinaryField8b>(builder, "x", log_size).unwrap());
let y_in =
array::from_fn(|_| unconstrained::<BinaryField8b>(builder, "y", log_size).unwrap());
let x_in = TL::from_fn(|_| unconstrained::<BinaryField8b>(builder, "x", log_size).unwrap());
let y_in = TL::from_fn(|_| unconstrained::<BinaryField8b>(builder, "y", log_size).unwrap());
let c_in = unconstrained::<BinaryField1b>(builder, "cin first", log_size)?;
let lookup_t_add = add_lookup(builder, "add table")?;
let mut lookup_batch_add = LookupBatch::new([lookup_t_add]);
Expand All @@ -61,14 +59,14 @@ where

pub fn test_bytesliced_add_carryfree<const WIDTH: usize, TL>()
where
TL: TowerLevel<OracleId, Data = [OracleId; WIDTH]>,
TL: TowerLevel,
{
test_circuit(|builder| {
let log_size = 14;
let x_in =
array::from_fn(|_| builder.add_committed("x", log_size, BinaryField8b::TOWER_LEVEL));
TL::from_fn(|_| builder.add_committed("x", log_size, BinaryField8b::TOWER_LEVEL));
let y_in =
array::from_fn(|_| builder.add_committed("y", log_size, BinaryField8b::TOWER_LEVEL));
TL::from_fn(|_| builder.add_committed("y", log_size, BinaryField8b::TOWER_LEVEL));
let c_in = builder.add_committed("c", log_size, BinaryField1b::TOWER_LEVEL);

if let Some(witness) = builder.witness() {
Expand Down Expand Up @@ -136,12 +134,11 @@ where

pub fn test_bytesliced_double_conditional_increment<const WIDTH: usize, TL>()
where
TL: TowerLevel<OracleId, Data = [OracleId; WIDTH]>,
TL: TowerLevel,
{
test_circuit(|builder| {
let log_size = 14;
let x_in =
array::from_fn(|_| unconstrained::<BinaryField8b>(builder, "x", log_size).unwrap());
let x_in = TL::from_fn(|_| unconstrained::<BinaryField8b>(builder, "x", log_size).unwrap());
let first_c_in = unconstrained::<BinaryField1b>(builder, "cin first", log_size)?;
let second_c_in = unconstrained::<BinaryField1b>(builder, "cin second", log_size)?;
let zero_oracle_carry =
Expand All @@ -166,15 +163,14 @@ where

pub fn test_bytesliced_mul<const WIDTH: usize, TL>()
where
TL: TowerLevel<OracleId>,
TL::Base: TowerLevel<OracleId, Data = [OracleId; WIDTH]>,
TL: TowerLevel,
{
test_circuit(|builder| {
let log_size = 14;
let mult_a =
array::from_fn(|_| unconstrained::<BinaryField8b>(builder, "a", log_size).unwrap());
TL::Base::from_fn(|_| unconstrained::<BinaryField8b>(builder, "a", log_size).unwrap());
let mult_b =
array::from_fn(|_| unconstrained::<BinaryField8b>(builder, "b", log_size).unwrap());
TL::Base::from_fn(|_| unconstrained::<BinaryField8b>(builder, "b", log_size).unwrap());
let zero_oracle_carry =
transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?;
let lookup_t_mul = mul_lookup(builder, "mul lookup")?;
Expand All @@ -201,9 +197,8 @@ where

pub fn test_bytesliced_modular_mul<const WIDTH: usize, TL>()
where
TL: TowerLevel<OracleId>,
TL::Base: TowerLevel<OracleId, Data = [OracleId; WIDTH]>,
<TL as TowerLevel<usize>>::Data: Debug,
TL: TowerLevel<Data<usize>: Debug>,
TL::Base: TowerLevel<Data<usize> = [OracleId; WIDTH]>,
{
test_circuit(|builder| {
let log_size = 14;
Expand Down
54 changes: 48 additions & 6 deletions crates/field/benches/packed_field_element_access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
use std::array;

use binius_field::{
PackedBinaryField128x1b, PackedBinaryField16x32b, PackedBinaryField16x8b,
PackedBinaryField1x128b, PackedBinaryField256x1b, PackedBinaryField2x128b,
PackedBinaryField2x64b, PackedBinaryField32x8b, PackedBinaryField4x128b,
PackedBinaryField4x32b, PackedBinaryField4x64b, PackedBinaryField512x1b,
PackedBinaryField64x8b, PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField,
ByteSlicedAES16x128b, ByteSlicedAES16x16b, ByteSlicedAES16x32b, ByteSlicedAES16x64b,
ByteSlicedAES16x8b, ByteSlicedAES32x128b, ByteSlicedAES32x16b, ByteSlicedAES32x32b,
ByteSlicedAES32x64b, ByteSlicedAES32x8b, ByteSlicedAES64x128b, ByteSlicedAES64x16b,
ByteSlicedAES64x32b, ByteSlicedAES64x64b, ByteSlicedAES64x8b, PackedBinaryField128x1b,
PackedBinaryField16x32b, PackedBinaryField16x8b, PackedBinaryField1x128b,
PackedBinaryField256x1b, PackedBinaryField2x128b, PackedBinaryField2x64b,
PackedBinaryField32x8b, PackedBinaryField4x128b, PackedBinaryField4x32b,
PackedBinaryField4x64b, PackedBinaryField512x1b, PackedBinaryField64x8b,
PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField,
};
use criterion::{
criterion_group, criterion_main, measurement::WallTime, BenchmarkGroup, Criterion, Throughput,
Expand Down Expand Up @@ -86,5 +90,43 @@ fn packed_512(c: &mut Criterion) {
benchmark_get_set!(PackedBinaryField4x128b, group);
}

criterion_group!(get_set, packed_128, packed_256, packed_512);
fn byte_sliced_128(c: &mut Criterion) {
let mut group = c.benchmark_group("bytes_sliced_128");

benchmark_get_set!(ByteSlicedAES16x8b, group);
benchmark_get_set!(ByteSlicedAES16x16b, group);
benchmark_get_set!(ByteSlicedAES16x32b, group);
benchmark_get_set!(ByteSlicedAES16x64b, group);
benchmark_get_set!(ByteSlicedAES16x128b, group);
}

fn byte_sliced_256(c: &mut Criterion) {
let mut group = c.benchmark_group("bytes_sliced_256");

benchmark_get_set!(ByteSlicedAES32x8b, group);
benchmark_get_set!(ByteSlicedAES32x16b, group);
benchmark_get_set!(ByteSlicedAES32x32b, group);
benchmark_get_set!(ByteSlicedAES32x64b, group);
benchmark_get_set!(ByteSlicedAES32x128b, group);
}

fn byte_sliced_512(c: &mut Criterion) {
let mut group = c.benchmark_group("bytes_sliced_512");

benchmark_get_set!(ByteSlicedAES64x8b, group);
benchmark_get_set!(ByteSlicedAES64x16b, group);
benchmark_get_set!(ByteSlicedAES64x32b, group);
benchmark_get_set!(ByteSlicedAES64x64b, group);
benchmark_get_set!(ByteSlicedAES64x128b, group);
}

criterion_group!(
get_set,
packed_128,
packed_256,
packed_512,
byte_sliced_128,
byte_sliced_256,
byte_sliced_512
);
criterion_main!(get_set);
54 changes: 48 additions & 6 deletions crates/field/benches/packed_field_init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
use std::array;

use binius_field::{
PackedBinaryField128x1b, PackedBinaryField16x32b, PackedBinaryField16x8b,
PackedBinaryField1x128b, PackedBinaryField256x1b, PackedBinaryField2x128b,
PackedBinaryField2x64b, PackedBinaryField32x8b, PackedBinaryField4x128b,
PackedBinaryField4x32b, PackedBinaryField4x64b, PackedBinaryField512x1b,
PackedBinaryField64x8b, PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField,
ByteSlicedAES16x128b, ByteSlicedAES16x16b, ByteSlicedAES16x32b, ByteSlicedAES16x64b,
ByteSlicedAES16x8b, ByteSlicedAES32x128b, ByteSlicedAES32x16b, ByteSlicedAES32x32b,
ByteSlicedAES32x64b, ByteSlicedAES32x8b, ByteSlicedAES64x128b, ByteSlicedAES64x16b,
ByteSlicedAES64x32b, ByteSlicedAES64x64b, ByteSlicedAES64x8b, PackedBinaryField128x1b,
PackedBinaryField16x32b, PackedBinaryField16x8b, PackedBinaryField1x128b,
PackedBinaryField256x1b, PackedBinaryField2x128b, PackedBinaryField2x64b,
PackedBinaryField32x8b, PackedBinaryField4x128b, PackedBinaryField4x32b,
PackedBinaryField4x64b, PackedBinaryField512x1b, PackedBinaryField64x8b,
PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField,
};
use criterion::{
criterion_group, criterion_main, measurement::WallTime, BenchmarkGroup, Criterion, Throughput,
Expand Down Expand Up @@ -71,5 +75,43 @@ fn packed_512(c: &mut Criterion) {
benchmark_from_fn!(PackedBinaryField4x128b, group);
}

criterion_group!(initialization, packed_128, packed_256, packed_512);
fn byte_sliced_128(c: &mut Criterion) {
let mut group = c.benchmark_group("bytes_sliced_128");

benchmark_from_fn!(ByteSlicedAES16x8b, group);
benchmark_from_fn!(ByteSlicedAES16x16b, group);
benchmark_from_fn!(ByteSlicedAES16x32b, group);
benchmark_from_fn!(ByteSlicedAES16x64b, group);
benchmark_from_fn!(ByteSlicedAES16x128b, group);
}

fn byte_sliced_256(c: &mut Criterion) {
let mut group = c.benchmark_group("bytes_sliced_256");

benchmark_from_fn!(ByteSlicedAES32x8b, group);
benchmark_from_fn!(ByteSlicedAES32x16b, group);
benchmark_from_fn!(ByteSlicedAES32x32b, group);
benchmark_from_fn!(ByteSlicedAES32x64b, group);
benchmark_from_fn!(ByteSlicedAES32x128b, group);
}

fn byte_sliced_512(c: &mut Criterion) {
let mut group = c.benchmark_group("bytes_sliced_512");

benchmark_from_fn!(ByteSlicedAES64x8b, group);
benchmark_from_fn!(ByteSlicedAES64x16b, group);
benchmark_from_fn!(ByteSlicedAES64x32b, group);
benchmark_from_fn!(ByteSlicedAES64x64b, group);
benchmark_from_fn!(ByteSlicedAES64x128b, group);
}

criterion_group!(
initialization,
packed_128,
packed_256,
packed_512,
byte_sliced_128,
byte_sliced_256,
byte_sliced_512
);
criterion_main!(initialization);
12 changes: 12 additions & 0 deletions crates/field/benches/packed_field_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,23 @@ macro_rules! benchmark_packed_operation {
PackedBinaryPolyval4x128b

// Byte sliced AES fields
ByteSlicedAES16x8b
ByteSlicedAES16x16b
ByteSlicedAES16x32b
ByteSlicedAES16x64b
ByteSlicedAES16x128b

ByteSlicedAES32x8b
ByteSlicedAES32x16b
ByteSlicedAES32x32b
ByteSlicedAES32x64b
ByteSlicedAES32x128b

ByteSlicedAES64x8b
ByteSlicedAES64x16b
ByteSlicedAES64x32b
ByteSlicedAES64x64b
ByteSlicedAES64x128b
]);
};
}
Expand Down
38 changes: 36 additions & 2 deletions crates/field/src/arch/aarch64/m128.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use crate::{
arch::binary_utils::{as_array_mut, as_array_ref},
arithmetic_traits::Broadcast,
underlier::{
impl_divisible, impl_iteration, NumCast, Random, SmallU, UnderlierType,
UnderlierWithBitOps, WithUnderlier, U1, U2, U4,
impl_divisible, impl_iteration, unpack_lo_128b_fallback, NumCast, Random, SmallU,
UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4,
},
BinaryField,
};
Expand Down Expand Up @@ -337,6 +337,40 @@ impl UnderlierWithBitOps for M128 {
_ => panic!("unsupported bit count"),
}
}

#[inline(always)]
fn shl_128b_lanes(self, rhs: usize) -> Self {
Self(self.0 << rhs)
}

#[inline(always)]
fn shr_128b_lanes(self, rhs: usize) -> Self {
Self(self.0 >> rhs)
}

#[inline(always)]
fn unpack_lo_128b_lanes(self, rhs: Self, log_block_len: usize) -> Self {
match log_block_len {
0..3 => unpack_lo_128b_fallback(self, rhs, log_block_len),
3 => unsafe { vzip1q_u8(self.into(), rhs.into()).into() },
4 => unsafe { vzip1q_u16(self.into(), rhs.into()).into() },
5 => unsafe { vzip1q_u32(self.into(), rhs.into()).into() },
6 => unsafe { vzip1q_u64(self.into(), rhs.into()).into() },
_ => panic!("Unsupported block length"),
}
}

#[inline(always)]
fn unpack_hi_128b_lanes(self, rhs: Self, log_block_len: usize) -> Self {
match log_block_len {
0..3 => unpack_lo_128b_fallback(self, rhs, log_block_len),
3 => unsafe { vzip2q_u8(self.into(), rhs.into()).into() },
4 => unsafe { vzip2q_u16(self.into(), rhs.into()).into() },
5 => unsafe { vzip2q_u32(self.into(), rhs.into()).into() },
6 => unsafe { vzip2q_u64(self.into(), rhs.into()).into() },
_ => panic!("Unsupported block length"),
}
}
}

impl UnderlierWithBitConstants for M128 {
Expand Down
Loading

0 comments on commit e9991ce

Please sign in to comment.