Skip to content

Commit

Permalink
Final version
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Jan 9, 2024
1 parent 57aa42e commit 7b912dc
Showing 1 changed file with 110 additions and 69 deletions.
179 changes: 110 additions & 69 deletions ipa-core/src/ff/boolean_array.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::fmt::{Binary, Formatter};

use bitvec::{
prelude::{bitarr, BitArr, Lsb0},
slice::Iter,
Expand Down Expand Up @@ -42,55 +44,57 @@ impl<'a> Iterator for BAIterator<'a> {
}
}

// TODO: indicate where and the source value
#[derive(thiserror::Error, Debug)]
#[error("The provided byte slice contains non-zero bits in padding")]
pub struct NonZeroPadding;

/// A value of ONE has a one in the first element of the bit array, followed by `$bits-1` zeros.
/// A value of ONE has a one in the first element of the bit array, followed by `$bits-1` $bit values.
/// This macro uses a bit of recursive repetition to produce those zeros.
///
/// The longest call is 8 bits, which involves `2(n+1)` macro expansions in addition to `bitarr!`.
macro_rules! bitarr_one {
macro_rules! bitarr_init {

// The binary value of `$bits-1` is expanded in MSB order for each of the values we care about.
// e.g., 20 =(-1)=> 19 =(binary)=> 0b10011 =(expand)=> 1 0 0 1 1

(2) => { bitarr_one!(1) };
(3) => { bitarr_one!(1 0) };
(4) => { bitarr_one!(1 1) };
(5) => { bitarr_one!(1 0 0) };
(6) => { bitarr_one!(1 0 1) };
(7) => { bitarr_one!(1 1 0) };
(8) => { bitarr_one!(1 1 1) };
(20) => { bitarr_one!(1 0 0 1 1) };
(32) => { bitarr_one!(1 1 1 1 1) };
(64) => { bitarr_one!(1 1 1 1 1 1) };
(112) => { bitarr_one!(1 1 0 1 1 1 1) };
(256) => { bitarr_one!(1 1 1 1 1 1 1 1) };

// Incrementally convert 1 or 0 into `[0,]` or `[]` as needed for the recursion step.
[$bit:tt; 2] => { bitarr_init!($bit, 1) };
[$bit:tt; 3] => { bitarr_init!($bit, 1 0) };
[$bit:tt; 4] => { bitarr_init!($bit, 1 1) };
[$bit:tt; 5] => { bitarr_init!($bit, 1 0 0) };
[$bit:tt; 6] => { bitarr_init!($bit, 1 0 1) };
[$bit:tt; 7] => { bitarr_init!($bit, 1 1 0) };
[$bit:tt; 8] => { bitarr_init!($bit, 1 1 1) };
[$bit:tt; 20] => { bitarr_init!($bit, 1 0 0 1 1) };
[$bit:tt; 32] => { bitarr_init!($bit, 1 1 1 1 1) };
[$bit:tt; 64] => { bitarr_init!($bit, 1 1 1 1 1 1) };
[$bit:tt; 112] => { bitarr_init!($bit, 1 1 0 1 1 1 1) };
[$bit:tt; 256] => { bitarr_init!($bit, 1 1 1 1 1 1 1 1) };

// Incrementally convert 1 or 0 into `[$bit,]` or `[]` as needed for the recursion step.
// This also reverses the bit order so that the MSB comes last, as needed for recursion.

// This passes a value back once the conversion is done.
($([$($x:tt)*])*) => { bitarr_one!(@r $([$($x)*])*) };
// This converts one 1 into `[0,]`.
($([$($x:tt)*])* 1 $($y:tt)*) => { bitarr_one!([0,] $([$($x)*])* $($y)*) };
($bit:tt, $([$($x:tt)*])*) => { bitarr_init!(@r $bit, $([$($x)*])*) };
// This converts one 1 into `[$bit,]`.
($bit:tt, $([$($x:tt)*])* 1 $($y:tt)*) => { bitarr_init!($bit, [$bit,] $([$($x)*])* $($y)*) };
// This converts one 0 into `[]`.
($([$($x:tt)*])* 0 $($y:tt)*) => { bitarr_one!([] $([$($x)*])* $($y)*) };
($bit:tt, $([$($x:tt)*])* 0 $($y:tt)*) => { bitarr_init!($bit, [] $([$($x)*])* $($y)*) };

// Recursion step.

// This is where recursion ends with a `BitArray`.
(@r [$($x:tt)*]) => { bitarr![const u8, Lsb0; 1, $($x)*] };
(@r $bit:tt, [$($x:tt)*]) => { bitarr![const u8, Lsb0; 1, $($x)*] };
// This is the recursion workhorse. It takes a list of lists. The outer lists are bracketed.
// The inner lists contain any form that can be repeated and concatenated, which probably
// means comma-separated values with a trailing comma.
// The first value is repeated once.
// The second value is repeated twice and merged into the first value.
// The third and subsequent values are repeated twice and shifted along one place.
// One-valued bits are represented as `[0,]`, zero-valued bits as `[]`.
(@r [$($x:tt)*] [$($y:tt)*] $([$($z:tt)*])*) => { bitarr_one!(@r [$($x)* $($y)* $($y)*] $([$($z)* $($z)*])*) };
// One-valued bits are represented as `[$bits,]`, zero-valued bits as `[]`.
(@r $bit:tt, [$($x:tt)*] [$($y:tt)*] $([$($z:tt)*])*) => { bitarr_init!(@r $bit, [$($x)* $($y)* $($y)*] $([$($z)* $($z)*])*) };
}

/// A value of ONE has a one in the first element of the bit array, followed by `$bits-1` zeros.
/// This macro uses a bit of recursive repetition to produce those zeros.
///
/// The longest call is 8 bits, which involves `2(n+1)` macro expansions in addition to `bitarr!`.
macro_rules! bitarr_one {
($bits: tt) => { bitarr_init![0; $bits] }
}

// Macro for boolean arrays <= 128 bits.
Expand Down Expand Up @@ -161,8 +165,26 @@ macro_rules! boolean_array_impl_small {
};
}

