Skip to content

Commit

Permalink
Merge pull request #905 from akoshelev/fallible-deser
Browse files Browse the repository at this point in the history
Fallible deserialization
  • Loading branch information
akoshelev authored Jan 5, 2024
2 parents 3c5d48b + 5c64f50 commit 432a2da
Show file tree
Hide file tree
Showing 24 changed files with 353 additions and 149 deletions.
6 changes: 5 additions & 1 deletion ipa-core/src/cli/playbook/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@ where
.unwrap();

let results: Vec<F> = results
.map(|bytes| AdditiveShare::<F>::from_byte_slice(&bytes).collect::<Vec<_>>())
.map(|bytes| {
AdditiveShare::<F>::from_byte_slice(&bytes)
.collect::<Result<Vec<_>, _>>()
.unwrap()
})
.reconstruct();

let lat = mpc_time.elapsed();
Expand Down
6 changes: 5 additions & 1 deletion ipa-core/src/cli/playbook/multiply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ where

// expect replicated shares to be sent back
results
.map(|bytes| Replicated::<F>::from_byte_slice(&bytes).collect::<Vec<_>>())
.map(|bytes| {
Replicated::<F>::from_byte_slice(&bytes)
.collect::<Result<Vec<_>, _>>()
.unwrap()
})
.reconstruct()
}
22 changes: 21 additions & 1 deletion ipa-core/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{backtrace::Backtrace, fmt::Debug};
use std::{backtrace::Backtrace, convert::Infallible, fmt::Debug};

use thiserror::Error;

Expand Down Expand Up @@ -108,3 +108,23 @@ pub fn set_global_panic_hook() {
(default_hook)(panic_info);
}));
}

/// Same purpose as [`unwrap-infallible`] but fewer dependencies.
/// As usual, there is a 8 year old [`RFC`] to make this to std that hasn't been merged yet.
///
/// [`unwrap-infallible`]: https://docs.rs/unwrap-infallible/latest/unwrap_infallible
/// [`RFC`]: https://github.com/rust-lang/rfcs/issues/1723
pub trait UnwrapInfallible {
/// R to avoid clashing with result's `Ok` type.
type R;

fn unwrap_infallible(self) -> Self::R;
}

impl<T> UnwrapInfallible for Result<T, Infallible> {
type R = T;

fn unwrap_infallible(self) -> Self::R {
self.unwrap()
}
}
17 changes: 11 additions & 6 deletions ipa-core/src/ff/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,23 @@ impl From<Boolean> for bool {
}
}

#[derive(thiserror::Error, Debug)]
#[error("{0} is not a valid boolean value, only 0 and 1 are accepted.")]
pub struct ParseBooleanError(u8);

impl Serializable for Boolean {
type Size = <<Boolean as SharedValue>::Storage as Block>::Size;
type DeserializationError = ParseBooleanError;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
buf[0] = u8::from(self.0);
}

///## Panics
/// panics when u8 is not 0 or 1
fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Self {
assert!(buf[0] < 2u8);
Boolean(buf[0] != 0)
fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Result<Self, Self::DeserializationError> {
if buf[0] > 1 {
return Err(ParseBooleanError(buf[0]));
}
Ok(Boolean(buf[0] != 0))
}
}

Expand Down Expand Up @@ -180,7 +185,7 @@ mod test {
let input = rng.gen::<Boolean>();
let mut a: GenericArray<u8, U1> = [0u8; 1].into();
input.serialize(&mut a);
let output = Boolean::deserialize(&a);
let output = Boolean::deserialize(&a).unwrap();
assert_eq!(input, output);
}

Expand Down
9 changes: 6 additions & 3 deletions ipa-core/src/ff/boolean_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,16 @@ macro_rules! boolean_array_impl {

impl Serializable for $name {
type Size = <Store as Block>::Size;
type DeserializationError = std::convert::Infallible;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
buf.copy_from_slice(self.0.as_raw_slice());
}

fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Self {
Self(<Store>::new(assert_copy(*buf).into()))
fn deserialize(
buf: &GenericArray<u8, Self::Size>,
) -> Result<Self, Self::DeserializationError> {
Ok(Self(<Store>::new(assert_copy(*buf).into())))
}
}

Expand Down Expand Up @@ -407,7 +410,7 @@ impl From<(u128, u128)> for BA256 {
.into_iter()
.chain(value.1.to_le_bytes());
let arr = GenericArray::<u8, U32>::try_from_iter(iter).unwrap();
BA256::deserialize(&arr)
BA256::deserialize_infallible(&arr)
}
}

