diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 0c1239c9b..42518d17a 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -10,6 +10,7 @@ use crate::{ ff::{ boolean::Boolean, boolean_array::{BooleanArray, BA5, BA64, BA8}, + curve_points::RP25519, ec_prime_field::Fp25519, Serializable, U128Conversions, }, @@ -18,7 +19,7 @@ use crate::{ TotalRecords, }, protocol::{ - basics::{BooleanArrayMul, BooleanProtocols}, + basics::{BooleanArrayMul, BooleanProtocols, Reveal}, context::{ Context, DZKPUpgraded, DZKPUpgradedSemiHonestContext, MacUpgraded, SemiHonestContext, UpgradableContext, UpgradedSemiHonestContext, @@ -35,7 +36,7 @@ use crate::{ }, secret_sharing::{ replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, FieldSimd, - SharedValue, TransposeFrom, + SharedValue, TransposeFrom, Vectorizable, }, seq_join::seq_join, sharding::NotSharded, @@ -303,6 +304,8 @@ where Replicated: BooleanProtocols, CONV_CHUNK>, Replicated: PrfSharing, PRF_CHUNK, Field = Fp25519> + FromPrss, + Replicated: + Reveal, Output = >::Array>, { let conv_records = TotalRecords::specified(div_round_up(input_rows.len(), Const::))?; diff --git a/ipa-core/src/protocol/ipa_prf/prf_eval.rs b/ipa-core/src/protocol/ipa_prf/prf_eval.rs index fa5f58848..1870c9ef7 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_eval.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_eval.rs @@ -1,13 +1,12 @@ use std::iter::zip; use futures::future::try_join; -use futures_util::FutureExt; use crate::{ error::Error, ff::{curve_points::RP25519, ec_prime_field::Fp25519}, protocol::{ - basics::{malicious_reveal, reveal, SecureMul}, + basics::{reveal, Reveal, SecureMul}, context::{ upgrade::Upgradable, UpgradableContext, UpgradedContext, UpgradedMaliciousContext, UpgradedSemiHonestContext, @@ -103,6 +102,7 @@ pub async fn eval_dy_prf( where C: UpgradedContext, AdditiveShare: PrfSharing, + AdditiveShare: Reveal>::Array>, Fp25519: FieldSimd, RP25519: Vectorizable, { @@ -129,10 +129,7 @@ where >::Array, >::Array, ) = try_join( - // TODO: these should invoke reveal via the trait when this function - // takes a context of an appropriate type. - malicious_reveal(ctx.narrow(&Step::RevealR), record_id, None, &sh_gr) - .map(|v| v.map(|arr| arr.unwrap())), + reveal(ctx.narrow(&Step::RevealR), record_id, &sh_gr), reveal(ctx.narrow(&Step::Revealz), record_id, &y), ) .await?; @@ -155,13 +152,14 @@ mod test { ff::{curve_points::RP25519, ec_prime_field::Fp25519}, helpers::{in_memory_config::MaliciousHelper, Role}, protocol::{ - context::{Context, UpgradableContext, Validator}, + basics::Reveal, + context::{Context, MacUpgraded, UpgradableContext, Validator}, ipa_prf::{ prf_eval::{eval_dy_prf, PrfSharing}, step::PrfStep, }, }, - secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, + secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares, Vectorizable}, test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig}, }; @@ -179,8 +177,9 @@ mod test { ) -> Result, Error> where C: UpgradableContext, - AdditiveShare: - PrfSharing< as Validator>::Context, 1, Field = Fp25519>, + AdditiveShare: PrfSharing, 1, Field = Fp25519>, + AdditiveShare: + Reveal, Output = >::Array>, { let ctx = ctx.set_total_records(input_match_keys.len()); let validator = ctx.validator::();