diff --git a/crates/core/component/dex/src/swap_claim/proof.rs b/crates/core/component/dex/src/swap_claim/proof.rs index 60e52ff6f1..272451c1cb 100644 --- a/crates/core/component/dex/src/swap_claim/proof.rs +++ b/crates/core/component/dex/src/swap_claim/proof.rs @@ -741,4 +741,102 @@ mod tests { assert!(check_circuit_satisfaction(public, private).is_err()); } } + + prop_compose! { + // This strategy is invalid because the block height of the swap commitment does not match + // the height of the batch swap output data. + fn arb_invalid_swapclaim_swap_commitment_height()(seed_phrase_randomness in any::<[u8; 32]>(), rseed_randomness in any::<[u8; 32]>(), value1_amount in 2..200u64, fee_amount in any::(), test_bsod in unfilled_bsod_strategy()) -> (SwapClaimProofPublic, SwapClaimProofPrivate) { + let seed_phrase = SeedPhrase::from_randomness(&seed_phrase_randomness); + let sk_recipient = SpendKey::from_seed_phrase_bip44(seed_phrase, &Bip44Path::new(0)); + let fvk_recipient = sk_recipient.full_viewing_key(); + let ivk_recipient = fvk_recipient.incoming(); + let (claim_address, _dtk_d) = ivk_recipient.payment_address(0u32.into()); + let nk = *sk_recipient.nullifier_key(); + + let gm = asset::Cache::with_known_assets().get_unit("gm").unwrap(); + let gn = asset::Cache::with_known_assets().get_unit("gn").unwrap(); + let trading_pair = TradingPair::new(gm.id(), gn.id()); + + let delta_1_i = Amount::from(value1_amount); + let delta_2_i = Amount::from(0u64); + let fee = Fee::default(); + + let rseed = Rseed(rseed_randomness); + let swap_plaintext = SwapPlaintext { + trading_pair, + delta_1_i, + delta_2_i, + claim_fee: fee, + claim_address, + rseed, + }; + let incorrect_fee = Fee::from_staking_token_amount(Amount::from(fee_amount)); + let mut sct = tct::Tree::new(); + let swap_commitment = swap_plaintext.swap_commitment(); + sct.insert(tct::Witness::Keep, swap_commitment).unwrap(); + let anchor = sct.root(); + let state_commitment_proof = sct.witness(swap_commitment).unwrap(); + let position = state_commitment_proof.position(); + let nullifier = Nullifier::derive(&nk, position, &swap_commitment); + + // End the block, and then add a dummy commitment that we'll use + // to compute the position and block height that the BSOD corresponds to. + sct.end_block().expect("can end block"); + let dummy_swap_commitment = tct::StateCommitment(Fq::from(1)); + sct.insert(tct::Witness::Keep, dummy_swap_commitment).unwrap(); + let dummy_state_commitment_proof = sct.witness(swap_commitment).unwrap(); + let dummy_position = dummy_state_commitment_proof.position(); + + let epoch_duration = 20; + let height = epoch_duration * dummy_position.epoch() + dummy_position.block(); + + let output_data = BatchSwapOutputData { + delta_1: test_bsod.delta_1, + delta_2: test_bsod.delta_2, + lambda_1: test_bsod.lambda_1, + lambda_2: test_bsod.lambda_2, + unfilled_1: test_bsod.unfilled_1, + unfilled_2: test_bsod.unfilled_2, + height: height.into(), + trading_pair: swap_plaintext.trading_pair, + epoch_starting_height: (epoch_duration * dummy_position.epoch()).into(), + }; + let (lambda_1, lambda_2) = output_data.pro_rata_outputs((delta_1_i, delta_2_i)); + + let (output_rseed_1, output_rseed_2) = swap_plaintext.output_rseeds(); + let note_blinding_1 = output_rseed_1.derive_note_blinding(); + let note_blinding_2 = output_rseed_2.derive_note_blinding(); + let (output_1_note, output_2_note) = swap_plaintext.output_notes(&output_data); + let note_commitment_1 = output_1_note.commit(); + let note_commitment_2 = output_2_note.commit(); + + let public = SwapClaimProofPublic { + anchor, + nullifier, + claim_fee: incorrect_fee, + output_data, + note_commitment_1, + note_commitment_2, + }; + let private = SwapClaimProofPrivate { + swap_plaintext, + state_commitment_proof, + nk, + lambda_1, + lambda_2, + note_blinding_1, + note_blinding_2, + }; + + (public, private) + } + } + + proptest! { + #[test] + fn swap_claim_proof_invalid_swap_commitment_height((public, private) in arb_invalid_swapclaim_swap_commitment_height()) { + assert!(check_satisfaction(&public, &private).is_err()); + assert!(check_circuit_satisfaction(public, private).is_err()); + } + } }