From 7b912dce685e6480276d58f12812377ec820b80f Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 8 Jan 2024 17:23:20 -0800 Subject: [PATCH] Final version --- ipa-core/src/ff/boolean_array.rs | 179 +++++++++++++++++++------------ 1 file changed, 110 insertions(+), 69 deletions(-) diff --git a/ipa-core/src/ff/boolean_array.rs b/ipa-core/src/ff/boolean_array.rs index e56065041..fa029740b 100644 --- a/ipa-core/src/ff/boolean_array.rs +++ b/ipa-core/src/ff/boolean_array.rs @@ -1,3 +1,5 @@ +use std::fmt::{Binary, Formatter}; + use bitvec::{ prelude::{bitarr, BitArr, Lsb0}, slice::Iter, @@ -42,55 +44,57 @@ impl<'a> Iterator for BAIterator<'a> { } } -// TODO: indicate where and the source value -#[derive(thiserror::Error, Debug)] -#[error("The provided byte slice contains non-zero bits in padding")] -pub struct NonZeroPadding; - -/// A value of ONE has a one in the first element of the bit array, followed by `$bits-1` zeros. +/// A value of ONE has a one in the first element of the bit array, followed by `$bits-1` $bit values. /// This macro uses a bit of recursive repetition to produce those zeros. /// /// The longest call is 8 bits, which involves `2(n+1)` macro expansions in addition to `bitarr!`. -macro_rules! bitarr_one { +macro_rules! bitarr_init { // The binary value of `$bits-1` is expanded in MSB order for each of the values we care about. // e.g., 20 =(-1)=> 19 =(binary)=> 0b10011 =(expand)=> 1 0 0 1 1 - - (2) => { bitarr_one!(1) }; - (3) => { bitarr_one!(1 0) }; - (4) => { bitarr_one!(1 1) }; - (5) => { bitarr_one!(1 0 0) }; - (6) => { bitarr_one!(1 0 1) }; - (7) => { bitarr_one!(1 1 0) }; - (8) => { bitarr_one!(1 1 1) }; - (20) => { bitarr_one!(1 0 0 1 1) }; - (32) => { bitarr_one!(1 1 1 1 1) }; - (64) => { bitarr_one!(1 1 1 1 1 1) }; - (112) => { bitarr_one!(1 1 0 1 1 1 1) }; - (256) => { bitarr_one!(1 1 1 1 1 1 1 1) }; - - // Incrementally convert 1 or 0 into `[0,]` or `[]` as needed for the recursion step. + [$bit:tt; 2] => { bitarr_init!($bit, 1) }; + [$bit:tt; 3] => { bitarr_init!($bit, 1 0) }; + [$bit:tt; 4] => { bitarr_init!($bit, 1 1) }; + [$bit:tt; 5] => { bitarr_init!($bit, 1 0 0) }; + [$bit:tt; 6] => { bitarr_init!($bit, 1 0 1) }; + [$bit:tt; 7] => { bitarr_init!($bit, 1 1 0) }; + [$bit:tt; 8] => { bitarr_init!($bit, 1 1 1) }; + [$bit:tt; 20] => { bitarr_init!($bit, 1 0 0 1 1) }; + [$bit:tt; 32] => { bitarr_init!($bit, 1 1 1 1 1) }; + [$bit:tt; 64] => { bitarr_init!($bit, 1 1 1 1 1 1) }; + [$bit:tt; 112] => { bitarr_init!($bit, 1 1 0 1 1 1 1) }; + [$bit:tt; 256] => { bitarr_init!($bit, 1 1 1 1 1 1 1 1) }; + + // Incrementally convert 1 or 0 into `[$bit,]` or `[]` as needed for the recursion step. // This also reverses the bit order so that the MSB comes last, as needed for recursion. // This passes a value back once the conversion is done. - ($([$($x:tt)*])*) => { bitarr_one!(@r $([$($x)*])*) }; - // This converts one 1 into `[0,]`. - ($([$($x:tt)*])* 1 $($y:tt)*) => { bitarr_one!([0,] $([$($x)*])* $($y)*) }; + ($bit:tt, $([$($x:tt)*])*) => { bitarr_init!(@r $bit, $([$($x)*])*) }; + // This converts one 1 into `[$bit,]`. + ($bit:tt, $([$($x:tt)*])* 1 $($y:tt)*) => { bitarr_init!($bit, [$bit,] $([$($x)*])* $($y)*) }; // This converts one 0 into `[]`. - ($([$($x:tt)*])* 0 $($y:tt)*) => { bitarr_one!([] $([$($x)*])* $($y)*) }; + ($bit:tt, $([$($x:tt)*])* 0 $($y:tt)*) => { bitarr_init!($bit, [] $([$($x)*])* $($y)*) }; // Recursion step. // This is where recursion ends with a `BitArray`. - (@r [$($x:tt)*]) => { bitarr![const u8, Lsb0; 1, $($x)*] }; + (@r $bit:tt, [$($x:tt)*]) => { bitarr![const u8, Lsb0; 1, $($x)*] }; // This is the recursion workhorse. It takes a list of lists. The outer lists are bracketed. // The inner lists contain any form that can be repeated and concatenated, which probably // means comma-separated values with a trailing comma. // The first value is repeated once. // The second value is repeated twice and merged into the first value. // The third and subsequent values are repeated twice and shifted along one place. - // One-valued bits are represented as `[0,]`, zero-valued bits as `[]`. - (@r [$($x:tt)*] [$($y:tt)*] $([$($z:tt)*])*) => { bitarr_one!(@r [$($x)* $($y)* $($y)*] $([$($z)* $($z)*])*) }; + // One-valued bits are represented as `[$bits,]`, zero-valued bits as `[]`. + (@r $bit:tt, [$($x:tt)*] [$($y:tt)*] $([$($z:tt)*])*) => { bitarr_init!(@r $bit, [$($x)* $($y)* $($y)*] $([$($z)* $($z)*])*) }; +} + +/// A value of ONE has a one in the first element of the bit array, followed by `$bits-1` zeros. +/// This macro uses a bit of recursive repetition to produce those zeros. +/// +/// The longest call is 8 bits, which involves `2(n+1)` macro expansions in addition to `bitarr!`. +macro_rules! bitarr_one { + ($bits: tt) => { bitarr_init![0; $bits] } } // Macro for boolean arrays <= 128 bits. @@ -161,8 +165,26 @@ macro_rules! boolean_array_impl_small { }; } +/// Macro to implement `Serializable` trait for boolean arrays. Depending on the size, conversion from [u8; N] to `BAXX` +/// can be fallible (if N is not a multiple of 8) or infallible. This macro takes care of it and provides the correct +/// implementation. Because macros can't do math, a hint is required to advise it which implementation it should provide. +#[rustfmt::skip] macro_rules! impl_serializable_trait { ($name: ident, $bits: tt, fallible) => { + #[derive(thiserror::Error, Debug)] + #[error( + "The provided byte slice contains non-zero value(s) {0:?} in padding bits [{}..]", + $bits + )] + pub struct NonZeroPadding(Store); + + /// The maximum value this boolean array can represent: `2^N - 1` where `N` is the number of bits. + /// Because of the [`restrictions`](https://docs.rs/bitvec/1.0.1/bitvec/array/struct.BitArray.html#usage), + /// for arrays with size `N` that is not a multiple of `8`, the `Store` capacity is greater than the + /// maximum value. For example: `BA20` (20 bit array) requires 3 bytes in memory and + /// store capacity (~16M) is greater than the maximum value for `BA20` which is (~1M). + const MAX_VALUE: Store = bitarr_init![1; $bits]; + impl Serializable for $name { type Size = ::Size; type DeserializationError = NonZeroPadding; @@ -174,12 +196,53 @@ macro_rules! impl_serializable_trait { fn deserialize( buf: &GenericArray, ) -> Result { - Ok(Self(::new(assert_copy(*buf).into()))) + let raw_val = ::new(assert_copy(*buf).into()); + + // make sure trailing bits (padding) are zeroes. + if MAX_VALUE == MAX_VALUE | raw_val { + Ok(Self(raw_val)) + } else { + Err(NonZeroPadding(raw_val)) + } + } + } + + #[cfg(all(test, unit_test))] + mod fallible_serialization_tests { + use super::*; + + /// [`https://github.com/private-attribution/ipa/issues/911`] + #[test] + fn deals_with_padding() { + assert_ne!( + 0, + $bits % 8, + "Padding only makes sense for lengths that are not multiples of 8." + ); + let mut non_zero_padding: Store = $name::ZERO.0; + non_zero_padding.set($bits, true); + let min_value: Store = $name::ZERO.0; + let one = $name::ONE.0; + + let err = + $name::deserialize(&GenericArray::from_array(non_zero_padding.into_inner())) + .unwrap_err(); + assert_eq!(non_zero_padding, err.0); + let _ = + $name::deserialize(&GenericArray::from_array(min_value.into_inner())).unwrap(); + let _ = $name::deserialize(&GenericArray::from_array(one.into_inner())).unwrap(); + let _ = + $name::deserialize(&GenericArray::from_array(MAX_VALUE.into_inner())).unwrap(); } } }; ($name: ident, $bits: tt, infallible) => { + const _SAFEGUARD: () = assert!( + $bits % 8 == 0, + "Infallible deserialization is defined for lengths that are multiples of 8 only" + ); + impl Serializable for $name { type Size = ::Size; type DeserializationError = std::convert::Infallible; @@ -211,12 +274,12 @@ macro_rules! boolean_array_impl { SharedValue, }, }; - - type Store = BitArr!(for $bits, in u8, Lsb0); + // formatting does not set the indent properly + type Store = BitArr!(for $bits, in u8, Lsb0); /// A Boolean array with $bits bits. #[derive(Clone, Copy, PartialEq, Eq, Debug)] - pub struct $name(pub Store); + pub struct $name(pub(super) Store); impl ArrayAccess for $name { type Output = Boolean; @@ -250,21 +313,6 @@ macro_rules! boolean_array_impl { impl_serializable_trait!($name, $bits, $deser_type); - // impl Serializable for $name { - // type Size = ::Size; - // type DeserializationError = std::convert::Infallible; - // - // fn serialize(&self, buf: &mut GenericArray) { - // buf.copy_from_slice(self.0.as_raw_slice()); - // } - // - // fn deserialize( - // buf: &GenericArray, - // ) -> Result { - // Ok(Self(::new(assert_copy(*buf).into()))) - // } - // } - impl std::ops::Add for $name { type Output = Self; fn add(self, rhs: Self) -> Self::Output { @@ -415,6 +463,18 @@ macro_rules! boolean_array_impl { } } } + + #[test] + fn serde() { + let ba = thread_rng().gen::<$name>(); + let mut buf = GenericArray::default(); + ba.serialize(&mut buf); + assert_eq!( + ba, + $name::deserialize(&buf).unwrap(), + "Failed to deserialize a valid value: {ba:?}" + ); + } } } @@ -433,7 +493,7 @@ store_impl!(U32, 256); //impl BA3 boolean_array_impl_small!(boolean_array_3, BA3, 3, fallible); -boolean_array_impl_small!(boolean_array_4, BA4, 4, infallible); +boolean_array_impl_small!(boolean_array_4, BA4, 4, fallible); boolean_array_impl_small!(boolean_array_5, BA5, 5, fallible); boolean_array_impl_small!(boolean_array_6, BA6, 6, fallible); boolean_array_impl_small!(boolean_array_7, BA7, 7, fallible); @@ -441,7 +501,7 @@ boolean_array_impl_small!(boolean_array_8, BA8, 8, infallible); boolean_array_impl_small!(boolean_array_20, BA20, 20, fallible); boolean_array_impl_small!(boolean_array_32, BA32, 32, infallible); boolean_array_impl_small!(boolean_array_64, BA64, 64, infallible); -boolean_array_impl_small!(boolean_array_112, BA112, 112, fallible); +boolean_array_impl_small!(boolean_array_112, BA112, 112, infallible); boolean_array_impl!(boolean_array_256, BA256, 256, infallible); // used to convert into Fp25519 @@ -469,22 +529,3 @@ impl rand::distributions::Distribution for rand::distributions::Standard (rng.gen(), rng.gen()).into() } } - -impl Binary for BA3 { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - Binary::fmt(&self.0, f) - } -} - -#[cfg(all(test, unit_test))] -mod tests { - use super::*; - - /// [`https://github.com/private-attribution/ipa/issues/911`] - #[test] - fn non_zero_padding_is_rejected() { - let src = 7_u8 | 1 << 5; - - let _ = BA3::deserialize(&GenericArray::from_array([src])).unwrap_err(); - } -} \ No newline at end of file