From cb8c14700656f78cac4a6a407307aa3afc14655b Mon Sep 17 00:00:00 2001 From: stiegerc Date: Mon, 3 Mar 2025 21:26:52 +0100 Subject: [PATCH] chore: refactor to use test_strategy in Canonical State proptests (#4200) Same story has [this](https://github.com/dfinity/ic/pull/4163). --- Cargo.lock | 2 + rs/canonical_state/BUILD.bazel | 10 +- rs/canonical_state/Cargo.toml | 1 + rs/canonical_state/tests/compatibility.rs | 353 +++++++++++------- .../tests/size_limit_visitor.rs | 81 ++-- rs/canonical_state/tree_hash/BUILD.bazel | 6 + rs/canonical_state/tree_hash/Cargo.toml | 1 + .../tree_hash/tests/hash_tree.rs | 13 +- 8 files changed, 285 insertions(+), 182 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ab219aeda03c..40f9bfa69261 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6572,6 +6572,7 @@ dependencies = [ "strum", "strum_macros", "tempfile", + "test-strategy 0.4.0", ] [[package]] @@ -6589,6 +6590,7 @@ dependencies = [ "rand 0.8.5", "rand_chacha 0.3.1", "scoped_threadpool", + "test-strategy 0.4.0", "thiserror 2.0.11", ] diff --git a/rs/canonical_state/BUILD.bazel b/rs/canonical_state/BUILD.bazel index 101def3cce67..a8556b986fe0 100644 --- a/rs/canonical_state/BUILD.bazel +++ b/rs/canonical_state/BUILD.bazel @@ -47,7 +47,9 @@ DEV_DEPENDENCIES = [ "@crate_index//:tempfile", ] -DEV_MACRO_DEPENDENCIES = [ +MACRO_DEV_DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:test-strategy", ] rust_library( @@ -68,20 +70,20 @@ rust_test( rust_test( name = "compatibility_test", srcs = ["tests/compatibility.rs"], - proc_macro_deps = DEV_MACRO_DEPENDENCIES, + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = DEPENDENCIES + DEV_DEPENDENCIES + [":canonical_state"], ) rust_test( name = "size_limit_visitor_test", srcs = ["tests/size_limit_visitor.rs"], - proc_macro_deps = DEV_MACRO_DEPENDENCIES, + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = DEPENDENCIES + DEV_DEPENDENCIES + [":canonical_state"], ) rust_test( name = "hash_tree_test", srcs = ["tests/hash_tree.rs"], - proc_macro_deps = DEV_MACRO_DEPENDENCIES, + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = DEPENDENCIES + DEV_DEPENDENCIES + [":canonical_state"], ) diff --git a/rs/canonical_state/Cargo.toml b/rs/canonical_state/Cargo.toml index 4abd235fde55..900f9f4f0b26 100644 --- a/rs/canonical_state/Cargo.toml +++ b/rs/canonical_state/Cargo.toml @@ -41,4 +41,5 @@ ic-wasm-types = { path = "../types/wasm_types" } lazy_static = { workspace = true } maplit = "1.0.2" proptest = { workspace = true } +test-strategy = "0.4.0" tempfile = { workspace = true } diff --git a/rs/canonical_state/tests/compatibility.rs b/rs/canonical_state/tests/compatibility.rs index e45dd64fd10b..f389c4e648e6 100644 --- a/rs/canonical_state/tests/compatibility.rs +++ b/rs/canonical_state/tests/compatibility.rs @@ -149,71 +149,104 @@ lazy_static! { ]; } -proptest! { - /// Tests that given a `StreamHeader` that is valid for a given certification - /// version range (e.g. no `reject_signals` before certification version 8) all - /// supported canonical type (e.g. `StreamHeaderV6` or `StreamHeader`) and - /// certification version combinations produce the exact same encoding. - #[test] - fn stream_header_unique_encoding((header, version_range) in arb_valid_versioned_stream_header(100)) { - let mut results = vec![]; - for version in iter(version_range) { - let results_before = results.len(); - for encoding in &*STREAM_HEADER_ENCODINGS { - if encoding.version_range.contains(&version) { - let bytes = (encoding.encode)((&header, version)) - .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); - results.push((version, encoding.name, bytes)); - } +/// Tests that given a `StreamHeader` that is valid for a given certification +/// version range (e.g. no `reject_signals` before certification version 8) all +/// supported canonical type (e.g. `StreamHeaderV6` or `StreamHeader`) and +/// certification version combinations produce the exact same encoding. +#[test_strategy::proptest] +fn stream_header_unique_encoding( + #[strategy(arb_valid_versioned_stream_header( + 100, // max_signal_count + ))] + test_header: (StreamHeader, RangeInclusive), +) { + let (header, version_range) = test_header; + + let mut results = vec![]; + for version in iter(version_range) { + let results_before = results.len(); + for encoding in &*STREAM_HEADER_ENCODINGS { + if encoding.version_range.contains(&version) { + let bytes = (encoding.encode)((&header, version)) + .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); + results.push((version, encoding.name, bytes)); } - assert!(results.len() > results_before, "No supported encodings for certification version {:?}", version); } + assert!( + results.len() > results_before, + "No supported encodings for certification version {:?}", + version + ); + } - if results.len() > 1 { - let (current_version, current_name, current_bytes) = results.pop().unwrap(); - for (version, name, bytes) in &results { - assert_eq!(¤t_bytes, bytes, "Different encodings: {}@{:?} and {}@{:?}", current_name, current_version, name, version); - } + if results.len() > 1 { + let (current_version, current_name, current_bytes) = results.pop().unwrap(); + for (version, name, bytes) in &results { + assert_eq!( + ¤t_bytes, bytes, + "Different encodings: {}@{:?} and {}@{:?}", + current_name, current_version, name, version + ); } } +} + +/// Tests that, given a `StreamHeader` that is valid for a given certification +/// version range (e.g. no `reject_signals` before certification version 8), +/// all supported encodings will decode back into the same `StreamHeader`. +#[test_strategy::proptest] +fn stream_header_roundtrip_encoding( + #[strategy(arb_valid_versioned_stream_header( + 100, // max_signal_count + ))] + test_header: (StreamHeader, RangeInclusive), +) { + let (header, version_range) = test_header; + + for version in iter(version_range) { + for encoding in &*STREAM_HEADER_ENCODINGS { + if encoding.version_range.contains(&version) { + let bytes = (encoding.encode)((&header, version)) + .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); + let result = (encoding.decode)(&bytes) + .unwrap_or_else(|_| panic!("Failed to decode {}@{:?}", encoding.name, version)); - /// Tests that, given a `StreamHeader` that is valid for a given certification - /// version range (e.g. no `reject_signals` before certification version 8), - /// all supported encodings will decode back into the same `StreamHeader`. - #[test] - fn stream_header_roundtrip_encoding((header, version_range) in arb_valid_versioned_stream_header(100)) { - for version in iter(version_range) { - for encoding in &*STREAM_HEADER_ENCODINGS { - if encoding.version_range.contains(&version) { - let bytes = (encoding.encode)((&header, version)) - .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); - let result = (encoding.decode)(&bytes) - .unwrap_or_else(|_| panic!("Failed to decode {}@{:?}", encoding.name, version)); - - assert_eq!(header, result, "Roundtrip encoding {}@{:?} failed", encoding.name, version); - } + assert_eq!( + header, result, + "Roundtrip encoding {}@{:?} failed", + encoding.name, version + ); } } } +} + +/// Tests that, given a `StreamHeader` that is invalid for a given certification +/// version range (e.g. `reject_signals` before certification version 8), +/// encoding will panic. +/// +/// Be aware that the output generated by this test failing includes all panics +/// (e.g. stack traces), including those produced by previous iterations where +/// panics were caught by `std::panic::catch_unwind`. +#[test_strategy::proptest] +fn stream_header_encoding_panic_on_invalid( + #[strategy(arb_invalid_versioned_stream_header( + 100, // max_signal_count + ))] + test_header: (StreamHeader, RangeInclusive), +) { + let (header, version_range) = test_header; + for version in iter(version_range) { + for encoding in &*STREAM_HEADER_ENCODINGS { + if encoding.version_range.contains(&version) { + let result = std::panic::catch_unwind(|| (encoding.encode)((&header, version))); - /// Tests that, given a `StreamHeader` that is invalid for a given certification - /// version range (e.g. `reject_signals` before certification version 8), - /// encoding will panic. - /// - /// Be aware that the output generated by this test failing includes all panics - /// (e.g. stack traces), including those produced by previous iterations where - /// panics were caught by `std::panic::catch_unwind`. - #[test] - fn stream_header_encoding_panic_on_invalid((header, version_range) in arb_invalid_versioned_stream_header(100)) { - for version in iter(version_range) { - for encoding in &*STREAM_HEADER_ENCODINGS { - if encoding.version_range.contains(&version) { - let result = std::panic::catch_unwind(|| { - (encoding.encode)((&header, version)) - }); - - assert!(result.is_err(), "Encoding of invalid {}@{:?} succeeded", encoding.name, version); - } + assert!( + result.is_err(), + "Encoding of invalid {}@{:?} succeeded", + encoding.name, + version + ); } } } @@ -256,49 +289,73 @@ lazy_static! { ]; } -proptest! { - /// Tests that given a `RequestOrResponse` that is valid for a given certification - /// version range (e.g. no `reject_signals` before certification version 8) all - /// supported canonical type (e.g. `RequestOrResponseV3` or `RequestOrResponse`) - /// and certification version combinations produce the exact same encoding. - #[test] - fn message_unique_encoding((message, version_range) in arb_valid_versioned_message()) { - let mut results = vec![]; - for version in iter(version_range) { - let results_before = results.len(); - for encoding in &*MESSAGE_ENCODINGS { - if encoding.version_range.contains(&version) { - let bytes = (encoding.encode)((&message, version)) - .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); - results.push((version, encoding.name, bytes)); - } +/// Tests that given a `RequestOrResponse` that is valid for a given certification +/// version range (e.g. no `reject_signals` before certification version 8) all +/// supported canonical type (e.g. `RequestOrResponseV3` or `RequestOrResponse`) +/// and certification version combinations produce the exact same encoding. +#[test_strategy::proptest] +fn message_unique_encoding( + #[strategy(arb_valid_versioned_message())] test_message: ( + RequestOrResponse, + RangeInclusive, + ), +) { + let (message, version_range) = test_message; + + let mut results = vec![]; + for version in iter(version_range) { + let results_before = results.len(); + for encoding in &*MESSAGE_ENCODINGS { + if encoding.version_range.contains(&version) { + let bytes = (encoding.encode)((&message, version)) + .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); + results.push((version, encoding.name, bytes)); } - assert!(results.len() > results_before, "No supported encodings for certification version {:?}", version); } + assert!( + results.len() > results_before, + "No supported encodings for certification version {:?}", + version + ); + } - if results.len() > 1 { - let (current_version, current_name, current_bytes) = results.pop().unwrap(); - for (version, name, bytes) in &results { - assert_eq!(¤t_bytes, bytes, "Different encodings: {}@{:?} and {}@{:?}", current_name, current_version, name, version); - } + if results.len() > 1 { + let (current_version, current_name, current_bytes) = results.pop().unwrap(); + for (version, name, bytes) in &results { + assert_eq!( + ¤t_bytes, bytes, + "Different encodings: {}@{:?} and {}@{:?}", + current_name, current_version, name, version + ); } } +} + +/// Tests that, given a `RequestOrResponse` that is valid for a given +/// certification version range, all supported encodings will decode back into +/// the same `RequestOrResponse`. +#[test_strategy::proptest] +fn message_roundtrip_encoding( + #[strategy(arb_valid_versioned_message())] test_message: ( + RequestOrResponse, + RangeInclusive, + ), +) { + let (message, version_range) = test_message; - /// Tests that, given a `RequestOrResponse` that is valid for a given - /// certification version range, all supported encodings will decode back into - /// the same `RequestOrResponse`. - #[test] - fn message_roundtrip_encoding((message, version_range) in arb_valid_versioned_message()) { - for version in iter(version_range) { - for encoding in &*MESSAGE_ENCODINGS { - if encoding.version_range.contains(&version) { - let bytes = (encoding.encode)((&message, version)) - .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); - let result = (encoding.decode)(&bytes) - .unwrap_or_else(|_| panic!("Failed to decode {}@{:?}", encoding.name, version)); - - assert_eq!(message, result, "Roundtrip encoding {}@{:?} failed", encoding.name, version); - } + for version in iter(version_range) { + for encoding in &*MESSAGE_ENCODINGS { + if encoding.version_range.contains(&version) { + let bytes = (encoding.encode)((&message, version)) + .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); + let result = (encoding.decode)(&bytes) + .unwrap_or_else(|_| panic!("Failed to decode {}@{:?}", encoding.name, version)); + + assert_eq!( + message, result, + "Roundtrip encoding {}@{:?} failed", + encoding.name, version + ); } } } @@ -350,31 +407,44 @@ pub(crate) fn arb_valid_system_metadata( ] } -proptest! { - /// Tests that given a `SystemMetadata` that is valid for a given certification - /// version range, all supported canonical type (e.g. `SystemMetadataV9` or - /// `SystemMetadataV10`) and certification version combinations produce the - /// exact same encoding. - #[test] - fn system_metadata_unique_encoding((metadata, version_range) in arb_valid_system_metadata()) { - let mut results = vec![]; - for version in iter(version_range) { - let results_before = results.len(); - for encoding in &*SYSTEM_METADATA_ENCODINGS { - if encoding.version_range.contains(&version) { - let bytes = (encoding.encode)((&metadata, version)) - .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); - results.push((version, encoding.name, bytes)); - } +/// Tests that given a `SystemMetadata` that is valid for a given certification +/// version range, all supported canonical type (e.g. `SystemMetadataV9` or +/// `SystemMetadataV10`) and certification version combinations produce the +/// exact same encoding. +#[test_strategy::proptest] +fn system_metadata_unique_encoding( + #[strategy(arb_valid_system_metadata())] test_metadata: ( + SystemMetadata, + RangeInclusive, + ), +) { + let (metadata, version_range) = test_metadata; + + let mut results = vec![]; + for version in iter(version_range) { + let results_before = results.len(); + for encoding in &*SYSTEM_METADATA_ENCODINGS { + if encoding.version_range.contains(&version) { + let bytes = (encoding.encode)((&metadata, version)) + .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); + results.push((version, encoding.name, bytes)); } - assert!(results.len() > results_before, "No supported encodings for certification version {:?}", version); } + assert!( + results.len() > results_before, + "No supported encodings for certification version {:?}", + version + ); + } - if results.len() > 1 { - let (current_version, current_name, current_bytes) = results.pop().unwrap(); - for (version, name, bytes) in &results { - assert_eq!(¤t_bytes, bytes, "Different encodings: {}@{:?} and {}@{:?}", current_name, current_version, name, version); - } + if results.len() > 1 { + let (current_version, current_name, current_bytes) = results.pop().unwrap(); + for (version, name, bytes) in &results { + assert_eq!( + ¤t_bytes, bytes, + "Different encodings: {}@{:?} and {}@{:?}", + current_name, current_version, name, version + ); } } } @@ -402,30 +472,43 @@ pub(crate) fn arb_valid_subnet_metrics( )] } -proptest! { - /// Tests that given a `SubnetMetrics` that is valid for a given certification - /// version range, all supported canonical type and certification version - /// combinations produce the exact same encoding. - #[test] - fn subnet_metrics_unique_encoding((subnet_metrics, version_range) in arb_valid_subnet_metrics()) { - let mut results = vec![]; - for version in iter(version_range) { - let results_before = results.len(); - for encoding in &*SUBNET_METRICS_ENCODINGS { - if encoding.version_range.contains(&version) { - let bytes = (encoding.encode)((&subnet_metrics, version)) - .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); - results.push((version, encoding.name, bytes)); - } +/// Tests that given a `SubnetMetrics` that is valid for a given certification +/// version range, all supported canonical type and certification version +/// combinations produce the exact same encoding. +#[test_strategy::proptest] +fn subnet_metrics_unique_encoding( + #[strategy(arb_valid_subnet_metrics())] test_subnet_metrics: ( + SubnetMetrics, + RangeInclusive, + ), +) { + let (subnet_metrics, version_range) = test_subnet_metrics; + + let mut results = vec![]; + for version in iter(version_range) { + let results_before = results.len(); + for encoding in &*SUBNET_METRICS_ENCODINGS { + if encoding.version_range.contains(&version) { + let bytes = (encoding.encode)((&subnet_metrics, version)) + .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); + results.push((version, encoding.name, bytes)); } - assert!(results.len() > results_before, "No supported encodings for certification version {:?}", version); } + assert!( + results.len() > results_before, + "No supported encodings for certification version {:?}", + version + ); + } - if results.len() > 1 { - let (current_version, current_name, current_bytes) = results.pop().unwrap(); - for (version, name, bytes) in &results { - assert_eq!(¤t_bytes, bytes, "Different encodings: {}@{:?} and {}@{:?}", current_name, current_version, name, version); - } + if results.len() > 1 { + let (current_version, current_name, current_bytes) = results.pop().unwrap(); + for (version, name, bytes) in &results { + assert_eq!( + ¤t_bytes, bytes, + "Different encodings: {}@{:?} and {}@{:?}", + current_name, current_version, name, version + ); } } } diff --git a/rs/canonical_state/tests/size_limit_visitor.rs b/rs/canonical_state/tests/size_limit_visitor.rs index d24639efcc0b..873ed2ab0db3 100644 --- a/rs/canonical_state/tests/size_limit_visitor.rs +++ b/rs/canonical_state/tests/size_limit_visitor.rs @@ -68,46 +68,53 @@ prop_compose! { } } -proptest! { - #[test] - fn size_limit_proptest(fixture in arb_fixture(10)) { - let Fixture{ state, end, slice_begin, size_limit, .. } = fixture; - - // Produce a size-limited slice starting from `slice_begin`. - let pattern = vec![ - Label(b"streams".to_vec()), - Any, - Label(b"messages".to_vec()), - Any, - ]; - let subtree_pattern = make_slice_pattern(slice_begin, end); - let visitor = SizeLimitVisitor::new( - pattern, - size_limit, - SubtreeVisitor::new(&subtree_pattern, MessageSpyVisitor::default()), +#[test_strategy::proptest] +fn size_limit_proptest(#[strategy(arb_fixture(10))] fixture: Fixture) { + let Fixture { + state, + end, + slice_begin, + size_limit, + .. + } = fixture; + + // Produce a size-limited slice starting from `slice_begin`. + let pattern = vec![ + Label(b"streams".to_vec()), + Any, + Label(b"messages".to_vec()), + Any, + ]; + let subtree_pattern = make_slice_pattern(slice_begin, end); + let visitor = SizeLimitVisitor::new( + pattern, + size_limit, + SubtreeVisitor::new(&subtree_pattern, MessageSpyVisitor::default()), + ); + let (actual_size, actual_begin, actual_end) = traverse(&state, visitor); + + if let (Some(actual_begin), Some(actual_end)) = (actual_begin, actual_end) { + // Non-empty slice. + assert_eq!(slice_begin, actual_begin); + assert!(actual_end <= end); + + // Size is below the limit or the slice consists of a single message. + assert!(actual_size <= size_limit || actual_end - actual_begin == 1); + // And must match the computed slice size. + assert_eq!( + compute_message_sizes(&state, actual_begin, actual_end), + actual_size ); - let (actual_size, actual_begin, actual_end) = traverse(&state, visitor); - - if let (Some(actual_begin), Some(actual_end)) = (actual_begin, actual_end) { - // Non-empty slice. - assert_eq!(slice_begin, actual_begin); - assert!(actual_end <= end); - - // Size is below the limit or the slice consists of a single message. - assert!(actual_size <= size_limit || actual_end - actual_begin == 1); - // And must match the computed slice size. - assert_eq!(compute_message_sizes(&state, actual_begin, actual_end), actual_size); - if actual_end < end { - // Including one more message should exceed `size_limit`. - assert!(compute_message_sizes(&state, actual_begin, actual_end + 1) > size_limit); - } - } else { - // Empty slice. - assert_eq!(0, actual_size); - // May only happen if `slice_begin == stream.messages.end`. - assert_eq!(slice_begin, end); + if actual_end < end { + // Including one more message should exceed `size_limit`. + assert!(compute_message_sizes(&state, actual_begin, actual_end + 1) > size_limit); } + } else { + // Empty slice. + assert_eq!(0, actual_size); + // May only happen if `slice_begin == stream.messages.end`. + assert_eq!(slice_begin, end); } } diff --git a/rs/canonical_state/tree_hash/BUILD.bazel b/rs/canonical_state/tree_hash/BUILD.bazel index c98ee0cdb385..239979d4ebae 100644 --- a/rs/canonical_state/tree_hash/BUILD.bazel +++ b/rs/canonical_state/tree_hash/BUILD.bazel @@ -22,6 +22,11 @@ DEV_DEPENDENCIES = [ "@crate_index//:rand_chacha", ] +MACRO_DEV_DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:test-strategy", +] + rust_library( name = "tree_hash", srcs = glob(["src/**/*.rs"]), @@ -33,5 +38,6 @@ rust_library( rust_test_suite( name = "tree_hash_integration", srcs = glob(["tests/**/*.rs"]), + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = DEPENDENCIES + DEV_DEPENDENCIES + [":tree_hash"], ) diff --git a/rs/canonical_state/tree_hash/Cargo.toml b/rs/canonical_state/tree_hash/Cargo.toml index 6142b75efd28..c3793bff62ad 100644 --- a/rs/canonical_state/tree_hash/Cargo.toml +++ b/rs/canonical_state/tree_hash/Cargo.toml @@ -21,3 +21,4 @@ ic-crypto-tree-hash-test-utils = { path = "../../crypto/tree_hash/test_utils" } proptest = { workspace = true } rand = { workspace = true } rand_chacha = { workspace = true } +test-strategy = "0.4.0" diff --git a/rs/canonical_state/tree_hash/tests/hash_tree.rs b/rs/canonical_state/tree_hash/tests/hash_tree.rs index 93163525808a..5b765cb61efd 100644 --- a/rs/canonical_state/tree_hash/tests/hash_tree.rs +++ b/rs/canonical_state/tree_hash/tests/hash_tree.rs @@ -174,10 +174,11 @@ fn test_non_existence_proof() { ); } -proptest! { - #[test] - fn same_witness(t in arbitrary_labeled_tree(), seed in prop::array::uniform32(any::())) { - let rng = &mut ChaCha20Rng::from_seed(seed); - test_membership_witness(&t, rng); - } +#[test_strategy::proptest] +fn same_witness( + #[strategy(arbitrary_labeled_tree())] t: LabeledTree>, + #[strategy(prop::array::uniform32(any::()))] seed: [u8; 32], +) { + let rng = &mut ChaCha20Rng::from_seed(seed); + test_membership_witness(&t, rng); }