diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a188139b418..7e879aa132e 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -53,6 +53,7 @@ go_deps.bzl @dfinity/idx /packages/ic-dummy-getrandom-for-wasm/ @dfinity/crypto-team /packages/ic-ed25519/ @dfinity/crypto-team /packages/ic-ethereum-types/ @dfinity/cross-chain-team +/packages/ic-hpke/ @dfinity/crypto-team /packages/ic-metrics-assert/ @dfinity/cross-chain-team /packages/ic-secp256k1/ @dfinity/crypto-team /packages/ic-sha3/ @dfinity/crypto-team @@ -162,7 +163,6 @@ go_deps.bzl @dfinity/idx /rs/memory_tracker/ @dfinity/execution /rs/messaging/ @dfinity/ic-message-routing-owners /rs/monitoring/ @dfinity/consensus -/rs/monitoring/backtrace/ @dfinity/consensus @dfinity/ic-message-routing-owners /rs/monitoring/metrics @dfinity/consensus @dfinity/ic-message-routing-owners /rs/monitoring/pprof/ @dfinity/consensus @dfinity/ic-message-routing-owners /rs/nervous_system/ @dfinity/nns-team diff --git a/.github/actions/bazel-test-all/action.yaml b/.github/actions/bazel-test-all/action.yaml index b7e3f0d39d0..0f432cbc581 100644 --- a/.github/actions/bazel-test-all/action.yaml +++ b/.github/actions/bazel-test-all/action.yaml @@ -33,6 +33,7 @@ runs: if [ -z "${SSH_AUTH_SOCK:-}" ]; then eval "$(ssh-agent -s)" ssh-add - <<< '${{ inputs.SSH_PRIVATE_KEY_BACKUP_POD }}' + echo "SSH_AUTH_SOCK=$SSH_AUTH_SOCK" >> "$GITHUB_ENV" fi rm -rf ~/.ssh diff --git a/.github/workflows-source/ci-main.yml b/.github/workflows-source/ci-main.yml index 1e940b5554c..ad743617077 100644 --- a/.github/workflows-source/ci-main.yml +++ b/.github/workflows-source/ci-main.yml @@ -94,6 +94,7 @@ jobs: - <<: *checkout - name: Set BAZEL_EXTRA_ARGS shell: bash + id: bazel-extra-args run: | set -xeuo pipefail # Determine which tests to skip and append 'long_test' for pull requests, merge groups or push on dev-gh-* @@ -120,21 +121,20 @@ jobs: # Prepend tags with '-' and join them with commas for Bazel TEST_TAG_FILTERS=$(IFS=,; echo "${EXCLUDED_TEST_TAGS[*]/#/-}") # Determine BAZEL_EXTRA_ARGS based on event type or branch name - BAZEL_EXTRA_ARGS="--test_tag_filters=$TEST_TAG_FILTERS" + BAZEL_EXTRA_ARGS=( "--test_tag_filters=$TEST_TAG_FILTERS" ) if [[ "$CI_EVENT_NAME" == 'merge_group' ]]; then - BAZEL_EXTRA_ARGS+=" --test_timeout_filters=short,moderate --flaky_test_attempts=3" + BAZEL_EXTRA_ARGS+=( --test_timeout_filters=short,moderate --flaky_test_attempts=3 ) elif [[ $BRANCH_NAME =~ ^hotfix-.* ]]; then - BAZEL_EXTRA_ARGS+=" --test_timeout_filters=short,moderate" + BAZEL_EXTRA_ARGS+=( --test_timeout_filters=short,moderate ) else - BAZEL_EXTRA_ARGS+=" --keep_going" + BAZEL_EXTRA_ARGS+=( --keep_going ) fi - # Export BAZEL_EXTRA_ARGS to environment - echo "BAZEL_EXTRA_ARGS=$BAZEL_EXTRA_ARGS" >> $GITHUB_ENV + echo "BAZEL_EXTRA_ARGS=${BAZEL_EXTRA_ARGS[@]}" >> $GITHUB_OUTPUT - name: Run Bazel Test All id: bazel-test-all uses: ./.github/actions/bazel-test-all/ with: - BAZEL_COMMAND: test --config=ci ${{ env.BAZEL_EXTRA_ARGS }} + BAZEL_COMMAND: test --config=ci ${{ steps.bazel-extra-args.outputs.BAZEL_EXTRA_ARGS }} BAZEL_TARGETS: //... BUILDEVENT_APIKEY: ${{ secrets.HONEYCOMB_TOKEN }} GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} diff --git a/.github/workflows/ci-main.yml b/.github/workflows/ci-main.yml index 909bbd7b43e..7f459b6b695 100644 --- a/.github/workflows/ci-main.yml +++ b/.github/workflows/ci-main.yml @@ -52,6 +52,7 @@ jobs: fetch-depth: ${{ github.event_name == 'pull_request' && 256 || 0 }} - name: Set BAZEL_EXTRA_ARGS shell: bash + id: bazel-extra-args run: | set -xeuo pipefail # Determine which tests to skip and append 'long_test' for pull requests, merge groups or push on dev-gh-* @@ -78,21 +79,20 @@ jobs: # Prepend tags with '-' and join them with commas for Bazel TEST_TAG_FILTERS=$(IFS=,; echo "${EXCLUDED_TEST_TAGS[*]/#/-}") # Determine BAZEL_EXTRA_ARGS based on event type or branch name - BAZEL_EXTRA_ARGS="--test_tag_filters=$TEST_TAG_FILTERS" + BAZEL_EXTRA_ARGS=( "--test_tag_filters=$TEST_TAG_FILTERS" ) if [[ "$CI_EVENT_NAME" == 'merge_group' ]]; then - BAZEL_EXTRA_ARGS+=" --test_timeout_filters=short,moderate --flaky_test_attempts=3" + BAZEL_EXTRA_ARGS+=( --test_timeout_filters=short,moderate --flaky_test_attempts=3 ) elif [[ $BRANCH_NAME =~ ^hotfix-.* ]]; then - BAZEL_EXTRA_ARGS+=" --test_timeout_filters=short,moderate" + BAZEL_EXTRA_ARGS+=( --test_timeout_filters=short,moderate ) else - BAZEL_EXTRA_ARGS+=" --keep_going" + BAZEL_EXTRA_ARGS+=( --keep_going ) fi - # Export BAZEL_EXTRA_ARGS to environment - echo "BAZEL_EXTRA_ARGS=$BAZEL_EXTRA_ARGS" >> $GITHUB_ENV + echo "BAZEL_EXTRA_ARGS=${BAZEL_EXTRA_ARGS[@]}" >> $GITHUB_OUTPUT - name: Run Bazel Test All id: bazel-test-all uses: ./.github/actions/bazel-test-all/ with: - BAZEL_COMMAND: test --config=ci ${{ env.BAZEL_EXTRA_ARGS }} + BAZEL_COMMAND: test --config=ci ${{ steps.bazel-extra-args.outputs.BAZEL_EXTRA_ARGS }} BAZEL_TARGETS: //... BUILDEVENT_APIKEY: ${{ secrets.HONEYCOMB_TOKEN }} GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} diff --git a/.github/workflows/update-mainnet-revisions.yaml b/.github/workflows/update-mainnet-revisions.yaml index 75048ccde07..61686fce43e 100644 --- a/.github/workflows/update-mainnet-revisions.yaml +++ b/.github/workflows/update-mainnet-revisions.yaml @@ -28,6 +28,8 @@ jobs: token: ${{ steps.app-token.outputs.token }} - name: Update IC versions file + env: + GH_TOKEN: ${{ steps.app-token.outputs.token }} run: | set -eEuxo pipefail @@ -46,8 +48,8 @@ jobs: uses: actions/create-github-app-token@v1 id: app-token with: - app-id: ${{ vars.PR_CREATION_BOT_PRIVATE_KEY }} - private-key: ${{ secrets.PR_CREATION_BOT_APP_ID }} + app-id: ${{ vars.PR_CREATION_BOT_APP_ID }} + private-key: ${{ secrets.PR_CREATION_BOT_PRIVATE_KEY }} - name: Checkout repository uses: actions/checkout@v4 @@ -61,6 +63,8 @@ jobs: version: 2.53.0 - name: Update Mainnet canisters file + env: + GH_TOKEN: ${{ steps.app-token.outputs.token }} run: | set -eEuxo pipefail diff --git a/Cargo.Bazel.Fuzzing.json.lock b/Cargo.Bazel.Fuzzing.json.lock index d2836cfbb32..19f781fb144 100644 --- a/Cargo.Bazel.Fuzzing.json.lock +++ b/Cargo.Bazel.Fuzzing.json.lock @@ -1,5 +1,5 @@ { - "checksum": "08d6ff34bbb2dba9b147373032c8d6016a785821f41dea400d071133689f395b", + "checksum": "6e56d15113d0514effecb38c75341abadec87f1a36b329d4c79de9c77e1d8fab", "crates": { "abnf 0.12.0": { "name": "abnf", @@ -1304,6 +1304,7 @@ "crate_features": { "common": [ "alloc", + "default", "getrandom", "rand_core" ], @@ -19874,6 +19875,10 @@ "id": "hmac 0.12.1", "target": "hmac" }, + { + "id": "hpke 0.12.0", + "target": "hpke" + }, { "id": "http 1.2.0", "target": "http" @@ -20553,6 +20558,10 @@ "id": "slog-term 2.9.1", "target": "slog_term" }, + { + "id": "slotmap 1.0.7", + "target": "slotmap" + }, { "id": "socket2 0.5.7", "target": "socket2" @@ -22326,6 +22335,7 @@ "arithmetic", "default", "digest", + "ecdh", "ff", "group", "hazmat", @@ -22362,6 +22372,10 @@ "id": "group 0.13.0", "target": "group" }, + { + "id": "hkdf 0.12.4", + "target": "hkdf" + }, { "id": "pem-rfc7468 0.7.0", "target": "pem_rfc7468" @@ -29780,6 +29794,105 @@ ], "license_file": "LICENSE" }, + "hpke 0.12.0": { + "name": "hpke", + "version": "0.12.0", + "package_url": "https://github.com/rozbb/rust-hpke", + "repository": { + "Http": { + "url": "https://static.crates.io/crates/hpke/0.12.0/download", + "sha256": "4917627a14198c3603282c5158b815ad5534795451d3c074b53cf3cee0960b11" + } + }, + "targets": [ + { + "Library": { + "crate_name": "hpke", + "crate_root": "src/lib.rs", + "srcs": { + "allow_empty": true, + "include": [ + "**/*.rs" + ] + } + } + } + ], + "library_target_name": "hpke", + "common_attrs": { + "compile_data_glob": [ + "**" + ], + "crate_features": { + "common": [ + "alloc", + "p384" + ], + "selects": {} + }, + "deps": { + "common": [ + { + "id": "aead 0.5.2", + "target": "aead" + }, + { + "id": "aes-gcm 0.10.3", + "target": "aes_gcm" + }, + { + "id": "chacha20poly1305 0.10.1", + "target": "chacha20poly1305" + }, + { + "id": "digest 0.10.7", + "target": "digest" + }, + { + "id": "generic-array 0.14.7", + "target": "generic_array" + }, + { + "id": "hkdf 0.12.4", + "target": "hkdf" + }, + { + "id": "hmac 0.12.1", + "target": "hmac" + }, + { + "id": "p384 0.13.1", + "target": "p384" + }, + { + "id": "rand_core 0.6.4", + "target": "rand_core" + }, + { + "id": "sha2 0.10.8", + "target": "sha2" + }, + { + "id": "subtle 2.6.1", + "target": "subtle" + }, + { + "id": "zeroize 1.8.1", + "target": "zeroize" + } + ], + "selects": {} + }, + "edition": "2021", + "version": "0.12.0" + }, + "license": "MIT/Apache-2.0", + "license_ids": [ + "Apache-2.0", + "MIT" + ], + "license_file": "LICENSE-APACHE" + }, "html5ever 0.26.0": { "name": "html5ever", "version": "0.26.0", @@ -32866,7 +32979,7 @@ "target": "serde_bytes" }, { - "id": "slotmap 1.0.6", + "id": "slotmap 1.0.7", "target": "slotmap" } ], @@ -49871,6 +49984,65 @@ ], "license_file": "LICENSE-APACHE" }, + "p384 0.13.1": { + "name": "p384", + "version": "0.13.1", + "package_url": "https://github.com/RustCrypto/elliptic-curves/tree/master/p384", + "repository": { + "Http": { + "url": "https://static.crates.io/crates/p384/0.13.1/download", + "sha256": "fe42f1670a52a47d448f14b6a5c61dd78fce51856e68edaa38f7ae3a46b8d6b6" + } + }, + "targets": [ + { + "Library": { + "crate_name": "p384", + "crate_root": "src/lib.rs", + "srcs": { + "allow_empty": true, + "include": [ + "**/*.rs" + ] + } + } + } + ], + "library_target_name": "p384", + "common_attrs": { + "compile_data_glob": [ + "**" + ], + "crate_features": { + "common": [ + "arithmetic", + "ecdh" + ], + "selects": {} + }, + "deps": { + "common": [ + { + "id": "elliptic-curve 0.13.8", + "target": "elliptic_curve" + }, + { + "id": "primeorder 0.13.2", + "target": "primeorder" + } + ], + "selects": {} + }, + "edition": "2021", + "version": "0.13.1" + }, + "license": "Apache-2.0 OR MIT", + "license_ids": [ + "Apache-2.0", + "MIT" + ], + "license_file": "LICENSE-APACHE" + }, "pairing 0.23.0": { "name": "pairing", "version": "0.23.0", @@ -71120,14 +71292,14 @@ ], "license_file": "LICENSE-APACHE" }, - "slotmap 1.0.6": { + "slotmap 1.0.7": { "name": "slotmap", - "version": "1.0.6", + "version": "1.0.7", "package_url": "https://github.com/orlp/slotmap", "repository": { "Http": { - "url": "https://static.crates.io/crates/slotmap/1.0.6/download", - "sha256": "e1e08e261d0e8f5c43123b7adf3e4ca1690d655377ac93a03b2c9d3e98de1342" + "url": "https://static.crates.io/crates/slotmap/1.0.7/download", + "sha256": "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" } }, "targets": [ @@ -71171,14 +71343,14 @@ "deps": { "common": [ { - "id": "slotmap 1.0.6", + "id": "slotmap 1.0.7", "target": "build_script_build" } ], "selects": {} }, "edition": "2018", - "version": "1.0.6" + "version": "1.0.7" }, "build_script_attrs": { "compile_data_glob": [ @@ -91291,6 +91463,7 @@ "hex-literal 0.4.1", "hkdf 0.12.4", "hmac 0.12.1", + "hpke 0.12.0", "http 1.2.0", "http-body 1.0.1", "http-body-util 0.1.2", @@ -91468,6 +91641,7 @@ "slog-json 2.6.1", "slog-scope 4.4.0", "slog-term 2.9.1", + "slotmap 1.0.7", "socket2 0.5.7", "ssh2 0.9.4", "static_assertions 1.1.0", diff --git a/Cargo.Bazel.Fuzzing.toml.lock b/Cargo.Bazel.Fuzzing.toml.lock index 1097c3b08a7..2efda2a6f17 100644 --- a/Cargo.Bazel.Fuzzing.toml.lock +++ b/Cargo.Bazel.Fuzzing.toml.lock @@ -3274,6 +3274,7 @@ dependencies = [ "hex-literal", "hkdf", "hmac", + "hpke", "http 1.2.0", "http-body 1.0.1", "http-body-util", @@ -3451,6 +3452,7 @@ dependencies = [ "slog-json", "slog-scope", "slog-term", + "slotmap", "socket2 0.5.7", "ssh2", "static_assertions", @@ -3769,6 +3771,7 @@ dependencies = [ "ff 0.13.0", "generic-array", "group 0.13.0", + "hkdf", "pem-rfc7468 0.7.0", "pkcs8 0.10.2", "rand_core 0.6.4", @@ -4954,6 +4957,26 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "hpke" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4917627a14198c3603282c5158b815ad5534795451d3c074b53cf3cee0960b11" +dependencies = [ + "aead", + "aes-gcm", + "chacha20poly1305", + "digest 0.10.7", + "generic-array", + "hkdf", + "hmac", + "p384", + "rand_core 0.6.4", + "sha2 0.10.8", + "subtle", + "zeroize", +] + [[package]] name = "html5ever" version = "0.26.0" @@ -8201,6 +8224,16 @@ dependencies = [ "sha2 0.10.8", ] +[[package]] +name = "p384" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe42f1670a52a47d448f14b6a5c61dd78fce51856e68edaa38f7ae3a46b8d6b6" +dependencies = [ + "elliptic-curve 0.13.8", + "primeorder", +] + [[package]] name = "pairing" version = "0.23.0" @@ -11169,9 +11202,9 @@ dependencies = [ [[package]] name = "slotmap" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1e08e261d0e8f5c43123b7adf3e4ca1690d655377ac93a03b2c9d3e98de1342" +checksum = "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" dependencies = [ "version_check", ] diff --git a/Cargo.Bazel.json.lock b/Cargo.Bazel.json.lock index 8524bd63a6b..8ff50908e41 100644 --- a/Cargo.Bazel.json.lock +++ b/Cargo.Bazel.json.lock @@ -1,5 +1,5 @@ { - "checksum": "df8f83940c75aa96f2b7ddeddf586c11aacb29401985b0424589fccd6f5125de", + "checksum": "c503d0e2b21e87e0ccf4248de59b7866351190a9976a27481568e3d30ae4869e", "crates": { "abnf 0.12.0": { "name": "abnf", @@ -1308,6 +1308,7 @@ "crate_features": { "common": [ "alloc", + "default", "getrandom", "rand_core" ], @@ -19702,6 +19703,10 @@ "id": "hmac 0.12.1", "target": "hmac" }, + { + "id": "hpke 0.12.0", + "target": "hpke" + }, { "id": "http 1.2.0", "target": "http" @@ -20381,6 +20386,10 @@ "id": "slog-term 2.9.1", "target": "slog_term" }, + { + "id": "slotmap 1.0.7", + "target": "slotmap" + }, { "id": "socket2 0.5.7", "target": "socket2" @@ -22154,6 +22163,7 @@ "arithmetic", "default", "digest", + "ecdh", "ff", "group", "hazmat", @@ -22190,6 +22200,10 @@ "id": "group 0.13.0", "target": "group" }, + { + "id": "hkdf 0.12.4", + "target": "hkdf" + }, { "id": "pem-rfc7468 0.7.0", "target": "pem_rfc7468" @@ -29635,6 +29649,105 @@ ], "license_file": "LICENSE" }, + "hpke 0.12.0": { + "name": "hpke", + "version": "0.12.0", + "package_url": "https://github.com/rozbb/rust-hpke", + "repository": { + "Http": { + "url": "https://static.crates.io/crates/hpke/0.12.0/download", + "sha256": "4917627a14198c3603282c5158b815ad5534795451d3c074b53cf3cee0960b11" + } + }, + "targets": [ + { + "Library": { + "crate_name": "hpke", + "crate_root": "src/lib.rs", + "srcs": { + "allow_empty": true, + "include": [ + "**/*.rs" + ] + } + } + } + ], + "library_target_name": "hpke", + "common_attrs": { + "compile_data_glob": [ + "**" + ], + "crate_features": { + "common": [ + "alloc", + "p384" + ], + "selects": {} + }, + "deps": { + "common": [ + { + "id": "aead 0.5.2", + "target": "aead" + }, + { + "id": "aes-gcm 0.10.3", + "target": "aes_gcm" + }, + { + "id": "chacha20poly1305 0.10.1", + "target": "chacha20poly1305" + }, + { + "id": "digest 0.10.7", + "target": "digest" + }, + { + "id": "generic-array 0.14.7", + "target": "generic_array" + }, + { + "id": "hkdf 0.12.4", + "target": "hkdf" + }, + { + "id": "hmac 0.12.1", + "target": "hmac" + }, + { + "id": "p384 0.13.1", + "target": "p384" + }, + { + "id": "rand_core 0.6.4", + "target": "rand_core" + }, + { + "id": "sha2 0.10.8", + "target": "sha2" + }, + { + "id": "subtle 2.6.1", + "target": "subtle" + }, + { + "id": "zeroize 1.8.1", + "target": "zeroize" + } + ], + "selects": {} + }, + "edition": "2021", + "version": "0.12.0" + }, + "license": "MIT/Apache-2.0", + "license_ids": [ + "Apache-2.0", + "MIT" + ], + "license_file": "LICENSE-APACHE" + }, "html5ever 0.26.0": { "name": "html5ever", "version": "0.26.0", @@ -32721,7 +32834,7 @@ "target": "serde_bytes" }, { - "id": "slotmap 1.0.6", + "id": "slotmap 1.0.7", "target": "slotmap" } ], @@ -49711,6 +49824,65 @@ ], "license_file": "LICENSE-APACHE" }, + "p384 0.13.1": { + "name": "p384", + "version": "0.13.1", + "package_url": "https://github.com/RustCrypto/elliptic-curves/tree/master/p384", + "repository": { + "Http": { + "url": "https://static.crates.io/crates/p384/0.13.1/download", + "sha256": "fe42f1670a52a47d448f14b6a5c61dd78fce51856e68edaa38f7ae3a46b8d6b6" + } + }, + "targets": [ + { + "Library": { + "crate_name": "p384", + "crate_root": "src/lib.rs", + "srcs": { + "allow_empty": true, + "include": [ + "**/*.rs" + ] + } + } + } + ], + "library_target_name": "p384", + "common_attrs": { + "compile_data_glob": [ + "**" + ], + "crate_features": { + "common": [ + "arithmetic", + "ecdh" + ], + "selects": {} + }, + "deps": { + "common": [ + { + "id": "elliptic-curve 0.13.8", + "target": "elliptic_curve" + }, + { + "id": "primeorder 0.13.2", + "target": "primeorder" + } + ], + "selects": {} + }, + "edition": "2021", + "version": "0.13.1" + }, + "license": "Apache-2.0 OR MIT", + "license_ids": [ + "Apache-2.0", + "MIT" + ], + "license_file": "LICENSE-APACHE" + }, "pairing 0.23.0": { "name": "pairing", "version": "0.23.0", @@ -70999,14 +71171,14 @@ ], "license_file": "LICENSE-APACHE" }, - "slotmap 1.0.6": { + "slotmap 1.0.7": { "name": "slotmap", - "version": "1.0.6", + "version": "1.0.7", "package_url": "https://github.com/orlp/slotmap", "repository": { "Http": { - "url": "https://static.crates.io/crates/slotmap/1.0.6/download", - "sha256": "e1e08e261d0e8f5c43123b7adf3e4ca1690d655377ac93a03b2c9d3e98de1342" + "url": "https://static.crates.io/crates/slotmap/1.0.7/download", + "sha256": "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" } }, "targets": [ @@ -71050,14 +71222,14 @@ "deps": { "common": [ { - "id": "slotmap 1.0.6", + "id": "slotmap 1.0.7", "target": "build_script_build" } ], "selects": {} }, "edition": "2018", - "version": "1.0.6" + "version": "1.0.7" }, "build_script_attrs": { "compile_data_glob": [ @@ -91204,6 +91376,7 @@ "hex-literal 0.4.1", "hkdf 0.12.4", "hmac 0.12.1", + "hpke 0.12.0", "http 1.2.0", "http-body 1.0.1", "http-body-util 0.1.2", @@ -91381,6 +91554,7 @@ "slog-json 2.6.1", "slog-scope 4.4.0", "slog-term 2.9.1", + "slotmap 1.0.7", "socket2 0.5.7", "ssh2 0.9.4", "static_assertions 1.1.0", diff --git a/Cargo.Bazel.toml.lock b/Cargo.Bazel.toml.lock index 7726a717b79..ec1e7c8c239 100644 --- a/Cargo.Bazel.toml.lock +++ b/Cargo.Bazel.toml.lock @@ -3263,6 +3263,7 @@ dependencies = [ "hex-literal", "hkdf", "hmac", + "hpke", "http 1.2.0", "http-body 1.0.1", "http-body-util", @@ -3440,6 +3441,7 @@ dependencies = [ "slog-json", "slog-scope", "slog-term", + "slotmap", "socket2 0.5.7", "ssh2", "static_assertions", @@ -3758,6 +3760,7 @@ dependencies = [ "ff 0.13.0", "generic-array", "group 0.13.0", + "hkdf", "pem-rfc7468 0.7.0", "pkcs8 0.10.2", "rand_core 0.6.4", @@ -4944,6 +4947,26 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "hpke" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4917627a14198c3603282c5158b815ad5534795451d3c074b53cf3cee0960b11" +dependencies = [ + "aead", + "aes-gcm", + "chacha20poly1305", + "digest 0.10.7", + "generic-array", + "hkdf", + "hmac", + "p384", + "rand_core 0.6.4", + "sha2 0.10.8", + "subtle", + "zeroize", +] + [[package]] name = "html5ever" version = "0.26.0" @@ -8192,6 +8215,16 @@ dependencies = [ "sha2 0.10.8", ] +[[package]] +name = "p384" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe42f1670a52a47d448f14b6a5c61dd78fce51856e68edaa38f7ae3a46b8d6b6" +dependencies = [ + "elliptic-curve 0.13.8", + "primeorder", +] + [[package]] name = "pairing" version = "0.23.0" @@ -11165,9 +11198,9 @@ dependencies = [ [[package]] name = "slotmap" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1e08e261d0e8f5c43123b7adf3e4ca1690d655377ac93a03b2c9d3e98de1342" +checksum = "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" dependencies = [ "version_check", ] diff --git a/Cargo.lock b/Cargo.lock index 63f312e0aa3..3a26eb719a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3920,6 +3920,7 @@ dependencies = [ "ff 0.13.0", "generic-array", "group 0.13.0", + "hkdf", "pem-rfc7468 0.7.0", "pkcs8 0.10.2", "rand_core 0.6.4", @@ -5253,6 +5254,26 @@ dependencies = [ "utils", ] +[[package]] +name = "hpke" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4917627a14198c3603282c5158b815ad5534795451d3c074b53cf3cee0960b11" +dependencies = [ + "aead", + "aes-gcm", + "chacha20poly1305", + "digest 0.10.7", + "generic-array", + "hkdf", + "hmac", + "p384", + "rand_core 0.6.4", + "sha2 0.10.8", + "subtle", + "zeroize", +] + [[package]] name = "html5ever" version = "0.26.0" @@ -6551,6 +6572,7 @@ dependencies = [ "strum", "strum_macros", "tempfile", + "test-strategy 0.4.0", ] [[package]] @@ -6568,6 +6590,7 @@ dependencies = [ "rand 0.8.5", "rand_chacha 0.3.1", "scoped_threadpool", + "test-strategy 0.4.0", "thiserror 2.0.11", ] @@ -6992,6 +7015,7 @@ dependencies = [ "ic-consensus-dkg", "ic-consensus-mocks", "ic-consensus-utils", + "ic-consensus-vetkd", "ic-crypto", "ic-crypto-prng", "ic-crypto-temp-crypto", @@ -7174,6 +7198,41 @@ dependencies = [ "slog", ] +[[package]] +name = "ic-consensus-vetkd" +version = "0.9.0" +dependencies = [ + "assert_matches", + "ic-artifact-pool", + "ic-consensus-mocks", + "ic-consensus-utils", + "ic-crypto", + "ic-error-types", + "ic-interfaces", + "ic-interfaces-registry", + "ic-interfaces-state-manager", + "ic-logger", + "ic-management-canister-types-private", + "ic-metrics", + "ic-protobuf", + "ic-registry-client-fake", + "ic-registry-client-helpers", + "ic-registry-keys", + "ic-registry-subnet-features", + "ic-replicated-state", + "ic-test-utilities", + "ic-test-utilities-registry", + "ic-test-utilities-state", + "ic-test-utilities-types", + "ic-types", + "ic-types-test-utils", + "prometheus", + "prost 0.13.4", + "slog", + "strum", + "strum_macros", +] + [[package]] name = "ic-crypto" version = "0.9.0" @@ -8246,6 +8305,7 @@ dependencies = [ "serde", "serde_bytes", "serde_cbor", + "test-strategy 0.4.0", "thiserror 2.0.11", ] @@ -8258,6 +8318,7 @@ dependencies = [ "ic-crypto-tree-hash", "proptest 1.6.0", "rand 0.8.5", + "test-strategy 0.4.0", "thiserror 2.0.11", ] @@ -8667,6 +8728,16 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ic-hpke" +version = "0.1.0" +dependencies = [ + "hex", + "hpke", + "rand 0.8.5", + "rand_chacha 0.3.1", +] + [[package]] name = "ic-http-certification" version = "3.0.2" @@ -10404,15 +10475,26 @@ dependencies = [ "async-trait", "candid", "canister-test", + "futures", "ic-cdk 0.17.1", - "ic-cdk-timers", "ic-config", + "ic-metrics-encoder", + "ic-nervous-system-time-helpers", + "ic-nervous-system-timers", "ic-registry-subnet-type", "ic-state-machine-tests", "ic-types", "serde", ] +[[package]] +name = "ic-nervous-system-timers" +version = "0.9.0" +dependencies = [ + "ic-cdk-timers", + "slotmap", +] + [[package]] name = "ic-nervous-system-timestamp" version = "0.0.1" @@ -10563,6 +10645,7 @@ dependencies = [ "ic-nervous-system-temporary", "ic-nervous-system-time-helpers", "ic-nervous-system-timer-task", + "ic-nervous-system-timers", "ic-neurons-fund", "ic-nns-common", "ic-nns-constants", @@ -11741,6 +11824,7 @@ dependencies = [ "ic-consensus-dkg", "ic-consensus-manager", "ic-consensus-utils", + "ic-consensus-vetkd", "ic-crypto-interfaces-sig-verification", "ic-crypto-tls-interfaces", "ic-cycles-account-manager", @@ -12727,6 +12811,7 @@ dependencies = [ "prost 0.13.4", "scoped_threadpool", "slog", + "test-strategy 0.4.0", ] [[package]] @@ -12799,6 +12884,7 @@ dependencies = [ "slog", "slog-term", "tempfile", + "test-strategy 0.4.0", "tokio", "tokio-util", "tower 0.5.2", @@ -12866,6 +12952,7 @@ dependencies = [ "strum", "strum_macros", "tempfile", + "test-strategy 0.4.0", "tree-deserializer", "uuid", ] @@ -14030,6 +14117,7 @@ dependencies = [ "reqwest 0.12.12", "slog", "tempfile", + "test-strategy 0.4.0", "thiserror 2.0.11", "tokio", "url", @@ -17228,6 +17316,16 @@ dependencies = [ "sha2 0.10.8", ] +[[package]] +name = "p384" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe42f1670a52a47d448f14b6a5c61dd78fce51856e68edaa38f7ae3a46b8d6b6" +dependencies = [ + "elliptic-curve 0.13.8", + "primeorder", +] + [[package]] name = "pairing" version = "0.23.0" @@ -22267,6 +22365,7 @@ dependencies = [ "proptest 1.6.0", "proptest-derive", "serde", + "test-strategy 0.4.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 391389cbd7c..1c63ec10d0b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "packages/ic-ed25519", "packages/ic-ethereum-types", "packages/ic-error-types", + "packages/ic-hpke", "packages/ic-metrics-assert", "packages/ic-secp256k1", "packages/ic-sha3", @@ -203,6 +204,7 @@ members = [ "rs/nervous_system/string", "rs/nervous_system/temporary", "rs/nervous_system/time_helpers", + "rs/nervous_system/timers", "rs/nervous_system/tools/release-runscript", "rs/nervous_system/tools/sync-with-released-nervous-system-wasms", "rs/nns/constants", @@ -700,6 +702,7 @@ slog = { version = "2.7.0", features = [ slog-async = { version = "2.8.0", features = ["nested-values"] } slog-scope = "4.4.0" slog-term = "2.9.1" +slotmap = "1.0.7" socket2 = { version = "0.5.7", features = ["all"] } ssh2 = "0.9.4" static_assertions = "1.1.0" diff --git a/bazel/conf/.bazelrc.build b/bazel/conf/.bazelrc.build index 776250fd770..d487c0ff7cd 100644 --- a/bazel/conf/.bazelrc.build +++ b/bazel/conf/.bazelrc.build @@ -72,10 +72,6 @@ test:systest --test_output=streamed --test_tag_filters= build:testnet --build_tag_filters= test:testnet --test_output=streamed --test_tag_filters= -# For sandboxed actions, mount an empty, writable directory at this absolute path -# (if supported by the sandboxing implementation, ignored otherwise). -test --sandbox_tmpfs_path=/tmp - # TODO(IDX-2374): enable alltests in CI when we will have actual system tests. #test:ci --config=alltests diff --git a/bazel/external_crates.bzl b/bazel/external_crates.bzl index 1efd95d0dbe..c32a105fd67 100644 --- a/bazel/external_crates.bzl +++ b/bazel/external_crates.bzl @@ -537,6 +537,11 @@ def external_crates_repository(name, cargo_lockfile, lockfile, sanitizers_enable "hmac": crate.spec( version = "^0.12", ), + "hpke": crate.spec( + version = "^0.12", + default_features = False, + features = ["p384", "alloc"], + ), "humantime": crate.spec( version = "^2.1.0", ), @@ -1236,6 +1241,9 @@ def external_crates_repository(name, cargo_lockfile, lockfile, sanitizers_enable "slog-term": crate.spec( version = "^2.9.1", ), + "slotmap": crate.spec( + version = "^1.0.7", + ), "socket2": crate.spec( version = "^0.5.7", features = [ diff --git a/ci/bazel-scripts/main.sh b/ci/bazel-scripts/main.sh index a91a08b2fa1..a0805cfaf7c 100755 --- a/ci/bazel-scripts/main.sh +++ b/ci/bazel-scripts/main.sh @@ -38,6 +38,9 @@ fi # if bazel targets is empty we don't need to run any tests if [ -z "${BAZEL_TARGETS:-}" ]; then echo "No bazel targets to build" + # create empty SHA256SUMS for build determinism + # (not ideal but temporary until we can improve or get rid of diff.sh) + touch SHA256SUMS exit 0 fi diff --git a/ic-os/boundary-guestos/context/docker-base.prod b/ic-os/boundary-guestos/context/docker-base.prod index 06160ff34c5..4729f2301e0 100644 --- a/ic-os/boundary-guestos/context/docker-base.prod +++ b/ic-os/boundary-guestos/context/docker-base.prod @@ -1 +1 @@ -ghcr.io/dfinity/boundaryos-base@sha256:a120ba3ea895b1889ffa7123b9131defb8398d347fcf7dbdc30b08b6123aa8d6 +ghcr.io/dfinity/boundaryos-base@sha256:74091ddbfb5353e0547509b567b8ca49e39f16881d0ce1ed21b4dd30d7453206 diff --git a/ic-os/components/selinux/ic-node/ic-node.te b/ic-os/components/selinux/ic-node/ic-node.te index 9c8fed12c57..5cb37e7ca0b 100644 --- a/ic-os/components/selinux/ic-node/ic-node.te +++ b/ic-os/components/selinux/ic-node/ic-node.te @@ -312,9 +312,10 @@ allow ic_canister_sandbox_t ic_replica_t : unix_stream_socket { setopt read writ # labelled differently eventually because allowing tmpfs is fairly broad. require { type tmpfs_t; } allow ic_canister_sandbox_t tmpfs_t : file { map read write getattr }; -# Also allow read-only access to checkpoint files (also given passed via -# file descriptor from replica). -allow ic_canister_sandbox_t ic_data_t : file { map read getattr }; +# Also allow read and execute access to checkpoint files (also given passed via +# file descriptor from replica). Execute is necessary to run the complied Wasm modules +# stored in the tmp directory in ic_data_t. +allow ic_canister_sandbox_t ic_data_t : file { map read execute getattr }; # Allow read/write access to files that back the heap delta for both sandbox and replica # The workflow is that the replica creates the files but passes a file descriptor to the sandbox # We explicitly do not allow the sandbox to open files because they should only be open by the replica diff --git a/ic-os/guestos/context/docker-base.dev b/ic-os/guestos/context/docker-base.dev index 2c76b04ad20..0c34e665be9 100644 --- a/ic-os/guestos/context/docker-base.dev +++ b/ic-os/guestos/context/docker-base.dev @@ -1 +1 @@ -ghcr.io/dfinity/guestos-base-dev@sha256:7c525336d4471f94a15e393512c6d7a809f460fce9b87858dccdb50aae8658fe +ghcr.io/dfinity/guestos-base-dev@sha256:c1af096b08acd6b8419457f2a53b1dbaf24517b1f838427dabff7fcb9e74adfa diff --git a/ic-os/guestos/context/docker-base.prod b/ic-os/guestos/context/docker-base.prod index 8da15950525..80cdf6ccfc1 100644 --- a/ic-os/guestos/context/docker-base.prod +++ b/ic-os/guestos/context/docker-base.prod @@ -1 +1 @@ -ghcr.io/dfinity/guestos-base@sha256:fcae741136de3d6ebe990786d2c887d33c58b5a2cc110c052e26178959998c8e +ghcr.io/dfinity/guestos-base@sha256:04d9836f55c2e1d2f3f629e20180c9c868dfe5b1109dbaf6dfca07da09ce16d9 diff --git a/ic-os/setupos/context/docker-base.dev b/ic-os/setupos/context/docker-base.dev index de2016497b0..434afecd1ba 100644 --- a/ic-os/setupos/context/docker-base.dev +++ b/ic-os/setupos/context/docker-base.dev @@ -1 +1 @@ -ghcr.io/dfinity/setupos-base-dev@sha256:a7d4991d8a851850aea449639aa0a63dc57e8be53f96efb1c1c17773c49be7b1 +ghcr.io/dfinity/setupos-base-dev@sha256:e64e9f9706bd3cce6af8da662dc54c3f361f056b9214a2b049b73939baa0c3b6 diff --git a/ic-os/setupos/context/docker-base.prod b/ic-os/setupos/context/docker-base.prod index 19804bcc087..aa18c6733d0 100644 --- a/ic-os/setupos/context/docker-base.prod +++ b/ic-os/setupos/context/docker-base.prod @@ -1 +1 @@ -ghcr.io/dfinity/setupos-base@sha256:ca3855a5128c56379e57d6524ae7c205600506e1fa2cca55167887ce37cf8406 +ghcr.io/dfinity/setupos-base@sha256:eca2d6085bcb7ea03d4033670873fe20e73082a6464d61ccb79ee3320ee94bc6 diff --git a/mainnet-subnet-revisions.json b/mainnet-subnet-revisions.json index 3d9d262c86e..7acb154f6f3 100644 --- a/mainnet-subnet-revisions.json +++ b/mainnet-subnet-revisions.json @@ -1,6 +1,6 @@ { "subnets": { - "tdb26-jop6k-aogll-7ltgs-eruif-6kk7m-qpktf-gdiqx-mxtrf-vb5e6-eqe": "2008d47a169c4984631c87f2efaa88798e6f14dc", - "io67a-2jmkw-zup3h-snbwi-g6a5n-rm5dn-b6png-lvdpl-nqnto-yih6l-gqe": "7147f471c7ac27f518e6c0eeb2015952b5e93e1b" + "tdb26-jop6k-aogll-7ltgs-eruif-6kk7m-qpktf-gdiqx-mxtrf-vb5e6-eqe": "7147f471c7ac27f518e6c0eeb2015952b5e93e1b", + "io67a-2jmkw-zup3h-snbwi-g6a5n-rm5dn-b6png-lvdpl-nqnto-yih6l-gqe": "6e64281a8e0b4faa1d859f115fc138eee6e136f8" } } \ No newline at end of file diff --git a/packages/ic-hpke/BUILD.bazel b/packages/ic-hpke/BUILD.bazel new file mode 100644 index 00000000000..d2724445790 --- /dev/null +++ b/packages/ic-hpke/BUILD.bazel @@ -0,0 +1,58 @@ +load("@rules_rust//rust:defs.bzl", "rust_doc", "rust_doc_test", "rust_library", "rust_test", "rust_test_suite") + +package(default_visibility = ["//visibility:public"]) + +DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:hpke", +] + +MACRO_DEPENDENCIES = [] + +DEV_DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:hex", + "@crate_index//:rand", + "@crate_index//:rand_chacha", +] + +MACRO_DEV_DEPENDENCIES = [] + +ALIASES = {} + +rust_library( + name = "ic-hpke", + srcs = glob(["src/**/*.rs"]), + aliases = ALIASES, + crate_name = "ic_hpke", + proc_macro_deps = MACRO_DEPENDENCIES, + version = "0.1.0", + deps = DEPENDENCIES, +) + +rust_doc( + name = "doc", + crate = ":ic-hpke", +) + +rust_doc_test( + name = "doc_test", + crate = ":ic-hpke", + deps = [":ic-hpke"] + DEPENDENCIES + DEV_DEPENDENCIES, +) + +rust_test( + name = "test", + aliases = ALIASES, + crate = ":ic-hpke", + proc_macro_deps = MACRO_DEPENDENCIES + MACRO_DEV_DEPENDENCIES, + deps = DEPENDENCIES + DEV_DEPENDENCIES, +) + +rust_test_suite( + name = "integration_tests", + srcs = glob(["tests/**/*.rs"]), + aliases = ALIASES, + proc_macro_deps = MACRO_DEPENDENCIES + MACRO_DEV_DEPENDENCIES, + deps = [":ic-hpke"] + DEPENDENCIES + DEV_DEPENDENCIES, +) diff --git a/packages/ic-hpke/CHANGELOG.md b/packages/ic-hpke/CHANGELOG.md new file mode 100644 index 00000000000..d4d0c9c1b29 --- /dev/null +++ b/packages/ic-hpke/CHANGELOG.md @@ -0,0 +1,10 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [0.1.0] - Not Yet Released + +Initial release. diff --git a/packages/ic-hpke/Cargo.toml b/packages/ic-hpke/Cargo.toml new file mode 100644 index 00000000000..c81d496b0e4 --- /dev/null +++ b/packages/ic-hpke/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "ic-hpke" +version = "0.1.0" +description = "A package created for the Internet Computer Protocol for Hybrid Public Key Encryption" +license = "Apache-2.0" +readme = "README.md" +include = ["src", "Cargo.toml", "CHANGELOG.md", "LICENSE", "README.md"] +repository = "https://github.com/dfinity/ic" +authors.workspace = true +edition.workspace = true +documentation.workspace = true + +[dependencies] +hpke = { version = "0.12", default-features = false, features = [ "p384", "alloc" ] } + +[dev-dependencies] +hex = "0.4" +rand = { version = "0.8", default-features = false, features = ["getrandom"] } +rand_chacha = { version = "0.3", default-features = false } diff --git a/packages/ic-hpke/LICENSE b/packages/ic-hpke/LICENSE new file mode 120000 index 00000000000..c87d0654aa6 --- /dev/null +++ b/packages/ic-hpke/LICENSE @@ -0,0 +1 @@ +../../licenses/Apache-2.0.txt \ No newline at end of file diff --git a/packages/ic-hpke/README.md b/packages/ic-hpke/README.md new file mode 100644 index 00000000000..2906fcbb48d --- /dev/null +++ b/packages/ic-hpke/README.md @@ -0,0 +1,5 @@ +# IC HPKE + +A package created for the Internet Computer Protocol for encrypting messages to +a recipient public key. + diff --git a/packages/ic-hpke/src/lib.rs b/packages/ic-hpke/src/lib.rs new file mode 100644 index 00000000000..5b21efdd01e --- /dev/null +++ b/packages/ic-hpke/src/lib.rs @@ -0,0 +1,443 @@ +#![forbid(unsafe_code)] +#![forbid(missing_docs)] + +//! Public Key Encryption Utility +//! +//! This crate offers functionality for encrypting messages using a public key, +//! with optional sender authentication. +//! +//! All binary strings produced by this crate include protocol and version +//! identifiers, which will allow algorithm rotation in the future should this +//! be necessary (for example to support a post quantum scheme) +//! +//! Two different modes are offered, namely authenticated and non-authenticated. +//! +//! When sending an authenticated message, the sender also uses their private key. +//! Decrypting the message takes as input both the recipients private key and the +//! purported senders public key. Decryption will only succeed if the sender of that +//! ciphertext did in fact have access to the associated private key. +//! +//! A non-authenticated message encrypts the message to the public key, but does not +//! provide any kind of source authentication. Thus the receiver can decrypt the message, +//! but does not have any idea who it came from; anyone can encrypt a message to the +//! recipients public key. +//! +//! Both modes can make use of an `associated_data` parameter. The `associated_data` field is +//! information which is not encrypted, nor is it included in the returned ciphertext +//! blob. However it is implicitly authenticated by a successful decryption; that is, if +//! the decrypting side uses the same `associated_data` parameter during decryption, then +//! decryption will succeed and the decryptor knows that the `associated_data` field they +//! used is also authentic, and is associated with that ciphertext message. If the +//! encryptor and decryptor disagree on the `associated_data` field, then decryption will +//! fail. Commonly, the `associated_data` is used to bind additional information about the +//! context which both the sender and receiver will know, for example a protocol identifer. +//! If no such information is available, the associated data can be set to an empty slice. +//! +//! # Example (Authenticated Encryption) +//! +//! ``` +//! let mut rng = rand::rngs::OsRng; +//! +//! let a_sk = ic_hpke::PrivateKey::generate(&mut rng); +//! let a_pk = a_sk.public_key(); +//! +//! let b_sk = ic_hpke::PrivateKey::generate(&mut rng); +//! let b_pk = b_sk.public_key(); +//! +//! // We assume the two public keys can be exchanged in a trusted way beforehand +//! +//! let msg = b"this is only a test"; +//! let associated_data = b"example-protocol-v2-with-auth"; +//! +//! let ctext = a_pk.encrypt(msg, associated_data, &b_sk, &mut rng).unwrap(); +//! +//! let recovered_msg = a_sk.decrypt(&ctext, associated_data, &b_pk).unwrap(); +//! assert_eq!(recovered_msg, msg, "failed to decrypt message"); +//! +//! // If recipient accidentally tries to decrypt without authentication, decryption fails +//! assert!(a_sk.decrypt_noauth(&ctext, associated_data).is_err()); +//! // If associated data is incorrect, decryption fails +//! assert!(a_sk.decrypt(&ctext, b"wrong-associated-data", &b_pk).is_err()); +//! // If the wrong public key is used, decryption fails +//! assert!(a_sk.decrypt(&ctext, associated_data, &a_pk).is_err()); +//! ``` +//! +//! # Example (Non-Authenticated Encryption) +//! +//! ``` +//! let mut rng = rand::rngs::OsRng; +//! +//! // perform key generation: +//! let sk = ic_hpke::PrivateKey::generate(&mut rng); +//! let sk_bytes = sk.serialize(); +//! // save sk_bytes to secure storage... +//! let pk_bytes = sk.public_key().serialize(); +//! // publish pk_bytes +//! +//! // Now someone can encrypt a message to your key: +//! let msg = b"attack at dawn"; +//! let associated_data = b"example-protocol-v1"; +//! let pk = ic_hpke::PublicKey::deserialize(&pk_bytes).unwrap(); +//! let ctext = pk.encrypt_noauth(msg, associated_data, &mut rng).unwrap(); +//! +//! // Upon receipt, decrypt the ciphertext: +//! let recovered_msg = sk.decrypt_noauth(&ctext, associated_data).unwrap(); +//! assert_eq!(recovered_msg, msg, "failed to decrypt message"); +//! +//! // If associated data is incorrect, decryption fails +//! assert!(sk.decrypt_noauth(&ctext, b"wrong-associated-data").is_err()); +//! ``` + +use hpke::rand_core::{CryptoRng, RngCore}; +use hpke::{ + aead::AesGcm256, kdf::HkdfSha384, kem::DhP384HkdfSha384, Deserializable, Kem, Serializable, +}; + +/* + * All artifacts produced by this crate - public keys, private keys, and + * ciphertexts, are prefixed with a header. + * + * The header starts with a 56 bit long "magic" field. This makes the values + * easy to identity - someone searching for this hex string will likely find + * this crate right away. It also ensures that values we accept were intended + * specifically for us. The string is the ASCII bytes of "IC HPKE". + * + * The next part is a 8 bit version field. This is currently 1; we only + * implement this one version. The field allows us extensibility in the future if + * needed, eg to handle a transition to post-quantum. + */ + +// An arbitrary magic number that is prefixed to all artifacts to identify them +// plus our initial version (01) +const MAGIC: u64 = 0x49432048504b4501; + +// Current header is just the magic + version field +const HEADER_SIZE: usize = 8; + +/* + * V1 KEM + * ======== + * + * The V1 kem uses HPKE from RFC 9180 using P-384 with HKDF-SHA-384 + * and AES-256 in GCM mode. + */ +type V1Kem = DhP384HkdfSha384; +type V1Kdf = HkdfSha384; +type V1Aead = AesGcm256; + +// The amount of random material which is used to derive a secret key +// +// RFC 9180 requires this be at least 48 bytes (for P-384), we somewhat arbitrarily use 64 +const V1_IKM_LENGTH: usize = 64; + +type V1PublicKey = ::PublicKey; +type V1PrivateKey = ::PrivateKey; + +/* + * A helper macro for reading the header and optionally checking the length + */ +macro_rules! check_header { + (@common $err:ty, $val:expr) => { + if $val.len() < HEADER_SIZE { + Err(<$err>::InvalidLength) + } else { + let magic = u64::from_be_bytes( + <[u8; 8]>::try_from(&$val[0..HEADER_SIZE]).expect("Conversion cannot fail"), + ); + if magic != MAGIC { + Err(<$err>::UnknownMagic) + } else { + Ok($val.len() - HEADER_SIZE) + } + } + }; + ($err:ty, $val:expr) => { + check_header!(@common $err, $val) + }; + ($err:ty, $val:expr, $req_len:expr) => { + match check_header!(@common $err, $val) { + Ok(len) => { + if len == $req_len { + Ok(()) + } else { + Err(<$err>::InvalidLength) + } + } + Err(e) => Err(e), + } + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +/// An error occured while deserializing a key +pub enum KeyDeserializationError { + /// The protocol identifier or version field was unknown to us + UnknownMagic, + /// The key was of a length that is not possibly valid + InvalidLength, + /// The header was valid but the key was invalid + InvalidKey, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +/// An error occurred during encryption +/// +/// Logically there is no reason for encryption to fail in this +/// context, but unfortunately some of the implementation has +/// falliable interfaces. Rather than unwrapping and risking +/// a panic, we pass this falliability onto the caller +pub enum EncryptionError { + /// Some error occurred during encryption + InternalError, +} + +#[derive(Clone)] +/// A public key usable for encryption +pub struct PublicKey { + pk: V1PublicKey, +} + +impl PublicKey { + /// Serialize the public key to a bytestring + pub fn serialize(&self) -> Vec { + let len = ::size(); + let mut buf = vec![0u8; HEADER_SIZE + len]; + buf[0..HEADER_SIZE].copy_from_slice(&MAGIC.to_be_bytes()); + self.pk.write_exact(&mut buf[HEADER_SIZE..]); + buf + } + + /// Deserialize a public key + pub fn deserialize(bytes: &[u8]) -> Result { + let len = ::size(); + + check_header!(KeyDeserializationError, bytes, len)?; + + match V1PublicKey::from_bytes(&bytes[HEADER_SIZE..]) { + Ok(pk) => Ok(Self { pk }), + Err(_) => Err(KeyDeserializationError::InvalidKey), + } + } + + /// Encrypt a message with sender authentication + /// + /// This encrypts a message using the recipients public key, and + /// additionally authenticates the message using the provided private key. + /// The decrypting side must know the recipients public key in order to + /// decrypt the message. + /// + /// The `associated_data` field is information which is not encrypted, nor is + /// it included in the returned blob. However it is implicitly authenticated + /// by a successful decryption; that is, if the decrypting side uses the + /// same `associated_data` parameter during decryption, then decryption will + /// succeed and the decryptor knows that this field is also authentic. If + /// the encryptor and decryptor disagree on the `associated_data` field, then + /// decryption will fail. If not needed, `associated_data` can be set to an + /// empty slice + /// + /// The recipient must use [`PrivateKey::decrypt`] to decrypt + pub fn encrypt( + &self, + msg: &[u8], + associated_data: &[u8], + sender: &PrivateKey, + rng: &mut R, + ) -> Result, EncryptionError> { + let opmode = hpke::OpModeS::::Auth((sender.sk.clone(), sender.pk.clone())); + self._v1_encrypt(&opmode, msg, associated_data, rng) + } + + fn _v1_encrypt( + &self, + opmode: &hpke::OpModeS, + msg: &[u8], + associated_data: &[u8], + rng: &mut R, + ) -> Result, EncryptionError> { + let mut buf = vec![]; + buf.extend_from_slice(&MAGIC.to_be_bytes()); + + // Note that &buf, containing the header, is passed as the "info" + // parameter; this is then used during the HKDF expansion, ensuring this + // ciphertext can only be decrypted in this context and not, for example, + // under a different version + // + // An alternative, had the info parameter not been available, would be to + // concatenate the user-provided associated data and the header + // together. However this separation is neater and avoids having to + // allocate and copy the associated data. + + let (hpke_key, hpke_ctext) = hpke::single_shot_seal::( + opmode, + &self.pk, + &buf, + msg, + associated_data, + rng, + ) + .map_err(|_| EncryptionError::InternalError)?; + + buf.extend_from_slice(&hpke_key.to_bytes()); + + buf.extend_from_slice(&hpke_ctext); + + Ok(buf) + } + + /// Encrypt a message without sender authentication + /// + /// This encrypts a message to the public key such that whoever + /// knows the associated private key can decrypt the message. + /// + /// The `associated_data` field is information which is not encrypted, nor is + /// it included in the returned blob. However it is implicitly authenticated + /// by a successful decryption; that is, if the decrypting side uses the + /// same `associated_data` parameter during decryption, then decryption will + /// succeed and the decryptor knows that this field is also authentic. If + /// the encryptor and decryptor disagree on the `associated_data` field, then + /// decryption will fail. If not needed, `associated_data` can be set to an + /// empty slice. + /// + /// This function provides no guarantees to the recipient about who sent it; + /// anyone can encrypt a message with this function. + /// + /// The recipient must use [`PrivateKey::decrypt_noauth`] to decrypt + pub fn encrypt_noauth( + &self, + msg: &[u8], + associated_data: &[u8], + rng: &mut R, + ) -> Result, EncryptionError> { + let opmode = hpke::OpModeS::::Base; + self._v1_encrypt(&opmode, msg, associated_data, rng) + } + + fn new(pk: V1PublicKey) -> Self { + Self { pk } + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +/// An error occured while decrypting a message +pub enum DecryptionError { + /// The protocol identifier/version field did not match + UnknownMagic, + /// The length was wrong + InvalidLength, + /// The header was valid but the decryption failed + InvalidCiphertext, +} + +#[derive(Clone)] +/// A private key usable for decryption +pub struct PrivateKey { + sk: V1PrivateKey, + pk: V1PublicKey, +} + +impl PrivateKey { + /// Generate a new random private key + pub fn generate(rng: &mut R) -> Self { + let mut ikm = [0; V1_IKM_LENGTH]; + rng.fill_bytes(&mut ikm); + let (sk, pk) = ::derive_keypair(&ikm); + Self { sk, pk } + } + + /// Return the associated public key + pub fn public_key(&self) -> PublicKey { + PublicKey::new(self.pk.clone()) + } + + /// Serialize this private key + pub fn serialize(&self) -> Vec { + let len = ::size(); + let mut buf = vec![0u8; HEADER_SIZE + len]; + buf[0..HEADER_SIZE].copy_from_slice(&MAGIC.to_be_bytes()); + self.sk.write_exact(&mut buf[HEADER_SIZE..]); + buf + } + + /// Deserialize a private key previously serialized with [`PrivateKey::serialize`] + pub fn deserialize(bytes: &[u8]) -> Result { + let len = ::size(); + + check_header!(KeyDeserializationError, bytes, len)?; + + match V1PrivateKey::from_bytes(&bytes[HEADER_SIZE..]) { + Ok(sk) => { + let pk = ::sk_to_pk(&sk); + Ok(Self { sk, pk }) + } + Err(_) => Err(KeyDeserializationError::InvalidKey), + } + } + + /// Decrypt a message with sender authentication + /// + /// This is the counterpart to [`PublicKey::encrypt`] + /// + /// This function provides sender authentication; if decryption succeeds + /// then it is mathematically guaranteed that the sender had access + /// to the secret key associated with `sender` + pub fn decrypt( + &self, + msg: &[u8], + associated_data: &[u8], + sender: &PublicKey, + ) -> Result, DecryptionError> { + let opmode = hpke::OpModeR::::Auth(sender.pk.clone()); + self._v1_decrypt(&opmode, msg, associated_data) + } + + /// Decrypt a message without sender authentication + /// + /// This is the counterpart to [`PublicKey::encrypt_noauth`] + /// + /// This function *cannot* decrypt messages created using [`PublicKey::encrypt`] + /// + /// # Warning + /// + /// Remember that without sender authentication there is no guarantee that the message + /// you decrypt was sent by anyone in particular. Using this function safely requires + /// some out of band authentication mechanism. + pub fn decrypt_noauth( + &self, + msg: &[u8], + associated_data: &[u8], + ) -> Result, DecryptionError> { + let opmode = hpke::OpModeR::::Base; + self._v1_decrypt(&opmode, msg, associated_data) + } + + fn _v1_decrypt( + &self, + opmode: &hpke::OpModeR, + msg: &[u8], + associated_data: &[u8], + ) -> Result, DecryptionError> { + let encap_key_len = ::EncappedKey::size(); + + if check_header!(DecryptionError, msg)? < encap_key_len { + return Err(DecryptionError::InvalidLength); + } + + let encap_key_bytes = &msg[HEADER_SIZE..HEADER_SIZE + encap_key_len]; + let encap_key = ::EncappedKey::from_bytes(encap_key_bytes) + .map_err(|_| DecryptionError::InvalidCiphertext)?; + + let ciphertext = &msg[HEADER_SIZE + encap_key_len..]; + + match hpke::single_shot_open::( + opmode, + &self.sk, + &encap_key, + &msg[0..HEADER_SIZE], + ciphertext, + associated_data, + ) { + Ok(ptext) => Ok(ptext), + Err(_) => Err(DecryptionError::InvalidCiphertext), + } + } +} diff --git a/packages/ic-hpke/tests/tests.rs b/packages/ic-hpke/tests/tests.rs new file mode 100644 index 00000000000..ebd4d87e5be --- /dev/null +++ b/packages/ic-hpke/tests/tests.rs @@ -0,0 +1,205 @@ +use ic_hpke::*; +use rand::{Rng, RngCore, SeedableRng}; + +#[test] +fn key_generation_and_noauth_encrypt_is_stable() { + let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(0); + + let sk = PrivateKey::generate(&mut rng); + + let sk_bytes = sk.serialize(); + assert_eq!(sk_bytes.len(), 56); + assert_eq!(hex::encode(sk_bytes), "49432048504b4501ca9b347c733a5375e97fb372763bd3b3478fce3ab7c7c340521d410051eff1c5cea6efa33cbf0910b919730726c42397"); + + let pk = sk.public_key(); + let pk_bytes = pk.serialize(); + assert_eq!(pk_bytes.len(), 105); + assert_eq!(hex::encode(pk_bytes), "49432048504b4501041e0fff52d0cb05f440c47614c50c2ace65db74194fc85fc55345140a9543bf1228dab6f0e4254505c7eaf692f7d8478eb8f027d944acc65d2c6818101b55a28861abc6386e6c85ded766e48e211253184ccaf7243685fe7ac36526a9ac7a4311"); + + let msg = b"this is a test"; + let aad = b"test associated data"; + + let ctext = pk + .encrypt_noauth(msg, aad, &mut rng) + .expect("encryption failed"); + + // 8 bytes version, 1+2*48 bytes P-384 point, 16 bytes GCM tag + assert_eq!(ctext.len(), 8 + (1 + 2 * 48) + 16 + msg.len()); + + assert_eq!(hex::encode(ctext), "49432048504b45010489d22b35f935051f9dd57c2e7909c388e97c2d960129ee92a0478710e4da9cfd6cda78881c297610d4776a27d73283675f743e1d20ce9019706d659b77aa78fc30f93000468ff2304b0c442c134f094e5e8d99b784f03eafee1c8f802cf661075a3c6e8e08338ef78f1a0a6731d5db1f0eb5131ce37ad7a0082a6be021d3"); +} + +#[test] +fn key_generation_and_authticated_encrypt_is_stable() { + let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(0); + + let sk_a = PrivateKey::generate(&mut rng); + let sk_b = PrivateKey::generate(&mut rng); + + let pk_a = sk_a.public_key(); + let pk_b = sk_b.public_key(); + + let msg = b"this is a test"; + let aad = b"test associated data"; + + let ctext_b_to_a = { + let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(1); + pk_a.encrypt(msg, aad, &sk_b, &mut rng) + .expect("encryption failed") + }; + + let ctext_a_to_b = { + let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(1); + pk_b.encrypt(msg, aad, &sk_a, &mut rng) + .expect("encryption failed") + }; + + assert_eq!(hex::encode(&ctext_b_to_a), "49432048504b450104c9613312faa2b1d5739c89a08ed6d3cb4935b5ad28db5855c19e32eb0f1f6fe6bb9f164da4524c8998d6c1529b99d429c93436a0f2dcdf3c58f806e0824c266dd43d29f2784176d56f2df1632ef1cf454da0ff52e9532eb452c928150f6710f6b80fe5a3ac17b9dc8b5c443f58c985f6365de6c756102e6bdb9e60432e11"); + + assert_eq!(hex::encode(&ctext_a_to_b), "49432048504b450104c9613312faa2b1d5739c89a08ed6d3cb4935b5ad28db5855c19e32eb0f1f6fe6bb9f164da4524c8998d6c1529b99d429c93436a0f2dcdf3c58f806e0824c266dd43d29f2784176d56f2df1632ef1cf454da0ff52e9532eb452c928150f6710f6d95b9ef4b2286c8ad9f0199bb716844ad13dec45cdb7bb265d4838369f72"); +} + +#[test] +fn smoke_test_noauth() { + let mut rng = rand::rngs::OsRng; + let sk = PrivateKey::generate(&mut rng); + let pk = sk.public_key(); + + for ptext_len in 0..128 { + let mut ptext = vec![0u8; ptext_len]; + rng.fill_bytes(&mut ptext); + let aad = rng.gen::<[u8; 32]>(); + let ctext = pk.encrypt_noauth(&ptext, &aad, &mut rng).unwrap(); + let rec = sk.decrypt_noauth(&ctext, &aad).unwrap(); + assert_eq!(rec, ptext); + } +} + +#[test] +fn smoke_test_auth() { + let mut rng = rand::rngs::OsRng; + + let a_sk = PrivateKey::generate(&mut rng); + let a_pk = a_sk.public_key(); + + let b_sk = PrivateKey::generate(&mut rng); + let b_pk = b_sk.public_key(); + + let aad = rng.gen::<[u8; 32]>(); + + for ptext_len in 0..128 { + let mut ptext = vec![0u8; ptext_len]; + rng.fill_bytes(&mut ptext); + let ctext = a_pk.encrypt(&ptext, &aad, &b_sk, &mut rng).unwrap(); + let rec = a_sk.decrypt(&ctext, &aad, &b_pk).unwrap(); + assert_eq!(rec, ptext); + + assert!(a_sk.decrypt_noauth(&ctext, &aad).is_err()); + } +} + +#[test] +fn any_bit_flip_causes_rejection_noauth() { + let mut rng = rand::rngs::OsRng; + + let a_sk = PrivateKey::generate(&mut rng); + let a_pk = a_sk.public_key(); + + let ptext = rng.gen::<[u8; 16]>().to_vec(); + let aad = rng.gen::<[u8; 32]>(); + + let mut ctext = a_pk.encrypt_noauth(&ptext, &aad, &mut rng).unwrap(); + + let bits = ctext.len() * 8; + + for bit in 0..bits { + ctext[bit / 8] ^= 1 << (bit % 8); + assert!(a_sk.decrypt_noauth(&ctext, &aad).is_err()); + + // restore the bit we just flipped + ctext[bit / 8] ^= 1 << (bit % 8); + } + + assert_eq!(a_sk.decrypt_noauth(&ctext, &aad), Ok(ptext)); +} + +#[test] +fn any_bit_flip_causes_rejection_auth() { + let mut rng = rand::rngs::OsRng; + + let a_sk = PrivateKey::generate(&mut rng); + let a_pk = a_sk.public_key(); + + let b_sk = PrivateKey::generate(&mut rng); + let b_pk = b_sk.public_key(); + + let ptext = rng.gen::<[u8; 16]>().to_vec(); + let aad = rng.gen::<[u8; 32]>(); + + let mut ctext = a_pk.encrypt(&ptext, &aad, &b_sk, &mut rng).unwrap(); + + let bits = ctext.len() * 8; + + for bit in 0..bits { + ctext[bit / 8] ^= 1 << (bit % 8); + assert!(a_sk.decrypt(&ctext, &aad, &b_pk).is_err()); + + // restore the bit we just flipped + ctext[bit / 8] ^= 1 << (bit % 8); + } + + assert_eq!(a_sk.decrypt(&ctext, &aad, &b_pk), Ok(ptext)); +} + +#[test] +fn any_truncation_causes_rejection_noauth() { + let mut rng = rand::rngs::OsRng; + + let a_sk = PrivateKey::generate(&mut rng); + let a_pk = a_sk.public_key(); + + let ptext = rng.gen::<[u8; 16]>().to_vec(); + let aad = rng.gen::<[u8; 32]>(); + + let mut ctext = a_pk.encrypt_noauth(&ptext, &aad, &mut rng).unwrap(); + + assert_eq!(a_sk.decrypt_noauth(&ctext, &aad), Ok(ptext)); + + loop { + ctext.pop(); + + assert!(a_sk.decrypt_noauth(&ctext, &aad).is_err()); + + if ctext.is_empty() { + break; + } + } +} + +#[test] +fn any_truncation_causes_rejection_auth() { + let mut rng = rand::rngs::OsRng; + + let a_sk = PrivateKey::generate(&mut rng); + let a_pk = a_sk.public_key(); + + let b_sk = PrivateKey::generate(&mut rng); + let b_pk = b_sk.public_key(); + + let ptext = rng.gen::<[u8; 16]>().to_vec(); + let aad = rng.gen::<[u8; 32]>(); + + let mut ctext = a_pk.encrypt(&ptext, &aad, &b_sk, &mut rng).unwrap(); + + assert_eq!(a_sk.decrypt(&ctext, &aad, &b_pk), Ok(ptext)); + + loop { + ctext.pop(); + + assert!(a_sk.decrypt(&ctext, &aad, &b_pk).is_err()); + + if ctext.is_empty() { + break; + } + } +} diff --git a/packages/pocket-ic/HOWTO.md b/packages/pocket-ic/HOWTO.md index e3e6b8dc6ff..d171fd2a1bf 100644 --- a/packages/pocket-ic/HOWTO.md +++ b/packages/pocket-ic/HOWTO.md @@ -42,48 +42,93 @@ See the [examples](README.md#examples) for more. ## Live Mode -Since version 4.0.0, the PocketIC server also exposes the IC's HTTP interface, just like the IC mainnet and the replica launched by dfx. This means that PocketIC instances can now be targeted by agent-based tools (agent.rs, agent.js, IC-Repl, etc). Note that PocketIC instances, if launched in the regular way, do not "make progress" by themselves, i.e., the state machines that represent the IC do not execute any messages without a call to `tick()` and their timestamps do not advance without a call to `advance_time(...)`. But the agent-based tools expect their target to make progress automatically (as the IC mainnet and the replica launched by dfx do) and use the current time as the IC time, since they dispatch asynchronous requests and poll for the result, checking for its freshness with respect to the current time. +The PocketIC server exposes the ICP's HTTP interface (as defined in the [Interface Specification](https://internetcomputer.org/docs/references/ic-interface-spec#http-interface)) used by the ICP mainnet. This means that PocketIC instances can also be targeted by agent-based tools, e.g., the [Rust](https://crates.io/crates/ic-agent) and [JavaScript](https://www.npmjs.com/package/@dfinity/agent) agents. -For that reason, you need to explicitly make an instance "live" by calling `make_live()` on it. This will do three things: +Note that PocketIC instances do not "make progress" by default, i.e., they do not execute any messages and time does not advance unless dedicated operations are triggered by separate HTTP requests. The "live" mode enabled by calling the function `PocketIc::make_live()` automates those steps by launching a background thread that -- It launches a thread that calls `tick()` and `advance_time(...)` on the instance regularly - several times per second. -- It creates a gateway which points to this live instance. -- It returns a gateway URL which can then be passed to agent-like tools. +- sets the current time as the PocketIC instance time; +- advances time on the PocketIC instance regularly; +- executes messages on the PocketIC instance; +- executes canister HTTP outcalls of the PocketIC instance. -Of course, other instances on the same PocketIC server remain unchanged - neither do they receive `tick`s nor can the gateway route requests to them. +The function `PocketIc::make_live()` also creates an HTTP gateway serving + - the ICP's HTTP interface (as defined in the [Interface Specification](https://internetcomputer.org/docs/references/ic-interface-spec#http-interface)) + - and the ICP's HTTP gateway interface (as defined in the [HTTP Gateway Protocol Specification](https://internetcomputer.org/docs/references/http-gateway-protocol-spec)) +and returns its URL. -**Attention**: Enabling auto-progress makes instances non-deterministic! There is no way to guarantee message order when agents dispatch async requests, which may interleave with each other and with the `tick`s from the auto-progress thread. If you need determinism, use the old, manually-`tick`ed API. +**Attention**: Enabling the "live" mode makes the PocketIC instance non-deterministic! For instance, there is no way to tell in which order messages are going to be executed. +The function `PocketIc::stop_live` can be used to disable the "live" mode: it stops the HTTP gateway and the background thread ensuring progress on the PocketIC instance. +However, the non-deterministic state changes during the "live" mode (e.g., time changes) could affect the PocketIC instance even after disabling the "live" mode. -**Attention**: It is strongly discouraged to use the PocketIC library for interacting with a live instance. -Live instances can be made non-live again by disabling auto-progress and disabling the gateway. -This is done by calling `stop_live()` on the instance. -Once this call returns, you can use the PocketIC library for testing again. -The instance will only make progress when you call `tick()` - but the state in which the instance halts is not deterministic. -So be extra careful with tests which are setup with a live phase and which then transition to non-live for the test section. +**Attention**: The "live" mode requires the PocketIC instance to have an NNS subnet. -Here is a sketch on how to use the live mode: +**Attention**: It is strongly discouraged to override time of a "live" PocketIC instance. + +Here is a sketch on how to use the PocketIC library to make an update call in the "live" mode: ```rust +// We create a PocketIC instance with an NNS subnet +// (the "live" mode requires the NNS subnet). let mut pic = PocketIcBuilder::new() .with_nns_subnet() .with_application_subnet() .build(); + +// Enable the "live" mode. +let _ = pic.make_live(None); + +// Create and install a test canister. +// ... + +// Submit an update call to the test canister making a canister http outcall. +let call_id = pic + .submit_call( + canister_id, + Principal::anonymous(), + "canister_http", + encode_one(()).unwrap(), + ) + .unwrap(); + +// Await the update call without making additional progress (the PocketIC instance +// is already in the "live" mode making progress automatically). +let reply = pic.await_call_no_ticks(call_id).unwrap(); + +// Process the reply. +// ... +``` + +Here is a sketch on how to use the IC agent in the "live" mode: + +```rust +// We create a PocketIC instance with an NNS subnet +// (the "live" mode requires the NNS subnet). +let mut pic = PocketIcBuilder::new() + .with_nns_subnet() + .with_application_subnet() + .build(); + +// Enable the "live" mode. let endpoint = pic.make_live(None); -// the local agent needs a runtime -let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); -let res = rt.block_on(async { + +// We use a tokio runtime to run the asynchronous IC agent. +let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); +rt.block_on(async { + // We create an IC agent. let agent = ic_agent::Agent::builder() - .with_url(endpoint.clone()) - .build() - .unwrap(); - // proof that the agent can communicate with the instance + .with_url(endpoint) + .build() + .unwrap(); + + // We fetch the PocketIC (i.e., non-mainnet) root key to successfully verify responses. agent.fetch_root_key().await.unwrap(); - // do something useful with the agent - let res = agent.[...] - res + + // Finally, we use the IC agent in tests. + // ... }); -// stop the HTTP gateway and auto progress -pic.stop_live(); ``` ## Concurrent update calls diff --git a/packages/pocket-ic/tests/tests.rs b/packages/pocket-ic/tests/tests.rs index bb42723916f..5d11b428638 100644 --- a/packages/pocket-ic/tests/tests.rs +++ b/packages/pocket-ic/tests/tests.rs @@ -1198,6 +1198,44 @@ fn test_canister_http() { assert_eq!(http_response.unwrap().body, body); } +#[test] +fn test_canister_http_in_live_mode() { + // We create a PocketIC instance with an NNS subnet + // (the "live" mode requires the NNS subnet). + let mut pic = PocketIcBuilder::new() + .with_nns_subnet() + .with_application_subnet() + .build(); + + // Enable the "live" mode. + let _ = pic.make_live(None); + + // Create a canister and charge it with 2T cycles. + let canister_id = pic.create_canister(); + pic.add_cycles(canister_id, INIT_CYCLES); + + // Install the test canister wasm file on the canister. + let test_wasm = test_canister_wasm(); + pic.install_canister(canister_id, test_wasm, vec![], None); + + // Submit an update call to the test canister making a canister http outcall. + let call_id = pic + .submit_call( + canister_id, + Principal::anonymous(), + "canister_http", + encode_one(()).unwrap(), + ) + .unwrap(); + + // Await the update call without making additional progress (the PocketIC instance + // is already in the "live" mode making progress automatically). + let reply = pic.await_call_no_ticks(call_id).unwrap(); + let http_response: Result = + decode_one(&reply).unwrap(); + http_response.unwrap(); +} + #[test] fn test_canister_http_with_transform() { let pic = PocketIc::new(); diff --git a/rs/canister_sandbox/src/replica_controller/sandboxed_execution_controller.rs b/rs/canister_sandbox/src/replica_controller/sandboxed_execution_controller.rs index a21b28f1c80..37df7db3846 100644 --- a/rs/canister_sandbox/src/replica_controller/sandboxed_execution_controller.rs +++ b/rs/canister_sandbox/src/replica_controller/sandboxed_execution_controller.rs @@ -2304,7 +2304,10 @@ mod tests { canister_module, PathBuf::new(), canister_id, - Arc::new(CompilationCache::new(MAX_COMPILATION_CACHE_SIZE)), + Arc::new(CompilationCache::new( + MAX_COMPILATION_CACHE_SIZE, + tempfile::tempdir().unwrap(), + )), ) .unwrap(); let sandbox_pid = match controller diff --git a/rs/canonical_state/BUILD.bazel b/rs/canonical_state/BUILD.bazel index 101def3cce6..a8556b986fe 100644 --- a/rs/canonical_state/BUILD.bazel +++ b/rs/canonical_state/BUILD.bazel @@ -47,7 +47,9 @@ DEV_DEPENDENCIES = [ "@crate_index//:tempfile", ] -DEV_MACRO_DEPENDENCIES = [ +MACRO_DEV_DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:test-strategy", ] rust_library( @@ -68,20 +70,20 @@ rust_test( rust_test( name = "compatibility_test", srcs = ["tests/compatibility.rs"], - proc_macro_deps = DEV_MACRO_DEPENDENCIES, + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = DEPENDENCIES + DEV_DEPENDENCIES + [":canonical_state"], ) rust_test( name = "size_limit_visitor_test", srcs = ["tests/size_limit_visitor.rs"], - proc_macro_deps = DEV_MACRO_DEPENDENCIES, + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = DEPENDENCIES + DEV_DEPENDENCIES + [":canonical_state"], ) rust_test( name = "hash_tree_test", srcs = ["tests/hash_tree.rs"], - proc_macro_deps = DEV_MACRO_DEPENDENCIES, + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = DEPENDENCIES + DEV_DEPENDENCIES + [":canonical_state"], ) diff --git a/rs/canonical_state/Cargo.toml b/rs/canonical_state/Cargo.toml index 4abd235fde5..900f9f4f0b2 100644 --- a/rs/canonical_state/Cargo.toml +++ b/rs/canonical_state/Cargo.toml @@ -41,4 +41,5 @@ ic-wasm-types = { path = "../types/wasm_types" } lazy_static = { workspace = true } maplit = "1.0.2" proptest = { workspace = true } +test-strategy = "0.4.0" tempfile = { workspace = true } diff --git a/rs/canonical_state/tests/compatibility.rs b/rs/canonical_state/tests/compatibility.rs index e45dd64fd10..f389c4e648e 100644 --- a/rs/canonical_state/tests/compatibility.rs +++ b/rs/canonical_state/tests/compatibility.rs @@ -149,71 +149,104 @@ lazy_static! { ]; } -proptest! { - /// Tests that given a `StreamHeader` that is valid for a given certification - /// version range (e.g. no `reject_signals` before certification version 8) all - /// supported canonical type (e.g. `StreamHeaderV6` or `StreamHeader`) and - /// certification version combinations produce the exact same encoding. - #[test] - fn stream_header_unique_encoding((header, version_range) in arb_valid_versioned_stream_header(100)) { - let mut results = vec![]; - for version in iter(version_range) { - let results_before = results.len(); - for encoding in &*STREAM_HEADER_ENCODINGS { - if encoding.version_range.contains(&version) { - let bytes = (encoding.encode)((&header, version)) - .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); - results.push((version, encoding.name, bytes)); - } +/// Tests that given a `StreamHeader` that is valid for a given certification +/// version range (e.g. no `reject_signals` before certification version 8) all +/// supported canonical type (e.g. `StreamHeaderV6` or `StreamHeader`) and +/// certification version combinations produce the exact same encoding. +#[test_strategy::proptest] +fn stream_header_unique_encoding( + #[strategy(arb_valid_versioned_stream_header( + 100, // max_signal_count + ))] + test_header: (StreamHeader, RangeInclusive), +) { + let (header, version_range) = test_header; + + let mut results = vec![]; + for version in iter(version_range) { + let results_before = results.len(); + for encoding in &*STREAM_HEADER_ENCODINGS { + if encoding.version_range.contains(&version) { + let bytes = (encoding.encode)((&header, version)) + .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); + results.push((version, encoding.name, bytes)); } - assert!(results.len() > results_before, "No supported encodings for certification version {:?}", version); } + assert!( + results.len() > results_before, + "No supported encodings for certification version {:?}", + version + ); + } - if results.len() > 1 { - let (current_version, current_name, current_bytes) = results.pop().unwrap(); - for (version, name, bytes) in &results { - assert_eq!(¤t_bytes, bytes, "Different encodings: {}@{:?} and {}@{:?}", current_name, current_version, name, version); - } + if results.len() > 1 { + let (current_version, current_name, current_bytes) = results.pop().unwrap(); + for (version, name, bytes) in &results { + assert_eq!( + ¤t_bytes, bytes, + "Different encodings: {}@{:?} and {}@{:?}", + current_name, current_version, name, version + ); } } +} + +/// Tests that, given a `StreamHeader` that is valid for a given certification +/// version range (e.g. no `reject_signals` before certification version 8), +/// all supported encodings will decode back into the same `StreamHeader`. +#[test_strategy::proptest] +fn stream_header_roundtrip_encoding( + #[strategy(arb_valid_versioned_stream_header( + 100, // max_signal_count + ))] + test_header: (StreamHeader, RangeInclusive), +) { + let (header, version_range) = test_header; + + for version in iter(version_range) { + for encoding in &*STREAM_HEADER_ENCODINGS { + if encoding.version_range.contains(&version) { + let bytes = (encoding.encode)((&header, version)) + .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); + let result = (encoding.decode)(&bytes) + .unwrap_or_else(|_| panic!("Failed to decode {}@{:?}", encoding.name, version)); - /// Tests that, given a `StreamHeader` that is valid for a given certification - /// version range (e.g. no `reject_signals` before certification version 8), - /// all supported encodings will decode back into the same `StreamHeader`. - #[test] - fn stream_header_roundtrip_encoding((header, version_range) in arb_valid_versioned_stream_header(100)) { - for version in iter(version_range) { - for encoding in &*STREAM_HEADER_ENCODINGS { - if encoding.version_range.contains(&version) { - let bytes = (encoding.encode)((&header, version)) - .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); - let result = (encoding.decode)(&bytes) - .unwrap_or_else(|_| panic!("Failed to decode {}@{:?}", encoding.name, version)); - - assert_eq!(header, result, "Roundtrip encoding {}@{:?} failed", encoding.name, version); - } + assert_eq!( + header, result, + "Roundtrip encoding {}@{:?} failed", + encoding.name, version + ); } } } +} + +/// Tests that, given a `StreamHeader` that is invalid for a given certification +/// version range (e.g. `reject_signals` before certification version 8), +/// encoding will panic. +/// +/// Be aware that the output generated by this test failing includes all panics +/// (e.g. stack traces), including those produced by previous iterations where +/// panics were caught by `std::panic::catch_unwind`. +#[test_strategy::proptest] +fn stream_header_encoding_panic_on_invalid( + #[strategy(arb_invalid_versioned_stream_header( + 100, // max_signal_count + ))] + test_header: (StreamHeader, RangeInclusive), +) { + let (header, version_range) = test_header; + for version in iter(version_range) { + for encoding in &*STREAM_HEADER_ENCODINGS { + if encoding.version_range.contains(&version) { + let result = std::panic::catch_unwind(|| (encoding.encode)((&header, version))); - /// Tests that, given a `StreamHeader` that is invalid for a given certification - /// version range (e.g. `reject_signals` before certification version 8), - /// encoding will panic. - /// - /// Be aware that the output generated by this test failing includes all panics - /// (e.g. stack traces), including those produced by previous iterations where - /// panics were caught by `std::panic::catch_unwind`. - #[test] - fn stream_header_encoding_panic_on_invalid((header, version_range) in arb_invalid_versioned_stream_header(100)) { - for version in iter(version_range) { - for encoding in &*STREAM_HEADER_ENCODINGS { - if encoding.version_range.contains(&version) { - let result = std::panic::catch_unwind(|| { - (encoding.encode)((&header, version)) - }); - - assert!(result.is_err(), "Encoding of invalid {}@{:?} succeeded", encoding.name, version); - } + assert!( + result.is_err(), + "Encoding of invalid {}@{:?} succeeded", + encoding.name, + version + ); } } } @@ -256,49 +289,73 @@ lazy_static! { ]; } -proptest! { - /// Tests that given a `RequestOrResponse` that is valid for a given certification - /// version range (e.g. no `reject_signals` before certification version 8) all - /// supported canonical type (e.g. `RequestOrResponseV3` or `RequestOrResponse`) - /// and certification version combinations produce the exact same encoding. - #[test] - fn message_unique_encoding((message, version_range) in arb_valid_versioned_message()) { - let mut results = vec![]; - for version in iter(version_range) { - let results_before = results.len(); - for encoding in &*MESSAGE_ENCODINGS { - if encoding.version_range.contains(&version) { - let bytes = (encoding.encode)((&message, version)) - .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); - results.push((version, encoding.name, bytes)); - } +/// Tests that given a `RequestOrResponse` that is valid for a given certification +/// version range (e.g. no `reject_signals` before certification version 8) all +/// supported canonical type (e.g. `RequestOrResponseV3` or `RequestOrResponse`) +/// and certification version combinations produce the exact same encoding. +#[test_strategy::proptest] +fn message_unique_encoding( + #[strategy(arb_valid_versioned_message())] test_message: ( + RequestOrResponse, + RangeInclusive, + ), +) { + let (message, version_range) = test_message; + + let mut results = vec![]; + for version in iter(version_range) { + let results_before = results.len(); + for encoding in &*MESSAGE_ENCODINGS { + if encoding.version_range.contains(&version) { + let bytes = (encoding.encode)((&message, version)) + .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); + results.push((version, encoding.name, bytes)); } - assert!(results.len() > results_before, "No supported encodings for certification version {:?}", version); } + assert!( + results.len() > results_before, + "No supported encodings for certification version {:?}", + version + ); + } - if results.len() > 1 { - let (current_version, current_name, current_bytes) = results.pop().unwrap(); - for (version, name, bytes) in &results { - assert_eq!(¤t_bytes, bytes, "Different encodings: {}@{:?} and {}@{:?}", current_name, current_version, name, version); - } + if results.len() > 1 { + let (current_version, current_name, current_bytes) = results.pop().unwrap(); + for (version, name, bytes) in &results { + assert_eq!( + ¤t_bytes, bytes, + "Different encodings: {}@{:?} and {}@{:?}", + current_name, current_version, name, version + ); } } +} + +/// Tests that, given a `RequestOrResponse` that is valid for a given +/// certification version range, all supported encodings will decode back into +/// the same `RequestOrResponse`. +#[test_strategy::proptest] +fn message_roundtrip_encoding( + #[strategy(arb_valid_versioned_message())] test_message: ( + RequestOrResponse, + RangeInclusive, + ), +) { + let (message, version_range) = test_message; - /// Tests that, given a `RequestOrResponse` that is valid for a given - /// certification version range, all supported encodings will decode back into - /// the same `RequestOrResponse`. - #[test] - fn message_roundtrip_encoding((message, version_range) in arb_valid_versioned_message()) { - for version in iter(version_range) { - for encoding in &*MESSAGE_ENCODINGS { - if encoding.version_range.contains(&version) { - let bytes = (encoding.encode)((&message, version)) - .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); - let result = (encoding.decode)(&bytes) - .unwrap_or_else(|_| panic!("Failed to decode {}@{:?}", encoding.name, version)); - - assert_eq!(message, result, "Roundtrip encoding {}@{:?} failed", encoding.name, version); - } + for version in iter(version_range) { + for encoding in &*MESSAGE_ENCODINGS { + if encoding.version_range.contains(&version) { + let bytes = (encoding.encode)((&message, version)) + .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); + let result = (encoding.decode)(&bytes) + .unwrap_or_else(|_| panic!("Failed to decode {}@{:?}", encoding.name, version)); + + assert_eq!( + message, result, + "Roundtrip encoding {}@{:?} failed", + encoding.name, version + ); } } } @@ -350,31 +407,44 @@ pub(crate) fn arb_valid_system_metadata( ] } -proptest! { - /// Tests that given a `SystemMetadata` that is valid for a given certification - /// version range, all supported canonical type (e.g. `SystemMetadataV9` or - /// `SystemMetadataV10`) and certification version combinations produce the - /// exact same encoding. - #[test] - fn system_metadata_unique_encoding((metadata, version_range) in arb_valid_system_metadata()) { - let mut results = vec![]; - for version in iter(version_range) { - let results_before = results.len(); - for encoding in &*SYSTEM_METADATA_ENCODINGS { - if encoding.version_range.contains(&version) { - let bytes = (encoding.encode)((&metadata, version)) - .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); - results.push((version, encoding.name, bytes)); - } +/// Tests that given a `SystemMetadata` that is valid for a given certification +/// version range, all supported canonical type (e.g. `SystemMetadataV9` or +/// `SystemMetadataV10`) and certification version combinations produce the +/// exact same encoding. +#[test_strategy::proptest] +fn system_metadata_unique_encoding( + #[strategy(arb_valid_system_metadata())] test_metadata: ( + SystemMetadata, + RangeInclusive, + ), +) { + let (metadata, version_range) = test_metadata; + + let mut results = vec![]; + for version in iter(version_range) { + let results_before = results.len(); + for encoding in &*SYSTEM_METADATA_ENCODINGS { + if encoding.version_range.contains(&version) { + let bytes = (encoding.encode)((&metadata, version)) + .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); + results.push((version, encoding.name, bytes)); } - assert!(results.len() > results_before, "No supported encodings for certification version {:?}", version); } + assert!( + results.len() > results_before, + "No supported encodings for certification version {:?}", + version + ); + } - if results.len() > 1 { - let (current_version, current_name, current_bytes) = results.pop().unwrap(); - for (version, name, bytes) in &results { - assert_eq!(¤t_bytes, bytes, "Different encodings: {}@{:?} and {}@{:?}", current_name, current_version, name, version); - } + if results.len() > 1 { + let (current_version, current_name, current_bytes) = results.pop().unwrap(); + for (version, name, bytes) in &results { + assert_eq!( + ¤t_bytes, bytes, + "Different encodings: {}@{:?} and {}@{:?}", + current_name, current_version, name, version + ); } } } @@ -402,30 +472,43 @@ pub(crate) fn arb_valid_subnet_metrics( )] } -proptest! { - /// Tests that given a `SubnetMetrics` that is valid for a given certification - /// version range, all supported canonical type and certification version - /// combinations produce the exact same encoding. - #[test] - fn subnet_metrics_unique_encoding((subnet_metrics, version_range) in arb_valid_subnet_metrics()) { - let mut results = vec![]; - for version in iter(version_range) { - let results_before = results.len(); - for encoding in &*SUBNET_METRICS_ENCODINGS { - if encoding.version_range.contains(&version) { - let bytes = (encoding.encode)((&subnet_metrics, version)) - .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); - results.push((version, encoding.name, bytes)); - } +/// Tests that given a `SubnetMetrics` that is valid for a given certification +/// version range, all supported canonical type and certification version +/// combinations produce the exact same encoding. +#[test_strategy::proptest] +fn subnet_metrics_unique_encoding( + #[strategy(arb_valid_subnet_metrics())] test_subnet_metrics: ( + SubnetMetrics, + RangeInclusive, + ), +) { + let (subnet_metrics, version_range) = test_subnet_metrics; + + let mut results = vec![]; + for version in iter(version_range) { + let results_before = results.len(); + for encoding in &*SUBNET_METRICS_ENCODINGS { + if encoding.version_range.contains(&version) { + let bytes = (encoding.encode)((&subnet_metrics, version)) + .unwrap_or_else(|_| panic!("Failed to encode {}@{:?}", encoding.name, version)); + results.push((version, encoding.name, bytes)); } - assert!(results.len() > results_before, "No supported encodings for certification version {:?}", version); } + assert!( + results.len() > results_before, + "No supported encodings for certification version {:?}", + version + ); + } - if results.len() > 1 { - let (current_version, current_name, current_bytes) = results.pop().unwrap(); - for (version, name, bytes) in &results { - assert_eq!(¤t_bytes, bytes, "Different encodings: {}@{:?} and {}@{:?}", current_name, current_version, name, version); - } + if results.len() > 1 { + let (current_version, current_name, current_bytes) = results.pop().unwrap(); + for (version, name, bytes) in &results { + assert_eq!( + ¤t_bytes, bytes, + "Different encodings: {}@{:?} and {}@{:?}", + current_name, current_version, name, version + ); } } } diff --git a/rs/canonical_state/tests/size_limit_visitor.rs b/rs/canonical_state/tests/size_limit_visitor.rs index d24639efcc0..873ed2ab0db 100644 --- a/rs/canonical_state/tests/size_limit_visitor.rs +++ b/rs/canonical_state/tests/size_limit_visitor.rs @@ -68,46 +68,53 @@ prop_compose! { } } -proptest! { - #[test] - fn size_limit_proptest(fixture in arb_fixture(10)) { - let Fixture{ state, end, slice_begin, size_limit, .. } = fixture; - - // Produce a size-limited slice starting from `slice_begin`. - let pattern = vec![ - Label(b"streams".to_vec()), - Any, - Label(b"messages".to_vec()), - Any, - ]; - let subtree_pattern = make_slice_pattern(slice_begin, end); - let visitor = SizeLimitVisitor::new( - pattern, - size_limit, - SubtreeVisitor::new(&subtree_pattern, MessageSpyVisitor::default()), +#[test_strategy::proptest] +fn size_limit_proptest(#[strategy(arb_fixture(10))] fixture: Fixture) { + let Fixture { + state, + end, + slice_begin, + size_limit, + .. + } = fixture; + + // Produce a size-limited slice starting from `slice_begin`. + let pattern = vec![ + Label(b"streams".to_vec()), + Any, + Label(b"messages".to_vec()), + Any, + ]; + let subtree_pattern = make_slice_pattern(slice_begin, end); + let visitor = SizeLimitVisitor::new( + pattern, + size_limit, + SubtreeVisitor::new(&subtree_pattern, MessageSpyVisitor::default()), + ); + let (actual_size, actual_begin, actual_end) = traverse(&state, visitor); + + if let (Some(actual_begin), Some(actual_end)) = (actual_begin, actual_end) { + // Non-empty slice. + assert_eq!(slice_begin, actual_begin); + assert!(actual_end <= end); + + // Size is below the limit or the slice consists of a single message. + assert!(actual_size <= size_limit || actual_end - actual_begin == 1); + // And must match the computed slice size. + assert_eq!( + compute_message_sizes(&state, actual_begin, actual_end), + actual_size ); - let (actual_size, actual_begin, actual_end) = traverse(&state, visitor); - - if let (Some(actual_begin), Some(actual_end)) = (actual_begin, actual_end) { - // Non-empty slice. - assert_eq!(slice_begin, actual_begin); - assert!(actual_end <= end); - - // Size is below the limit or the slice consists of a single message. - assert!(actual_size <= size_limit || actual_end - actual_begin == 1); - // And must match the computed slice size. - assert_eq!(compute_message_sizes(&state, actual_begin, actual_end), actual_size); - if actual_end < end { - // Including one more message should exceed `size_limit`. - assert!(compute_message_sizes(&state, actual_begin, actual_end + 1) > size_limit); - } - } else { - // Empty slice. - assert_eq!(0, actual_size); - // May only happen if `slice_begin == stream.messages.end`. - assert_eq!(slice_begin, end); + if actual_end < end { + // Including one more message should exceed `size_limit`. + assert!(compute_message_sizes(&state, actual_begin, actual_end + 1) > size_limit); } + } else { + // Empty slice. + assert_eq!(0, actual_size); + // May only happen if `slice_begin == stream.messages.end`. + assert_eq!(slice_begin, end); } } diff --git a/rs/canonical_state/tree_hash/BUILD.bazel b/rs/canonical_state/tree_hash/BUILD.bazel index c98ee0cdb38..239979d4eba 100644 --- a/rs/canonical_state/tree_hash/BUILD.bazel +++ b/rs/canonical_state/tree_hash/BUILD.bazel @@ -22,6 +22,11 @@ DEV_DEPENDENCIES = [ "@crate_index//:rand_chacha", ] +MACRO_DEV_DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:test-strategy", +] + rust_library( name = "tree_hash", srcs = glob(["src/**/*.rs"]), @@ -33,5 +38,6 @@ rust_library( rust_test_suite( name = "tree_hash_integration", srcs = glob(["tests/**/*.rs"]), + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = DEPENDENCIES + DEV_DEPENDENCIES + [":tree_hash"], ) diff --git a/rs/canonical_state/tree_hash/Cargo.toml b/rs/canonical_state/tree_hash/Cargo.toml index 6142b75efd2..c3793bff62a 100644 --- a/rs/canonical_state/tree_hash/Cargo.toml +++ b/rs/canonical_state/tree_hash/Cargo.toml @@ -21,3 +21,4 @@ ic-crypto-tree-hash-test-utils = { path = "../../crypto/tree_hash/test_utils" } proptest = { workspace = true } rand = { workspace = true } rand_chacha = { workspace = true } +test-strategy = "0.4.0" diff --git a/rs/canonical_state/tree_hash/tests/hash_tree.rs b/rs/canonical_state/tree_hash/tests/hash_tree.rs index 93163525808..5b765cb61ef 100644 --- a/rs/canonical_state/tree_hash/tests/hash_tree.rs +++ b/rs/canonical_state/tree_hash/tests/hash_tree.rs @@ -174,10 +174,11 @@ fn test_non_existence_proof() { ); } -proptest! { - #[test] - fn same_witness(t in arbitrary_labeled_tree(), seed in prop::array::uniform32(any::())) { - let rng = &mut ChaCha20Rng::from_seed(seed); - test_membership_witness(&t, rng); - } +#[test_strategy::proptest] +fn same_witness( + #[strategy(arbitrary_labeled_tree())] t: LabeledTree>, + #[strategy(prop::array::uniform32(any::()))] seed: [u8; 32], +) { + let rng = &mut ChaCha20Rng::from_seed(seed); + test_membership_witness(&t, rng); } diff --git a/rs/consensus/BUILD.bazel b/rs/consensus/BUILD.bazel index 3b77c11b8f4..a0fa40d731d 100644 --- a/rs/consensus/BUILD.bazel +++ b/rs/consensus/BUILD.bazel @@ -9,6 +9,7 @@ DEPENDENCIES = [ "//rs/config", "//rs/consensus/dkg", "//rs/consensus/utils", + "//rs/consensus/vetkd", "//rs/crypto", "//rs/crypto/prng", "//rs/crypto/utils/threshold_sig_der", diff --git a/rs/consensus/Cargo.toml b/rs/consensus/Cargo.toml index 874c75942bb..21a38f4e6f1 100644 --- a/rs/consensus/Cargo.toml +++ b/rs/consensus/Cargo.toml @@ -10,6 +10,7 @@ documentation.workspace = true ic-config = { path = "../config" } ic-consensus-dkg = { path = "./dkg" } ic-consensus-utils = { path = "./utils" } +ic-consensus-vetkd = { path = "./vetkd" } ic-crypto = { path = "../crypto" } ic-crypto-prng = { path = "../crypto/prng" } ic-crypto-test-utils-canister-threshold-sigs = { path = "../crypto/test_utils/canister_threshold_sigs" } diff --git a/rs/consensus/benches/validate_payload.rs b/rs/consensus/benches/validate_payload.rs index babddccd7c9..fb0e19cad0f 100644 --- a/rs/consensus/benches/validate_payload.rs +++ b/rs/consensus/benches/validate_payload.rs @@ -170,6 +170,7 @@ where Arc::new(FakeSelfValidatingPayloadBuilder::new()), Arc::new(FakeCanisterHttpPayloadBuilder::new()), Arc::new(MockBatchPayloadBuilder::new().expect_noop()), + Arc::new(MockBatchPayloadBuilder::new().expect_noop()), metrics_registry, no_op_logger(), )); diff --git a/rs/consensus/src/consensus.rs b/rs/consensus/src/consensus.rs index 4b627cc7ad4..7a962e89120 100644 --- a/rs/consensus/src/consensus.rs +++ b/rs/consensus/src/consensus.rs @@ -143,6 +143,7 @@ impl ConsensusImpl { self_validating_payload_builder: Arc, canister_http_payload_builder: Arc, query_stats_payload_builder: Arc, + vetkd_payload_builder: Arc, dkg_pool: Arc>, idkg_pool: Arc>, dkg_key_manager: Arc>, @@ -171,6 +172,7 @@ impl ConsensusImpl { self_validating_payload_builder, canister_http_payload_builder, query_stats_payload_builder, + vetkd_payload_builder, metrics_registry.clone(), logger.clone(), )); @@ -682,6 +684,7 @@ mod tests { Arc::new(FakeSelfValidatingPayloadBuilder::new()), Arc::new(FakeCanisterHttpPayloadBuilder::new()), Arc::new(MockBatchPayloadBuilder::new().expect_noop()), + Arc::new(MockBatchPayloadBuilder::new().expect_noop()), dkg_pool, idkg_pool, Arc::new(Mutex::new(DkgKeyManager::new( diff --git a/rs/consensus/src/consensus/batch_delivery.rs b/rs/consensus/src/consensus/batch_delivery.rs index ba32b71e8cd..2d302ff9c29 100644 --- a/rs/consensus/src/consensus/batch_delivery.rs +++ b/rs/consensus/src/consensus/batch_delivery.rs @@ -13,6 +13,7 @@ use ic_consensus_dkg::get_vetkey_public_keys; use ic_consensus_utils::{ crypto_hashable_to_seed, membership::Membership, pool_reader::PoolReader, }; +use ic_consensus_vetkd::VetKdPayloadBuilderImpl; use ic_https_outcalls_consensus::payload_builder::CanisterHttpPayloadBuilderImpl; use ic_interfaces::{ batch_payload::IntoMessages, @@ -285,6 +286,10 @@ pub fn generate_responses_to_subnet_calls( CanisterHttpPayloadBuilderImpl::into_messages(&block_payload.batch.canister_http); consensus_responses.append(&mut http_responses); stats.canister_http = http_stats; + + let mut vetkd_responses = + VetKdPayloadBuilderImpl::into_messages(&block_payload.batch.vetkd); + consensus_responses.append(&mut vetkd_responses); } consensus_responses } diff --git a/rs/consensus/src/consensus/payload.rs b/rs/consensus/src/consensus/payload.rs index 6ec5319fe20..5950c4018c3 100644 --- a/rs/consensus/src/consensus/payload.rs +++ b/rs/consensus/src/consensus/payload.rs @@ -43,6 +43,7 @@ pub(crate) enum BatchPayloadSectionBuilder { SelfValidating(Arc), CanisterHttp(Arc), QueryStats(Arc), + VetKd(Arc), } impl BatchPayloadSectionBuilder { @@ -92,6 +93,7 @@ impl BatchPayloadSectionBuilder { Self::SelfValidating(_) => "self_validating", Self::CanisterHttp(_) => "canister_http", Self::QueryStats(_) => "query_stats", + Self::VetKd(_) => "vetkd", } } @@ -320,6 +322,44 @@ impl BatchPayloadSectionBuilder { } } } + Self::VetKd(builder) => { + let past_payloads: Vec = + filter_past_payloads(past_payloads, |_, _, payload| { + if payload.is_summary() { + None + } else { + Some(&payload.as_ref().as_data().batch.vetkd) + } + }); + + let vetkd = builder.build_payload( + height, + max_size, + &past_payloads, + proposal_context.validation_context, + ); + let size = NumBytes::new(vetkd.len() as u64); + + // Check validation as safety measure + match builder.validate_payload(height, proposal_context, &vetkd, &past_payloads) { + Ok(()) => { + payload.vetkd = vetkd; + size + } + Err(err) => { + error!( + logger, + "VetKd payload did not pass validation, this is a bug, {:?} @{}", + err, + CRITICAL_ERROR_VALIDATION_NOT_PASSED + ); + + metrics.critical_error_validation_not_passed.inc(); + payload.vetkd = vec![]; + NumBytes::new(0) + } + } + } } } @@ -405,6 +445,25 @@ impl BatchPayloadSectionBuilder { Ok(NumBytes::new(payload.query_stats.len() as u64)) } + Self::VetKd(builder) => { + let past_payloads: Vec = + filter_past_payloads(past_payloads, |_, _, payload| { + if payload.is_summary() { + None + } else { + Some(&payload.as_ref().as_data().batch.vetkd) + } + }); + + builder.validate_payload( + height, + proposal_context, + &payload.vetkd, + &past_payloads, + )?; + + Ok(NumBytes::new(payload.vetkd.len() as u64)) + } } } } diff --git a/rs/consensus/src/consensus/payload_builder.rs b/rs/consensus/src/consensus/payload_builder.rs index 235185ec29d..3d560d48756 100644 --- a/rs/consensus/src/consensus/payload_builder.rs +++ b/rs/consensus/src/consensus/payload_builder.rs @@ -49,6 +49,7 @@ impl PayloadBuilderImpl { self_validating_payload_builder: Arc, canister_http_payload_builder: Arc, query_stats_payload_builder: Arc, + vetkd_payload_builder: Arc, metrics: MetricsRegistry, logger: ReplicaLogger, ) -> Self { @@ -58,6 +59,7 @@ impl PayloadBuilderImpl { BatchPayloadSectionBuilder::XNet(xnet_payload_builder), BatchPayloadSectionBuilder::CanisterHttp(canister_http_payload_builder), BatchPayloadSectionBuilder::QueryStats(query_stats_payload_builder), + BatchPayloadSectionBuilder::VetKd(vetkd_payload_builder), ]; Self { @@ -272,6 +274,7 @@ pub(crate) mod test { let canister_http_payload_builder = FakeCanisterHttpPayloadBuilder::new().with_responses(canister_http_responses); let query_stats_payload_builder = MockBatchPayloadBuilder::new().expect_noop(); + let vetkd_payload_builder = MockBatchPayloadBuilder::new().expect_noop(); PayloadBuilderImpl::new( subnet_test_id(0), @@ -282,6 +285,7 @@ pub(crate) mod test { Arc::new(self_validating_payload_builder), Arc::new(canister_http_payload_builder), Arc::new(query_stats_payload_builder), + Arc::new(vetkd_payload_builder), MetricsRegistry::new(), no_op_logger(), ) diff --git a/rs/consensus/src/idkg.rs b/rs/consensus/src/idkg.rs index 89a3d62013f..587ba5bd85d 100644 --- a/rs/consensus/src/idkg.rs +++ b/rs/consensus/src/idkg.rs @@ -534,9 +534,12 @@ fn compute_bouncer( BouncerValue::MaybeWantsLater } } - IDkgMessageId::VetKdKeyShare(_, _) => { - // TODO(CON-1424): Accept VetKd shares - BouncerValue::Unwanted + IDkgMessageId::VetKdKeyShare(_, data) => { + if data.get_ref().height <= args.certified_height + Height::from(LOOK_AHEAD) { + BouncerValue::Wants + } else { + BouncerValue::MaybeWantsLater + } } IDkgMessageId::Complaint(_, data) => { if data.get_ref().height <= args.finalized_height + Height::from(LOOK_AHEAD) { @@ -563,7 +566,7 @@ mod tests { use ic_test_utilities::state_manager::RefMockStateManager; use ic_types::consensus::idkg::{ complaint_prefix, dealing_prefix, dealing_support_prefix, ecdsa_sig_share_prefix, - opening_prefix, schnorr_sig_share_prefix, IDkgArtifactIdData, + opening_prefix, schnorr_sig_share_prefix, vetkd_key_share_prefix, IDkgArtifactIdData, }; use ic_types::{ consensus::idkg::{RequestId, SigShareIdData}, @@ -711,6 +714,13 @@ mod tests { ), BouncerValue::Wants, ), + ( + IDkgMessageId::VetKdKeyShare( + vetkd_key_share_prefix(&request_id_fetch_1, &NODE_1), + get_fake_share_id_data(&request_id_fetch_1).into(), + ), + BouncerValue::Wants, + ), ( IDkgMessageId::EcdsaSigShare( ecdsa_sig_share_prefix(&request_id_fetch_2, &NODE_1), @@ -725,6 +735,13 @@ mod tests { ), BouncerValue::Wants, ), + ( + IDkgMessageId::VetKdKeyShare( + vetkd_key_share_prefix(&request_id_fetch_2, &NODE_1), + get_fake_share_id_data(&request_id_fetch_2).into(), + ), + BouncerValue::Wants, + ), ( IDkgMessageId::EcdsaSigShare( ecdsa_sig_share_prefix(&request_id_stash, &NODE_1), @@ -739,6 +756,13 @@ mod tests { ), BouncerValue::MaybeWantsLater, ), + ( + IDkgMessageId::VetKdKeyShare( + vetkd_key_share_prefix(&request_id_stash, &NODE_1), + get_fake_share_id_data(&request_id_stash).into(), + ), + BouncerValue::MaybeWantsLater, + ), ]; for (id, expected) in tests { diff --git a/rs/consensus/tests/framework/runner.rs b/rs/consensus/tests/framework/runner.rs index e660f4b1c4f..6ca085ea6b8 100644 --- a/rs/consensus/tests/framework/runner.rs +++ b/rs/consensus/tests/framework/runner.rs @@ -160,6 +160,7 @@ impl<'a> ConsensusRunner<'a> { deps.self_validating_payload_builder.clone(), deps.canister_http_payload_builder.clone(), deps.query_stats_payload_builder.clone(), + deps.vetkd_payload_builder.clone(), deps.dkg_pool.clone(), deps.idkg_pool.clone(), dkg_key_manager.clone(), diff --git a/rs/consensus/tests/framework/types.rs b/rs/consensus/tests/framework/types.rs index 90983fc40eb..c33f26332f1 100644 --- a/rs/consensus/tests/framework/types.rs +++ b/rs/consensus/tests/framework/types.rs @@ -170,6 +170,7 @@ pub struct ConsensusDependencies { pub(crate) self_validating_payload_builder: Arc, pub(crate) canister_http_payload_builder: Arc, pub(crate) query_stats_payload_builder: Arc, + pub(crate) vetkd_payload_builder: Arc, pub consensus_pool: Arc>, pub dkg_pool: Arc>, pub idkg_pool: Arc>, @@ -228,6 +229,7 @@ impl ConsensusDependencies { self_validating_payload_builder: Arc::new(FakeSelfValidatingPayloadBuilder::new()), canister_http_payload_builder: Arc::new(FakeCanisterHttpPayloadBuilder::new()), query_stats_payload_builder: Arc::new(MockBatchPayloadBuilder::new().expect_noop()), + vetkd_payload_builder: Arc::new(MockBatchPayloadBuilder::new().expect_noop()), state_manager, metrics_registry, replica_config, diff --git a/rs/consensus/tests/payload.rs b/rs/consensus/tests/payload.rs index d5c3456d652..7132ed01a69 100644 --- a/rs/consensus/tests/payload.rs +++ b/rs/consensus/tests/payload.rs @@ -62,6 +62,9 @@ fn consensus_produces_expected_batches() { let query_stats_payload_builder = MockBatchPayloadBuilder::new().expect_noop(); let query_stats_payload_builder = Arc::new(query_stats_payload_builder); + let vetkd_payload_builder = MockBatchPayloadBuilder::new().expect_noop(); + let vetkd_payload_builder = Arc::new(vetkd_payload_builder); + let mut state_manager = MockStateManager::new(); state_manager.expect_remove_states_below().return_const(()); state_manager @@ -157,6 +160,7 @@ fn consensus_produces_expected_batches() { Arc::clone(&self_validating_payload_builder) as Arc<_>, Arc::clone(&canister_http_payload_builder) as Arc<_>, query_stats_payload_builder, + vetkd_payload_builder, Arc::clone(&dkg_pool) as Arc<_>, Arc::clone(&idkg_pool) as Arc<_>, dkg_key_manager.clone(), diff --git a/rs/consensus/vetkd/src/lib.rs b/rs/consensus/vetkd/src/lib.rs index 4d5edd08626..c1d4bb1d229 100644 --- a/rs/consensus/vetkd/src/lib.rs +++ b/rs/consensus/vetkd/src/lib.rs @@ -348,6 +348,8 @@ impl VetKdPayloadBuilderImpl { invalid_artifact(InvalidVetKdPayloadReason::VetKdKeyVerificationError(err)) } else { warn!(self.log, "VetKD payload validation failure: {err:?}"); + self.metrics + .payload_errors_inc("validation_failed", &context.key_id()); validation_failed(VetKdPayloadValidationFailure::VetKdKeyVerificationError( err, )) diff --git a/rs/crypto/tree_hash/BUILD.bazel b/rs/crypto/tree_hash/BUILD.bazel index e351077d93e..47a888e3662 100644 --- a/rs/crypto/tree_hash/BUILD.bazel +++ b/rs/crypto/tree_hash/BUILD.bazel @@ -27,6 +27,11 @@ DEV_DEPENDENCIES = [ "@crate_index//:serde_cbor", ] +MACRO_DEV_DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:test-strategy", +] + rust_library( name = "tree_hash", srcs = glob(["src/**"]), @@ -48,6 +53,7 @@ rust_test( rust_test_suite( name = "tree_hash_integration", srcs = glob(["tests/**"]), + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = DEPENDENCIES + DEV_DEPENDENCIES + [":tree_hash"], ) diff --git a/rs/crypto/tree_hash/Cargo.toml b/rs/crypto/tree_hash/Cargo.toml index d424d489d20..57f3ebabcba 100644 --- a/rs/crypto/tree_hash/Cargo.toml +++ b/rs/crypto/tree_hash/Cargo.toml @@ -24,6 +24,7 @@ proptest = { workspace = true } prost = { workspace = true } rand = { workspace = true } serde_cbor = { workspace = true } +test-strategy = "0.4.0" [[bench]] name = "tree_hash" diff --git a/rs/crypto/tree_hash/test_utils/BUILD.bazel b/rs/crypto/tree_hash/test_utils/BUILD.bazel index 2a2188069cd..9f91e591594 100644 --- a/rs/crypto/tree_hash/test_utils/BUILD.bazel +++ b/rs/crypto/tree_hash/test_utils/BUILD.bazel @@ -17,6 +17,11 @@ DEV_DEPENDENCIES = [ "//rs/crypto/test_utils/reproducible_rng", ] +MACRO_DEV_DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:test-strategy", +] + rust_library( name = "test_utils", testonly = True, @@ -33,5 +38,6 @@ rust_library( rust_test_suite( name = "test_utils_integration", srcs = glob(["tests/**"]), + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = DEPENDENCIES + DEV_DEPENDENCIES + [":test_utils"], ) diff --git a/rs/crypto/tree_hash/test_utils/Cargo.toml b/rs/crypto/tree_hash/test_utils/Cargo.toml index c97645ccfb9..4b3046a3990 100644 --- a/rs/crypto/tree_hash/test_utils/Cargo.toml +++ b/rs/crypto/tree_hash/test_utils/Cargo.toml @@ -15,3 +15,4 @@ thiserror = { workspace = true } [dev-dependencies] ic-crypto-test-utils-reproducible-rng = { path = "../../test_utils/reproducible_rng" } +test-strategy = "0.4.0" diff --git a/rs/crypto/tree_hash/test_utils/tests/conversions.rs b/rs/crypto/tree_hash/test_utils/tests/conversions.rs index 278c7d7cd50..405d5938fb7 100644 --- a/rs/crypto/tree_hash/test_utils/tests/conversions.rs +++ b/rs/crypto/tree_hash/test_utils/tests/conversions.rs @@ -1,22 +1,21 @@ use assert_matches::assert_matches; -use ic_crypto_tree_hash::{prune_witness, HashTreeBuilder, Witness, WitnessGenerator}; +use ic_crypto_tree_hash::{prune_witness, HashTreeBuilder, LabeledTree, Witness, WitnessGenerator}; use ic_crypto_tree_hash_test_utils::{ arbitrary::arbitrary_labeled_tree, hash_tree_builder_from_labeled_tree, }; -use proptest::prelude::*; -proptest! { - #[test] - fn hash_tree_builder_from_labeled_tree_works_correctly(tree in arbitrary_labeled_tree()){ - let builder = hash_tree_builder_from_labeled_tree(&tree); - // check that the witness is correct by pruning it completely - let wg = builder - .witness_generator() - .expect("Failed to retrieve a witness constructor"); - let witness = wg - .witness(&tree) - .expect("Failed to build a witness for the whole tree"); - let witness = prune_witness(&witness, &tree).expect("failed to prune witness"); - assert_matches!(witness, Witness::Pruned { digest: _ }); - } +#[test_strategy::proptest] +fn hash_tree_builder_from_labeled_tree_works_correctly( + #[strategy(arbitrary_labeled_tree())] tree: LabeledTree>, +) { + let builder = hash_tree_builder_from_labeled_tree(&tree); + // check that the witness is correct by pruning it completely + let wg = builder + .witness_generator() + .expect("Failed to retrieve a witness constructor"); + let witness = wg + .witness(&tree) + .expect("Failed to build a witness for the whole tree"); + let witness = prune_witness(&witness, &tree).expect("failed to prune witness"); + assert_matches!(witness, Witness::Pruned { digest: _ }); } diff --git a/rs/crypto/tree_hash/tests/conversion.rs b/rs/crypto/tree_hash/tests/conversion.rs index b3a94390e9f..63ce40983b1 100644 --- a/rs/crypto/tree_hash/tests/conversion.rs +++ b/rs/crypto/tree_hash/tests/conversion.rs @@ -5,7 +5,6 @@ use ic_crypto_tree_hash::{ use ic_crypto_tree_hash_test_utils::{ arbitrary::arbitrary_well_formed_mixed_hash_tree, MAX_HASH_TREE_DEPTH, }; -use proptest::prelude::*; use std::convert::TryInto; type TreeOfBlobs = LabeledTree>; @@ -22,12 +21,17 @@ fn labeled(s: &str, b: &[u8]) -> MixedHashTree { MixedHashTree::Labeled(s.into(), Box::new(MixedHashTree::Leaf(b.to_vec()))) } -proptest! { - #[test] - fn prop_well_formed_trees_are_convertible(t in arbitrary_well_formed_mixed_hash_tree()) { - let r: Result = t.clone().try_into(); - assert!(r.is_ok(), "Failed to convert a well-formed mixed hash tree {:?} into a labeled tree: {:?}", t, r); - } +#[test_strategy::proptest] +fn prop_well_formed_trees_are_convertible( + #[strategy(arbitrary_well_formed_mixed_hash_tree())] t: MixedHashTree, +) { + let r: Result = t.clone().try_into(); + assert!( + r.is_ok(), + "Failed to convert a well-formed mixed hash tree {:?} into a labeled tree: {:?}", + t, + r + ); } type T = TreeOfBlobs; diff --git a/rs/crypto/tree_hash/tests/encoding.rs b/rs/crypto/tree_hash/tests/encoding.rs index ef6665edf27..40fd55759fc 100644 --- a/rs/crypto/tree_hash/tests/encoding.rs +++ b/rs/crypto/tree_hash/tests/encoding.rs @@ -57,61 +57,76 @@ fn arbitrary_valid_cbor_encoding() -> impl Strategy { ) } -proptest! { - #[test] - fn prop_tree_to_cbor_roundtrip(t in arbitrary_mixed_hash_tree()) { - let cbor = serde_cbor::to_vec(&t).expect("failed to encode into CBOR"); - let decoded: MixedHashTree = serde_cbor::from_slice(&cbor[..]).expect("failed to decode CBOR"); - assert_eq!(t, decoded); - } - - #[test] - fn prop_cbor_to_tree_roundtrip(v in arbitrary_valid_cbor_encoding()) { - let t: MixedHashTree = serde_cbor::value::from_value(v.clone()).expect("failed to decode CBOR"); - let v_encoded = serde_cbor::value::to_value(&t).expect("failed to encode into CBOR"); - assert_eq!(v, v_encoded); - } +#[test_strategy::proptest] +fn prop_tree_to_cbor_roundtrip(#[strategy(arbitrary_mixed_hash_tree())] t: MixedHashTree) { + let cbor = serde_cbor::to_vec(&t).expect("failed to encode into CBOR"); + let decoded: MixedHashTree = serde_cbor::from_slice(&cbor[..]).expect("failed to decode CBOR"); + assert_eq!(t, decoded); +} +#[test_strategy::proptest] +fn prop_cbor_to_tree_roundtrip(#[strategy(arbitrary_valid_cbor_encoding())] v: Cbor) { + let t: MixedHashTree = serde_cbor::value::from_value(v.clone()).expect("failed to decode CBOR"); + let v_encoded = serde_cbor::value::to_value(&t).expect("failed to encode into CBOR"); + assert_eq!(v, v_encoded); +} - #[test] - fn prop_encoding_fails_on_invalid_cbor(v in arbitrary_invalid_cbor_encoding()) { - let r: Result = serde_cbor::value::from_value(v.clone()); +#[test_strategy::proptest] +fn prop_encoding_fails_on_invalid_cbor(#[strategy(arbitrary_invalid_cbor_encoding())] v: Cbor) { + let r: Result = serde_cbor::value::from_value(v.clone()); - assert!(r.is_err(), "Successfully parsed a MixedHashTree {:?} from invalid CBOR {:?}", r.unwrap(), v); - } + assert!( + r.is_err(), + "Successfully parsed a MixedHashTree {:?} from invalid CBOR {:?}", + r.unwrap(), + v + ); +} - #[test] - fn prop_fails_on_extra_array_items(v in arbitrary_valid_cbor_encoding()) { - use std::string::ToString; +#[test_strategy::proptest] +fn prop_fails_on_extra_array_items(#[strategy(arbitrary_valid_cbor_encoding())] v: Cbor) { + use std::string::ToString; - if let Cbor::Array(mut vec) = v { - vec.push(Cbor::Array(vec![])); + if let Cbor::Array(mut vec) = v { + vec.push(Cbor::Array(vec![])); - let v = Cbor::Array(vec); - let r: Result = serde_cbor::value::from_value(v.clone()); - match r { - Ok(_) => panic!("Successfully parsed a MixedHashTree from invalid CBOR {:?}", v), - Err(err) => assert!(err.to_string().contains("length"), "Expected invalid length error, got {:?}", err), - } + let v = Cbor::Array(vec); + let r: Result = serde_cbor::value::from_value(v.clone()); + match r { + Ok(_) => panic!( + "Successfully parsed a MixedHashTree from invalid CBOR {:?}", + v + ), + Err(err) => assert!( + err.to_string().contains("length"), + "Expected invalid length error, got {:?}", + err + ), } } +} - #[test] - fn prop_fails_on_missing_array_items(v in arbitrary_valid_cbor_encoding()) { - use std::string::ToString; +#[test_strategy::proptest] +fn prop_fails_on_missing_array_items(#[strategy(arbitrary_valid_cbor_encoding())] v: Cbor) { + use std::string::ToString; - if let Cbor::Array(mut vec) = v { - vec.pop(); + if let Cbor::Array(mut vec) = v { + vec.pop(); - let v = Cbor::Array(vec); - let r: Result = serde_cbor::value::from_value(v.clone()); - match r { - Ok(_) => panic!("Successfully parsed a MixedHashTree from invalid CBOR {:?}", v), - Err(err) => assert!(err.to_string().contains("length"), "Expected invalid length error, got {:?}", err), - } + let v = Cbor::Array(vec); + let r: Result = serde_cbor::value::from_value(v.clone()); + match r { + Ok(_) => panic!( + "Successfully parsed a MixedHashTree from invalid CBOR {:?}", + v + ), + Err(err) => assert!( + err.to_string().contains("length"), + "Expected invalid length error, got {:?}", + err + ), } } - } #[test] diff --git a/rs/crypto/tree_hash/tests/merge.rs b/rs/crypto/tree_hash/tests/merge.rs index d87e380d656..af685c108af 100644 --- a/rs/crypto/tree_hash/tests/merge.rs +++ b/rs/crypto/tree_hash/tests/merge.rs @@ -2,7 +2,6 @@ use assert_matches::assert_matches; use ic_crypto_sha2::Sha256; use ic_crypto_tree_hash::{Digest, Label, WitnessBuilder, WitnessGenerationError}; use ic_crypto_tree_hash_test_utils::MAX_HASH_TREE_DEPTH; -use proptest::prelude::*; mod mixed_hash_tree { use super::*; @@ -10,19 +9,31 @@ mod mixed_hash_tree { use ic_crypto_tree_hash_test_utils::arbitrary::arbitrary_mixed_hash_tree; use MixedHashTree::*; - proptest! { - #[test] - fn merge_of_big_tree_is_idempotent(t in arbitrary_mixed_hash_tree()) { - assert_eq!(Ok(t.clone()), MixedHashTree::merge_trees(t.clone(), t)); - } + #[test_strategy::proptest] + fn merge_of_big_tree_is_idempotent(#[strategy(arbitrary_mixed_hash_tree())] t: MixedHashTree) { + assert_eq!(Ok(t.clone()), MixedHashTree::merge_trees(t.clone(), t)); + } - #[test] - fn merge_of_pruned_with_anything_else_is_idempotent(t in arbitrary_mixed_hash_tree()) { - assert_eq!(Ok(t.clone()), MixedHashTree::merge_trees(t.clone(), prune_leaves(&t))); - assert_eq!(Ok(t.clone()), MixedHashTree::merge_trees(t.clone(), prune_left_forks(&t))); - assert_eq!(Ok(t.clone()), MixedHashTree::merge_trees(t.clone(), prune_right_forks(&t))); - assert_eq!(Ok(t.clone()), MixedHashTree::merge_trees(t.clone(), prune_labels(&t))); - } + #[test_strategy::proptest] + fn merge_of_pruned_with_anything_else_is_idempotent( + #[strategy(arbitrary_mixed_hash_tree())] t: MixedHashTree, + ) { + assert_eq!( + Ok(t.clone()), + MixedHashTree::merge_trees(t.clone(), prune_leaves(&t)) + ); + assert_eq!( + Ok(t.clone()), + MixedHashTree::merge_trees(t.clone(), prune_left_forks(&t)) + ); + assert_eq!( + Ok(t.clone()), + MixedHashTree::merge_trees(t.clone(), prune_right_forks(&t)) + ); + assert_eq!( + Ok(t.clone()), + MixedHashTree::merge_trees(t.clone(), prune_labels(&t)) + ); } #[test] diff --git a/rs/crypto/tree_hash/tests/mixed_hash_tree_proto.rs b/rs/crypto/tree_hash/tests/mixed_hash_tree_proto.rs index df958fff7d0..6e367a9083a 100644 --- a/rs/crypto/tree_hash/tests/mixed_hash_tree_proto.rs +++ b/rs/crypto/tree_hash/tests/mixed_hash_tree_proto.rs @@ -5,11 +5,12 @@ use ic_protobuf::proxy::{ProtoProxy, ProxyDecodeError}; use proptest::prelude::*; -proptest! { - #[test] - fn encoding_roundtrip(t in arbitrary_mixed_hash_tree()) { - prop_assert_eq!(t.clone(), PbTree::proxy_decode(&PbTree::proxy_encode(t)).unwrap()) - } +#[test_strategy::proptest] +fn encoding_roundtrip(#[strategy(arbitrary_mixed_hash_tree())] t: T) { + prop_assert_eq!( + t.clone(), + PbTree::proxy_decode(&PbTree::proxy_encode(t)).unwrap() + ) } fn encode(t: &PbTree) -> Vec { diff --git a/rs/crypto/tree_hash/tests/tree_hash.rs b/rs/crypto/tree_hash/tests/tree_hash.rs index 96f66fab8ab..5356c99c5f5 100644 --- a/rs/crypto/tree_hash/tests/tree_hash.rs +++ b/rs/crypto/tree_hash/tests/tree_hash.rs @@ -3,7 +3,6 @@ use ic_crypto_sha2::Sha256; use ic_crypto_test_utils_reproducible_rng::reproducible_rng; use ic_crypto_tree_hash::*; use ic_crypto_tree_hash_test_utils::*; -use proptest::prelude::*; use rand::Rng; use rand::{CryptoRng, RngCore}; use std::collections::BTreeMap; @@ -1375,17 +1374,15 @@ fn tree_with_three_levels() -> HashTreeBuilderImpl { builder } -proptest! { - #[test] - fn recompute_digest_for_mixed_hash_tree_iteratively_and_recursively_produces_same_digest( - tree in arbitrary::arbitrary_well_formed_mixed_hash_tree() - ){ - let rec_or_error = mixed_hash_tree_digest_recursive(&tree); - // ignore the error case, since the iterative algorithm is infallible - if let Ok(rec) = rec_or_error { - let iter = tree.digest(); - assert_eq!(rec, iter); - } +#[test_strategy::proptest] +fn recompute_digest_for_mixed_hash_tree_iteratively_and_recursively_produces_same_digest( + #[strategy(arbitrary::arbitrary_well_formed_mixed_hash_tree())] tree: MixedHashTree, +) { + let rec_or_error = mixed_hash_tree_digest_recursive(&tree); + // ignore the error case, since the iterative algorithm is infallible + if let Ok(rec) = rec_or_error { + let iter = tree.digest(); + assert_eq!(rec, iter); } } diff --git a/rs/embedders/fuzz/BUILD.bazel b/rs/embedders/fuzz/BUILD.bazel index 493b36377ef..2935d7a9d4a 100644 --- a/rs/embedders/fuzz/BUILD.bazel +++ b/rs/embedders/fuzz/BUILD.bazel @@ -44,6 +44,7 @@ rust_library( "@crate_index//:futures", "@crate_index//:lazy_static", "@crate_index//:libfuzzer-sys", + "@crate_index//:tempfile", "@crate_index//:tokio", "@crate_index//:wasm-encoder", "@crate_index//:wasm-smith", diff --git a/rs/embedders/fuzz/src/wasm_executor.rs b/rs/embedders/fuzz/src/wasm_executor.rs index b3043b45ab4..7130ae116a0 100644 --- a/rs/embedders/fuzz/src/wasm_executor.rs +++ b/rs/embedders/fuzz/src/wasm_executor.rs @@ -89,7 +89,10 @@ fn setup_wasm_execution_input(func_ref: FuncRef) -> WasmExecutionInput { let api_type = ApiType::init(UNIX_EPOCH, vec![], user_test_id(24).get()); let canister_current_memory_usage = NumBytes::new(0); let canister_current_message_memory_usage = MessageMemoryUsage::ZERO; - let compilation_cache = Arc::new(CompilationCache::new(NumBytes::new(0))); + let compilation_cache = Arc::new(CompilationCache::new( + NumBytes::new(0), + tempfile::tempdir().unwrap(), + )); WasmExecutionInput { api_type: api_type.clone(), sandbox_safe_system_state: get_system_state(api_type), diff --git a/rs/embedders/src/compilation_cache.rs b/rs/embedders/src/compilation_cache.rs index 106a5c7c50d..25bbb0e1176 100644 --- a/rs/embedders/src/compilation_cache.rs +++ b/rs/embedders/src/compilation_cache.rs @@ -55,9 +55,11 @@ impl MemoryDiskBytes for CompilationCache { } impl CompilationCache { - pub fn new(capacity: NumBytes) -> Self { - Self::Memory { - cache: Mutex::new(LruCache::new(capacity, NumBytes::from(GB))), + pub fn new(capacity: NumBytes, dir: TempDir) -> Self { + Self::Disk { + dir, + cache: Mutex::new(LruCache::new(capacity, NumBytes::from(100 * GB))), + counter: AtomicU64::new(0), } } @@ -211,7 +213,10 @@ impl StoredCompilation { /// each other if they all try to insert in the cache at the same time. #[test] fn concurrent_insertions() { - let cache = CompilationCache::new(NumBytes::from(30 * 1024 * 1024)); + let cache = CompilationCache::new( + NumBytes::from(30 * 1024 * 1024), + tempfile::tempdir().unwrap(), + ); let wasm = wat::parse_str("(module)").unwrap(); let canister_module = CanisterModule::new(wasm.clone()); let binary = ic_wasm_types::BinaryEncodedWasm::new(wasm.clone()); diff --git a/rs/execution_environment/src/hypervisor.rs b/rs/execution_environment/src/hypervisor.rs index b07c5bdcca3..462c8913b34 100644 --- a/rs/execution_environment/src/hypervisor.rs +++ b/rs/execution_environment/src/hypervisor.rs @@ -191,7 +191,7 @@ impl Hypervisor { state_reader: Arc>, // TODO(EXC-1821): Create a temp dir in this directory for use in the // compilation cache. - _temp_dir: &Path, + temp_dir: &Path, ) -> Self { let mut embedder_config = config.embedders_config.clone(); embedder_config.dirty_page_overhead = dirty_page_overhead; @@ -225,7 +225,10 @@ impl Hypervisor { own_subnet_id, log, cycles_account_manager, - compilation_cache: Arc::new(CompilationCache::new(MAX_COMPILATION_CACHE_SIZE)), + compilation_cache: Arc::new(CompilationCache::new( + MAX_COMPILATION_CACHE_SIZE, + tempfile::tempdir_in(temp_dir).unwrap(), + )), deterministic_time_slicing: config.deterministic_time_slicing, cost_to_compile_wasm_instruction: config .embedders_config @@ -253,7 +256,10 @@ impl Hypervisor { own_subnet_id, log, cycles_account_manager, - compilation_cache: Arc::new(CompilationCache::new(MAX_COMPILATION_CACHE_SIZE)), + compilation_cache: Arc::new(CompilationCache::new( + MAX_COMPILATION_CACHE_SIZE, + tempfile::tempdir().unwrap(), + )), deterministic_time_slicing, cost_to_compile_wasm_instruction, dirty_page_overhead, diff --git a/rs/ledger_suite/icrc1/ledger/src/main.rs b/rs/ledger_suite/icrc1/ledger/src/main.rs index 31ce5b260a7..bb59f0dfb7a 100644 --- a/rs/ledger_suite/icrc1/ledger/src/main.rs +++ b/rs/ledger_suite/icrc1/ledger/src/main.rs @@ -785,6 +785,10 @@ fn supported_standards() -> Vec { name: "ICRC-3".to_string(), url: "https://github.com/dfinity/ICRC-1/tree/main/standards/ICRC-3".to_string(), }, + StandardRecord { + name: "ICRC-10".to_string(), + url: "https://github.com/dfinity/ICRC/blob/main/ICRCs/ICRC-10/ICRC-10.md".to_string(), + }, StandardRecord { name: "ICRC-21".to_string(), url: "https://github.com/dfinity/wg-identity-authentication/blob/main/topics/ICRC-21/icrc_21_consent_msg.md".to_string(), diff --git a/rs/ledger_suite/tests/sm-tests/src/lib.rs b/rs/ledger_suite/tests/sm-tests/src/lib.rs index 77e4647deac..4d1203c4c50 100644 --- a/rs/ledger_suite/tests/sm-tests/src/lib.rs +++ b/rs/ledger_suite/tests/sm-tests/src/lib.rs @@ -1084,7 +1084,10 @@ where standards.push(standard.name); } standards.sort(); - assert_eq!(standards, vec!["ICRC-1", "ICRC-2", "ICRC-21", "ICRC-3"]); + assert_eq!( + standards, + vec!["ICRC-1", "ICRC-10", "ICRC-2", "ICRC-21", "ICRC-3"] + ); } pub fn test_total_supply(ledger_wasm: Vec, encode_init_args: fn(InitArgs) -> T) diff --git a/rs/nervous_system/timer_task/BUILD.bazel b/rs/nervous_system/timer_task/BUILD.bazel index 37925185056..fbb0879da24 100644 --- a/rs/nervous_system/timer_task/BUILD.bazel +++ b/rs/nervous_system/timer_task/BUILD.bazel @@ -6,9 +6,12 @@ package(default_visibility = ["//rs/nervous_system:default_visibility"]) DEPENDENCIES = [ # Keep sorted. + "//rs/nervous_system/time_helpers", + "//rs/nervous_system/timers", "@crate_index//:candid", + "@crate_index//:futures", "@crate_index//:ic-cdk", - "@crate_index//:ic-cdk-timers", + "@crate_index//:ic-metrics-encoder", "@crate_index//:serde", ] diff --git a/rs/nervous_system/timer_task/Cargo.toml b/rs/nervous_system/timer_task/Cargo.toml index 8803c30c26f..e811ff901e8 100644 --- a/rs/nervous_system/timer_task/Cargo.toml +++ b/rs/nervous_system/timer_task/Cargo.toml @@ -12,10 +12,13 @@ path = "src/lib.rs" [dependencies] async-trait = { workspace = true } -ic-cdk = { workspace = true } candid = { workspace = true } +ic-cdk = { workspace = true } +ic-metrics-encoder = "1.1.1" +ic-nervous-system-time-helpers = { path = "../time_helpers" } +ic-nervous-system-timers = { path = "../timers" } +futures = { workspace = true } serde = { workspace = true } -ic-cdk-timers = { workspace = true } [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] ic-config = { path = "../../config" } diff --git a/rs/nervous_system/timer_task/src/lib.rs b/rs/nervous_system/timer_task/src/lib.rs index 074f5c32782..28139e27928 100644 --- a/rs/nervous_system/timer_task/src/lib.rs +++ b/rs/nervous_system/timer_task/src/lib.rs @@ -49,7 +49,7 @@ //! } //! fn initial_delay(&self) -> Duration { Duration::from_secs(0) } //! -//! const NAME: &'static str = "SomeRecurringSyncTask"; +//! const NAME: &'static str = "some_recurring_sync_task"; //! } //! //! @@ -62,7 +62,7 @@ //! self.state.with_borrow_mut(|state| state.do_something()); //! } //! -//! const NAME: &'static str = "SomePeriodicTask"; +//! const NAME: &'static str = "some_periodic_sync_task"; //! const INTERVAL: Duration = Duration::from_secs(10); //! } //! @@ -83,7 +83,7 @@ //! //! fn initial_delay(&self) -> Duration { Duration::from_secs(0) } //! -//! const NAME: &'static str = "SomeRecurringAsyncTask"; +//! const NAME: &'static str = "some_recurring_async_task"; //! } //! //! struct SomePeriodicAsyncTask { @@ -96,7 +96,7 @@ //! self.state.with_borrow_mut(|state| state.do_something()).await; //! } //! -//! const NAME: &'static str = "SomePeriodicAsyncTask"; +//! const NAME: &'static str = "some_periodic_async_task"; //! const INTERVAL: Duration = Duration::from_secs(10); //! } //! @@ -113,25 +113,87 @@ //! } //! ``` +mod metrics; + +pub use metrics::MetricsRegistry as TimerTaskMetricsRegistry; + use async_trait::async_trait; +#[cfg(not(target_arch = "wasm32"))] +use futures::FutureExt; +#[cfg(target_arch = "wasm32")] use ic_cdk::spawn; -use ic_cdk_timers::{set_timer, set_timer_interval}; +use ic_nervous_system_time_helpers::now_seconds; +pub use ic_nervous_system_timers::{set_timer, set_timer_interval}; +use metrics::{with_async_metrics, with_sync_metrics, MetricsRegistryRef}; +use std::future::Future; use std::time::Duration; +/// This function is used to spawn a future in a way that is compatible with both the WASM and +/// non-WASM environments that are used for testing. This only actually spawns in the case where +/// the WASM is running in the IC, or has some other source of asynchrony. Otherwise, it +/// immediately executes.s +fn spawn_in_canister_env(future: impl Future + Sized + 'static) { + #[cfg(target_arch = "wasm32")] + { + spawn(future); + } + // This is needed for tests + #[cfg(not(target_arch = "wasm32"))] + { + future + .now_or_never() + .expect("Future could not execute in non-WASM environment"); + } +} + +/// Returns the number of instructions executed in the current message. Returns 0 if not running in +/// a WASM. +fn instruction_counter() -> u64 { + #[cfg(target_arch = "wasm32")] + { + ic_cdk::api::instruction_counter() + } + #[cfg(not(target_arch = "wasm32"))] + { + 0 + } +} + +/// Returns the number of instructions executed in the current call context. Useful for measuring +/// instructions across multiple messages. Returns 0 if not running in a WASM. +fn call_context_instruction_counter() -> u64 { + #[cfg(target_arch = "wasm32")] + { + ic_cdk::api::call_context_instruction_counter() + } + #[cfg(not(target_arch = "wasm32"))] + { + 0 + } +} + pub trait RecurringSyncTask: Sized + 'static { fn execute(self) -> (Duration, Self); fn initial_delay(&self) -> Duration; - fn schedule_with_delay(self, delay: Duration) { + fn schedule_with_delay(self, delay: Duration, metrics_registry: MetricsRegistryRef) { set_timer(delay, move || { + let instructions_before = instruction_counter(); + let (new_delay, new_task) = self.execute(); - new_task.schedule_with_delay(new_delay); + + let instructions_used = instruction_counter() - instructions_before; + with_sync_metrics(metrics_registry, Self::NAME, |metrics| { + metrics.record(instructions_used, now_seconds()); + }); + + new_task.schedule_with_delay(new_delay, metrics_registry); }); } - fn schedule(self) { + fn schedule(self, metrics_registry: MetricsRegistryRef) { let initial_delay = self.initial_delay(); - self.schedule_with_delay(initial_delay); + self.schedule_with_delay(initial_delay, metrics_registry); } const NAME: &'static str; @@ -142,18 +204,28 @@ pub trait RecurringAsyncTask: Sized + 'static { async fn execute(self) -> (Duration, Self); fn initial_delay(&self) -> Duration; - fn schedule_with_delay(self, delay: Duration) { + fn schedule_with_delay(self, delay: Duration, metrics_registry: MetricsRegistryRef) { set_timer(delay, move || { - spawn(async move { + spawn_in_canister_env(async move { + let instructions_before = call_context_instruction_counter(); + with_async_metrics(metrics_registry, Self::NAME, |metrics| { + metrics.record_start(now_seconds()); + }); + let (new_delay, new_task) = self.execute().await; - new_task.schedule_with_delay(new_delay); + + let instructions_used = call_context_instruction_counter() - instructions_before; + with_async_metrics(metrics_registry, Self::NAME, |metrics| { + metrics.record_finish(instructions_used, now_seconds()); + }); + new_task.schedule_with_delay(new_delay, metrics_registry); }); }); } - fn schedule(self) { + fn schedule(self, metrics_registry: MetricsRegistryRef) { let initial_delay = self.initial_delay(); - self.schedule_with_delay(initial_delay); + self.schedule_with_delay(initial_delay, metrics_registry); } const NAME: &'static str; @@ -163,9 +235,16 @@ pub trait PeriodicSyncTask: Copy + Sized + 'static { // TODO: can periodic tasks have a state that is mutable across invocations? fn execute(self); - fn schedule(self) { + fn schedule(self, metrics_registry: MetricsRegistryRef) { set_timer_interval(Self::INTERVAL, move || { + let instructions_before = instruction_counter(); + self.execute(); + + let instructions_used = instruction_counter() - instructions_before; + with_sync_metrics(metrics_registry, Self::NAME, |metrics| { + metrics.record(instructions_used, now_seconds()); + }); }); } @@ -177,10 +256,20 @@ pub trait PeriodicSyncTask: Copy + Sized + 'static { pub trait PeriodicAsyncTask: Copy + Sized + 'static { async fn execute(self); - fn schedule(self) { + fn schedule(self, metrics_registry: MetricsRegistryRef) { set_timer_interval(Self::INTERVAL, move || { - spawn(async move { + spawn_in_canister_env(async move { + let instructions_before = call_context_instruction_counter(); + with_async_metrics(metrics_registry, Self::NAME, |metrics| { + metrics.record_start(now_seconds()); + }); + self.execute().await; + + let instructions_used = call_context_instruction_counter() - instructions_before; + with_async_metrics(metrics_registry, Self::NAME, |metrics| { + metrics.record_finish(instructions_used, now_seconds()); + }); }); }); } diff --git a/rs/nervous_system/timer_task/src/metrics.rs b/rs/nervous_system/timer_task/src/metrics.rs new file mode 100644 index 00000000000..bb789ab6c9a --- /dev/null +++ b/rs/nervous_system/timer_task/src/metrics.rs @@ -0,0 +1,207 @@ +use ic_metrics_encoder::MetricsEncoder; +use std::{cell::RefCell, collections::HashMap, thread::LocalKey}; + +/// Metrics for a synchronous task. +#[derive(Default)] +pub(crate) struct SyncTaskMetrics { + instruction: InstructionMetrics, + last_executed: u64, +} + +impl SyncTaskMetrics { + pub(crate) fn record(&mut self, instructions_used: u64, time_seconds: u64) { + self.last_executed = time_seconds; + self.instruction.record(instructions_used); + } +} + +/// Metrics for an asynchronous task. +#[derive(Default)] +pub(crate) struct AsyncTaskMetrics { + outstanding_count: u64, + instruction: InstructionMetrics, + last_started: u64, + last_finished: u64, +} + +impl AsyncTaskMetrics { + pub(crate) fn record_start(&mut self, time_seconds: u64) { + self.outstanding_count += 1; + self.last_started = time_seconds; + } + + pub(crate) fn record_finish(&mut self, instructions_used: u64, time_seconds: u64) { + self.outstanding_count -= 1; + self.last_finished = time_seconds; + self.instruction.record(instructions_used); + } +} + +/// Metrics for the number of instructions used by a task (synchronous or +/// asynchronous). +#[derive(Default)] +pub(crate) struct InstructionMetrics { + sum: u128, + histogram: InstructionHistogram, +} + +pub const INSTRUCTION_BUCKET_COUNT: usize = 29; +pub const INSTRUCTION_BUCKETS: [u64; INSTRUCTION_BUCKET_COUNT] = [ + 10_000, + 20_000, + 50_000, + 100_000, + 200_000, + 500_000, + 1_000_000, + 2_000_000, + 5_000_000, + 10_000_000, + 20_000_000, + 50_000_000, + 100_000_000, + 200_000_000, + 500_000_000, + 1_000_000_000, + 2_000_000_000, + 5_000_000_000, + 10_000_000_000, + 20_000_000_000, + 50_000_000_000, + 100_000_000_000, + 200_000_000_000, + 500_000_000_000, + 1_000_000_000_000, + 2_000_000_000_000, + 5_000_000_000_000, + 10_000_000_000_000, + u64::MAX, +]; +type InstructionHistogram = [u64; INSTRUCTION_BUCKET_COUNT]; + +impl InstructionMetrics { + fn record(&mut self, instruction_count: u64) { + self.sum += instruction_count as u128; + for (i, &bucket) in INSTRUCTION_BUCKETS.iter().enumerate() { + if instruction_count <= bucket { + self.histogram[i] += 1; + break; + } + } + } + + fn encode( + &self, + task_name: &str, + histogram: ic_metrics_encoder::LabeledHistogramBuilder>, + ) -> std::io::Result<()> { + let buckets = INSTRUCTION_BUCKETS + .iter() + .cloned() + .zip(self.histogram.iter().cloned()) + .map(|(b, m)| (b as f64, m as f64)); + histogram.histogram(&[("task_name", task_name)], buckets, self.sum as f64)?; + Ok(()) + } +} + +#[derive(Default)] +pub struct MetricsRegistry { + sync_metrics: HashMap, + async_metrics: HashMap, +} + +pub(crate) type MetricsRegistryRef = &'static LocalKey>; + +pub(crate) fn with_sync_metrics( + metrics_registry: MetricsRegistryRef, + task_name: &'static str, + f: impl FnOnce(&mut SyncTaskMetrics), +) { + metrics_registry.with_borrow_mut(|metrics_registry| { + let task_metrics = metrics_registry + .sync_metrics + .entry(task_name.to_string()) + .or_default(); + f(task_metrics); + }); +} + +pub(crate) fn with_async_metrics( + metrics_registry: MetricsRegistryRef, + task_name: &'static str, + f: impl FnOnce(&mut AsyncTaskMetrics), +) { + metrics_registry.with_borrow_mut(|metrics_registry| { + let task_metrics = metrics_registry + .async_metrics + .entry(task_name.to_string()) + .or_default(); + f(task_metrics); + }); +} + +impl MetricsRegistry { + /// Encodes the metrics into the given encoder. + pub fn encode( + &self, + prefix: &'static str, + encoder: &mut MetricsEncoder>, + ) -> std::io::Result<()> { + let instruction_histogram_metric_name = format!("{}_task_instruction", prefix); + + let sync_last_executed_metric_name = format!("{}_sync_task_last_executed", prefix); + for (task_name, metrics) in &self.sync_metrics { + metrics.instruction.encode( + task_name, + encoder.histogram_vec( + &instruction_histogram_metric_name, + "The number of instructions used by a task.", + )?, + )?; + encoder + .gauge_vec( + &sync_last_executed_metric_name, + "The time when the task was last executed", + )? + .value(&[("task_name", task_name)], metrics.last_executed as f64)?; + } + + let async_outstanding_counter_metric_name = + format!("{}_async_task_outstanding_count", prefix); + let async_last_started_metric_name = format!("{}_async_task_last_started", prefix); + let async_last_finished_metric_name = format!("{}_async_task_last_finished", prefix); + for (task_name, metrics) in &self.async_metrics { + metrics.instruction.encode( + task_name, + encoder.histogram_vec( + &instruction_histogram_metric_name, + "The number of instructions used by a task. Note that the instructions \ + of an async task are counted across multiple messages", + )?, + )?; + encoder + .counter_vec( + &async_outstanding_counter_metric_name, + "The number of async tasks that have been started but not finished", + )? + .value( + &[("task_name", task_name)], + metrics.outstanding_count as f64, + )?; + encoder + .gauge_vec( + &async_last_started_metric_name, + "The time when the task was last started", + )? + .value(&[("task_name", task_name)], metrics.last_started as f64)?; + encoder + .gauge_vec( + &async_last_finished_metric_name, + "The time when the task was last finished", + )? + .value(&[("task_name", task_name)], metrics.last_finished as f64)?; + } + Ok(()) + } +} diff --git a/rs/nervous_system/timer_task/tests/test_canisters/timer_task_canister.rs b/rs/nervous_system/timer_task/tests/test_canisters/timer_task_canister.rs index 91bc83f4cec..d7e9cb097f5 100644 --- a/rs/nervous_system/timer_task/tests/test_canisters/timer_task_canister.rs +++ b/rs/nervous_system/timer_task/tests/test_canisters/timer_task_canister.rs @@ -1,7 +1,9 @@ use async_trait::async_trait; use ic_cdk::{init, query}; +use ic_metrics_encoder::MetricsEncoder; use ic_nervous_system_timer_task::{ PeriodicAsyncTask, PeriodicSyncTask, RecurringAsyncTask, RecurringSyncTask, + TimerTaskMetricsRegistry, }; use std::{cell::RefCell, collections::BTreeMap, time::Duration}; @@ -14,29 +16,44 @@ fn increase_counter(name: &'static str) { thread_local! { static COUNTERS : RefCell> = const { RefCell::new(BTreeMap::new()) }; + static METRICS_REGISTRY: RefCell = RefCell::new(TimerTaskMetricsRegistry::default()); } fn schedule(name: &str) { match name { - SuccessRecurringSyncTask::NAME => SuccessRecurringSyncTask::default().schedule(), + SuccessRecurringSyncTask::NAME => { + SuccessRecurringSyncTask::default().schedule(&METRICS_REGISTRY) + } IncrementalDelayRecurringSyncTask::NAME => { - IncrementalDelayRecurringSyncTask::default().schedule() + IncrementalDelayRecurringSyncTask::default().schedule(&METRICS_REGISTRY) + } + PanicRecurringSyncTask::NAME => { + PanicRecurringSyncTask::default().schedule(&METRICS_REGISTRY) } - PanicRecurringSyncTask::NAME => PanicRecurringSyncTask::default().schedule(), OutOfInstructionsRecurringSyncTask::NAME => { - OutOfInstructionsRecurringSyncTask::default().schedule() + OutOfInstructionsRecurringSyncTask::default().schedule(&METRICS_REGISTRY) + } + SuccessRecurringAsyncTask::NAME => { + SuccessRecurringAsyncTask::default().schedule(&METRICS_REGISTRY) + } + PanicRecurringAsyncTask::NAME => { + PanicRecurringAsyncTask::default().schedule(&METRICS_REGISTRY) } - SuccessRecurringAsyncTask::NAME => SuccessRecurringAsyncTask::default().schedule(), - PanicRecurringAsyncTask::NAME => PanicRecurringAsyncTask::default().schedule(), OutOfInstructionsBeforeCallRecurringAsyncTask::NAME => { - OutOfInstructionsBeforeCallRecurringAsyncTask::default().schedule() + OutOfInstructionsBeforeCallRecurringAsyncTask::default().schedule(&METRICS_REGISTRY) } OutOfInstructionsAfterCallRecurringAsyncTask::NAME => { - OutOfInstructionsAfterCallRecurringAsyncTask::default().schedule() + OutOfInstructionsAfterCallRecurringAsyncTask::default().schedule(&METRICS_REGISTRY) + } + SuccessPeriodicSyncTask::NAME => { + SuccessPeriodicSyncTask::default().schedule(&METRICS_REGISTRY) + } + SuccessPeriodicAsyncTask::NAME => { + SuccessPeriodicAsyncTask::default().schedule(&METRICS_REGISTRY) + } + PanicPeriodicAsyncTask::NAME => { + PanicPeriodicAsyncTask::default().schedule(&METRICS_REGISTRY) } - SuccessPeriodicSyncTask::NAME => SuccessPeriodicSyncTask::default().schedule(), - SuccessPeriodicAsyncTask::NAME => SuccessPeriodicAsyncTask::default().schedule(), - PanicPeriodicAsyncTask::NAME => PanicPeriodicAsyncTask::default().schedule(), _ => panic!("Unknown task: {}", name), } } @@ -53,6 +70,17 @@ fn get_counter(name: String) -> u64 { COUNTERS.with_borrow(|counters| *counters.get(&name).unwrap_or(&0)) } +#[query] +fn get_metrics() -> String { + METRICS_REGISTRY.with_borrow(|metrics_registry| { + let mut encoder = MetricsEncoder::new(vec![], (ic_cdk::api::time() / 1_000_000) as i64); + metrics_registry + .encode("test_canister", &mut encoder) + .unwrap(); + String::from_utf8(encoder.into_inner()).unwrap() + }) +} + #[query] fn __self_call() {} @@ -77,7 +105,7 @@ impl RecurringSyncTask for SuccessRecurringSyncTask { Duration::from_secs(0) } - const NAME: &'static str = "SuccessRecurringSyncTask"; + const NAME: &'static str = "success_recurring_sync_task"; } #[derive(Default)] @@ -102,7 +130,7 @@ impl RecurringSyncTask for IncrementalDelayRecurringSyncTask { Duration::from_secs(0) } - const NAME: &'static str = "IncrementalDelayRecurringSyncTask"; + const NAME: &'static str = "incremental_delay_recurring_sync_task"; } #[derive(Default)] @@ -118,7 +146,7 @@ impl RecurringSyncTask for PanicRecurringSyncTask { Duration::from_secs(0) } - const NAME: &'static str = "PanicRecurringSyncTask"; + const NAME: &'static str = "panic_recurring_sync_task"; } #[derive(Default)] @@ -138,7 +166,7 @@ impl RecurringSyncTask for OutOfInstructionsRecurringSyncTask { Duration::from_secs(0) } - const NAME: &'static str = "OutOfInstructionsRecurringSyncTask"; + const NAME: &'static str = "out_of_instructions_recurring_sync_task"; } #[derive(Default)] @@ -157,7 +185,7 @@ impl RecurringAsyncTask for SuccessRecurringAsyncTask { Duration::from_secs(0) } - const NAME: &'static str = "SuccessRecurringAsyncTask"; + const NAME: &'static str = "success_recurring_async_task"; } #[derive(Default)] @@ -175,7 +203,7 @@ impl RecurringAsyncTask for PanicRecurringAsyncTask { Duration::from_secs(0) } - const NAME: &'static str = "PanicRecurringAsyncTask"; + const NAME: &'static str = "panic_recurring_async_task"; } #[derive(Default)] @@ -199,7 +227,7 @@ impl RecurringAsyncTask for OutOfInstructionsBeforeCallRecurringAsyncTask { Duration::from_secs(0) } - const NAME: &'static str = "OutOfInstructionsBeforeCallRecurringAsyncTask"; + const NAME: &'static str = "out_of_instructions_before_call_recurring_async_task"; } #[derive(Default)] @@ -223,7 +251,7 @@ impl RecurringAsyncTask for OutOfInstructionsAfterCallRecurringAsyncTask { Duration::from_secs(0) } - const NAME: &'static str = "OutOfInstructionsAfterCallRecurringAsyncTask"; + const NAME: &'static str = "out_of_instructions_after_call_recurring_async_task"; } #[derive(Default, Clone, Copy)] @@ -234,7 +262,7 @@ impl PeriodicSyncTask for SuccessPeriodicSyncTask { increase_counter(Self::NAME); } - const NAME: &'static str = "SuccessPeriodicSyncTask"; + const NAME: &'static str = "success_periodic_sync_task"; const INTERVAL: Duration = Duration::from_secs(1); } @@ -247,7 +275,7 @@ impl PeriodicAsyncTask for SuccessPeriodicAsyncTask { increase_counter(Self::NAME); } - const NAME: &'static str = "SuccessPeriodicAsyncTask"; + const NAME: &'static str = "success_periodic_async_task"; const INTERVAL: Duration = Duration::from_secs(1); } @@ -262,6 +290,6 @@ impl PeriodicAsyncTask for PanicPeriodicAsyncTask { panic!("This task always panics"); } - const NAME: &'static str = "PanicPeriodicAsyncTask"; + const NAME: &'static str = "panic_periodic_async_task"; const INTERVAL: Duration = Duration::from_secs(1); } diff --git a/rs/nervous_system/timer_task/tests/tests.rs b/rs/nervous_system/timer_task/tests/tests.rs index 30ec9f0ea8c..3d97aef8d3c 100644 --- a/rs/nervous_system/timer_task/tests/tests.rs +++ b/rs/nervous_system/timer_task/tests/tests.rs @@ -39,6 +39,16 @@ fn get_counter(state_machine: &StateMachine, canister_id: CanisterId, name: &str Decode!(&reply, u64).unwrap() } +fn get_metrics(state_machine: &StateMachine, canister_id: CanisterId) -> String { + let result = state_machine + .query(canister_id, "get_metrics", Encode!(&()).unwrap()) + .unwrap(); + let WasmResult::Reply(reply) = result else { + panic!("Query failed: {:?}", result); + }; + Decode!(&reply, String).unwrap() +} + fn set_up_canister_with_tasks(state_machine: &StateMachine, task_names: Vec) -> CanisterId { let timer_task_canister_wasm = Project::cargo_bin_maybe_from_env("timer-task-canister", &[]); state_machine @@ -50,12 +60,66 @@ fn set_up_canister_with_tasks(state_machine: &StateMachine, task_names: Vec= 100, "{} counter {}", name, counter); + } + + // We just make sure that the metrics are present without checking the values, to prevent + // the test from becoming too brittle. + let metrics = get_metrics(&state_machine, canister_id); + for metric_snippet in &[ + "test_canister_task_instruction_bucket{task_name=\"success_recurring_sync_task\",le=\"10000\"}", + "test_canister_task_instruction_bucket{task_name=\"success_recurring_async_task\",le=\"10000\"}", + "test_canister_task_instruction_bucket{task_name=\"success_periodic_sync_task\",le=\"10000\"}", + "test_canister_task_instruction_bucket{task_name=\"success_periodic_async_task\",le=\"10000\"}", + "test_canister_task_instruction_sum{task_name=\"success_recurring_sync_task\"}", + "test_canister_task_instruction_sum{task_name=\"success_recurring_async_task\"}", + "test_canister_task_instruction_sum{task_name=\"success_periodic_sync_task\"}", + "test_canister_task_instruction_sum{task_name=\"success_periodic_async_task\"}", + "test_canister_task_instruction_count{task_name=\"success_recurring_sync_task\"}", + "test_canister_task_instruction_count{task_name=\"success_recurring_async_task\"}", + "test_canister_task_instruction_count{task_name=\"success_periodic_sync_task\"}", + "test_canister_task_instruction_count{task_name=\"success_periodic_async_task\"}", + "test_canister_sync_task_last_executed{task_name=\"success_recurring_sync_task\"}", + "test_canister_sync_task_last_executed{task_name=\"success_periodic_sync_task\"}", + "test_canister_async_task_outstanding_count{task_name=\"success_recurring_async_task\"}", + "test_canister_async_task_outstanding_count{task_name=\"success_periodic_async_task\"}", + "test_canister_async_task_last_started{task_name=\"success_recurring_async_task\"}", + "test_canister_async_task_last_started{task_name=\"success_periodic_async_task\"}", + "test_canister_async_task_last_finished{task_name=\"success_recurring_async_task\"}", + "test_canister_async_task_last_finished{task_name=\"success_periodic_async_task\"}", + ] { + assert!( + metrics.contains(metric_snippet), + "Metrics missing snippet: {}", + metric_snippet + ); + } +} + #[test] fn test_incremental_delay() { let state_machine = state_machine_for_test(); let canister_id = set_up_canister_with_tasks( &state_machine, - vec!["IncrementalDelayRecurringSyncTask".to_string()], + vec!["incremental_delay_recurring_sync_task".to_string()], ); for _ in 0..10 { @@ -66,21 +130,23 @@ fn test_incremental_delay() { let counter = get_counter( &state_machine, canister_id, - "IncrementalDelayRecurringSyncTask", + "incremental_delay_recurring_sync_task", ); assert_eq!(counter, 5); } #[test] fn test_out_of_instruction_tasks() { + // In this test, we make sure out-of-instruction tasks of various types don't + // prevent other tasks from executing. let state_machine = state_machine_for_test(); let canister_id = set_up_canister_with_tasks( &state_machine, vec![ - "SuccessRecurringSyncTask".to_string(), - "OutOfInstructionsRecurringSyncTask".to_string(), - "OutOfInstructionsBeforeCallRecurringAsyncTask".to_string(), - "OutOfInstructionsAfterCallRecurringAsyncTask".to_string(), + "success_periodic_sync_task".to_string(), + "out_of_instructions_recurring_sync_task".to_string(), + "out_of_instructions_before_call_recurring_async_task".to_string(), + "out_of_instructions_after_call_recurring_async_task".to_string(), ], ); @@ -89,7 +155,7 @@ fn test_out_of_instruction_tasks() { state_machine.tick(); } - let successful_counter = get_counter(&state_machine, canister_id, "SuccessRecurringSyncTask"); + let successful_counter = get_counter(&state_machine, canister_id, "success_periodic_sync_task"); assert!( successful_counter > 20, "successful_counter {}", @@ -99,14 +165,14 @@ fn test_out_of_instruction_tasks() { let out_of_instructions_sync_counter = get_counter( &state_machine, canister_id, - "OutOfInstructionsRecurringSyncTask", + "out_of_instructions_recurring_sync_task", ); assert_eq!(out_of_instructions_sync_counter, 0); let out_of_instructions_after_call_async_counter = get_counter( &state_machine, canister_id, - "OutOfInstructionsAfterCallRecurringAsyncTask", + "out_of_instructions_after_call_recurring_async_task", ); assert_eq!(out_of_instructions_after_call_async_counter, 1); } @@ -114,66 +180,49 @@ fn test_out_of_instruction_tasks() { #[test] fn test_panic_recurring_sync_task() { let state_machine = state_machine_for_test(); - let canister_id = - set_up_canister_with_tasks(&state_machine, vec!["PanicRecurringSyncTask".to_string()]); + let canister_id = set_up_canister_with_tasks( + &state_machine, + vec!["panic_recurring_sync_task".to_string()], + ); for _ in 0..100 { state_machine.advance_time(std::time::Duration::from_secs(1)); state_machine.tick(); } - let counter = get_counter(&state_machine, canister_id, "PanicRecurringSyncTask"); + let counter = get_counter(&state_machine, canister_id, "panic_recurring_sync_task"); assert_eq!(counter, 0); } #[test] fn test_panic_recurring_async_task() { let state_machine = state_machine_for_test(); - let canister_id = - set_up_canister_with_tasks(&state_machine, vec!["PanicRecurringAsyncTask".to_string()]); - + let canister_id = set_up_canister_with_tasks( + &state_machine, + vec!["panic_recurring_async_task".to_string()], + ); for _ in 0..100 { state_machine.advance_time(std::time::Duration::from_secs(1)); state_machine.tick(); } - let counter = get_counter(&state_machine, canister_id, "PanicRecurringAsyncTask"); + let counter = get_counter(&state_machine, canister_id, "panic_recurring_async_task"); assert_eq!(counter, 1); } -#[test] -fn test_success_tasks() { - let state_machine = state_machine_for_test(); - let task_names = vec![ - "SuccessRecurringSyncTask".to_string(), - "SuccessRecurringAsyncTask".to_string(), - "SuccessPeriodicSyncTask".to_string(), - "SuccessPeriodicAsyncTask".to_string(), - ]; - let canister_id = set_up_canister_with_tasks(&state_machine, task_names.clone()); - - for _ in 0..100 { - state_machine.advance_time(std::time::Duration::from_secs(1)); - state_machine.tick(); - } - - for name in task_names { - let counter = get_counter(&state_machine, canister_id, &name); - assert!(counter >= 100, "{} counter {}", name, counter); - } -} - #[test] fn test_panic_periodic_async_task() { let state_machine = state_machine_for_test(); - let canister_id = - set_up_canister_with_tasks(&state_machine, vec!["PanicPeriodicAsyncTask".to_string()]); + let canister_id = set_up_canister_with_tasks( + &state_machine, + vec!["panic_periodic_async_task".to_string()], + ); for _ in 0..100 { state_machine.advance_time(std::time::Duration::from_secs(1)); state_machine.tick(); } - let counter = get_counter(&state_machine, canister_id, "PanicPeriodicAsyncTask"); + let counter = get_counter(&state_machine, canister_id, "panic_periodic_async_task"); assert!(counter >= 100, "counter {}", counter); } diff --git a/rs/nervous_system/timers/BUILD.bazel b/rs/nervous_system/timers/BUILD.bazel new file mode 100644 index 00000000000..e4b3eada7a3 --- /dev/null +++ b/rs/nervous_system/timers/BUILD.bazel @@ -0,0 +1,37 @@ +load("@rules_rust//rust:defs.bzl", "rust_library", "rust_test", "rust_test_suite") + +package(default_visibility = ["//visibility:public"]) + +DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:ic-cdk-timers", + "@crate_index//:slotmap", +] + +DEV_DEPENDENCIES = DEPENDENCIES + [ +] + +LIB_SRCS = glob( + ["src/**"], + exclude = ["**/*tests.rs"], +) + +rust_library( + name = "timers", + srcs = LIB_SRCS, + crate_name = "ic_nervous_system_timers", + version = "0.0.1", + deps = DEPENDENCIES, +) + +rust_test_suite( + name = "timers_integration_test", + srcs = glob(["tests/**/*.rs"]), + deps = [":timers"] + DEPENDENCIES + DEV_DEPENDENCIES, +) + +rust_test( + name = "timers_test", + srcs = glob(["src/**/*.rs"]), + deps = DEV_DEPENDENCIES, +) diff --git a/rs/nervous_system/timers/Cargo.toml b/rs/nervous_system/timers/Cargo.toml new file mode 100644 index 00000000000..927ba1d4864 --- /dev/null +++ b/rs/nervous_system/timers/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "ic-nervous-system-timers" +version.workspace = true +authors.workspace = true +description.workspace = true +documentation.workspace = true +edition.workspace = true + +[dependencies] +ic-cdk-timers = { workspace = true } +slotmap = { workspace = true } \ No newline at end of file diff --git a/rs/nervous_system/timers/src/lib.rs b/rs/nervous_system/timers/src/lib.rs new file mode 100644 index 00000000000..697705a0b46 --- /dev/null +++ b/rs/nervous_system/timers/src/lib.rs @@ -0,0 +1,34 @@ +//! This crate is meant to provide a test-safe version of timers. +//! +//! It is suggested to create your own timers mod in your own crate that looks like the following: +//! +//! mod timers { +//! #[cfg(not(target_arch = "wasm32"))] +//! pub use crate::real::{clear_timer, set_timer, set_timer_interval}; +//! #[cfg(target_arch = "wasm32")] +//! pub use crate::test::{clear_timer, set_timer, set_timer_interval}; +//! } +//! +//! At this point, you should rely on `mod timers` from your own crate instead of ic-cdk-timers. +//! +//! This will ensur +//! 9e that you use the ic_cdk_timers version of the functions correctly, while +//! having access to some useful test functions in unit tests such as: +//! +//! get_time_for_timers() +//! advance_time_for_timers(duration:Duration) +//! run_pending_timers() +//! get_timer_by_id(timer_id: TimerId) +//! +//! These functions will allow you to test everything thoroughly. +//! + +pub use ic_cdk_timers::TimerId; + +#[cfg(target_arch = "wasm32")] +pub use ic_cdk_timers::{clear_timer, set_timer, set_timer_interval}; + +#[cfg(not(target_arch = "wasm32"))] +pub use test::{clear_timer, set_timer, set_timer_interval}; + +pub mod test; diff --git a/rs/nervous_system/timers/src/test.rs b/rs/nervous_system/timers/src/test.rs new file mode 100644 index 00000000000..f7fa3b14d03 --- /dev/null +++ b/rs/nervous_system/timers/src/test.rs @@ -0,0 +1,247 @@ +use ic_cdk_timers::TimerId; +use slotmap::SlotMap; +use std::cell::RefCell; +use std::collections::BTreeMap; +use std::mem; +use std::time::{Duration, SystemTime}; + +enum TimerTask { + OneShot(OneShotTimerTask), + Recurring(RecurringTimerTask), +} + +impl Default for TimerTask { + fn default() -> Self { + TimerTask::OneShot(OneShotTimerTask::default()) + } +} + +impl TimerTask { + fn next_run(&self) -> Duration { + match self { + TimerTask::Recurring(task) => task.run_at_duration_after_epoch, + TimerTask::OneShot(task) => task.run_at_duration_after_epoch, + } + } +} + +struct RecurringTimerTask { + pub interval: Duration, + pub run_at_duration_after_epoch: Duration, + pub func: Box, +} + +impl Default for RecurringTimerTask { + fn default() -> Self { + Self { + interval: Duration::from_secs(u64::MAX), + run_at_duration_after_epoch: Duration::default(), + func: Box::new(|| {}), + } + } +} + +struct OneShotTimerTask { + pub run_at_duration_after_epoch: Duration, + pub func: Box, +} + +impl Default for OneShotTimerTask { + fn default() -> Self { + Self { + run_at_duration_after_epoch: Duration::default(), + func: Box::new(|| {}), + } + } +} + +thread_local! { + // This could be improved to use some other kind of time that would be perhaps + // common to all time-based functions in the system. However, using systemtime + // would make the tests slow, so that is not a good option. + pub static CURRENT_TIME: RefCell = RefCell::new(SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap()); + pub static TIMER_TASKS: RefCell> = RefCell::default(); +} + +pub fn set_timer(delay: Duration, func: impl FnOnce() + 'static) -> TimerId { + let current_time = CURRENT_TIME.with(|current_time| *current_time.borrow()); + TIMER_TASKS.with(|timer_tasks| { + timer_tasks + .borrow_mut() + .insert(TimerTask::OneShot(OneShotTimerTask { + run_at_duration_after_epoch: current_time + delay, + func: Box::new(func), + })) + }) +} + +pub fn set_timer_interval(interval: Duration, func: impl FnMut() + 'static) -> TimerId { + let current_time = CURRENT_TIME.with(|current_time| *current_time.borrow()); + TIMER_TASKS.with(|timer_tasks| { + timer_tasks + .borrow_mut() + .insert(TimerTask::Recurring(RecurringTimerTask { + interval, + run_at_duration_after_epoch: current_time + interval, + func: Box::new(func), + })) + }) +} + +pub fn clear_timer(id: TimerId) { + TIMER_TASKS.with(|timer_intervals| { + timer_intervals.borrow_mut().remove(id); + }); +} +// Set the time as seconds since unix epoch +pub fn set_time_for_timers(duration_since_epoch: Duration) { + CURRENT_TIME.with(|current_time| { + *current_time.borrow_mut() = duration_since_epoch; + }); +} + +pub fn get_time_for_timers() -> Duration { + CURRENT_TIME.with(|current_time| *current_time.borrow()) +} + +pub fn advance_time_for_timers(duration: Duration) { + CURRENT_TIME.with(|current_time| { + *current_time.borrow_mut() += duration; + }); +} + +pub fn run_pending_timers() { + let current_time = CURRENT_TIME.with(|current_time| *current_time.borrow()); + + let tasks: BTreeMap = TIMER_TASKS.with(|timer_tasks| { + let mut timer_tasks = timer_tasks.borrow_mut(); + let mut runnable_ids = vec![]; + for (id, timer_task) in timer_tasks.iter_mut() { + if current_time >= timer_task.next_run() { + runnable_ids.push(id); + } + } + runnable_ids + .into_iter() + .map(|id| (id, timer_tasks.get_mut(id).map(mem::take).unwrap())) + .collect() + }); + + for (id, task) in tasks.into_iter() { + match task { + TimerTask::OneShot(task) => { + (task.func)(); + TIMER_TASKS.with(|timer_tasks| { + timer_tasks.borrow_mut().remove(id); + }); + } + TimerTask::Recurring(mut task) => { + (task.func)(); + task.run_at_duration_after_epoch += task.interval; + TIMER_TASKS.with(|timer_tasks| { + if let Some(slot) = timer_tasks.borrow_mut().get_mut(id) { + *slot = TimerTask::Recurring(task) + }; + }); + } + } + } +} + +pub fn run_pending_timers_every_interval_for_count(interval: Duration, count: u64) { + for _ in 0..count { + advance_time_for_timers(interval); + run_pending_timers(); + } +} + +pub fn has_timer_task(timer_id: TimerId) -> bool { + TIMER_TASKS.with(|timers| timers.borrow().contains_key(timer_id)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_timers_setting_running_and_clearing() { + thread_local! { + static TIMER_1_COUNT: RefCell = const { RefCell::new(0) }; + static TIMER_2_COUNT: RefCell = const { RefCell::new(0) }; + } + + let timer_1_id = set_timer(Duration::from_secs(10), || { + TIMER_1_COUNT.with(|count| { + *count.borrow_mut() += 1; + }); + }); + let timer_2_id = set_timer_interval(Duration::from_secs(5), || { + TIMER_2_COUNT.with(|count| { + *count.borrow_mut() += 1; + }); + }); + assert!(has_timer_task(timer_1_id)); + assert!(has_timer_task(timer_2_id)); + + let current_time = get_time_for_timers(); + + // Run the timers + run_pending_timers(); + + // Check nothing ran yet + TIMER_1_COUNT.with(|count| { + assert_eq!(*count.borrow(), 0); + }); + TIMER_2_COUNT.with(|count| { + assert_eq!(*count.borrow(), 0); + }); + + // Advance time by 5 seconds + set_time_for_timers(current_time + Duration::from_secs(1)); + advance_time_for_timers(Duration::from_secs(4)); + + // Run the timers + run_pending_timers(); + + // Check that the second timer ran + TIMER_1_COUNT.with(|count| { + assert_eq!(*count.borrow(), 0); + }); + TIMER_2_COUNT.with(|count| { + assert_eq!(*count.borrow(), 1); + }); + + run_pending_timers_every_interval_for_count(Duration::from_secs(5), 2); + + // Check that the first timer ran + TIMER_1_COUNT.with(|count| { + assert_eq!(*count.borrow(), 1); + }); + TIMER_2_COUNT.with(|count| { + assert_eq!(*count.borrow(), 3); + }); + + // Timer 1 should no longer exist, but timer 2 is an interval. + assert!(!has_timer_task(timer_1_id)); + assert!(has_timer_task(timer_2_id)); + + clear_timer(timer_2_id); + + assert!(!has_timer_task(timer_2_id)); + + run_pending_timers_every_interval_for_count(Duration::from_secs(5), 2); + + // Check that second timer is in fact not running + TIMER_2_COUNT.with(|count| { + assert_eq!(*count.borrow(), 3); + }); + + // Time internally advances as expected + assert_eq!( + get_time_for_timers(), + current_time + Duration::from_secs(25) + ); + } +} diff --git a/rs/nns/governance/BUILD.bazel b/rs/nns/governance/BUILD.bazel index 250e55b383d..f6042c8db7c 100644 --- a/rs/nns/governance/BUILD.bazel +++ b/rs/nns/governance/BUILD.bazel @@ -58,6 +58,7 @@ DEPENDENCIES = [ "//rs/nervous_system/temporary", "//rs/nervous_system/time_helpers", "//rs/nervous_system/timer_task", + "//rs/nervous_system/timers", "//rs/nns/cmc", "//rs/nns/common", "//rs/nns/constants", diff --git a/rs/nns/governance/CHANGELOG.md b/rs/nns/governance/CHANGELOG.md index 085bf4ce7e3..4b6e96d9168 100644 --- a/rs/nns/governance/CHANGELOG.md +++ b/rs/nns/governance/CHANGELOG.md @@ -11,6 +11,20 @@ here were moved from the adjacent `unreleased_changelog.md` file. INSERT NEW RELEASES HERE +# 2025-03-01: Proposal 135613 + +http://dashboard.internetcomputer.org/proposal/135613 + +## Added + +* Define API for disburse maturity. While disburse maturity is not yet enabled, clients may already + start preparing for this new NNS neuron operation. + +## Deprecated + +* NnsCanisterUpgrade/NnsRootUpgrade NNS funtions are made obsolete. + + # 2025-02-21: Proposal 135436 http://dashboard.internetcomputer.org/proposal/135436 @@ -160,4 +174,4 @@ the neuron. More precisely, b. Its influence on proposals goes to 0. -END \ No newline at end of file +END diff --git a/rs/nns/governance/Cargo.toml b/rs/nns/governance/Cargo.toml index 077751de0d4..64c8adcf08a 100644 --- a/rs/nns/governance/Cargo.toml +++ b/rs/nns/governance/Cargo.toml @@ -57,6 +57,7 @@ ic-nervous-system-runtime = { path = "../../nervous_system/runtime" } ic-nervous-system-proto = { path = "../../nervous_system/proto" } ic-nervous-system-temporary = { path = "../../nervous_system/temporary" } ic-nervous-system-time-helpers = { path = "../../nervous_system/time_helpers" } +ic-nervous-system-timers = { path = "../../nervous_system/timers" } ic-neurons-fund = { path = "../../nervous_system/neurons_fund" } ic-nns-common = { path = "../common" } ic-nns-constants = { path = "../constants" } diff --git a/rs/nns/governance/canister/canister.rs b/rs/nns/governance/canister/canister.rs index e3cf60624d9..06b503d62a8 100644 --- a/rs/nns/governance/canister/canister.rs +++ b/rs/nns/governance/canister/canister.rs @@ -226,8 +226,6 @@ fn canister_init_(init_payload: ApiGovernanceProto) { init_payload.neurons.len() ); - schedule_timers(); - let governance_proto = InternalGovernanceProto::from(init_payload); set_governance(Governance::new( governance_proto, @@ -236,6 +234,10 @@ fn canister_init_(init_payload: ApiGovernanceProto) { Arc::new(CMCCanister::::new()), Box::new(CanisterRandomnessGenerator::new()), )); + + // Timers etc should not be scheduled until after Governance has been initialized, since + // some of them may rely on Governance state to determine when they should run. + schedule_timers(); } #[pre_upgrade] @@ -273,7 +275,6 @@ fn canister_post_upgrade() { restored_state.xdr_conversion_rate, ); - schedule_timers(); set_governance(Governance::new_restored( restored_state, Arc::new(CanisterEnv::new()), @@ -283,6 +284,10 @@ fn canister_post_upgrade() { )); validate_stable_storage(); + + // Timers etc should not be scheduled until after Governance has been initialized, since + // some of them may rely on Governance state to determine when they should run. + schedule_timers(); } #[cfg(feature = "test")] diff --git a/rs/nns/governance/init/src/lib.rs b/rs/nns/governance/init/src/lib.rs index 80dadc60553..f442ef41fa0 100644 --- a/rs/nns/governance/init/src/lib.rs +++ b/rs/nns/governance/init/src/lib.rs @@ -95,6 +95,17 @@ impl GovernanceCanisterInitPayloadBuilder { TEST_NEURON_2_OWNER_PRINCIPAL, TEST_NEURON_3_ID, TEST_NEURON_3_OWNER_PRINCIPAL, }; use ic_nns_governance_api::pb::v1::{neuron::DissolveState, Neuron}; + use std::time::SystemTime; + + // This assumption here is that with_current_time is used. + // Alternatively, we could use u64::MAX, but u64::MAX is not as + // realistic. + let voting_power_refreshed_timestamp_seconds = Some( + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(), + ); let mut neuron1 = { let neuron_id = NeuronIdProto::from(self.new_neuron_id()); @@ -109,6 +120,7 @@ impl GovernanceCanisterInitPayloadBuilder { * TEST_NEURON_TOTAL_STAKE_E8S */ account: subaccount, not_for_profit: true, + voting_power_refreshed_timestamp_seconds, ..Default::default() } }; @@ -135,6 +147,7 @@ impl GovernanceCanisterInitPayloadBuilder { aging_since_timestamp_seconds: 1, account: subaccount, not_for_profit: false, + voting_power_refreshed_timestamp_seconds, ..Default::default() } }; @@ -153,6 +166,7 @@ impl GovernanceCanisterInitPayloadBuilder { aging_since_timestamp_seconds: 10, account: subaccount, not_for_profit: false, + voting_power_refreshed_timestamp_seconds, ..Default::default() } }; diff --git a/rs/nns/governance/src/canister_state.rs b/rs/nns/governance/src/canister_state.rs index 2bf61a63850..7e3d6883d3c 100644 --- a/rs/nns/governance/src/canister_state.rs +++ b/rs/nns/governance/src/canister_state.rs @@ -109,6 +109,10 @@ pub fn set_governance(gov: Governance) { .validate() .expect("Error initializing the governance canister."); } +#[cfg(any(test, not(target_arch = "wasm32")))] +pub fn set_governance_for_tests(gov: Governance) { + GOVERNANCE.set(gov); +} #[derive(Default)] pub struct CanisterEnv { diff --git a/rs/nns/governance/src/governance.rs b/rs/nns/governance/src/governance.rs index b3d40b26fb7..bbdc0465c76 100644 --- a/rs/nns/governance/src/governance.rs +++ b/rs/nns/governance/src/governance.rs @@ -6396,6 +6396,9 @@ impl Governance { } } + pub fn get_ledger(&self) -> Arc { + self.ledger.clone() + } /// Triggers a reward distribution event if enough time has passed since /// the last one. This is intended to be called by a cron /// process. @@ -6420,24 +6423,7 @@ impl Governance { LOG_PREFIX, e, ), } - // Second try to distribute voting rewards (once per day). - } else if self.should_distribute_rewards() { - // Getting the total ICP supply from the ledger is expensive enough that we - // don't want to do it on every call to `run_periodic_tasks`. So we only - // fetch it when it's needed. - match self.ledger.total_supply().await { - Ok(supply) => { - if self.should_distribute_rewards() { - self.distribute_rewards(supply); - } - } - Err(e) => println!( - "{}Error when getting total ICP supply: {}", - LOG_PREFIX, - GovernanceError::from(e), - ), - } - // Third try to compute cached metrics (once per day). + // Second try to compute cached metrics (once per day). } else if self.should_compute_cached_metrics() { match self.ledger.total_supply().await { Ok(supply) => { @@ -6757,16 +6743,6 @@ impl Governance { self.heap_data.spawning_neurons = Some(false); } - /// Return `true` if rewards should be distributed, `false` otherwise - fn should_distribute_rewards(&self) -> bool { - let latest_distribution_nominal_end_timestamp_seconds = - self.latest_reward_event().day_after_genesis * REWARD_DISTRIBUTION_PERIOD_SECONDS - + self.heap_data.genesis_timestamp_seconds; - - self.most_recent_fully_elapsed_reward_round_end_timestamp_seconds() - > latest_distribution_nominal_end_timestamp_seconds - } - /// Create a reward event. /// /// This method: diff --git a/rs/nns/governance/src/timer_tasks/mod.rs b/rs/nns/governance/src/timer_tasks/mod.rs index d17356e8b36..b95ab8f3fce 100644 --- a/rs/nns/governance/src/timer_tasks/mod.rs +++ b/rs/nns/governance/src/timer_tasks/mod.rs @@ -1,10 +1,18 @@ -use ic_nervous_system_timer_task::RecurringAsyncTask; +use ic_nervous_system_timer_task::{RecurringAsyncTask, TimerTaskMetricsRegistry}; use seeding::SeedingTask; +use std::cell::RefCell; use crate::canister_state::GOVERNANCE; +use crate::timer_tasks::reward_distribution::CalculateDistributableRewardsTask; +mod reward_distribution; mod seeding; +thread_local! { + static METRICS_REGISTRY: RefCell = RefCell::new(TimerTaskMetricsRegistry::default()); +} + pub fn schedule_tasks() { - SeedingTask::new(&GOVERNANCE).schedule(); + SeedingTask::new(&GOVERNANCE).schedule(&METRICS_REGISTRY); + CalculateDistributableRewardsTask::new(&GOVERNANCE).schedule(&METRICS_REGISTRY); } diff --git a/rs/nns/governance/src/timer_tasks/reward_distribution.rs b/rs/nns/governance/src/timer_tasks/reward_distribution.rs new file mode 100644 index 00000000000..1610d8669da --- /dev/null +++ b/rs/nns/governance/src/timer_tasks/reward_distribution.rs @@ -0,0 +1,192 @@ +use crate::governance::{Governance, LOG_PREFIX, REWARD_DISTRIBUTION_PERIOD_SECONDS}; +use crate::pb::v1::GovernanceError; +use async_trait::async_trait; +use ic_nervous_system_timer_task::RecurringAsyncTask; +use std::cell::RefCell; +use std::thread::LocalKey; +use std::time::Duration; + +pub(super) struct CalculateDistributableRewardsTask { + governance: &'static LocalKey>, +} + +impl CalculateDistributableRewardsTask { + pub fn new(governance: &'static LocalKey>) -> Self { + Self { governance } + } + + fn next_reward_task_from_now(&self) -> Duration { + self.governance.with_borrow(|governance| { + let latest_day_after_genesis = governance.latest_reward_event().day_after_genesis; + let now = governance.env.now(); + let genesis_timestamp_seconds = governance.heap_data.genesis_timestamp_seconds; + + delay_until_next_run(now, genesis_timestamp_seconds, latest_day_after_genesis) + }) + } +} + +fn delay_until_next_run( + now: u64, + genesis_timestamp_seconds: u64, + latest_reward_day_after_genesis: u64, +) -> Duration { + let latest_distribution_nominal_end_timestamp_seconds = latest_reward_day_after_genesis + * REWARD_DISTRIBUTION_PERIOD_SECONDS + + genesis_timestamp_seconds; + + // We add 1 to the end of the period to make sure we always run after the period is over, to + // avoid missing any proposals that would be ready to settle right on the edge of the period. + let next = + latest_distribution_nominal_end_timestamp_seconds + REWARD_DISTRIBUTION_PERIOD_SECONDS + 1; + + // We want the difference between next and now. If it's in the past, we want to run + // immediately + Duration::from_secs(next.saturating_sub(now)) +} + +#[async_trait] +impl RecurringAsyncTask for CalculateDistributableRewardsTask { + async fn execute(self) -> (Duration, Self) { + let total_supply = self + .governance + .with_borrow(|governance| governance.get_ledger()) + .total_supply() + .await; + match total_supply { + Ok(total_supply) => { + self.governance.with_borrow_mut(|governance| { + governance.distribute_rewards(total_supply); + }); + } + Err(err) => { + ic_cdk::println!( + "{}Error when getting total ICP supply: {}", + LOG_PREFIX, + GovernanceError::from(err) + ) + } + } + + let next_run = self.next_reward_task_from_now(); + (next_run, self) + } + + fn initial_delay(&self) -> Duration { + self.next_reward_task_from_now() + } + + const NAME: &'static str = "calculate_distributable_rewards"; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::canister_state::{set_governance_for_tests, CanisterRandomnessGenerator}; + use crate::governance::Governance; + use crate::test_utils::{MockEnvironment, StubCMC, StubIcpLedger}; + use std::sync::Arc; + + fn test_delay_until_next_run( + now: u64, + genesis_timestamp_seconds: u64, + latest_reward_day_after_genesis: u64, + expected: Duration, + ) { + let next = delay_until_next_run( + now, + genesis_timestamp_seconds, + latest_reward_day_after_genesis, + ); + assert_eq!(next, expected); + } + + #[test] + fn test_delay_until_next_run_all_zero() { + let now = 0; + let genesis_timestamp_seconds = 0; + let latest_reward_day_after_genesis = 0; + + test_delay_until_next_run( + now, + genesis_timestamp_seconds, + latest_reward_day_after_genesis, + Duration::from_secs(REWARD_DISTRIBUTION_PERIOD_SECONDS + 1), + ); + } + + #[test] + fn test_delay_until_next_run_missed_days() { + let now = REWARD_DISTRIBUTION_PERIOD_SECONDS * 3; + let genesis_timestamp_seconds = 0; + let latest_reward_day_after_genesis = 1; + + test_delay_until_next_run( + now, + genesis_timestamp_seconds, + latest_reward_day_after_genesis, + Duration::from_secs(0), + ); + } + + #[test] + fn test_delay_until_next_run_exactly_at_event() { + let now = REWARD_DISTRIBUTION_PERIOD_SECONDS + 1; + let genesis_timestamp_seconds = 0; + let latest_reward_day_after_genesis = 1; + + test_delay_until_next_run( + now, + genesis_timestamp_seconds, + latest_reward_day_after_genesis, + Duration::from_secs(REWARD_DISTRIBUTION_PERIOD_SECONDS), + ); + } + + #[test] + fn test_delay_until_next_run_with_positive_genesis_value() { + let now = 10_000 + REWARD_DISTRIBUTION_PERIOD_SECONDS * 5 + 500; + let genesis_timestamp_seconds = 10_000; + let latest_reward_day_after_genesis = 5; + + test_delay_until_next_run( + now, + genesis_timestamp_seconds, + latest_reward_day_after_genesis, + Duration::from_secs(REWARD_DISTRIBUTION_PERIOD_SECONDS - 500 + 1), + ); + } + + #[test] + fn test_governance_integration_with_delay_calculation() { + let now = 10_000 + REWARD_DISTRIBUTION_PERIOD_SECONDS * 5 + 500; + let genesis_timestamp_seconds = 10_000; + let latest_reward_day_after_genesis = 5; + + let governance_proto = crate::pb::v1::Governance { + genesis_timestamp_seconds, + latest_reward_event: Some(crate::pb::v1::RewardEvent { + day_after_genesis: latest_reward_day_after_genesis, + ..Default::default() + }), + ..Default::default() + }; + + let gov = Governance::new( + governance_proto, + Arc::new(MockEnvironment::new(vec![], now)), + Arc::new(StubIcpLedger {}), + Arc::new(StubCMC {}), + Box::new(CanisterRandomnessGenerator::new()), + ); + set_governance_for_tests(gov); + + let task = CalculateDistributableRewardsTask::new(&crate::canister_state::GOVERNANCE); + + let next = task.next_reward_task_from_now(); + assert_eq!( + next, + Duration::from_secs(REWARD_DISTRIBUTION_PERIOD_SECONDS - 500 + 1) + ); + } +} diff --git a/rs/nns/governance/tests/fake.rs b/rs/nns/governance/tests/fake.rs index b832e430b74..f940eba48a4 100644 --- a/rs/nns/governance/tests/fake.rs +++ b/rs/nns/governance/tests/fake.rs @@ -5,6 +5,7 @@ use futures::future::FutureExt; use ic_base_types::{CanisterId, PrincipalId}; use ic_ledger_core::tokens::CheckedSub; use ic_nervous_system_common::{cmc::CMC, ledger::IcpLedger, NervousSystemError}; +use ic_nervous_system_timers::test::{advance_time_for_timers, set_time_for_timers}; use ic_nns_common::{ pb::v1::{NeuronId, ProposalId}, types::UpdateIcpXdrConversionRatePayload, @@ -113,10 +114,14 @@ pub struct FakeDriver { /// Create a default mock driver. impl Default for FakeDriver { fn default() -> Self { - Self { + let ret = Self { state: Arc::new(Mutex::new(Default::default())), error_on_next_ledger_call: Arc::new(Mutex::new(None)), - } + }; + set_time_for_timers(std::time::Duration::from_secs( + ret.state.try_lock().unwrap().now, + )); + ret } } @@ -130,6 +135,7 @@ impl FakeDriver { /// Constructs a mock driver that starts at the given timestamp. pub fn at(self, timestamp: u64) -> FakeDriver { + set_time_for_timers(std::time::Duration::from_secs(timestamp)); self.state.lock().unwrap().now = timestamp; self } @@ -178,6 +184,7 @@ impl FakeDriver { /// Increases the time by the given amount. pub fn advance_time_by(&mut self, delta_seconds: u64) { + advance_time_for_timers(std::time::Duration::from_secs(delta_seconds)); self.state.lock().unwrap().now += delta_seconds; } @@ -416,12 +423,12 @@ impl RandomnessGenerator for FakeDriver { Ok(bytes) } - fn seed_rng(&mut self, _seed: [u8; 32]) { - todo!() + fn seed_rng(&mut self, seed: [u8; 32]) { + self.state.try_lock().unwrap().rng = ChaCha20Rng::from_seed(seed); } fn get_rng_seed(&self) -> Option<[u8; 32]> { - todo!() + Some(self.state.try_lock().unwrap().rng.get_seed()) } } @@ -587,6 +594,12 @@ impl Environment for FakeDriver { .unwrap()); } + if method_name == "raw_rand" { + let mut bytes = [0u8; 32]; + self.state.try_lock().unwrap().rng.fill_bytes(&mut bytes); + return Ok(Encode!(&bytes).unwrap()); + } + println!( "WARNING: Unexpected canister call:\n\ ..target = {}\n\ diff --git a/rs/nns/governance/tests/fixtures/environment_fixture.rs b/rs/nns/governance/tests/fixtures/environment_fixture.rs index 17491558a94..18eb4eeacf7 100644 --- a/rs/nns/governance/tests/fixtures/environment_fixture.rs +++ b/rs/nns/governance/tests/fixtures/environment_fixture.rs @@ -1,6 +1,7 @@ use async_trait::async_trait; use candid::{CandidType, Decode, Encode, Error}; use ic_base_types::CanisterId; +use ic_nervous_system_timers::test::{advance_time_for_timers, set_time_for_timers}; use ic_nns_governance::governance::RandomnessGenerator; use ic_nns_governance::{ governance::{Environment, HeapGrowthPotential, RngError}, @@ -12,6 +13,7 @@ use ic_sns_wasm::pb::v1::{DeployNewSnsRequest, ListDeployedSnsesRequest}; use proptest::prelude::RngCore; use rand::rngs::StdRng; use rand_chacha::ChaCha20Rng; +use std::time::Duration; use std::{ collections::VecDeque, sync::{Arc, Mutex}, @@ -61,12 +63,18 @@ pub struct EnvironmentFixture { impl EnvironmentFixture { pub fn new(state: EnvironmentFixtureState) -> Self { - EnvironmentFixture { + let ret = EnvironmentFixture { environment_fixture_state: Arc::new(Mutex::new(state)), - } + }; + set_time_for_timers(Duration::from_secs( + ret.environment_fixture_state.try_lock().unwrap().now, + )); + + ret } pub fn advance_time_by(&self, delta_seconds: u64) { + advance_time_for_timers(Duration::from_secs(delta_seconds)); self.environment_fixture_state.try_lock().unwrap().now += delta_seconds } diff --git a/rs/nns/governance/tests/governance.rs b/rs/nns/governance/tests/governance.rs index d7802bb9488..884b7dcbca0 100644 --- a/rs/nns/governance/tests/governance.rs +++ b/rs/nns/governance/tests/governance.rs @@ -33,6 +33,9 @@ use ic_nervous_system_common_test_keys::{ }; use ic_nervous_system_common_test_utils::{LedgerReply, SpyLedger}; use ic_nervous_system_proto::pb::v1::{Decimal, Duration, GlobalTimeOfDay, Image, Percentage}; +use ic_nervous_system_timers::test::{ + run_pending_timers, run_pending_timers_every_interval_for_count, +}; use ic_neurons_fund::{ NeuronsFundParticipationLimits, PolynomialMatchingFunction, SerializableFunction, }; @@ -43,6 +46,7 @@ use ic_nns_common::{ use ic_nns_constants::{ GOVERNANCE_CANISTER_ID, LEDGER_CANISTER_ID as ICP_LEDGER_CANISTER_ID, SNS_WASM_CANISTER_ID, }; +use ic_nns_governance::canister_state::{governance_mut, set_governance_for_tests}; use ic_nns_governance::governance::RandomnessGenerator; use ic_nns_governance::{ governance::{ @@ -141,6 +145,7 @@ use std::{ #[cfg(feature = "tla")] use ic_nns_governance::governance::tla::{check_traces as tla_check_traces, TLA_TRACES_LKEY}; use ic_nns_governance::storage::reset_stable_memory; +use ic_nns_governance::timer_tasks::schedule_tasks; #[cfg(feature = "tla")] use tla_instrumentation_proc_macros::with_tla_trace_check; @@ -3529,13 +3534,17 @@ async fn test_reward_event_proposals_last_longer_than_reward_period() { // Proposals last longer than the reward period let wait_for_quiet_threshold_seconds = 5 * REWARD_DISTRIBUTION_PERIOD_SECONDS; fixture.wait_for_quiet_threshold_seconds = wait_for_quiet_threshold_seconds; - let mut gov = Governance::new( + let gov = Governance::new( fixture, fake_driver.get_fake_env(), fake_driver.get_fake_ledger(), fake_driver.get_fake_cmc(), fake_driver.get_fake_randomness_generator(), ); + set_governance_for_tests(gov); + let gov = governance_mut(); + schedule_tasks(); + let expected_initial_event = RewardEvent { day_after_genesis: 0, actual_timestamp_seconds: genesis_timestamp_seconds, @@ -3548,12 +3557,14 @@ async fn test_reward_event_proposals_last_longer_than_reward_period() { assert_eq!(*gov.latest_reward_event(), expected_initial_event); fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS / 2); + run_pending_timers(); gov.run_periodic_tasks().now_or_never(); // Too early: nothing should have changed assert_eq!(*gov.latest_reward_event(), expected_initial_event); fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS); - // We are now 1.5 reward periods (1.5 days) past genesis. + // We are now 1.5 reward periods (1.5 days) past genesis.1 + run_pending_timers(); gov.run_periodic_tasks().now_or_never(); // A reward event should have happened, albeit an empty one, i.e., // given that no voting took place, no rewards were distributed. @@ -3596,6 +3607,7 @@ async fn test_reward_event_proposals_last_longer_than_reward_period() { // total_available_e8s_equivalent is equal to 199 maturity. fake_driver.advance_time_by(2 * REWARD_DISTRIBUTION_PERIOD_SECONDS); // We are now at +3.5 reward periods. + run_pending_timers(); gov.run_periodic_tasks().now_or_never(); { let fully_elapsed_reward_rounds = 3; @@ -3621,6 +3633,7 @@ async fn test_reward_event_proposals_last_longer_than_reward_period() { fake_driver.advance_time_by(3 * REWARD_DISTRIBUTION_PERIOD_SECONDS - 5); // We are now at +6.5 - epsilon reward periods. Notice that at 6.5 reward // periods, the proposal become rewardable. + run_pending_timers(); gov.run_periodic_tasks().now_or_never(); // This should have triggered an empty reward event assert_eq!(gov.latest_reward_event().day_after_genesis, 6); @@ -3630,10 +3643,12 @@ async fn test_reward_event_proposals_last_longer_than_reward_period() { // This should generate a RewardEvent, because we now have a rewardable // proposal (i.e. the proposal has reward_status ReadyToSettle). + run_pending_timers(); gov.run_periodic_tasks().now_or_never(); assert_eq!(gov.latest_reward_event().day_after_genesis, 6); // let's advance far enough to trigger a reward event fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS); + run_pending_timers(); gov.run_periodic_tasks().now_or_never(); // Inspect latest_reward_event. @@ -3692,6 +3707,7 @@ async fn test_reward_event_proposals_last_longer_than_reward_period() { // Now let's advance again -- a new empty reward event should happen fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS); + run_pending_timers(); gov.run_periodic_tasks().now_or_never(); assert_eq!( *gov.latest_reward_event(), @@ -3734,14 +3750,18 @@ async fn test_restricted_proposals_are_not_eligible_for_voting_rewards() { fixture.wait_for_quiet_threshold_seconds = wait_for_quiet_threshold_seconds; fixture.short_voting_period_seconds = wait_for_quiet_threshold_seconds; fixture.neuron_management_voting_period_seconds = Some(wait_for_quiet_threshold_seconds); - let mut gov = Governance::new( + let gov = Governance::new( fixture, fake_driver.get_fake_env(), fake_driver.get_fake_ledger(), fake_driver.get_fake_cmc(), fake_driver.get_fake_randomness_generator(), ); - gov.run_periodic_tasks().now_or_never(); + set_governance_for_tests(gov); + let gov = governance_mut(); + schedule_tasks(); + + run_pending_timers(); // Initial reward event assert_eq!( *gov.latest_reward_event(), @@ -3791,8 +3811,9 @@ async fn test_restricted_proposals_are_not_eligible_for_voting_rewards() { // for the reward event. // total_available_e8s_equivalent is equal to reward function * total supply / 365.25, // which is 10% * 1234567890/365.25 = 338006 - fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS); - gov.run_periodic_tasks().now_or_never(); + fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS + 1); + run_pending_timers(); + assert_eq!( *gov.latest_reward_event(), RewardEvent { @@ -3808,6 +3829,7 @@ async fn test_restricted_proposals_are_not_eligible_for_voting_rewards() { { gov.run_periodic_tasks().now_or_never(); + run_pending_timers(); let info = gov.get_proposal_data(ProposalId { id: 1 }).unwrap(); assert_eq!(info.status(), ProposalStatus::Rejected); assert_eq!( @@ -3868,7 +3890,7 @@ fn test_reward_distribution_skips_deleted_neurons() { // Let's set genesis let genesis_timestamp_seconds = fake_driver.now(); fixture.genesis_timestamp_seconds = genesis_timestamp_seconds; - let mut gov = Governance::new( + let gov = Governance::new( fixture, fake_driver.get_fake_env(), fake_driver.get_fake_ledger(), @@ -3876,6 +3898,10 @@ fn test_reward_distribution_skips_deleted_neurons() { fake_driver.get_fake_randomness_generator(), ); + set_governance_for_tests(gov); + let gov = governance_mut(); + schedule_tasks(); + // Make sure that the fixture function indeed did not create a neuron 999. assert_matches!(gov.neuron_store.with_neuron(&NeuronId { id: 999 }, |n| n.clone()).map_err(|e| { let gov_error: GovernanceError = e.into(); @@ -3883,7 +3909,7 @@ fn test_reward_distribution_skips_deleted_neurons() { }), Err(e) if e.error_type == NotFound as i32); // The proposal at genesis time is not ready to be settled - gov.run_periodic_tasks().now_or_never(); + run_pending_timers(); assert_eq!( *gov.latest_reward_event(), RewardEvent { @@ -3897,8 +3923,9 @@ fn test_reward_distribution_skips_deleted_neurons() { } ); - fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS); - gov.run_periodic_tasks().now_or_never(); + fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS + 1); + run_pending_timers(); + assert_eq!( *gov.latest_reward_event(), RewardEvent { @@ -3946,13 +3973,17 @@ async fn test_genesis_in_the_future_in_supported() { // Let's set genesis let genesis_timestamp_seconds = fake_driver.now() + 3 * REWARD_DISTRIBUTION_PERIOD_SECONDS / 2; fixture.genesis_timestamp_seconds = genesis_timestamp_seconds; - let mut gov = Governance::new( + let gov = Governance::new( fixture, fake_driver.get_fake_env(), fake_driver.get_fake_ledger(), fake_driver.get_fake_cmc(), fake_driver.get_fake_randomness_generator(), ); + set_governance_for_tests(gov); + let gov = governance_mut(); + schedule_tasks(); + gov.run_periodic_tasks().now_or_never(); // At genesis, we should create an empty reward event assert_eq!( @@ -4010,7 +4041,7 @@ async fn test_genesis_in_the_future_in_supported() { .unwrap(); fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS); - gov.run_periodic_tasks().now_or_never(); + run_pending_timers(); // We're still pre-genesis at that point assert!(fake_driver.now() < genesis_timestamp_seconds); // No new reward event should have been created... @@ -4076,7 +4107,7 @@ async fn test_genesis_in_the_future_in_supported() { .unwrap(); fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS); - gov.run_periodic_tasks().now_or_never(); + run_pending_timers(); // Now we're 0.5 reward period after genesis. Still no new reward event // expected. assert_eq!( @@ -4107,11 +4138,11 @@ async fn test_genesis_in_the_future_in_supported() { ); // Let's go just at the time we should create the first reward event - fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS / 2); - gov.run_periodic_tasks().now_or_never(); + fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS / 2 + 1); + run_pending_timers(); assert_eq!( fake_driver.now(), - genesis_timestamp_seconds + REWARD_DISTRIBUTION_PERIOD_SECONDS + genesis_timestamp_seconds + REWARD_DISTRIBUTION_PERIOD_SECONDS + 1 ); // Given that the second neuron is much bigger (stake 953) compared to the // the first neuron (stake 23) and only the first neuron voted, @@ -4132,7 +4163,7 @@ async fn test_genesis_in_the_future_in_supported() { // Let's go just at the time we should create the first reward event fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS); - gov.run_periodic_tasks().now_or_never(); + run_pending_timers(); // This time, the other long proposal submitted before genesis should be // considered assert_eq!( @@ -4193,8 +4224,9 @@ fn compute_maturities( let proposals: Vec = proposals.into_iter().map(|x| x.into()).collect(); - let mut fake_driver = - fake::FakeDriver::default().with_supply(Tokens::from_e8s(365_250 * reward_pot_e8s / 100)); + let mut fake_driver = fake::FakeDriver::default() + .at(DEFAULT_TEST_START_TIMESTAMP_SECONDS) + .with_supply(Tokens::from_e8s(365_250 * reward_pot_e8s / 100)); let neurons = stakes_e8s .iter() @@ -4220,13 +4252,16 @@ fn compute_maturities( .with_wait_for_quiet_threshold(10) .build(); - let mut gov = Governance::new( + let gov = Governance::new( governance_proto, fake_driver.get_fake_env(), fake_driver.get_fake_ledger(), fake_driver.get_fake_cmc(), fake_driver.get_fake_randomness_generator(), ); + set_governance_for_tests(gov); + let gov = governance_mut(); + schedule_tasks(); let expected_initial_event = RewardEvent { day_after_genesis: 0, @@ -4237,11 +4272,10 @@ fn compute_maturities( rounds_since_last_distribution: Some(0), latest_round_available_e8s_equivalent: Some(0), }; - assert_eq!(*gov.latest_reward_event(), expected_initial_event); for (i, behavior) in (1_u64..).zip(proposals.iter()) { - behavior.propose_and_vote(&mut gov, format!("proposal {}", i)); + behavior.propose_and_vote(gov, format!("proposal {}", i)); } // Let's advance time by one reward periods. All proposals should be considered @@ -4255,6 +4289,9 @@ fn compute_maturities( } gov.run_periodic_tasks().now_or_never(); + fake_driver.advance_time_by(1); + run_pending_timers(); + // Inspect latest_reward_event. let actual_reward_event = gov.latest_reward_event(); assert_eq!( @@ -4262,7 +4299,8 @@ fn compute_maturities( RewardEvent { day_after_genesis: 1, actual_timestamp_seconds: DEFAULT_TEST_START_TIMESTAMP_SECONDS - + REWARD_DISTRIBUTION_PERIOD_SECONDS, + + REWARD_DISTRIBUTION_PERIOD_SECONDS + + 1, settled_proposals: (1_u64..=proposals.len() as u64) .map(|id| ProposalId { id }) .collect(), @@ -6991,7 +7029,7 @@ fn test_staked_maturity() { let dissolve_delay_seconds = MIN_DISSOLVE_DELAY_FOR_VOTE_ELIGIBILITY_SECONDS; let neuron_stake_e8s = 10 * 100_000_000; // 10 ICP - let (mut driver, mut gov, id, _to_subaccount) = governance_with_staked_neuron( + let (mut driver, gov, id, _to_subaccount) = governance_with_staked_neuron( dissolve_delay_seconds, neuron_stake_e8s, block_height, @@ -6999,6 +7037,10 @@ fn test_staked_maturity() { nonce, ); + set_governance_for_tests(gov); + let gov = governance_mut(); + schedule_tasks(); + gov.neuron_store .with_neuron_mut(&id, |neuron| { assert_eq!(neuron.maturity_e8s_equivalent, 0); @@ -7049,6 +7091,7 @@ fn test_staked_maturity() { // Advance time by 5 days and run periodic tasks so that the neuron is granted (staked) maturity. driver.advance_time_by(5 * 24 * 3600); + run_pending_timers(); gov.run_periodic_tasks().now_or_never(); let neuron = gov @@ -7110,6 +7153,7 @@ fn test_staked_maturity() { .expect("Configuring neuron failed"); driver.advance_time_by(MIN_DISSOLVE_DELAY_FOR_VOTE_ELIGIBILITY_SECONDS); + run_pending_timers(); gov.unstake_maturity_of_dissolved_neurons(); // All the maturity should now be regular maturity @@ -9555,13 +9599,17 @@ async fn test_max_number_of_proposals_with_ballots() { wait_for_quiet_threshold_seconds: 5, ..fixture_two_neurons_second_is_bigger() }; - let mut gov = Governance::new( + let gov = Governance::new( proto, fake_driver.get_fake_env(), fake_driver.get_fake_ledger(), fake_driver.get_fake_cmc(), fake_driver.get_fake_randomness_generator(), ); + set_governance_for_tests(gov); + let gov = governance_mut(); + schedule_tasks(); + // Vote with neuron 1. It is smaller, so proposals are not auto-accepted. for i in 0..MAX_NUMBER_OF_PROPOSALS_WITH_BALLOTS { gov.make_proposal( @@ -9627,7 +9675,8 @@ async fn test_max_number_of_proposals_with_ballots() { ); fake_driver.advance_time_by(10); - gov.run_periodic_tasks().now_or_never(); + gov.run_periodic_tasks().now_or_never().unwrap(); + run_pending_timers(); // Now all proposals should have been rejected. for i in 1_u64..MAX_NUMBER_OF_PROPOSALS_WITH_BALLOTS as u64 + 2 { @@ -9657,7 +9706,8 @@ async fn test_max_number_of_proposals_with_ballots() { // Let's make a reward event happen fake_driver.advance_time_by(REWARD_DISTRIBUTION_PERIOD_SECONDS); - gov.run_periodic_tasks().now_or_never(); + gov.run_periodic_tasks().now_or_never().unwrap(); + run_pending_timers(); // Now it should be allowed to submit a new one gov.make_proposal( @@ -14277,7 +14327,7 @@ async fn test_settle_neurons_fund_is_idempotent_for_create_service_nervous_syste } #[tokio::test] -async fn distribute_rewards_load_test() { +async fn distribute_rewards_test() { // Step 1: Prepare the world. let genesis_timestamp_seconds = 1; @@ -14385,31 +14435,24 @@ async fn distribute_rewards_load_test() { ..Default::default() }; - let mut governance = Governance::new( + let governance = Governance::new( proto, helper.get_fake_env(), helper.get_fake_ledger(), helper.get_fake_cmc(), helper.get_fake_randomness_generator(), ); + set_governance_for_tests(governance); + let governance = governance_mut(); + schedule_tasks(); // Prevent gc. governance.latest_gc_timestamp_seconds = now; // Step 2: Run code under test. - let clock = std::time::Instant::now; - let start = clock(); governance.run_periodic_tasks().await; - let execution_duration_seconds = (clock() - start).as_secs_f64(); + run_pending_timers_every_interval_for_count(std::time::Duration::from_secs(2), 10); - // Step 3: Inspect results. The main thing is to make sure that the code - // under test ran within a "reasonable" amount of time. On a 2019 MacBook - // Pro, it takes < 1.5 s. The limit is set to > 10x that to hopefully avoid - // flakes in CI. - assert!( - execution_duration_seconds < 5.0, - "{}", - execution_duration_seconds - ); + // Step 3: Inspect results. // Step 3.1: Inspect neurons to make sure they have been rewarded for voting. governance.neuron_store.with_active_neurons_iter(|iter| { diff --git a/rs/nns/governance/tests/proposals.rs b/rs/nns/governance/tests/proposals.rs index 6cf9dc72264..d9022832ee5 100644 --- a/rs/nns/governance/tests/proposals.rs +++ b/rs/nns/governance/tests/proposals.rs @@ -1,6 +1,9 @@ use ic_base_types::PrincipalId; use ic_nervous_system_common::{E8, ONE_DAY_SECONDS}; +use ic_nervous_system_timers::test::run_pending_timers_every_interval_for_count; use ic_nns_common::pb::v1::{NeuronId, ProposalId}; +use ic_nns_governance::canister_state::{governance_mut, set_governance_for_tests}; +use ic_nns_governance::timer_tasks::schedule_tasks; use ic_nns_governance::{ governance::{Governance, REWARD_DISTRIBUTION_PERIOD_SECONDS}, pb::v1::{ @@ -14,6 +17,7 @@ use icp_ledger::Tokens; use lazy_static::lazy_static; use maplit::{btreemap, hashmap}; use std::collections::BTreeMap; +use std::time::Duration; pub mod fake; @@ -213,7 +217,7 @@ async fn test_distribute_rewards_with_total_potential_voting_power() { .at(NOW_SECONDS) .with_supply(Tokens::from_tokens(100).unwrap()); - let mut governance = Governance::new( + let governance = Governance::new( GOVERNANCE_PROTO.clone(), fake_driver.get_fake_env(), fake_driver.get_fake_ledger(), @@ -221,8 +225,15 @@ async fn test_distribute_rewards_with_total_potential_voting_power() { fake_driver.get_fake_randomness_generator(), ); + set_governance_for_tests(governance); + let governance = governance_mut(); + schedule_tasks(); + // Step 2: Call code under test. - governance.run_periodic_tasks().await; + run_pending_timers_every_interval_for_count( + Duration::from_secs(REWARD_DISTRIBUTION_PERIOD_SECONDS), + 1, + ); // Step 3: Inspect result(s). let get_neuron_rewards = |neuron_id| { diff --git a/rs/nns/governance/unreleased_changelog.md b/rs/nns/governance/unreleased_changelog.md index 1ec441c51fd..6423403dd98 100644 --- a/rs/nns/governance/unreleased_changelog.md +++ b/rs/nns/governance/unreleased_changelog.md @@ -9,12 +9,13 @@ on the process that this file is part of, see ## Added -## Changed +* Collect metrics about timer tasks defined using ic_nervous_system_timer_task library. +## Changed +* Voting Rewards will be scheduled by a timer instead of by heartbeats. +* ## Deprecated -* NnsCanisterUpgrade/NnsRootUpgrade NNS funtions are made obsolete. - ## Removed ## Fixed diff --git a/rs/nns/integration_tests/src/neuron_following.rs b/rs/nns/integration_tests/src/neuron_following.rs index 58a173495f4..a3d2855813b 100644 --- a/rs/nns/integration_tests/src/neuron_following.rs +++ b/rs/nns/integration_tests/src/neuron_following.rs @@ -200,9 +200,6 @@ fn follow_same_neuron_multiple_times() { ); } -// This test is failing because of timing issues. We disable it until the NNS team -// has a fix. -#[ignore] #[test] fn vote_propagation_with_following() { let state_machine = setup_state_machine_with_nns_canisters(); diff --git a/rs/nns/sns-wasm/CHANGELOG.md b/rs/nns/sns-wasm/CHANGELOG.md index cea83b43281..a242627fcfe 100644 --- a/rs/nns/sns-wasm/CHANGELOG.md +++ b/rs/nns/sns-wasm/CHANGELOG.md @@ -11,6 +11,16 @@ here were moved from the adjacent `unreleased_changelog.md` file. INSERT NEW RELEASES HERE +# 2025-03-01: Proposal 135614 + +http://dashboard.internetcomputer.org/proposal/135614 + +## Added + +* Enable [automatic advancement of SNS target versions for newly launches + SNSs](https://forum.dfinity.org/t/proposal-automatic-sns-target-version-advancement-for-newly-created-snss). + + # 2025-02-21: Proposal 135437 http://dashboard.internetcomputer.org/proposal/135437 diff --git a/rs/pocket_ic_server/tests/test.rs b/rs/pocket_ic_server/tests/test.rs index 2f0edffeaca..cb9a6db10d1 100644 --- a/rs/pocket_ic_server/tests/test.rs +++ b/rs/pocket_ic_server/tests/test.rs @@ -429,7 +429,7 @@ fn test_specified_id() { .unwrap(); let canister_id = rt.block_on(async { let agent = ic_agent::Agent::builder() - .with_url(endpoint.clone()) + .with_url(endpoint) .build() .unwrap(); agent.fetch_root_key().await.unwrap(); diff --git a/rs/protobuf/def/types/v1/consensus.proto b/rs/protobuf/def/types/v1/consensus.proto index 33bd3b36180..a4dafff1ed5 100644 --- a/rs/protobuf/def/types/v1/consensus.proto +++ b/rs/protobuf/def/types/v1/consensus.proto @@ -67,6 +67,7 @@ message Block { reserved 14; bytes canister_http_payload_bytes = 15; bytes query_stats_payload_bytes = 16; + bytes vetkd_payload_bytes = 17; bytes payload_hash = 11; } diff --git a/rs/protobuf/src/gen/types/types.v1.rs b/rs/protobuf/src/gen/types/types.v1.rs index 5615591241d..f326cb9fce6 100644 --- a/rs/protobuf/src/gen/types/types.v1.rs +++ b/rs/protobuf/src/gen/types/types.v1.rs @@ -1292,6 +1292,8 @@ pub struct Block { pub canister_http_payload_bytes: ::prost::alloc::vec::Vec, #[prost(bytes = "vec", tag = "16")] pub query_stats_payload_bytes: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "17")] + pub vetkd_payload_bytes: ::prost::alloc::vec::Vec, #[prost(bytes = "vec", tag = "11")] pub payload_hash: ::prost::alloc::vec::Vec, } diff --git a/rs/replica/setup_ic_network/BUILD.bazel b/rs/replica/setup_ic_network/BUILD.bazel index bba68f83a82..57ecf89ca22 100644 --- a/rs/replica/setup_ic_network/BUILD.bazel +++ b/rs/replica/setup_ic_network/BUILD.bazel @@ -8,6 +8,7 @@ DEPENDENCIES = [ "//rs/config", "//rs/consensus/dkg", "//rs/consensus/utils", + "//rs/consensus/vetkd", "//rs/crypto/interfaces/sig_verification", "//rs/crypto/tls_interfaces", "//rs/cycles_account_manager", diff --git a/rs/replica/setup_ic_network/Cargo.toml b/rs/replica/setup_ic_network/Cargo.toml index a9e464849a7..d9bddfa3b96 100644 --- a/rs/replica/setup_ic_network/Cargo.toml +++ b/rs/replica/setup_ic_network/Cargo.toml @@ -15,6 +15,7 @@ ic-consensus = { path = "../../consensus" } ic-consensus-dkg = { path = "../../consensus/dkg" } ic-consensus-manager = { path = "../../p2p/consensus_manager" } ic-consensus-utils = { path = "../../consensus/utils" } +ic-consensus-vetkd = { path = "../../consensus/vetkd" } ic-crypto-interfaces-sig-verification = { path = "../../crypto/interfaces/sig_verification" } ic-crypto-tls-interfaces = { path = "../../crypto/tls_interfaces" } ic-cycles-account-manager = { path = "../../cycles_account_manager" } diff --git a/rs/replica/setup_ic_network/src/lib.rs b/rs/replica/setup_ic_network/src/lib.rs index f479c71404c..0b9068b0b13 100644 --- a/rs/replica/setup_ic_network/src/lib.rs +++ b/rs/replica/setup_ic_network/src/lib.rs @@ -18,6 +18,7 @@ use ic_consensus::{ use ic_consensus_dkg::DkgBouncer; use ic_consensus_manager::{AbortableBroadcastChannel, AbortableBroadcastChannelBuilder}; use ic_consensus_utils::{crypto::ConsensusCrypto, pool_reader::PoolReader}; +use ic_consensus_vetkd::VetKdPayloadBuilderImpl; use ic_crypto_interfaces_sig_verification::IngressSigVerifier; use ic_crypto_tls_interfaces::TlsConfig; use ic_cycles_account_manager::CyclesAccountManager; @@ -534,6 +535,17 @@ fn start_consensus( metrics_registry, log.clone(), )); + + let vetkd_payload_builder = Arc::new(VetKdPayloadBuilderImpl::new( + artifact_pools.idkg_pool.clone(), + consensus_pool_cache.clone(), + consensus_crypto.clone(), + state_reader.clone(), + subnet_id, + registry_client.clone(), + metrics_registry, + log.clone(), + )); // ------------------------------------------------------------------------ let replica_config = ReplicaConfig { node_id, subnet_id }; @@ -556,6 +568,7 @@ fn start_consensus( self_validating_payload_builder, https_outcalls_payload_builder, Arc::from(query_stats_payload_builder), + vetkd_payload_builder, Arc::clone(&artifact_pools.dkg_pool) as Arc<_>, Arc::clone(&artifact_pools.idkg_pool) as Arc<_>, Arc::clone(&dkg_key_manager) as Arc<_>, diff --git a/rs/sns/governance/CHANGELOG.md b/rs/sns/governance/CHANGELOG.md index 8da6d1fc155..9af177df21e 100644 --- a/rs/sns/governance/CHANGELOG.md +++ b/rs/sns/governance/CHANGELOG.md @@ -11,6 +11,81 @@ here were moved from the adjacent `unreleased_changelog.md` file. INSERT NEW RELEASES HERE +# 2025-03-01: Proposal 135615 + +http://dashboard.internetcomputer.org/proposal/135615 + +## Added + +* New type of SNS proposals `SetTopicsForCustomProposals` can be used to batch-set topics for all custom proposals (or any non-empty subset thereof) at once. + + Example usage: + + ```bash + dfx canister --ic call ${SNS_GOVERNANCE_CANISTER_ID} manage_neuron '( + record { + subaccount = blob "'${PROPOSER_SNS_NEURON_SUBACCOUNT}'"; + command = opt variant { + MakeProposal = record { + url = "https://forum.dfinity.org/t/sns-topics-plan"; + title = "Set topics for custom SNS proposals"; + action = opt variant { + SetTopicsForCustomProposals = record { + custom_function_id_to_topic = vec { + record { + 42; variant { ApplicationBusinessLogic } + } + record { + 123; variant { DaoCommunitySettings } + } + }; + } + }; + summary = "Set topics ApplicationBusinessLogic and \ + DaoCommunitySettings for SNS proposals with \ + IDs 42 and 123 resp."; + } + }; + }, + )' + ``` + +## Changed + +* Enable +[automatic target version advancement](https://forum.dfinity.org/t/proposal-opt-in-mechanism-for-automatic-sns-target-version-advancement/39874) +for newly deployed SNSs. To opt out, please submit a `ManageNervousSystemParameters` proposal, e.g.: + + ```bash + dfx canister --ic call ${SNS_GOVERNANCE_CANISTER_ID} manage_neuron '( + record { + subaccount = blob "'${PROPOSER_SNS_NEURON_SUBACCOUNT}'"; + command = opt variant { + MakeProposal = record { + url = "https://forum.dfinity.org/t/proposal-opt-in-mechanism-for-automatic-sns-target-version-advancement"; + title = "Opt out from automatic advancement of SNS target versions"; + action = opt variant { + ManageNervousSystemParameters = record { + automatically_advance_target_version = opt false; + } + }; + summary = "Disable automatically advancing the target version \ + of this SNS to have full control over the delivery of SNS framework \ + upgrades blessed by the NNS."; + } + }; + }, + )' + ``` + +## Fixed + +* `ManageNervousSystemParameters` proposals now enforce that at least one field is set. + +* Errors caused by trying to submit proposals restricted in pre-initialization mode should no + longer overflow. + + # 2025-02-15: Proposal 135315 http://dashboard.internetcomputer.org/proposal/135315 diff --git a/rs/sns/governance/unreleased_changelog.md b/rs/sns/governance/unreleased_changelog.md index d5bc304304f..94126a0ff42 100644 --- a/rs/sns/governance/unreleased_changelog.md +++ b/rs/sns/governance/unreleased_changelog.md @@ -9,76 +9,12 @@ on the process that this file is part of, see ## Added -* New type of SNS proposals `SetTopicsForCustomProposals` can be used to batch-set topics for all custom proposals (or any non-empty subset thereof) at once. - - Example usage: - - ```bash - dfx canister --ic call ${SNS_GOVERNANCE_CANISTER_ID} manage_neuron '( - record { - subaccount = blob "'${PROPOSER_SNS_NEURON_SUBACCOUNT}'"; - command = opt variant { - MakeProposal = record { - url = "https://forum.dfinity.org/t/sns-topics-plan"; - title = "Set topics for custom SNS proposals"; - action = opt variant { - SetTopicsForCustomProposals = record { - custom_function_id_to_topic = vec { - record { - 42; variant { ApplicationBusinessLogic } - } - record { - 123; variant { DaoCommunitySettings } - } - }; - } - }; - summary = "Set topics ApplicationBusinessLogic and \ - DaoCommunitySettings for SNS proposals with \ - IDs 42 and 123 resp."; - } - }; - }, - )' - ``` - ## Changed -* Enable -[automatic target version advancement](https://forum.dfinity.org/t/proposal-opt-in-mechanism-for-automatic-sns-target-version-advancement/39874) -for newly deployed SNSs. To opt out, please submit a `ManageNervousSystemParameters` proposal, e.g.: - - ```bash - dfx canister --ic call ${SNS_GOVERNANCE_CANISTER_ID} manage_neuron '( - record { - subaccount = blob "'${PROPOSER_SNS_NEURON_SUBACCOUNT}'"; - command = opt variant { - MakeProposal = record { - url = "https://forum.dfinity.org/t/proposal-opt-in-mechanism-for-automatic-sns-target-version-advancement"; - title = "Opt out from automatic advancement of SNS target versions"; - action = opt variant { - ManageNervousSystemParameters = record { - automatically_advance_target_version = opt false; - } - }; - summary = "Disable automatically advancing the target version \ - of this SNS to have full control over the delivery of SNS framework \ - upgrades blessed by the NNS."; - } - }; - }, - )' - ``` - ## Deprecated ## Removed ## Fixed -* `ManageNervousSystemParameters` proposals now enforce that at least one field is set. - -* Errors caused by trying to submit proposals restricted in pre-initialization mode should no - longer overflow. - ## Security diff --git a/rs/state_layout/BUILD.bazel b/rs/state_layout/BUILD.bazel index 34ee6108cd8..0b59bb7ddd6 100644 --- a/rs/state_layout/BUILD.bazel +++ b/rs/state_layout/BUILD.bazel @@ -36,7 +36,10 @@ DEV_DEPENDENCIES = [ "@crate_index//:proptest", ] -MACRO_DEV_DEPENDENCIES = [] +MACRO_DEV_DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:test-strategy", +] rust_library( name = "state_layout", diff --git a/rs/state_layout/Cargo.toml b/rs/state_layout/Cargo.toml index 7734ddd8ad9..d5652807b92 100644 --- a/rs/state_layout/Cargo.toml +++ b/rs/state_layout/Cargo.toml @@ -33,3 +33,4 @@ ic-test-utilities-tmpdir = { path = "../test_utilities/tmpdir" } ic-test-utilities-types = { path = "../test_utilities/types" } itertools = { workspace = true } proptest = { workspace = true } +test-strategy = "0.4.0" diff --git a/rs/state_layout/src/state_layout/tests.rs b/rs/state_layout/src/state_layout/tests.rs index 6cc3dc22748..f5e759f2a36 100644 --- a/rs/state_layout/src/state_layout/tests.rs +++ b/rs/state_layout/src/state_layout/tests.rs @@ -573,9 +573,13 @@ fn test_all_existing_pagemaps() { ); } -proptest! { -#[test] -fn read_back_wasm_memory_overlay_file_names(heights in random_sorted_unique_heights(10)) { +#[test_strategy::proptest] +fn read_back_wasm_memory_overlay_file_names( + #[strategy(random_sorted_unique_heights( + 10, // max_length + ))] + heights: Vec, +) { let tmp = tmpdir("canister"); let canister_layout: CanisterLayout = CanisterLayout::new_untracked(tmp.path().to_owned()).unwrap(); @@ -591,8 +595,18 @@ fn read_back_wasm_memory_overlay_file_names(heights in random_sorted_unique_heig // Create some other files that should be ignored. File::create(canister_layout.raw_path().join("otherfile")).unwrap(); - File::create(canister_layout.stable_memory().overlay(Height::new(42), Shard::new(0))).unwrap(); - File::create(canister_layout.wasm_chunk_store().overlay(Height::new(42), Shard::new(0))).unwrap(); + File::create( + canister_layout + .stable_memory() + .overlay(Height::new(42), Shard::new(0)), + ) + .unwrap(); + File::create( + canister_layout + .wasm_chunk_store() + .overlay(Height::new(42), Shard::new(0)), + ) + .unwrap(); File::create(canister_layout.vmemory_0().base()).unwrap(); let existing_overlays = canister_layout.vmemory_0().existing_overlays().unwrap(); @@ -601,8 +615,13 @@ fn read_back_wasm_memory_overlay_file_names(heights in random_sorted_unique_heig prop_assert_eq!(overlay_names, existing_overlays); } -#[test] -fn read_back_stable_memory_overlay_file_names(heights in random_sorted_unique_heights(10)) { +#[test_strategy::proptest] +fn read_back_stable_memory_overlay_file_names( + #[strategy(random_sorted_unique_heights( + 10, // max_length + ))] + heights: Vec, +) { let tmp = tmpdir("canister"); let canister_layout: CanisterLayout = CanisterLayout::new_untracked(tmp.path().to_owned()).unwrap(); @@ -618,8 +637,18 @@ fn read_back_stable_memory_overlay_file_names(heights in random_sorted_unique_he // Create some other files that should be ignored. File::create(canister_layout.raw_path().join("otherfile")).unwrap(); - File::create(canister_layout.vmemory_0().overlay(Height::new(42), Shard::new(0))).unwrap(); - File::create(canister_layout.wasm_chunk_store().overlay(Height::new(42), Shard::new(0))).unwrap(); + File::create( + canister_layout + .vmemory_0() + .overlay(Height::new(42), Shard::new(0)), + ) + .unwrap(); + File::create( + canister_layout + .wasm_chunk_store() + .overlay(Height::new(42), Shard::new(0)), + ) + .unwrap(); File::create(canister_layout.stable_memory().base()).unwrap(); let existing_overlays = canister_layout.stable_memory().existing_overlays().unwrap(); @@ -628,14 +657,23 @@ fn read_back_stable_memory_overlay_file_names(heights in random_sorted_unique_he prop_assert_eq!(overlay_names, existing_overlays); } -#[test] -fn read_back_wasm_chunk_store_overlay_file_names(heights in random_sorted_unique_heights(10)) { +#[test_strategy::proptest] +fn read_back_wasm_chunk_store_overlay_file_names( + #[strategy(random_sorted_unique_heights( + 10, // max_length + ))] + heights: Vec, +) { let tmp = tmpdir("canister"); let canister_layout: CanisterLayout = CanisterLayout::new_untracked(tmp.path().to_owned()).unwrap(); let overlay_names: Vec = heights .iter() - .map(|h| canister_layout.wasm_chunk_store().overlay(*h, Shard::new(0))) + .map(|h| { + canister_layout + .wasm_chunk_store() + .overlay(*h, Shard::new(0)) + }) .collect(); // Create the overlay files in the directory. @@ -645,26 +683,49 @@ fn read_back_wasm_chunk_store_overlay_file_names(heights in random_sorted_unique // Create some other files that should be ignored. File::create(canister_layout.raw_path().join("otherfile")).unwrap(); - File::create(canister_layout.vmemory_0().overlay(Height::new(42), Shard::new(0))).unwrap(); - File::create(canister_layout.stable_memory().overlay(Height::new(42), Shard::new(0))).unwrap(); + File::create( + canister_layout + .vmemory_0() + .overlay(Height::new(42), Shard::new(0)), + ) + .unwrap(); + File::create( + canister_layout + .stable_memory() + .overlay(Height::new(42), Shard::new(0)), + ) + .unwrap(); File::create(canister_layout.wasm_chunk_store().base()).unwrap(); - let existing_overlays = canister_layout.wasm_chunk_store().existing_overlays().unwrap(); + let existing_overlays = canister_layout + .wasm_chunk_store() + .existing_overlays() + .unwrap(); // We expect the list of paths to be the same including ordering. prop_assert_eq!(overlay_names, existing_overlays); } -#[test] -fn read_back_checkpoint_directory_names(heights in random_sorted_unique_heights(10)) { +#[test_strategy::proptest] +fn read_back_checkpoint_directory_names( + #[strategy(random_sorted_unique_heights( + 10, // max_length + ))] + heights: Vec, +) { with_test_replica_logger(|log| { let tmp = tmpdir("state_layout"); let metrics_registry = ic_metrics::MetricsRegistry::new(); - let state_layout = StateLayout::try_new(log, tmp.path().to_owned(), &metrics_registry).unwrap(); + let state_layout = + StateLayout::try_new(log, tmp.path().to_owned(), &metrics_registry).unwrap(); let checkpoint_names: Vec = heights .iter() - .map(|h| state_layout.checkpoints().join(StateLayout::checkpoint_name(*h))) + .map(|h| { + state_layout + .checkpoints() + .join(StateLayout::checkpoint_name(*h)) + }) .collect(); // Create the (empty) checkpoint directories. @@ -679,8 +740,15 @@ fn read_back_checkpoint_directory_names(heights in random_sorted_unique_heights( }); } -#[test] -fn read_back_canister_snapshot_ids(mut snapshot_ids in random_unique_snapshot_ids(10, 10, 10)) { +#[test_strategy::proptest] +fn read_back_canister_snapshot_ids( + #[strategy(random_unique_snapshot_ids( + 10, // max_length + 10, // canister_count + 10, // snapshots_per_canister_count + ))] + mut snapshot_ids: Vec, +) { let tmp = tmpdir("checkpoint"); let checkpoint_layout: CheckpointLayout = CheckpointLayout::new_untracked(tmp.path().to_owned(), Height::new(0)).unwrap(); @@ -694,36 +762,55 @@ fn read_back_canister_snapshot_ids(mut snapshot_ids in random_unique_snapshot_id prop_assert_eq!(snapshot_ids, actual_snapshot_ids); } -#[test] -fn can_add_and_delete_canister_snapshots(snapshot_ids in random_unique_snapshot_ids(10, 10, 10)) { +#[test_strategy::proptest] +fn can_add_and_delete_canister_snapshots( + #[strategy(random_unique_snapshot_ids( + 10, // max_length + 10, // canister_count + 10, // snapshots_per_canister_count + ))] + snapshot_ids: Vec, +) { let tmp = tmpdir("checkpoint"); let checkpoint_layout: CheckpointLayout = CheckpointLayout::new_untracked(tmp.path().to_owned(), Height::new(0)).unwrap(); - fn check_snapshot_layout(checkpoint_layout: &CheckpointLayout, expected_snapshot_ids: &[SnapshotId]) { + fn check_snapshot_layout( + checkpoint_layout: &CheckpointLayout, + expected_snapshot_ids: &[SnapshotId], + ) { let actual_snapshot_ids = checkpoint_layout.snapshot_ids().unwrap(); let mut expected_snapshot_ids = expected_snapshot_ids.to_vec(); expected_snapshot_ids.sort(); assert_eq!(expected_snapshot_ids, actual_snapshot_ids); - let num_unique_canisters = actual_snapshot_ids.iter().map(|snapshot_id| snapshot_id.get_canister_id()).unique().count(); - - let num_canister_directories = std::fs::read_dir(checkpoint_layout.raw_path().join(SNAPSHOTS_DIR)).unwrap().count(); + let num_unique_canisters = actual_snapshot_ids + .iter() + .map(|snapshot_id| snapshot_id.get_canister_id()) + .unique() + .count(); + + let num_canister_directories = + std::fs::read_dir(checkpoint_layout.raw_path().join(SNAPSHOTS_DIR)) + .unwrap() + .count(); assert_eq!(num_unique_canisters, num_canister_directories); } for i in 0..snapshot_ids.len() { check_snapshot_layout(&checkpoint_layout, &snapshot_ids[..i]); checkpoint_layout.snapshot(&snapshot_ids[i]).unwrap(); // Creates the directory as side effect. - check_snapshot_layout(&checkpoint_layout, &snapshot_ids[..(i+1)]); + check_snapshot_layout(&checkpoint_layout, &snapshot_ids[..(i + 1)]); } for i in 0..snapshot_ids.len() { check_snapshot_layout(&checkpoint_layout, &snapshot_ids[i..]); - checkpoint_layout.snapshot(&snapshot_ids[i]).unwrap().delete_dir().unwrap(); - check_snapshot_layout(&checkpoint_layout, &snapshot_ids[(i+1)..]); + checkpoint_layout + .snapshot(&snapshot_ids[i]) + .unwrap() + .delete_dir() + .unwrap(); + check_snapshot_layout(&checkpoint_layout, &snapshot_ids[(i + 1)..]); } } - -} diff --git a/rs/state_machine_tests/BUILD.bazel b/rs/state_machine_tests/BUILD.bazel index 7df73104794..760853a70be 100644 --- a/rs/state_machine_tests/BUILD.bazel +++ b/rs/state_machine_tests/BUILD.bazel @@ -108,6 +108,11 @@ DEV_DEPENDENCIES = [ MACRO_DEPENDENCIES = [] +MACRO_DEV_DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:test-strategy", +] + rust_binary( name = "ic-test-state-machine", testonly = True, @@ -155,6 +160,7 @@ rust_ic_test( rust_ic_test( name = "state_machine_unit_test", crate = ":state_machine_tests", + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = [ # Keep sorted. "@crate_index//:proptest", diff --git a/rs/state_machine_tests/Cargo.toml b/rs/state_machine_tests/Cargo.toml index 9c53b943505..a1ebe78e2db 100644 --- a/rs/state_machine_tests/Cargo.toml +++ b/rs/state_machine_tests/Cargo.toml @@ -82,3 +82,4 @@ ic-base-types = { path = "../types/base_types" } ic-test-utilities = { path = "../test_utilities" } ic-universal-canister = { path = "../universal_canister/lib" } proptest = { workspace = true } +test-strategy = "0.4.0" diff --git a/rs/state_machine_tests/src/lib.rs b/rs/state_machine_tests/src/lib.rs index 91faee37b63..b6a1d698b81 100644 --- a/rs/state_machine_tests/src/lib.rs +++ b/rs/state_machine_tests/src/lib.rs @@ -105,7 +105,7 @@ use ic_replicated_state::{ use ic_state_layout::{CheckpointLayout, ReadOnly}; use ic_state_manager::StateManagerImpl; use ic_test_utilities::crypto::CryptoReturningOk; -use ic_test_utilities_consensus::FakeConsensusPoolCache; +use ic_test_utilities_consensus::{batch::MockBatchPayloadBuilder, FakeConsensusPoolCache}; use ic_test_utilities_metrics::{ fetch_counter_vec, fetch_histogram_stats, fetch_int_counter, fetch_int_gauge, fetch_int_gauge_vec, Labels, @@ -852,6 +852,7 @@ pub struct StateMachine { /// A drop guard to gracefully cancel the ingress watcher task. _ingress_watcher_drop_guard: tokio_util::sync::DropGuard, query_stats_payload_builder: Arc, + vetkd_payload_builder: Arc, remove_old_states: bool, // This field must be the last one so that the temporary directory is deleted at the very end. state_dir: Box, @@ -1309,6 +1310,7 @@ impl StateMachineBuilder { self_validating_payload_builder, sm.canister_http_payload_builder.clone(), sm.query_stats_payload_builder.clone(), + sm.vetkd_payload_builder.clone(), sm.metrics_registry.clone(), sm.replica_logger.clone(), )); @@ -1615,6 +1617,8 @@ impl StateMachine { replica_logger.clone(), )); + let vetkd_payload_builder = Arc::new(MockBatchPayloadBuilder::new().expect_noop()); + // Setup ingress watcher for synchronous call endpoint. let (completed_execution_messages_tx, completed_execution_messages_rx) = mpsc::channel(COMPLETED_EXECUTION_MESSAGES_BUFFER_SIZE); @@ -1859,6 +1863,7 @@ impl StateMachine { canister_http_pool, canister_http_payload_builder, query_stats_payload_builder: pocket_query_stats_payload_builder, + vetkd_payload_builder, remove_old_states, } } diff --git a/rs/state_machine_tests/src/tests.rs b/rs/state_machine_tests/src/tests.rs index be007266078..c968fedeb22 100644 --- a/rs/state_machine_tests/src/tests.rs +++ b/rs/state_machine_tests/src/tests.rs @@ -1,31 +1,29 @@ use ic_secp256k1::{DerivationIndex, DerivationPath, PrivateKey, PublicKey}; -use proptest::{collection::vec as pvec, prelude::*, prop_assert, proptest}; - -proptest! { - #[test] - fn test_derivation_prop( - derivation_path_bytes in pvec(pvec(any::(), 1..10), 1..10), - message_hash in pvec(any::(), 32), - ) { - let private_key_bytes = - hex::decode("fb7d1f5b82336bb65b82bf4f27776da4db71c1ef632c6a7c171c0cbfa2ea4920").unwrap(); - - let ecdsa_secret_key: PrivateKey = - PrivateKey::deserialize_sec1(private_key_bytes.as_slice()).unwrap(); - - let derivation_path = DerivationPath::new( - derivation_path_bytes - .into_iter() - .map(DerivationIndex) - .collect(), - ); - - let derived_secret_key = ecdsa_secret_key.derive_subkey(&derivation_path).0; - let signature = derived_secret_key.sign_message_with_ecdsa(&message_hash); - - let derived_public_key = derived_secret_key.public_key(); - prop_assert!(derived_public_key.verify_ecdsa_signature(&message_hash, &signature)); - } +use proptest::{collection::vec as pvec, prelude::*, prop_assert}; + +#[test_strategy::proptest] +fn test_derivation_prop( + #[strategy(pvec(pvec(any::(), 1..10), 1..10))] derivation_path_bytes: Vec>, + #[strategy(pvec(any::(), 32))] message_hash: Vec, +) { + let private_key_bytes = + hex::decode("fb7d1f5b82336bb65b82bf4f27776da4db71c1ef632c6a7c171c0cbfa2ea4920").unwrap(); + + let ecdsa_secret_key: PrivateKey = + PrivateKey::deserialize_sec1(private_key_bytes.as_slice()).unwrap(); + + let derivation_path = DerivationPath::new( + derivation_path_bytes + .into_iter() + .map(DerivationIndex) + .collect(), + ); + + let derived_secret_key = ecdsa_secret_key.derive_subkey(&derivation_path).0; + let signature = derived_secret_key.sign_message_with_ecdsa(&message_hash); + + let derived_public_key = derived_secret_key.public_key(); + prop_assert!(derived_public_key.verify_ecdsa_signature(&message_hash, &signature)); } #[test] diff --git a/rs/state_manager/BUILD.bazel b/rs/state_manager/BUILD.bazel index 5490a78eb31..1296768a14b 100644 --- a/rs/state_manager/BUILD.bazel +++ b/rs/state_manager/BUILD.bazel @@ -3,6 +3,16 @@ load("//bazel:defs.bzl", "rust_bench", "rust_ic_test") package(default_visibility = ["//visibility:public"]) +MACRO_DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:strum_macros", +] + +MACRO_DEV_DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:test-strategy", +] + rust_library( name = "state_manager", srcs = glob(["src/**"]), @@ -11,10 +21,7 @@ rust_library( "//conditions:default": [], }), crate_name = "ic_state_manager", - proc_macro_deps = [ - # Keep sorted. - "@crate_index//:strum_macros", - ], + proc_macro_deps = MACRO_DEPENDENCIES, version = "0.9.0", deps = [ # Keep sorted. @@ -64,6 +71,7 @@ rust_test( name = "state_manager_lib_tests", timeout = "long", crate = ":state_manager", + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = [ # Keep sorted. "//packages/ic-error-types", @@ -94,6 +102,7 @@ rust_ic_test( "tests/state_manager.rs", ], crate_root = "tests/state_manager.rs", + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = [ # Keep sorted. ":state_manager", diff --git a/rs/state_manager/Cargo.toml b/rs/state_manager/Cargo.toml index 120f3b26e97..bddd768f025 100644 --- a/rs/state_manager/Cargo.toml +++ b/rs/state_manager/Cargo.toml @@ -76,6 +76,7 @@ ic-wasm-types = { path = "../types/wasm_types" } maplit = "1.0.2" proptest = { workspace = true } strum = { workspace = true } +test-strategy = "0.4.0" [[bench]] name = "bench_traversal" diff --git a/rs/state_manager/src/labeled_tree_visitor/tests.rs b/rs/state_manager/src/labeled_tree_visitor/tests.rs index a894f4d995f..b35a7b3e39f 100644 --- a/rs/state_manager/src/labeled_tree_visitor/tests.rs +++ b/rs/state_manager/src/labeled_tree_visitor/tests.rs @@ -108,11 +108,9 @@ fn arb_tree() -> impl Strategy>> { ) } -proptest! { - #[test] - fn roundtrip(t in arb_tree()) { - let mut v = LabeledTreeVisitor::default(); - let _ = traverse_labeled_tree(&t, &mut v); - prop_assert_eq!(v.finish(), t); - } +#[test_strategy::proptest] +fn roundtrip(#[strategy(arb_tree())] t: LabeledTree>) { + let mut v = LabeledTreeVisitor::default(); + let _ = traverse_labeled_tree(&t, &mut v); + prop_assert_eq!(v.finish(), t); } diff --git a/rs/state_manager/src/manifest/tests/compatibility.rs b/rs/state_manager/src/manifest/tests/compatibility.rs index 8b4188f4329..610fe24ed53 100644 --- a/rs/state_manager/src/manifest/tests/compatibility.rs +++ b/rs/state_manager/src/manifest/tests/compatibility.rs @@ -599,24 +599,33 @@ fn deterministic_manifest_hash() { ); } -proptest! { - #[test] - fn chunk_info_deterministic_encoding(chunk_info in arbitrary_chunk_info()) { - assert_eq!(encode_chunk_info(&chunk_info), encode_chunk_info_expected(&chunk_info)); - } +#[test_strategy::proptest] +fn chunk_info_deterministic_encoding(#[strategy(arbitrary_chunk_info())] chunk_info: ChunkInfo) { + assert_eq!( + encode_chunk_info(&chunk_info), + encode_chunk_info_expected(&chunk_info) + ); +} - #[test] - fn file_info_deterministic_encoding(file_info in arbitrary_file_info()) { - assert_eq!(encode_file_info(&file_info), encode_file_info_expected(&file_info)); - } +#[test_strategy::proptest] +fn file_info_deterministic_encoding(#[strategy(arbitrary_file_info())] file_info: FileInfo) { + assert_eq!( + encode_file_info(&file_info), + encode_file_info_expected(&file_info) + ); +} - #[test] - fn manifest_deterministic_encoding( - version in 0..=(MAX_SUPPORTED_STATE_SYNC_VERSION as u32), - file_table in prop::collection::vec(arbitrary_file_info(), 0..=1000), - chunk_table in prop::collection::vec(arbitrary_chunk_info(), 0..=1000) - ) { - let manifest = Manifest::new(version.try_into().unwrap(), file_table, chunk_table); - assert_eq!(encode_manifest(&manifest), encode_manifest_expected(&manifest)); - } +#[test_strategy::proptest] +fn manifest_deterministic_encoding( + #[strategy(0..=MAX_SUPPORTED_STATE_SYNC_VERSION as u32)] version: u32, + #[strategy(prop::collection::vec(arbitrary_file_info(), 0..=1000))] file_table: Vec, + #[strategy(prop::collection::vec(arbitrary_chunk_info(), 0..=1000))] chunk_table: Vec< + ChunkInfo, + >, +) { + let manifest = Manifest::new(version.try_into().unwrap(), file_table, chunk_table); + assert_eq!( + encode_manifest(&manifest), + encode_manifest_expected(&manifest) + ); } diff --git a/rs/state_manager/src/stream_encoding/tests.rs b/rs/state_manager/src/stream_encoding/tests.rs index 80740d205e6..7a247a5044a 100644 --- a/rs/state_manager/src/stream_encoding/tests.rs +++ b/rs/state_manager/src/stream_encoding/tests.rs @@ -2,68 +2,101 @@ use super::*; use ic_base_types::NumSeconds; use ic_canonical_state::MAX_SUPPORTED_CERTIFICATION_VERSION; use ic_registry_subnet_type::SubnetType; -use ic_replicated_state::{testing::ReplicatedStateTesting, ReplicatedState}; +use ic_replicated_state::{testing::ReplicatedStateTesting, ReplicatedState, Stream}; use ic_test_utilities_state::{arb_stream, new_canister_state}; use ic_test_utilities_types::ids::{canister_test_id, subnet_test_id, user_test_id}; use ic_types::{xnet::StreamSlice, Cycles}; -use proptest::prelude::*; const INITIAL_CYCLES: Cycles = Cycles::new(1 << 36); -proptest! { - #[test] - fn stream_encode_decode_roundtrip(stream in arb_stream(0, 10, 0, 10)) { - let mut state = ReplicatedState::new(subnet_test_id(1), SubnetType::Application); +#[test_strategy::proptest] +fn stream_encode_decode_roundtrip( + #[strategy(arb_stream( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + stream: Stream, +) { + let mut state = ReplicatedState::new(subnet_test_id(1), SubnetType::Application); - let subnet = subnet_test_id(42); - let stream_slice: StreamSlice = stream.clone().into(); - state.modify_streams(|streams| { - streams.insert(subnet, stream); - }); - state.metadata.certification_version = MAX_SUPPORTED_CERTIFICATION_VERSION; + let subnet = subnet_test_id(42); + let stream_slice: StreamSlice = stream.clone().into(); + state.modify_streams(|streams| { + streams.insert(subnet, stream); + }); + state.metadata.certification_version = MAX_SUPPORTED_CERTIFICATION_VERSION; - // Add some noise, for good measure. - state.put_canister_state(new_canister_state( - canister_test_id(13), - user_test_id(24).get(), - INITIAL_CYCLES, - NumSeconds::from(100_000), - )); + // Add some noise, for good measure. + state.put_canister_state(new_canister_state( + canister_test_id(13), + user_test_id(24).get(), + INITIAL_CYCLES, + NumSeconds::from(100_000), + )); - let tree_encoding = encode_stream_slice(&state, subnet, stream_slice.header().begin(), stream_slice.header().end(), None).0; - let bytes = encode_tree(tree_encoding.clone()); - assert_eq!(decode_stream_slice(&bytes[..]), Ok((subnet, stream_slice)), "failed to decode tree {:?}", tree_encoding); - } - - #[test] - fn stream_encode_with_size_limit(stream in arb_stream(0, 10, 0, 10), size_limit in 0..1000usize) { - let mut state = ReplicatedState::new(subnet_test_id(1), SubnetType::Application); + let tree_encoding = encode_stream_slice( + &state, + subnet, + stream_slice.header().begin(), + stream_slice.header().end(), + None, + ) + .0; + let bytes = encode_tree(tree_encoding.clone()); + assert_eq!( + decode_stream_slice(&bytes[..]), + Ok((subnet, stream_slice)), + "failed to decode tree {:?}", + tree_encoding + ); +} - let subnet = subnet_test_id(42); - let stream_slice: StreamSlice = stream.clone().into(); - state.modify_streams(|streams| { - streams.insert(subnet, stream); - }); - state.metadata.certification_version = MAX_SUPPORTED_CERTIFICATION_VERSION; +#[test_strategy::proptest] +fn stream_encode_with_size_limit( + #[strategy(arb_stream( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + stream: Stream, + #[strategy(0..1000usize)] size_limit: usize, +) { + let mut state = ReplicatedState::new(subnet_test_id(1), SubnetType::Application); - let tree_encoding = encode_stream_slice(&state, subnet, stream_slice.header().begin(), stream_slice.header().end(), Some(size_limit)).0; - let bytes = encode_tree(tree_encoding.clone()); - match decode_stream_slice(&bytes[..]) { - Ok((actual_subnet, actual_slice)) => { - assert_eq!(subnet, actual_subnet); - match stream_slice.messages() { - // Expect at least one message. - Some(messages) => { - assert_eq!(stream_slice.header(), actual_slice.header()); - assert_eq!(stream_slice.header().begin(), messages.begin()); - assert!(messages.begin() < messages.end()); - } + let subnet = subnet_test_id(42); + let stream_slice: StreamSlice = stream.clone().into(); + state.modify_streams(|streams| { + streams.insert(subnet, stream); + }); + state.metadata.certification_version = MAX_SUPPORTED_CERTIFICATION_VERSION; - // `stream` had no messages. - None => assert_eq!(stream_slice, actual_slice) + let tree_encoding = encode_stream_slice( + &state, + subnet, + stream_slice.header().begin(), + stream_slice.header().end(), + Some(size_limit), + ) + .0; + let bytes = encode_tree(tree_encoding.clone()); + match decode_stream_slice(&bytes[..]) { + Ok((actual_subnet, actual_slice)) => { + assert_eq!(subnet, actual_subnet); + match stream_slice.messages() { + // Expect at least one message. + Some(messages) => { + assert_eq!(stream_slice.header(), actual_slice.header()); + assert_eq!(stream_slice.header().begin(), messages.begin()); + assert!(messages.begin() < messages.end()); } - }, - Err(e) => panic!("Failed to decode tree {:?}: {}", tree_encoding, e) + + // `stream` had no messages. + None => assert_eq!(stream_slice, actual_slice), + } } + Err(e) => panic!("Failed to decode tree {:?}: {}", tree_encoding, e), } } diff --git a/rs/state_manager/src/tree_diff.rs b/rs/state_manager/src/tree_diff.rs index d32a27d1cec..112af29237d 100644 --- a/rs/state_manager/src/tree_diff.rs +++ b/rs/state_manager/src/tree_diff.rs @@ -522,30 +522,50 @@ mod tests { .boxed() } - proptest! { - #[test] - fn tree_diff_against_self_is_empty(tree in arb_tree(4, 3)) { - prop_assert!(diff_rose_trees(&tree, &tree).is_empty()); - } - - #[test] - fn tree_diff_detects_changing_single_hash((tree, idx) in arb_tree_and_leaf_index(4, 3), - new_hash in any::<[u8; 32]>().prop_map(Digest)) { - let size = num_leaves(&tree); - prop_assume!(idx < size); - let mut tree_2 = tree.clone(); - let (path, _old_hash) = modify_leaf_at_index(&mut tree_2, idx, new_hash.clone()).unwrap(); - let expected_diff = changes(&[(path, Change::InsertLeaf(new_hash))][..]); - assert_eq!(diff_rose_trees(&tree, &tree_2), expected_diff); - } - - #[test] - fn tree_diff_detects_removing_a_node((tree, idx) in arb_tree_and_edge_index(4, 3)) { - let mut tree_2 = tree.clone(); - let (path, _node) = remove_edge_at_index(&mut tree_2, idx).unwrap(); - let expected_diff = changes(&[(path, Change::DeleteSubtree)][..]); - assert_eq!(diff_rose_trees(&tree, &tree_2), expected_diff); - } + #[test_strategy::proptest] + fn tree_diff_against_self_is_empty( + #[strategy(arb_tree( + 4, // max_height + 3, // max_width + ))] + tree: RoseHashTree, + ) { + prop_assert!(diff_rose_trees(&tree, &tree).is_empty()); + } + + #[test_strategy::proptest] + fn tree_diff_detects_changing_single_hash( + #[strategy(arb_tree_and_leaf_index( + 4, // max_height + 3, // max_width + ))] + test_tree: (RoseHashTree, usize), + #[strategy(any::<[u8; 32]>())] new_hash: [u8; 32], + ) { + let (tree, idx) = test_tree; + let new_hash = Digest(new_hash); + + let size = num_leaves(&tree); + prop_assume!(idx < size); + let mut tree_2 = tree.clone(); + let (path, _old_hash) = modify_leaf_at_index(&mut tree_2, idx, new_hash.clone()).unwrap(); + let expected_diff = changes(&[(path, Change::InsertLeaf(new_hash))][..]); + assert_eq!(diff_rose_trees(&tree, &tree_2), expected_diff); + } + + #[test_strategy::proptest] + fn tree_diff_detects_removing_a_node( + #[strategy(arb_tree_and_edge_index( + 4, // max_height + 3, // max_width + ))] + test_tree: (RoseHashTree, usize), + ) { + let (tree, idx) = test_tree; + let mut tree_2 = tree.clone(); + let (path, _node) = remove_edge_at_index(&mut tree_2, idx).unwrap(); + let expected_diff = changes(&[(path, Change::DeleteSubtree)][..]); + assert_eq!(diff_rose_trees(&tree, &tree_2), expected_diff); } #[test] diff --git a/rs/state_manager/tests/state_manager.rs b/rs/state_manager/tests/state_manager.rs index e1496f254bd..e34c4ce3fbc 100644 --- a/rs/state_manager/tests/state_manager.rs +++ b/rs/state_manager/tests/state_manager.rs @@ -6447,292 +6447,385 @@ fn restore_chunk_store_from_snapshot() { assert!(env.execute_ingress(canister_id, "read", vec![],).is_err(),); } -proptest! { - #[test] - fn stream_store_encode_decode(stream in arb_stream(0, 10, 0, 10), size_limit in 0..20usize) { - encode_decode_stream_test( - /* stream to be used */ - stream, - /* size limit used upon encoding */ - Some(size_limit), - /* custom destination subnet */ - None, - /* certification verification should succeed */ - true, - /* modification between encoding and decoding */ - |state_manager, slice| { - // we do not modify the slice before decoding it again - so this should succeed - (state_manager, slice) - } - ); - } +#[test_strategy::proptest] +fn stream_store_encode_decode( + #[strategy(arb_stream( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + stream: Stream, + #[strategy(0..20usize)] size_limit: usize, +) { + encode_decode_stream_test( + /* stream to be used */ + stream, + /* size limit used upon encoding */ + Some(size_limit), + /* custom destination subnet */ + None, + /* certification verification should succeed */ + true, + /* modification between encoding and decoding */ + |state_manager, slice| { + // we do not modify the slice before decoding it again - so this should succeed + (state_manager, slice) + }, + ); +} - #[test] - #[should_panic(expected = "InvalidSignature")] - fn stream_store_decode_with_modified_hash_fails(stream in arb_stream(0, 10, 0, 10), size_limit in 0..20usize) { - encode_decode_stream_test( - /* stream to be used */ - stream, - /* size limit used upon encoding */ - Some(size_limit), - /* custom destination subnet */ - None, - /* certification verification should succeed */ - true, - /* modification between encoding and decoding */ - |state_manager, mut slice| { - let mut hash = slice.certification.signed.content.hash.get(); - *hash.0.first_mut().unwrap() = hash.0.first().unwrap().overflowing_add(1).0; - slice.certification.signed.content.hash = CryptoHashOfPartialState::from(hash); - - (state_manager, slice) - } - ); - } +#[test_strategy::proptest] +#[should_panic(expected = "InvalidSignature")] +fn stream_store_decode_with_modified_hash_fails( + #[strategy(arb_stream( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + stream: Stream, + #[strategy(0..20usize)] size_limit: usize, +) { + encode_decode_stream_test( + /* stream to be used */ + stream, + /* size limit used upon encoding */ + Some(size_limit), + /* custom destination subnet */ + None, + /* certification verification should succeed */ + true, + /* modification between encoding and decoding */ + |state_manager, mut slice| { + let mut hash = slice.certification.signed.content.hash.get(); + *hash.0.first_mut().unwrap() = hash.0.first().unwrap().overflowing_add(1).0; + slice.certification.signed.content.hash = CryptoHashOfPartialState::from(hash); + + (state_manager, slice) + }, + ); +} - #[test] - #[should_panic(expected = "Failed to deserialize witness")] - fn stream_store_decode_with_empty_witness_fails(stream in arb_stream(0, 10, 0, 10), size_limit in 0..20usize) { - encode_decode_stream_test( - /* stream to be used */ - stream, - /* size limit used upon encoding */ - Some(size_limit), - /* custom destination subnet */ - None, - /* certification verification should succeed */ - true, - /* modification between encoding and decoding */ - |state_manager, mut slice| { - slice.merkle_proof = vec![]; +#[test_strategy::proptest] +#[should_panic(expected = "Failed to deserialize witness")] +fn stream_store_decode_with_empty_witness_fails( + #[strategy(arb_stream( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + stream: Stream, + #[strategy(0..20usize)] size_limit: usize, +) { + encode_decode_stream_test( + /* stream to be used */ + stream, + /* size limit used upon encoding */ + Some(size_limit), + /* custom destination subnet */ + None, + /* certification verification should succeed */ + true, + /* modification between encoding and decoding */ + |state_manager, mut slice| { + slice.merkle_proof = vec![]; - (state_manager, slice) - } - ); - } + (state_manager, slice) + }, + ); +} - #[test] - #[should_panic(expected = "InconsistentPartialTree")] - fn stream_store_decode_slice_push_additional_message(stream in arb_stream(0, 10, 0, 10)) { - encode_decode_stream_test( - /* stream to be used */ - stream, - /* size limit used upon encoding */ - None, - /* custom destination subnet */ - None, - /* certification verification should succeed */ - true, - /* modification between encoding and decoding */ - |state_manager, slice| { - /* generate replacement stream for slice.payload */ - modify_encoded_stream_helper(state_manager, slice, |decoded_slice| { - let mut messages = match decoded_slice.messages() { - None => StreamIndexedQueue::default(), - Some(messages) => messages.clone(), - }; - - let req = RequestBuilder::default() - .sender(CanisterId::unchecked_from_principal(PrincipalId::try_from(&[2][..]).unwrap())) - .receiver(CanisterId::unchecked_from_principal(PrincipalId::try_from(&[3][..]).unwrap())) - .method_name("test".to_string()) - .sender_reply_callback(CallbackId::from(999)) - .build(); - - messages.push(req.into()); - - let signals_end = decoded_slice.header().signals_end(); - - Stream::new(messages, signals_end) - }) - } - ); - } +#[test_strategy::proptest] +#[should_panic(expected = "InconsistentPartialTree")] +fn stream_store_decode_slice_push_additional_message( + #[strategy(arb_stream( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + stream: Stream, +) { + encode_decode_stream_test( + /* stream to be used */ + stream, + /* size limit used upon encoding */ + None, + /* custom destination subnet */ + None, + /* certification verification should succeed */ + true, + /* modification between encoding and decoding */ + |state_manager, slice| { + /* generate replacement stream for slice.payload */ + modify_encoded_stream_helper(state_manager, slice, |decoded_slice| { + let mut messages = match decoded_slice.messages() { + None => StreamIndexedQueue::default(), + Some(messages) => messages.clone(), + }; - /// Depending on the specific input, may fail with either `InvalidSignature` or - /// `InconsistentPartialTree`. Hence, only a generic `should_panic`. - #[test] - #[should_panic] - fn stream_store_decode_slice_modify_message_begin(stream in arb_stream(0, 10, 0, 10)) { - encode_decode_stream_test( - /* stream to be used */ - stream, - /* size limit used upon encoding */ - None, - /* custom destination subnet */ - None, - /* certification verification should succeed */ - true, - /* modification between encoding and decoding */ - |state_manager, slice| { - /* generate replacement stream for slice.payload */ - modify_encoded_stream_helper( - state_manager, - slice, - |decoded_slice| { - let mut messages = StreamIndexedQueue::with_begin(StreamIndex::from(99999)); - let signals_end = decoded_slice.header().signals_end(); - - if let Some(decoded_messages) = decoded_slice.messages() { - for (_index, msg) in decoded_messages.iter() { - messages.push(msg.clone()); - } + let req = RequestBuilder::default() + .sender(CanisterId::unchecked_from_principal( + PrincipalId::try_from(&[2][..]).unwrap(), + )) + .receiver(CanisterId::unchecked_from_principal( + PrincipalId::try_from(&[3][..]).unwrap(), + )) + .method_name("test".to_string()) + .sender_reply_callback(CallbackId::from(999)) + .build(); + + messages.push(req.into()); + + let signals_end = decoded_slice.header().signals_end(); + + Stream::new(messages, signals_end) + }) + }, + ); +} + +/// Depending on the specific input, may fail with either `InvalidSignature` or +/// `InconsistentPartialTree`. Hence, only a generic `should_panic`. +#[test_strategy::proptest] +#[should_panic] +fn stream_store_decode_slice_modify_message_begin( + #[strategy(arb_stream( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + stream: Stream, +) { + encode_decode_stream_test( + /* stream to be used */ + stream, + /* size limit used upon encoding */ + None, + /* custom destination subnet */ + None, + /* certification verification should succeed */ + true, + /* modification between encoding and decoding */ + |state_manager, slice| { + /* generate replacement stream for slice.payload */ + modify_encoded_stream_helper(state_manager, slice, |decoded_slice| { + let mut messages = StreamIndexedQueue::with_begin(StreamIndex::from(99999)); + let signals_end = decoded_slice.header().signals_end(); + + if let Some(decoded_messages) = decoded_slice.messages() { + for (_index, msg) in decoded_messages.iter() { + messages.push(msg.clone()); } + } - Stream::new(messages, signals_end) - }) - } - ); - } + Stream::new(messages, signals_end) + }) + }, + ); +} - #[test] - #[should_panic(expected = "InvalidSignature")] - fn stream_store_decode_slice_modify_signals_end(stream in arb_stream(0, 10, 0, 10)) { - encode_decode_stream_test( - /* stream to be used */ - stream, - /* size limit used upon encoding */ - None, - /* custom destination subnet */ - None, - /* certification verification should succeed */ - true, - /* modification between encoding and decoding */ - |state_manager, slice| { - /* generate replacement stream for slice.payload */ - modify_encoded_stream_helper(state_manager, slice, |decoded_slice| { - let messages = decoded_slice.messages() - .unwrap_or(&StreamIndexedQueue::default()).clone(); - let signals_end = decoded_slice.header().signals_end() + 99999.into(); - - Stream::new(messages, signals_end) - }) - } - ); - } +#[test_strategy::proptest] +#[should_panic(expected = "InvalidSignature")] +fn stream_store_decode_slice_modify_signals_end( + #[strategy(arb_stream( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + stream: Stream, +) { + encode_decode_stream_test( + /* stream to be used */ + stream, + /* size limit used upon encoding */ + None, + /* custom destination subnet */ + None, + /* certification verification should succeed */ + true, + /* modification between encoding and decoding */ + |state_manager, slice| { + /* generate replacement stream for slice.payload */ + modify_encoded_stream_helper(state_manager, slice, |decoded_slice| { + let messages = decoded_slice + .messages() + .unwrap_or(&StreamIndexedQueue::default()) + .clone(); + let signals_end = decoded_slice.header().signals_end() + 99999.into(); + + Stream::new(messages, signals_end) + }) + }, + ); +} - #[test] - #[should_panic(expected = "InvalidSignature")] - fn stream_store_decode_slice_push_signal(stream in arb_stream(0, 10, 0, 10)) { - encode_decode_stream_test( - /* stream to be used */ - stream, - /* size limit used upon encoding */ - None, - /* custom destination subnet */ - None, - /* certification verification should succeed */ - true, - /* modification between encoding and decoding */ - |state_manager, slice| { - /* generate replacement stream for slice.payload */ - modify_encoded_stream_helper(state_manager, slice, |decoded_slice| { - let messages = decoded_slice.messages() - .unwrap_or(&StreamIndexedQueue::default()).clone(); - let mut signals_end = decoded_slice.header().signals_end(); - - signals_end.inc_assign(); - - Stream::new(messages, signals_end) - }) - } - ); - } +#[test_strategy::proptest] +#[should_panic(expected = "InvalidSignature")] +fn stream_store_decode_slice_push_signal( + #[strategy(arb_stream( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + stream: Stream, +) { + encode_decode_stream_test( + /* stream to be used */ + stream, + /* size limit used upon encoding */ + None, + /* custom destination subnet */ + None, + /* certification verification should succeed */ + true, + /* modification between encoding and decoding */ + |state_manager, slice| { + /* generate replacement stream for slice.payload */ + modify_encoded_stream_helper(state_manager, slice, |decoded_slice| { + let messages = decoded_slice + .messages() + .unwrap_or(&StreamIndexedQueue::default()) + .clone(); + let mut signals_end = decoded_slice.header().signals_end(); + + signals_end.inc_assign(); + + Stream::new(messages, signals_end) + }) + }, + ); +} - #[test] - #[should_panic(expected = "InvalidDestination")] - fn stream_store_decode_with_invalid_destination(stream in arb_stream(0, 10, 0, 10), size_limit in 0..20usize) { - encode_decode_stream_test( - /* stream to be used */ - stream, - /* size limit used upon encoding */ - Some(size_limit), - /* custom destination subnet */ - Some(subnet_test_id(1)), - /* certification verification should succeed */ - true, - /* modification between encoding and decoding */ - |state_manager, slice| { - // Do not modify the slice before decoding it again - the wrong - // destination subnet should already make it fail - (state_manager, slice) - } - ); - } +#[test_strategy::proptest] +#[should_panic(expected = "InvalidDestination")] +fn stream_store_decode_with_invalid_destination( + #[strategy(arb_stream( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + stream: Stream, + #[strategy(0..20usize)] size_limit: usize, +) { + encode_decode_stream_test( + /* stream to be used */ + stream, + /* size limit used upon encoding */ + Some(size_limit), + /* custom destination subnet */ + Some(subnet_test_id(1)), + /* certification verification should succeed */ + true, + /* modification between encoding and decoding */ + |state_manager, slice| { + // Do not modify the slice before decoding it again - the wrong + // destination subnet should already make it fail + (state_manager, slice) + }, + ); +} - #[test] - #[should_panic(expected = "InvalidSignature")] - fn stream_store_decode_with_rejecting_verifier(stream in arb_stream(0, 10, 0, 10), size_limit in 0..20usize) { - encode_decode_stream_test( - /* stream to be used */ - stream, - /* size limit used upon encoding */ - Some(size_limit), - /* custom destination subnet */ - None, - /* certification verification should fail */ - false, - /* modification between encoding and decoding */ - |state_manager, slice| { - // Do not modify the slice before decoding it again - the signature validation - // failure caused by passing the `RejectingVerifier` should already make it fail. - (state_manager, slice) - } - ); - } +#[test_strategy::proptest] +#[should_panic(expected = "InvalidSignature")] +fn stream_store_decode_with_rejecting_verifier( + #[strategy(arb_stream( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + stream: Stream, + #[strategy(0..20usize)] size_limit: usize, +) { + encode_decode_stream_test( + /* stream to be used */ + stream, + /* size limit used upon encoding */ + Some(size_limit), + /* custom destination subnet */ + None, + /* certification verification should fail */ + false, + /* modification between encoding and decoding */ + |state_manager, slice| { + // Do not modify the slice before decoding it again - the signature validation + // failure caused by passing the `RejectingVerifier` should already make it fail. + (state_manager, slice) + }, + ); +} - /// If both signature verification and slice decoding would fail, we expect to - /// see an error about the former. - #[test] - #[should_panic(expected = "InvalidSignature")] - fn stream_store_decode_with_invalid_destination_and_rejecting_verifier(stream in arb_stream(0, 10, 0, 10), size_limit in 0..20usize) { - encode_decode_stream_test( - /* stream to be used */ - stream, - /* size limit used upon encoding */ - Some(size_limit), - /* custom destination subnet */ - Some(subnet_test_id(1)), - /* certification verification should fail */ - false, - /* modification between encoding and decoding */ - |state_manager, slice| { - // Do not modify the slice, the wrong destination subnet and rejecting verifier - // should make it fail regardless. - (state_manager, slice) - } - ); - } +/// If both signature verification and slice decoding would fail, we expect to +/// see an error about the former. +#[test_strategy::proptest] +#[should_panic(expected = "InvalidSignature")] +fn stream_store_decode_with_invalid_destination_and_rejecting_verifier( + #[strategy(arb_stream( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + stream: Stream, + #[strategy(0..20usize)] size_limit: usize, +) { + encode_decode_stream_test( + /* stream to be used */ + stream, + /* size limit used upon encoding */ + Some(size_limit), + /* custom destination subnet */ + Some(subnet_test_id(1)), + /* certification verification should fail */ + false, + /* modification between encoding and decoding */ + |state_manager, slice| { + // Do not modify the slice, the wrong destination subnet and rejecting verifier + // should make it fail regardless. + (state_manager, slice) + }, + ); +} - #[test] - fn stream_store_encode_partial((stream, begin, count) in arb_stream_slice(1, 10, 0, 10), byte_limit in 0..1000usize) { - // Partial slice with messages beginning at `begin + 1`. - encode_partial_slice_test( - stream, - begin, - begin.increment(), - count - 1, - byte_limit - ); - } +#[test_strategy::proptest] +fn stream_store_encode_partial( + #[strategy(arb_stream_slice( + 1, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), + #[strategy(0..1000usize)] byte_limit: usize, +) { + let (stream, begin, count) = test_slice; + // Partial slice with messages beginning at `begin + 1`. + encode_partial_slice_test(stream, begin, begin.increment(), count - 1, byte_limit); } // 1 test case is sufficient to test index validation. -proptest! { - #![proptest_config(ProptestConfig::with_cases(1))] - - #[test] - #[should_panic(expected = "failed to encode certified stream: InvalidSliceIndices")] - fn stream_store_encode_partial_bad_indices((stream, begin, count) in arb_stream_slice(1, 10, 0, 10), byte_limit in 0..1000usize) { - // `witness_begin` (`== begin + 1`) after `msg_begin` (`== begin`). - encode_partial_slice_test( - stream, - begin.increment(), - begin, - count, - byte_limit - ); - } +#[test_strategy::proptest(ProptestConfig::with_cases(1))] +#[should_panic(expected = "failed to encode certified stream: InvalidSliceIndices")] +fn stream_store_encode_partial_bad_indices( + #[strategy(arb_stream_slice( + 1, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), + #[strategy(0..1000usize)] byte_limit: usize, +) { + let (stream, begin, count) = test_slice; + // `witness_begin` (`== begin + 1`) after `msg_begin` (`== begin`). + encode_partial_slice_test(stream, begin.increment(), begin, count, byte_limit); } /// Test if query stats are correctly aggregated into the canister state. @@ -6916,16 +7009,17 @@ fn arbitrary_test_canister_op() -> impl Strategy { } } -proptest! { -#![proptest_config(ProptestConfig { +#[test_strategy::proptest(ProptestConfig { // Fork to prevent flaky timeouts due to closed sandbox fds fork: true, // We go for fewer, but longer runs ..ProptestConfig::with_cases(5) })] - -#[test] -fn random_canister_input(ops in proptest::collection::vec(arbitrary_test_canister_op(), 1..50)) { +fn random_canister_input( + #[strategy(proptest::collection::vec(arbitrary_test_canister_op(), 1..50))] ops: Vec< + TestCanisterOp, + >, +) { /// Execute op against the state machine `env` fn execute_op(env: StateMachine, canister_id: CanisterId, op: TestCanisterOp) -> StateMachine { match op { @@ -6950,7 +7044,8 @@ fn random_canister_input(ops in proptest::collection::vec(arbitrary_test_caniste } TestCanisterOp::CanisterReinstall => { env.reinstall_canister_wat(canister_id, TEST_CANISTER, vec![]); - env.execute_ingress(canister_id, "grow_page", vec![]).unwrap(); + env.execute_ingress(canister_id, "grow_page", vec![]) + .unwrap(); env } TestCanisterOp::Checkpoint => { @@ -6974,8 +7069,7 @@ fn random_canister_input(ops in proptest::collection::vec(arbitrary_test_caniste let canister_id = env.install_canister_wat(TEST_CANISTER, vec![], None); - env - .execute_ingress(canister_id, "grow_page", vec![]) + env.execute_ingress(canister_id, "grow_page", vec![]) .unwrap(); // Execute all operations the state machine. @@ -6983,4 +7077,3 @@ fn random_canister_input(ops in proptest::collection::vec(arbitrary_test_caniste env = execute_op(env, canister_id, op.clone()); } } -} diff --git a/rs/test_utilities/types/src/batch/payload.rs b/rs/test_utilities/types/src/batch/payload.rs index 180a4ef9e2a..71e4d80d7c4 100644 --- a/rs/test_utilities/types/src/batch/payload.rs +++ b/rs/test_utilities/types/src/batch/payload.rs @@ -14,6 +14,7 @@ impl Default for PayloadBuilder { self_validating: SelfValidatingPayload::default(), canister_http: vec![], query_stats: vec![], + vetkd: vec![], }, } } diff --git a/rs/tree_deserializer/BUILD.bazel b/rs/tree_deserializer/BUILD.bazel index 7fad34c5f40..9f54b55cae3 100644 --- a/rs/tree_deserializer/BUILD.bazel +++ b/rs/tree_deserializer/BUILD.bazel @@ -15,13 +15,16 @@ rust_library( ], ) +MACRO_DEV_DEPENDENCIES = [ + # Keep sorted. + "@crate_index//:proptest-derive", + "@crate_index//:test-strategy", +] + rust_test( name = "tree_deserializer_test", crate = ":tree_deserializer", - proc_macro_deps = [ - # Keep sorted. - "@crate_index//:proptest-derive", - ], + proc_macro_deps = MACRO_DEV_DEPENDENCIES, deps = [ # Keep sorted. "@crate_index//:maplit", diff --git a/rs/tree_deserializer/Cargo.toml b/rs/tree_deserializer/Cargo.toml index 01f7b00b17d..db97e766bf0 100644 --- a/rs/tree_deserializer/Cargo.toml +++ b/rs/tree_deserializer/Cargo.toml @@ -15,3 +15,4 @@ serde = { workspace = true } maplit = "1.0.2" proptest = { workspace = true } proptest-derive = { workspace = true } +test-strategy = "0.4.0" diff --git a/rs/tree_deserializer/src/tests.rs b/rs/tree_deserializer/src/tests.rs index f3a5db7582c..3be57860295 100644 --- a/rs/tree_deserializer/src/tests.rs +++ b/rs/tree_deserializer/src/tests.rs @@ -154,11 +154,9 @@ fn can_collect_leaves() { ); } -proptest! { - #[test] - fn tree_encoding_roundtrip(s in any::()) { - let t = encode_as_tree(&s); - let s_decoded = decode(&t).expect("failed to decode a struct"); - assert_eq!(s, s_decoded); - } +#[test_strategy::proptest] +fn tree_encoding_roundtrip(#[strategy(any::())] s: S) { + let t = encode_as_tree(&s); + let s_decoded = decode(&t).expect("failed to decode a struct"); + assert_eq!(s, s_decoded); } diff --git a/rs/types/types/src/batch.rs b/rs/types/types/src/batch.rs index b145fcb66e0..faa1cd4dd9e 100644 --- a/rs/types/types/src/batch.rs +++ b/rs/types/types/src/batch.rs @@ -120,6 +120,7 @@ pub struct BatchPayload { pub self_validating: SelfValidatingPayload, pub canister_http: Vec, pub query_stats: Vec, + pub vetkd: Vec, } /// Batch properties collected form the last DKG summary block. @@ -179,6 +180,7 @@ impl BatchPayload { self_validating, canister_http, query_stats, + vetkd, } = &self; ingress.is_empty() @@ -186,6 +188,7 @@ impl BatchPayload { && self_validating.is_empty() && canister_http.is_empty() && query_stats.is_empty() + && vetkd.is_empty() } } @@ -306,12 +309,14 @@ mod tests { self_validating, canister_http, query_stats, + vetkd, } = BatchPayload::default(); assert_eq!(ingress.count_bytes(), 0); assert_eq!(self_validating.count_bytes(), 0); assert_eq!(canister_http.len(), 0); assert_eq!(query_stats.len(), 0); + assert_eq!(vetkd.len(), 0); } /// This is a quick test to check the invariant, that the [`Default`] implementation @@ -327,6 +332,7 @@ mod tests { self_validating, canister_http, query_stats, + vetkd, } = &payload; assert!(ingress.is_empty()); @@ -334,6 +340,7 @@ mod tests { assert!(self_validating.is_empty()); assert!(canister_http.is_empty()); assert!(query_stats.is_empty()); + assert!(vetkd.is_empty()); } #[test] diff --git a/rs/types/types/src/consensus.rs b/rs/types/types/src/consensus.rs index f866e2536b6..d434405ff0f 100644 --- a/rs/types/types/src/consensus.rs +++ b/rs/types/types/src/consensus.rs @@ -1295,6 +1295,7 @@ impl From<&Block> for pb::Block { self_validating_payload, canister_http_payload_bytes, query_stats_payload_bytes, + vetkd_payload_bytes, idkg_payload, ) = if payload.is_summary() { ( @@ -1304,6 +1305,7 @@ impl From<&Block> for pb::Block { None, vec![], vec![], + vec![], payload.as_summary().idkg.as_ref().map(|idkg| idkg.into()), ) } else { @@ -1315,6 +1317,7 @@ impl From<&Block> for pb::Block { Some(pb::SelfValidatingPayload::from(&batch.self_validating)), batch.canister_http.clone(), batch.query_stats.clone(), + batch.vetkd.clone(), payload.as_data().idkg.as_ref().map(|idkg| idkg.into()), ) }; @@ -1332,6 +1335,7 @@ impl From<&Block> for pb::Block { self_validating_payload, canister_http_payload_bytes, query_stats_payload_bytes, + vetkd_payload_bytes, idkg_payload, payload_hash: block.payload.get_hash().clone().get().0, } @@ -1362,6 +1366,7 @@ impl TryFrom for Block { .unwrap_or_default(), canister_http: block.canister_http_payload_bytes, query_stats: block.query_stats_payload_bytes, + vetkd: block.vetkd_payload_bytes, }; let payload = match dkg_payload { diff --git a/rs/xnet/payload_builder/BUILD.bazel b/rs/xnet/payload_builder/BUILD.bazel index db73bf84e9e..1baa76ea6fd 100644 --- a/rs/xnet/payload_builder/BUILD.bazel +++ b/rs/xnet/payload_builder/BUILD.bazel @@ -69,6 +69,7 @@ DEV_DEPENDENCIES = [ MACRO_DEV_DEPENDENCIES = [ # Keep sorted. + "@crate_index//:test-strategy", ] rust_library( diff --git a/rs/xnet/payload_builder/Cargo.toml b/rs/xnet/payload_builder/Cargo.toml index 8dce8fb6a42..93e884b1369 100644 --- a/rs/xnet/payload_builder/Cargo.toml +++ b/rs/xnet/payload_builder/Cargo.toml @@ -60,4 +60,5 @@ nix = { workspace = true } proptest = { workspace = true } reqwest = { workspace = true } tempfile = { workspace = true } +test-strategy = "0.4.0" url = { workspace = true } diff --git a/rs/xnet/payload_builder/tests/certified_slice_pool.rs b/rs/xnet/payload_builder/tests/certified_slice_pool.rs index a9fe02b511c..61c1eb694ea 100644 --- a/rs/xnet/payload_builder/tests/certified_slice_pool.rs +++ b/rs/xnet/payload_builder/tests/certified_slice_pool.rs @@ -5,6 +5,7 @@ use ic_interfaces_certified_stream_store::DecodeStreamError; use ic_interfaces_certified_stream_store_mocks::MockCertifiedStreamStore; use ic_metrics::MetricsRegistry; use ic_protobuf::{messaging::xnet::v1, proxy::ProtoProxy}; +use ic_replicated_state::Stream; use ic_test_utilities_logger::with_test_replica_logger; use ic_test_utilities_metrics::{metric_vec, HistogramStats}; use ic_test_utilities_state::arb_stream_slice; @@ -29,867 +30,1032 @@ pub const SRC_SUBNET: SubnetId = REMOTE_SUBNET; pub const DST_SUBNET: SubnetId = OWN_SUBNET; pub const REGISTRY_VERSION: RegistryVersion = RegistryVersion::new(169); -proptest! { - #[test] - fn slice_unpack_roundtrip((stream, from, msg_count) in arb_stream_slice(1, 10, 0, 10)) { - with_test_replica_logger(|log| { - let fixture = StateManagerFixture::new(log).with_stream(DST_SUBNET, stream); - - let certified_slice = fixture.get_slice(DST_SUBNET, from, msg_count); - let unpacked = UnpackedStreamSlice::try_from(certified_slice.clone()) - .expect("failed to unpack certified stream"); +#[test_strategy::proptest] +fn slice_unpack_roundtrip( + #[strategy(arb_stream_slice( + 1, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (stream, from, msg_count) = test_slice; + + with_test_replica_logger(|log| { + let fixture = StateManagerFixture::new(log).with_stream(DST_SUBNET, stream); + + let certified_slice = fixture.get_slice(DST_SUBNET, from, msg_count); + let unpacked = UnpackedStreamSlice::try_from(certified_slice.clone()) + .expect("failed to unpack certified stream"); + + assert_slices_eq(certified_slice, CertifiedStreamSlice::from(unpacked)); + }); +} - assert_slices_eq( - certified_slice, - CertifiedStreamSlice::from(unpacked) - ); - }); +#[test_strategy::proptest] +fn slice_garbage_collect( + #[strategy(arb_stream_slice( + 1, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (mut stream, from, msg_count) = test_slice; + + /// Convenience wrapper for `UnpackedStreamSlice::garbage_collect()` that takes + /// and returns `CertifiedStreamSlices`. + fn gc( + certified_slice: &CertifiedStreamSlice, + message_index: StreamIndex, + signal_index: StreamIndex, + ) -> Option { + let unpacked = UnpackedStreamSlice::try_from(certified_slice.clone()) + .expect("failed to unpack certified stream"); + + unpacked + .garbage_collect(&ExpectedIndices { + message_index, + signal_index, + }) + .unwrap() + .map(|leftover| leftover.into()) } - #[test] - fn slice_garbage_collect((mut stream, from, msg_count) in arb_stream_slice(1, 10, 0, 10)) { - /// Convenience wrapper for `UnpackedStreamSlice::garbage_collect()` that takes - /// and returns `CertifiedStreamSlices`. - fn gc( - certified_slice: &CertifiedStreamSlice, - message_index: StreamIndex, - signal_index: StreamIndex, - ) -> Option { - let unpacked = UnpackedStreamSlice::try_from(certified_slice.clone()) - .expect("failed to unpack certified stream"); - - unpacked - .garbage_collect(&ExpectedIndices { - message_index, - signal_index - }) - .unwrap() - .map(|leftover| leftover.into()) - } + with_test_replica_logger(|log| { + // Increment `signals_end` so we can later safely decrement it without underflow. + stream.push_accept_signal(); + let signals_end = stream.signals_end(); - with_test_replica_logger(|log| { - // Increment `signals_end` so we can later safely decrement it without underflow. - stream.push_accept_signal(); - let signals_end = stream.signals_end(); + let fixture = StateManagerFixture::new(log).with_stream(DST_SUBNET, stream); + let certified_slice = fixture.get_slice(DST_SUBNET, from, msg_count); - let fixture = StateManagerFixture::new(log).with_stream(DST_SUBNET, stream); - let certified_slice = fixture.get_slice(DST_SUBNET, from, msg_count); - - if msg_count > 0 { - // Garbage collecting no messages and no signals should yield the original slice. + if msg_count > 0 { + // Garbage collecting no messages and no signals should yield the original slice. + assert_opt_slices_eq( + Some(fixture.get_slice(DST_SUBNET, from, msg_count)), + gc(&certified_slice, from, signals_end.decrement()), + ); + // Garbage collecting no messages and all signals should yield the original slice. + assert_opt_slices_eq( + Some(fixture.get_slice(DST_SUBNET, from, msg_count)), + gc(&certified_slice, from, signals_end), + ); + if msg_count > 1 { + let from_middle = from + StreamIndex::from(msg_count as u64 / 2); + let msgs_from_middle = (msg_count + 1) / 2; + // Garbage collecting some messages and no signals should truncate the slice. assert_opt_slices_eq( - Some(fixture.get_slice(DST_SUBNET, from, msg_count)), - gc(&certified_slice, from, signals_end.decrement()), + Some(fixture.get_slice(DST_SUBNET, from_middle, msgs_from_middle)), + gc(&certified_slice, from_middle, signals_end.decrement()), ); - // Garbage collecting no messages and all signals should yield the original slice. + // Garbage collecting some messages and all signals should truncate the slice. assert_opt_slices_eq( - Some(fixture.get_slice(DST_SUBNET, from, msg_count)), - gc(&certified_slice, from, signals_end), + Some(fixture.get_slice(DST_SUBNET, from_middle, msgs_from_middle)), + gc(&certified_slice, from_middle, signals_end), ); - if msg_count > 1 { - let from_middle = from + StreamIndex::from(msg_count as u64 / 2); - let msgs_from_middle = (msg_count + 1) / 2; - // Garbage collecting some messages and no signals should truncate the slice. - assert_opt_slices_eq( - Some(fixture.get_slice(DST_SUBNET, from_middle, msgs_from_middle)), - gc(&certified_slice, from_middle, signals_end.decrement()), - ); - // Garbage collecting some messages and all signals should truncate the slice. - assert_opt_slices_eq( - Some(fixture.get_slice(DST_SUBNET, from_middle, msgs_from_middle)), - gc(&certified_slice, from_middle, signals_end), - ); - } } + } - let to = from + StreamIndex::from(msg_count as u64); - // Garbage collecting all messages and no signals should yield an empty slice. - assert_opt_slices_eq( - Some(fixture.get_slice(DST_SUBNET, to, 0)), - gc(&certified_slice, to, signals_end.decrement()), - ); - // Garbage collecting all messages and all signals should yield `None`. - assert_opt_slices_eq( - None, - gc(&certified_slice, to, signals_end), - ); - }); - } + let to = from + StreamIndex::from(msg_count as u64); + // Garbage collecting all messages and no signals should yield an empty slice. + assert_opt_slices_eq( + Some(fixture.get_slice(DST_SUBNET, to, 0)), + gc(&certified_slice, to, signals_end.decrement()), + ); + // Garbage collecting all messages and all signals should yield `None`. + assert_opt_slices_eq(None, gc(&certified_slice, to, signals_end)); + }); +} - #[test] - fn slice_take_prefix((stream, from, msg_count) in arb_stream_slice(0, 100, 0, 100)) { - /// Convenience wrapper for `UnpackedStreamSlice::take_prefix()` that takes a - /// `&CertifiedStreamSlice` argument. - fn take_prefix( - certified_slice: &CertifiedStreamSlice, - msg_limit: Option, - byte_limit: Option, - ) -> (Option, Option) { - let unpacked = UnpackedStreamSlice::try_from(certified_slice.clone()) - .expect("failed to unpack certified stream"); - - let (prefix, postfix) = unpacked.take_prefix(msg_limit, byte_limit).unwrap(); - - // Ensure that any limits were respected. - if let (Some(msg_limit), Some(prefix)) = (msg_limit, prefix.as_ref()) { - assert!(testing::slice_len(prefix) <= msg_limit); - } - if let (Some(byte_limit), Some(prefix)) = (byte_limit, prefix.as_ref()) { - assert!(prefix.count_bytes() <= byte_limit); +#[test_strategy::proptest] +fn slice_take_prefix( + #[strategy(arb_stream_slice( + 0, // min_size + 100, // max_size + 0, // min_signal_count + 100, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (stream, from, msg_count) = test_slice; + + /// Convenience wrapper for `UnpackedStreamSlice::take_prefix()` that takes a + /// `&CertifiedStreamSlice` argument. + fn take_prefix( + certified_slice: &CertifiedStreamSlice, + msg_limit: Option, + byte_limit: Option, + ) -> (Option, Option) { + let unpacked = UnpackedStreamSlice::try_from(certified_slice.clone()) + .expect("failed to unpack certified stream"); + + let (prefix, postfix) = unpacked.take_prefix(msg_limit, byte_limit).unwrap(); + + // Ensure that any limits were respected. + if let (Some(msg_limit), Some(prefix)) = (msg_limit, prefix.as_ref()) { + assert!(testing::slice_len(prefix) <= msg_limit); + } + if let (Some(byte_limit), Some(prefix)) = (byte_limit, prefix.as_ref()) { + assert!(prefix.count_bytes() <= byte_limit); + } + // Testing the signal limit is pointless here because it requires very large streams + // that would make this test needlessly slow. There is a dedicated test for it. + + // And that a longer prefix would have gone over one the limits. + let unpacked = UnpackedStreamSlice::try_from(certified_slice.clone()).unwrap(); + match prefix.as_ref() { + Some(prefix) if postfix.is_some() => { + let prefix_len = testing::slice_len(prefix); + let longer_prefix = unpacked + .take_prefix(Some(prefix_len + 1), None) + .unwrap() + .0 + .unwrap(); + let over_msg_limit = msg_limit + .map(|limit| testing::slice_len(&longer_prefix) > limit) + .unwrap_or_default(); + let over_byte_limit = byte_limit + .map(|limit| longer_prefix.count_bytes() > limit) + .unwrap_or_default(); + assert!(over_msg_limit || over_byte_limit) } - // Testing the signal limit is pointless here because it requires very large streams - // that would make this test needlessly slow. There is a dedicated test for it. - - // And that a longer prefix would have gone over one the limits. - let unpacked = UnpackedStreamSlice::try_from(certified_slice.clone()).unwrap(); - match prefix.as_ref() { - Some(prefix) if postfix.is_some() => { - let prefix_len = testing::slice_len(prefix); - let longer_prefix = unpacked.take_prefix(Some(prefix_len + 1), None).unwrap().0.unwrap(); - let over_msg_limit = msg_limit.map(|limit| testing::slice_len(&longer_prefix) > limit).unwrap_or_default(); - let over_byte_limit = byte_limit.map(|limit| longer_prefix.count_bytes() > limit).unwrap_or_default(); - assert!(over_msg_limit || over_byte_limit) - }, - None => { - let empty_prefix = unpacked.take_prefix(Some(0), None).unwrap().0.unwrap(); - assert!(empty_prefix.count_bytes() > byte_limit.unwrap()) - } - _ => {} + None => { + let empty_prefix = unpacked.take_prefix(Some(0), None).unwrap().0.unwrap(); + assert!(empty_prefix.count_bytes() > byte_limit.unwrap()) } - - (prefix.map(|prefix| prefix.into()), postfix.map(|postfix| postfix.into())) + _ => {} } - /// Helper producing two adjacent `CertifiedStreamSlices` starting at `from` and - /// of lengths `prefix_msg_count` and respectively `postfix_msg_count`. - fn split( - fixture: &StateManagerFixture, - subnet_id: SubnetId, - from: StreamIndex, - prefix_msg_count: usize, - postfix_msg_count: usize, - ) -> (Option, Option) { - ( - Some(fixture.get_slice(subnet_id, from, prefix_msg_count)), - Some(fixture.get_slice( - subnet_id, from + StreamIndex::from(prefix_msg_count as u64), postfix_msg_count)), - ) - } + ( + prefix.map(|prefix| prefix.into()), + postfix.map(|postfix| postfix.into()), + ) + } - with_test_replica_logger(|log| { - let fixture = StateManagerFixture::new(log).with_stream(DST_SUBNET, stream); - let certified_slice = fixture.get_slice(DST_SUBNET, from, msg_count); + /// Helper producing two adjacent `CertifiedStreamSlices` starting at `from` and + /// of lengths `prefix_msg_count` and respectively `postfix_msg_count`. + fn split( + fixture: &StateManagerFixture, + subnet_id: SubnetId, + from: StreamIndex, + prefix_msg_count: usize, + postfix_msg_count: usize, + ) -> (Option, Option) { + ( + Some(fixture.get_slice(subnet_id, from, prefix_msg_count)), + Some(fixture.get_slice( + subnet_id, + from + StreamIndex::from(prefix_msg_count as u64), + postfix_msg_count, + )), + ) + } - // Taking an unlimited prefix should result in the full slice and no leftover. + with_test_replica_logger(|log| { + let fixture = StateManagerFixture::new(log).with_stream(DST_SUBNET, stream); + let certified_slice = fixture.get_slice(DST_SUBNET, from, msg_count); + + // Taking an unlimited prefix should result in the full slice and no leftover. + assert_opt_slice_pairs_eq( + (Some(certified_slice.clone()), None), + take_prefix(&certified_slice, None, None), + ); + + // Taking a too-small prefix should result in no prefix and the original left over. + assert_opt_slice_pairs_eq( + (None, Some(certified_slice.clone())), + take_prefix(&certified_slice, None, Some(13)), + ); + + // Even if requesting for zero messages. + assert_opt_slice_pairs_eq( + (None, Some(certified_slice.clone())), + take_prefix(&certified_slice, Some(0), Some(13)), + ); + + if msg_count > 0 { + // Taking zero messages should result in an empty prefix and the original left over. assert_opt_slice_pairs_eq( - (Some(certified_slice.clone()), None), - take_prefix(&certified_slice, None, None), + split(&fixture, DST_SUBNET, from, 0, msg_count), + take_prefix(&certified_slice, Some(0), None), ); - // Taking a too-small prefix should result in no prefix and the original left over. + // Taking an unlimited number of messages with a byte limit just under the byte size + // should result in `msg_count - 1` messages and 1 message left over. + let byte_size = UnpackedStreamSlice::try_from(certified_slice.clone()) + .expect("failed to unpack certified stream") + .count_bytes(); assert_opt_slice_pairs_eq( - (None, Some(certified_slice.clone())), - take_prefix(&certified_slice, None, Some(13)), + split(&fixture, DST_SUBNET, from, msg_count - 1, 1), + take_prefix(&certified_slice, None, Some(byte_size - 1)), ); - // Even if requesting for zero messages. + // As should taking `msg_count - 1` messages. assert_opt_slice_pairs_eq( - (None, Some(certified_slice.clone())), - take_prefix(&certified_slice, Some(0), Some(13)), + split(&fixture, DST_SUBNET, from, msg_count - 1, 1), + take_prefix(&certified_slice, Some(msg_count - 1), None), ); - if msg_count > 0 { - // Taking zero messages should result in an empty prefix and the original left over. - assert_opt_slice_pairs_eq( - split(&fixture, DST_SUBNET, from, 0, msg_count), - take_prefix(&certified_slice, Some(0), None), - ); - - // Taking an unlimited number of messages with a byte limit just under the byte size - // should result in `msg_count - 1` messages and 1 message left over. - let byte_size = UnpackedStreamSlice::try_from(certified_slice.clone()) - .expect("failed to unpack certified stream").count_bytes(); - assert_opt_slice_pairs_eq( - split(&fixture, DST_SUBNET, from, msg_count - 1, 1), - take_prefix(&certified_slice, None, Some(byte_size - 1)), - ); - - // As should taking `msg_count - 1` messages. - assert_opt_slice_pairs_eq( - split(&fixture, DST_SUBNET, from, msg_count - 1, 1), - take_prefix(&certified_slice, Some(msg_count - 1), None), - ); - - // But setting both limits exactly should result in the full slice and no leftover. - assert_opt_slice_pairs_eq( + // But setting both limits exactly should result in the full slice and no leftover. + assert_opt_slice_pairs_eq( (Some(certified_slice.clone()), None), - take_prefix(&certified_slice, Some(msg_count), Some(byte_size)), - ); - } else { - // Taking zero messages from an empty slice should result in the full slice and no leftover. - assert_opt_slice_pairs_eq( - (Some(certified_slice.clone()), None), - take_prefix(&certified_slice, Some(0), None), - ); - } - }); - } - - #[test] - fn invalid_slice((stream, from, msg_count) in arb_stream_slice(1, 10, 0, 10)) { - // Returns the provided slice, adjusted by the provided function. - fn adjust>)>( - slice: &CertifiedStreamSlice, - mut f: F, - ) -> CertifiedStreamSlice { - let mut adjusted = slice.clone(); - let mut tree = v1::LabeledTree::proxy_decode(slice.payload.as_slice()).unwrap(); - f(&mut tree); - adjusted.payload = v1::LabeledTree::proxy_encode(tree); - adjusted - } - - // Asserts that unpacking the given slice fails with the expected error message. - fn assert_unpack_fails( - expected: InvalidSlice, - invalid_slice: CertifiedStreamSlice, - ) { - match UnpackedStreamSlice::try_from(invalid_slice) { - Err(CertifiedSliceError::InvalidPayload(reason)) => assert_eq!(expected, reason), - actual => panic!( - "Expected Err(CertifiedSliceError::InvalidPayload((\"{:?}\")), got {:?}", - expected, - actual - ), - } - } - - // Returns the `FlatMap` contained in a `SubTree`. - fn children_of(tree: &mut LabeledTree>) -> &mut FlatMap>> { - match tree { - LabeledTree::SubTree(children) => children, - LabeledTree::Leaf(_) => panic!("not a SubTree"), - } - } - - with_test_replica_logger(|log| { - let stream_begin = stream.messages_begin(); - let fixture = StateManagerFixture::new(log).with_stream(DST_SUBNET, stream); - - let certified_slice = fixture.get_slice(DST_SUBNET, from, msg_count); - - assert_unpack_fails( - InvalidSlice::MissingStreams, - adjust( - &certified_slice, - |tree| { - children_of(tree).split_off(&Label::from("")); - } - ), - ); - - assert_unpack_fails( - InvalidSlice::MissingStream, - adjust( - &certified_slice, - |tree| { - let streams = children_of(tree).get_mut(&Label::from("streams")).unwrap(); - children_of(streams).split_off(&Label::from("")); - } - ), + take_prefix(&certified_slice, Some(msg_count), Some(byte_size)), ); - - assert_unpack_fails( - InvalidSlice::MissingHeader, - adjust( - &certified_slice, - |tree| { - let streams = children_of(tree).get_mut(&Label::from("streams")).unwrap(); - let streams_tree = children_of(streams); - let subnet_id = streams_tree.keys()[0].clone(); - let stream = streams_tree.get_mut(&subnet_id).unwrap(); - children_of(stream).split_off(&Label::from("")); - } - ), + } else { + // Taking zero messages from an empty slice should result in the full slice and no leftover. + assert_opt_slice_pairs_eq( + (Some(certified_slice.clone()), None), + take_prefix(&certified_slice, Some(0), None), ); + } + }); +} - // Must have at least 2 messages and be able to prepend one. - if msg_count > 1 && from.get() > 0 { - // Stream with an extra message prepended to payload only. - let slice_with_extra_message = adjust( - &certified_slice, - |tree| { - let streams = children_of(tree).get_mut(&Label::from("streams")).unwrap(); - let streams_tree = children_of(streams); - let subnet_id = streams_tree.keys()[0].clone(); - let stream = streams_tree.get_mut(&subnet_id).unwrap(); - let stream_tree = children_of(stream); - let messages = stream_tree.get_mut(&Label::from("messages")).unwrap(); - let messages_tree = children_of(messages); - let mut messages_vec: Vec<_> = - messages_tree.iter().map(|(k, v)| (k.clone(), v.clone())).collect(); - messages_vec.insert(0, ( - from.decrement().to_label(), - LabeledTree::Leaf(vec![]) - )); - std::mem::swap(messages_tree, &mut FlatMap::from_key_values(messages_vec)); - } - ); - - if from > stream_begin { - // Valid slice, but mismatching withess. - - // Unpacking will succeed, as we're not validating against the witness. - let unpacked = UnpackedStreamSlice::try_from(slice_with_extra_message.clone()).unwrap(); - - // But GC should fail. - match unpacked.garbage_collect(&ExpectedIndices { - message_index: from.increment(), - signal_index: StreamIndex::from(u64::MAX), - }) { - Err(CertifiedSliceError::WitnessPruningFailed(_)) => {} - actual => panic!( - "Expected Err(CertifiedSliceError::WitnessPruningFailed(_), got {:?}", - actual - ), - } - - // As should taking a prefix. - let unpacked = UnpackedStreamSlice::try_from(slice_with_extra_message).unwrap(); - match unpacked.take_prefix(Some(1), None) { - Err(CertifiedSliceError::WitnessPruningFailed(_)) => {} - actual => panic!( - "Expected Err(CertifiedSliceError::WitnessPruningFailed(_), got {:?}", - actual - ), - } - } else { - // Invalid slice, begin index before stream begin index. Unpacking should fail. - match UnpackedStreamSlice::try_from(slice_with_extra_message) { - Err(CertifiedSliceError::InvalidPayload(InvalidSlice::InvalidBounds)) => {} - actual => panic!( - "Expected Err(CertifiedSliceError::InvalidPayload(InvalidBounds), got {:?}", - actual - ), - } - } - } - }); +#[test_strategy::proptest] +fn invalid_slice( + #[strategy(arb_stream_slice( + 1, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (stream, from, msg_count) = test_slice; + + // Returns the provided slice, adjusted by the provided function. + fn adjust>)>( + slice: &CertifiedStreamSlice, + mut f: F, + ) -> CertifiedStreamSlice { + let mut adjusted = slice.clone(); + let mut tree = v1::LabeledTree::proxy_decode(slice.payload.as_slice()).unwrap(); + f(&mut tree); + adjusted.payload = v1::LabeledTree::proxy_encode(tree); + adjusted } - /// Verifies that the size estimate returned by `count_bytes()` is within - /// 5% of the actual size of the encoded struct. - /// - /// If this test fails, you need to check where the error lies (payload vs. - /// witness) and adjust the estimate accordingly. Or bump the error margin. - #[test] - fn slice_accurate_count_bytes((stream, from, msg_count) in arb_stream_slice(2, 100, 0, 100)) { - /// Asserts that the `actual` value is within `+/-(error_percent% + - /// absolute_error)` of the `expected` value. - fn assert_almost_equal( - expected: usize, - actual: usize, - error_percent: usize, - absolute_error: usize, - ) { - let expected_min = expected * (100 - error_percent) / 100 - absolute_error; - let expected_max = expected * (100 + error_percent) / 100 + absolute_error; - assert!( - expected_min <= actual && actual <= expected_max, - "Expecting estimated size to be within {}% + {} of {}, was {}", - error_percent, - absolute_error, - expected, - actual, - ); - } - - /// Verifies that the result of calling `count_bytes()` on the - /// `UnpackedStreamSlice` unpacked from `slice` is within 5% of the - /// byte size of `slice`. - fn assert_good_estimate(slice: CertifiedStreamSlice) { - let unpacked = UnpackedStreamSlice::try_from(slice.clone()) - .expect("failed to unpack certified stream"); - - let packed_payload_bytes = slice.payload.len(); - let unpacked_payload_bytes = testing::payload_count_bytes(&unpacked); - assert_almost_equal(packed_payload_bytes, unpacked_payload_bytes, 1, 10); - - let packed_witness_bytes = slice.merkle_proof.len(); - let unpacked_witness_bytes = testing::witness_count_bytes(&unpacked); - assert_almost_equal(packed_witness_bytes, unpacked_witness_bytes, 5, 10); - - let packed_bytes = - slice.payload.len() + slice.merkle_proof.len() + slice.certification.count_bytes(); - let unpacked_bytes = unpacked.count_bytes(); - assert_almost_equal(packed_bytes, unpacked_bytes, MAX_XNET_PAYLOAD_SIZE_ERROR_MARGIN_PERCENT as usize, 0); + // Asserts that unpacking the given slice fails with the expected error message. + fn assert_unpack_fails(expected: InvalidSlice, invalid_slice: CertifiedStreamSlice) { + match UnpackedStreamSlice::try_from(invalid_slice) { + Err(CertifiedSliceError::InvalidPayload(reason)) => assert_eq!(expected, reason), + actual => panic!( + "Expected Err(CertifiedSliceError::InvalidPayload((\"{:?}\")), got {:?}", + expected, actual + ), } - - with_test_replica_logger(|log| { - let fixture = - StateManagerFixture::new(log).with_stream(DST_SUBNET, stream); - - // Verify that we have good estimates for empty, single-message - // and many-message slices, to ensure that both fixed and - // per-message overheads are accurate. - assert_good_estimate(fixture.get_slice(DST_SUBNET, from, 0)); - assert_good_estimate(fixture.get_slice(DST_SUBNET, from, 1)); - assert_good_estimate(fixture.get_slice(DST_SUBNET, from, msg_count)); - }); } - /// Verifies that `certified_slice_count_bytes(&slice)` (used in payload - /// validation) produces the exact same estimate as - /// `UnpackedStreamSlice::try_from(slice).unwrap().count_bytes()` (used in - /// payload building). - #[test] - fn matching_count_bytes((stream, from, msg_count) in arb_stream_slice(2, 100, 0, 100)) { - /// Verifies that the two ways of computing a byte size estimate produce - /// the exact same result. - fn assert_matching_count_bytes(slice: CertifiedStreamSlice) { - let fn_estimate = certified_slice_count_bytes(&slice) - .expect("failed to unpack certified stream"); - let unpacked = UnpackedStreamSlice::try_from(slice) - .expect("failed to unpack certified stream"); - - assert_eq!(unpacked.count_bytes(), fn_estimate); + // Returns the `FlatMap` contained in a `SubTree`. + fn children_of(tree: &mut LabeledTree>) -> &mut FlatMap>> { + match tree { + LabeledTree::SubTree(children) => children, + LabeledTree::Leaf(_) => panic!("not a SubTree"), } - - with_test_replica_logger(|log| { - let fixture = - StateManagerFixture::new(log).with_stream(DST_SUBNET, stream); - - // Verify equality for empty, single-message and many-message slices. - assert_matching_count_bytes(fixture.get_slice(DST_SUBNET, from, 0)); - assert_matching_count_bytes(fixture.get_slice(DST_SUBNET, from, 1)); - assert_matching_count_bytes(fixture.get_slice(DST_SUBNET, from, msg_count)); - }); } - #[test] - fn pool( - (mut stream, from, msg_count) in arb_stream_slice(1, 10, 0, 10), - ) { - /// Asserts that the pool has a cached stream position for the given subnet. - fn has_stream_position(subnet_id: SubnetId, pool: &CertifiedSlicePool) -> bool { - !matches!(pool.slice_stats(subnet_id), (None, _, _, _)) - } - /// Asserts that the pool contains a slice from the given subnet. - fn has_slice(subnet_id: SubnetId, pool: &CertifiedSlicePool) -> bool { - !matches!(pool.slice_stats(subnet_id), (_, None, 0, 0)) - } - /// Takes the full pooled slice from the given subnet. Panics if no such slice exists. - fn take_slice(subnet_id: SubnetId, pool: &mut CertifiedSlicePool) -> Option { - pool.take_slice(subnet_id, None, None, None).unwrap().map(|(slice, _)| slice) - } - /// Asserts that the pool contains a slice with the expected stats and non-zero byte size. - fn assert_has_slice( - subnet_id: SubnetId, - pool: &mut CertifiedSlicePool, - expected_stream_position: Option, - expected_slice_begin: Option, - expected_msg_count: usize, - ) { - let (stream_position, slice_begin, msg_count, byte_size) = pool.slice_stats(subnet_id); - assert_eq!(expected_stream_position, stream_position); - assert_eq!(expected_slice_begin, slice_begin); - assert_eq!(expected_msg_count, msg_count); - assert!(byte_size > 0); - } + with_test_replica_logger(|log| { + let stream_begin = stream.messages_begin(); + let fixture = StateManagerFixture::new(log).with_stream(DST_SUBNET, stream); + + let certified_slice = fixture.get_slice(DST_SUBNET, from, msg_count); + + assert_unpack_fails( + InvalidSlice::MissingStreams, + adjust(&certified_slice, |tree| { + children_of(tree).split_off(&Label::from("")); + }), + ); + + assert_unpack_fails( + InvalidSlice::MissingStream, + adjust(&certified_slice, |tree| { + let streams = children_of(tree).get_mut(&Label::from("streams")).unwrap(); + children_of(streams).split_off(&Label::from("")); + }), + ); + + assert_unpack_fails( + InvalidSlice::MissingHeader, + adjust(&certified_slice, |tree| { + let streams = children_of(tree).get_mut(&Label::from("streams")).unwrap(); + let streams_tree = children_of(streams); + let subnet_id = streams_tree.keys()[0].clone(); + let stream = streams_tree.get_mut(&subnet_id).unwrap(); + children_of(stream).split_off(&Label::from("")); + }), + ); + + // Must have at least 2 messages and be able to prepend one. + if msg_count > 1 && from.get() > 0 { + // Stream with an extra message prepended to payload only. + let slice_with_extra_message = adjust(&certified_slice, |tree| { + let streams = children_of(tree).get_mut(&Label::from("streams")).unwrap(); + let streams_tree = children_of(streams); + let subnet_id = streams_tree.keys()[0].clone(); + let stream = streams_tree.get_mut(&subnet_id).unwrap(); + let stream_tree = children_of(stream); + let messages = stream_tree.get_mut(&Label::from("messages")).unwrap(); + let messages_tree = children_of(messages); + let mut messages_vec: Vec<_> = messages_tree + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + messages_vec.insert(0, (from.decrement().to_label(), LabeledTree::Leaf(vec![]))); + std::mem::swap(messages_tree, &mut FlatMap::from_key_values(messages_vec)); + }); + + if from > stream_begin { + // Valid slice, but mismatching withess. + + // Unpacking will succeed, as we're not validating against the witness. + let unpacked = + UnpackedStreamSlice::try_from(slice_with_extra_message.clone()).unwrap(); + + // But GC should fail. + match unpacked.garbage_collect(&ExpectedIndices { + message_index: from.increment(), + signal_index: StreamIndex::from(u64::MAX), + }) { + Err(CertifiedSliceError::WitnessPruningFailed(_)) => {} + actual => panic!( + "Expected Err(CertifiedSliceError::WitnessPruningFailed(_), got {:?}", + actual + ), + } - with_test_replica_logger(|log| { - // Increment `signals_end` so we can later safely decrement it without underflow. - stream.push_accept_signal(); - - // Indices just before the slice. Garbage collecting these should be a no-op. - let indices_before = ExpectedIndices{ - message_index: from, - signal_index: stream.signals_end().decrement(), - }; - let zero_indices = ExpectedIndices::default(); - - let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream); - let slice = fixture.get_slice(DST_SUBNET, from, msg_count); - let messages_begin = if msg_count > 0 { - Some(from) + // As should taking a prefix. + let unpacked = UnpackedStreamSlice::try_from(slice_with_extra_message).unwrap(); + match unpacked.take_prefix(Some(1), None) { + Err(CertifiedSliceError::WitnessPruningFailed(_)) => {} + actual => panic!( + "Expected Err(CertifiedSliceError::WitnessPruningFailed(_), got {:?}", + actual + ), + } } else { - None - }; - - let mut certified_stream_store = MockCertifiedStreamStore::new(); - // Actual return value does not matter as long as it's `Ok(_)`. - certified_stream_store - .expect_decode_certified_stream_slice() - .returning(|_, _, _| Ok(StreamSliceBuilder::new().build())); - let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; - let mut pool = CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &MetricsRegistry::new()); - - // Empty pool is empty. - assert!(pool.peers().next().is_none()); - assert!(!has_stream_position(SRC_SUBNET, &pool)); - assert!(!has_slice(SRC_SUBNET, &pool)); - assert!(take_slice(SRC_SUBNET, &mut pool).is_none()); - - // Populate the pool. - pool.put(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log.clone()).unwrap(); - - // Peers and stream positions still not set. - assert!(pool.peers().next().is_none()); - assert!(!has_stream_position(SRC_SUBNET, &pool)); - - // But we can take the slice out of the pool... - assert!(has_slice(SRC_SUBNET, &pool)); - assert_eq!(slice, take_slice(SRC_SUBNET, &mut pool).unwrap()); - // ...once. - assert!(!has_slice(SRC_SUBNET, &pool)); - assert!(take_slice(SRC_SUBNET, &mut pool).is_none()); - - // Create a fresh, populated pool. - let mut pool = CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); - pool.garbage_collect(btreemap! {SRC_SUBNET => ExpectedIndices::default()}); - pool.put(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log.clone()).unwrap(); - - // Sanity check that the slice is in the pool. - { - let mut peers = pool.peers(); - assert_eq!(Some(&SRC_SUBNET), peers.next()); - assert!(peers.next().is_none()); - - pool.observe_pool_size_bytes(); - assert_eq!( - UnpackedStreamSlice::try_from(slice.clone()).unwrap().count_bytes(), - fixture.fetch_pool_size_bytes() - ); + // Invalid slice, begin index before stream begin index. Unpacking should fail. + match UnpackedStreamSlice::try_from(slice_with_extra_message) { + Err(CertifiedSliceError::InvalidPayload(InvalidSlice::InvalidBounds)) => {} + actual => panic!( + "Expected Err(CertifiedSliceError::InvalidPayload(InvalidBounds), got {:?}", + actual + ), + } } - assert_has_slice(SRC_SUBNET, &mut pool, Some(zero_indices), messages_begin, msg_count); - - // Garbage collecting no messages and no signals should be a no-op. - pool.garbage_collect(btreemap! {SRC_SUBNET => indices_before.clone()}); - // But stream position should be updated. - assert_has_slice(SRC_SUBNET, &mut pool, Some(indices_before.clone()), messages_begin, msg_count); + } + }); +} - // Taking a slice with too low a byte limit should also be a no-op. - assert_eq!( - None, - pool.take_slice(SRC_SUBNET, Some(&indices_before), None, Some(1)).unwrap(), - ); - assert_has_slice(SRC_SUBNET, &mut pool, Some(indices_before.clone()), messages_begin, msg_count); +/// Verifies that the size estimate returned by `count_bytes()` is within +/// 5% of the actual size of the encoded struct. +/// +/// If this test fails, you need to check where the error lies (payload vs. +/// witness) and adjust the estimate accordingly. Or bump the error margin. +#[test_strategy::proptest] +fn slice_accurate_count_bytes( + #[strategy(arb_stream_slice( + 2, // min_size + 100, // max_size + 0, // min_signal_count + 100, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (stream, from, msg_count) = test_slice; + + /// Asserts that the `actual` value is within `+/-(error_percent% + + /// absolute_error)` of the `expected` value. + fn assert_almost_equal( + expected: usize, + actual: usize, + error_percent: usize, + absolute_error: usize, + ) { + let expected_min = expected * (100 - error_percent) / 100 - absolute_error; + let expected_max = expected * (100 + error_percent) / 100 + absolute_error; + assert!( + expected_min <= actual && actual <= expected_max, + "Expecting estimated size to be within {}% + {} of {}, was {}", + error_percent, + absolute_error, + expected, + actual, + ); + } - // Taking a slice with message limit zero should return the header only... - assert_opt_slices_eq( - Some(fixture.get_slice(DST_SUBNET, from, 0)), - pool.take_slice(SRC_SUBNET, Some(&indices_before), Some(0), None) - .unwrap() - .map(|(slice, _)| slice), - ); - // ...but advance `signals_end`. - let mut stream_position = ExpectedIndices { - message_index: from, - signal_index: indices_before.signal_index.increment(), - }; - if msg_count == 0 { - // Slice had length zero, it should have been consumed. - assert_eq!( - (Some(stream_position), None, 0, 0), - pool.slice_stats(SRC_SUBNET) - ); - // Terminate early. - return; - } + /// Verifies that the result of calling `count_bytes()` on the + /// `UnpackedStreamSlice` unpacked from `slice` is within 5% of the + /// byte size of `slice`. + fn assert_good_estimate(slice: CertifiedStreamSlice) { + let unpacked = UnpackedStreamSlice::try_from(slice.clone()) + .expect("failed to unpack certified stream"); + + let packed_payload_bytes = slice.payload.len(); + let unpacked_payload_bytes = testing::payload_count_bytes(&unpacked); + assert_almost_equal(packed_payload_bytes, unpacked_payload_bytes, 1, 10); + + let packed_witness_bytes = slice.merkle_proof.len(); + let unpacked_witness_bytes = testing::witness_count_bytes(&unpacked); + assert_almost_equal(packed_witness_bytes, unpacked_witness_bytes, 5, 10); + + let packed_bytes = + slice.payload.len() + slice.merkle_proof.len() + slice.certification.count_bytes(); + let unpacked_bytes = unpacked.count_bytes(); + assert_almost_equal( + packed_bytes, + unpacked_bytes, + MAX_XNET_PAYLOAD_SIZE_ERROR_MARGIN_PERCENT as usize, + 0, + ); + } - // Slice was non-empty, messages should still be there. - assert_has_slice(SRC_SUBNET, &mut pool, Some(stream_position.clone()), Some(from), msg_count); - - // Pretend message 0 was already included into a block and take the next 1 message. - stream_position.message_index.inc_assign(); - let prefix = pool.take_slice(SRC_SUBNET, Some(&stream_position), Some(1), None).unwrap(); - if msg_count == 1 { - // Attempting to take a second message should have returned nothing... - assert_eq!(None, prefix); - // ...and GC-ed everything. - assert_eq!( - (Some(stream_position), None, 0, 0), - pool.slice_stats(SRC_SUBNET) - ); - // Terminate early. - return; - } + with_test_replica_logger(|log| { + let fixture = StateManagerFixture::new(log).with_stream(DST_SUBNET, stream); - // A slice containing the second message should have been returned. - assert_opt_slices_eq( - Some(fixture.get_slice(DST_SUBNET, from.increment(), 1)), - prefix.map(|(slice, _)| slice), - ); + // Verify that we have good estimates for empty, single-message + // and many-message slices, to ensure that both fixed and + // per-message overheads are accurate. + assert_good_estimate(fixture.get_slice(DST_SUBNET, from, 0)); + assert_good_estimate(fixture.get_slice(DST_SUBNET, from, 1)); + assert_good_estimate(fixture.get_slice(DST_SUBNET, from, msg_count)); + }); +} - stream_position.message_index.inc_assign(); - if msg_count == 2 { - // Slice should have been consumed. - assert_eq!( - (Some(stream_position), None, 0, 0), - pool.slice_stats(SRC_SUBNET) - ); - // Terminate early. - return; - } +/// Verifies that `certified_slice_count_bytes(&slice)` (used in payload +/// validation) produces the exact same estimate as +/// `UnpackedStreamSlice::try_from(slice).unwrap().count_bytes()` (used in +/// payload building). +#[test_strategy::proptest] +fn matching_count_bytes( + #[strategy(arb_stream_slice( + 2, // min_size + 100, // max_size + 0, // min_signal_count + 100, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (stream, from, msg_count) = test_slice; + + /// Verifies that the two ways of computing a byte size estimate produce + /// the exact same result. + fn assert_matching_count_bytes(slice: CertifiedStreamSlice) { + let fn_estimate = + certified_slice_count_bytes(&slice).expect("failed to unpack certified stream"); + let unpacked = + UnpackedStreamSlice::try_from(slice).expect("failed to unpack certified stream"); + + assert_eq!(unpacked.count_bytes(), fn_estimate); + } - // Rest of slice should be in the pool. - assert_has_slice( - SRC_SUBNET, - &mut pool, - Some(stream_position.clone()), - Some(stream_position.message_index), - msg_count - 2); - - // GC-ing with an earlier message index should leave the slice unchanged... - let earlier_message_index = from.increment(); - let earlier_indices = ExpectedIndices { - message_index: earlier_message_index, - signal_index: stream_position.signal_index, - }; - pool.garbage_collect(btreemap! {SRC_SUBNET => earlier_indices.clone()}); - assert_has_slice( - SRC_SUBNET, - &mut pool, - Some(earlier_indices.clone()), - Some(stream_position.message_index), - msg_count - 2); + with_test_replica_logger(|log| { + let fixture = StateManagerFixture::new(log).with_stream(DST_SUBNET, stream); - // ...but putting back the original slice now should replace it (from the earlier index). - pool.put(SRC_SUBNET, slice, REGISTRY_VERSION, log).unwrap(); - assert_has_slice(SRC_SUBNET, &mut pool, Some(earlier_indices), Some(earlier_message_index), msg_count - 1); + // Verify equality for empty, single-message and many-message slices. + assert_matching_count_bytes(fixture.get_slice(DST_SUBNET, from, 0)); + assert_matching_count_bytes(fixture.get_slice(DST_SUBNET, from, 1)); + assert_matching_count_bytes(fixture.get_slice(DST_SUBNET, from, msg_count)); + }); +} - assert_eq!( - metric_vec(&[ - (&[(LABEL_STATUS, STATUS_SUCCESS)], 2), - (&[(LABEL_STATUS, STATUS_NONE)], 1), - ]), - fixture.fetch_pool_take_count() - ); - // take_slice() returned 2 Some(_) results, one empty, one with a single message. - assert_eq!( - HistogramStats { - count: 2, - sum: 1.0 - }, - fixture.fetch_pool_take_messages() - ); - // Called take_slice() 3x, skipping one message total. - assert_eq!( - HistogramStats { - count: 3, - sum: 1.0 - }, - fixture.fetch_pool_take_gced_messages() - ); - assert_eq!(2, fixture.fetch_pool_take_size_bytes().count); - }); +#[test_strategy::proptest] +fn pool( + #[strategy(arb_stream_slice( + 1, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (mut stream, from, msg_count) = test_slice; + + /// Asserts that the pool has a cached stream position for the given subnet. + fn has_stream_position(subnet_id: SubnetId, pool: &CertifiedSlicePool) -> bool { + !matches!(pool.slice_stats(subnet_id), (None, _, _, _)) } - - #[test] - fn pool_append_same_slice( - (mut stream, from, msg_count) in arb_stream_slice(1, 10, 0, 10), + /// Asserts that the pool contains a slice from the given subnet. + fn has_slice(subnet_id: SubnetId, pool: &CertifiedSlicePool) -> bool { + !matches!(pool.slice_stats(subnet_id), (_, None, 0, 0)) + } + /// Takes the full pooled slice from the given subnet. Panics if no such slice exists. + fn take_slice( + subnet_id: SubnetId, + pool: &mut CertifiedSlicePool, + ) -> Option { + pool.take_slice(subnet_id, None, None, None) + .unwrap() + .map(|(slice, _)| slice) + } + /// Asserts that the pool contains a slice with the expected stats and non-zero byte size. + fn assert_has_slice( + subnet_id: SubnetId, + pool: &mut CertifiedSlicePool, + expected_stream_position: Option, + expected_slice_begin: Option, + expected_msg_count: usize, ) { - let to = from + (msg_count as u64).into(); - with_test_replica_logger(|log| { - // Increment `signals_end` so we can later safely decrement it without underflow. - stream.push_accept_signal(); - - let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); - let slice = fixture.get_slice(DST_SUBNET, from, msg_count); - let slice_bytes = UnpackedStreamSlice::try_from(slice.clone()).unwrap().count_bytes(); - - // Stream position guaranteed to yield a slice, even if empty. - let stream_position = ExpectedIndices{ - message_index: from, - signal_index: stream.signals_end().decrement(), - }; - - let mut certified_stream_store = MockCertifiedStreamStore::new(); - // Actual return value does not matter as long as it's `Ok(_)`. - certified_stream_store - .expect_decode_certified_stream_slice() - .returning(|_, _, _| Ok(StreamSliceBuilder::new().build())); - let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; - let mut pool = CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); - - // `append()` with no slice present is equivalent to `put()`. - pool.append(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log.clone()).unwrap(); - // Note: this takes the slice and updates the cached stream position to its end indices. - assert_opt_slices_eq( - Some(slice.clone()), - pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) - .unwrap() - .map(|(slice, _)| slice), - ); - - // Appending the same slice after taking it should be a no-op. - pool.append(SRC_SUBNET, slice, REGISTRY_VERSION, log.clone()).unwrap(); - let mut stream_position = ExpectedIndices{ - message_index: to, - signal_index: stream.signals_end(), - }; - assert_eq!( - (Some(stream_position.clone()), None, 0, 0), - pool.slice_stats(SRC_SUBNET) - ); - - // But appending the same slice with a higher `signals_end` should result in an empty - // slice (with the new `signals_end`). - stream.push_accept_signal(); - let new_fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); - let new_slice = new_fixture.get_slice(DST_SUBNET, from, msg_count); + let (stream_position, slice_begin, msg_count, byte_size) = pool.slice_stats(subnet_id); + assert_eq!(expected_stream_position, stream_position); + assert_eq!(expected_slice_begin, slice_begin); + assert_eq!(expected_msg_count, msg_count); + assert!(byte_size > 0); + } - pool.append(SRC_SUBNET, new_slice, REGISTRY_VERSION, log).unwrap(); + with_test_replica_logger(|log| { + // Increment `signals_end` so we can later safely decrement it without underflow. + stream.push_accept_signal(); + + // Indices just before the slice. Garbage collecting these should be a no-op. + let indices_before = ExpectedIndices { + message_index: from, + signal_index: stream.signals_end().decrement(), + }; + let zero_indices = ExpectedIndices::default(); + + let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream); + let slice = fixture.get_slice(DST_SUBNET, from, msg_count); + let messages_begin = if msg_count > 0 { Some(from) } else { None }; + + let mut certified_stream_store = MockCertifiedStreamStore::new(); + // Actual return value does not matter as long as it's `Ok(_)`. + certified_stream_store + .expect_decode_certified_stream_slice() + .returning(|_, _, _| Ok(StreamSliceBuilder::new().build())); + let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; + let mut pool = + CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &MetricsRegistry::new()); + + // Empty pool is empty. + assert!(pool.peers().next().is_none()); + assert!(!has_stream_position(SRC_SUBNET, &pool)); + assert!(!has_slice(SRC_SUBNET, &pool)); + assert!(take_slice(SRC_SUBNET, &mut pool).is_none()); + + // Populate the pool. + pool.put(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log.clone()) + .unwrap(); + + // Peers and stream positions still not set. + assert!(pool.peers().next().is_none()); + assert!(!has_stream_position(SRC_SUBNET, &pool)); + + // But we can take the slice out of the pool... + assert!(has_slice(SRC_SUBNET, &pool)); + assert_eq!(slice, take_slice(SRC_SUBNET, &mut pool).unwrap()); + // ...once. + assert!(!has_slice(SRC_SUBNET, &pool)); + assert!(take_slice(SRC_SUBNET, &mut pool).is_none()); + + // Create a fresh, populated pool. + let mut pool = + CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); + pool.garbage_collect(btreemap! {SRC_SUBNET => ExpectedIndices::default()}); + pool.put(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log.clone()) + .unwrap(); + + // Sanity check that the slice is in the pool. + { + let mut peers = pool.peers(); + assert_eq!(Some(&SRC_SUBNET), peers.next()); + assert!(peers.next().is_none()); - let empty_slice = new_fixture.get_slice(DST_SUBNET, to, 0); - let empty_slice_bytes = UnpackedStreamSlice::try_from(empty_slice.clone()).unwrap().count_bytes(); - assert_opt_slices_eq( - Some(empty_slice), - pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) + pool.observe_pool_size_bytes(); + assert_eq!( + UnpackedStreamSlice::try_from(slice.clone()) .unwrap() - .map(|(slice, _)| slice), + .count_bytes(), + fixture.fetch_pool_size_bytes() ); - stream_position.signal_index = stream.signals_end(); + } + assert_has_slice( + SRC_SUBNET, + &mut pool, + Some(zero_indices), + messages_begin, + msg_count, + ); + + // Garbage collecting no messages and no signals should be a no-op. + pool.garbage_collect(btreemap! {SRC_SUBNET => indices_before.clone()}); + // But stream position should be updated. + assert_has_slice( + SRC_SUBNET, + &mut pool, + Some(indices_before.clone()), + messages_begin, + msg_count, + ); + + // Taking a slice with too low a byte limit should also be a no-op. + assert_eq!( + None, + pool.take_slice(SRC_SUBNET, Some(&indices_before), None, Some(1)) + .unwrap(), + ); + assert_has_slice( + SRC_SUBNET, + &mut pool, + Some(indices_before.clone()), + messages_begin, + msg_count, + ); + + // Taking a slice with message limit zero should return the header only... + assert_opt_slices_eq( + Some(fixture.get_slice(DST_SUBNET, from, 0)), + pool.take_slice(SRC_SUBNET, Some(&indices_before), Some(0), None) + .unwrap() + .map(|(slice, _)| slice), + ); + // ...but advance `signals_end`. + let mut stream_position = ExpectedIndices { + message_index: from, + signal_index: indices_before.signal_index.increment(), + }; + if msg_count == 0 { + // Slice had length zero, it should have been consumed. assert_eq!( (Some(stream_position), None, 0, 0), pool.slice_stats(SRC_SUBNET) ); + // Terminate early. + return; + } - pool.observe_pool_size_bytes(); + // Slice was non-empty, messages should still be there. + assert_has_slice( + SRC_SUBNET, + &mut pool, + Some(stream_position.clone()), + Some(from), + msg_count, + ); + + // Pretend message 0 was already included into a block and take the next 1 message. + stream_position.message_index.inc_assign(); + let prefix = pool + .take_slice(SRC_SUBNET, Some(&stream_position), Some(1), None) + .unwrap(); + if msg_count == 1 { + // Attempting to take a second message should have returned nothing... + assert_eq!(None, prefix); + // ...and GC-ed everything. assert_eq!( - 0, - fixture.fetch_pool_size_bytes() - ); - assert_eq!( - metric_vec(&[ - (&[(LABEL_STATUS, STATUS_SUCCESS)], 2), - ]), - fixture.fetch_pool_take_count() - ); - // take_slice() returned 2 Some(_) results, one empty, one with msg_count messages. - assert_eq!( - HistogramStats { - count: 2, - sum: msg_count as f64 - }, - fixture.fetch_pool_take_messages() - ); - // Called take_slice() 2x, not skipping any message. - assert_eq!( - HistogramStats { - count: 2, - sum: 0.0 - }, - fixture.fetch_pool_take_gced_messages() + (Some(stream_position), None, 0, 0), + pool.slice_stats(SRC_SUBNET) ); + // Terminate early. + return; + } + + // A slice containing the second message should have been returned. + assert_opt_slices_eq( + Some(fixture.get_slice(DST_SUBNET, from.increment(), 1)), + prefix.map(|(slice, _)| slice), + ); + + stream_position.message_index.inc_assign(); + if msg_count == 2 { + // Slice should have been consumed. assert_eq!( - HistogramStats { - count: 2, - sum: (slice_bytes + empty_slice_bytes) as f64 - }, - fixture.fetch_pool_take_size_bytes() + (Some(stream_position), None, 0, 0), + pool.slice_stats(SRC_SUBNET) ); - }); - } + // Terminate early. + return; + } - #[test] - fn pool_append_non_empty_to_empty( - (mut stream, from, msg_count) in arb_stream_slice(1, 10, 0, 10), - ) { - with_test_replica_logger(|log| { - // Increment `signals_end` so we can later safely decrement it without underflow. - stream.push_accept_signal(); - - let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); - let slice = fixture.get_slice(DST_SUBNET, from, msg_count); - - // Stream position matching slice begin. - let stream_position = ExpectedIndices{ - message_index: from, - signal_index: stream.signals_end(), - }; - - let mut certified_stream_store = MockCertifiedStreamStore::new(); - // Actual return value does not matter as long as it's `Ok(_)`. - certified_stream_store - .expect_decode_certified_stream_slice() - .returning(|_, _, _| Ok(StreamSliceBuilder::new().build())); - let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; - let mut pool = CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); - - // Append an empty slice. - let empty_prefix_slice = fixture.get_slice(DST_SUBNET, from, 0); - pool.append(SRC_SUBNET, empty_prefix_slice, REGISTRY_VERSION, log.clone()).unwrap(); - assert_matches!( - pool.slice_stats(SRC_SUBNET), - (None, None, 0, byte_size) if byte_size > 0 - ); + // Rest of slice should be in the pool. + assert_has_slice( + SRC_SUBNET, + &mut pool, + Some(stream_position.clone()), + Some(stream_position.message_index), + msg_count - 2, + ); + + // GC-ing with an earlier message index should leave the slice unchanged... + let earlier_message_index = from.increment(); + let earlier_indices = ExpectedIndices { + message_index: earlier_message_index, + signal_index: stream_position.signal_index, + }; + pool.garbage_collect(btreemap! {SRC_SUBNET => earlier_indices.clone()}); + assert_has_slice( + SRC_SUBNET, + &mut pool, + Some(earlier_indices.clone()), + Some(stream_position.message_index), + msg_count - 2, + ); + + // ...but putting back the original slice now should replace it (from the earlier index). + pool.put(SRC_SUBNET, slice, REGISTRY_VERSION, log).unwrap(); + assert_has_slice( + SRC_SUBNET, + &mut pool, + Some(earlier_indices), + Some(earlier_message_index), + msg_count - 1, + ); + + assert_eq!( + metric_vec(&[ + (&[(LABEL_STATUS, STATUS_SUCCESS)], 2), + (&[(LABEL_STATUS, STATUS_NONE)], 1), + ]), + fixture.fetch_pool_take_count() + ); + // take_slice() returned 2 Some(_) results, one empty, one with a single message. + assert_eq!( + HistogramStats { count: 2, sum: 1.0 }, + fixture.fetch_pool_take_messages() + ); + // Called take_slice() 3x, skipping one message total. + assert_eq!( + HistogramStats { count: 3, sum: 1.0 }, + fixture.fetch_pool_take_gced_messages() + ); + assert_eq!(2, fixture.fetch_pool_take_size_bytes().count); + }); +} - // Appending the full slice should pool the full slice. - pool.append(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log).unwrap(); - assert_matches!( - pool.slice_stats(SRC_SUBNET), - (None, Some(messages_begin), count, byte_size) - if messages_begin == from - && count == msg_count - && byte_size > 0 - ); - assert_opt_slices_eq( - Some(slice), - pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) - .unwrap() - .map(|(slice, _)| slice), - ); - }); - } +#[test_strategy::proptest] +fn pool_append_same_slice( + #[strategy(arb_stream_slice( + 1, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (mut stream, from, msg_count) = test_slice; + + let to = from + (msg_count as u64).into(); + with_test_replica_logger(|log| { + // Increment `signals_end` so we can later safely decrement it without underflow. + stream.push_accept_signal(); + + let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); + let slice = fixture.get_slice(DST_SUBNET, from, msg_count); + let slice_bytes = UnpackedStreamSlice::try_from(slice.clone()) + .unwrap() + .count_bytes(); + + // Stream position guaranteed to yield a slice, even if empty. + let stream_position = ExpectedIndices { + message_index: from, + signal_index: stream.signals_end().decrement(), + }; + + let mut certified_stream_store = MockCertifiedStreamStore::new(); + // Actual return value does not matter as long as it's `Ok(_)`. + certified_stream_store + .expect_decode_certified_stream_slice() + .returning(|_, _, _| Ok(StreamSliceBuilder::new().build())); + let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; + let mut pool = + CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); + + // `append()` with no slice present is equivalent to `put()`. + pool.append(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log.clone()) + .unwrap(); + // Note: this takes the slice and updates the cached stream position to its end indices. + assert_opt_slices_eq( + Some(slice.clone()), + pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) + .unwrap() + .map(|(slice, _)| slice), + ); + + // Appending the same slice after taking it should be a no-op. + pool.append(SRC_SUBNET, slice, REGISTRY_VERSION, log.clone()) + .unwrap(); + let mut stream_position = ExpectedIndices { + message_index: to, + signal_index: stream.signals_end(), + }; + assert_eq!( + (Some(stream_position.clone()), None, 0, 0), + pool.slice_stats(SRC_SUBNET) + ); + + // But appending the same slice with a higher `signals_end` should result in an empty + // slice (with the new `signals_end`). + stream.push_accept_signal(); + let new_fixture = + StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); + let new_slice = new_fixture.get_slice(DST_SUBNET, from, msg_count); + + pool.append(SRC_SUBNET, new_slice, REGISTRY_VERSION, log) + .unwrap(); + + let empty_slice = new_fixture.get_slice(DST_SUBNET, to, 0); + let empty_slice_bytes = UnpackedStreamSlice::try_from(empty_slice.clone()) + .unwrap() + .count_bytes(); + assert_opt_slices_eq( + Some(empty_slice), + pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) + .unwrap() + .map(|(slice, _)| slice), + ); + stream_position.signal_index = stream.signals_end(); + assert_eq!( + (Some(stream_position), None, 0, 0), + pool.slice_stats(SRC_SUBNET) + ); + + pool.observe_pool_size_bytes(); + assert_eq!(0, fixture.fetch_pool_size_bytes()); + assert_eq!( + metric_vec(&[(&[(LABEL_STATUS, STATUS_SUCCESS)], 2),]), + fixture.fetch_pool_take_count() + ); + // take_slice() returned 2 Some(_) results, one empty, one with msg_count messages. + assert_eq!( + HistogramStats { + count: 2, + sum: msg_count as f64 + }, + fixture.fetch_pool_take_messages() + ); + // Called take_slice() 2x, not skipping any message. + assert_eq!( + HistogramStats { count: 2, sum: 0.0 }, + fixture.fetch_pool_take_gced_messages() + ); + assert_eq!( + HistogramStats { + count: 2, + sum: (slice_bytes + empty_slice_bytes) as f64 + }, + fixture.fetch_pool_take_size_bytes() + ); + }); +} - #[test] - fn pool_append_non_empty_to_non_empty( - (mut stream, from, msg_count) in arb_stream_slice(2, 10, 0, 10), - ) { - with_test_replica_logger(|log| { - // Increment `signals_end` so we can later safely decrement it without underflow. - stream.push_accept_signal(); - - let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); - let slice = fixture.get_slice(DST_SUBNET, from, msg_count); - - // Stream position matching slice begin. - let stream_position = ExpectedIndices{ - message_index: from, - signal_index: stream.signals_end(), - }; - - let mut certified_stream_store = MockCertifiedStreamStore::new(); - // Actual return value does not matter as long as it's `Ok(_)`. - certified_stream_store - .expect_decode_certified_stream_slice() - .returning(|_, _, _| Ok(StreamSliceBuilder::new().build())); - let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; - let mut pool = CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); - - // Slice midpoint. - let prefix_len = msg_count / 2; - let suffix_len = msg_count - prefix_len; - let mid = from + (prefix_len as u64).into(); - - // Pool first half of slice. - let prefix_slice = fixture.get_slice(DST_SUBNET, from, prefix_len); - pool.put(SRC_SUBNET, prefix_slice, REGISTRY_VERSION, log.clone()).unwrap(); - assert_matches!( - pool.slice_stats(SRC_SUBNET), - (None, Some(messages_begin), count, byte_size) - if messages_begin == from - && count == prefix_len - && byte_size > 0 - ); +#[test_strategy::proptest] +fn pool_append_non_empty_to_empty( + #[strategy(arb_stream_slice( + 1, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (mut stream, from, msg_count) = test_slice; + + with_test_replica_logger(|log| { + // Increment `signals_end` so we can later safely decrement it without underflow. + stream.push_accept_signal(); + + let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); + let slice = fixture.get_slice(DST_SUBNET, from, msg_count); + + // Stream position matching slice begin. + let stream_position = ExpectedIndices { + message_index: from, + signal_index: stream.signals_end(), + }; + + let mut certified_stream_store = MockCertifiedStreamStore::new(); + // Actual return value does not matter as long as it's `Ok(_)`. + certified_stream_store + .expect_decode_certified_stream_slice() + .returning(|_, _, _| Ok(StreamSliceBuilder::new().build())); + let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; + let mut pool = + CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); + + // Append an empty slice. + let empty_prefix_slice = fixture.get_slice(DST_SUBNET, from, 0); + pool.append( + SRC_SUBNET, + empty_prefix_slice, + REGISTRY_VERSION, + log.clone(), + ) + .unwrap(); + assert_matches!( + pool.slice_stats(SRC_SUBNET), + (None, None, 0, byte_size) if byte_size > 0 + ); + + // Appending the full slice should pool the full slice. + pool.append(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log) + .unwrap(); + assert_matches!( + pool.slice_stats(SRC_SUBNET), + (None, Some(messages_begin), count, byte_size) + if messages_begin == from + && count == msg_count + && byte_size > 0 + ); + assert_opt_slices_eq( + Some(slice), + pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) + .unwrap() + .map(|(slice, _)| slice), + ); + }); +} - // Appending a slice with a duplicate message should fail. - let overlapping_suffix_slice = - fixture.get_partial_slice(DST_SUBNET, from, mid.decrement(), suffix_len + 1); +#[test_strategy::proptest] +fn pool_append_non_empty_to_non_empty( + #[strategy(arb_stream_slice( + 2, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (mut stream, from, msg_count) = test_slice; + + with_test_replica_logger(|log| { + // Increment `signals_end` so we can later safely decrement it without underflow. + stream.push_accept_signal(); + + let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); + let slice = fixture.get_slice(DST_SUBNET, from, msg_count); + + // Stream position matching slice begin. + let stream_position = ExpectedIndices { + message_index: from, + signal_index: stream.signals_end(), + }; + + let mut certified_stream_store = MockCertifiedStreamStore::new(); + // Actual return value does not matter as long as it's `Ok(_)`. + certified_stream_store + .expect_decode_certified_stream_slice() + .returning(|_, _, _| Ok(StreamSliceBuilder::new().build())); + let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; + let mut pool = + CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); + + // Slice midpoint. + let prefix_len = msg_count / 2; + let suffix_len = msg_count - prefix_len; + let mid = from + (prefix_len as u64).into(); + + // Pool first half of slice. + let prefix_slice = fixture.get_slice(DST_SUBNET, from, prefix_len); + pool.put(SRC_SUBNET, prefix_slice, REGISTRY_VERSION, log.clone()) + .unwrap(); + assert_matches!( + pool.slice_stats(SRC_SUBNET), + (None, Some(messages_begin), count, byte_size) + if messages_begin == from + && count == prefix_len + && byte_size > 0 + ); + + // Appending a slice with a duplicate message should fail. + let overlapping_suffix_slice = + fixture.get_partial_slice(DST_SUBNET, from, mid.decrement(), suffix_len + 1); + assert_matches!( + pool.append( + SRC_SUBNET, + overlapping_suffix_slice, + REGISTRY_VERSION, + log.clone() + ), + Err(CertifiedSliceError::InvalidAppend( + InvalidAppend::IndexMismatch + )) + ); + // Pooled slice stays unchanged. + assert_matches!( + pool.slice_stats(SRC_SUBNET), + (None, Some(messages_begin), count, byte_size) + if messages_begin == from + && count == prefix_len + && byte_size > 0 + ); + + if msg_count >= 3 { + // Appending a slice with a message gap should fail. + let gapped_suffix_slice = + fixture.get_partial_slice(DST_SUBNET, from, mid.increment(), suffix_len - 1); assert_matches!( - pool.append(SRC_SUBNET, overlapping_suffix_slice, REGISTRY_VERSION, log.clone()), - Err(CertifiedSliceError::InvalidAppend(InvalidAppend::IndexMismatch)) + pool.append( + SRC_SUBNET, + gapped_suffix_slice, + REGISTRY_VERSION, + log.clone() + ), + Err(CertifiedSliceError::InvalidAppend( + InvalidAppend::IndexMismatch + )) ); // Pooled slice stays unchanged. assert_matches!( @@ -899,266 +1065,276 @@ proptest! { && count == prefix_len && byte_size > 0 ); + } - if msg_count >= 3 { - // Appending a slice with a message gap should fail. - let gapped_suffix_slice = - fixture.get_partial_slice(DST_SUBNET, from, mid.increment(), suffix_len - 1); - assert_matches!( - pool.append(SRC_SUBNET, gapped_suffix_slice, REGISTRY_VERSION, log.clone()), - Err(CertifiedSliceError::InvalidAppend(InvalidAppend::IndexMismatch)) - ); - // Pooled slice stays unchanged. - assert_matches!( - pool.slice_stats(SRC_SUBNET), - (None, Some(messages_begin), count, byte_size) - if messages_begin == from - && count == prefix_len - && byte_size > 0 - ); - } - - // Appending the matching second half should succeed. - let suffix_slice = - fixture.get_partial_slice(DST_SUBNET, from, mid, suffix_len); - pool.append(SRC_SUBNET, suffix_slice, REGISTRY_VERSION, log).unwrap(); - // And result in the full slice being pooled. - assert_matches!( - pool.slice_stats(SRC_SUBNET), - (None, Some(messages_begin), count, byte_size) - if messages_begin == from - && count == msg_count - && byte_size > 0 - ); - assert_opt_slices_eq( - Some(slice), - pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) - .unwrap() - .map(|(slice, _)| slice), - ); - }); - } - - #[test] - fn pool_put_invalid_slice( - (mut stream, from, msg_count) in arb_stream_slice(1, 10, 0, 10), - ) { - with_test_replica_logger(|log| { - // Increment `signals_end` so we can later safely decrement it without underflow. - stream.push_accept_signal(); - - let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); - let slice = fixture.get_slice(DST_SUBNET, from, msg_count); - - let stream_position = ExpectedIndices{ - message_index: from, - signal_index: stream.signals_end(), - }; - - let mut certified_stream_store = MockCertifiedStreamStore::new(); - // Fail validation for the slice. - certified_stream_store - .expect_decode_certified_stream_slice() - .returning(move |_, _, _| Err(DecodeStreamError::InvalidSignature(REMOTE_SUBNET))); - let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; - let mut pool = CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); - - // `put()` should fail. - assert_matches!( - pool.put(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log.clone()), - Err(CertifiedSliceError::DecodeStreamError(DecodeStreamError::InvalidSignature(_))) - ); - - // Pool should be untouched. - assert_eq!( - (None, None, 0, 0), - pool.slice_stats(SRC_SUBNET) - ); - assert_opt_slices_eq( - None, - pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) - .unwrap() - .map(|(slice, _)| slice), - ); - pool.observe_pool_size_bytes(); - assert_eq!( - 0, - fixture.fetch_pool_size_bytes() - ); - }); - } - - #[test] - fn pool_append_invalid_slice( - (mut stream, _from, _msg_count) in arb_stream_slice(2, 10, 0, 10), - ) { - let stream_begin = stream.messages_begin(); - // Set `from` and `msg_count` so that we always get two non-empty slices. - let from = stream_begin.increment(); - let msg_count = 1; - - with_test_replica_logger(|log| { - // Increment `signals_end` so we can later safely decrement it without underflow. - stream.push_accept_signal(); - - let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); - let prefix_msg_count = (from - stream_begin).get() as usize; - let prefix = fixture.get_slice(DST_SUBNET, stream_begin, prefix_msg_count); - let slice = fixture.get_partial_slice(DST_SUBNET, stream_begin, from, msg_count); - - let mut stream_position = ExpectedIndices{ - message_index: stream_begin, - signal_index: stream.signals_end(), - }; - - let mut certified_stream_store = MockCertifiedStreamStore::new(); - // Accept the prefix as valid, but fail validation for the merged slice. - certified_stream_store - .expect_decode_certified_stream_slice() - .with(always(), always(), eq(prefix.clone())) - .returning(|_, _, _| Ok(StreamSliceBuilder::new().build())); - certified_stream_store - .expect_decode_certified_stream_slice() - .returning(move |_, _, _| Err(DecodeStreamError::InvalidSignature(REMOTE_SUBNET))); - let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; - let mut pool = CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); - - // Populate the pool with the prefix. - pool.put(SRC_SUBNET, prefix.clone(), REGISTRY_VERSION, log.clone()).unwrap(); - assert_matches!( - pool.slice_stats(SRC_SUBNET), - (None, Some(messages_begin), count, byte_size) - if messages_begin == stream_begin - && count == prefix_msg_count - && byte_size > 0 - ); - - // `append()` should fail. - assert_matches!( - pool.append(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log.clone()), - Err(CertifiedSliceError::DecodeStreamError(DecodeStreamError::InvalidSignature(_))) - ); - - // Pool contents should be unchanged. - assert_matches!( - pool.slice_stats(SRC_SUBNET), - (None, Some(messages_begin), count, byte_size) - if messages_begin == stream_begin - && count == prefix_msg_count - && byte_size > 0 - ); - assert_opt_slices_eq( - Some(prefix.clone()), - pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) - .unwrap() - .map(|(slice, _)| slice), - ); - stream_position.message_index = from; - assert_eq!( - (Some(stream_position.clone()), None, 0, 0), - pool.slice_stats(SRC_SUBNET) - ); + // Appending the matching second half should succeed. + let suffix_slice = fixture.get_partial_slice(DST_SUBNET, from, mid, suffix_len); + pool.append(SRC_SUBNET, suffix_slice, REGISTRY_VERSION, log) + .unwrap(); + // And result in the full slice being pooled. + assert_matches!( + pool.slice_stats(SRC_SUBNET), + (None, Some(messages_begin), count, byte_size) + if messages_begin == from + && count == msg_count + && byte_size > 0 + ); + assert_opt_slices_eq( + Some(slice), + pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) + .unwrap() + .map(|(slice, _)| slice), + ); + }); +} - pool.observe_pool_size_bytes(); - assert_eq!( - 0, - fixture.fetch_pool_size_bytes() - ); - }); - } +#[test_strategy::proptest] +fn pool_put_invalid_slice( + #[strategy(arb_stream_slice( + 1, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (mut stream, from, msg_count) = test_slice; + + with_test_replica_logger(|log| { + // Increment `signals_end` so we can later safely decrement it without underflow. + stream.push_accept_signal(); + + let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); + let slice = fixture.get_slice(DST_SUBNET, from, msg_count); + + let stream_position = ExpectedIndices { + message_index: from, + signal_index: stream.signals_end(), + }; + + let mut certified_stream_store = MockCertifiedStreamStore::new(); + // Fail validation for the slice. + certified_stream_store + .expect_decode_certified_stream_slice() + .returning(move |_, _, _| Err(DecodeStreamError::InvalidSignature(REMOTE_SUBNET))); + let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; + let mut pool = + CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); + + // `put()` should fail. + assert_matches!( + pool.put(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log.clone()), + Err(CertifiedSliceError::DecodeStreamError( + DecodeStreamError::InvalidSignature(_) + )) + ); + + // Pool should be untouched. + assert_eq!((None, None, 0, 0), pool.slice_stats(SRC_SUBNET)); + assert_opt_slices_eq( + None, + pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) + .unwrap() + .map(|(slice, _)| slice), + ); + pool.observe_pool_size_bytes(); + assert_eq!(0, fixture.fetch_pool_size_bytes()); + }); +} - #[test] - fn pool_append_invalid_slice_to_empty( - (mut stream, from, msg_count) in arb_stream_slice(1, 10, 0, 10), - ) { - with_test_replica_logger(|log| { - // Increment `signals_end` so we can later safely decrement it without underflow. - stream.push_accept_signal(); - - let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); - let slice = fixture.get_slice(DST_SUBNET, from, msg_count); - - let stream_position = ExpectedIndices{ - message_index: from, - signal_index: stream.signals_end(), - }; - - let mut certified_stream_store = MockCertifiedStreamStore::new(); - // Fail validation for the slice. - certified_stream_store - .expect_decode_certified_stream_slice() - .returning(move |_, _, _| Err(DecodeStreamError::InvalidSignature(REMOTE_SUBNET))); - let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; - let mut pool = CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); - - // `append()` should fail. - assert_matches!( - pool.append(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log.clone()), - Err(CertifiedSliceError::DecodeStreamError(DecodeStreamError::InvalidSignature(_))) - ); +#[test_strategy::proptest] +fn pool_append_invalid_slice( + #[strategy(arb_stream_slice( + 2, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (mut stream, _, _) = test_slice; + + let stream_begin = stream.messages_begin(); + // Set `from` and `msg_count` so that we always get two non-empty slices. + let from = stream_begin.increment(); + let msg_count = 1; + + with_test_replica_logger(|log| { + // Increment `signals_end` so we can later safely decrement it without underflow. + stream.push_accept_signal(); + + let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); + let prefix_msg_count = (from - stream_begin).get() as usize; + let prefix = fixture.get_slice(DST_SUBNET, stream_begin, prefix_msg_count); + let slice = fixture.get_partial_slice(DST_SUBNET, stream_begin, from, msg_count); + + let mut stream_position = ExpectedIndices { + message_index: stream_begin, + signal_index: stream.signals_end(), + }; + + let mut certified_stream_store = MockCertifiedStreamStore::new(); + // Accept the prefix as valid, but fail validation for the merged slice. + certified_stream_store + .expect_decode_certified_stream_slice() + .with(always(), always(), eq(prefix.clone())) + .returning(|_, _, _| Ok(StreamSliceBuilder::new().build())); + certified_stream_store + .expect_decode_certified_stream_slice() + .returning(move |_, _, _| Err(DecodeStreamError::InvalidSignature(REMOTE_SUBNET))); + let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; + let mut pool = + CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); + + // Populate the pool with the prefix. + pool.put(SRC_SUBNET, prefix.clone(), REGISTRY_VERSION, log.clone()) + .unwrap(); + assert_matches!( + pool.slice_stats(SRC_SUBNET), + (None, Some(messages_begin), count, byte_size) + if messages_begin == stream_begin + && count == prefix_msg_count + && byte_size > 0 + ); + + // `append()` should fail. + assert_matches!( + pool.append(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log.clone()), + Err(CertifiedSliceError::DecodeStreamError( + DecodeStreamError::InvalidSignature(_) + )) + ); + + // Pool contents should be unchanged. + assert_matches!( + pool.slice_stats(SRC_SUBNET), + (None, Some(messages_begin), count, byte_size) + if messages_begin == stream_begin + && count == prefix_msg_count + && byte_size > 0 + ); + assert_opt_slices_eq( + Some(prefix.clone()), + pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) + .unwrap() + .map(|(slice, _)| slice), + ); + stream_position.message_index = from; + assert_eq!( + (Some(stream_position.clone()), None, 0, 0), + pool.slice_stats(SRC_SUBNET) + ); + + pool.observe_pool_size_bytes(); + assert_eq!(0, fixture.fetch_pool_size_bytes()); + }); +} - // Pool should be untouched. - assert_eq!( - (None, None, 0, 0), - pool.slice_stats(SRC_SUBNET) - ); - assert_opt_slices_eq( - None, - pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) - .unwrap() - .map(|(slice, _)| slice), - ); - pool.observe_pool_size_bytes(); - assert_eq!( - 0, - fixture.fetch_pool_size_bytes() - ); - }); - } +#[test_strategy::proptest] +fn pool_append_invalid_slice_to_empty( + #[strategy(arb_stream_slice( + 1, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (mut stream, from, msg_count) = test_slice; + + with_test_replica_logger(|log| { + // Increment `signals_end` so we can later safely decrement it without underflow. + stream.push_accept_signal(); + + let fixture = StateManagerFixture::new(log.clone()).with_stream(DST_SUBNET, stream.clone()); + let slice = fixture.get_slice(DST_SUBNET, from, msg_count); + + let stream_position = ExpectedIndices { + message_index: from, + signal_index: stream.signals_end(), + }; + + let mut certified_stream_store = MockCertifiedStreamStore::new(); + // Fail validation for the slice. + certified_stream_store + .expect_decode_certified_stream_slice() + .returning(move |_, _, _| Err(DecodeStreamError::InvalidSignature(REMOTE_SUBNET))); + let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; + let mut pool = + CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); + + // `append()` should fail. + assert_matches!( + pool.append(SRC_SUBNET, slice.clone(), REGISTRY_VERSION, log.clone()), + Err(CertifiedSliceError::DecodeStreamError( + DecodeStreamError::InvalidSignature(_) + )) + ); + + // Pool should be untouched. + assert_eq!((None, None, 0, 0), pool.slice_stats(SRC_SUBNET)); + assert_opt_slices_eq( + None, + pool.take_slice(SRC_SUBNET, Some(&stream_position), None, None) + .unwrap() + .map(|(slice, _)| slice), + ); + pool.observe_pool_size_bytes(); + assert_eq!(0, fixture.fetch_pool_size_bytes()); + }); } -proptest! { - #![proptest_config(ProptestConfig::with_cases(10))] +// Testing the 'signals limit' (the limit on messages in a slice such that the number of +// signals after inducting it is capped) requires streams with thousands of messages in it. +// +// It is therefore using a reduced number of cases to keep the load within reasonable bounds. +#[test_strategy::proptest(ProptestConfig::with_cases(10))] +fn pool_take_slice_respects_signal_limit( + #[strategy(arb_stream_slice( + MAX_SIGNALS, // min_size + 2 * MAX_SIGNALS, // max_size + 0, // min_signal_count + 0, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (stream, from, msg_count) = test_slice; + + with_test_replica_logger(|log| { + // Stream position matching slice begin. + let begin = ExpectedIndices { + message_index: from, + signal_index: stream.signals_end(), + }; - // Testing the 'signals limit' (the limit on messages in a slice such that the number of - // signals after inducting it is capped) requires streams with thousands of messages in it. - // - // It is therefore using a reduced number of cases to keep the load within reasonable bounds. - #[test] - fn pool_take_slice_respects_signal_limit( - (stream, from, msg_count) in arb_stream_slice(MAX_SIGNALS, 2 * MAX_SIGNALS, 0, 0), - ) { - with_test_replica_logger(|log| { - // Stream position matching slice begin. - let begin = ExpectedIndices{ - message_index: from, - signal_index: stream.signals_end(), - }; - - let stream_begin = stream.messages_begin(); - let fixture = StateManagerFixture::new(log.clone()).with_stream(SRC_SUBNET, stream); - let slice = fixture.get_slice(SRC_SUBNET, from, msg_count); - - let mut certified_stream_store = MockCertifiedStreamStore::new(); - // Actual return value does not matter as long as it's `Ok(_)`. - certified_stream_store - .expect_decode_certified_stream_slice() - .returning(|_, _, _| Ok(StreamSliceBuilder::new().build())); - let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; - let mut pool = CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); - - pool.put(SRC_SUBNET, slice, REGISTRY_VERSION, log.clone()).unwrap(); - let _ = pool.take_slice(SRC_SUBNET, Some(&begin), None, None).unwrap().unwrap(); - - let (new_begin, _, _, _) = pool.slice_stats(SRC_SUBNET); - let messages_end = new_begin.unwrap().message_index; - - assert!( - messages_end <= max_message_index(stream_begin), - "messages_end: {} > max_message_index: {}", - messages_end, - max_message_index(stream_begin), - ); - }); - } + let stream_begin = stream.messages_begin(); + let fixture = StateManagerFixture::new(log.clone()).with_stream(SRC_SUBNET, stream); + let slice = fixture.get_slice(SRC_SUBNET, from, msg_count); + + let mut certified_stream_store = MockCertifiedStreamStore::new(); + // Actual return value does not matter as long as it's `Ok(_)`. + certified_stream_store + .expect_decode_certified_stream_slice() + .returning(|_, _, _| Ok(StreamSliceBuilder::new().build())); + let certified_stream_store = Arc::new(certified_stream_store) as Arc<_>; + let mut pool = + CertifiedSlicePool::new(Arc::clone(&certified_stream_store), &fixture.metrics); + + pool.put(SRC_SUBNET, slice, REGISTRY_VERSION, log.clone()) + .unwrap(); + let _ = pool + .take_slice(SRC_SUBNET, Some(&begin), None, None) + .unwrap() + .unwrap(); + + let (new_begin, _, _, _) = pool.slice_stats(SRC_SUBNET); + let messages_end = new_begin.unwrap().message_index; + + assert!( + messages_end <= max_message_index(stream_begin), + "messages_end: {} > max_message_index: {}", + messages_end, + max_message_index(stream_begin), + ); + }); } diff --git a/rs/xnet/payload_builder/tests/xnet_payload_builder.rs b/rs/xnet/payload_builder/tests/xnet_payload_builder.rs index 7bad4e5bc65..932c5c470e5 100644 --- a/rs/xnet/payload_builder/tests/xnet_payload_builder.rs +++ b/rs/xnet/payload_builder/tests/xnet_payload_builder.rs @@ -281,456 +281,530 @@ fn out_stream(in_stream: &Stream, messages_begin: StreamIndex) -> Stream { ) } -proptest! { - #![proptest_config(ProptestConfig::with_cases(10))] - - /// Tests that the payload builder does not include messages in a stream slice that - /// would lead to more than `MAX_SIGNALS` amount of signals in the outgoing stream. - /// - /// The input consists of - /// - an outgoing stream that has already seen most of the messages coming - /// from the incoming stream, i.e. it has close and up to `MAX_SIGNALS` signals. - /// - a very large incoming stream that has slightly more than `MAX_SIGNALS` messages in it. - /// - /// The stream slice to include in the payload will start from `out_stream.signals_end()`. - /// - /// If there is room for more signals, messages are expected to be included in the slice - /// such that `slice.messages_end() - in_stream.begin()` == `MAX_SIGNALS`, i.e. after inducting - /// the slice there would be exactly `MAX_SIGNALS` signals in the `out_stream`. - #[test] - fn get_xnet_payload_respects_signal_limit( - // `MAX_SIGNALS` <= signals_end()` <= `MAX_SIGNALS` + 20 - out_stream in arb_stream_with_config( - 0..=10, // msg_start_range - 30..=40, // size_range - 10..=20, // signal_start_range - (MAX_SIGNALS - 10)..=MAX_SIGNALS, // signal_count_range - RejectReason::all(), - ), - // `MAX_SIGNALS` + 20 <= `messages_end() <= `MAX_SIGNALS` + 40 - in_stream in arb_stream_with_config( - 0..=10, // msg_start_range - (MAX_SIGNALS + 20)..=(MAX_SIGNALS + 30), // size_range - 10..=20, // signal_start_range - 0..=10, // signal_count_range - RejectReason::all(), - ), - ) { - with_test_replica_logger(|log| { - let from = out_stream.signals_end(); - let msg_count = (in_stream.messages_end() - from).get() as usize; - let signals_count_after_gc = (out_stream.signals_end() - in_stream.messages_begin()).get() as usize; - - let mut state_manager = - StateManagerFixture::with_subnet_type(SubnetType::Application, log.clone()); - state_manager = state_manager.with_stream(SUBNET_1, out_stream); - - let xnet_payload_builder = XNetPayloadBuilderFixture::new(state_manager); - xnet_payload_builder.pool_slice(SUBNET_1, &in_stream, from, msg_count, &log); - - // Build the payload without a byte limit. Messages up to the signal limit should be - // included. - let (payload, raw_payload, byte_size) = xnet_payload_builder.get_xnet_payload(usize::MAX); - assert_eq!(byte_size, xnet_payload_builder.validate_xnet_payload(&raw_payload).unwrap()); - - if signals_count_after_gc < MAX_SIGNALS { - let slice = payload.get(&SUBNET_1).unwrap(); - let messages = slice.messages().unwrap(); - - let signals_count = messages.end().get() - in_stream.messages_begin().get(); - assert_eq!( - MAX_SIGNALS, - signals_count as usize, - "inducting payload would lead to signals_count > MAX_SIGNALS", - ); - } else { - assert!(payload.len() <= 1, "at most one slice expected in the payload"); - if let Some(slice) = payload.get(&SUBNET_1) { - assert!(slice.messages().is_none(), "no messages expected in the slice"); - } - } - }); - } -} - -proptest! { - /// Tests payload building with various alignments of expected indices to - /// slice: just before the pooled slice, within the pooled slice, just - /// after the pooled slice. - #[test] - fn get_xnet_payload_slice_alignment( - (stream, from, msg_count) in arb_stream_slice(5, 10, 0, 10), - ) { - // Bump `from` (and adjust `msg_count` accordingly) so we can decrement it later - // on. - let from = from.increment(); - let msg_count = msg_count - 1; - - with_test_replica_logger(|log| { - let mut state_manager = - StateManagerFixture::with_subnet_type(SubnetType::Application, log.clone()); - - // We will be creating 3 identical copies of the slice, each coming from a - // different subnet. - // - // Create 3 reverse streams within `state_manager`: - // * One with `signals_end` just before `from` (so there's a missing message). - // * One with `signals_end` just after `from` (so there's an extra message). - // * One with `signals_end` just after `from + msg_count` (so we've already - // seen all messages). - state_manager = - state_manager.with_stream(SUBNET_1, out_stream(&stream, from.decrement())); - state_manager = - state_manager.with_stream(SUBNET_2, out_stream(&stream, from.increment())); - state_manager = state_manager.with_stream( - SUBNET_3, - out_stream(&stream, from + (msg_count as u64 + 1).into()), - ); - - // Create payload builder with the 3 slices pooled. - let xnet_payload_builder = XNetPayloadBuilderFixture::new(state_manager); - xnet_payload_builder.pool_slice(SUBNET_1, &stream, from, msg_count, &log); - xnet_payload_builder.pool_slice(SUBNET_2, &stream, from, msg_count, &log); - xnet_payload_builder.pool_slice(SUBNET_3, &stream, from, msg_count, &log); +/// Tests that the payload builder does not include messages in a stream slice that +/// would lead to more than `MAX_SIGNALS` amount of signals in the outgoing stream. +/// +/// The input consists of +/// - an outgoing stream that has already seen most of the messages coming +/// from the incoming stream, i.e. it has close and up to `MAX_SIGNALS` signals. +/// - a very large incoming stream that has slightly more than `MAX_SIGNALS` messages in it. +/// +/// The stream slice to include in the payload will start from `out_stream.signals_end()`. +/// +/// If there is room for more signals, messages are expected to be included in the slice +/// such that `slice.messages_end() - in_stream.begin()` == `MAX_SIGNALS`, i.e. after inducting +/// the slice there would be exactly `MAX_SIGNALS` signals in the `out_stream`. +#[test_strategy::proptest(ProptestConfig::with_cases(10))] +fn get_xnet_payload_respects_signal_limit( + // `MAX_SIGNALS` <= signals_end()` <= `MAX_SIGNALS` + 20 + #[strategy(arb_stream_with_config( + 0..=10, // msg_start_range + 30..=40, // size_range + 10..=20, // signal_start_range + (MAX_SIGNALS - 10)..=MAX_SIGNALS, // signal_count_range + RejectReason::all(), + ))] + out_stream: Stream, + + // `MAX_SIGNALS` + 20 <= `messages_end() <= `MAX_SIGNALS` + 40 + #[strategy(arb_stream_with_config( + 0..=10, // msg_start_range + (MAX_SIGNALS + 20)..=(MAX_SIGNALS + 30), // size_range + 10..=20, // signal_start_range + 0..=10, // signal_count_range + RejectReason::all(), + ))] + in_stream: Stream, +) { + with_test_replica_logger(|log| { + let from = out_stream.signals_end(); + let msg_count = (in_stream.messages_end() - from).get() as usize; + let signals_count_after_gc = + (out_stream.signals_end() - in_stream.messages_begin()).get() as usize; + + let mut state_manager = + StateManagerFixture::with_subnet_type(SubnetType::Application, log.clone()); + state_manager = state_manager.with_stream(SUBNET_1, out_stream); + + let xnet_payload_builder = XNetPayloadBuilderFixture::new(state_manager); + xnet_payload_builder.pool_slice(SUBNET_1, &in_stream, from, msg_count, &log); + + // Build the payload without a byte limit. Messages up to the signal limit should be + // included. + let (payload, raw_payload, byte_size) = xnet_payload_builder.get_xnet_payload(usize::MAX); + assert_eq!( + byte_size, + xnet_payload_builder + .validate_xnet_payload(&raw_payload) + .unwrap() + ); - // Build the payload. - let payload = xnet_payload_builder - .get_xnet_payload(usize::MAX).0; + if signals_count_after_gc < MAX_SIGNALS { + let slice = payload.get(&SUBNET_1).unwrap(); + let messages = slice.messages().unwrap(); - // Payload should contain 1 slice... + let signals_count = messages.end().get() - in_stream.messages_begin().get(); assert_eq!( - 1, - payload.len(), - "Expecting 1 slice in payload, got {}", - payload.len() + MAX_SIGNALS, signals_count as usize, + "inducting payload would lead to signals_count > MAX_SIGNALS", ); - // ...from SUBNET_2... - if let Some(slice) = payload.get(&SUBNET_2) { - assert_eq!(stream.messages_begin(), slice.header().begin()); - assert_eq!(stream.messages_end(), slice.header().end()); - assert_eq!(stream.signals_end(), slice.header().signals_end()); - - // ...with non-empty messages... - if let Some(messages) = slice.messages() { - // ...between (from + 1) and stream.end. - assert_eq!(from.increment(), messages.begin()); - assert_eq!(from + (msg_count as u64).into(), messages.end()); - } else { - panic!("Expected a non-empty slice from SUBNET_2"); - } - } else { - panic!( - "Expected a slice from SUBNET_2, got {:?}", - payload.keys().next() + } else { + assert!( + payload.len() <= 1, + "at most one slice expected in the payload" + ); + if let Some(slice) = payload.get(&SUBNET_1) { + assert!( + slice.messages().is_none(), + "no messages expected in the slice" ); } + } + }); +} - assert_eq!( - metric_vec(&[(&[(LABEL_STATUS, STATUS_SUCCESS)], 1)]), - xnet_payload_builder.build_payload_counts() - ); - assert_eq!( - HistogramStats { - count: 1, - sum: (msg_count - 1) as f64 - }, - xnet_payload_builder.slice_messages_stats() - ); - assert_eq!(1, xnet_payload_builder.slice_payload_size_stats().count); - }); - } - - /// Tests payload building with a byte limit just under the total slice - /// size. - #[test] - fn get_xnet_payload_byte_limit_exceeded( - (stream1, from1, msg_count1) in arb_stream_slice(10, 15, 0, 10), - (stream2, from2, msg_count2) in arb_stream_slice(10, 15, 0, 10), - ) { - with_test_replica_logger(|log| { - let mut state_manager = - StateManagerFixture::with_subnet_type(SubnetType::Application, log.clone()); - - // Create a matching outgoing stream within `state_manager` for each slice. - state_manager = state_manager.with_stream(SUBNET_1, out_stream(&stream1, from1)); - state_manager = state_manager.with_stream(SUBNET_2, out_stream(&stream2, from2)); - - // Create payload builder with the 2 slices pooled. - let xnet_payload_builder = XNetPayloadBuilderFixture::new(state_manager); - let mut slice_bytes_sum = 0; - slice_bytes_sum += xnet_payload_builder.pool_slice(SUBNET_1, &stream1, from1, msg_count1, &log); - slice_bytes_sum += xnet_payload_builder.pool_slice(SUBNET_2, &stream2, from2, msg_count2, &log); - - // Build a payload with a byte limit just under the total size of the 2 slices. - let payload = xnet_payload_builder - .get_xnet_payload(slice_bytes_sum - 1).0; - - // Payload should contain 2 slices. - assert_eq!( - 2, - payload.len(), - "Expecting 2 slices in payload, got {}", - payload.len() - ); - // And exactly one message should be missing. - let msg_count: usize = payload - .values() - .map(|slice| slice.messages().map_or(0, |m| m.len())) - .sum(); - assert_eq!(msg_count1 + msg_count2 - 1, msg_count); +/// Tests payload building with various alignments of expected indices to +/// slice: just before the pooled slice, within the pooled slice, just +/// after the pooled slice. +#[test_strategy::proptest] +fn get_xnet_payload_slice_alignment( + #[strategy(arb_stream_slice( + 5, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (stream, from, msg_count) = test_slice; + + // Bump `from` (and adjust `msg_count` accordingly) so we can decrement it later + // on. + let from = from.increment(); + let msg_count = msg_count - 1; + + with_test_replica_logger(|log| { + let mut state_manager = + StateManagerFixture::with_subnet_type(SubnetType::Application, log.clone()); + + // We will be creating 3 identical copies of the slice, each coming from a + // different subnet. + // + // Create 3 reverse streams within `state_manager`: + // * One with `signals_end` just before `from` (so there's a missing message). + // * One with `signals_end` just after `from` (so there's an extra message). + // * One with `signals_end` just after `from + msg_count` (so we've already + // seen all messages). + state_manager = state_manager.with_stream(SUBNET_1, out_stream(&stream, from.decrement())); + state_manager = state_manager.with_stream(SUBNET_2, out_stream(&stream, from.increment())); + state_manager = state_manager.with_stream( + SUBNET_3, + out_stream(&stream, from + (msg_count as u64 + 1).into()), + ); - assert_eq!( - metric_vec(&[(&[(LABEL_STATUS, STATUS_SUCCESS)], 1)]), - xnet_payload_builder.build_payload_counts() - ); - assert_eq!( - HistogramStats { - count: 2, - sum: (msg_count1 + msg_count2 - 1) as f64 - }, - xnet_payload_builder.slice_messages_stats() + // Create payload builder with the 3 slices pooled. + let xnet_payload_builder = XNetPayloadBuilderFixture::new(state_manager); + xnet_payload_builder.pool_slice(SUBNET_1, &stream, from, msg_count, &log); + xnet_payload_builder.pool_slice(SUBNET_2, &stream, from, msg_count, &log); + xnet_payload_builder.pool_slice(SUBNET_3, &stream, from, msg_count, &log); + + // Build the payload. + let payload = xnet_payload_builder.get_xnet_payload(usize::MAX).0; + + // Payload should contain 1 slice... + assert_eq!( + 1, + payload.len(), + "Expecting 1 slice in payload, got {}", + payload.len() + ); + // ...from SUBNET_2... + if let Some(slice) = payload.get(&SUBNET_2) { + assert_eq!(stream.messages_begin(), slice.header().begin()); + assert_eq!(stream.messages_end(), slice.header().end()); + assert_eq!(stream.signals_end(), slice.header().signals_end()); + + // ...with non-empty messages... + if let Some(messages) = slice.messages() { + // ...between (from + 1) and stream.end. + assert_eq!(from.increment(), messages.begin()); + assert_eq!(from + (msg_count as u64).into(), messages.end()); + } else { + panic!("Expected a non-empty slice from SUBNET_2"); + } + } else { + panic!( + "Expected a slice from SUBNET_2, got {:?}", + payload.keys().next() ); - assert_eq!(2, xnet_payload_builder.slice_payload_size_stats().count); - }); - } - - /// Tests payload building with a byte limit too small even for an empty - /// slice. - #[test] - fn get_xnet_payload_byte_limit_too_small( - (stream, from, msg_count) in arb_stream_slice(10, 15, 0, 10), - ) { - with_test_replica_logger(|log| { - let mut state_manager = - StateManagerFixture::with_subnet_type(SubnetType::Application, log.clone()); + } - // Create a matching outgoing stream within `state_manager` for each slice. - state_manager = state_manager.with_stream(REMOTE_SUBNET, out_stream(&stream, from)); + assert_eq!( + metric_vec(&[(&[(LABEL_STATUS, STATUS_SUCCESS)], 1)]), + xnet_payload_builder.build_payload_counts() + ); + assert_eq!( + HistogramStats { + count: 1, + sum: (msg_count - 1) as f64 + }, + xnet_payload_builder.slice_messages_stats() + ); + assert_eq!(1, xnet_payload_builder.slice_payload_size_stats().count); + }); +} - // Create payload builder with the slice pooled. - let xnet_payload_builder = XNetPayloadBuilderFixture::new(state_manager); - xnet_payload_builder.pool_slice(REMOTE_SUBNET, &stream, from, msg_count, &log); +/// Tests payload building with a byte limit just under the total slice +/// size. +#[test_strategy::proptest] +fn get_xnet_payload_byte_limit_exceeded( + #[strategy(arb_stream_slice( + 10, // min_size + 15, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice1: (Stream, StreamIndex, usize), + #[strategy(arb_stream_slice( + 10, // min_size + 15, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice2: (Stream, StreamIndex, usize), +) { + let (stream1, from1, msg_count1) = test_slice1; + let (stream2, from2, msg_count2) = test_slice2; + + with_test_replica_logger(|log| { + let mut state_manager = + StateManagerFixture::with_subnet_type(SubnetType::Application, log.clone()); + + // Create a matching outgoing stream within `state_manager` for each slice. + state_manager = state_manager.with_stream(SUBNET_1, out_stream(&stream1, from1)); + state_manager = state_manager.with_stream(SUBNET_2, out_stream(&stream2, from2)); + + // Create payload builder with the 2 slices pooled. + let xnet_payload_builder = XNetPayloadBuilderFixture::new(state_manager); + let mut slice_bytes_sum = 0; + slice_bytes_sum += + xnet_payload_builder.pool_slice(SUBNET_1, &stream1, from1, msg_count1, &log); + slice_bytes_sum += + xnet_payload_builder.pool_slice(SUBNET_2, &stream2, from2, msg_count2, &log); + + // Build a payload with a byte limit just under the total size of the 2 slices. + let payload = xnet_payload_builder.get_xnet_payload(slice_bytes_sum - 1).0; + + // Payload should contain 2 slices. + assert_eq!( + 2, + payload.len(), + "Expecting 2 slices in payload, got {}", + payload.len() + ); + // And exactly one message should be missing. + let msg_count: usize = payload + .values() + .map(|slice| slice.messages().map_or(0, |m| m.len())) + .sum(); + assert_eq!(msg_count1 + msg_count2 - 1, msg_count); + + assert_eq!( + metric_vec(&[(&[(LABEL_STATUS, STATUS_SUCCESS)], 1)]), + xnet_payload_builder.build_payload_counts() + ); + assert_eq!( + HistogramStats { + count: 2, + sum: (msg_count1 + msg_count2 - 1) as f64 + }, + xnet_payload_builder.slice_messages_stats() + ); + assert_eq!(2, xnet_payload_builder.slice_payload_size_stats().count); + }); +} - // Build a payload with a byte limit too small even for an empty slice. - let (payload, _, byte_size) = xnet_payload_builder.get_xnet_payload(1); +/// Tests payload building with a byte limit too small even for an empty +/// slice. +#[test_strategy::proptest] +fn get_xnet_payload_byte_limit_too_small( + #[strategy(arb_stream_slice( + 10, // min_size + 15, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (stream, from, msg_count) = test_slice; + + with_test_replica_logger(|log| { + let mut state_manager = + StateManagerFixture::with_subnet_type(SubnetType::Application, log.clone()); + + // Create a matching outgoing stream within `state_manager` for each slice. + state_manager = state_manager.with_stream(REMOTE_SUBNET, out_stream(&stream, from)); + + // Create payload builder with the slice pooled. + let xnet_payload_builder = XNetPayloadBuilderFixture::new(state_manager); + xnet_payload_builder.pool_slice(REMOTE_SUBNET, &stream, from, msg_count, &log); + + // Build a payload with a byte limit too small even for an empty slice. + let (payload, _, byte_size) = xnet_payload_builder.get_xnet_payload(1); + + // Payload should contain no slices. + assert!( + payload.is_empty(), + "Expecting empty payload, got payload of length {}", + payload.len() + ); + assert_eq!(0, byte_size.get()); - // Payload should contain no slices. - assert!( - payload.is_empty(), - "Expecting empty payload, got payload of length {}", - payload.len() - ); - assert_eq!(0, byte_size.get()); + assert_eq!( + metric_vec(&[(&[(LABEL_STATUS, STATUS_SUCCESS)], 1)]), + xnet_payload_builder.build_payload_counts() + ); + assert_eq!( + HistogramStats { count: 0, sum: 0.0 }, + xnet_payload_builder.slice_messages_stats() + ); + assert_eq!(0, xnet_payload_builder.slice_payload_size_stats().count); + }); +} - assert_eq!( - metric_vec(&[(&[(LABEL_STATUS, STATUS_SUCCESS)], 1)]), - xnet_payload_builder.build_payload_counts() - ); - assert_eq!( - HistogramStats { - count: 0, - sum: 0.0 - }, - xnet_payload_builder.slice_messages_stats() +/// Tests payload building from a pool containing an empty slice only. +#[test_strategy::proptest] +fn get_xnet_payload_empty_slice( + #[strategy(arb_stream( + 1, // min_size + 1, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + out_stream: Stream, +) { + // Empty incoming stream. + let from = out_stream.signals_end(); + let stream = Stream::new( + StreamIndexedQueue::with_begin(from), + out_stream.header().begin(), + ); + + with_test_replica_logger(|log| { + let mut state_manager = + StateManagerFixture::with_subnet_type(SubnetType::Application, log.clone()); + + // Place outgoing stream into `state_manager`. + state_manager = state_manager.with_stream(REMOTE_SUBNET, out_stream); + + // Create payload builder with empty slice pooled. + let xnet_payload_builder = XNetPayloadBuilderFixture::new(state_manager); + xnet_payload_builder.pool_slice(REMOTE_SUBNET, &stream, from, 0, &log); + + // Build a payload. + let (payload, _, byte_size) = xnet_payload_builder.get_xnet_payload(usize::MAX); + + // Payload should be empty (we already have all signals in the slice). + assert!( + payload.is_empty(), + "Expecting empty in payload, got a slice" + ); + assert_eq!(0, byte_size.get()); + + // Bump `stream.signals_end` and pool an empty slice again. + let mut updated_stream = stream.clone(); + updated_stream.push_accept_signal(); + xnet_payload_builder.pool_slice(REMOTE_SUBNET, &updated_stream, from, 0, &log); + + // Build a payload again. + let payload = xnet_payload_builder.get_xnet_payload(usize::MAX).0; + + // Payload should now contain 1 empty slice from REMOTE_SUBNET. + assert_eq!( + 1, + payload.len(), + "Expecting 1 slice in payload, got {}", + payload.len() + ); + if let Some(slice) = payload.get(&REMOTE_SUBNET) { + assert_eq!(stream.messages_begin(), slice.header().begin()); + assert_eq!(stream.messages_end(), slice.header().end()); + assert_eq!(updated_stream.signals_end(), slice.header().signals_end()); + assert!(slice.messages().is_none()); + } else { + panic!( + "Expected a slice from REMOTE_SUBNET, got {:?}", + payload.keys().next() ); - assert_eq!(0, xnet_payload_builder.slice_payload_size_stats().count); - }); - } + } - /// Tests payload building from a pool containing an empty slice only. - #[test] - fn get_xnet_payload_empty_slice( - out_stream in arb_stream(1, 1, 0, 10), - ) { - // Empty incoming stream. - let from = out_stream.signals_end(); - let stream = Stream::new( - StreamIndexedQueue::with_begin(from), - out_stream.header().begin(), + assert_eq!( + metric_vec(&[(&[(LABEL_STATUS, STATUS_SUCCESS)], 2)]), + xnet_payload_builder.build_payload_counts() ); + assert_eq!( + HistogramStats { count: 1, sum: 0. }, + xnet_payload_builder.slice_messages_stats() + ); + assert_eq!(1, xnet_payload_builder.slice_payload_size_stats().count); + }); +} - with_test_replica_logger(|log| { - let mut state_manager = - StateManagerFixture::with_subnet_type(SubnetType::Application, log.clone()); - - // Place outgoing stream into `state_manager`. - state_manager = state_manager.with_stream(REMOTE_SUBNET, out_stream); - - // Create payload builder with empty slice pooled. - let xnet_payload_builder = XNetPayloadBuilderFixture::new(state_manager); - xnet_payload_builder.pool_slice(REMOTE_SUBNET, &stream, from, 0, &log); - - // Build a payload. - let (payload, _, byte_size) = xnet_payload_builder - .get_xnet_payload(usize::MAX); - - // Payload should be empty (we already have all signals in the slice). - assert!(payload.is_empty(), "Expecting empty in payload, got a slice"); - assert_eq!(0, byte_size.get()); +/// Tests payload building on a system subnet when the combined sizes of the +/// incoming stream slice and outgoing stream exceed the system subnet +/// stream throttling limit. +#[test_strategy::proptest] +fn system_subnet_stream_throttling( + #[strategy(arb_stream( + SYSTEM_SUBNET_STREAM_MSG_LIMIT / 2 + 1, // min_size + SYSTEM_SUBNET_STREAM_MSG_LIMIT + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + out_stream: Stream, + #[strategy(arb_stream_slice( + SYSTEM_SUBNET_STREAM_MSG_LIMIT / 2 + 1, // min_size + SYSTEM_SUBNET_STREAM_MSG_LIMIT, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (stream, from, msg_count) = test_slice; + + // Set the outgoing stream's signals_end to the slice begin. + let out_stream = Stream::new(out_stream.messages().clone(), from); + // And the incoming stream's signals_end just beyond the outgoing stream's + // start, so we always get a slice, even empty. + let stream = Stream::new( + stream.messages().clone(), + out_stream.messages_begin().increment(), + ); + + with_test_replica_logger(|log| { + // Fixtures. + let state_manager = StateManagerFixture::with_subnet_type(SubnetType::System, log.clone()) + .with_stream(REMOTE_SUBNET, out_stream.clone()); + let xnet_payload_builder = XNetPayloadBuilderFixture::new(state_manager); + + // Populate payload builder pool with the REMOTE_SUBNET -> OWN_SUBNET slice. + let certified_slice = in_slice(&stream, from, from, msg_count, &log); + { + let mut slice_pool = xnet_payload_builder.certified_slice_pool.lock().unwrap(); + slice_pool + .put( + REMOTE_SUBNET, + certified_slice, + REGISTRY_VERSION, + log.clone(), + ) + .unwrap(); + } - // Bump `stream.signals_end` and pool an empty slice again. - let mut updated_stream = stream.clone(); - updated_stream.push_accept_signal(); - xnet_payload_builder.pool_slice(REMOTE_SUBNET, &updated_stream, from, 0, &log); + let payload = xnet_payload_builder.get_xnet_payload(usize::MAX).0; - // Build a payload again. - let payload = xnet_payload_builder - .get_xnet_payload(usize::MAX).0; + assert_eq!(1, payload.len()); + if let Some(slice) = payload.get(&REMOTE_SUBNET) { + let max_slice_len = + SYSTEM_SUBNET_STREAM_MSG_LIMIT.saturating_sub(out_stream.messages().len()); + let expected_slice_len = msg_count.min(max_slice_len); - // Payload should now contain 1 empty slice from REMOTE_SUBNET. - assert_eq!( - 1, - payload.len(), - "Expecting 1 slice in payload, got {}", - payload.len() - ); - if let Some(slice) = payload.get(&REMOTE_SUBNET) { - assert_eq!(stream.messages_begin(), slice.header().begin()); - assert_eq!(stream.messages_end(), slice.header().end()); - assert_eq!(updated_stream.signals_end(), slice.header().signals_end()); + if expected_slice_len == 0 { assert!(slice.messages().is_none()); + } else if let Some(messages) = slice.messages() { + assert_eq!( + expected_slice_len, + messages.len(), + "Expecting a slice of length min({}, {}), got {}", + msg_count, + max_slice_len, + messages.len() + ); } else { panic!( - "Expected a slice from REMOTE_SUBNET, got {:?}", - payload.keys().next() + "Expecting a slice of length min({}, {}), got an empty slice", + msg_count, max_slice_len ); } assert_eq!( - metric_vec(&[(&[(LABEL_STATUS, STATUS_SUCCESS)], 2)]), + metric_vec(&[(&[(LABEL_STATUS, STATUS_SUCCESS)], 1)]), xnet_payload_builder.build_payload_counts() ); assert_eq!( HistogramStats { count: 1, - sum: 0. + sum: expected_slice_len as f64 }, xnet_payload_builder.slice_messages_stats() ); assert_eq!(1, xnet_payload_builder.slice_payload_size_stats().count); - }); - } - - /// Tests payload building on a system subnet when the combined sizes of the - /// incoming stream slice and outgoing stream exceed the system subnet - /// stream throttling limit. - #[test] - fn system_subnet_stream_throttling( - out_stream in arb_stream(SYSTEM_SUBNET_STREAM_MSG_LIMIT / 2 + 1, SYSTEM_SUBNET_STREAM_MSG_LIMIT + 10, 0, 10), - (stream, from, msg_count) in arb_stream_slice(SYSTEM_SUBNET_STREAM_MSG_LIMIT / 2 + 1, SYSTEM_SUBNET_STREAM_MSG_LIMIT, 0, 10), - ) { - // Set the outgoing stream's signals_end to the slice begin. - let out_stream = Stream::new(out_stream.messages().clone(), from); - // And the incoming stream's signals_end just beyond the outgoing stream's - // start, so we always get a slice, even empty. - let stream = Stream::new(stream.messages().clone(), out_stream.messages_begin().increment()); - - with_test_replica_logger(|log| { - // Fixtures. - let state_manager = - StateManagerFixture::with_subnet_type(SubnetType::System, log.clone()) - .with_stream(REMOTE_SUBNET, out_stream.clone()); - let xnet_payload_builder = XNetPayloadBuilderFixture::new(state_manager); - - // Populate payload builder pool with the REMOTE_SUBNET -> OWN_SUBNET slice. - let certified_slice = in_slice(&stream, from, from, msg_count, &log); - { - let mut slice_pool = xnet_payload_builder.certified_slice_pool.lock().unwrap(); - slice_pool.put(REMOTE_SUBNET, certified_slice, REGISTRY_VERSION, log.clone()).unwrap(); - } - - let payload = xnet_payload_builder - .get_xnet_payload(usize::MAX).0; - - assert_eq!(1, payload.len()); - if let Some(slice) = payload.get(&REMOTE_SUBNET) { - let max_slice_len = - SYSTEM_SUBNET_STREAM_MSG_LIMIT.saturating_sub(out_stream.messages().len()); - let expected_slice_len = msg_count.min(max_slice_len); - - if expected_slice_len == 0 { - assert!(slice.messages().is_none()); - } else if let Some(messages) = slice.messages() { - assert_eq!( - expected_slice_len, - messages.len(), - "Expecting a slice of length min({}, {}), got {}", - msg_count, - max_slice_len, - messages.len() - ); - } else { - panic!( - "Expecting a slice of length min({}, {}), got an empty slice", - msg_count, max_slice_len - ); - } - - assert_eq!( - metric_vec(&[(&[(LABEL_STATUS, STATUS_SUCCESS)], 1)]), - xnet_payload_builder.build_payload_counts() - ); - assert_eq!( - HistogramStats { - count: 1, - sum: expected_slice_len as f64 - }, - xnet_payload_builder.slice_messages_stats() - ); - assert_eq!(1, xnet_payload_builder.slice_payload_size_stats().count); - } else { - panic!( - "Expecting payload with a single slice, from {}", - REMOTE_SUBNET - ); - } - }); - } - - /// Tests that `validate_xnet_payload()` successfully validates any payload - /// produced by `get_xnet_payload()` and produces the same size estimate. - #[test] - fn validate_xnet_payload( - (stream1, from1, msg_count1) in arb_stream_slice(0, 10, 0, 10), - (stream2, from2, msg_count2) in arb_stream_slice(0, 10, 0, 10), - size_limit_percentage in 0..110u64, - ) { - with_test_replica_logger(|log| { - let mut state_manager = - StateManagerFixture::with_subnet_type(SubnetType::Application, log.clone()); - - // Create a matching outgoing stream within `state_manager` for each slice. - state_manager = state_manager.with_stream(SUBNET_1, out_stream(&stream1, from1)); - state_manager = state_manager.with_stream(SUBNET_2, out_stream(&stream2, from2)); - - // Create payload builder with the 2 slices pooled. - let fixture = XNetPayloadBuilderFixture::new(state_manager); - let mut slice_bytes_sum = 0; - slice_bytes_sum += fixture.pool_slice(SUBNET_1, &stream1, from1, msg_count1, &log); - slice_bytes_sum += fixture.pool_slice(SUBNET_2, &stream2, from2, msg_count2, &log); - - let validation_context = validation_context_at(fixture.certified_height); - - // Build a payload with a byte limit dictated by `size_limit_percentage`. - let byte_size_limit = (slice_bytes_sum as u64 * size_limit_percentage / 100).into(); - let (payload, byte_size) = fixture.xnet_payload_builder.get_xnet_payload( - &validation_context, - &[], - byte_size_limit, + } else { + panic!( + "Expecting payload with a single slice, from {}", + REMOTE_SUBNET ); - assert!(byte_size <= byte_size_limit); + } + }); +} - // Payload should validate and the size estimate should match. - assert_eq!( - byte_size, - fixture.xnet_payload_builder.validate_xnet_payload( - &payload, - &validation_context, - &[], - ).unwrap() - ); - }); - } +/// Tests that `validate_xnet_payload()` successfully validates any payload +/// produced by `get_xnet_payload()` and produces the same size estimate. +#[test_strategy::proptest] +fn validate_xnet_payload( + #[strategy(arb_stream_slice( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice1: (Stream, StreamIndex, usize), + #[strategy(arb_stream_slice( + 0, // min_size + 10, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice2: (Stream, StreamIndex, usize), + #[strategy(0..110u64)] size_limit_percentage: u64, +) { + let (stream1, from1, msg_count1) = test_slice1; + let (stream2, from2, msg_count2) = test_slice2; + + with_test_replica_logger(|log| { + let mut state_manager = + StateManagerFixture::with_subnet_type(SubnetType::Application, log.clone()); + + // Create a matching outgoing stream within `state_manager` for each slice. + state_manager = state_manager.with_stream(SUBNET_1, out_stream(&stream1, from1)); + state_manager = state_manager.with_stream(SUBNET_2, out_stream(&stream2, from2)); + + // Create payload builder with the 2 slices pooled. + let fixture = XNetPayloadBuilderFixture::new(state_manager); + let mut slice_bytes_sum = 0; + slice_bytes_sum += fixture.pool_slice(SUBNET_1, &stream1, from1, msg_count1, &log); + slice_bytes_sum += fixture.pool_slice(SUBNET_2, &stream2, from2, msg_count2, &log); + + let validation_context = validation_context_at(fixture.certified_height); + + // Build a payload with a byte limit dictated by `size_limit_percentage`. + let byte_size_limit = (slice_bytes_sum as u64 * size_limit_percentage / 100).into(); + let (payload, byte_size) = fixture.xnet_payload_builder.get_xnet_payload( + &validation_context, + &[], + byte_size_limit, + ); + assert!(byte_size <= byte_size_limit); + + // Payload should validate and the size estimate should match. + assert_eq!( + byte_size, + fixture + .xnet_payload_builder + .validate_xnet_payload(&payload, &validation_context, &[],) + .unwrap() + ); + }); } /// A fake `XNetClient` that returns the results matching the respective @@ -777,354 +851,460 @@ enum FakeXNetClientError { NoContent, } -proptest! { - /// Tests refilling an empty pool. - #[test] - fn refill_pool_empty( - (stream, from, msg_count) in arb_stream_slice(10, 15, 0, 10), - ) { - with_test_replica_logger(|log| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - - let slice = stream.slice(from, Some(msg_count)); - let certified_slice = in_slice(&stream, from, from, msg_count, &log); - - let stream_position = ExpectedIndices { - message_index: from, - signal_index: stream.signals_end(), - }; - - let mut certified_stream_store = MockCertifiedStreamStore::new(); - certified_stream_store - .expect_decode_certified_stream_slice() - .returning(move |_, _, _| Ok(slice.clone())); - let metrics_registry = MetricsRegistry::new(); - let pool = Arc::new(Mutex::new(CertifiedSlicePool::new(Arc::new(certified_stream_store) as Arc<_>, &metrics_registry))); - pool.lock() - .unwrap() - .garbage_collect(btreemap! [REMOTE_SUBNET => stream_position.clone()]); - - let registry = get_registry_for_test(); - let proximity_map = Arc::new(ProximityMap::new(OWN_NODE, registry.clone(), &metrics_registry, log.clone())); - let endpoint_resolver = - XNetEndpointResolver::new(registry.clone(), OWN_NODE, OWN_SUBNET, proximity_map, log.clone()); - let byte_limit = (POOL_SLICE_BYTE_SIZE_MAX - 350) * 98 / 100; - let url = endpoint_resolver - .xnet_endpoint_url(REMOTE_SUBNET, from, from, byte_limit) - .unwrap() - .url - .to_string(); - - let xnet_client = Arc::new(FakeXNetClient { - results: btreemap![ - url => Ok(certified_slice.clone()), - ], - }); - - let refill_handle = PoolRefillTask::start( - Arc::clone(&pool), - endpoint_resolver, - xnet_client, - runtime.handle().clone(), - Arc::new(XNetPayloadBuilderMetrics::new(&metrics_registry)), - log, - ); - refill_handle.trigger_refill(registry.get_latest_version()); - - runtime.block_on(async { - let mut count: u64 = 0; - // Keep polling until a slice is present in the pool. - loop { - if let (_, Some(_), _, _) = pool.lock().unwrap().slice_stats(REMOTE_SUBNET) - { - break; - } - count += 1; - if count > 50 { - panic!("refill task failed to complete within 5 seconds"); - } - tokio::time::sleep(Duration::from_millis(100)).await; - } - }); - - assert_opt_slices_eq( - Some(certified_slice), - pool.lock() - .unwrap() - .take_slice(REMOTE_SUBNET, Some(&stream_position), None, None) - .unwrap() - .map(|(slice, _)| slice), - ); +/// Tests refilling an empty pool. +#[test_strategy::proptest] +fn refill_pool_empty( + #[strategy(arb_stream_slice( + 10, // min_size + 15, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (stream, from, msg_count) = test_slice; + + with_test_replica_logger(|log| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + + let slice = stream.slice(from, Some(msg_count)); + let certified_slice = in_slice(&stream, from, from, msg_count, &log); + + let stream_position = ExpectedIndices { + message_index: from, + signal_index: stream.signals_end(), + }; + + let mut certified_stream_store = MockCertifiedStreamStore::new(); + certified_stream_store + .expect_decode_certified_stream_slice() + .returning(move |_, _, _| Ok(slice.clone())); + let metrics_registry = MetricsRegistry::new(); + let pool = Arc::new(Mutex::new(CertifiedSlicePool::new( + Arc::new(certified_stream_store) as Arc<_>, + &metrics_registry, + ))); + pool.lock() + .unwrap() + .garbage_collect(btreemap! [REMOTE_SUBNET => stream_position.clone()]); + + let registry = get_registry_for_test(); + let proximity_map = Arc::new(ProximityMap::new( + OWN_NODE, + registry.clone(), + &metrics_registry, + log.clone(), + )); + let endpoint_resolver = XNetEndpointResolver::new( + registry.clone(), + OWN_NODE, + OWN_SUBNET, + proximity_map, + log.clone(), + ); + let byte_limit = (POOL_SLICE_BYTE_SIZE_MAX - 350) * 98 / 100; + let url = endpoint_resolver + .xnet_endpoint_url(REMOTE_SUBNET, from, from, byte_limit) + .unwrap() + .url + .to_string(); + + let xnet_client = Arc::new(FakeXNetClient { + results: btreemap![ + url => Ok(certified_slice.clone()), + ], }); - } - /// Tests refilling a pool with an already existing slice, requiring an - /// append. - #[test] - fn refill_pool_append( - (stream, from, msg_count) in arb_stream_slice(10, 15, 0, 10), - ) { - // Bump `from` so we always get a non-empty prefix. - let from = from.increment(); - let msg_count = msg_count - 1; - - with_test_replica_logger(|log| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - - let stream_begin = stream.messages_begin(); - let prefix_msg_count = (from - stream_begin).get() as usize; - let certified_prefix = in_slice(&stream, stream_begin, stream_begin, prefix_msg_count, &log); - let certified_suffix = in_slice(&stream, stream_begin, from, msg_count, &log); - let expected_msg_count = prefix_msg_count + msg_count; - let slice = stream.slice(stream_begin, Some(expected_msg_count)); - let certified_slice = in_slice(&stream, stream_begin, stream_begin, expected_msg_count, &log); - - let stream_position = ExpectedIndices { - message_index: stream_begin, - signal_index: stream.signals_end(), - }; - - let mut certified_stream_store = MockCertifiedStreamStore::new(); - // Actual return value does not matter as long as it's `Ok(_)`. - certified_stream_store - .expect_decode_certified_stream_slice() - .returning(move |_, _, _| Ok(slice.clone())); - let metrics_registry = MetricsRegistry::new(); - let pool = Arc::new(Mutex::new(CertifiedSlicePool::new(Arc::new(certified_stream_store) as Arc<_>, &metrics_registry))); - let prefix_size_bytes = UnpackedStreamSlice::try_from(certified_prefix.clone()).unwrap().count_bytes(); - { - let mut pool = pool.lock().unwrap(); - pool.put(REMOTE_SUBNET, certified_prefix, REGISTRY_VERSION, log.clone()).unwrap(); - pool.garbage_collect(btreemap! [REMOTE_SUBNET => stream_position.clone()]); + let refill_handle = PoolRefillTask::start( + Arc::clone(&pool), + endpoint_resolver, + xnet_client, + runtime.handle().clone(), + Arc::new(XNetPayloadBuilderMetrics::new(&metrics_registry)), + log, + ); + refill_handle.trigger_refill(registry.get_latest_version()); + + runtime.block_on(async { + let mut count: u64 = 0; + // Keep polling until a slice is present in the pool. + loop { + if let (_, Some(_), _, _) = pool.lock().unwrap().slice_stats(REMOTE_SUBNET) { + break; + } + count += 1; + if count > 50 { + panic!("refill task failed to complete within 5 seconds"); + } + tokio::time::sleep(Duration::from_millis(100)).await; } + }); - let registry = get_registry_for_test(); - let proximity_map = Arc::new(ProximityMap::new(OWN_NODE, registry.clone(), &metrics_registry, log.clone())); - let endpoint_resolver = - XNetEndpointResolver::new(registry.clone(), OWN_NODE, OWN_SUBNET, proximity_map, log.clone()); - let byte_limit = (POOL_SLICE_BYTE_SIZE_MAX - prefix_size_bytes - 350) * 98 / 100; - let url = endpoint_resolver - .xnet_endpoint_url(REMOTE_SUBNET, stream_begin, from, byte_limit) + assert_opt_slices_eq( + Some(certified_slice), + pool.lock() + .unwrap() + .take_slice(REMOTE_SUBNET, Some(&stream_position), None, None) .unwrap() - .url - .to_string(); - - let xnet_client = Arc::new(FakeXNetClient { - results: btreemap![ - url => Ok(certified_suffix), - ], - }); - - let refill_handle = PoolRefillTask::start( - Arc::clone(&pool), - endpoint_resolver, - xnet_client, - runtime.handle().clone(), - Arc::new(XNetPayloadBuilderMetrics::new(&metrics_registry)), + .map(|(slice, _)| slice), + ); + }); +} + +/// Tests refilling a pool with an already existing slice, requiring an +/// append. +#[test_strategy::proptest] +fn refill_pool_append( + #[strategy(arb_stream_slice( + 10, // min_size + 15, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (stream, from, msg_count) = test_slice; + + // Bump `from` so we always get a non-empty prefix. + let from = from.increment(); + let msg_count = msg_count - 1; + + with_test_replica_logger(|log| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + + let stream_begin = stream.messages_begin(); + let prefix_msg_count = (from - stream_begin).get() as usize; + let certified_prefix = + in_slice(&stream, stream_begin, stream_begin, prefix_msg_count, &log); + let certified_suffix = in_slice(&stream, stream_begin, from, msg_count, &log); + let expected_msg_count = prefix_msg_count + msg_count; + let slice = stream.slice(stream_begin, Some(expected_msg_count)); + let certified_slice = in_slice( + &stream, + stream_begin, + stream_begin, + expected_msg_count, + &log, + ); + + let stream_position = ExpectedIndices { + message_index: stream_begin, + signal_index: stream.signals_end(), + }; + + let mut certified_stream_store = MockCertifiedStreamStore::new(); + // Actual return value does not matter as long as it's `Ok(_)`. + certified_stream_store + .expect_decode_certified_stream_slice() + .returning(move |_, _, _| Ok(slice.clone())); + let metrics_registry = MetricsRegistry::new(); + let pool = Arc::new(Mutex::new(CertifiedSlicePool::new( + Arc::new(certified_stream_store) as Arc<_>, + &metrics_registry, + ))); + let prefix_size_bytes = UnpackedStreamSlice::try_from(certified_prefix.clone()) + .unwrap() + .count_bytes(); + { + let mut pool = pool.lock().unwrap(); + pool.put( + REMOTE_SUBNET, + certified_prefix, + REGISTRY_VERSION, log.clone(), - ); - refill_handle.trigger_refill(registry.get_latest_version()); - - runtime.block_on(async { - let mut count: u64 = 0; - // Keep polling until the pooled slice has `expected_msg_count` messages. - loop { - let (_, _, msg_count, _) = pool.lock().unwrap().slice_stats(REMOTE_SUBNET); - if msg_count == expected_msg_count { - break; - } - count += 1; - if count > 50 { - panic!("refill task failed to complete within 5 seconds"); - } - tokio::time::sleep(Duration::from_millis(100)).await; + ) + .unwrap(); + pool.garbage_collect(btreemap! [REMOTE_SUBNET => stream_position.clone()]); + } + + let registry = get_registry_for_test(); + let proximity_map = Arc::new(ProximityMap::new( + OWN_NODE, + registry.clone(), + &metrics_registry, + log.clone(), + )); + let endpoint_resolver = XNetEndpointResolver::new( + registry.clone(), + OWN_NODE, + OWN_SUBNET, + proximity_map, + log.clone(), + ); + let byte_limit = (POOL_SLICE_BYTE_SIZE_MAX - prefix_size_bytes - 350) * 98 / 100; + let url = endpoint_resolver + .xnet_endpoint_url(REMOTE_SUBNET, stream_begin, from, byte_limit) + .unwrap() + .url + .to_string(); + + let xnet_client = Arc::new(FakeXNetClient { + results: btreemap![ + url => Ok(certified_suffix), + ], + }); + + let refill_handle = PoolRefillTask::start( + Arc::clone(&pool), + endpoint_resolver, + xnet_client, + runtime.handle().clone(), + Arc::new(XNetPayloadBuilderMetrics::new(&metrics_registry)), + log.clone(), + ); + refill_handle.trigger_refill(registry.get_latest_version()); + + runtime.block_on(async { + let mut count: u64 = 0; + // Keep polling until the pooled slice has `expected_msg_count` messages. + loop { + let (_, _, msg_count, _) = pool.lock().unwrap().slice_stats(REMOTE_SUBNET); + if msg_count == expected_msg_count { + break; } - }); - - assert_opt_slices_eq( - Some(certified_slice), - pool.lock() - .unwrap() - .take_slice(REMOTE_SUBNET, Some(&stream_position), None, None) - .unwrap() - .map(|(slice, _)| slice), - ); + count += 1; + if count > 50 { + panic!("refill task failed to complete within 5 seconds"); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } }); - } - /// Tests handling of an invalid slice when refilling the pool. - #[test] - fn refill_pool_put_invalid_slice( - (stream, from, msg_count) in arb_stream_slice(10, 15, 0, 10), - ) { - with_test_replica_logger(|log| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - - let stream_position = ExpectedIndices { - message_index: from, - signal_index: stream.signals_end(), - }; - - let mut certified_stream_store = MockCertifiedStreamStore::new(); - certified_stream_store - .expect_decode_certified_stream_slice() - .returning(|_, _, _| Err(DecodeStreamError::InvalidSignature(REMOTE_SUBNET))); - let metrics_registry = MetricsRegistry::new(); - let pool = Arc::new(Mutex::new(CertifiedSlicePool::new(Arc::new(certified_stream_store) as Arc<_>, &metrics_registry))); + assert_opt_slices_eq( + Some(certified_slice), pool.lock() .unwrap() - .garbage_collect(btreemap! [REMOTE_SUBNET => stream_position.clone()]); - - let registry = get_registry_for_test(); - let proximity_map = Arc::new(ProximityMap::new(OWN_NODE, registry.clone(), &metrics_registry, log.clone())); - let endpoint_resolver = - XNetEndpointResolver::new(registry.clone(), OWN_NODE, OWN_SUBNET, proximity_map, log.clone()); - let byte_limit = (POOL_SLICE_BYTE_SIZE_MAX - 350) * 98 / 100; - let url = endpoint_resolver - .xnet_endpoint_url(REMOTE_SUBNET, from, from, byte_limit) + .take_slice(REMOTE_SUBNET, Some(&stream_position), None, None) .unwrap() - .url - .to_string(); - - let certified_slice = in_slice(&stream, from, from, msg_count, &log); - let xnet_client = Arc::new(FakeXNetClient { - results: btreemap![ - url => Ok(certified_slice), - ], - }); - - let refill_handle = PoolRefillTask::start( - Arc::clone(&pool), - endpoint_resolver, - xnet_client, - runtime.handle().clone(), - Arc::new(XNetPayloadBuilderMetrics::new(&metrics_registry)), - log, - ); - refill_handle.trigger_refill(registry.get_latest_version()); - - runtime.block_on(async { - let mut count: u64 = 0; - // Keep polling until we observe a `DecodeStreamError`. - loop { - if let Some(1) = fetch_int_counter_vec(&metrics_registry, METRIC_PULL_ATTEMPT_COUNT).get(&btreemap![LABEL_STATUS.into() => "DecodeStreamError".into()]) - { - break; - } - count += 1; - if count > 50 { - panic!("refill task failed to complete within 5 seconds"); - } - tokio::time::sleep(Duration::from_millis(100)).await; - } - }); + .map(|(slice, _)| slice), + ); + }); +} - assert!( - pool.lock() - .unwrap() - .take_slice(REMOTE_SUBNET, Some(&stream_position), None, None) - .unwrap() - .is_none(), - ); +/// Tests handling of an invalid slice when refilling the pool. +#[test_strategy::proptest] +fn refill_pool_put_invalid_slice( + #[strategy(arb_stream_slice( + 10, // min_size + 15, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (stream, from, msg_count) = test_slice; + + with_test_replica_logger(|log| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + + let stream_position = ExpectedIndices { + message_index: from, + signal_index: stream.signals_end(), + }; + + let mut certified_stream_store = MockCertifiedStreamStore::new(); + certified_stream_store + .expect_decode_certified_stream_slice() + .returning(|_, _, _| Err(DecodeStreamError::InvalidSignature(REMOTE_SUBNET))); + let metrics_registry = MetricsRegistry::new(); + let pool = Arc::new(Mutex::new(CertifiedSlicePool::new( + Arc::new(certified_stream_store) as Arc<_>, + &metrics_registry, + ))); + pool.lock() + .unwrap() + .garbage_collect(btreemap! [REMOTE_SUBNET => stream_position.clone()]); + + let registry = get_registry_for_test(); + let proximity_map = Arc::new(ProximityMap::new( + OWN_NODE, + registry.clone(), + &metrics_registry, + log.clone(), + )); + let endpoint_resolver = XNetEndpointResolver::new( + registry.clone(), + OWN_NODE, + OWN_SUBNET, + proximity_map, + log.clone(), + ); + let byte_limit = (POOL_SLICE_BYTE_SIZE_MAX - 350) * 98 / 100; + let url = endpoint_resolver + .xnet_endpoint_url(REMOTE_SUBNET, from, from, byte_limit) + .unwrap() + .url + .to_string(); + + let certified_slice = in_slice(&stream, from, from, msg_count, &log); + let xnet_client = Arc::new(FakeXNetClient { + results: btreemap![ + url => Ok(certified_slice), + ], }); - } - /// Tests validation failure while refilling a pool with an already existing - /// slice, requiring an append. - #[test] - fn refill_pool_append_invalid_slice( - (stream, from, msg_count) in arb_stream_slice(10, 15, 0, 10), - ) { - // Bump `from` so we always get a non-empty prefix. - let from = from.increment(); - let msg_count = msg_count - 1; - - with_test_replica_logger(|log| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - - let stream_begin = stream.messages_begin(); - let prefix_msg_count = (from - stream_begin).get() as usize; - let certified_prefix = in_slice(&stream, stream_begin, stream_begin, prefix_msg_count, &log); - let certified_suffix = in_slice(&stream, stream_begin, from, msg_count, &log); - let expected_msg_count = prefix_msg_count + msg_count; - let slice = stream.slice(stream_begin, Some(expected_msg_count)); - - let stream_position = ExpectedIndices { - message_index: stream_begin, - signal_index: stream.signals_end(), - }; - - let mut certified_stream_store = MockCertifiedStreamStore::new(); - // Accept the prefix as valid, but fail validation for the merged slice. - certified_stream_store - .expect_decode_certified_stream_slice() - .with(always(), always(), eq(certified_prefix.clone())) - .returning(move |_, _, _| Ok(slice.clone())); - certified_stream_store - .expect_decode_certified_stream_slice() - .returning(move |_, _, _| Err(DecodeStreamError::InvalidSignature(REMOTE_SUBNET))); - let metrics_registry = MetricsRegistry::new(); - let pool = Arc::new(Mutex::new(CertifiedSlicePool::new(Arc::new(certified_stream_store) as Arc<_>, &metrics_registry))); - let prefix_size_bytes = UnpackedStreamSlice::try_from(certified_prefix.clone()).unwrap().count_bytes(); - - { - let mut pool = pool.lock().unwrap(); - pool.put(REMOTE_SUBNET, certified_prefix.clone(), REGISTRY_VERSION, log.clone()).unwrap(); - pool.garbage_collect(btreemap! [REMOTE_SUBNET => stream_position.clone()]); + let refill_handle = PoolRefillTask::start( + Arc::clone(&pool), + endpoint_resolver, + xnet_client, + runtime.handle().clone(), + Arc::new(XNetPayloadBuilderMetrics::new(&metrics_registry)), + log, + ); + refill_handle.trigger_refill(registry.get_latest_version()); + + runtime.block_on(async { + let mut count: u64 = 0; + // Keep polling until we observe a `DecodeStreamError`. + loop { + if let Some(1) = fetch_int_counter_vec(&metrics_registry, METRIC_PULL_ATTEMPT_COUNT) + .get(&btreemap![LABEL_STATUS.into() => "DecodeStreamError".into()]) + { + break; + } + count += 1; + if count > 50 { + panic!("refill task failed to complete within 5 seconds"); + } + tokio::time::sleep(Duration::from_millis(100)).await; } + }); - let registry = get_registry_for_test(); - let proximity_map = Arc::new(ProximityMap::new(OWN_NODE, registry.clone(), &metrics_registry, log.clone())); - let endpoint_resolver = - XNetEndpointResolver::new(registry.clone(), OWN_NODE, OWN_SUBNET, proximity_map, log.clone()); - let byte_limit = (POOL_SLICE_BYTE_SIZE_MAX - prefix_size_bytes - 350) * 98 / 100; - let url = endpoint_resolver - .xnet_endpoint_url(REMOTE_SUBNET, stream_begin, from, byte_limit) - .unwrap() - .url - .to_string(); - - let xnet_client = Arc::new(FakeXNetClient { - results: btreemap![ - url => Ok(certified_suffix), - ], - }); - - let refill_handle = PoolRefillTask::start( - Arc::clone(&pool), - endpoint_resolver, - xnet_client, - runtime.handle().clone(), - Arc::new(XNetPayloadBuilderMetrics::new(&metrics_registry)), + assert!(pool + .lock() + .unwrap() + .take_slice(REMOTE_SUBNET, Some(&stream_position), None, None) + .unwrap() + .is_none(),); + }); +} + +/// Tests validation failure while refilling a pool with an already existing +/// slice, requiring an append. +#[test_strategy::proptest] +fn refill_pool_append_invalid_slice( + #[strategy(arb_stream_slice( + 10, // min_size + 15, // max_size + 0, // min_signal_count + 10, // max_signal_count + ))] + test_slice: (Stream, StreamIndex, usize), +) { + let (stream, from, msg_count) = test_slice; + + // Bump `from` so we always get a non-empty prefix. + let from = from.increment(); + let msg_count = msg_count - 1; + + with_test_replica_logger(|log| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + + let stream_begin = stream.messages_begin(); + let prefix_msg_count = (from - stream_begin).get() as usize; + let certified_prefix = + in_slice(&stream, stream_begin, stream_begin, prefix_msg_count, &log); + let certified_suffix = in_slice(&stream, stream_begin, from, msg_count, &log); + let expected_msg_count = prefix_msg_count + msg_count; + let slice = stream.slice(stream_begin, Some(expected_msg_count)); + + let stream_position = ExpectedIndices { + message_index: stream_begin, + signal_index: stream.signals_end(), + }; + + let mut certified_stream_store = MockCertifiedStreamStore::new(); + // Accept the prefix as valid, but fail validation for the merged slice. + certified_stream_store + .expect_decode_certified_stream_slice() + .with(always(), always(), eq(certified_prefix.clone())) + .returning(move |_, _, _| Ok(slice.clone())); + certified_stream_store + .expect_decode_certified_stream_slice() + .returning(move |_, _, _| Err(DecodeStreamError::InvalidSignature(REMOTE_SUBNET))); + let metrics_registry = MetricsRegistry::new(); + let pool = Arc::new(Mutex::new(CertifiedSlicePool::new( + Arc::new(certified_stream_store) as Arc<_>, + &metrics_registry, + ))); + let prefix_size_bytes = UnpackedStreamSlice::try_from(certified_prefix.clone()) + .unwrap() + .count_bytes(); + + { + let mut pool = pool.lock().unwrap(); + pool.put( + REMOTE_SUBNET, + certified_prefix.clone(), + REGISTRY_VERSION, log.clone(), - ); - refill_handle.trigger_refill(registry.get_latest_version()); - - runtime.block_on(async { - let mut count: u64 = 0; - // Keep polling until we observe a `DecodeStreamError`. - loop { - if let Some(1) = fetch_int_counter_vec(&metrics_registry, METRIC_PULL_ATTEMPT_COUNT).get(&btreemap![LABEL_STATUS.into() => "DecodeStreamError".into()]) - { - break; - } - count += 1; - if count > 50 { - panic!("refill task failed to complete within 5 seconds"); - } - tokio::time::sleep(Duration::from_millis(100)).await; + ) + .unwrap(); + pool.garbage_collect(btreemap! [REMOTE_SUBNET => stream_position.clone()]); + } + + let registry = get_registry_for_test(); + let proximity_map = Arc::new(ProximityMap::new( + OWN_NODE, + registry.clone(), + &metrics_registry, + log.clone(), + )); + let endpoint_resolver = XNetEndpointResolver::new( + registry.clone(), + OWN_NODE, + OWN_SUBNET, + proximity_map, + log.clone(), + ); + let byte_limit = (POOL_SLICE_BYTE_SIZE_MAX - prefix_size_bytes - 350) * 98 / 100; + let url = endpoint_resolver + .xnet_endpoint_url(REMOTE_SUBNET, stream_begin, from, byte_limit) + .unwrap() + .url + .to_string(); + + let xnet_client = Arc::new(FakeXNetClient { + results: btreemap![ + url => Ok(certified_suffix), + ], + }); + + let refill_handle = PoolRefillTask::start( + Arc::clone(&pool), + endpoint_resolver, + xnet_client, + runtime.handle().clone(), + Arc::new(XNetPayloadBuilderMetrics::new(&metrics_registry)), + log.clone(), + ); + refill_handle.trigger_refill(registry.get_latest_version()); + + runtime.block_on(async { + let mut count: u64 = 0; + // Keep polling until we observe a `DecodeStreamError`. + loop { + if let Some(1) = fetch_int_counter_vec(&metrics_registry, METRIC_PULL_ATTEMPT_COUNT) + .get(&btreemap![LABEL_STATUS.into() => "DecodeStreamError".into()]) + { + break; } - }); - - // Only the prefix is pooled. - assert_opt_slices_eq( - Some(certified_prefix), - pool.lock() - .unwrap() - .take_slice(REMOTE_SUBNET, Some(&stream_position), None, None) - .unwrap() - .map(|(slice, _)| slice), - ); + count += 1; + if count > 50 { + panic!("refill task failed to complete within 5 seconds"); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } }); - } + + // Only the prefix is pooled. + assert_opt_slices_eq( + Some(certified_prefix), + pool.lock() + .unwrap() + .take_slice(REMOTE_SUBNET, Some(&stream_position), None, None) + .unwrap() + .map(|(slice, _)| slice), + ); + }); }