diff --git a/ipa-core/src/ff/boolean_array.rs b/ipa-core/src/ff/boolean_array.rs index 3b6c3925f..63bb595d1 100644 --- a/ipa-core/src/ff/boolean_array.rs +++ b/ipa-core/src/ff/boolean_array.rs @@ -90,8 +90,8 @@ macro_rules! bitarr_one { // Macro for boolean arrays <= 128 bits. macro_rules! boolean_array_impl_small { - ($modname:ident, $name:ident, $bits:tt) => { - boolean_array_impl!($modname, $name, $bits); + ($modname:ident, $name:ident, $bits:tt, $deser_type:tt) => { + boolean_array_impl!($modname, $name, $bits, $deser_type); // TODO(812): remove this impl; BAs are not field elements. impl Field for $name { @@ -156,9 +156,103 @@ macro_rules! boolean_array_impl_small { }; } -//macro for implementing Boolean array, only works for a byte size for which Block is defined +#[derive(thiserror::Error, Debug)] +#[error("The provided byte slice contains non-zero value(s) {0:?} in padding bits [{1}..]")] +pub struct NonZeroPadding(GenericArray, usize); + +/// Macro to implement `Serializable` trait for boolean arrays. Depending on the size, conversion from [u8; N] to `BAXX` +/// can be fallible (if N is not a multiple of 8) or infallible. This macro takes care of it and provides the correct +/// implementation. Because macros can't do math, a hint is required to advise it which implementation it should provide. +macro_rules! impl_serializable_trait { + ($name: ident, $bits: tt, $store: ty, fallible) => { + impl Serializable for $name { + type Size = <$store as Block>::Size; + type DeserializationError = NonZeroPadding<$store>; + + fn serialize(&self, buf: &mut GenericArray) { + buf.copy_from_slice(self.0.as_raw_slice()); + } + + fn deserialize( + buf: &GenericArray, + ) -> Result { + let raw_val = <$store>::new(assert_copy(*buf).into()); + + // make sure trailing bits (padding) are zeroes. + if raw_val[$bits..].not_any() { + Ok(Self(raw_val)) + } else { + Err(NonZeroPadding( + GenericArray::from_array(raw_val.into_inner()), + $bits, + )) + } + } + } + + #[cfg(all(test, unit_test))] + mod fallible_serialization_tests { + use super::*; + + /// [`https://github.com/private-attribution/ipa/issues/911`] + #[test] + fn deals_with_padding() { + fn deserialize(val: $store) -> Result<$name, NonZeroPadding<$store>> { + $name::deserialize(&GenericArray::from_array(val.into_inner())) + } + + assert_ne!( + 0, + $bits % 8, + "Padding only makes sense for lengths that are not multiples of 8." + ); + + let mut non_zero_padding = $name::ZERO.0; + non_zero_padding.set($bits, true); + assert_eq!( + GenericArray::from_array(non_zero_padding.into_inner()), + deserialize(non_zero_padding).unwrap_err().0 + ); + + let min_value = $name::ZERO.0; + deserialize(min_value).unwrap(); + + let one = $name::ONE.0; + deserialize(one).unwrap(); + + let mut max_value = $name::ZERO.0; + max_value[..$bits].fill(true); + deserialize(max_value).unwrap(); + } + } + }; + + ($name: ident, $bits: tt, $store: ty, infallible) => { + const _SAFEGUARD: () = assert!( + $bits % 8 == 0, + "Infallible deserialization is defined for lengths that are multiples of 8 only" + ); + + impl Serializable for $name { + type Size = <$store as Block>::Size; + type DeserializationError = std::convert::Infallible; + + fn serialize(&self, buf: &mut GenericArray) { + buf.copy_from_slice(self.0.as_raw_slice()); + } + + fn deserialize( + buf: &GenericArray, + ) -> Result { + Ok(Self(<$store>::new(assert_copy(*buf).into()))) + } + } + }; +} + +// macro for implementing Boolean array, only works for a byte size for which Block is defined macro_rules! boolean_array_impl { - ($modname:ident, $name:ident, $bits:tt) => { + ($modname:ident, $name:ident, $bits:tt, $deser_type: tt) => { #[allow(clippy::suspicious_arithmetic_impl)] #[allow(clippy::suspicious_op_assign_impl)] mod $modname { @@ -175,7 +269,7 @@ macro_rules! boolean_array_impl { /// A Boolean array with $bits bits. #[derive(Clone, Copy, PartialEq, Eq, Debug)] - pub struct $name(pub Store); + pub struct $name(pub(super) Store); impl ArrayAccess for $name { type Output = Boolean; @@ -207,20 +301,7 @@ macro_rules! boolean_array_impl { const ZERO: Self = Self(::ZERO); } - impl Serializable for $name { - type Size = ::Size; - type DeserializationError = std::convert::Infallible; - - fn serialize(&self, buf: &mut GenericArray) { - buf.copy_from_slice(self.0.as_raw_slice()); - } - - fn deserialize( - buf: &GenericArray, - ) -> Result { - Ok(Self(::new(assert_copy(*buf).into()))) - } - } + impl_serializable_trait!($name, $bits, Store, $deser_type); impl std::ops::Add for $name { type Output = Self; @@ -372,6 +453,18 @@ macro_rules! boolean_array_impl { } } } + + #[test] + fn serde() { + let ba = thread_rng().gen::<$name>(); + let mut buf = GenericArray::default(); + ba.serialize(&mut buf); + assert_eq!( + ba, + $name::deserialize(&buf).unwrap(), + "Failed to deserialize a valid value: {ba:?}" + ); + } } } @@ -389,17 +482,17 @@ store_impl!(U14, 112); store_impl!(U32, 256); //impl BA3 -boolean_array_impl_small!(boolean_array_3, BA3, 3); -boolean_array_impl_small!(boolean_array_4, BA4, 4); -boolean_array_impl_small!(boolean_array_5, BA5, 5); -boolean_array_impl_small!(boolean_array_6, BA6, 6); -boolean_array_impl_small!(boolean_array_7, BA7, 7); -boolean_array_impl_small!(boolean_array_8, BA8, 8); -boolean_array_impl_small!(boolean_array_20, BA20, 20); -boolean_array_impl_small!(boolean_array_32, BA32, 32); -boolean_array_impl_small!(boolean_array_64, BA64, 64); -boolean_array_impl_small!(boolean_array_112, BA112, 112); -boolean_array_impl!(boolean_array_256, BA256, 256); +boolean_array_impl_small!(boolean_array_3, BA3, 3, fallible); +boolean_array_impl_small!(boolean_array_4, BA4, 4, fallible); +boolean_array_impl_small!(boolean_array_5, BA5, 5, fallible); +boolean_array_impl_small!(boolean_array_6, BA6, 6, fallible); +boolean_array_impl_small!(boolean_array_7, BA7, 7, fallible); +boolean_array_impl_small!(boolean_array_8, BA8, 8, infallible); +boolean_array_impl_small!(boolean_array_20, BA20, 20, fallible); +boolean_array_impl_small!(boolean_array_32, BA32, 32, infallible); +boolean_array_impl_small!(boolean_array_64, BA64, 64, infallible); +boolean_array_impl_small!(boolean_array_112, BA112, 112, infallible); +boolean_array_impl!(boolean_array_256, BA256, 256, infallible); // used to convert into Fp25519 impl From<(u128, u128)> for BA256 {