/// Macro to implement `Serializable` trait for boolean arrays. Depending on the size, conversion from [u8; N] to `BAXX`
/// can be fallible (if N is not a multiple of 8) or infallible. This macro takes care of it and provides the correct
/// implementation. Because macros can't do math, a hint is required to advise it which implementation it should provide.
#[rustfmt::skip]
macro_rules! impl_serializable_trait {
($name: ident, $bits: tt, fallible) => {
#[derive(thiserror::Error, Debug)]
#[error(
"The provided byte slice contains non-zero value(s) {0:?} in padding bits [{}..]",
$bits
)]
pub struct NonZeroPadding(Store);

/// The maximum value this boolean array can represent: `2^N - 1` where `N` is the number of bits.
/// Because of the [`restrictions`](https://docs.rs/bitvec/1.0.1/bitvec/array/struct.BitArray.html#usage),
/// for arrays with size `N` that is not a multiple of `8`, the `Store` capacity is greater than the
/// maximum value. For example: `BA20` (20 bit array) requires 3 bytes in memory and
/// store capacity (~16M) is greater than the maximum value for `BA20` which is (~1M).
const MAX_VALUE: Store = bitarr_init![1; $bits];

impl Serializable for $name {
type Size = <Store as Block>::Size;
type DeserializationError = NonZeroPadding;
Expand All @@ -174,12 +196,53 @@ macro_rules! impl_serializable_trait {
fn deserialize(
buf: &GenericArray<u8, Self::Size>,
) -> Result<Self, Self::DeserializationError> {
Ok(Self(<Store>::new(assert_copy(*buf).into())))
let raw_val = <Store>::new(assert_copy(*buf).into());

// make sure trailing bits (padding) are zeroes.
if MAX_VALUE == MAX_VALUE | raw_val {
Ok(Self(raw_val))
} else {
Err(NonZeroPadding(raw_val))
}
}
}

#[cfg(all(test, unit_test))]
mod fallible_serialization_tests {
use super::*;

/// [`https://github.com/private-attribution/ipa/issues/911`]
#[test]
fn deals_with_padding() {
assert_ne!(
0,
$bits % 8,
"Padding only makes sense for lengths that are not multiples of 8."
);
let mut non_zero_padding: Store = $name::ZERO.0;
non_zero_padding.set($bits, true);
let min_value: Store = $name::ZERO.0;
let one = $name::ONE.0;

let err =
$name::deserialize(&GenericArray::from_array(non_zero_padding.into_inner()))
.unwrap_err();
assert_eq!(non_zero_padding, err.0);
let _ =
$name::deserialize(&GenericArray::from_array(min_value.into_inner())).unwrap();
let _ = $name::deserialize(&GenericArray::from_array(one.into_inner())).unwrap();
let _ =
$name::deserialize(&GenericArray::from_array(MAX_VALUE.into_inner())).unwrap();
}
}
};