Expand Down
27 changes: 23 additions & 4 deletions ipa-core/src/ff/curve_points.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,25 @@ impl SharedValue for RP25519 {
const ZERO: Self = Self(CompressedRistretto([0_u8; 32]));
}

#[derive(thiserror::Error, Debug)]
#[error("{0:?} is not the canonical encoding of a Ristretto point.")]
pub struct NonCanonicalEncoding(CompressedRistretto);

impl Serializable for RP25519 {
type Size = <<RP25519 as SharedValue>::Storage as Block>::Size;
type DeserializationError = NonCanonicalEncoding;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
*buf.as_mut() = self.0.to_bytes();
}

fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Self {
debug_assert!(CompressedRistretto((*buf).into()).decompress().is_some());
RP25519(CompressedRistretto((*buf).into()))
fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Result<Self, Self::DeserializationError> {
let point = CompressedRistretto((*buf).into());
if cfg!(debug_assertions) && point.decompress().is_none() {
return Err(NonCanonicalEncoding(point));
}

Ok(RP25519(point))
}
}

Expand Down Expand Up @@ -180,6 +189,7 @@ mod test {
use rand::{thread_rng, Rng};
use typenum::U32;

use super::*;
use crate::{
ff::{curve_points::RP25519, ec_prime_field::Fp25519, Serializable},
secret_sharing::SharedValue,
Expand All @@ -194,7 +204,7 @@ mod test {
let input = rng.gen::<RP25519>();
let mut a: GenericArray<u8, U32> = [0u8; 32].into();
input.serialize(&mut a);
let output = RP25519::deserialize(&a);
let output = RP25519::deserialize(&a).unwrap();
assert_eq!(input, output);
}

Expand Down Expand Up @@ -236,4 +246,13 @@ mod test {
assert_ne!(0u64, u64::from(fp_a));
assert_ne!(0u32, u32::from(fp_a));
}

#[test]
fn non_canonical() {
const ZERO: u128 = 0;
// 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF is not a valid Ristretto point
let buf: [u8; 32] = unsafe { std::mem::transmute([!ZERO, !ZERO]) };
let err = RP25519::deserialize(GenericArray::from_slice(&buf)).unwrap_err();
assert!(matches!(err, NonCanonicalEncoding(_)));
}
}
16 changes: 10 additions & 6 deletions ipa-core/src/ff/ec_prime_field.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::convert::Infallible;

use curve25519_dalek::scalar::Scalar;
use generic_array::GenericArray;
use hkdf::Hkdf;
Expand Down Expand Up @@ -48,14 +50,15 @@ impl From<Fp25519> for Scalar {

impl Serializable for Fp25519 {
type Size = <<Fp25519 as SharedValue>::Storage as Block>::Size;
type DeserializationError = Infallible;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
*buf.as_mut() = self.0.to_bytes();
}

/// Deserialized values are reduced modulo the field order.
fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Self {
Fp25519(Scalar::from_bytes_mod_order((*buf).into()))
fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Result<Self, Self::DeserializationError> {
Ok(Fp25519(Scalar::from_bytes_mod_order((*buf).into())))
}
}

Expand Down Expand Up @@ -135,7 +138,7 @@ impl From<BA256> for Fp25519 {
let mut buf: GenericArray<u8, U32> = [0u8; 32].into();
s.serialize(&mut buf);
// Reduces mod order
Fp25519::deserialize(&buf)
Fp25519::deserialize_infallible(&buf)
}
}

Expand All @@ -159,11 +162,12 @@ macro_rules! sc_hash_impl {
fn from(s: $u_type) -> Self {
use hkdf::Hkdf;
use sha2::Sha256;

let hk = Hkdf::<Sha256>::new(None, &s.to_le_bytes());
let mut okm = [0u8; 32];
//error invalid length from expand only happens when okm is very large
hk.expand(&[], &mut okm).unwrap();
Fp25519::deserialize(&okm.into())
Fp25519::deserialize_infallible(&okm.into())
}
}
};
Expand Down Expand Up @@ -195,7 +199,7 @@ impl FromRandomU128 for Fp25519 {
let mut okm = [0u8; 32];
//error invalid length from expand only happens when okm is very large
hk.expand(&[], &mut okm).unwrap();
Fp25519::deserialize(&okm.into())
Fp25519::deserialize_infallible(&okm.into())
}
}

