Skip to content

Commit

Permalink
Merge pull request private-attribution#1069 from danielmasny/bitslice
Browse files Browse the repository at this point in the history
Bitslice Fix
  • Loading branch information
danielmasny authored May 22, 2024
2 parents 711ebcd + d084979 commit 2284946
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 52 deletions.
37 changes: 34 additions & 3 deletions ipa-core/src/ff/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ impl FromRandomU128 for Boolean {
impl DZKPCompatibleField for Boolean {
fn as_segment_entry(array: &<Self as Vectorizable<1>>::Array) -> SegmentEntry<'_> {
if bool::from(Boolean::from_array(array)) {
SegmentEntry::from_bitslice(BitSlice::from_element(&1))
SegmentEntry::from_bitslice(BitSlice::from_element(&1u8).get(0..1).unwrap())
} else {
SegmentEntry::from_bitslice(BitSlice::from_element(&0))
SegmentEntry::from_bitslice(BitSlice::from_element(&0u8).get(0..1).unwrap())
}
}
}
Expand All @@ -234,7 +234,11 @@ mod test {
use rand::{thread_rng, Rng};
use typenum::U1;

use crate::ff::{boolean::Boolean, ArrayAccess, Serializable};
use crate::{
ff::{boolean::Boolean, ArrayAccess, Field, Serializable},
protocol::context::dzkp_field::DZKPCompatibleField,
secret_sharing::{SharedValue, Vectorizable},
};

impl Arbitrary for Boolean {
type Parameters = <bool as Arbitrary>::Parameters;
Expand Down Expand Up @@ -275,6 +279,33 @@ mod test {
assert_ne!(a, !a);
}

#[test]
fn boolean_bitslice() {
let one = Boolean::ONE;
let zero = Boolean::ZERO;

// convert into vectorizable
let one_vec: <Boolean as Vectorizable<1>>::Array = one.into_array();
let zero_vec: <Boolean as Vectorizable<1>>::Array = zero.into_array();

// generate slices
let slice_one = <Boolean as DZKPCompatibleField<1>>::as_segment_entry(&one_vec);
let slice_zero = <Boolean as DZKPCompatibleField<1>>::as_segment_entry(&zero_vec);

// check length
assert_eq!(slice_one.len(), 1usize);
assert_eq!(slice_zero.len(), 1usize);

// check content
assert_ne!(*slice_one.as_bitslice(), *slice_zero.as_bitslice());
assert!(slice_one.as_bitslice().first().unwrap());
assert!(!slice_zero.as_bitslice().first().unwrap());

// there is nothing more
assert!(slice_one.as_bitslice().get(1).is_none());
assert!(slice_zero.as_bitslice().get(1).is_none());
}

/// test `ArrayAccess` for Boolean
#[test]
fn test_array_access() {
Expand Down
23 changes: 22 additions & 1 deletion ipa-core/src/ff/boolean_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ macro_rules! boolean_array_impl {
#[inline]
#[must_use]
pub fn as_bitslice(&self) -> &BitSlice<u8, Lsb0> {
self.0.as_bitslice()
self.0.as_bitslice().get(0..$bits).unwrap()
}
}

Expand Down Expand Up @@ -554,6 +554,7 @@ macro_rules! boolean_array_impl {
proptest,
};
use rand::{thread_rng, Rng};
use bitvec::bits;

use super::*;

Expand Down Expand Up @@ -730,6 +731,26 @@ macro_rules! boolean_array_impl {
let actual = format!("{:?}", $name::ZERO);
assert_eq!(expected, actual);
}

#[test]
fn bitslice() {
let zero = $name::ZERO;
let random = thread_rng().gen::<$name>();

// generate slices
let slice_zero = zero.as_bitslice();
let slice_random = random.as_bitslice();

// check length
assert_eq!(slice_zero.len(), $bits);
assert_eq!(slice_random.len(), $bits);

// // check content
assert_eq!(*slice_zero, bits![0;$bits]);
slice_random.iter().enumerate().for_each(|(i,bit)| {
assert_eq!(bit,bool::from(random.get(i).unwrap()));
});
}
}
}

Expand Down
141 changes: 117 additions & 24 deletions ipa-core/src/protocol/context/dzkp_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,11 @@ impl<'a> Segment<'a> {
debug_assert_eq!(x_left.len(), prss_left.len());
debug_assert_eq!(x_left.len(), prss_right.len());
debug_assert_eq!(x_left.len(), z_right.len());
// check that length is either multiple of 256 or 256 is multiple of length
debug_assert_eq!(
(
x_left.len(),
x_left.len() % 256 == 0 || 256 % x_left.len() == 0
),
(x_left.len(), true)
// check that length is either smaller or a multiple of 256
debug_assert!(
x_left.len() <= 256 || x_left.len() % 256 == 0,
"length {} needs to be smaller or a multiple of 256",
x_left.len()
);
// asserts passed, create struct
Self {
Expand Down Expand Up @@ -209,6 +207,11 @@ impl<'a> SegmentEntry<'a> {
SegmentEntry(entry)
}

#[must_use]
pub fn as_bitslice(&self) -> &'a BitSliceType {
self.0
}

/// This function returns the size in bits.
#[must_use]
pub fn len(&self) -> usize {
Expand Down Expand Up @@ -283,38 +286,60 @@ impl MultiplicationInputsBatch {
self.is_empty
}

/// `insert_segment` allows to include a new segment in `MultiplicationInputsBatch`
/// `insert_segment` allows to include a new segment in `MultiplicationInputsBatch`.
/// It supports `segments` that are either smaller than 256 bits or multiple of 256 bits.
///
/// ## Panics
/// Panics when segments have different lengths across records, the `record_id` is less than
/// `first_record` or when `record_id` is more than `first_record + max_multiplications`,
/// i.e. not enough space has been allocated.
/// Panics when segments have different lengths across records.
/// It also Panics when the `record_id` is smaller
/// than the first record of the batch, i.e. `first_record`
/// or too large, i.e. `first_record+max_multiplications`
fn insert_segment(&mut self, record_id: RecordId, segment: Segment) {
// check segment size
debug_assert_eq!(segment.len(), self.multiplication_bit_size);

// panics when record_id is out of bounds
assert!(record_id >= self.first_record);
assert!(
record_id < RecordId::from(self.max_multiplications + usize::from(self.first_record))
);

// update last record
self.last_record = cmp::max(self.last_record, record_id);

// panics when record_id is less than first_record
let id_within_batch = usize::from(record_id) - usize::from(self.first_record);
let block_id = (segment.len() * id_within_batch) >> BIT_ARRAY_SHIFT;

// panics when record_id is too large to fit in, i.e. when it is out of bounds
if 256 % segment.len() == 0 {
self.insert_segment_small(id_within_batch, block_id, segment);
if segment.len() <= 256 {
self.insert_segment_small(record_id, segment);
} else {
self.insert_segment_large(block_id, &segment);
self.insert_segment_large(record_id, &segment);
}
}

/// insert `segments` for `segments` that divide 256
/// insert `segments` that are smaller than or equal to 256
///
/// ## Panics
/// Panics when `bit_length` and `block_id` are out of bounds.
fn insert_segment_small(&mut self, id_within_batch: usize, block_id: usize, segment: Segment) {
/// It also Panics when the `record_id` is smaller
/// than the first record of the batch, i.e. `first_record`
/// or too large, i.e. `first_record+max_multiplications`
fn insert_segment_small(&mut self, record_id: RecordId, segment: Segment) {
// check length
debug_assert!(segment.len() <= 256);

// panics when record_id is out of bounds
assert!(record_id >= self.first_record);
assert!(
record_id < RecordId::from(self.max_multiplications + usize::from(self.first_record))
);

// panics when record_id is less than first_record
let id_within_batch = usize::from(record_id) - usize::from(self.first_record);
// round up segment length to a power of two since we want to have divisors of 256
let length = segment.len().next_power_of_two();

let block_id = (length * id_within_batch) >> BIT_ARRAY_SHIFT;
// segments are small, pack one or more in each entry of `vec`
let position_within_block_start = (segment.len() * id_within_batch) % 256;
let position_within_block_start = (length * id_within_batch) % 256;
let position_within_block_end = position_within_block_start + segment.len();

let block = &mut self.vec[block_id];
Expand All @@ -337,11 +362,26 @@ impl MultiplicationInputsBatch {
}
}

/// insert `segments` for `segments` that are multiples of 256
/// insert `segments` that are multiples of 256
///
/// ## Panics
/// Panics when segment is not a multiple of 256 or is out of bounds.
fn insert_segment_large(&mut self, block_id: usize, segment: &Segment) {
/// It also Panics when the `record_id` is smaller
/// than the first record of the batch, i.e. `first_record`
/// or too large, i.e. `first_record+max_multiplications`
fn insert_segment_large(&mut self, record_id: RecordId, segment: &Segment) {
// check length
debug_assert_eq!(segment.len() % 256, 0);

// panics when record_id is out of bounds
assert!(record_id >= self.first_record);
assert!(
record_id < RecordId::from(self.max_multiplications + usize::from(self.first_record))
);

let id_within_batch = usize::from(record_id) - usize::from(self.first_record);
let block_id = (segment.len() * id_within_batch) >> BIT_ARRAY_SHIFT;

let length_in_blocks = segment.len() >> BIT_ARRAY_SHIFT;
for i in 0..length_in_blocks {
MultiplicationInputsBlock::set(
Expand Down Expand Up @@ -650,7 +690,7 @@ impl<'a> Drop for MaliciousDZKPValidator<'a> {
mod tests {
use std::iter::{repeat, zip};

use bitvec::vec::BitVec;
use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec};
use futures::TryStreamExt;
use futures_util::stream::iter;
use proptest::{prop_compose, proptest, sample::select};
Expand Down Expand Up @@ -966,4 +1006,57 @@ mod tests {
assert_eq!(prover.1, verifier_right);
});
}

#[test]
fn powers_of_two() {
let bits = BitArray::<[u8; 32], Lsb0>::new([255u8; 32]);

// Boolean
assert_eq!(
1usize,
SegmentEntry::from_bitslice(bits.get(0..1).unwrap())
.len()
.next_power_of_two()
);

// BA3
assert_eq!(
4usize,
SegmentEntry::from_bitslice(bits.get(0..3).unwrap())
.len()
.next_power_of_two()
);

// BA8
assert_eq!(
8usize,
SegmentEntry::from_bitslice(bits.get(0..8).unwrap())
.len()
.next_power_of_two()
);

// BA20
assert_eq!(
32usize,
SegmentEntry::from_bitslice(bits.get(0..20).unwrap())
.len()
.next_power_of_two()
);

// BA64
assert_eq!(
64usize,
SegmentEntry::from_bitslice(bits.get(0..64).unwrap())
.len()
.next_power_of_two()
);

// BA256
assert_eq!(
256usize,
SegmentEntry::from_bitslice(bits.get(0..256).unwrap())
.len()
.next_power_of_two()
);
}
}
45 changes: 21 additions & 24 deletions ipa-core/src/secret_sharing/vector/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,37 +84,34 @@ macro_rules! boolean_vector {
let a = rng.gen::<$vec>();
let b = rng.gen::<$vec>();

// required by DZKP storage strategy
if (a.as_bitslice().len() % 256 == 0 || 256 % a.as_bitslice().len() == 0) {
let bit_shares = bit.share_with(&mut rng);
let a_shares = a.share_with(&mut rng);
let b_shares = b.share_with(&mut rng);
let bit_shares = bit.share_with(&mut rng);
let a_shares = a.share_with(&mut rng);
let b_shares = b.share_with(&mut rng);

let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares)))
.map(|(ctx, (bit_share, (a_share, b_share)))| async move {
let v = ctx.clone().dzkp_validator(1);
let m_ctx = v.context();
let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares)))
.map(|(ctx, (bit_share, (a_share, b_share)))| async move {
let v = ctx.clone().dzkp_validator(1);
let m_ctx = v.context();

let result = select(
m_ctx.set_total_records(1),
RecordId::from(0),
&bit_share,
&a_share,
&b_share,
)
.await?;
let result = select(
m_ctx.set_total_records(1),
RecordId::from(0),
&bit_share,
&a_share,
&b_share,
)
.await?;

v.validate::<Fp61BitPrime>().await?;
v.validate::<Fp61BitPrime>().await?;

Ok::<_, Error>(result)
});
Ok::<_, Error>(result)
});

let [ab0, ab1, ab2] = join3v(futures).await;
let [ab0, ab1, ab2] = join3v(futures).await;

let ab = [ab0, ab1, ab2].reconstruct();
let ab = [ab0, ab1, ab2].reconstruct();

assert_eq!(ab, if bit.into() { a } else { b });
}
assert_eq!(ab, if bit.into() { a } else { b });
}

#[tokio::test]
Expand Down

0 comments on commit 2284946

Please sign in to comment.