($name: ident, $bits: tt, infallible) => {
const _SAFEGUARD: () = assert!(
$bits % 8 == 0,
"Infallible deserialization is defined for lengths that are multiples of 8 only"
);

impl Serializable for $name {
type Size = <Store as Block>::Size;
type DeserializationError = std::convert::Infallible;
Expand Down Expand Up @@ -211,12 +274,12 @@ macro_rules! boolean_array_impl {
SharedValue,
},
};

type Store = BitArr!(for $bits, in u8, Lsb0);
// formatting does not set the indent properly
type Store = BitArr!(for $bits, in u8, Lsb0);

/// A Boolean array with $bits bits.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct $name(pub Store);
pub struct $name(pub(super) Store);

impl ArrayAccess for $name {
type Output = Boolean;
Expand Down Expand Up @@ -250,21 +313,6 @@ macro_rules! boolean_array_impl {

impl_serializable_trait!($name, $bits, $deser_type);

// impl Serializable for $name {
// type Size = <Store as Block>::Size;
// type DeserializationError = std::convert::Infallible;
//
// fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
// buf.copy_from_slice(self.0.as_raw_slice());
// }
//
// fn deserialize(
// buf: &GenericArray<u8, Self::Size>,
// ) -> Result<Self, Self::DeserializationError> {
// Ok(Self(<Store>::new(assert_copy(*buf).into())))
// }
// }

impl std::ops::Add for $name {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Expand Down Expand Up @@ -415,6 +463,18 @@ macro_rules! boolean_array_impl {
}
}
}

#[test]
fn serde() {
let ba = thread_rng().gen::<$name>();
let mut buf = GenericArray::default();
ba.serialize(&mut buf);
assert_eq!(
ba,
$name::deserialize(&buf).unwrap(),
"Failed to deserialize a valid value: {ba:?}"
);
}
}
}

Expand All @@ -433,15 +493,15 @@ store_impl!(U32, 256);

//impl BA3
boolean_array_impl_small!(boolean_array_3, BA3, 3, fallible);
boolean_array_impl_small!(boolean_array_4, BA4, 4, infallible);
boolean_array_impl_small!(boolean_array_4, BA4, 4, fallible);
boolean_array_impl_small!(boolean_array_5, BA5, 5, fallible);
boolean_array_impl_small!(boolean_array_6, BA6, 6, fallible);
boolean_array_impl_small!(boolean_array_7, BA7, 7, fallible);
boolean_array_impl_small!(boolean_array_8, BA8, 8, infallible);
boolean_array_impl_small!(boolean_array_20, BA20, 20, fallible);
boolean_array_impl_small!(boolean_array_32, BA32, 32, infallible);
boolean_array_impl_small!(boolean_array_64, BA64, 64, infallible);
boolean_array_impl_small!(boolean_array_112, BA112, 112, fallible);
boolean_array_impl_small!(boolean_array_112, BA112, 112, infallible);
boolean_array_impl!(boolean_array_256, BA256, 256, infallible);

// used to convert into Fp25519
Expand Down Expand Up @@ -469,22 +529,3 @@ impl rand::distributions::Distribution<BA256> for rand::distributions::Standard
(rng.gen(), rng.gen()).into()
}
}

impl Binary for BA3 {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Binary::fmt(&self.0, f)
}
}

#[cfg(all(test, unit_test))]
mod tests {
use super::*;

/// [`https://github.com/private-attribution/ipa/issues/911`]
#[test]
fn non_zero_padding_is_rejected() {
let src = 7_u8 | 1 << 5;

let _ = BA3::deserialize(&GenericArray::from_array([src])).unwrap_err();
}
}

0 comments on commit 7b912dc

Please sign in to comment.