diff --git a/ipa-core/src/protocol/prss/crypto.rs b/ipa-core/src/protocol/prss/crypto.rs index f9c9f5cd2..b596667ff 100644 --- a/ipa-core/src/protocol/prss/crypto.rs +++ b/ipa-core/src/protocol/prss/crypto.rs @@ -277,7 +277,7 @@ impl Generator { /// Generate the value at the given index. /// This uses the MMO^{\pi} function described in . #[must_use] - pub(crate) fn generate>(&self, index: I) -> u128 { + pub(super) fn generate>(&self, index: I) -> u128 { let index = index.into(); #[cfg(debug_assertions)] self.used.use_index(index).unwrap(); diff --git a/ipa-core/src/protocol/prss/mod.rs b/ipa-core/src/protocol/prss/mod.rs index 667c60245..dda02aba3 100644 --- a/ipa-core/src/protocol/prss/mod.rs +++ b/ipa-core/src/protocol/prss/mod.rs @@ -1,17 +1,12 @@ mod crypto; -use std::{ - collections::HashMap, - fmt::{Debug, Display, Formatter}, - marker::PhantomData, - num::TryFromIntError, - ops::AddAssign, -}; +use std::{collections::HashMap, fmt::Debug, marker::PhantomData, ops::AddAssign}; pub use crypto::{ FromPrss, FromRandom, FromRandomU128, Generator, GeneratorFactory, KeyExchange, SharedRandomness, }; use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray}; +pub(super) use internal::PrssIndex128; use x25519_dalek::PublicKey; use crate::{ @@ -20,57 +15,109 @@ use crate::{ sync::{Arc, Mutex}, }; -/// Internal PRSS index. -/// -/// `PrssIndex128` values are directly input to the block cipher used for pseudo-random generation. -/// Each invocation must use a distinct `PrssIndex128` value. Most code should use the `PrssIndex` -/// type instead, which often corresponds to record IDs. `PrssIndex128` values are produced by -/// the `PrssIndex::offset` function and include the primary `PrssIndex` plus a possible offset -/// when more than 128 bits of randomness are required to generate the requested value. +/// This module restricts access to internal PRSS index's private fields +/// and enforces constructing it via `new` even for impl blocks +/// defined in this module. /// -/// This is public so that it can be used by the instrumentation wrappers in -/// `ipa_core::protocol::context`. It should not generally be used outside the PRSS implementation. -#[derive(Clone, Copy, PartialEq, Eq, Hash)] -pub(crate) struct PrssIndex128 { - index: PrssIndex, - offset: u32, -} +/// This helps to make sure an invalid [`PrssIndex128`] cannot be created +/// and all instantiations occur via `From` calls. +mod internal { + use std::{ + fmt::{Debug, Display, Formatter}, + num::TryFromIntError, + }; -impl Display for PrssIndex128 { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}:{}", self.index.0, self.offset) + use crate::protocol::prss::PrssIndex; + + /// Internal PRSS index. + /// + /// `PrssIndex128` values are directly input to the block cipher used for pseudo-random generation. + /// Each invocation must use a distinct `PrssIndex128` value. Most code should use the `PrssIndex` + /// type instead, which often corresponds to record IDs. `PrssIndex128` values are produced by + /// the `PrssIndex::offset` function and include the primary `PrssIndex` plus a possible offset + /// when more than 128 bits of randomness are required to generate the requested value. + #[derive(Clone, Copy, PartialEq, Eq, Hash)] + pub struct PrssIndex128 { + index: PrssIndex, + offset: u32, + } + + impl Display for PrssIndex128 { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}:{}", self.index.0, self.offset) + } } -} -impl Debug for PrssIndex128 { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{self}") + impl Debug for PrssIndex128 { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{self}") + } } -} -#[cfg(debug_assertions)] -impl From for PrssIndex128 { - fn from(value: u64) -> Self { - Self::try_from(u128::from(value)).unwrap() + #[cfg(debug_assertions)] + impl From for PrssIndex128 { + fn from(value: u64) -> Self { + Self::try_from(u128::from(value)).unwrap() + } } -} -impl From for u128 { - fn from(value: PrssIndex128) -> Self { - u128::from((u64::from(value.index.0) << 32) + u64::from(value.offset)) + impl From for u128 { + fn from(value: PrssIndex128) -> Self { + u128::from(u64::from(value)) + } } -} -impl TryFrom for PrssIndex128 { - type Error = TryFromIntError; + impl From for u64 { + fn from(value: PrssIndex128) -> Self { + (u64::from(value.index.0) << 32) + u64::from(value.offset) + } + } + + impl TryFrom for PrssIndex128 { + type Error = PrssIndexError; + + fn try_from(value: u128) -> Result { + let value64 = u64::try_from(value)?; + let index = PrssIndex::from(u32::try_from(value64 >> 32).unwrap()); + let offset = usize::try_from(value64 & u64::from(u32::MAX)).unwrap(); + + Self::new(index, offset) + } + } + + impl PrssIndex128 { + /// The absolute maximum number of times we can encrypt + /// using the same AES key inside PRSS is 2^43. We reserve + /// 32 bits for index, leaving the remaining 11 for the offset. + /// That puts a limit to 32k maximum entropy generated + /// from PRSS using the same record id. + /// [`proof`]: + pub(super) const MAX_OFFSET: u32 = 1 << 11; + + pub fn new(index: PrssIndex, offset: usize) -> Result { + let this = Self { + index, + offset: offset.try_into()?, + }; + + if this.offset <= Self::MAX_OFFSET { + Ok(this) + } else { + Err(PrssIndexError::OutOfRange(this.into())) + } + } - fn try_from(value: u128) -> Result { - let value = u64::try_from(value)?; + pub fn index(self) -> PrssIndex { + self.index + } + } - Ok(Self { - index: u32::try_from(value >> 32).unwrap().into(), - offset: u32::try_from(value & u64::from(u32::MAX)).unwrap(), - }) + #[derive(Debug, thiserror::Error)] + pub enum PrssIndexError { + #[error("Type conversion failed")] + ConversionError(#[from] TryFromIntError), + #[error("PRSS index is out of range: {0}")] + OutOfRange(u128), } } @@ -116,10 +163,7 @@ impl AddAssign for PrssIndex { impl PrssIndex { fn offset(self, offset: usize) -> PrssIndex128 { - PrssIndex128 { - index: self, - offset: offset.try_into().expect("PRSS offset out of range"), - } + PrssIndex128::new(self, offset).expect("PRSS offset must not be out of range") } } @@ -419,7 +463,7 @@ pub mod test { let (g1, g2) = make(SEED); assert_eq!(g1.generate(0), g2.generate(0)); assert_eq!(g1.generate(1), g2.generate(1)); - assert_eq!(g1.generate(u64::MAX), g2.generate(u64::MAX)); + assert_eq!(g1.generate(1 << 32), g2.generate(1 << 32)); // Calling generators seeded with the same key produce the same output assert_eq!(g1.generate(12), make(SEED).0.generate(12)); @@ -680,6 +724,15 @@ pub mod test { base += 1; } + #[test] + fn index_upper_bound() { + let bad_index = (u128::from(u32::MAX) << 32) + u128::from(PrssIndex128::MAX_OFFSET + 1); + let good_index = (u128::from(u32::MAX) << 32) + u128::from(PrssIndex128::MAX_OFFSET); + + assert!(PrssIndex128::try_from(bad_index).is_err()); + assert!(PrssIndex128::try_from(good_index).is_ok()); + } + fn assert_8_byte_index_is_valid(index: u32, offset: usize) { let index = PrssIndex(index); let index128 = u128::from(index.offset(offset)); @@ -691,7 +744,7 @@ pub mod test { proptest! { #[test] - fn prss_index_128_conversions(index in 0..u32::MAX, offset in 0..u32::MAX) { + fn prss_index_128_conversions(index in 0..u32::MAX, offset in 0..PrssIndex128::MAX_OFFSET) { assert_8_byte_index_is_valid(index ,usize::try_from(offset).unwrap()); } }