Expand Down Expand Up @@ -233,7 +237,7 @@ mod test {
let input = rng.gen::<Fp25519>();
let mut a: GenericArray<u8, U32> = [0u8; 32].into();
input.serialize(&mut a);
let output = Fp25519::deserialize(&a);
let output = Fp25519::deserialize_infallible(&a);
assert_eq!(input, output);
}

Expand Down
7 changes: 4 additions & 3 deletions ipa-core/src/ff/galois_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,13 +464,14 @@ macro_rules! bit_array_impl {

impl Serializable for $name {
type Size = <$store as Block>::Size;
type DeserializationError = std::convert::Infallible;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
buf.copy_from_slice(self.0.as_raw_slice());
}

fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Self {
Self(<$store>::new(assert_copy(*buf).into()))
fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Result<Self, Self::DeserializationError> {
Ok(Self(<$store>::new(assert_copy(*buf).into())))
}
}

Expand Down Expand Up @@ -598,7 +599,7 @@ macro_rules! bit_array_impl {
let mut buf = GenericArray::default();
a.clone().serialize(&mut buf);

assert_eq!(a, $name::deserialize(&buf));
assert_eq!(a, $name::deserialize_infallible(&buf));
}
}

Expand Down
28 changes: 26 additions & 2 deletions ipa-core/src/ff/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ mod field;
mod galois_field;
mod prime_field;

use std::ops::{Add, AddAssign, Sub, SubAssign};
use std::{
convert::Infallible,
ops::{Add, AddAssign, Sub, SubAssign},
};

pub use field::{Field, FieldType};
pub use galois_field::{GaloisField, Gf2, Gf20Bit, Gf32Bit, Gf3Bit, Gf40Bit, Gf8Bit, Gf9Bit};
Expand All @@ -19,6 +22,8 @@ use generic_array::{ArrayLength, GenericArray};
pub use prime_field::Fp31;
pub use prime_field::{Fp32BitPrime, PrimeField};

use crate::error::UnwrapInfallible;

#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum Error {
#[error("unknown field type {type_str}")]
Expand All @@ -43,6 +48,8 @@ impl<T, Rhs> AddSubAssign<Rhs> for T where T: AddAssign<Rhs> + SubAssign<Rhs> {}
pub trait Serializable: Sized {
/// Required number of bytes to store this message on disk/network
type Size: ArrayLength;
/// The error type that can be returned if an error occurs during deserialization.
type DeserializationError: std::error::Error + Send + Sync + 'static;

/// Serialize this message to a mutable slice. It is enforced at compile time or on the caller
/// side that this slice is sized to fit this instance. Implementations do not need to check
Expand All @@ -53,7 +60,24 @@ pub trait Serializable: Sized {
/// buffer has enough capacity to fit instances of this trait.
///
/// [`serialize`]: Self::serialize
fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Self;
///
/// ## Errors
/// In general, deserialization may fail even if buffer size is enough. The bytes may
/// not represent a valid value in the domain, in this case implementations will return an error.
fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Result<Self, Self::DeserializationError>;

/// Same as [`deserialize`] but returns an actual value if it is known at compile time that deserialization
/// is infallible.
///
/// [`deserialize`]: Self::deserialize
fn deserialize_infallible(buf: &GenericArray<u8, Self::Size>) -> Self
where
Infallible: From<Self::DeserializationError>,
{
Self::deserialize(buf)
.map_err(Into::into)
.unwrap_infallible()
}
}

pub trait ArrayAccess {
Expand Down
Loading

0 comments on commit 432a2da

Please sign in to comment.