Skip to content

Commit

Permalink
Malicious PRF works for many records
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Aug 5, 2024
1 parent b0bd158 commit 68af5d3
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 155 deletions.
6 changes: 2 additions & 4 deletions ipa-core/src/protocol/basics/reveal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,19 +294,17 @@ where
}
}

pub fn mac_validated_reveal<C, V, S, F>(
pub fn mac_validated_reveal<C, S, F>(
ctx: C,
validator: V,
record_id: RecordId,
v: S,
) -> impl Future<Output = Result<S::Output, Error>> + Send
where
C: UpgradedContext<Field = F>,
V: Validator<C>,
S: Reveal<C> + Send + Sync,
{
async move {
validator.validate_record(record_id).await?;
ctx.validate_record(record_id).await?;
assert_send(v.reveal(ctx, record_id)).await
}
}
Expand Down
4 changes: 4 additions & 0 deletions ipa-core/src/protocol/context/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ impl<'a, F: ExtendableField> UpgradedContext for Upgraded<'a, F> {
type Field = F;
type Share = MaliciousReplicated<F>;

async fn validate_record(&self, record_id: RecordId) -> Result<(), Error> {
unimplemented!("validate_record is not implemented for UpgradedContext")
}


async fn upgrade_one(
&self,
Expand Down
4 changes: 4 additions & 0 deletions ipa-core/src/protocol/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,12 @@ pub trait UpgradedContext: Context {
assert_send(input.upgrade(record_id, self)).await
}

async fn validate_record(&self, record_id: RecordId) -> Result<(), Error>;

/// TODO: this is very promising to make work with new validator. this is the exact interface
/// I need to upgrade in different contexts using the same record id
///
/// TODO: delete this method as we use `upgrade_record` now
async fn upgrade_one(
&self,
record_id: RecordId,
Expand Down
11 changes: 8 additions & 3 deletions ipa-core/src/protocol/context/semi_honest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,17 @@ impl<'a, B: ShardBinding> super::Context for Context<'a, B> {

impl<'a, B: ShardBinding> UpgradableContext for Context<'a, B> {
type UpgradedContext<F: ExtendableField> = Upgraded<'a, B, F>;
type BatchUpgradedContext<F: ExtendableField> = Upgraded<'a, B, F>;
type BatchUpgradedContext<F: ExtendableField> = Self::UpgradedContext<F>;
type Validator<F: ExtendableField> = Validator<'a, B, F>;
type BatchValidator<F: ExtendableField> = Validator<'a, B, F>;
type BatchValidator<F: ExtendableField> = Self::Validator<F>;

fn validator<F: ExtendableField>(self) -> Self::Validator<F> {
Self::Validator::new(self.inner)
}

fn batch_validator<F: ExtendableField>(self, total_records: TotalRecords) -> Self::BatchValidator<F> {
Self::Validator::new(self.inner)
use crate::protocol::context::Context;
Self::Validator::new(self.inner.set_total_records(total_records))
}

type DZKPValidator = SemiHonestDZKPValidator<'a, B>;
Expand Down Expand Up @@ -293,6 +294,10 @@ impl<'a, B: ShardBinding, F: ExtendableField> UpgradedContext for Upgraded<'a, B
type Field = F;
type Share = Replicated<F>;

async fn validate_record(&self, record_id: RecordId) -> Result<(), Error> {
Ok(())
}

async fn upgrade_one(
&self,
_record_id: RecordId,
Expand Down
35 changes: 19 additions & 16 deletions ipa-core/src/protocol/context/validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,23 +91,7 @@ where <V as ExtendableField>::ExtendedField: FieldSimd<N>,
where UpgradedMaliciousContext<'ctx, V>: 'ctx, 'a: 'ctx
{
async move {
// let b = {
// let batch_ref = context.batch.upgrade().unwrap();
// let mut batch = batch_ref.lock().unwrap();
// let b = batch.get_batch(record_id);
// let accumulator = MaliciousAccumulator { inner: Arc::downgrade(b.batch.u_and_w.lock().unwrap()) };
// let r_share = b.batch.r_share.clone();
//
// UpgradedMaliciousContext::new(
// b.batch.context().clone()
// };
upgrade_one(context.malicious_ctx(record_id), record_id, self).await
// b.batch.context().upgrade_one(record_id, b.batch.context()).await
// b.batch.protocol_ctx.upgrade().upgrade_one(record_id, self).await
// let malicious_ctx = UpgradedMaliciousContext::new(
// context.base_ctx.as_base(),
// )
// upgrade_one(malicious_ctx, record_id, self).await
}
}
}
Expand Down Expand Up @@ -502,6 +486,25 @@ impl <F: ExtendableField> UpgradedContext for BatchUpgradedContext<'_, F> {
type Field = F;
type Share = malicious::AdditiveShare<F>;

async fn validate_record(&self, record_id: RecordId) -> Result<(), Error> {
let batch_ref = self.batch.upgrade().expect("Validator is not dropped");
let r = {
let mut batch = batch_ref.lock().unwrap();
batch.validate_record(record_id)
};
match r {
Either::Left((_, batch)) => {
batch.batch.validate(()).await?;
batch.notify.notify_waiters();
Ok(())
}
Either::Right(notify) => {
notify.notified().await;
Ok(())
}
}
}

async fn upgrade_one(&self, record_id: RecordId, x: Replicated<Self::Field>) -> Result<Self::Share, Error> {
todo!()
}
Expand Down
12 changes: 5 additions & 7 deletions ipa-core/src/protocol/ipa_prf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ use crate::{
seq_join::SeqJoin,
};
use crate::protocol::context::validator::Upgradeable;
use crate::protocol::ipa_prf::prf_eval::compute_prf;
use crate::protocol::prss::SharedRandomness;

#[derive(Clone, Debug, Default)]
Expand Down Expand Up @@ -300,7 +299,6 @@ where
C::UpgradedContext<Boolean>: UpgradedContext<Field = Boolean, Share = Replicated<Boolean>>,
C::UpgradedContext<Fp25519>: UpgradedContext<Field = Fp25519>,
Replicated<Fp25519, { Fp25519::VECTORIZE }> :Upgradeable<C::UpgradedContext<Fp25519>>,
<Replicated<Fp25519, { Fp25519::VECTORIZE }> as Upgradeable<C::UpgradedContext<Fp25519>>>::Output: PrfSharing<C::UpgradedContext<Fp25519>, {Fp25519::VECTORIZE}>,
Replicated<Fp25519, { Fp25519::VECTORIZE}>: FromPrss,
BK: BooleanArray,
TV: BooleanArray,
Expand All @@ -309,8 +307,9 @@ where
<<C as UpgradableContext>::DZKPValidator as DZKPValidator>::Context,
CONV_CHUNK,
>,
Replicated<Fp25519, PRF_CHUNK>:
BasicProtocols<C::UpgradedContext<Fp25519>, PRF_CHUNK, ProtocolField = Fp25519>, // todo: default vectorize
// Replicated<Fp25519, PRF_CHUNK>:
// BasicProtocols<C::UpgradedContext<Fp25519>, PRF_CHUNK, ProtocolField = Fp25519>, // todo: default vectorize
Replicated<Fp25519, PRF_CHUNK>: PrfSharing<C::UpgradedContext<Fp25519>, PRF_CHUNK, Field = Fp25519>
{
let conv_records =
TotalRecords::specified(div_round_up(input_rows.len(), Const::<CONV_CHUNK>))?;
Expand Down Expand Up @@ -345,16 +344,15 @@ where
.narrow(&Step::EvalPrf);
// .set_total_records(eval_records);

let prf_key = &gen_prf_key(eval_ctx.clone()).await?;
let prf_key = &gen_prf_key(eval_ctx.clone());
let eval_ctx = eval_ctx.set_total_records(eval_records);

let prf_of_match_keys = seq_join(
eval_ctx.active_work(),
stream::iter(zip(curve_pts, zip(repeat(eval_ctx), repeat(validator)))).enumerate().map(|(i, (curve_pts, (eval_ctx, validator)))| {
let record_id = RecordId::from(i);
curve_pts.then(move |pts| async move {
let pts = eval_ctx.clone().upgrade_record(record_id, pts).await?;
compute_prf(eval_ctx, validator, record_id, &prf_key, pts).await
eval_dy_prf(eval_ctx, record_id, prf_key, pts).await
})
}),
)
Expand Down
Loading

0 comments on commit 68af5d3

Please sign in to comment.