diff --git a/.github/scripts/build-server.sh b/.github/scripts/build-server.sh index 120a9065a4..16ebcf0ebc 100755 --- a/.github/scripts/build-server.sh +++ b/.github/scripts/build-server.sh @@ -6,8 +6,7 @@ environment=$1 aws s3 sync .git s3://tlsn-deploy/$environment/.git --delete -cd notary-server -cargo build --release -aws s3 cp target/release/notary-server s3://tlsn-deploy/$environment/ +cargo build -p notary-server --release +aws s3 cp ./target/release/notary-server s3://tlsn-deploy/$environment/ exit 0 diff --git a/.github/scripts/deploy-server.sh b/.github/scripts/deploy-server.sh index 844bc6be3e..ff86ca1742 100755 --- a/.github/scripts/deploy-server.sh +++ b/.github/scripts/deploy-server.sh @@ -4,11 +4,11 @@ set -ex environment=$1 branch=$2 -INSTANCE_ID=$(aws ec2 describe-instances --filters Name=tag:Name,Values=[tlsnotary-backend] --query "Reservations[*].Instances[*][InstanceId]" --output text) +INSTANCE_ID=$(aws ec2 describe-instances --filters Name=tag:Name,Values=[tlsnotary-backend-v1] Name=instance-state-name,Values=[running] --query "Reservations[*].Instances[*][InstanceId]" --output text) aws ec2 create-tags --resources $INSTANCE_ID --tags "Key=$environment,Value=$branch" COMMIT_HASH=$(git rev-parse HEAD) -DEPLOY_ID=$(aws deploy create-deployment --application-name tlsn-$environment --deployment-group-name tlsn-$environment-group --github-location repository=$GITHUB_REPOSITORY,commitId=$COMMIT_HASH --ignore-application-stop-failures --file-exists OVERWRITE --output text) +DEPLOY_ID=$(aws deploy create-deployment --application-name tlsn-$environment-v1 --deployment-group-name tlsn-$environment-v1-group --github-location repository=$GITHUB_REPOSITORY,commitId=$COMMIT_HASH --ignore-application-stop-failures --file-exists OVERWRITE --output text) while true; do STATUS=$(aws deploy get-deployment --deployment-id $DEPLOY_ID --query 'deploymentInfo.status' --output text) diff --git a/.github/scripts/modify-proxy.sh b/.github/scripts/modify-proxy.sh new file mode 100755 index 0000000000..921b1641f3 --- /dev/null +++ b/.github/scripts/modify-proxy.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# This script is triggered by Deploy server workflow in order to send an execution command of cd-scripts/modify_proxy.sh via AWS SSM to the proxy server + +set -e + +GH_OWNER="tlsnotary" +GH_REPO="tlsn" +BACKEND_INSTANCE_ID=$(aws ec2 describe-instances --filters Name=tag:Name,Values=[tlsnotary-backend-v1] Name=instance-state-name,Values=[running] --query "Reservations[*].Instances[*][InstanceId]" --output text) +PROXY_INSTANCE_ID=$(aws ec2 describe-instances --filters Name=tag:Name,Values=[tlsnotary-web] Name=instance-state-name,Values=[running] --query "Reservations[*].Instances[*][InstanceId]" --output text) +TAGS=$(aws ec2 describe-instances --instance-ids $BACKEND_INSTANCE_ID --query 'Reservations[*].Instances[*].Tags') + +TAG=$(echo $TAGS | jq -r '.[][][] | select(.Key == "stable").Value') +PORT=$(echo $TAGS | jq -r '.[][][] | select(.Key == "port").Value') + +COMMAND_ID=$(aws ssm send-command --document-name "AWS-RunRemoteScript" --instance-ids $PROXY_INSTANCE_ID --parameters '{"sourceType":["GitHub"],"sourceInfo":["{\"owner\":\"'${GH_OWNER}'\", \"repository\":\"'${GH_REPO}'\", \"getOptions\":\"branch:'${TAG}'\", \"path\": \"cd-scripts\"}"],"commandLine":["modify_proxy.sh '${PORT}' '${TAG}' "]}' --output text --query "Command.CommandId") + +while true; do + SSM_STATUS=$(aws ssm list-command-invocations --command-id $COMMAND_ID --details --query "CommandInvocations[].Status" --output text) + + if [ $SSM_STATUS != "Success" ] && [ $SSM_STATUS != "InProgress" ]; then + echo "Proxy modification failed" + aws ssm list-command-invocations --command-id $COMMAND_ID --details --query "CommandInvocations[].CommandPlugins[].{Status:Status,Output:Output}" + exit 1 + elif [ $SSM_STATUS = "Success" ]; then + aws ssm list-command-invocations --command-id $COMMAND_ID --details --query "CommandInvocations[].CommandPlugins[].{Status:Status,Output:Output}" + echo "Success" + break + fi + + sleep 2 +done + +exit 0 diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml new file mode 100644 index 0000000000..61a1eac9c7 --- /dev/null +++ b/.github/workflows/bench.yml @@ -0,0 +1,27 @@ +name: Run Benchmarks +on: + # manual trigger + workflow_dispatch: + +jobs: + run-benchmarks: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Build Docker Image + run: | + docker build -t tlsn-bench . -f ./crates/benches/benches.Dockerfile + + - name: Run Benchmarks + run: | + docker run --privileged -v ${{ github.workspace }}/crates/benches/:/benches tlsn-bench + + - name: Upload runtime_vs_latency.html + uses: actions/upload-artifact@v4 + with: + name: benchmark_graphs + path: | + ./crates/benches/runtime_vs_latency.html + ./crates/benches/runtime_vs_bandwidth.html diff --git a/.github/workflows/cd-server.yml b/.github/workflows/cd-server.yml index 562f6b9058..123dc8cc94 100644 --- a/.github/workflows/cd-server.yml +++ b/.github/workflows/cd-server.yml @@ -44,13 +44,13 @@ jobs: exit 1 fi - - name: Wait for test workflow to succeed + - name: Wait for integration test workflow to succeed if: github.event_name == 'push' uses: lewagon/wait-on-check-action@v1.3.1 with: ref: ${{ github.ref }} - # Have to be specify '(notary-server)', as we are using matrix for build_and_test job in ci.yml, else it will fail, more details [here](https://github.com/lewagon/wait-on-check-action#check-name) - check-name: 'Build and test (notary-server)' + # More details [here](https://github.com/lewagon/wait-on-check-action#check-name) + check-name: 'Run tests release build' repo-token: ${{ secrets.GITHUB_TOKEN }} # How frequent (in seconds) this job will call GitHub API to check the status of the job specified at 'check-name' wait-interval: 60 @@ -71,12 +71,6 @@ jobs: uses: dtolnay/rust-toolchain@stable with: toolchain: stable - components: clippy - - - name: Use caching - uses: Swatinem/rust-cache@v2.5.0 - with: - workspaces: ${{ matrix.package }} -> target - name: Cargo build run: | @@ -85,3 +79,8 @@ jobs: - name: Trigger Deployment run: | .github/scripts/deploy-server.sh ${{ steps.manipulate.outputs.env }} $GITHUB_REF_NAME + + - name: Modify Proxy + if: ${{ steps.manipulate.outputs.env == 'stable' }} + run: | + .github/scripts/modify-proxy.sh diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 9779e52426..60767652ad 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -16,18 +16,18 @@ jobs: contents: read packages: write steps: - - name: Wait for test workflow to succeed + - name: Wait for integration test workflow to succeed uses: lewagon/wait-on-check-action@v1.3.1 with: ref: ${{ github.ref }} - # Have to be specify '(notary-server)', as we are using matrix for build_and_test job in ci.yml, else it will fail, more details [here](https://github.com/lewagon/wait-on-check-action#check-name) - check-name: 'Build and test (notary-server)' + # More details [here](https://github.com/lewagon/wait-on-check-action#check-name) + check-name: 'Run tests release build' repo-token: ${{ secrets.GITHUB_TOKEN }} # How frequent (in seconds) this job will call GitHub API to check the status of the job specified at 'check-name' wait-interval: 60 - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Log in to the Container registry uses: docker/login-action@v2 @@ -49,4 +49,4 @@ jobs: push: true tags: ${{ steps.meta-notary-server.outputs.tags }} labels: ${{ steps.meta-notary-server.outputs.labels }} - file: ./notary-server/notary-server.Dockerfile + file: ./crates/notary/server/notary-server.Dockerfile diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 27197dfdfb..3ab2a2f6de 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,49 +13,39 @@ on: env: CARGO_TERM_COLOR: always CARGO_REGISTRIES_CRATES_IO_PROTOCOL: sparse + # We need a higher number of parallel rayon tasks than the default (which is 4) + # in order to prevent a deadlock, c.f. + # - https://github.com/tlsnotary/tlsn/issues/548 + # - https://github.com/privacy-scaling-explorations/mpz/issues/178 + # 32 seems to be big enough for the foreseeable future + RAYON_NUM_THREADS: 32 jobs: - build_and_test: - name: Build and test - if: ( ! github.event.pull_request.draft ) + fmt: + name: Check formatting runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - package: - - components/integration-tests - - components/uid-mux - - components/cipher - - components/universal-hash - - components/aead - - components/key-exchange - - components/point-addition - - components/prf - - components/tls - - tlsn - - notary-server - include: - - package: components/integration-tests - release: true - - package: notary-server - release: true - - package: tlsn - all-features: true - defaults: - run: - working-directory: ${{ matrix.package }} steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 + # We use nightly to support `imports_granularity` feature - name: Install nightly rust toolchain with rustfmt uses: dtolnay/rust-toolchain@stable with: toolchain: nightly components: rustfmt - - name: "Check formatting" + - name: Use caching + uses: Swatinem/rust-cache@v2.7.3 + + - name: Check formatting run: cargo +nightly fmt --check --all + build-and-test: + name: Build and test + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 - name: Install stable rust toolchain uses: dtolnay/rust-toolchain@stable @@ -63,32 +53,89 @@ jobs: toolchain: stable components: clippy - - name: "Clippy" - run: cargo clippy --all-features --examples -- -D warnings - - name: Use caching - uses: Swatinem/rust-cache@v2.5.0 + uses: Swatinem/rust-cache@v2.7.3 + + - name: Clippy + run: cargo clippy --all-features --all-targets -- -D warnings + + - name: Build + run: cargo build --all-targets + + - name: Test + run: cargo test + build-wasm: + name: Build and test wasm + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install stable rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + targets: wasm32-unknown-unknown + toolchain: stable + + - name: Install nightly rust toolchain + uses: dtolnay/rust-toolchain@stable with: - workspaces: ${{ matrix.package }} -> target + targets: wasm32-unknown-unknown,x86_64-unknown-linux-gnu + toolchain: nightly + components: rust-src + + - name: Install chromedriver + run: | + sudo apt-get update + sudo apt-get install -y chromium-chromedriver - - name: "Build" - run: cargo build ${{ matrix.release && '--release' }} + - name: Install wasm-pack + run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh + + - name: Use caching + uses: Swatinem/rust-cache@v2.7.3 - - name: "Test" - if: ${{ matrix.release != true }} - run: cargo test --lib --bins --tests --examples --workspace + - name: Run tests + run: | + cd crates/wasm-test-runner + ./run.sh + tests-integration: + name: Run tests release build + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 - - name: "Test all features" - if: ${{ matrix.release != true && matrix.all-features == true }} - run: cargo test --lib --bins --tests --examples --workspace --all-features + - name: Install stable rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: stable - - name: "Integration Test" - if: ${{ matrix.release == true }} - run: cargo test --release --tests + - name: Use caching + uses: Swatinem/rust-cache@v2.7.3 - - name: "Integration Test all features" - if: ${{ matrix.release == true && matrix.all-features == true }} - run: cargo test --release --tests --all-features + - name: Add custom DNS entry to /etc/hosts for notary TLS test + run: echo "127.0.0.1 tlsnotaryserver.io" | sudo tee -a /etc/hosts - - name: "Check that benches compile" - run: cargo bench --no-run + - name: Run integration tests + run: cargo test --profile tests-integration --workspace --exclude tlsn-tls-client --exclude tlsn-tls-core -- --include-ignored + coverage: + runs-on: ubuntu-latest + env: + CARGO_TERM_COLOR: always + steps: + - uses: actions/checkout@v4 + - name: Install stable rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: stable + - name: Install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov + - name: Generate code coverage + run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: lcov.info + fail_ci_if_error: true \ No newline at end of file diff --git a/.github/workflows/rebase.yml b/.github/workflows/rebase.yml new file mode 100644 index 0000000000..d28b07aac8 --- /dev/null +++ b/.github/workflows/rebase.yml @@ -0,0 +1,24 @@ +name: Automatic Rebase +on: + issue_comment: + types: [created] +jobs: + rebase: + name: Rebase + runs-on: ubuntu-latest + if: >- + github.event.issue.pull_request != '' && + contains(github.event.comment.body, '/rebase') && + github.event.comment.author_association == 'MEMBER' + steps: + - name: Checkout the latest code + uses: actions/checkout@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} + fetch-depth: 0 # otherwise, you will fail to push refs to dest repo + - name: Automatic Rebase + uses: cirrus-actions/rebase@1.8 + with: + autosquash: false + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/rustdoc.yml b/.github/workflows/rustdoc.yml index 11fd84c022..5c86ee6cd8 100644 --- a/.github/workflows/rustdoc.yml +++ b/.github/workflows/rustdoc.yml @@ -12,10 +12,9 @@ env: jobs: rustdoc: - if: ( ! github.event.pull_request.draft ) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install Rust Toolchain (Stable) uses: dtolnay/rust-toolchain@stable @@ -23,18 +22,18 @@ jobs: toolchain: stable - name: "rustdoc" - run: cd tlsn; cargo doc -p tlsn-core -p tlsn-prover -p tlsn-verifier --no-deps --all-features + run: cargo doc -p tlsn-core -p tlsn-prover -p tlsn-verifier --no-deps --all-features # --target-dir ${GITHUB_WORKSPACE}/docs # https://dev.to/deciduously/prepare-your-rust-api-docs-for-github-pages-2n5i - name: "Add index file -> tlsn_prover" run: | - echo "" > tlsn/target/doc/index.html + echo "" > target/doc/index.html - name: Deploy uses: peaceiris/actions-gh-pages@v3 if: ${{ github.ref == 'refs/heads/dev' }} with: github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: tlsn/target/doc/ + publish_dir: target/doc/ # cname: rustdocs.tlsnotary.org diff --git a/.github/workflows/wasm.yml b/.github/workflows/wasm.yml deleted file mode 100644 index 6b87afb181..0000000000 --- a/.github/workflows/wasm.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: wasm-build - -on: - push: - branches: - - dev - tags: - - "[v]?[0-9]+.[0-9]+.[0-9]+*" - pull_request: - branches: - - dev - -env: - CARGO_TERM_COLOR: always - CARGO_REGISTRIES_CRATES_IO_PROTOCOL: sparse - -jobs: - build_and_test: - name: Build for target wasm32-unknown-unknown - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - package: - - tlsn/tlsn-core - - tlsn/tlsn-prover - - components/tls/tls-client - defaults: - run: - working-directory: ${{ matrix.package }} - steps: - - name: Checkout repository - uses: actions/checkout@v3 - - - name: Install stable rust toolchain - uses: dtolnay/rust-toolchain@stable - with: - targets: wasm32-unknown-unknown - toolchain: stable - - - name: Use caching - uses: Swatinem/rust-cache@v2.5.0 - with: - workspaces: ${{ matrix.package }} -> ../target - - - name: "Build" - run: cargo build --target wasm32-unknown-unknown diff --git a/.gitignore b/.gitignore index c169048c8c..f79cb086da 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,6 @@ Cargo.lock # logs *.log + +# metrics +*.csv \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index ef88a6b9c9..0000000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,7 +0,0 @@ -# Changelog -All notable changes to this project will be documented in this file. - -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [Unreleased] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4147bd9727..ea212bbdbc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -16,27 +16,20 @@ keywords. Try to do one pull request per change. -### Updating the changelog +## Linting -Update the changes you have made in -[CHANGELOG](CHANGELOG.md) -file under the **Unreleased** section. +Before a Pull Request (PR) can be merged, the Continuous Integration (CI) pipeline automatically lints all code using [Clippy](https://doc.rust-lang.org/stable/clippy/usage.html). To ensure your code is free of linting issues before creating a PR, run the following command: -Add the changes of your pull request to one of the following subsections, -depending on the types of changes defined by -[Keep a changelog](https://keepachangelog.com/en/1.0.0/): - -- `Added` for new features. -- `Changed` for changes in existing functionality. -- `Deprecated` for soon-to-be removed features. -- `Removed` for now removed features. -- `Fixed` for any bug fixes. -- `Security` in case of vulnerabilities. +```sh +cargo clippy --all-features --all-targets -- -D warnings +``` -If the required subsection does not exist yet under **Unreleased**, create it! +This command will lint your code with all features and targets enabled, and treat any warnings as errors, ensuring that your code meets the required standards. ## Style +This repository includes a `rustfmt.toml` file with custom formatting settings that are automatically validated by CI before any Pull Requests (PRs) can be merged. To ensure your code adheres to these standards, format your code using this configuration before submitting a PR. We strongly recommend enabling *auto format on save* to streamline this process. In Visual Studio Code (VSCode), you can enable this feature by turning on [`editor.formatOnSave`](https://code.visualstudio.com/docs/editor/codebasics#_formatting) in the settings. + ### Capitalization and punctuation Both line comments and doc comments must be capitalized. Each sentence must end with a period. @@ -61,6 +54,7 @@ Comments for function arguments must adhere to this pattern: /// Performs a certain computation. Any other description of the function. /// /// # Arguments +/// /// * `arg1` - The first argument. /// * `arg2` - The second argument. pub fn compute(... diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000000..e36d547f73 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,148 @@ +[workspace] +members = [ + "crates/benches", + "crates/common", + "crates/components/aead", + "crates/components/authdecode/authdecode", + "crates/components/authdecode/authdecode-core", + "crates/components/authdecode/transcript", + "crates/components/block-cipher", + "crates/components/hmac-sha256", + "crates/components/hmac-sha256-circuits", + "crates/components/key-exchange", + "crates/components/poseidon-halo2", + "crates/components/stream-cipher", + "crates/components/universal-hash", + "crates/core", + "crates/data-fixtures", + "crates/examples", + "crates/formats", + "crates/notary/client", + "crates/notary/server", + "crates/notary/tests-integration", + "crates/prover", + "crates/server-fixture/certs", + "crates/server-fixture/server", + "crates/tests-integration", + "crates/tls/backend", + "crates/tls/client", + "crates/tls/client-async", + "crates/tls/core", + "crates/tls/mpc", + "crates/tls/server-fixture", + "crates/verifier", + "crates/wasm", + "crates/wasm-test-runner", +] +resolver = "2" + +[profile.tests-integration] +inherits = "release" +opt-level = 1 + +[workspace.dependencies] +notary-client = { path = "crates/notary/client" } +notary-server = { path = "crates/notary/server" } +poseidon-halo2 = { path = "crates/components/poseidon-halo2" } +tlsn-authdecode = { path = "crates/components/authdecode/authdecode" } +tlsn-authdecode-core = { path = "crates/components/authdecode/authdecode-core" } +tlsn-authdecode-transcript = { path = "crates/components/authdecode/transcript" } +tls-server-fixture = { path = "crates/tls/server-fixture" } +tlsn-aead = { path = "crates/components/aead" } +tlsn-block-cipher = { path = "crates/components/block-cipher" } +tlsn-common = { path = "crates/common" } +tlsn-core = { path = "crates/core" } +tlsn-data-fixtures = { path = "crates/data-fixtures" } +tlsn-formats = { path = "crates/formats" } +tlsn-hmac-sha256 = { path = "crates/components/hmac-sha256" } +tlsn-hmac-sha256-circuits = { path = "crates/components/hmac-sha256-circuits" } +tlsn-key-exchange = { path = "crates/components/key-exchange" } +tlsn-prover = { path = "crates/prover" } +tlsn-server-fixture = { path = "crates/server-fixture/server" } +tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" } +tlsn-stream-cipher = { path = "crates/components/stream-cipher" } +tlsn-tls-backend = { path = "crates/tls/backend" } +tlsn-tls-client = { path = "crates/tls/client" } +tlsn-tls-client-async = { path = "crates/tls/client-async" } +tlsn-tls-core = { path = "crates/tls/core" } +tlsn-tls-mpc = { path = "crates/tls/mpc" } +tlsn-universal-hash = { path = "crates/components/universal-hash" } +tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "e7b2db6" } +tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "e7b2db6" } +tlsn-verifier = { path = "crates/verifier" } + +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } + +serio = { version = "0.1" } +spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "e7b2db6" } +uid-mux = { version = "0.1", features = ["serio"] } + +aes = { version = "0.8" } +aes-gcm = { version = "0.9" } +anyhow = { version = "1.0" } +async-trait = { version = "0.1" } +async-tungstenite = { version = "0.25" } +axum = { version = "0.7" } +bcs = { version = "0.1" } +bincode = { version = "1.3" } +blake3 = { version = "1.5" } +bytes = { version = "1.4" } +chrono = { version = "0.4" } +cipher = { version = "0.4" } +criterion = { version = "0.5" } +ctr = { version = "0.9" } +derive_builder = { version = "0.12" } +digest = { version = "0.10" } +elliptic-curve = { version = "0.13" } +enum-try-as-inner = { version = "0.1" } +env_logger = { version = "0.10" } +futures = { version = "0.3" } +futures-rustls = { version = "0.26" } +futures-util = { version = "0.3" } +generic-array = { version = "0.14" } +hex = { version = "0.4" } +hmac = { version = "0.12" } +http = { version = "1.1" } +http-body-util = { version = "0.1" } +hyper = { version = "1.1" } +hyper-util = { version = "0.1" } +k256 = { version = "0.13" } +log = { version = "0.4" } +once_cell = { version = "1.19" } +opaque-debug = { version = "0.3" } +p256 = { version = "0.13" } +pkcs8 = { version = "0.10" } +pin-project-lite = { version = "0.2" } +rand = { version = "0.8" } +rand_chacha = { version = "0.3" } +rand_core = { version = "0.6" } +regex = { version = "1.10" } +ring = { version = "0.17" } +rs_merkle = { git = "https://github.com/tlsnotary/rs-merkle.git", rev = "85f3e82" } +rstest = { version = "0.17" } +rustls = { version = "0.21" } +rustls-pemfile = { version = "1.0" } +sct = { version = "0.7" } +serde = { version = "1.0" } +serde_json = { version = "1.0" } +sha2 = { version = "0.10" } +signature = { version = "2.2" } +thiserror = { version = "1.0" } +tokio = { version = "1.38" } +tokio-rustls = { version = "0.24" } +tokio-util = { version = "0.7" } +tracing = { version = "0.1" } +tracing-subscriber = { version = "0.3" } +uuid = { version = "1.4" } +web-time = { version = "0.2" } +webpki = { version = "0.22" } +webpki-roots = { version = "0.26" } +ws_stream_tungstenite = { version = "0.13" } +zeroize = { version = "1.8" } diff --git a/README.md b/README.md index e127563920..1d5bf00596 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,8 @@ [mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg [apache-badge]: https://img.shields.io/github/license/saltstack/salt -[actions-badge]: https://github.com/tlsnotary/tlsn/actions/workflows/ci.yml/badge.svg -[actions-url]: https://github.com/tlsnotary/tlsn/actions?query=workflow%3Arust+branch%3Adev +[actions-badge]: https://github.com/tlsnotary/tlsn/actions/workflows/ci.yml/badge.svg?branch=dev +[actions-url]: https://github.com/tlsnotary/tlsn/actions?query=workflow%3Aci+branch%3Adev [Website](https://tlsnotary.org) | [Documentation](https://docs.tlsnotary.org) | @@ -18,6 +18,8 @@ # TLSNotary +**Data provenance and privacy with secure multi-party computation** + ## ⚠️ Notice This project is currently under active development and should not be used in production. Expect bugs and regular major breaking changes. @@ -30,25 +32,44 @@ All crates in this repository are licensed under either of at your option. -## Overview +## Branches + +- [`main`](https://github.com/tlsnotary/tlsn/tree/main) + - Default branch — points to the latest release. + - This is stable and suitable for most users. +- [`dev`](https://github.com/tlsnotary/tlsn/tree/dev) + - Development branch — contains the latest PRs. + - Developers should submit their PRs against this branch. + +## Directory + +- [examples](./crates/examples/): Examples on how to use the TLSNotary protocol. +- [tlsn-prover](./crates/prover/): The library for the prover component. +- [tlsn-verifier](./crates/verifier/): The library for the verifier component. +- [notary](./crates/notary/): Implements the [notary server](https://docs.tlsnotary.org/intro.html#tls-verification-with-a-general-purpose-notary) and its client. +- [components](./crates/components/): Houses low-level libraries. + +This repository contains the source code for the Rust implementation of the TLSNotary protocol. For additional tools and implementations related to TLSNotary, visit . This includes repositories such as [`tlsn-js`](https://github.com/tlsnotary/tlsn-js), [`tlsn-extension`](https://github.com/tlsnotary/tlsn-extension), [`explorer`](https://github.com/tlsnotary/explorer), among others. -- **tls**: Home of the TLS logic of our protocol like handshake en-/decryption, ghash, **currently outdated** -- **utils**: Utility functions which are frequently used everywhere -- **actors**: Provides actors, which implement protocol-specific functionality using - the actor pattern. They usually wrap an aio module -- **universal-hash**: Implements ghash, which is used AES-GCM. Poly-1305 coming soon. -- **point-addition**: Used in key-exchange and allows to compute a two party sharing of - an EC curve point -### General remarks +## Development -- the TLSNotary codebase makes heavy use of async Rust. Usually an aio - crate/module implements the network IO and wraps a core crate/module which - provides the protocol implementation. This is a frequent pattern you will - encounter in the codebase. -- some protocols are implemented using the actor pattern to facilitate - asynchronous message processing with shared state. +> [!IMPORTANT] +> **Note on Rust-to-WASM Compilation**: This project requires compiling Rust into WASM, which needs [`clang`](https://clang.llvm.org/) version 16.0.0 or newer. MacOS users, be aware that Xcode's default `clang` might be older. If you encounter the error `No available targets are compatible with triple "wasm32-unknown-unknown"`, it's likely due to an outdated `clang`. Updating `clang` to a newer version should resolve this issue. +> +> For MacOS aarch64 users, if Apple's default `clang` isn't working, try installing `llvm` via Homebrew (`brew install llvm`). You can then prioritize the Homebrew `clang` over the default macOS version by modifying your `PATH`. Add the following line to your shell configuration file (e.g., `.bashrc`, `.zshrc`): +> ```sh +> export PATH="/opt/homebrew/opt/llvm/bin:$PATH" +> ``` +If you run into this error: +``` +Could not find directory of OpenSSL installation, and this `-sys` crate cannot + proceed without this knowledge. If OpenSSL is installed and this crate had + trouble finding it, you can set the `OPENSSL_DIR` environment variable for the + compilation process. +``` +Make sure you have the development packages of OpenSSL installed (`libssl-dev` on Ubuntu or `openssl-devel` on Fedora). ## Contribution @@ -56,4 +77,4 @@ Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. -See [CONTRIBUTING.md](CONTRIBUTING.md). \ No newline at end of file +See [CONTRIBUTING.md](CONTRIBUTING.md). diff --git a/appspec-scripts/after_install.sh b/appspec-scripts/after_install.sh deleted file mode 100755 index 062ad35599..0000000000 --- a/appspec-scripts/after_install.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -set -e -export PATH=$PATH:/home/ubuntu/.cargo/bin - -APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }') - -# Prepare directory -sudo rm -rf ~/$APP_NAME/tlsn -sudo mv ~/tlsn/ ~/$APP_NAME -sudo mkdir -p ~/$APP_NAME/tlsn/notary-server/target/release -sudo chown -R ubuntu.ubuntu ~/$APP_NAME - -# Download .git directory -aws s3 cp s3://tlsn-deploy/$APP_NAME/.git ~/$APP_NAME/tlsn/.git --recursive - -# Download binary -aws s3 cp s3://tlsn-deploy/$APP_NAME/notary-server ~/$APP_NAME/tlsn/notary-server/target/release -chmod +x ~/$APP_NAME/tlsn/notary-server/target/release/notary-server - -exit 0 diff --git a/appspec-scripts/before_install.sh b/appspec-scripts/before_install.sh deleted file mode 100755 index 07b5380508..0000000000 --- a/appspec-scripts/before_install.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash -#set -e - -APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }') - -if [ ! -d $APP_NAME ]; then - mkdir ~/$APP_NAME -fi - -exit 0 diff --git a/appspec-scripts/start_app.sh b/appspec-scripts/start_app.sh deleted file mode 100755 index 08cd49ab22..0000000000 --- a/appspec-scripts/start_app.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash -set -e -export PATH=$PATH:/home/ubuntu/.cargo/bin - -APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }') - -cd ~/$APP_NAME/tlsn/notary-server -target/release/notary-server --config-file ~/.notary/$APP_NAME/config.yaml &> ~/$APP_NAME/tlsn/notary.log & - -exit 0 diff --git a/appspec-scripts/stop_app.sh b/appspec-scripts/stop_app.sh deleted file mode 100755 index 3a30a291b1..0000000000 --- a/appspec-scripts/stop_app.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -e - -APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }') - -PID=$(pgrep -f notary.*$APP_NAME) -kill -15 $PID - -exit 0 diff --git a/appspec-scripts/validate_app.sh b/appspec-scripts/validate_app.sh deleted file mode 100755 index 8efbc9727e..0000000000 --- a/appspec-scripts/validate_app.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -set -e - -# Verify proccess is running -APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }') - -pgrep -f notary.*$APP_NAME -[ $? -eq 0 ] || exit 1 - -# Verify that listening sockets exist -if [ "$APPLICATION_NAME" == "tlsn-nightly" ]; then - port=7048 -else - port=7047 -fi - -exposed_ports=$(netstat -lnt4 | egrep -cw $port) -[ $exposed_ports -eq 1 ] || exit 1 - -exit 0 diff --git a/appspec.yml b/appspec.yml index c6cf353d44..c1f6125f4c 100644 --- a/appspec.yml +++ b/appspec.yml @@ -10,22 +10,22 @@ permissions: group: ubuntu hooks: BeforeInstall: - - location: appspec-scripts/before_install.sh + - location: cd-scripts/appspec-scripts/before_install.sh timeout: 300 runas: ubuntu AfterInstall: - - location: appspec-scripts/after_install.sh + - location: cd-scripts/appspec-scripts/after_install.sh timeout: 300 runas: ubuntu ApplicationStart: - - location: appspec-scripts/start_app.sh + - location: cd-scripts/appspec-scripts/start_app.sh timeout: 300 runas: ubuntu ApplicationStop: - - location: appspec-scripts/stop_app.sh + - location: cd-scripts/appspec-scripts/stop_app.sh timeout: 300 runas: ubuntu ValidateService: - - location: appspec-scripts/validate_app.sh + - location: cd-scripts/appspec-scripts/validate_app.sh timeout: 300 runas: ubuntu diff --git a/build_all.sh b/build_all.sh deleted file mode 100755 index e6945b7af4..0000000000 --- a/build_all.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -for package in components/uid-mux components/actors/actor-ot components/cipher components/universal-hash components/aead components/key-exchange components/point-addition components/prf components/tls tlsn; do - pushd $package - # cargo update - cargo clean - cargo build - cargo test - cargo clippy --all-features -- -D warnings || exit - popd -done diff --git a/cd-scripts/appspec-scripts/after_install.sh b/cd-scripts/appspec-scripts/after_install.sh new file mode 100755 index 0000000000..a6041c08db --- /dev/null +++ b/cd-scripts/appspec-scripts/after_install.sh @@ -0,0 +1,35 @@ +#!/bin/bash +set -e + +TAG=$(curl http://169.254.169.254/latest/meta-data/tags/instance/stable) +APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }') + +if [ $APP_NAME = "stable" ]; then + # Prepare directories for stable versions + sudo mkdir ~/${APP_NAME}_${TAG} + sudo mv ~/tlsn ~/${APP_NAME}_${TAG} + sudo mkdir -p ~/${APP_NAME}_${TAG}/tlsn/notary/target/release + sudo chown -R ubuntu.ubuntu ~/${APP_NAME}_${TAG} + + # Download .git directory + aws s3 cp s3://tlsn-deploy/$APP_NAME/.git ~/${APP_NAME}_${TAG}/tlsn/.git --recursive + + # Download binary + aws s3 cp s3://tlsn-deploy/$APP_NAME/notary-server ~/${APP_NAME}_${TAG}/tlsn/notary/target/release + chmod +x ~/${APP_NAME}_${TAG}/tlsn/notary/target/release/notary-server +else + # Prepare directory for dev + sudo rm -rf ~/$APP_NAME/tlsn + sudo mv ~/tlsn/ ~/$APP_NAME + sudo mkdir -p ~/$APP_NAME/tlsn/notary/target/release + sudo chown -R ubuntu.ubuntu ~/$APP_NAME + + # Download .git directory + aws s3 cp s3://tlsn-deploy/$APP_NAME/.git ~/$APP_NAME/tlsn/.git --recursive + + # Download binary + aws s3 cp s3://tlsn-deploy/$APP_NAME/notary-server ~/$APP_NAME/tlsn/notary/target/release + chmod +x ~/$APP_NAME/tlsn/notary/target/release/notary-server +fi + +exit 0 diff --git a/cd-scripts/appspec-scripts/before_install.sh b/cd-scripts/appspec-scripts/before_install.sh new file mode 100755 index 0000000000..76a47c6115 --- /dev/null +++ b/cd-scripts/appspec-scripts/before_install.sh @@ -0,0 +1,20 @@ +#!/bin/bash +set -e + +APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }') + +if [ $APP_NAME = "stable" ]; then + VERSIONS_DEPLOYED=$(find ~/ -maxdepth 1 -type d -name 'stable_*') + VERSIONS_DEPLOYED_COUNT=$(echo $VERSIONS_DEPLOYED | wc -w) + + if [ $VERSIONS_DEPLOYED_COUNT -gt 3 ]; then + echo "More than 3 stable versions found" + exit 1 + fi +else + if [ ! -d ~/$APP_NAME ]; then + mkdir ~/$APP_NAME + fi +fi + +exit 0 diff --git a/cd-scripts/appspec-scripts/start_app.sh b/cd-scripts/appspec-scripts/start_app.sh new file mode 100755 index 0000000000..0d449f6222 --- /dev/null +++ b/cd-scripts/appspec-scripts/start_app.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Port tagging will also be used to manipulate proxy server via modify_proxy.sh script +set -ex + +TAG=$(curl http://169.254.169.254/latest/meta-data/tags/instance/stable) +APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }') + +if [ $APP_NAME = "stable" ]; then + # Check if all stable ports are in use. If true, terminate the deployment + [[ $(netstat -lnt4 | egrep -c ':(7047|7057|7067)\s') -eq 3 ]] && { echo "All stable ports are in use"; exit 1; } + STABLE_PORTS="7047 7057 7067" + for PORT in $STABLE_PORTS; do + PORT_LISTENING=$(netstat -lnt4 | egrep -cw $PORT || true) + if [ $PORT_LISTENING -eq 0 ]; then + ~/${APP_NAME}_${TAG}/tlsn/notary/target/release/notary-server --config-file ~/.notary/${APP_NAME}_${PORT}/config.yaml &> ~/${APP_NAME}_${TAG}/tlsn/notary.log & + # Create a tag that will be used for service validation + INSTANCE_ID=$(curl http://169.254.169.254/latest/meta-data/instance-id) + aws ec2 create-tags --resources $INSTANCE_ID --tags "Key=port,Value=$PORT" + break + fi + done +else + ~/$APP_NAME/tlsn/notary/target/release/notary-server --config-file ~/.notary/$APP_NAME/config.yaml &> ~/$APP_NAME/tlsn/notary.log & +fi + +exit 0 diff --git a/cd-scripts/appspec-scripts/stop_app.sh b/cd-scripts/appspec-scripts/stop_app.sh new file mode 100755 index 0000000000..ae92a8c06c --- /dev/null +++ b/cd-scripts/appspec-scripts/stop_app.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# AWS CodeDeploy hook sequence: https://docs.aws.amazon.com/codedeploy/latest/userguide/reference-appspec-file-structure-hooks.html#appspec-hooks-server +set -ex + +APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }') + +if [ $APP_NAME = "stable" ]; then + VERSIONS_DEPLOYED=$(find ~/ -maxdepth 1 -type d -name 'stable_*') + VERSIONS_DEPLOYED_COUNT=$(echo $VERSIONS_DEPLOYED | wc -w) + + # Remove oldest version if exists + if [ $VERSIONS_DEPLOYED_COUNT -eq 3 ]; then + echo "Candidate versions to be removed:" + OLDEST_DIR="" + OLDEST_TIME="" + + for DIR in $VERSIONS_DEPLOYED; do + TIME=$(stat -c %W $DIR) + + if [ -z $OLDEST_TIME ] || [ $TIME -lt $OLDEST_TIME ]; then + OLDEST_DIR=$DIR + OLDEST_TIME=$TIME + fi + done + + echo "The oldest version is running under: $OLDEST_DIR" + PID=$(lsof $OLDEST_DIR/tlsn/notary/target/release/notary-server | awk '{ print $2 }' | tail -1) + kill -15 $PID || true + rm -rf $OLDEST_DIR + fi +else + PID=$(pgrep -f notary.*$APP_NAME) + kill -15 $PID || true +fi + +exit 0 diff --git a/cd-scripts/appspec-scripts/validate_app.sh b/cd-scripts/appspec-scripts/validate_app.sh new file mode 100755 index 0000000000..3921fd337f --- /dev/null +++ b/cd-scripts/appspec-scripts/validate_app.sh @@ -0,0 +1,21 @@ +#!/bin/bash +set -e + +# Verify proccess is running +APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }') + +# Verify that listening sockets exist +if [ $APP_NAME = "stable" ]; then + PORT=$(curl http://169.254.169.254/latest/meta-data/tags/instance/port) + ps -ef | grep notary.*$APP_NAME.*$PORT | grep -v grep + [ $? -eq 0 ] || exit 1 +else + PORT=7048 + pgrep -f notary.*$APP_NAME + [ $? -eq 0 ] || exit 1 +fi + +EXPOSED_PORTS=$(netstat -lnt4 | egrep -cw $PORT) +[ $EXPOSED_PORTS -eq 1 ] || exit 1 + +exit 0 diff --git a/cd-scripts/modify_proxy.sh b/cd-scripts/modify_proxy.sh new file mode 100755 index 0000000000..1abb5e5665 --- /dev/null +++ b/cd-scripts/modify_proxy.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# This script is executed on proxy side, in order to assign the available port to latest stable version +set -e + +PORT=$1 +VERSION=$2 + +sed -i "/# Port $PORT/{n;s/v[0-9].[0-9].[0-9]-[a-z]*.[0-9]*/$VERSION/g}" /etc/nginx/sites-available/tlsnotary-pse +sed -i "/# Port $PORT/{n;n;s/v[0-9].[0-9].[0-9]-[a-z]*.[0-9]*/$VERSION/g}" /etc/nginx/sites-available/tlsnotary-pse + +nginx -t +nginx -s reload + +exit 0 diff --git a/components/aead/Cargo.toml b/components/aead/Cargo.toml deleted file mode 100644 index 546bc9a514..0000000000 --- a/components/aead/Cargo.toml +++ /dev/null @@ -1,41 +0,0 @@ -[package] -name = "tlsn-aead" -authors = ["TLSNotary Team"] -description = "This crate provides an implementation of a two-party version of AES-GCM behind an AEAD trait" -keywords = ["tls", "mpc", "2pc", "aead", "aes", "aes-gcm"] -categories = ["cryptography"] -license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" -edition = "2021" - -[lib] -name = "aead" - -[features] -default = ["mock"] -mock = [] -tracing = [ - "dep:tracing", - "tlsn-block-cipher/tracing", - "tlsn-stream-cipher/tracing", - "tlsn-universal-hash/tracing", -] - -[dependencies] -tlsn-block-cipher = { path = "../cipher/block-cipher" } -tlsn-stream-cipher = { path = "../cipher/stream-cipher" } -tlsn-universal-hash = { path = "../universal-hash" } -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" } - -async-trait = "0.1" -derive_builder = "0.12" -thiserror = "1" -futures = "0.3" -serde = "1" -tracing = { version = "0.1", optional = true } - -[dev-dependencies] -tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] } -aes-gcm = "0.10" diff --git a/components/aead/src/aes_gcm/tag.rs b/components/aead/src/aes_gcm/tag.rs deleted file mode 100644 index f3f3d1c71a..0000000000 --- a/components/aead/src/aes_gcm/tag.rs +++ /dev/null @@ -1,64 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::ops::Add; - -use crate::AeadError; - -pub(crate) const AES_GCM_TAG_LEN: usize = 16; - -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] -pub(crate) struct AesGcmTagShare(pub(crate) [u8; 16]); - -impl AesGcmTagShare { - pub(crate) fn from_unchecked(share: &[u8]) -> Result { - if share.len() != 16 { - return Err(AeadError::ValidationError( - "Received tag share is not 16 bytes long".to_string(), - )); - } - let mut result = [0u8; 16]; - result.copy_from_slice(share); - Ok(Self(result)) - } -} - -impl AsRef<[u8]> for AesGcmTagShare { - fn as_ref(&self) -> &[u8] { - &self.0 - } -} - -impl Add for AesGcmTagShare { - type Output = Vec; - - fn add(self, rhs: Self) -> Self::Output { - self.0 - .iter() - .zip(rhs.0.iter()) - .map(|(a, b)| a ^ b) - .collect() - } -} - -/// Builds padded data for GHASH -#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", ret))] -pub(crate) fn build_ghash_data(mut aad: Vec, mut ciphertext: Vec) -> Vec { - let associated_data_bitlen = (aad.len() as u64) * 8; - let text_bitlen = (ciphertext.len() as u64) * 8; - - let len_block = ((associated_data_bitlen as u128) << 64) + (text_bitlen as u128); - - // pad data to be a multiple of 16 bytes - let aad_padded_block_count = (aad.len() / 16) + (aad.len() % 16 != 0) as usize; - aad.resize(aad_padded_block_count * 16, 0); - - let ciphertext_padded_block_count = - (ciphertext.len() / 16) + (ciphertext.len() % 16 != 0) as usize; - ciphertext.resize(ciphertext_padded_block_count * 16, 0); - - let mut data: Vec = Vec::with_capacity(aad.len() + ciphertext.len() + 16); - data.extend(aad); - data.extend(ciphertext); - data.extend_from_slice(&len_block.to_be_bytes()); - - data -} diff --git a/components/aead/src/msg.rs b/components/aead/src/msg.rs deleted file mode 100644 index a15d318633..0000000000 --- a/components/aead/src/msg.rs +++ /dev/null @@ -1,29 +0,0 @@ -//! Message types for AEAD protocols. - -use serde::{Deserialize, Serialize}; - -use mpz_core::{commit::Decommitment, hash::Hash}; - -/// Aead messages. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[allow(missing_docs)] -pub enum AeadMessage { - TagShareCommitment(Hash), - TagShareDecommitment(Decommitment), - TagShare(TagShare), -} - -/// A tag share. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TagShare { - /// The share of the tag. - pub share: Vec, -} - -impl From for TagShare { - fn from(tag_share: crate::aes_gcm::AesGcmTagShare) -> Self { - Self { - share: tag_share.0.to_vec(), - } - } -} diff --git a/components/cipher/Cargo.toml b/components/cipher/Cargo.toml deleted file mode 100644 index f75fa46e89..0000000000 --- a/components/cipher/Cargo.toml +++ /dev/null @@ -1,30 +0,0 @@ -[workspace] -members = ["stream-cipher", "block-cipher"] -resolver = "2" - -[workspace.dependencies] -# tlsn -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" } - -# crypto -aes = "0.8" -ctr = "0.9.2" -cipher = "0.4.3" - -# async -async-trait = "0.1" -futures = "0.3" -tokio = { version = "1", default-features = false } - -# testing -rstest = "0.17" -criterion = "0.5" - -# error/log -thiserror = "1" -tracing = "0.1" - -# misc -derive_builder = "0.12" diff --git a/components/cipher/block-cipher/src/cipher.rs b/components/cipher/block-cipher/src/cipher.rs deleted file mode 100644 index 21d2713606..0000000000 --- a/components/cipher/block-cipher/src/cipher.rs +++ /dev/null @@ -1,182 +0,0 @@ -use std::marker::PhantomData; - -use async_trait::async_trait; - -use mpz_garble::{value::ValueRef, Decode, DecodePrivate, Execute, Memory}; -use utils::id::NestedId; - -use crate::{BlockCipher, BlockCipherCircuit, BlockCipherConfig, BlockCipherError}; - -struct State { - execution_id: NestedId, - key: Option, -} - -/// An MPC block cipher -pub struct MpcBlockCipher -where - C: BlockCipherCircuit, - E: Memory + Execute + Decode + DecodePrivate + Send + Sync, -{ - state: State, - - executor: E, - - _cipher: PhantomData, -} - -impl MpcBlockCipher -where - C: BlockCipherCircuit, - E: Memory + Execute + Decode + DecodePrivate + Send + Sync, -{ - /// Creates a new MPC block cipher - /// - /// # Arguments - /// - /// * `config` - The configuration for the block cipher - /// * `executor` - The executor to use for the MPC - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip(executor)) - )] - pub fn new(config: BlockCipherConfig, executor: E) -> Self { - let execution_id = NestedId::new(&config.id).append_counter(); - Self { - state: State { - execution_id, - key: None, - }, - executor, - _cipher: PhantomData, - } - } -} - -#[async_trait] -impl BlockCipher for MpcBlockCipher -where - C: BlockCipherCircuit, - E: Memory + Execute + Decode + DecodePrivate + Send + Sync + Send, -{ - #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", skip(self)))] - fn set_key(&mut self, key: ValueRef) { - self.state.key = Some(key); - } - - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self, plaintext), err) - )] - async fn encrypt_private(&mut self, plaintext: Vec) -> Result, BlockCipherError> { - let len = plaintext.len(); - let block: C::BLOCK = plaintext - .try_into() - .map_err(|_| BlockCipherError::InvalidInputLength(C::BLOCK_LEN, len))?; - - let key = self.state.key.clone().ok_or(BlockCipherError::KeyNotSet)?; - - let id = self.state.execution_id.increment_in_place().to_string(); - - let msg = self - .executor - .new_private_input::(&format!("{}/msg", &id))?; - let ciphertext = self - .executor - .new_output::(&format!("{}/ciphertext", &id))?; - - self.executor.assign(&msg, block)?; - - self.executor - .execute(C::circuit(), &[key, msg], &[ciphertext.clone()]) - .await?; - - let mut outputs = self.executor.decode(&[ciphertext]).await?; - - let ciphertext: C::BLOCK = if let Ok(ciphertext) = outputs - .pop() - .expect("ciphertext should be present") - .try_into() - { - ciphertext - } else { - panic!("ciphertext should be a block") - }; - - Ok(ciphertext.into()) - } - - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self), err) - )] - async fn encrypt_blind(&mut self) -> Result, BlockCipherError> { - let key = self.state.key.clone().ok_or(BlockCipherError::KeyNotSet)?; - - let id = self.state.execution_id.increment_in_place().to_string(); - - let msg = self - .executor - .new_blind_input::(&format!("{}/msg", &id))?; - let ciphertext = self - .executor - .new_output::(&format!("{}/ciphertext", &id))?; - - self.executor - .execute(C::circuit(), &[key, msg], &[ciphertext.clone()]) - .await?; - - let mut outputs = self.executor.decode(&[ciphertext]).await?; - - let ciphertext: C::BLOCK = if let Ok(ciphertext) = outputs - .pop() - .expect("ciphertext should be present") - .try_into() - { - ciphertext - } else { - panic!("ciphertext should be a block") - }; - - Ok(ciphertext.into()) - } - - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self, plaintext), err) - )] - async fn encrypt_share(&mut self, plaintext: Vec) -> Result, BlockCipherError> { - let len = plaintext.len(); - let block: C::BLOCK = plaintext - .try_into() - .map_err(|_| BlockCipherError::InvalidInputLength(C::BLOCK_LEN, len))?; - - let key = self.state.key.clone().ok_or(BlockCipherError::KeyNotSet)?; - - let id = self.state.execution_id.increment_in_place().to_string(); - - let msg = self - .executor - .new_public_input::(&format!("{}/msg", &id))?; - let ciphertext = self - .executor - .new_output::(&format!("{}/ciphertext", &id))?; - - self.executor.assign(&msg, block)?; - - self.executor - .execute(C::circuit(), &[key, msg], &[ciphertext.clone()]) - .await?; - - let mut outputs = self.executor.decode_shared(&[ciphertext]).await?; - - let share: C::BLOCK = - if let Ok(share) = outputs.pop().expect("share should be present").try_into() { - share - } else { - panic!("share should be a block") - }; - - Ok(share.into()) - } -} diff --git a/components/cipher/block-cipher/src/lib.rs b/components/cipher/block-cipher/src/lib.rs deleted file mode 100644 index 37b84afaaa..0000000000 --- a/components/cipher/block-cipher/src/lib.rs +++ /dev/null @@ -1,163 +0,0 @@ -//! This crate provides a 2PC block cipher implementation. -//! -//! Both parties work together to encrypt or share an encrypted block using a shared key. - -#![deny(missing_docs, unreachable_pub, unused_must_use)] -#![deny(clippy::all)] -#![deny(unsafe_code)] - -mod cipher; -mod circuit; -mod config; - -use async_trait::async_trait; - -use mpz_garble::value::ValueRef; - -pub use crate::{ - cipher::MpcBlockCipher, - circuit::{Aes128, BlockCipherCircuit}, -}; -pub use config::{BlockCipherConfig, BlockCipherConfigBuilder, BlockCipherConfigBuilderError}; - -/// Errors that can occur when using the block cipher -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum BlockCipherError { - #[error(transparent)] - MemoryError(#[from] mpz_garble::MemoryError), - #[error(transparent)] - ExecutionError(#[from] mpz_garble::ExecutionError), - #[error(transparent)] - DecodeError(#[from] mpz_garble::DecodeError), - #[error("Cipher key not set")] - KeyNotSet, - #[error("Input does not match block length: expected {0}, got {1}")] - InvalidInputLength(usize, usize), -} - -/// A trait for MPC block ciphers -#[async_trait] -pub trait BlockCipher: Send + Sync -where - Cipher: BlockCipherCircuit, -{ - /// Sets the key for the block cipher. - fn set_key(&mut self, key: ValueRef); - - /// Encrypts the given plaintext keeping it hidden from the other party(s). - /// - /// Returns the ciphertext - /// - /// * `plaintext` - The plaintext to encrypt - async fn encrypt_private(&mut self, plaintext: Vec) -> Result, BlockCipherError>; - - /// Encrypts a plaintext provided by the other party(s). - /// - /// Returns the ciphertext - async fn encrypt_blind(&mut self) -> Result, BlockCipherError>; - - /// Encrypts a plaintext provided by both parties. Fails if the - /// plaintext provided by both parties does not match. - /// - /// Returns an additive share of the ciphertext - /// - /// * `plaintext` - The plaintext to encrypt - async fn encrypt_share(&mut self, plaintext: Vec) -> Result, BlockCipherError>; -} - -#[cfg(test)] -mod tests { - use super::*; - - use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory, Vm}; - - use crate::circuit::Aes128; - - use ::aes::Aes128 as TestAes128; - use ::cipher::{BlockEncrypt, KeyInit}; - - fn aes128(key: [u8; 16], msg: [u8; 16]) -> [u8; 16] { - let mut msg = msg.into(); - let cipher = TestAes128::new(&key.into()); - cipher.encrypt_block(&mut msg); - msg.into() - } - - #[tokio::test] - async fn test_block_cipher_blind() { - let leader_config = BlockCipherConfig::builder().id("test").build().unwrap(); - let follower_config = BlockCipherConfig::builder().id("test").build().unwrap(); - - let key = [0u8; 16]; - - let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test").await; - let leader_thread = leader_vm.new_thread("test").await.unwrap(); - let follower_thread = follower_vm.new_thread("test").await.unwrap(); - - // Key is public just for this test, typically it is private - let leader_key = leader_thread.new_public_input::<[u8; 16]>("key").unwrap(); - let follower_key = follower_thread.new_public_input::<[u8; 16]>("key").unwrap(); - - leader_thread.assign(&leader_key, key).unwrap(); - follower_thread.assign(&follower_key, key).unwrap(); - - let mut leader = MpcBlockCipher::::new(leader_config, leader_thread); - leader.set_key(leader_key); - - let mut follower = MpcBlockCipher::::new(follower_config, follower_thread); - follower.set_key(follower_key); - - let plaintext = [0u8; 16]; - - let (leader_ciphertext, follower_ciphertext) = tokio::try_join!( - leader.encrypt_private(plaintext.to_vec()), - follower.encrypt_blind() - ) - .unwrap(); - - let expected = aes128(key, plaintext); - - assert_eq!(leader_ciphertext, expected.to_vec()); - assert_eq!(leader_ciphertext, follower_ciphertext); - } - - #[tokio::test] - async fn test_block_cipher_share() { - let leader_config = BlockCipherConfig::builder().id("test").build().unwrap(); - let follower_config = BlockCipherConfig::builder().id("test").build().unwrap(); - - let key = [0u8; 16]; - - let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test").await; - let leader_thread = leader_vm.new_thread("test").await.unwrap(); - let follower_thread = follower_vm.new_thread("test").await.unwrap(); - - // Key is public just for this test, typically it is private - let leader_key = leader_thread.new_public_input::<[u8; 16]>("key").unwrap(); - let follower_key = follower_thread.new_public_input::<[u8; 16]>("key").unwrap(); - - leader_thread.assign(&leader_key, key).unwrap(); - follower_thread.assign(&follower_key, key).unwrap(); - - let mut leader = MpcBlockCipher::::new(leader_config, leader_thread); - leader.set_key(leader_key); - - let mut follower = MpcBlockCipher::::new(follower_config, follower_thread); - follower.set_key(follower_key); - - let plaintext = [0u8; 16]; - - let (leader_share, follower_share) = tokio::try_join!( - leader.encrypt_share(plaintext.to_vec()), - follower.encrypt_share(plaintext.to_vec()) - ) - .unwrap(); - - let expected = aes128(key, plaintext); - - let result: [u8; 16] = std::array::from_fn(|i| leader_share[i] ^ follower_share[i]); - - assert_eq!(result, expected); - } -} diff --git a/components/cipher/stream-cipher/benches/mock.rs b/components/cipher/stream-cipher/benches/mock.rs deleted file mode 100644 index bae9b694e4..0000000000 --- a/components/cipher/stream-cipher/benches/mock.rs +++ /dev/null @@ -1,145 +0,0 @@ -use criterion::{criterion_group, criterion_main, Criterion, Throughput}; - -use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory, Vm}; -use tlsn_stream_cipher::{ - Aes128Ctr, CtrCircuit, MpcStreamCipher, StreamCipher, StreamCipherConfigBuilder, -}; - -async fn bench_stream_cipher_encrypt(thread_count: usize, len: usize) { - let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test").await; - - let leader_thread = leader_vm.new_thread("key_config").await.unwrap(); - let leader_key = leader_thread.new_public_input::<[u8; 16]>("key").unwrap(); - let leader_iv = leader_thread.new_public_input::<[u8; 4]>("iv").unwrap(); - - leader_thread.assign(&leader_key, [0u8; 16]).unwrap(); - leader_thread.assign(&leader_iv, [0u8; 4]).unwrap(); - - let follower_thread = follower_vm.new_thread("key_config").await.unwrap(); - let follower_key = follower_thread.new_public_input::<[u8; 16]>("key").unwrap(); - let follower_iv = follower_thread.new_public_input::<[u8; 4]>("iv").unwrap(); - - follower_thread.assign(&follower_key, [0u8; 16]).unwrap(); - follower_thread.assign(&follower_iv, [0u8; 4]).unwrap(); - - let leader_thread_pool = leader_vm - .new_thread_pool("mock", thread_count) - .await - .unwrap(); - let follower_thread_pool = follower_vm - .new_thread_pool("mock", thread_count) - .await - .unwrap(); - - let leader_config = StreamCipherConfigBuilder::default() - .id("test".to_string()) - .build() - .unwrap(); - - let follower_config = StreamCipherConfigBuilder::default() - .id("test".to_string()) - .build() - .unwrap(); - - let mut leader = MpcStreamCipher::::new(leader_config, leader_thread_pool); - leader.set_key(leader_key, leader_iv); - - let mut follower = MpcStreamCipher::::new(follower_config, follower_thread_pool); - follower.set_key(follower_key, follower_iv); - - let plaintext = vec![0u8; len]; - let explicit_nonce = vec![0u8; 8]; - - _ = tokio::try_join!( - leader.encrypt_private(explicit_nonce.clone(), plaintext), - follower.encrypt_blind(explicit_nonce, len) - ) - .unwrap(); - - _ = tokio::try_join!(leader_vm.finalize(), follower_vm.finalize()).unwrap(); -} - -async fn bench_stream_cipher_zk(thread_count: usize, len: usize) { - let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test").await; - - let key = [0u8; 16]; - let iv = [0u8; 4]; - - let leader_thread = leader_vm.new_thread("key_config").await.unwrap(); - let leader_key = leader_thread.new_public_input::<[u8; 16]>("key").unwrap(); - let leader_iv = leader_thread.new_public_input::<[u8; 4]>("iv").unwrap(); - - leader_thread.assign(&leader_key, key).unwrap(); - leader_thread.assign(&leader_iv, iv).unwrap(); - - let follower_thread = follower_vm.new_thread("key_config").await.unwrap(); - let follower_key = follower_thread.new_public_input::<[u8; 16]>("key").unwrap(); - let follower_iv = follower_thread.new_public_input::<[u8; 4]>("iv").unwrap(); - - follower_thread.assign(&follower_key, key).unwrap(); - follower_thread.assign(&follower_iv, iv).unwrap(); - - let leader_thread_pool = leader_vm - .new_thread_pool("mock", thread_count) - .await - .unwrap(); - let follower_thread_pool = follower_vm - .new_thread_pool("mock", thread_count) - .await - .unwrap(); - - let leader_config = StreamCipherConfigBuilder::default() - .id("test".to_string()) - .build() - .unwrap(); - - let follower_config = StreamCipherConfigBuilder::default() - .id("test".to_string()) - .build() - .unwrap(); - - let mut leader = MpcStreamCipher::::new(leader_config, leader_thread_pool); - leader.set_key(leader_key, leader_iv); - - let mut follower = MpcStreamCipher::::new(follower_config, follower_thread_pool); - follower.set_key(follower_key, follower_iv); - - let plaintext = vec![0u8; len]; - let explicit_nonce = [0u8; 8]; - let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &plaintext).unwrap(); - - _ = tokio::try_join!( - leader.prove_plaintext(explicit_nonce.to_vec(), plaintext), - follower.verify_plaintext(explicit_nonce.to_vec(), ciphertext) - ) - .unwrap(); - - _ = tokio::try_join!(leader_vm.finalize(), follower_vm.finalize()).unwrap(); -} - -fn criterion_benchmark(c: &mut Criterion) { - let rt = tokio::runtime::Runtime::new().unwrap(); - let thread_count = 8; - let len = 1024; - - let mut group = c.benchmark_group("stream_cipher/encrypt_private"); - group.throughput(Throughput::Bytes(len as u64)); - group.bench_function(format!("{}", len), |b| { - b.to_async(&rt) - .iter(|| async { bench_stream_cipher_encrypt(thread_count, len).await }) - }); - - drop(group); - - let mut group = c.benchmark_group("stream_cipher/zk"); - group.throughput(Throughput::Bytes(len as u64)); - group.bench_function(format!("{}", len), |b| { - b.to_async(&rt) - .iter(|| async { bench_stream_cipher_zk(thread_count, len).await }) - }); - - drop(group); -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/components/cipher/stream-cipher/src/stream_cipher.rs b/components/cipher/stream-cipher/src/stream_cipher.rs deleted file mode 100644 index 601bf6921a..0000000000 --- a/components/cipher/stream-cipher/src/stream_cipher.rs +++ /dev/null @@ -1,842 +0,0 @@ -use async_trait::async_trait; -use mpz_circuits::types::Value; -use std::{collections::HashMap, fmt::Debug, marker::PhantomData}; - -use mpz_garble::{ - value::ValueRef, Decode, DecodePrivate, Execute, Memory, Prove, Thread, ThreadPool, Verify, -}; -use utils::id::NestedId; - -use crate::{ - cipher::CtrCircuit, - circuit::build_array_xor, - config::{InputText, KeyBlockConfig, StreamCipherConfig}, - StreamCipher, StreamCipherError, -}; - -/// An MPC stream cipher. -pub struct MpcStreamCipher -where - C: CtrCircuit, - E: Thread + Execute + Decode + DecodePrivate + Send + Sync, -{ - config: StreamCipherConfig, - state: State, - thread_pool: ThreadPool, - - _cipher: PhantomData, -} - -struct State { - /// Encoded key and IV for the cipher. - encoded_key_iv: Option, - /// Key and IV for the cipher. - key_iv: Option, - /// Unique identifier for each execution of the cipher. - execution_id: NestedId, - /// Unique identifier for each byte in the transcript. - transcript_counter: NestedId, - /// Unique identifier for each byte in the ciphertext (prefixed with execution id). - ciphertext_counter: NestedId, - /// Persists the transcript counter for each transcript id. - transcript_state: HashMap, -} - -#[derive(Clone)] -struct EncodedKeyAndIv { - key: ValueRef, - iv: ValueRef, -} - -#[derive(Clone)] -struct KeyAndIv { - key: Vec, - iv: Vec, -} - -impl MpcStreamCipher -where - C: CtrCircuit, - E: Thread + Execute + Prove + Verify + Decode + DecodePrivate + Send + Sync + 'static, -{ - /// Creates a new counter-mode cipher. - pub fn new(config: StreamCipherConfig, thread_pool: ThreadPool) -> Self { - let execution_id = NestedId::new(&config.id).append_counter(); - let transcript_counter = NestedId::new(&config.transcript_id).append_counter(); - let ciphertext_counter = execution_id.append_string("ciphertext").append_counter(); - - Self { - config, - state: State { - encoded_key_iv: None, - key_iv: None, - execution_id, - transcript_counter, - ciphertext_counter, - transcript_state: HashMap::new(), - }, - thread_pool, - _cipher: PhantomData, - } - } - - /// Returns unique identifiers for the next bytes in the transcript. - fn plaintext_ids(&mut self, len: usize) -> Vec { - (0..len) - .map(|_| { - self.state - .transcript_counter - .increment_in_place() - .to_string() - }) - .collect() - } - - /// Returns unique identifiers for the next bytes in the ciphertext. - fn ciphertext_ids(&mut self, len: usize) -> Vec { - (0..len) - .map(|_| { - self.state - .ciphertext_counter - .increment_in_place() - .to_string() - }) - .collect() - } - - async fn compute_keystream( - &mut self, - explicit_nonce: Vec, - start_ctr: usize, - len: usize, - mode: ExecutionMode, - ) -> Result { - let EncodedKeyAndIv { key, iv } = self - .state - .encoded_key_iv - .clone() - .ok_or(StreamCipherError::KeyIvNotSet)?; - - let explicit_nonce_len = explicit_nonce.len(); - let explicit_nonce: C::NONCE = explicit_nonce.try_into().map_err(|_| { - StreamCipherError::InvalidExplicitNonceLength { - expected: C::NONCE_LEN, - actual: explicit_nonce_len, - } - })?; - - // Divide msg length by block size rounding up - let block_count = (len / C::BLOCK_LEN) + (len % C::BLOCK_LEN != 0) as usize; - - let block_configs = (0..block_count) - .map(|i| { - KeyBlockConfig::::new( - key.clone(), - iv.clone(), - explicit_nonce, - (start_ctr + i) as u32, - ) - }) - .collect::>(); - - let execution_id = self.state.execution_id.increment_in_place(); - - let keystream = compute_keystream( - &mut self.thread_pool, - execution_id, - block_configs, - len, - mode, - ) - .await?; - - Ok(keystream) - } - - /// Applies the keystream to the provided input text. - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(self), err) - )] - async fn apply_keystream( - &mut self, - input_text: InputText, - keystream: ValueRef, - mode: ExecutionMode, - ) -> Result { - let execution_id = self.state.execution_id.increment_in_place(); - - let mut scope = self.thread_pool.new_scope(); - scope.push(move |thread| { - Box::pin(apply_keystream( - thread, - mode, - execution_id, - input_text, - keystream, - )) - }); - - let output_text = scope.wait().await.into_iter().next().unwrap()?; - - Ok(output_text) - } - - async fn decode_public(&mut self, value: ValueRef) -> Result { - let mut scope = self.thread_pool.new_scope(); - scope.push(move |thread| Box::pin(async move { thread.decode(&[value]).await })); - let mut output = scope.wait().await.into_iter().next().unwrap()?; - Ok(output.pop().unwrap()) - } - - async fn decode_private(&mut self, value: ValueRef) -> Result { - let mut scope = self.thread_pool.new_scope(); - scope.push(move |thread| Box::pin(async move { thread.decode_private(&[value]).await })); - let mut output = scope.wait().await.into_iter().next().unwrap()?; - Ok(output.pop().unwrap()) - } - - async fn decode_blind(&mut self, value: ValueRef) -> Result<(), StreamCipherError> { - let mut scope = self.thread_pool.new_scope(); - scope.push(move |thread| Box::pin(async move { thread.decode_blind(&[value]).await })); - scope.wait().await.into_iter().next().unwrap()?; - Ok(()) - } - - async fn prove(&mut self, value: ValueRef) -> Result<(), StreamCipherError> { - let mut scope = self.thread_pool.new_scope(); - scope.push(move |thread| Box::pin(async move { thread.prove(&[value]).await })); - scope.wait().await.into_iter().next().unwrap()?; - Ok(()) - } - - async fn verify(&mut self, value: ValueRef, expected: Value) -> Result<(), StreamCipherError> { - let mut scope = self.thread_pool.new_scope(); - scope.push(move |thread| { - Box::pin(async move { thread.verify(&[value], &[expected]).await }) - }); - scope.wait().await.into_iter().next().unwrap()?; - Ok(()) - } -} - -#[async_trait] -impl StreamCipher for MpcStreamCipher -where - C: CtrCircuit, - E: Thread + Execute + Prove + Verify + Decode + DecodePrivate + Send + Sync + 'static, -{ - fn set_key(&mut self, key: ValueRef, iv: ValueRef) { - self.state.encoded_key_iv = Some(EncodedKeyAndIv { key, iv }); - } - - async fn decode_key_private(&mut self) -> Result<(), StreamCipherError> { - let EncodedKeyAndIv { key, iv } = self - .state - .encoded_key_iv - .clone() - .ok_or(StreamCipherError::KeyIvNotSet)?; - - let mut scope = self.thread_pool.new_scope(); - scope.push(move |thread| Box::pin(async move { thread.decode_private(&[key, iv]).await })); - let output = scope.wait().await.into_iter().next().unwrap()?; - - let [key, iv]: [_; 2] = output.try_into().expect("decoded 2 values"); - let key: Vec = key.try_into().expect("key is an array"); - let iv: Vec = iv.try_into().expect("iv is an array"); - - self.state.key_iv = Some(KeyAndIv { key, iv }); - - Ok(()) - } - - async fn decode_key_blind(&mut self) -> Result<(), StreamCipherError> { - let EncodedKeyAndIv { key, iv } = self - .state - .encoded_key_iv - .clone() - .ok_or(StreamCipherError::KeyIvNotSet)?; - - let mut scope = self.thread_pool.new_scope(); - scope.push(move |thread| Box::pin(async move { thread.decode_blind(&[key, iv]).await })); - scope.wait().await.into_iter().next().unwrap()?; - - Ok(()) - } - - fn set_transcript_id(&mut self, id: &str) { - let current_id = self - .state - .transcript_counter - .root() - .expect("root id is set"); - let current_counter = self.state.transcript_counter.clone(); - self.state - .transcript_state - .insert(current_id.to_string(), current_counter); - - if let Some(counter) = self.state.transcript_state.get(id) { - self.state.transcript_counter = counter.clone(); - } else { - self.state.transcript_counter = NestedId::new(id).append_counter(); - } - } - - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self, plaintext), err) - )] - async fn encrypt_public( - &mut self, - explicit_nonce: Vec, - plaintext: Vec, - ) -> Result, StreamCipherError> { - let keystream = self - .compute_keystream( - explicit_nonce, - self.config.start_ctr, - plaintext.len(), - ExecutionMode::Mpc, - ) - .await?; - - let plaintext_ids = self.plaintext_ids(plaintext.len()); - let ciphertext = self - .apply_keystream( - InputText::Public { - ids: plaintext_ids, - text: plaintext, - }, - keystream, - ExecutionMode::Mpc, - ) - .await?; - - let ciphertext: Vec = self - .decode_public(ciphertext) - .await? - .try_into() - .expect("ciphertext is array"); - - Ok(ciphertext) - } - - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self, plaintext), err) - )] - async fn encrypt_private( - &mut self, - explicit_nonce: Vec, - plaintext: Vec, - ) -> Result, StreamCipherError> { - let keystream = self - .compute_keystream( - explicit_nonce, - self.config.start_ctr, - plaintext.len(), - ExecutionMode::Mpc, - ) - .await?; - - let plaintext_ids = self.plaintext_ids(plaintext.len()); - let ciphertext = self - .apply_keystream( - InputText::Private { - ids: plaintext_ids, - text: plaintext, - }, - keystream, - ExecutionMode::Mpc, - ) - .await?; - - let ciphertext: Vec = self - .decode_public(ciphertext) - .await? - .try_into() - .expect("ciphertext is array"); - - Ok(ciphertext) - } - - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self), err) - )] - async fn encrypt_blind( - &mut self, - explicit_nonce: Vec, - len: usize, - ) -> Result, StreamCipherError> { - let keystream = self - .compute_keystream( - explicit_nonce, - self.config.start_ctr, - len, - ExecutionMode::Mpc, - ) - .await?; - - let plaintext_ids = self.plaintext_ids(len); - let ciphertext = self - .apply_keystream( - InputText::Blind { ids: plaintext_ids }, - keystream, - ExecutionMode::Mpc, - ) - .await?; - - let ciphertext: Vec = self - .decode_public(ciphertext) - .await? - .try_into() - .expect("ciphertext is array"); - - Ok(ciphertext) - } - - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self), err) - )] - async fn decrypt_public( - &mut self, - explicit_nonce: Vec, - ciphertext: Vec, - ) -> Result, StreamCipherError> { - // TODO: We may want to support writing to the transcript when decrypting - // in public mode. - let keystream = self - .compute_keystream( - explicit_nonce, - self.config.start_ctr, - ciphertext.len(), - ExecutionMode::Mpc, - ) - .await?; - - let ciphertext_ids = self.ciphertext_ids(ciphertext.len()); - let plaintext = self - .apply_keystream( - InputText::Public { - ids: ciphertext_ids, - text: ciphertext, - }, - keystream, - ExecutionMode::Mpc, - ) - .await?; - - let plaintext: Vec = self - .decode_public(plaintext) - .await? - .try_into() - .expect("plaintext is array"); - - Ok(plaintext) - } - - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self), err) - )] - async fn decrypt_private( - &mut self, - explicit_nonce: Vec, - ciphertext: Vec, - ) -> Result, StreamCipherError> { - let keystream_ref = self - .compute_keystream( - explicit_nonce, - self.config.start_ctr, - ciphertext.len(), - ExecutionMode::Mpc, - ) - .await?; - - let keystream: Vec = self - .decode_private(keystream_ref.clone()) - .await? - .try_into() - .expect("keystream is array"); - - let plaintext = ciphertext - .into_iter() - .zip(keystream) - .map(|(c, k)| c ^ k) - .collect::>(); - - // Prove plaintext encrypts back to ciphertext - let plaintext_ids = self.plaintext_ids(plaintext.len()); - let ciphertext = self - .apply_keystream( - InputText::Private { - ids: plaintext_ids, - text: plaintext.clone(), - }, - keystream_ref, - ExecutionMode::Prove, - ) - .await?; - - self.prove(ciphertext).await?; - - Ok(plaintext) - } - - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self), err) - )] - async fn decrypt_blind( - &mut self, - explicit_nonce: Vec, - ciphertext: Vec, - ) -> Result<(), StreamCipherError> { - let keystream_ref = self - .compute_keystream( - explicit_nonce, - self.config.start_ctr, - ciphertext.len(), - ExecutionMode::Mpc, - ) - .await?; - - self.decode_blind(keystream_ref.clone()).await?; - - // Verify the plaintext encrypts back to ciphertext - let plaintext_ids = self.plaintext_ids(ciphertext.len()); - let ciphertext_ref = self - .apply_keystream( - InputText::Blind { ids: plaintext_ids }, - keystream_ref, - ExecutionMode::Verify, - ) - .await?; - - self.verify(ciphertext_ref, ciphertext.into()).await?; - - Ok(()) - } - - async fn prove_plaintext( - &mut self, - explicit_nonce: Vec, - ciphertext: Vec, - ) -> Result, StreamCipherError> { - let KeyAndIv { key, iv } = self - .state - .key_iv - .clone() - .ok_or(StreamCipherError::KeyIvNotSet)?; - - let plaintext = C::apply_keystream( - &key, - &iv, - self.config.start_ctr, - &explicit_nonce, - &ciphertext, - )?; - - // Prove plaintext encrypts back to ciphertext - let keystream = self - .compute_keystream( - explicit_nonce, - self.config.start_ctr, - plaintext.len(), - ExecutionMode::Prove, - ) - .await?; - - let plaintext_ids = self.plaintext_ids(plaintext.len()); - let ciphertext = self - .apply_keystream( - InputText::Private { - ids: plaintext_ids, - text: plaintext.clone(), - }, - keystream, - ExecutionMode::Prove, - ) - .await?; - - self.prove(ciphertext).await?; - - Ok(plaintext) - } - - async fn verify_plaintext( - &mut self, - explicit_nonce: Vec, - ciphertext: Vec, - ) -> Result<(), StreamCipherError> { - let keystream = self - .compute_keystream( - explicit_nonce, - self.config.start_ctr, - ciphertext.len(), - ExecutionMode::Verify, - ) - .await?; - - let plaintext_ids = self.plaintext_ids(ciphertext.len()); - let ciphertext_ref = self - .apply_keystream( - InputText::Blind { ids: plaintext_ids }, - keystream, - ExecutionMode::Verify, - ) - .await?; - - self.verify(ciphertext_ref, ciphertext.into()).await?; - - Ok(()) - } - - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip(self), err) - )] - async fn share_keystream_block( - &mut self, - explicit_nonce: Vec, - ctr: usize, - ) -> Result, StreamCipherError> { - let EncodedKeyAndIv { key, iv } = self - .state - .encoded_key_iv - .clone() - .ok_or(StreamCipherError::KeyIvNotSet)?; - - let explicit_nonce_len = explicit_nonce.len(); - let explicit_nonce: C::NONCE = explicit_nonce.try_into().map_err(|_| { - StreamCipherError::InvalidExplicitNonceLength { - expected: C::NONCE_LEN, - actual: explicit_nonce_len, - } - })?; - - let block_id = self.state.execution_id.increment_in_place(); - let mut scope = self.thread_pool.new_scope(); - scope.push(move |thread| { - Box::pin(async move { - let key_block = compute_key_block( - thread, - block_id, - KeyBlockConfig::::new(key, iv, explicit_nonce, ctr as u32), - ExecutionMode::Mpc, - ) - .await?; - - let share = thread - .decode_shared(&[key_block]) - .await? - .into_iter() - .next() - .unwrap(); - - Ok::<_, StreamCipherError>(share) - }) - }); - - let share: Vec = scope - .wait() - .await - .into_iter() - .next() - .unwrap()? - .try_into() - .expect("share is an array"); - - Ok(share) - } -} - -#[derive(Debug, Clone, Copy)] -enum ExecutionMode { - Mpc, - Prove, - Verify, -} - -async fn apply_keystream( - thread: &mut T, - mode: ExecutionMode, - execution_id: NestedId, - input_text: InputText, - keystream: ValueRef, -) -> Result { - let input_text = match input_text { - InputText::Public { ids, text } => { - let refs = text - .into_iter() - .zip(ids) - .map(|(byte, id)| { - let value_ref = thread.new_public_input::(&id)?; - thread.assign(&value_ref, byte)?; - - Ok::<_, StreamCipherError>(value_ref) - }) - .collect::, _>>()?; - thread.array_from_values(&refs)? - } - InputText::Private { ids, text } => { - let refs = text - .into_iter() - .zip(ids) - .map(|(byte, id)| { - let value_ref = thread.new_private_input::(&id)?; - thread.assign(&value_ref, byte)?; - - Ok::<_, StreamCipherError>(value_ref) - }) - .collect::, _>>()?; - thread.array_from_values(&refs)? - } - InputText::Blind { ids } => { - let refs = ids - .into_iter() - .map(|id| thread.new_blind_input::(&id)) - .collect::, _>>()?; - thread.array_from_values(&refs)? - } - }; - - let output_text = thread.new_array_output::( - &execution_id.append_string("output").to_string(), - input_text.len(), - )?; - - let circ = build_array_xor(input_text.len()); - - match mode { - ExecutionMode::Mpc => { - thread - .execute(circ, &[input_text, keystream], &[output_text.clone()]) - .await?; - } - ExecutionMode::Prove => { - thread - .execute_prove(circ, &[input_text, keystream], &[output_text.clone()]) - .await?; - } - ExecutionMode::Verify => { - thread - .execute_verify(circ, &[input_text, keystream], &[output_text.clone()]) - .await?; - } - } - - Ok(output_text) -} - -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(thread_pool), err) -)] -async fn compute_keystream< - T: Thread + Memory + Execute + Prove + Verify + Decode + DecodePrivate + Send + 'static, - C: CtrCircuit, ->( - thread_pool: &mut ThreadPool, - execution_id: NestedId, - configs: Vec>, - len: usize, - mode: ExecutionMode, -) -> Result { - let mut block_id = execution_id.append_counter(); - let mut scope = thread_pool.new_scope(); - - for config in configs { - let block_id = block_id.increment_in_place(); - scope.push(move |thread| Box::pin(compute_key_block(thread, block_id, config, mode))); - } - - let key_blocks = scope - .wait() - .await - .into_iter() - .collect::, _>>()?; - - // Flatten the key blocks into a single array. - let keystream = key_blocks - .iter() - .flat_map(|block| block.iter()) - .take(len) - .cloned() - .map(|id| ValueRef::Value { id }) - .collect::>(); - - let mut scope = thread_pool.new_scope(); - scope.push(move |thread| Box::pin(async move { thread.array_from_values(&keystream) })); - - let keystream = scope.wait().await.into_iter().next().unwrap()?; - - Ok(keystream) -} - -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(thread), err) -)] -async fn compute_key_block< - T: Memory + Execute + Prove + Verify + Decode + DecodePrivate + Send, - C: CtrCircuit, ->( - thread: &mut T, - block_id: NestedId, - config: KeyBlockConfig, - mode: ExecutionMode, -) -> Result { - let KeyBlockConfig { - key, - iv, - explicit_nonce, - ctr, - .. - } = config; - - let explicit_nonce_ref = thread.new_public_input::<::NONCE>( - &block_id.append_string("explicit_nonce").to_string(), - )?; - let ctr_ref = thread.new_public_input::<[u8; 4]>(&block_id.append_string("ctr").to_string())?; - let key_block = - thread.new_output::(&block_id.append_string("key_block").to_string())?; - - thread.assign(&explicit_nonce_ref, explicit_nonce)?; - thread.assign(&ctr_ref, ctr.to_be_bytes())?; - - // Execute circuit - match mode { - ExecutionMode::Mpc => { - thread - .execute( - C::circuit(), - &[key, iv, explicit_nonce_ref, ctr_ref], - &[key_block.clone()], - ) - .await?; - } - ExecutionMode::Prove => { - thread - .execute_prove( - C::circuit(), - &[key, iv, explicit_nonce_ref, ctr_ref], - &[key_block.clone()], - ) - .await?; - } - ExecutionMode::Verify => { - thread - .execute_verify( - C::circuit(), - &[key, iv, explicit_nonce_ref, ctr_ref], - &[key_block.clone()], - ) - .await?; - } - } - - Ok(key_block) -} diff --git a/components/integration-tests/Cargo.toml b/components/integration-tests/Cargo.toml deleted file mode 100644 index 6f6402566b..0000000000 --- a/components/integration-tests/Cargo.toml +++ /dev/null @@ -1,37 +0,0 @@ -[package] -name = "integration-tests" -version = "0.0.0" -edition = "2021" -publish = false - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] - -[profile.release] -lto = true - - -[dev-dependencies] -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -tlsn-block-cipher = { path = "../cipher/block-cipher" } -tlsn-stream-cipher = { path = "../cipher/stream-cipher" } -tlsn-universal-hash = { path = "../universal-hash" } -tlsn-aead = { path = "../aead" } -tlsn-key-exchange = { path = "../key-exchange" } -tlsn-point-addition = { path = "../point-addition" } -tlsn-hmac-sha256 = { path = "../prf/hmac-sha256" } -tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" } - -uid-mux = { path = "../uid-mux" } - -p256 = { version = "0.13" } - -futures = "0.3" -rand_chacha = "0.3" -rand = "0.8" - -tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] } -tokio-util = { version = "0.7", features = ["compat"] } diff --git a/components/integration-tests/tests/test.rs b/components/integration-tests/tests/test.rs deleted file mode 100644 index 04cefbfde0..0000000000 --- a/components/integration-tests/tests/test.rs +++ /dev/null @@ -1,396 +0,0 @@ -use aead::{ - aes_gcm::{AesGcmConfig, MpcAesGcm, Role as AesGcmRole}, - Aead, -}; -use block_cipher::{Aes128, BlockCipherConfigBuilder, MpcBlockCipher}; -use ff::Gf2_128; -use futures::StreamExt; -use hmac_sha256::{MpcPrf, Prf, PrfConfig, SessionKeys}; -use key_exchange::{KeyExchange, KeyExchangeConfig, Role as KeyExchangeRole}; -use mpz_garble::{config::Role as GarbleRole, protocol::deap::DEAPVm, Vm}; -use mpz_ot::{ - actor::kos::{ReceiverActor, SenderActor}, - chou_orlandi::{ - Receiver as BaseReceiver, ReceiverConfig as BaseReceiverConfig, Sender as BaseSender, - SenderConfig as BaseSenderConfig, - }, - kos::{Receiver, ReceiverConfig, Sender, SenderConfig}, -}; -use mpz_share_conversion as ff; -use mpz_share_conversion::{ShareConversionReveal, ShareConversionVerify}; -use p256::{NonZeroScalar, PublicKey, SecretKey}; -use point_addition::{MpcPointAddition, Role as PointAdditionRole, P256}; -use rand::SeedableRng; -use rand_chacha::ChaCha20Rng; -use tlsn_stream_cipher::{Aes128Ctr, MpcStreamCipher, StreamCipherConfig}; -use tlsn_universal_hash::ghash::{Ghash, GhashConfig}; -use tokio_util::compat::TokioAsyncReadCompatExt; -use uid_mux::{yamux, UidYamux}; -use utils_aio::{codec::BincodeMux, mux::MuxChannel}; - -const OT_SETUP_COUNT: usize = 50_000; - -/// The following integration test checks the interplay of individual components of the TLSNotary -/// protocol. These are: -/// - channel multiplexing -/// - oblivious transfer -/// - point addition -/// - key exchange -/// - prf -/// - aead cipher (stream cipher + ghash) -#[tokio::test] -async fn test_components() { - let mut rng = ChaCha20Rng::seed_from_u64(0); - - let (leader_socket, follower_socket) = tokio::io::duplex(1 << 25); - - let mut leader_mux = UidYamux::new( - yamux::Config::default(), - leader_socket.compat(), - yamux::Mode::Client, - ); - let mut follower_mux = UidYamux::new( - yamux::Config::default(), - follower_socket.compat(), - yamux::Mode::Server, - ); - - let leader_mux_control = leader_mux.control(); - let follower_mux_control = follower_mux.control(); - - tokio::spawn(async move { leader_mux.run().await.unwrap() }); - tokio::spawn(async move { follower_mux.run().await.unwrap() }); - - let mut leader_mux = BincodeMux::new(leader_mux_control); - let mut follower_mux = BincodeMux::new(follower_mux_control); - - let leader_ot_sender_config = SenderConfig::default(); - let follower_ot_recvr_config = ReceiverConfig::default(); - - let follower_ot_sender_config = SenderConfig::builder().sender_commit().build().unwrap(); - let leader_ot_recvr_config = ReceiverConfig::builder().sender_commit().build().unwrap(); - - let (leader_ot_sender_sink, leader_ot_sender_stream) = - leader_mux.get_channel("ot/0").await.unwrap().split(); - - let (follower_ot_recvr_sink, follower_ot_recvr_stream) = - follower_mux.get_channel("ot/0").await.unwrap().split(); - - let (leader_ot_receiver_sink, leader_ot_receiver_stream) = - leader_mux.get_channel("ot/1").await.unwrap().split(); - - let (follower_ot_sender_sink, follower_ot_sender_stream) = - follower_mux.get_channel("ot/1").await.unwrap().split(); - - let mut leader_ot_sender_actor = SenderActor::new( - Sender::new( - leader_ot_sender_config, - BaseReceiver::new(BaseReceiverConfig::default()), - ), - leader_ot_sender_sink, - leader_ot_sender_stream, - ); - - let mut follower_ot_recvr_actor = ReceiverActor::new( - Receiver::new( - follower_ot_recvr_config, - BaseSender::new(BaseSenderConfig::default()), - ), - follower_ot_recvr_sink, - follower_ot_recvr_stream, - ); - - let mut leader_ot_recvr_actor = ReceiverActor::new( - Receiver::new( - leader_ot_recvr_config, - BaseSender::new( - BaseSenderConfig::builder() - .receiver_commit() - .build() - .unwrap(), - ), - ), - leader_ot_receiver_sink, - leader_ot_receiver_stream, - ); - - let mut follower_ot_sender_actor = SenderActor::new( - Sender::new( - follower_ot_sender_config, - BaseReceiver::new( - BaseReceiverConfig::builder() - .receiver_commit() - .build() - .unwrap(), - ), - ), - follower_ot_sender_sink, - follower_ot_sender_stream, - ); - - let leader_ot_sender = leader_ot_sender_actor.sender(); - let follower_ot_recvr = follower_ot_recvr_actor.receiver(); - - let leader_ot_recvr = leader_ot_recvr_actor.receiver(); - let follower_ot_sender = follower_ot_sender_actor.sender(); - - tokio::spawn(async move { - leader_ot_sender_actor.setup(OT_SETUP_COUNT).await.unwrap(); - leader_ot_sender_actor.run().await.unwrap(); - }); - - tokio::spawn(async move { - follower_ot_recvr_actor.setup(OT_SETUP_COUNT).await.unwrap(); - follower_ot_recvr_actor.run().await.unwrap(); - }); - - tokio::spawn(async move { - leader_ot_recvr_actor.setup(OT_SETUP_COUNT).await.unwrap(); - leader_ot_recvr_actor.run().await.unwrap(); - }); - - tokio::spawn(async move { - follower_ot_sender_actor - .setup(OT_SETUP_COUNT) - .await - .unwrap(); - follower_ot_sender_actor.run().await.unwrap(); - follower_ot_sender_actor.reveal().await.unwrap(); - }); - - let mut leader_vm = DEAPVm::new( - "vm", - GarbleRole::Leader, - [0u8; 32], - leader_mux.get_channel("vm").await.unwrap(), - Box::new(leader_mux.clone()), - leader_ot_sender.clone(), - leader_ot_recvr.clone(), - ); - - let mut follower_vm = DEAPVm::new( - "vm", - GarbleRole::Follower, - [1u8; 32], - follower_mux.get_channel("vm").await.unwrap(), - Box::new(follower_mux.clone()), - follower_ot_sender.clone(), - follower_ot_recvr.clone(), - ); - - let leader_p256_sender = ff::ConverterSender::::new( - ff::SenderConfig::builder().id("p256/0").build().unwrap(), - leader_ot_sender.clone(), - leader_mux.get_channel("p256/0").await.unwrap(), - ); - - let leader_p256_receiver = ff::ConverterReceiver::::new( - ff::ReceiverConfig::builder().id("p256/1").build().unwrap(), - follower_ot_recvr.clone(), - leader_mux.get_channel("p256/1").await.unwrap(), - ); - - let follower_p256_sender = ff::ConverterSender::::new( - ff::SenderConfig::builder().id("p256/1").build().unwrap(), - leader_ot_sender.clone(), - follower_mux.get_channel("p256/1").await.unwrap(), - ); - - let follower_p256_receiver = ff::ConverterReceiver::::new( - ff::ReceiverConfig::builder().id("p256/0").build().unwrap(), - follower_ot_recvr.clone(), - follower_mux.get_channel("p256/0").await.unwrap(), - ); - - let leader_pa_sender = MpcPointAddition::new(PointAdditionRole::Leader, leader_p256_sender); - let leader_pa_receiver = MpcPointAddition::new(PointAdditionRole::Leader, leader_p256_receiver); - - let follower_pa_sender = - MpcPointAddition::new(PointAdditionRole::Follower, follower_p256_sender); - - let follower_pa_receiver = - MpcPointAddition::new(PointAdditionRole::Follower, follower_p256_receiver); - - let mut leader_ke = key_exchange::KeyExchangeCore::new( - leader_mux.get_channel("ke").await.unwrap(), - leader_pa_sender, - leader_pa_receiver, - leader_vm.new_thread("ke").await.unwrap(), - KeyExchangeConfig::builder() - .id("ke") - .role(KeyExchangeRole::Leader) - .build() - .unwrap(), - ); - - let mut follower_ke = key_exchange::KeyExchangeCore::new( - follower_mux.get_channel("ke").await.unwrap(), - follower_pa_sender, - follower_pa_receiver, - follower_vm.new_thread("ke").await.unwrap(), - KeyExchangeConfig::builder() - .id("ke") - .role(KeyExchangeRole::Follower) - .build() - .unwrap(), - ); - - let (leader_pms, follower_pms) = - futures::try_join!(leader_ke.setup(), follower_ke.setup()).unwrap(); - - let mut leader_prf = MpcPrf::new( - PrfConfig::builder() - .role(hmac_sha256::Role::Leader) - .build() - .unwrap(), - leader_vm.new_thread("prf/0").await.unwrap(), - leader_vm.new_thread("prf/1").await.unwrap(), - ); - let mut follower_prf = MpcPrf::new( - PrfConfig::builder() - .role(hmac_sha256::Role::Follower) - .build() - .unwrap(), - follower_vm.new_thread("prf/0").await.unwrap(), - follower_vm.new_thread("prf/1").await.unwrap(), - ); - - futures::try_join!( - leader_prf.setup(leader_pms.into_value()), - follower_prf.setup(follower_pms.into_value()) - ) - .unwrap(); - - let block_cipher_config = BlockCipherConfigBuilder::default() - .id("aes") - .build() - .unwrap(); - let leader_block_cipher = MpcBlockCipher::::new( - block_cipher_config.clone(), - leader_vm.new_thread("block_cipher").await.unwrap(), - ); - let follower_block_cipher = MpcBlockCipher::::new( - block_cipher_config, - follower_vm.new_thread("block_cipher").await.unwrap(), - ); - - let stream_cipher_config = StreamCipherConfig::builder() - .id("aes-ctr") - .transcript_id("tx") - .build() - .unwrap(); - let leader_stream_cipher = MpcStreamCipher::::new( - stream_cipher_config.clone(), - leader_vm.new_thread_pool("aes-ctr", 4).await.unwrap(), - ); - let follower_stream_cipher = MpcStreamCipher::::new( - stream_cipher_config, - follower_vm.new_thread_pool("aes-ctr", 4).await.unwrap(), - ); - - let mut leader_gf2 = ff::ConverterSender::::new( - ff::SenderConfig::builder() - .id("gf2") - .record() - .build() - .unwrap(), - leader_ot_sender.clone(), - leader_mux.get_channel("gf2").await.unwrap(), - ); - - let mut follower_gf2 = ff::ConverterReceiver::::new( - ff::ReceiverConfig::builder() - .id("gf2") - .record() - .build() - .unwrap(), - follower_ot_recvr.clone(), - follower_mux.get_channel("gf2").await.unwrap(), - ); - - let ghash_config = GhashConfig::builder() - .id("aes_gcm/ghash") - .initial_block_count(64) - .build() - .unwrap(); - - let leader_ghash = Ghash::new(ghash_config.clone(), leader_gf2.handle().unwrap()); - let follower_ghash = Ghash::new(ghash_config, follower_gf2.handle().unwrap()); - - let mut leader_aead = MpcAesGcm::new( - AesGcmConfig::builder() - .id("aes_gcm") - .role(AesGcmRole::Leader) - .build() - .unwrap(), - leader_mux.get_channel("aes_gcm").await.unwrap(), - Box::new(leader_block_cipher), - Box::new(leader_stream_cipher), - Box::new(leader_ghash), - ); - - let mut follower_aead = MpcAesGcm::new( - AesGcmConfig::builder() - .id("aes_gcm") - .role(AesGcmRole::Follower) - .build() - .unwrap(), - follower_mux.get_channel("aes_gcm").await.unwrap(), - Box::new(follower_block_cipher), - Box::new(follower_stream_cipher), - Box::new(follower_ghash), - ); - - let leader_private_key = SecretKey::random(&mut rng); - let follower_private_key = SecretKey::random(&mut rng); - let server_public_key = PublicKey::from_secret_scalar(&NonZeroScalar::random(&mut rng)); - - // Setup complete - - let _ = tokio::try_join!( - leader_ke.compute_client_key(leader_private_key), - follower_ke.compute_client_key(follower_private_key) - ) - .unwrap(); - - leader_ke.set_server_key(server_public_key); - - tokio::try_join!(leader_ke.compute_pms(), follower_ke.compute_pms()).unwrap(); - - let (leader_session_keys, follower_session_keys) = tokio::try_join!( - leader_prf.compute_session_keys_private([0u8; 32], [0u8; 32]), - follower_prf.compute_session_keys_blind() - ) - .unwrap(); - - let SessionKeys { - client_write_key: leader_key, - client_iv: leader_iv, - .. - } = leader_session_keys; - - let SessionKeys { - client_write_key: follower_key, - client_iv: follower_iv, - .. - } = follower_session_keys; - - tokio::try_join!( - leader_aead.set_key(leader_key, leader_iv), - follower_aead.set_key(follower_key, follower_iv) - ) - .unwrap(); - - let msg = vec![0u8; 4096]; - - let _ = tokio::try_join!( - leader_aead.encrypt_private(vec![0u8; 8], msg.clone(), vec![]), - follower_aead.encrypt_blind(vec![0u8; 8], msg.len(), vec![]) - ) - .unwrap(); - - follower_ot_sender.shutdown().await.unwrap(); - - tokio::try_join!(leader_vm.finalize(), follower_vm.finalize()).unwrap(); - tokio::try_join!(leader_gf2.reveal(), follower_gf2.verify()).unwrap(); -} diff --git a/components/key-exchange/Cargo.toml b/components/key-exchange/Cargo.toml deleted file mode 100644 index 0fdcb43a22..0000000000 --- a/components/key-exchange/Cargo.toml +++ /dev/null @@ -1,37 +0,0 @@ -[package] -name = "tlsn-key-exchange" -authors = ["TLSNotary Team"] -description = "Implementation of the TLSNotary-specific key-exchange protocol" -keywords = ["tls", "mpc", "2pc", "pms", "key-exchange"] -categories = ["cryptography"] -license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" -edition = "2021" - -[lib] -name = "key_exchange" - -[features] -default = ["mock"] -tracing = ["dep:tracing", "tlsn-point-addition/tracing"] -mock = [] - -[dependencies] -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" } -mpz-share-conversion-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -tlsn-point-addition = { path = "../point-addition" } -p256 = { version = "0.13", features = ["ecdh"] } -async-trait = "0.1" -thiserror = "1" -serde = "1" -futures = "0.3" -derive_builder = "0.12" -tracing = { version = "0.1", optional = true } - -[dev-dependencies] -rand_chacha = "0.3" -rand_core = "0.6" -tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] } diff --git a/components/key-exchange/src/circuit.rs b/components/key-exchange/src/circuit.rs deleted file mode 100644 index 82818ea6cc..0000000000 --- a/components/key-exchange/src/circuit.rs +++ /dev/null @@ -1,43 +0,0 @@ -//! This module provides the circuits used in the key exchange protocol - -use std::sync::Arc; - -use mpz_circuits::{circuits::big_num::nbyte_add_mod_trace, Circuit, CircuitBuilder}; - -/// NIST P-256 prime big-endian -static P: [u8; 32] = [ - 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, -]; - -/// Circuit for combining additive shares of the PMS, twice -/// -/// # Inputs -/// -/// 0. PMS_SHARE_A: 32 bytes PMS Additive Share -/// 1. PMS_SHARE_B: 32 bytes PMS Additive Share -/// 2. PMS_SHARE_C: 32 bytes PMS Additive Share -/// 3. PMS_SHARE_D: 32 bytes PMS Additive Share -/// -/// # Outputs -/// 0. PMS1: Pre-master Secret = PMS_SHARE_A + PMS_SHARE_B -/// 1. PMS2: Pre-master Secret = PMS_SHARE_C + PMS_SHARE_D -/// 2. EQ: Equality check of PMS1 and PMS2 -pub(crate) fn build_pms_circuit() -> Arc { - let builder = CircuitBuilder::new(); - let share_a = builder.add_array_input::(); - let share_b = builder.add_array_input::(); - let share_c = builder.add_array_input::(); - let share_d = builder.add_array_input::(); - - let a = nbyte_add_mod_trace(builder.state(), share_a, share_b, P); - let b = nbyte_add_mod_trace(builder.state(), share_c, share_d, P); - - let eq: [_; 32] = std::array::from_fn(|i| a[i] ^ b[i]); - - builder.add_output(a); - builder.add_output(b); - builder.add_output(eq); - - Arc::new(builder.build().expect("pms circuit is valid")) -} diff --git a/components/key-exchange/src/exchange.rs b/components/key-exchange/src/exchange.rs deleted file mode 100644 index a84bdd8548..0000000000 --- a/components/key-exchange/src/exchange.rs +++ /dev/null @@ -1,621 +0,0 @@ -//! This module implements the key exchange logic - -use async_trait::async_trait; -use futures::{SinkExt, StreamExt}; -use mpz_garble::{value::ValueRef, Decode, Execute, Load, Memory}; - -use mpz_share_conversion_core::fields::{p256::P256, Field}; -use p256::{EncodedPoint, PublicKey, SecretKey}; -use point_addition::PointAddition; -use std::fmt::Debug; - -use utils_aio::expect_msg_or_err; - -use crate::{ - circuit::build_pms_circuit, - config::{KeyExchangeConfig, Role}, - KeyExchange, KeyExchangeChannel, KeyExchangeError, KeyExchangeMessage, Pms, -}; - -enum State { - Initialized, - Setup { - share_a: ValueRef, - share_b: ValueRef, - share_c: ValueRef, - share_d: ValueRef, - pms_1: ValueRef, - pms_2: ValueRef, - eq: ValueRef, - }, - KeyExchange { - share_a: ValueRef, - share_b: ValueRef, - share_c: ValueRef, - share_d: ValueRef, - pms_1: ValueRef, - pms_2: ValueRef, - eq: ValueRef, - }, - Complete, - Error, -} - -/// The instance for performing the key exchange protocol -/// -/// Can be either a leader or a follower depending on the `role` field in [KeyExchangeConfig] -pub struct KeyExchangeCore { - /// A channel for exchanging messages between leader and follower - channel: KeyExchangeChannel, - /// The sender instance for performing point addition - point_addition_sender: PS, - /// The receiver instance for performing point addition - point_addition_receiver: PR, - /// MPC executor - executor: E, - /// The private key of the party behind this instance, either follower or leader - private_key: Option, - /// The public key of the server - server_key: Option, - /// The config used for the key exchange protocol - config: KeyExchangeConfig, - /// The state of the protocol - state: State, -} - -impl Debug for KeyExchangeCore -where - PS: PointAddition + Send + Debug, - PR: PointAddition + Send + Debug, - E: Memory + Execute + Decode + Send, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("KeyExchangeCore") - .field("channel", &"{{ ... }}") - .field("point_addition_sender", &"{{ ... }}") - .field("point_addition_receiver", &"{{ ... }}") - .field("executor", &"{{ ... }}") - .field("private_key", &"{{ ... }}") - .field("server_key", &self.server_key) - .field("config", &self.config) - .finish() - } -} - -impl KeyExchangeCore -where - PS: PointAddition + Send + Debug, - PR: PointAddition + Send + Debug, - E: Memory + Execute + Decode + Send, -{ - /// Creates a new [KeyExchangeCore] - /// - /// * `channel` - The channel for sending messages between leader and follower - /// * `point_addition_sender` - The point addition sender instance used during key exchange - /// * `point_addition_receiver` - The point addition receiver instance used during key exchange - /// * `executor` - The MPC executor - /// * `config` - The config used for the key exchange protocol - #[cfg_attr( - feature = "tracing", - tracing::instrument( - level = "info", - skip(channel, executor, point_addition_sender, point_addition_receiver), - ret - ) - )] - pub fn new( - channel: KeyExchangeChannel, - point_addition_sender: PS, - point_addition_receiver: PR, - executor: E, - config: KeyExchangeConfig, - ) -> Self { - Self { - channel, - point_addition_sender, - point_addition_receiver, - executor, - private_key: None, - server_key: None, - config, - state: State::Initialized, - } - } - - async fn compute_pms_shares(&mut self) -> Result<(P256, P256), KeyExchangeError> { - let state = std::mem::replace(&mut self.state, State::Error); - - let State::Setup { - share_a, - share_b, - share_c, - share_d, - pms_1, - pms_2, - eq, - } = state - else { - todo!() - }; - - let server_key = match self.config.role() { - Role::Leader => { - // Send server public key to follower - if let Some(server_key) = &self.server_key { - self.channel - .send(KeyExchangeMessage::ServerPublicKey((*server_key).into())) - .await?; - - *server_key - } else { - return Err(KeyExchangeError::NoServerKey); - } - } - Role::Follower => { - // Receive server's public key from leader - let message = - expect_msg_or_err!(self.channel, KeyExchangeMessage::ServerPublicKey)?; - let server_key = message.try_into()?; - - self.server_key = Some(server_key); - - server_key - } - }; - - let private_key = self - .private_key - .take() - .ok_or(KeyExchangeError::NoPrivateKey)?; - - // Compute the leader's/follower's share of the pre-master secret - // - // We need to mimic the [diffie-hellman](p256::ecdh::diffie_hellman) function without the - // [SharedSecret](p256::ecdh::SharedSecret) wrapper, because this makes it harder to get - // the result as an EC curve point. - let shared_secret = { - let public_projective = server_key.to_projective(); - (public_projective * private_key.to_nonzero_scalar().as_ref()).to_affine() - }; - - let encoded_point = EncodedPoint::from(PublicKey::from_affine(shared_secret)?); - let (sender_share, receiver_share) = futures::try_join!( - self.point_addition_sender - .compute_x_coordinate_share(encoded_point), - self.point_addition_receiver - .compute_x_coordinate_share(encoded_point) - )?; - - self.state = State::KeyExchange { - share_a, - share_b, - share_c, - share_d, - pms_1, - pms_2, - eq, - }; - - match self.config.role() { - Role::Leader => Ok((sender_share, receiver_share)), - Role::Follower => Ok((receiver_share, sender_share)), - } - } - - async fn compute_pms_for( - &mut self, - pms_share1: P256, - pms_share2: P256, - ) -> Result { - let state = std::mem::replace(&mut self.state, State::Error); - - let State::KeyExchange { - share_a, - share_b, - share_c, - share_d, - pms_1, - pms_2, - eq, - } = state - else { - todo!() - }; - - let pms_share1: [u8; 32] = pms_share1 - .to_be_bytes() - .try_into() - .expect("pms share is 32 bytes"); - let pms_share2: [u8; 32] = pms_share2 - .to_be_bytes() - .try_into() - .expect("pms share is 32 bytes"); - - match self.config.role() { - Role::Leader => { - self.executor.assign(&share_a, pms_share1)?; - self.executor.assign(&share_c, pms_share2)?; - } - Role::Follower => { - self.executor.assign(&share_b, pms_share1)?; - self.executor.assign(&share_d, pms_share2)?; - } - } - - self.executor - .execute( - build_pms_circuit(), - &[share_a, share_b, share_c, share_d], - &[pms_1.clone(), pms_2, eq.clone()], - ) - .await?; - - #[cfg(feature = "tracing")] - tracing::event!(tracing::Level::DEBUG, "Successfully executed PMS circuit!"); - - let mut outputs = self.executor.decode(&[eq]).await?; - - let eq: [u8; 32] = outputs.remove(0).try_into().expect("eq is 32 bytes"); - - // Eq should be all zeros if pms_1 == pms_2 - if eq != [0u8; 32] { - return Err(KeyExchangeError::CheckFailed); - } - - self.state = State::Complete; - - // Both parties use pms_1 as the pre-master secret - Ok(Pms::new(pms_1)) - } -} - -#[async_trait] -impl KeyExchange for KeyExchangeCore -where - PS: PointAddition + Send + Debug, - PR: PointAddition + Send + Debug, - E: Memory + Load + Execute + Decode + Send, -{ - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip(self), ret) - )] - fn server_key(&self) -> Option { - self.server_key - } - - /// Set the server's public key - #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", skip(self)))] - fn set_server_key(&mut self, server_key: PublicKey) { - self.server_key = Some(server_key); - } - - async fn setup(&mut self) -> Result { - let state = std::mem::replace(&mut self.state, State::Error); - - let State::Initialized = state else { - return Err(KeyExchangeError::InvalidState( - "expected to be in Initialized state".to_string(), - )); - }; - - let (share_a, share_b, share_c, share_d) = match self.config.role() { - Role::Leader => { - let share_a = self - .executor - .new_private_input::<[u8; 32]>("pms/share_a") - .unwrap(); - let share_b = self - .executor - .new_blind_input::<[u8; 32]>("pms/share_b") - .unwrap(); - let share_c = self - .executor - .new_private_input::<[u8; 32]>("pms/share_c") - .unwrap(); - let share_d = self - .executor - .new_blind_input::<[u8; 32]>("pms/share_d") - .unwrap(); - - (share_a, share_b, share_c, share_d) - } - Role::Follower => { - let share_a = self - .executor - .new_blind_input::<[u8; 32]>("pms/share_a") - .unwrap(); - let share_b = self - .executor - .new_private_input::<[u8; 32]>("pms/share_b") - .unwrap(); - let share_c = self - .executor - .new_blind_input::<[u8; 32]>("pms/share_c") - .unwrap(); - let share_d = self - .executor - .new_private_input::<[u8; 32]>("pms/share_d") - .unwrap(); - - (share_a, share_b, share_c, share_d) - } - }; - - let pms_1 = self.executor.new_output::<[u8; 32]>("pms/1")?; - let pms_2 = self.executor.new_output::<[u8; 32]>("pms/2")?; - let eq = self.executor.new_output::<[u8; 32]>("pms/eq")?; - - self.executor - .load( - build_pms_circuit(), - &[ - share_a.clone(), - share_b.clone(), - share_c.clone(), - share_d.clone(), - ], - &[pms_1.clone(), pms_2.clone(), eq.clone()], - ) - .await?; - - self.state = State::Setup { - share_a, - share_b, - share_c, - share_d, - pms_1: pms_1.clone(), - pms_2, - eq, - }; - - Ok(Pms::new(pms_1)) - } - - /// Compute the client's public key - /// - /// The client's public key in this context is the combined public key (EC point addition) of - /// the leader's public key and the follower's public key. - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip(self, private_key), ret, err) - )] - async fn compute_client_key( - &mut self, - private_key: SecretKey, - ) -> Result, KeyExchangeError> { - let public_key = private_key.public_key(); - self.private_key = Some(private_key); - - match self.config.role() { - Role::Leader => { - // Receive public key from follower - let message = - expect_msg_or_err!(self.channel, KeyExchangeMessage::FollowerPublicKey)?; - let follower_public_key: PublicKey = message.try_into()?; - - // Combine public keys - let client_public_key = PublicKey::from_affine( - (public_key.to_projective() + follower_public_key.to_projective()).to_affine(), - )?; - - Ok(Some(client_public_key)) - } - Role::Follower => { - // Send public key to leader - self.channel - .send(KeyExchangeMessage::FollowerPublicKey(public_key.into())) - .await?; - - Ok(None) - } - } - } - - /// Computes the PMS - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip(self), err) - )] - async fn compute_pms(&mut self) -> Result { - let (pms_share1, pms_share2) = self.compute_pms_shares().await?; - - self.compute_pms_for(pms_share1, pms_share2).await - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use mpz_garble::{ - protocol::deap::mock::{ - create_mock_deap_vm, MockFollower, MockFollowerThread, MockLeader, MockLeaderThread, - }, - Vm, - }; - use mpz_share_conversion_core::fields::{p256::P256, Field}; - use p256::{NonZeroScalar, PublicKey, SecretKey}; - use rand_chacha::ChaCha20Rng; - use rand_core::SeedableRng; - - use crate::{ - mock::{create_mock_key_exchange_pair, MockKeyExchange}, - KeyExchangeError, - }; - - async fn create_pair() -> ( - ( - MockKeyExchange, - MockKeyExchange, - ), - (MockLeader, MockFollower), - ) { - let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test").await; - ( - create_mock_key_exchange_pair( - "test", - leader_vm.new_thread("ke").await.unwrap(), - follower_vm.new_thread("ke").await.unwrap(), - ), - (leader_vm, follower_vm), - ) - } - - #[tokio::test] - async fn test_key_exchange() { - let mut rng = ChaCha20Rng::from_seed([0_u8; 32]); - - let leader_private_key = SecretKey::random(&mut rng); - let follower_private_key = SecretKey::random(&mut rng); - let server_public_key = PublicKey::from_secret_scalar(&NonZeroScalar::random(&mut rng)); - - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = create_pair().await; - - let client_public_key = perform_key_exchange( - &mut leader, - &mut follower, - leader_private_key.clone(), - follower_private_key.clone(), - server_public_key, - ) - .await; - - let expected_client_public_key = PublicKey::from_affine( - (leader_private_key.public_key().to_projective() - + follower_private_key.public_key().to_projective()) - .to_affine(), - ) - .unwrap(); - - assert_eq!(client_public_key, expected_client_public_key); - } - - #[tokio::test] - async fn test_compute_pms_share() { - let mut rng = ChaCha20Rng::from_seed([0_u8; 32]); - - let leader_private_key = SecretKey::random(&mut rng); - let follower_private_key = SecretKey::random(&mut rng); - let server_private_key = NonZeroScalar::random(&mut rng); - let server_public_key = PublicKey::from_secret_scalar(&server_private_key); - - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = create_pair().await; - - let client_public_key = perform_key_exchange( - &mut leader, - &mut follower, - leader_private_key.clone(), - follower_private_key.clone(), - server_public_key, - ) - .await; - - leader.set_server_key(server_public_key); - - let ((l_pms1, l_pms2), (f_pms1, f_pms2)) = - tokio::try_join!(leader.compute_pms_shares(), follower.compute_pms_shares()).unwrap(); - - let expected_ecdh_x = - p256::ecdh::diffie_hellman(server_private_key, client_public_key.as_affine()); - - assert_eq!( - expected_ecdh_x.raw_secret_bytes().to_vec(), - (l_pms1 + f_pms1).to_be_bytes() - ); - assert_eq!( - expected_ecdh_x.raw_secret_bytes().to_vec(), - (l_pms2 + f_pms2).to_be_bytes() - ); - assert_eq!(l_pms1 + f_pms1, l_pms2 + f_pms2); - assert_ne!(l_pms1, f_pms1); - assert_ne!(l_pms2, f_pms2); - assert_ne!(l_pms1, l_pms2); - assert_ne!(f_pms1, f_pms2); - } - - #[tokio::test] - async fn test_compute_pms() { - let mut rng = ChaCha20Rng::from_seed([0_u8; 32]); - - let leader_private_key = SecretKey::random(&mut rng); - let follower_private_key = SecretKey::random(&mut rng); - let server_private_key = NonZeroScalar::random(&mut rng); - let server_public_key = PublicKey::from_secret_scalar(&server_private_key); - - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = create_pair().await; - - _ = perform_key_exchange( - &mut leader, - &mut follower, - leader_private_key.clone(), - follower_private_key.clone(), - server_public_key, - ) - .await; - - leader.set_server_key(server_public_key); - - let (_leader_pms, _follower_pms) = - tokio::try_join!(leader.compute_pms(), follower.compute_pms()).unwrap(); - - assert_eq!(leader.server_key.unwrap(), server_public_key); - assert_eq!(follower.server_key.unwrap(), server_public_key); - } - - #[tokio::test] - async fn test_compute_pms_fail() { - let mut rng = ChaCha20Rng::from_seed([0_u8; 32]); - - let leader_private_key = SecretKey::random(&mut rng); - let follower_private_key = SecretKey::random(&mut rng); - let server_private_key = NonZeroScalar::random(&mut rng); - let server_public_key = PublicKey::from_secret_scalar(&server_private_key); - - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = create_pair().await; - - _ = perform_key_exchange( - &mut leader, - &mut follower, - leader_private_key.clone(), - follower_private_key.clone(), - server_public_key, - ) - .await; - - leader.set_server_key(server_public_key); - - let ((mut l_pms1, l_pms2), (f_pms1, f_pms2)) = - tokio::try_join!(leader.compute_pms_shares(), follower.compute_pms_shares()).unwrap(); - - l_pms1 = l_pms1 + P256::one(); - - let err = tokio::try_join!( - leader.compute_pms_for(l_pms1, l_pms2), - follower.compute_pms_for(f_pms1, f_pms2) - ) - .unwrap_err(); - - assert!(matches!(err, KeyExchangeError::CheckFailed)); - } - - async fn perform_key_exchange( - leader: &mut impl KeyExchange, - follower: &mut impl KeyExchange, - leader_private_key: SecretKey, - follower_private_key: SecretKey, - server_public_key: PublicKey, - ) -> PublicKey { - tokio::try_join!(leader.setup(), follower.setup()).unwrap(); - - let (client_public_key, _) = tokio::try_join!( - leader.compute_client_key(leader_private_key), - follower.compute_client_key(follower_private_key) - ) - .unwrap(); - - leader.set_server_key(server_public_key); - - client_public_key.unwrap() - } -} diff --git a/components/key-exchange/src/lib.rs b/components/key-exchange/src/lib.rs deleted file mode 100644 index 945cc1959e..0000000000 --- a/components/key-exchange/src/lib.rs +++ /dev/null @@ -1,108 +0,0 @@ -//! # The Key Exchange Protocol -//! -//! This crate implements a key exchange protocol with 3 parties, namely server, leader and -//! follower. The goal is to end up with a shared secret (ECDH) between the server and the client. -//! The client in this context is leader and follower combined, which means that each of them will -//! end up with a share of the shared secret. The leader will do all the necessary communication -//! with the server alone and forward all messages from and to the follower. -//! -//! A detailed description of this protocol can be found in our documentation -//! . - -#![deny(missing_docs, unreachable_pub, unused_must_use)] -#![deny(clippy::all)] -#![forbid(unsafe_code)] - -mod circuit; -mod config; -mod exchange; -#[cfg(feature = "mock")] -pub mod mock; -pub mod msg; - -pub use config::{ - KeyExchangeConfig, KeyExchangeConfigBuilder, KeyExchangeConfigBuilderError, Role, -}; -pub use exchange::KeyExchangeCore; -pub use msg::KeyExchangeMessage; - -/// A channel for exchanging key exchange messages -pub type KeyExchangeChannel = Box>; - -use async_trait::async_trait; -use mpz_garble::value::ValueRef; -use p256::{PublicKey, SecretKey}; -use utils_aio::duplex::Duplex; - -/// Pre-master secret. -#[derive(Debug, Clone)] -pub struct Pms(ValueRef); - -impl Pms { - /// Create a new PMS - pub fn new(value: ValueRef) -> Self { - Self(value) - } - - /// Get the value of the PMS - pub fn into_value(self) -> ValueRef { - self.0 - } -} - -/// An error that can occur during the key exchange protocol -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum KeyExchangeError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - MemoryError(#[from] mpz_garble::MemoryError), - #[error(transparent)] - LoadError(#[from] mpz_garble::LoadError), - #[error(transparent)] - ExecutionError(#[from] mpz_garble::ExecutionError), - #[error(transparent)] - DecodeError(#[from] mpz_garble::DecodeError), - #[error(transparent)] - PointAdditionError(#[from] point_addition::PointAdditionError), - #[error(transparent)] - PublicKey(#[from] p256::elliptic_curve::Error), - #[error(transparent)] - KeyParseError(#[from] msg::KeyParseError), - #[error("Server Key not set")] - NoServerKey, - #[error("Private key not set")] - NoPrivateKey, - #[error("invalid state: {0}")] - InvalidState(String), - #[error("PMS equality check failed")] - CheckFailed, -} - -/// A trait for the 3-party key exchange protocol -#[async_trait] -pub trait KeyExchange { - /// Get the server's public key - fn server_key(&self) -> Option; - - /// Set the server's public key - fn set_server_key(&mut self, server_key: PublicKey); - - /// Performs any necessary one-time setup, returning a reference to the PMS. - /// - /// The PMS will not be assigned until `compute_pms` is called. - async fn setup(&mut self) -> Result; - - /// Compute the client's public key - /// - /// The client's public key in this context is the combined public key (EC point addition) of - /// the leader's public key and the follower's public key. - async fn compute_client_key( - &mut self, - private_key: SecretKey, - ) -> Result, KeyExchangeError>; - - /// Computes the PMS - async fn compute_pms(&mut self) -> Result; -} diff --git a/components/key-exchange/src/mock.rs b/components/key-exchange/src/mock.rs deleted file mode 100644 index 1da507764f..0000000000 --- a/components/key-exchange/src/mock.rs +++ /dev/null @@ -1,56 +0,0 @@ -//! This module provides mock types for key exchange leader and follower and a function to create -//! such a pair - -use crate::{KeyExchangeConfig, KeyExchangeCore, KeyExchangeMessage, Role}; - -use mpz_garble::{Decode, Execute, Memory}; -use point_addition::mock::{ - mock_point_converter_pair, MockPointAdditionReceiver, MockPointAdditionSender, -}; -use utils_aio::duplex::MemoryDuplex; - -/// A mock key exchange instance -pub type MockKeyExchange = - KeyExchangeCore; - -/// Create a mock pair of key exchange leader and follower -pub fn create_mock_key_exchange_pair( - id: &str, - leader_executor: E, - follower_executor: E, -) -> (MockKeyExchange, MockKeyExchange) { - let (leader_pa_sender, follower_pa_recvr) = mock_point_converter_pair(&format!("{}/pa/0", id)); - let (follower_pa_sender, leader_pa_recvr) = mock_point_converter_pair(&format!("{}/pa/1", id)); - - let (leader_channel, follower_channel) = MemoryDuplex::::new(); - - let key_exchange_config_leader = KeyExchangeConfig::builder() - .id(id) - .role(Role::Leader) - .build() - .unwrap(); - - let key_exchange_config_follower = KeyExchangeConfig::builder() - .id(id) - .role(Role::Follower) - .build() - .unwrap(); - - let leader = KeyExchangeCore::new( - Box::new(leader_channel), - leader_pa_sender, - leader_pa_recvr, - leader_executor, - key_exchange_config_leader, - ); - - let follower = KeyExchangeCore::new( - Box::new(follower_channel), - follower_pa_sender, - follower_pa_recvr, - follower_executor, - key_exchange_config_follower, - ); - - (leader, follower) -} diff --git a/components/point-addition/Cargo.toml b/components/point-addition/Cargo.toml deleted file mode 100644 index 3772569800..0000000000 --- a/components/point-addition/Cargo.toml +++ /dev/null @@ -1,31 +0,0 @@ -[package] -name = "tlsn-point-addition" -authors = ["TLSNotary Team"] -description = "Addition of EC points using 2PC, producing additive secret-shares of the resulting x-coordinate" -keywords = ["tls", "mpc", "2pc", "ecc", "elliptic"] -categories = ["cryptography"] -license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" -edition = "2021" - -[lib] -name = "point_addition" - -[features] -default = ["mock"] -mock = ["dep:mpz-core"] -tracing = ["dep:tracing"] - -[dependencies] -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54", optional = true } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-share-conversion-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -p256 = { version = "0.13", features = ["arithmetic"] } -tracing = { version = "0.1", optional = true } -async-trait = "0.1" -thiserror = "1" - -[dev-dependencies] -tokio = { version = "1.23", features = ["macros", "rt", "rt-multi-thread"] } -rand_chacha = "0.3" -rand = "0.8" diff --git a/components/point-addition/src/conversion.rs b/components/point-addition/src/conversion.rs deleted file mode 100644 index 1da5e4c4c2..0000000000 --- a/components/point-addition/src/conversion.rs +++ /dev/null @@ -1,131 +0,0 @@ -//! This module implements a secure two-party computation protocol for adding two private EC points -//! and secret-sharing the resulting x coordinate (the shares are field elements of the field -//! underlying the elliptic curve). -//! This protocol has semi-honest security. -//! -//! The protocol is described in - -use std::marker::PhantomData; - -use super::{PointAddition, PointAdditionError}; -use async_trait::async_trait; -use mpz_share_conversion::ShareConversion; -use mpz_share_conversion_core::fields::{p256::P256, Field}; -use p256::EncodedPoint; - -/// The instance used for adding the curve points -#[derive(Debug)] -pub struct MpcPointAddition -where - F: Field, - C: ShareConversion, -{ - /// Indicates which role this converter instance will fulfill - role: Role, - /// The share converter - converter: C, - - _field: PhantomData, -} - -/// The role: either Leader or Follower -/// -/// Follower needs to perform an inversion operation on the point during point addition -#[allow(missing_docs)] -#[derive(Debug, Clone, Copy)] -pub enum Role { - Leader, - Follower, -} - -impl Role { - /// Adapt the point depending on the role - /// - /// One party needs to adapt the coordinates. We decided that this is the follower's job. - fn adapt_point(&self, [x, y]: [V; 2]) -> [V; 2] { - match self { - Role::Leader => [x, y], - Role::Follower => [-x, -y], - } - } -} - -impl MpcPointAddition -where - F: Field, - C: ShareConversion + std::fmt::Debug, -{ - /// Create a new [MpcPointAddition] instance - #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", ret))] - pub fn new(role: Role, converter: C) -> Self { - Self { - converter, - role, - _field: PhantomData, - } - } - - /// Perform the conversion of P = A + B => P_x = a + b - /// - /// Since we are only interested in the x-coordinate of P (for the PMS) and because elliptic - /// curve point addition is an expensive operation in 2PC, we secret-share the x-coordinate - /// of P as a simple addition of field elements between the two parties. So we go from an EC - /// point addition to an addition of field elements for the x-coordinate. - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip(point), err) - )] - async fn convert(&mut self, point: [F; 2]) -> Result { - let [x, y] = point; - let [x_n, y_n] = self.role.adapt_point([x, y]); - - let a2m_output = self.converter.to_multiplicative(vec![y_n, x_n]).await?; - - let a = a2m_output[0]; - let b = a2m_output[1]; - - let c = a * b.inverse(); - let c = c * c; - - let d = self.converter.to_additive(vec![c]).await?[0]; - let x_r = d + -x; - - Ok(x_r) - } -} - -#[async_trait] -impl PointAddition for MpcPointAddition -where - C: ShareConversion + Send + Sync + std::fmt::Debug, -{ - type Point = EncodedPoint; - type XCoordinate = P256; - - async fn compute_x_coordinate_share( - &mut self, - point: Self::Point, - ) -> Result { - let [x, y] = point_to_p256(point)?; - self.convert([x, y]).await - } -} - -/// Convert the external library's point type to our library's field type -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(point), err) -)] -pub(crate) fn point_to_p256(point: EncodedPoint) -> Result<[P256; 2], PointAdditionError> { - let mut x: [u8; 32] = (*point.x().ok_or(PointAdditionError::Coordinates)?).into(); - let mut y: [u8; 32] = (*point.y().ok_or(PointAdditionError::Coordinates)?).into(); - - // reverse to little endian - x.reverse(); - y.reverse(); - - let x = P256::try_from(x).unwrap(); - let y = P256::try_from(y).unwrap(); - - Ok([x, y]) -} diff --git a/components/point-addition/src/lib.rs b/components/point-addition/src/lib.rs deleted file mode 100644 index de7ef2ea0f..0000000000 --- a/components/point-addition/src/lib.rs +++ /dev/null @@ -1,115 +0,0 @@ -//! A secure two-party computation (2PC) library for converting additive shares of an elliptic -//! curve (EC) point into additive shares of said point's x-coordinate. The additive shares of the -//! x-coordinate are finite field elements. - -#![deny(missing_docs, unreachable_pub, unused_must_use)] -#![deny(clippy::all)] -#![forbid(unsafe_code)] - -use async_trait::async_trait; -use mpz_share_conversion::ShareConversionError; -use mpz_share_conversion_core::fields::Field; - -mod conversion; - -/// A mock implementation of the [PointAddition] trait -#[cfg(feature = "mock")] -pub mod mock; - -pub use conversion::{MpcPointAddition, Role}; -pub use mpz_share_conversion_core::fields::p256::P256; - -/// The error type for [PointAddition] -#[allow(missing_docs)] -#[derive(Debug, thiserror::Error)] -pub enum PointAdditionError { - #[error(transparent)] - ShareConversion(#[from] ShareConversionError), - #[error("Unable to get coordinates from elliptic curve point")] - Coordinates, -} - -/// A trait for secret-sharing the sum of two elliptic curve points as a sum of field elements -/// -/// This trait is for securely secret-sharing the addition of two elliptic curve points. -/// Let `P + Q = O = (x, y)`. Each party receives additive shares of the x-coordinate. -#[async_trait] -pub trait PointAddition { - /// The elliptic curve point type - type Point; - /// The x-coordinate type for the finite field underlying the EC - type XCoordinate: Field; - - /// Adds two elliptic curve points in 2PC, returning respective secret shares - /// of the resulting x-coordinate to both parties. - async fn compute_x_coordinate_share( - &mut self, - point: Self::Point, - ) -> Result; -} - -#[cfg(test)] -mod tests { - use crate::{conversion::point_to_p256, mock::mock_point_converter_pair, PointAddition}; - use mpz_core::Block; - use mpz_share_conversion_core::{fields::p256::P256, Field}; - use p256::{ - elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint}, - EncodedPoint, NonZeroScalar, ProjectivePoint, PublicKey, - }; - use rand::{Rng, SeedableRng}; - use rand_chacha::ChaCha12Rng; - - #[tokio::test] - async fn test_point_conversion() { - let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); - - let p1: [u8; 32] = rng.gen(); - let p2: [u8; 32] = rng.gen(); - - let p1 = curve_point_from_be_bytes(p1); - let p2 = curve_point_from_be_bytes(p2); - - let p = add_curve_points(&p1, &p2); - - let (mut c1, mut c2) = mock_point_converter_pair("test"); - - let c1_fut = c1.compute_x_coordinate_share(p1); - let c2_fut = c2.compute_x_coordinate_share(p2); - - let (c1_output, c2_output) = tokio::join!(c1_fut, c2_fut); - let (c1_output, c2_output) = (c1_output.unwrap(), c2_output.unwrap()); - - assert_eq!(point_to_p256(p).unwrap()[0], c1_output + c2_output); - } - - #[test] - fn test_point_to_p256() { - let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); - - let p_expected: [u8; 32] = rng.gen(); - let p_expected = curve_point_from_be_bytes(p_expected); - - let p256: [P256; 2] = point_to_p256(p_expected).unwrap(); - - let x: [u8; 32] = p256[0].to_be_bytes().try_into().unwrap(); - let y: [u8; 32] = p256[1].to_be_bytes().try_into().unwrap(); - - let p = EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), false); - - assert_eq!(p_expected, p); - } - - fn curve_point_from_be_bytes(bytes: [u8; 32]) -> EncodedPoint { - let scalar = NonZeroScalar::from_repr(bytes.into()).unwrap(); - let pk = PublicKey::from_secret_scalar(&scalar); - pk.to_encoded_point(false) - } - - fn add_curve_points(p1: &EncodedPoint, p2: &EncodedPoint) -> EncodedPoint { - let p1 = ProjectivePoint::from_encoded_point(p1).unwrap(); - let p2 = ProjectivePoint::from_encoded_point(p2).unwrap(); - let p = p1 + p2; - p.to_encoded_point(false) - } -} diff --git a/components/point-addition/src/mock.rs b/components/point-addition/src/mock.rs deleted file mode 100644 index 4860ce8421..0000000000 --- a/components/point-addition/src/mock.rs +++ /dev/null @@ -1,30 +0,0 @@ -use crate::{MpcPointAddition, Role}; -use mpz_share_conversion::{ - mock::{mock_converter_pair, MockConverterReceiver, MockConverterSender}, - ReceiverConfig, SenderConfig, -}; -use mpz_share_conversion_core::fields::p256::P256; - -/// A mock point addition sender implementing [MpcPointAddition] for [P256] -pub type MockPointAdditionSender = MpcPointAddition>; - -/// A mock point addition receiver implementing [MpcPointAddition] for [P256] -pub type MockPointAdditionReceiver = MpcPointAddition>; - -/// Create a pair of [MpcPointAddition] instances -pub fn mock_point_converter_pair(id: &str) -> (MockPointAdditionSender, MockPointAdditionReceiver) { - let (sender, receiver) = mock_converter_pair( - SenderConfig::builder() - .id(format!("{}/converter", id)) - .build() - .unwrap(), - ReceiverConfig::builder() - .id(format!("{}/converter", id)) - .build() - .unwrap(), - ); - ( - MpcPointAddition::new(Role::Leader, sender), - MpcPointAddition::new(Role::Follower, receiver), - ) -} diff --git a/components/prf/Cargo.toml b/components/prf/Cargo.toml deleted file mode 100644 index f3e91fbd87..0000000000 --- a/components/prf/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[workspace] -members = ["hmac-sha256-circuits", "hmac-sha256"] -resolver = "2" - -[workspace.dependencies] -# tlsn -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } - -# async -async-trait = "0.1" -futures = "0.3" -tokio = "1" - -# error/log -thiserror = "1" -tracing = "0.1" - -# testing -criterion = "0.5" diff --git a/components/prf/hmac-sha256/benches/prf.rs b/components/prf/hmac-sha256/benches/prf.rs deleted file mode 100644 index 25cafcfaa9..0000000000 --- a/components/prf/hmac-sha256/benches/prf.rs +++ /dev/null @@ -1,93 +0,0 @@ -use criterion::{criterion_group, criterion_main, Criterion}; - -use hmac_sha256::{MpcPrf, Prf, PrfConfig, Role}; -use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory, Vm}; - -#[allow(clippy::unit_arg)] -fn criterion_benchmark(c: &mut Criterion) { - let mut group = c.benchmark_group("prf"); - group.sample_size(10); - let rt = tokio::runtime::Runtime::new().unwrap(); - - group.bench_function("prf_setup", |b| b.to_async(&rt).iter(setup)); - group.bench_function("prf", |b| b.to_async(&rt).iter(prf)); -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); - -async fn setup() { - let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("bench").await; - - let mut leader = MpcPrf::new( - PrfConfig::builder().role(Role::Leader).build().unwrap(), - leader_vm.new_thread("prf/0").await.unwrap(), - leader_vm.new_thread("prf/1").await.unwrap(), - ); - let mut follower = MpcPrf::new( - PrfConfig::builder().role(Role::Follower).build().unwrap(), - follower_vm.new_thread("prf/0").await.unwrap(), - follower_vm.new_thread("prf/1").await.unwrap(), - ); - - let leader_thread = leader_vm.new_thread("setup").await.unwrap(); - let follower_thread = follower_vm.new_thread("setup").await.unwrap(); - - let leader_pms = leader_thread.new_public_input::<[u8; 32]>("pms").unwrap(); - let follower_pms = follower_thread.new_public_input::<[u8; 32]>("pms").unwrap(); - - futures::try_join!(leader.setup(leader_pms), follower.setup(follower_pms)).unwrap(); -} - -async fn prf() { - let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("bench").await; - - let mut leader = MpcPrf::new( - PrfConfig::builder().role(Role::Leader).build().unwrap(), - leader_vm.new_thread("prf/0").await.unwrap(), - leader_vm.new_thread("prf/1").await.unwrap(), - ); - let mut follower = MpcPrf::new( - PrfConfig::builder().role(Role::Follower).build().unwrap(), - follower_vm.new_thread("prf/0").await.unwrap(), - follower_vm.new_thread("prf/1").await.unwrap(), - ); - - let pms = [42u8; 32]; - - let client_random = [0u8; 32]; - let server_random = [1u8; 32]; - let cf_hs_hash = [2u8; 32]; - let sf_hs_hash = [3u8; 32]; - - let leader_thread = leader_vm.new_thread("setup").await.unwrap(); - let follower_thread = follower_vm.new_thread("setup").await.unwrap(); - - let leader_pms = leader_thread.new_public_input::<[u8; 32]>("pms").unwrap(); - let follower_pms = follower_thread.new_public_input::<[u8; 32]>("pms").unwrap(); - - leader_thread.assign(&leader_pms, pms).unwrap(); - follower_thread.assign(&follower_pms, pms).unwrap(); - - futures::try_join!(leader.setup(leader_pms), follower.setup(follower_pms)).unwrap(); - - let (_leader_keys, _follower_keys) = futures::try_join!( - leader.compute_session_keys_private(client_random, server_random), - follower.compute_session_keys_blind() - ) - .unwrap(); - - let _ = futures::try_join!( - leader.compute_client_finished_vd_private(cf_hs_hash), - follower.compute_client_finished_vd_blind() - ) - .unwrap(); - - let _ = futures::try_join!( - leader.compute_server_finished_vd_private(sf_hs_hash), - follower.compute_server_finished_vd_blind() - ) - .unwrap(); - - futures::try_join!(leader_vm.finalize(), follower_vm.finalize()).unwrap(); -} diff --git a/components/prf/hmac-sha256/src/error.rs b/components/prf/hmac-sha256/src/error.rs deleted file mode 100644 index 356f82b001..0000000000 --- a/components/prf/hmac-sha256/src/error.rs +++ /dev/null @@ -1,45 +0,0 @@ -use std::error::Error; - -use crate::prf::state::StateError; - -/// Errors that can occur during PRF computation. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum PrfError { - #[error("MPC backend error: {0:?}")] - Mpc(Box), - #[error("role error: {0:?}")] - RoleError(String), - #[error("Invalid state: {0}")] - InvalidState(String), -} - -impl From for PrfError { - fn from(err: StateError) -> Self { - PrfError::InvalidState(err.to_string()) - } -} - -impl From for PrfError { - fn from(err: mpz_garble::MemoryError) -> Self { - PrfError::Mpc(Box::new(err)) - } -} - -impl From for PrfError { - fn from(err: mpz_garble::LoadError) -> Self { - PrfError::Mpc(Box::new(err)) - } -} - -impl From for PrfError { - fn from(err: mpz_garble::ExecutionError) -> Self { - PrfError::Mpc(Box::new(err)) - } -} - -impl From for PrfError { - fn from(err: mpz_garble::DecodeError) -> Self { - PrfError::Mpc(Box::new(err)) - } -} diff --git a/components/prf/hmac-sha256/src/prf.rs b/components/prf/hmac-sha256/src/prf.rs deleted file mode 100644 index d501dc19c4..0000000000 --- a/components/prf/hmac-sha256/src/prf.rs +++ /dev/null @@ -1,475 +0,0 @@ -use std::{ - fmt::Debug, - sync::{Arc, OnceLock}, -}; - -use async_trait::async_trait; - -use hmac_sha256_circuits::{build_session_keys, build_verify_data}; -use mpz_circuits::Circuit; -use mpz_garble::{ - config::Visibility, value::ValueRef, Decode, DecodePrivate, Execute, Load, Memory, -}; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; - -use crate::{Prf, PrfConfig, PrfError, Role, SessionKeys, CF_LABEL, SF_LABEL}; - -#[cfg(feature = "tracing")] -use tracing::instrument; - -/// Circuit for computing TLS session keys. -static SESSION_KEYS_CIRC: OnceLock> = OnceLock::new(); -/// Circuit for computing TLS client verify data. -static CLIENT_VD_CIRC: OnceLock> = OnceLock::new(); -/// Circuit for computing TLS server verify data. -static SERVER_VD_CIRC: OnceLock> = OnceLock::new(); - -enum Msg { - Cf, - Sf, -} - -#[derive(Debug)] -pub(crate) struct Randoms { - pub(crate) client_random: ValueRef, - pub(crate) server_random: ValueRef, -} - -#[derive(Debug, Clone)] -pub(crate) struct HashState { - pub(crate) ms_outer_hash_state: ValueRef, - pub(crate) ms_inner_hash_state: ValueRef, -} - -#[derive(Debug)] -pub(crate) struct VerifyData { - pub(crate) handshake_hash: ValueRef, - pub(crate) vd: ValueRef, -} - -/// MPC PRF for computing TLS HMAC-SHA256 PRF. -pub struct MpcPrf { - config: PrfConfig, - state: state::State, - thread_0: E, - thread_1: E, -} - -impl Debug for MpcPrf { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MpcPrf") - .field("config", &self.config) - .field("state", &self.state) - .finish() - } -} - -impl MpcPrf -where - E: Load + Memory + Execute + DecodePrivate + Send, -{ - /// Creates a new instance of the PRF. - pub fn new(config: PrfConfig, thread_0: E, thread_1: E) -> MpcPrf { - MpcPrf { - config, - state: state::State::Initialized, - thread_0, - thread_1, - } - } - - /// Executes a circuit which computes TLS session keys. - async fn execute_session_keys( - &mut self, - randoms: Option<([u8; 32], [u8; 32])>, - ) -> Result { - let state::SessionKeys { - pms, - randoms: randoms_refs, - hash_state, - keys, - cf_vd, - sf_vd, - } = std::mem::replace(&mut self.state, state::State::Error).try_into_session_keys()?; - - let circ = SESSION_KEYS_CIRC - .get() - .expect("session keys circuit is set"); - - if let Some((client_random, server_random)) = randoms { - self.thread_0 - .assign(&randoms_refs.client_random, client_random)?; - self.thread_0 - .assign(&randoms_refs.server_random, server_random)?; - } - - self.thread_0 - .execute( - circ.clone(), - &[pms, randoms_refs.client_random, randoms_refs.server_random], - &[ - keys.client_write_key.clone(), - keys.server_write_key.clone(), - keys.client_iv.clone(), - keys.server_iv.clone(), - hash_state.ms_outer_hash_state.clone(), - hash_state.ms_inner_hash_state.clone(), - ], - ) - .await?; - - self.state = state::State::ClientFinished(state::ClientFinished { - hash_state, - cf_vd, - sf_vd, - }); - - Ok(keys) - } - - async fn execute_cf_vd( - &mut self, - handshake_hash: Option<[u8; 32]>, - ) -> Result, PrfError> { - let state::ClientFinished { - hash_state, - cf_vd, - sf_vd, - } = std::mem::replace(&mut self.state, state::State::Error).try_into_client_finished()?; - - let circ = CLIENT_VD_CIRC.get().expect("client vd circuit is set"); - - if let Some(handshake_hash) = handshake_hash { - self.thread_0 - .assign(&cf_vd.handshake_hash, handshake_hash)?; - } - - self.thread_0 - .execute( - circ.clone(), - &[ - hash_state.ms_outer_hash_state.clone(), - hash_state.ms_inner_hash_state.clone(), - cf_vd.handshake_hash, - ], - &[cf_vd.vd.clone()], - ) - .await?; - - let vd = if handshake_hash.is_some() { - let mut outputs = self.thread_0.decode_private(&[cf_vd.vd]).await?; - let vd: [u8; 12] = outputs.remove(0).try_into().expect("vd is 12 bytes"); - - Some(vd) - } else { - self.thread_0.decode_blind(&[cf_vd.vd]).await?; - - None - }; - - self.state = state::State::ServerFinished(state::ServerFinished { hash_state, sf_vd }); - - Ok(vd) - } - - async fn execute_sf_vd( - &mut self, - handshake_hash: Option<[u8; 32]>, - ) -> Result, PrfError> { - let state::ServerFinished { hash_state, sf_vd } = - std::mem::replace(&mut self.state, state::State::Error).try_into_server_finished()?; - - let circ = SERVER_VD_CIRC.get().expect("server vd circuit is set"); - - if let Some(handshake_hash) = handshake_hash { - self.thread_1 - .assign(&sf_vd.handshake_hash, handshake_hash)?; - } - - self.thread_1 - .execute( - circ.clone(), - &[ - hash_state.ms_outer_hash_state, - hash_state.ms_inner_hash_state, - sf_vd.handshake_hash, - ], - &[sf_vd.vd.clone()], - ) - .await?; - - let vd = if handshake_hash.is_some() { - let mut outputs = self.thread_1.decode_private(&[sf_vd.vd]).await?; - let vd: [u8; 12] = outputs.remove(0).try_into().expect("vd is 12 bytes"); - - Some(vd) - } else { - self.thread_1.decode_blind(&[sf_vd.vd]).await?; - - None - }; - - self.state = state::State::Complete; - - Ok(vd) - } -} - -#[async_trait] -impl Prf for MpcPrf -where - E: Memory + Load + Execute + Decode + DecodePrivate + Send, -{ - #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] - async fn setup(&mut self, pms: ValueRef) -> Result<(), PrfError> { - std::mem::replace(&mut self.state, state::State::Error).try_into_initialized()?; - - let visibility = match self.config.role { - Role::Leader => Visibility::Private, - Role::Follower => Visibility::Blind, - }; - - // Perform pre-computation for all circuits. - let (randoms, hash_state, keys) = - setup_session_keys(&mut self.thread_0, pms.clone(), visibility).await?; - - let (cf_vd, sf_vd) = futures::try_join!( - setup_finished_msg(&mut self.thread_0, Msg::Cf, hash_state.clone(), visibility), - setup_finished_msg(&mut self.thread_1, Msg::Sf, hash_state.clone(), visibility), - )?; - - self.state = state::State::SessionKeys(state::SessionKeys { - pms, - randoms, - hash_state, - keys, - cf_vd, - sf_vd, - }); - - Ok(()) - } - - #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] - async fn compute_session_keys_private( - &mut self, - client_random: [u8; 32], - server_random: [u8; 32], - ) -> Result { - if self.config.role != Role::Leader { - return Err(PrfError::RoleError( - "only leader can provide inputs".to_string(), - )); - } - - self.execute_session_keys(Some((client_random, server_random))) - .await - } - - #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] - async fn compute_client_finished_vd_private( - &mut self, - handshake_hash: [u8; 32], - ) -> Result<[u8; 12], PrfError> { - if self.config.role != Role::Leader { - return Err(PrfError::RoleError( - "only leader can provide inputs".to_string(), - )); - } - - self.execute_cf_vd(Some(handshake_hash)) - .await - .map(|hash| hash.expect("vd is decoded")) - } - - #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] - async fn compute_server_finished_vd_private( - &mut self, - handshake_hash: [u8; 32], - ) -> Result<[u8; 12], PrfError> { - if self.config.role != Role::Leader { - return Err(PrfError::RoleError( - "only leader can provide inputs".to_string(), - )); - } - - self.execute_sf_vd(Some(handshake_hash)) - .await - .map(|hash| hash.expect("vd is decoded")) - } - - #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] - async fn compute_session_keys_blind(&mut self) -> Result { - if self.config.role != Role::Follower { - return Err(PrfError::RoleError( - "leader must provide inputs".to_string(), - )); - } - - self.execute_session_keys(None).await - } - - #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] - async fn compute_client_finished_vd_blind(&mut self) -> Result<(), PrfError> { - if self.config.role != Role::Follower { - return Err(PrfError::RoleError( - "leader must provide inputs".to_string(), - )); - } - - self.execute_cf_vd(None).await.map(|_| ()) - } - - #[cfg_attr(feature = "tracing", instrument(level = "debug", skip(self), err))] - async fn compute_server_finished_vd_blind(&mut self) -> Result<(), PrfError> { - if self.config.role != Role::Follower { - return Err(PrfError::RoleError( - "leader must provide inputs".to_string(), - )); - } - - self.execute_sf_vd(None).await.map(|_| ()) - } -} - -pub(crate) mod state { - use super::*; - use enum_try_as_inner::EnumTryAsInner; - - #[derive(Debug, EnumTryAsInner)] - #[derive_err(Debug)] - pub(crate) enum State { - Initialized, - SessionKeys(SessionKeys), - ClientFinished(ClientFinished), - ServerFinished(ServerFinished), - Complete, - Error, - } - - #[derive(Debug)] - pub(crate) struct SessionKeys { - pub(crate) pms: ValueRef, - pub(crate) randoms: Randoms, - pub(crate) hash_state: HashState, - pub(crate) keys: crate::SessionKeys, - pub(crate) cf_vd: VerifyData, - pub(crate) sf_vd: VerifyData, - } - - #[derive(Debug)] - pub(crate) struct ClientFinished { - pub(crate) hash_state: HashState, - pub(crate) cf_vd: VerifyData, - pub(crate) sf_vd: VerifyData, - } - - #[derive(Debug)] - pub(crate) struct ServerFinished { - pub(crate) hash_state: HashState, - pub(crate) sf_vd: VerifyData, - } -} - -async fn setup_session_keys( - thread: &mut T, - pms: ValueRef, - visibility: Visibility, -) -> Result<(Randoms, HashState, SessionKeys), PrfError> { - let client_random = thread.new_input::<[u8; 32]>("client_finished", visibility)?; - let server_random = thread.new_input::<[u8; 32]>("server_finished", visibility)?; - - let client_write_key = thread.new_output::<[u8; 16]>("client_write_key")?; - let server_write_key = thread.new_output::<[u8; 16]>("server_write_key")?; - let client_iv = thread.new_output::<[u8; 4]>("client_write_iv")?; - let server_iv = thread.new_output::<[u8; 4]>("server_write_iv")?; - - let ms_outer_hash_state = thread.new_output::<[u32; 8]>("ms_outer_hash_state")?; - let ms_inner_hash_state = thread.new_output::<[u32; 8]>("ms_inner_hash_state")?; - - if SESSION_KEYS_CIRC.get().is_none() { - _ = SESSION_KEYS_CIRC.set(Backend::spawn(build_session_keys).await); - } - - let circ = SESSION_KEYS_CIRC - .get() - .expect("session keys circuit is set"); - - thread - .load( - circ.clone(), - &[pms, client_random.clone(), server_random.clone()], - &[ - client_write_key.clone(), - server_write_key.clone(), - client_iv.clone(), - server_iv.clone(), - ms_outer_hash_state.clone(), - ms_inner_hash_state.clone(), - ], - ) - .await?; - - Ok(( - Randoms { - client_random, - server_random, - }, - HashState { - ms_outer_hash_state, - ms_inner_hash_state, - }, - SessionKeys { - client_write_key, - server_write_key, - client_iv, - server_iv, - }, - )) -} - -async fn setup_finished_msg( - thread: &mut T, - msg: Msg, - hash_state: HashState, - visibility: Visibility, -) -> Result { - let name = match msg { - Msg::Cf => String::from("client_finished"), - Msg::Sf => String::from("server_finished"), - }; - - let handshake_hash = - thread.new_input::<[u8; 32]>(&format!("{name}/handshake_hash"), visibility)?; - let vd = thread.new_output::<[u8; 12]>(&format!("{name}/vd"))?; - - let circ = match msg { - Msg::Cf => &CLIENT_VD_CIRC, - Msg::Sf => &SERVER_VD_CIRC, - }; - - let label = match msg { - Msg::Cf => CF_LABEL, - Msg::Sf => SF_LABEL, - }; - - if circ.get().is_none() { - _ = circ.set(Backend::spawn(move || build_verify_data(label)).await); - } - - let circ = circ.get().expect("session keys circuit is set"); - - thread - .load( - circ.clone(), - &[ - hash_state.ms_outer_hash_state, - hash_state.ms_inner_hash_state, - handshake_hash.clone(), - ], - &[vd.clone()], - ) - .await?; - - Ok(VerifyData { handshake_hash, vd }) -} diff --git a/components/tls/Cargo.toml b/components/tls/Cargo.toml deleted file mode 100644 index 4632770094..0000000000 --- a/components/tls/Cargo.toml +++ /dev/null @@ -1,56 +0,0 @@ -[workspace] -members = [ - "tls-client", - "tls-backend", - "tls-core", - "tls-mpc", - "tls-client-async", - "tls-server-fixture", -] -resolver = "2" - -[workspace.dependencies] -# rand -rand = "0.8" -rand_chacha = "0.3" - -# crypto -aes = "0.8" -aes-gcm = "0.9" -sha2 = "0.10" -hmac = "0.12" -sct = "0.7" -digest = "0.10" -webpki = "0.22" -webpki-roots = "0.26" -ring = "0.17" -p256 = "0.13" -rustls-pemfile = "1" -rustls = "0.20" -async-rustls = "0.4" - -# async -async-trait = "0.1" -futures = "0.3" -tokio = "1" -tokio-util = "0.7" -hyper = "0.14" - -# serialization -bytes = "1" -serde = "1" - -# error/log -tracing = "0.1" -tracing-subscriber = "0.3" -thiserror = "1" -log = "0.4" -env_logger = "0.10" - -# testing -rstest = "0.12" - -# misc -derive_builder = "0.12" -enum-try-as-inner = "0.1" -web-time = "0.2" diff --git a/components/tls/tls-client/build.rs b/components/tls/tls-client/build.rs deleted file mode 100644 index 9c73252a65..0000000000 --- a/components/tls/tls-client/build.rs +++ /dev/null @@ -1,13 +0,0 @@ -/// This build script allows us to enable the `read_buf` language feature only -/// for Rust Nightly. -/// -/// See the comment in lib.rs to understand why we need this. - -#[cfg_attr(feature = "read_buf", rustversion::not(nightly))] -fn main() {} - -#[cfg(feature = "read_buf")] -#[rustversion::nightly] -fn main() { - println!("cargo:rustc-cfg=read_buf"); -} diff --git a/components/tls/tls-client/tests/client_cert_verifier.rs b/components/tls/tls-client/tests/client_cert_verifier.rs deleted file mode 100644 index 3f5d19c03f..0000000000 --- a/components/tls/tls-client/tests/client_cert_verifier.rs +++ /dev/null @@ -1,292 +0,0 @@ -//! Tests for configuring and using a [`ClientCertVerifier`] for a server. - -#![cfg(feature = "dangerous_configuration")] - -mod common; - -use crate::common::{ - dns_name, do_handshake_until_both_error, do_handshake_until_error, get_client_root_store, - make_client_config_with_versions, make_client_config_with_versions_with_auth, - make_pair_for_arc_configs, ErrorFromPeer, KeyType, ALL_KEY_TYPES, -}; -use rustls::{ - client::WebPkiVerifier, - internal::msgs::enums::{AlertDescription, ContentType}, - server::{ClientCertVerified, ClientCertVerifier}, - Certificate, ClientConnection, DistinguishedNames, Error, ServerConfig, ServerConnection, - SignatureScheme, -}; -use std::sync::Arc; - -// Client is authorized! -fn ver_ok() -> Result { - Ok(rustls::server::ClientCertVerified::assertion()) -} - -// Use when we shouldn't even attempt verification -fn ver_unreachable() -> Result { - unreachable!() -} - -// Verifier that returns an error that we can expect -fn ver_err() -> Result { - Err(Error::General("test err".to_string())) -} - -fn server_config_with_verifier( - kt: KeyType, - client_cert_verifier: MockClientVerifier, -) -> ServerConfig { - ServerConfig::builder() - .with_safe_defaults() - .with_client_cert_verifier(Arc::new(client_cert_verifier)) - .with_single_cert(kt.get_chain(), kt.get_key()) - .unwrap() -} - -#[test] -// Happy path, we resolve to a root, it is verified OK, should be able to connect -fn client_verifier_works() { - for kt in ALL_KEY_TYPES.iter() { - let client_verifier = MockClientVerifier { - verified: ver_ok, - subjects: Some(get_client_root_store(*kt).subjects()), - mandatory: Some(true), - offered_schemes: None, - }; - - let server_config = server_config_with_verifier(*kt, client_verifier); - let server_config = Arc::new(server_config); - - for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(*kt, &[version]); - let (mut client, mut server) = - make_pair_for_arc_configs(&Arc::new(client_config.clone()), &server_config); - let err = do_handshake_until_error(&mut client, &mut server); - assert_eq!(err, Ok(())); - } - } -} - -// Server offers no verification schemes -#[test] -fn client_verifier_no_schemes() { - for kt in ALL_KEY_TYPES.iter() { - let client_verifier = MockClientVerifier { - verified: ver_ok, - subjects: Some(get_client_root_store(*kt).subjects()), - mandatory: Some(true), - offered_schemes: Some(vec![]), - }; - - let server_config = server_config_with_verifier(*kt, client_verifier); - let server_config = Arc::new(server_config); - - for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(*kt, &[version]); - let (mut client, mut server) = - make_pair_for_arc_configs(&Arc::new(client_config.clone()), &server_config); - let err = do_handshake_until_error(&mut client, &mut server); - assert_eq!( - err, - Err(ErrorFromPeer::Client(Error::CorruptMessagePayload( - ContentType::Handshake - ))) - ); - } - } -} - -// Common case, we do not find a root store to resolve to -#[test] -fn client_verifier_no_root() { - for kt in ALL_KEY_TYPES.iter() { - let client_verifier = MockClientVerifier { - verified: ver_ok, - subjects: None, - mandatory: Some(true), - offered_schemes: None, - }; - - let server_config = server_config_with_verifier(*kt, client_verifier); - let server_config = Arc::new(server_config); - - for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(*kt, &[version]); - let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); - let mut client = - ClientConnection::new(Arc::new(client_config), dns_name("notlocalhost")).unwrap(); - let errs = do_handshake_until_both_error(&mut client, &mut server); - assert_eq!( - errs, - Err(vec![ - ErrorFromPeer::Server(Error::General( - "client rejected by client_auth_root_subjects".into() - )), - ErrorFromPeer::Client(Error::AlertReceived(AlertDescription::AccessDenied)) - ]) - ); - } - } -} - -// If we cannot resolve a root, we cannot decide if auth is mandatory -#[test] -fn client_verifier_no_auth_no_root() { - for kt in ALL_KEY_TYPES.iter() { - let client_verifier = MockClientVerifier { - verified: ver_unreachable, - subjects: None, - mandatory: Some(true), - offered_schemes: None, - }; - - let server_config = server_config_with_verifier(*kt, client_verifier); - let server_config = Arc::new(server_config); - - for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(*kt, &[version]); - let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); - let mut client = - ClientConnection::new(Arc::new(client_config), dns_name("notlocalhost")).unwrap(); - let errs = do_handshake_until_both_error(&mut client, &mut server); - assert_eq!( - errs, - Err(vec![ - ErrorFromPeer::Server(Error::General( - "client rejected by client_auth_root_subjects".into() - )), - ErrorFromPeer::Client(Error::AlertReceived(AlertDescription::AccessDenied)) - ]) - ); - } - } -} - -// If we do have a root, we must do auth -#[test] -fn client_verifier_no_auth_yes_root() { - for kt in ALL_KEY_TYPES.iter() { - let client_verifier = MockClientVerifier { - verified: ver_unreachable, - subjects: Some(get_client_root_store(*kt).subjects()), - mandatory: Some(true), - offered_schemes: None, - }; - - let server_config = server_config_with_verifier(*kt, client_verifier); - let server_config = Arc::new(server_config); - - for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(*kt, &[version]); - let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); - let mut client = - ClientConnection::new(Arc::new(client_config), dns_name("localhost")).unwrap(); - let errs = do_handshake_until_both_error(&mut client, &mut server); - assert_eq!( - errs, - Err(vec![ - ErrorFromPeer::Server(Error::NoCertificatesPresented), - ErrorFromPeer::Client(Error::AlertReceived( - AlertDescription::CertificateRequired - )) - ]) - ); - } - } -} - -#[test] -// Triple checks we propagate the rustls::Error through -fn client_verifier_fails_properly() { - for kt in ALL_KEY_TYPES.iter() { - let client_verifier = MockClientVerifier { - verified: ver_err, - subjects: Some(get_client_root_store(*kt).subjects()), - mandatory: Some(true), - offered_schemes: None, - }; - - let server_config = server_config_with_verifier(*kt, client_verifier); - let server_config = Arc::new(server_config); - - for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(*kt, &[version]); - let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); - let mut client = - ClientConnection::new(Arc::new(client_config), dns_name("localhost")).unwrap(); - let err = do_handshake_until_error(&mut client, &mut server); - assert_eq!( - err, - Err(ErrorFromPeer::Server(Error::General("test err".into()))) - ); - } - } -} - -#[test] -// If a verifier returns a None on Mandatory-ness, then we error out -fn client_verifier_must_determine_client_auth_requirement_to_continue() { - for kt in ALL_KEY_TYPES.iter() { - let client_verifier = MockClientVerifier { - verified: ver_ok, - subjects: Some(get_client_root_store(*kt).subjects()), - mandatory: None, - offered_schemes: None, - }; - - let server_config = server_config_with_verifier(*kt, client_verifier); - let server_config = Arc::new(server_config); - - for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(*kt, &[version]); - let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); - let mut client = - ClientConnection::new(Arc::new(client_config), dns_name("localhost")).unwrap(); - let errs = do_handshake_until_both_error(&mut client, &mut server); - assert_eq!( - errs, - Err(vec![ - ErrorFromPeer::Server(Error::General( - "client rejected by client_auth_mandatory".into() - )), - ErrorFromPeer::Client(Error::AlertReceived(AlertDescription::AccessDenied)) - ]) - ); - } - } -} - -pub struct MockClientVerifier { - pub verified: fn() -> Result, - pub subjects: Option, - pub mandatory: Option, - pub offered_schemes: Option>, -} - -impl ClientCertVerifier for MockClientVerifier { - fn client_auth_mandatory(&self) -> Option { - self.mandatory - } - - fn client_auth_root_subjects(&self) -> Option { - self.subjects.as_ref().cloned() - } - - fn verify_client_cert( - &self, - _end_entity: &Certificate, - _intermediates: &[Certificate], - _now: web_time::SystemTime, - ) -> Result { - (self.verified)() - } - - fn supported_verify_schemes(&self) -> Vec { - if let Some(schemes) = &self.offered_schemes { - schemes.clone() - } else { - WebPkiVerifier::verification_schemes() - } - } -} diff --git a/components/tls/tls-client/tests/key_log_file_env.rs b/components/tls/tls-client/tests/key_log_file_env.rs deleted file mode 100644 index 84f253270b..0000000000 --- a/components/tls/tls-client/tests/key_log_file_env.rs +++ /dev/null @@ -1,105 +0,0 @@ -//! Tests of [`tls_client::KeyLogFile`] that require us to set environment variables. -//! -//! vvvv -//! Every test you add to this file MUST execute through `serialized()`. -//! ^^^^ -//! -//! See https://github.com/rust-lang/rust/issues/90308; despite not being marked -//! `unsafe`, `env::var::set_var` is an unsafe function. These tests are separated -//! from the rest of the tests so that their use of `set_ver` is less likely to -//! affect them; as of the time these tests were moved to this file, Cargo will -//! compile each test suite file to a separate executable, so these will be run -//! in a completely separate process. This way, executing every test through -//! `serialized()` will cause them to be run one at a time. -//! -//! Note: If/when we add new constructors to `KeyLogFile` to allow constructing -//! one from a path directly (without using an environment variable), then those -//! tests SHOULD NOT go in this file. -//! -//! XXX: These tests don't actually test the functionality; they just ensure -//! the code coverage doesn't complain it isn't covered. TODO: Verify that the -//! file was created successfully, with the right permissions, etc., and that it -//! contains something like what we expect. - -#![allow(dead_code, unused_imports)] - -mod common; - -use crate::common::{ - do_handshake, make_client_config_with_versions, make_pair_for_arc_configs, make_server_config, - receive, send, KeyType, -}; -use std::{ - env, - io::Write, - sync::{Arc, Mutex, Once}, -}; - -/// Approximates `#[serial]` from the `serial_test` crate. -/// -/// No attempt is made to recover from a poisoned mutex, which will -/// happen when `f` panics. In other words, all the tests that use -/// `serialized` will start failing after one test panics. -fn serialized(f: impl FnOnce()) { - // Ensure every test is run serialized - // TODO: Use `std::sync::Lazy` once that is stable. - static mut MUTEX: Option> = None; - static ONCE: Once = Once::new(); - ONCE.call_once(|| unsafe { - MUTEX = Some(Mutex::new(())); - }); - let mutex = unsafe { MUTEX.as_mut() }; - - let _guard = mutex.unwrap().lock().unwrap(); - - // XXX: NOT thread safe. - env::set_var("SSLKEYLOGFILE", "./sslkeylogfile.txt"); - - f() -} - -// #[test] -// fn exercise_key_log_file_for_client() { -// serialized(|| { -// let server_config = Arc::new(make_server_config(KeyType::Rsa)); -// env::set_var("SSLKEYLOGFILE", "./sslkeylogfile.txt"); - -// for version in tls_client::ALL_VERSIONS { -// let mut client_config = make_client_config_with_versions(KeyType::Rsa, &[version]); -// client_config.key_log = Arc::new(tls_client::KeyLogFile::new()); - -// let (mut client, mut server) = -// make_pair_for_arc_configs(&Arc::new(client_config), &server_config); - -// assert_eq!(5, client.writer().write(b"hello").unwrap()); - -// do_handshake(&mut client, &mut server); -// send(&mut client, &mut server); -// server.process_new_packets().unwrap(); -// } -// }) -// } - -// #[test] -// fn exercise_key_log_file_for_server() { -// serialized(|| { -// let mut server_config = make_server_config(KeyType::Rsa); - -// env::set_var("SSLKEYLOGFILE", "./sslkeylogfile.txt"); -// server_config.key_log = Arc::new(rustls::KeyLogFile::new()); - -// let server_config = Arc::new(server_config); - -// for version in tls_client::ALL_VERSIONS { -// let client_config = make_client_config_with_versions(KeyType::Rsa, &[version]); -// let (mut client, mut server) = -// make_pair_for_arc_configs(&Arc::new(client_config), &server_config); - -// assert_eq!(5, client.writer().write(b"hello").unwrap()); - -// do_handshake(&mut client, &mut server); -// send(&mut client, &mut server); -// server.process_new_packets().unwrap(); -// } -// }) -// } diff --git a/components/tls/tls-client/tests/server_cert_verifier.rs b/components/tls/tls-client/tests/server_cert_verifier.rs deleted file mode 100644 index 829c0ade0a..0000000000 --- a/components/tls/tls-client/tests/server_cert_verifier.rs +++ /dev/null @@ -1,266 +0,0 @@ -//! Tests for configuring and using a [`ServerCertVerifier`] for a client. - -#![cfg(feature = "dangerous_configuration")] - -mod common; -use crate::common::{ - do_handshake, do_handshake_until_both_error, make_client_config_with_versions, - make_pair_for_arc_configs, make_server_config, ErrorFromPeer, ALL_KEY_TYPES, -}; -use rustls::{ - client::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier, WebPkiVerifier}, - internal::msgs::{enums::AlertDescription, handshake::DigitallySignedStruct}, - Certificate, Error, SignatureScheme, -}; -use std::sync::Arc; - -#[test] -fn client_can_override_certificate_verification() { - for kt in ALL_KEY_TYPES.iter() { - let verifier = Arc::new(MockServerVerifier::accepts_anything()); - - let server_config = Arc::new(make_server_config(*kt)); - - for version in rustls::ALL_VERSIONS { - let mut client_config = make_client_config_with_versions(*kt, &[version]); - client_config - .dangerous() - .set_certificate_verifier(verifier.clone()); - - let (mut client, mut server) = - make_pair_for_arc_configs(&Arc::new(client_config), &server_config); - do_handshake(&mut client, &mut server); - } - } -} - -#[test] -fn client_can_override_certificate_verification_and_reject_certificate() { - for kt in ALL_KEY_TYPES.iter() { - let verifier = Arc::new(MockServerVerifier::rejects_certificate( - Error::CorruptMessage, - )); - - let server_config = Arc::new(make_server_config(*kt)); - - for version in rustls::ALL_VERSIONS { - let mut client_config = make_client_config_with_versions(*kt, &[version]); - client_config - .dangerous() - .set_certificate_verifier(verifier.clone()); - - let (mut client, mut server) = - make_pair_for_arc_configs(&Arc::new(client_config), &server_config); - let errs = do_handshake_until_both_error(&mut client, &mut server); - assert_eq!( - errs, - Err(vec![ - ErrorFromPeer::Client(Error::CorruptMessage), - ErrorFromPeer::Server(Error::AlertReceived(AlertDescription::BadCertificate)) - ]) - ); - } - } -} - -#[test] -fn client_can_override_certificate_verification_and_reject_tls12_signatures() { - for kt in ALL_KEY_TYPES.iter() { - let mut client_config = make_client_config_with_versions(*kt, &[&rustls::version::TLS12]); - let verifier = Arc::new(MockServerVerifier::rejects_tls12_signatures( - Error::CorruptMessage, - )); - - client_config.dangerous().set_certificate_verifier(verifier); - - let server_config = Arc::new(make_server_config(*kt)); - - let (mut client, mut server) = - make_pair_for_arc_configs(&Arc::new(client_config), &server_config); - let errs = do_handshake_until_both_error(&mut client, &mut server); - assert_eq!( - errs, - Err(vec![ - ErrorFromPeer::Client(Error::CorruptMessage), - ErrorFromPeer::Server(Error::AlertReceived(AlertDescription::BadCertificate)) - ]) - ); - } -} - -#[test] -fn client_can_override_certificate_verification_and_reject_tls13_signatures() { - for kt in ALL_KEY_TYPES.iter() { - let mut client_config = make_client_config_with_versions(*kt, &[&rustls::version::TLS13]); - let verifier = Arc::new(MockServerVerifier::rejects_tls13_signatures( - Error::CorruptMessage, - )); - - client_config.dangerous().set_certificate_verifier(verifier); - - let server_config = Arc::new(make_server_config(*kt)); - - let (mut client, mut server) = - make_pair_for_arc_configs(&Arc::new(client_config), &server_config); - let errs = do_handshake_until_both_error(&mut client, &mut server); - assert_eq!( - errs, - Err(vec![ - ErrorFromPeer::Client(Error::CorruptMessage), - ErrorFromPeer::Server(Error::AlertReceived(AlertDescription::BadCertificate)) - ]) - ); - } -} - -#[test] -fn client_can_override_certificate_verification_and_offer_no_signature_schemes() { - for kt in ALL_KEY_TYPES.iter() { - let verifier = Arc::new(MockServerVerifier::offers_no_signature_schemes()); - - let server_config = Arc::new(make_server_config(*kt)); - - for version in rustls::ALL_VERSIONS { - let mut client_config = make_client_config_with_versions(*kt, &[version]); - client_config - .dangerous() - .set_certificate_verifier(verifier.clone()); - - let (mut client, mut server) = - make_pair_for_arc_configs(&Arc::new(client_config), &server_config); - let errs = do_handshake_until_both_error(&mut client, &mut server); - assert_eq!( - errs, - Err(vec![ - ErrorFromPeer::Server(Error::PeerIncompatibleError( - "no overlapping sigschemes".into() - )), - ErrorFromPeer::Client(Error::AlertReceived(AlertDescription::HandshakeFailure)), - ]) - ); - } - } -} - -pub struct MockServerVerifier { - cert_rejection_error: Option, - tls12_signature_error: Option, - tls13_signature_error: Option, - wants_scts: bool, - signature_schemes: Vec, -} - -impl ServerCertVerifier for MockServerVerifier { - fn verify_server_cert( - &self, - end_entity: &rustls::Certificate, - intermediates: &[rustls::Certificate], - server_name: &rustls::ServerName, - scts: &mut dyn Iterator, - oscp_response: &[u8], - now: web_time::SystemTime, - ) -> Result { - let scts: Vec> = scts.map(|x| x.to_owned()).collect(); - println!( - "verify_server_cert({:?}, {:?}, {:?}, {:?}, {:?}, {:?})", - end_entity, intermediates, server_name, scts, oscp_response, now - ); - if let Some(error) = &self.cert_rejection_error { - Err(error.clone()) - } else { - Ok(ServerCertVerified::assertion()) - } - } - - fn verify_tls12_signature( - &self, - message: &[u8], - cert: &Certificate, - dss: &DigitallySignedStruct, - ) -> Result { - println!( - "verify_tls12_signature({:?}, {:?}, {:?})", - message, cert, dss - ); - if let Some(error) = &self.tls12_signature_error { - Err(error.clone()) - } else { - Ok(HandshakeSignatureValid::assertion()) - } - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &Certificate, - dss: &DigitallySignedStruct, - ) -> Result { - println!( - "verify_tls13_signature({:?}, {:?}, {:?})", - message, cert, dss - ); - if let Some(error) = &self.tls13_signature_error { - Err(error.clone()) - } else { - Ok(HandshakeSignatureValid::assertion()) - } - } - - fn supported_verify_schemes(&self) -> Vec { - self.signature_schemes.clone() - } - - fn request_scts(&self) -> bool { - println!("request_scts? {:?}", self.wants_scts); - self.wants_scts - } -} - -impl MockServerVerifier { - pub fn accepts_anything() -> Self { - MockServerVerifier { - cert_rejection_error: None, - ..Default::default() - } - } - - pub fn rejects_certificate(err: Error) -> Self { - MockServerVerifier { - cert_rejection_error: Some(err), - ..Default::default() - } - } - - pub fn rejects_tls12_signatures(err: Error) -> Self { - MockServerVerifier { - tls12_signature_error: Some(err), - ..Default::default() - } - } - - pub fn rejects_tls13_signatures(err: Error) -> Self { - MockServerVerifier { - tls13_signature_error: Some(err), - ..Default::default() - } - } - - pub fn offers_no_signature_schemes() -> Self { - MockServerVerifier { - signature_schemes: vec![], - ..Default::default() - } - } -} - -impl Default for MockServerVerifier { - fn default() -> Self { - MockServerVerifier { - cert_rejection_error: None, - tls12_signature_error: None, - tls13_signature_error: None, - wants_scts: false, - signature_schemes: WebPkiVerifier::verification_schemes(), - } - } -} diff --git a/components/tls/tls-mpc/Cargo.toml b/components/tls/tls-mpc/Cargo.toml deleted file mode 100644 index f1233773b1..0000000000 --- a/components/tls/tls-mpc/Cargo.toml +++ /dev/null @@ -1,68 +0,0 @@ -[package] -name = "tlsn-tls-mpc" -authors = ["TLSNotary Team"] -description = "Implementation of the backend trait for 2PC" -keywords = ["tls", "mpc", "2pc"] -categories = ["cryptography"] -license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" -edition = "2021" - -[lib] -name = "tls_mpc" - -[features] -default = ["tracing"] -tracing = [ - "dep:tracing", - "tlsn-block-cipher/tracing", - "tlsn-stream-cipher/tracing", - "tlsn-universal-hash/tracing", - "tlsn-aead/tracing", - "tlsn-key-exchange/tracing", - "tlsn-point-addition/tracing", - "tlsn-hmac-sha256/tracing", - "tlsn-tls-client-async/tracing", - "uid-mux/tracing", -] - -[dependencies] -tlsn-tls-core = { path = "../tls-core", features = ["serde"] } -tlsn-tls-backend = { path = "../tls-backend" } - -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } - -tlsn-block-cipher = { path = "../../cipher/block-cipher" } -tlsn-stream-cipher = { path = "../../cipher/stream-cipher" } -tlsn-universal-hash = { path = "../../universal-hash" } -tlsn-aead = { path = "../../aead" } -tlsn-key-exchange = { path = "../../key-exchange" } -tlsn-point-addition = { path = "../../point-addition" } -tlsn-hmac-sha256 = { path = "../../prf/hmac-sha256" } - -tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" } - -p256.workspace = true -rand.workspace = true -futures.workspace = true -async-trait.workspace = true -serde.workspace = true -derive_builder.workspace = true -enum-try-as-inner.workspace = true -thiserror.workspace = true -tracing = { workspace = true, optional = true } -ludi = { git = "https://github.com/sinui0/ludi", rev = "b590de5" } - -[dev-dependencies] -tlsn-tls-client = { path = "../tls-client" } -tlsn-tls-client-async = { path = "../tls-client-async" } -tls-server-fixture = { path = "../tls-server-fixture" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -uid-mux = { path = "../../uid-mux" } - -tracing-subscriber.workspace = true - -tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } -tokio-util = { workspace = true, features = ["compat"] } diff --git a/components/tls/tls-mpc/src/config.rs b/components/tls/tls-mpc/src/config.rs deleted file mode 100644 index d2b16dcaf2..0000000000 --- a/components/tls/tls-mpc/src/config.rs +++ /dev/null @@ -1,121 +0,0 @@ -use derive_builder::Builder; - -static DEFAULT_OPAQUE_TX_TRANSCRIPT_ID: &str = "opaque_tx"; -static DEFAULT_OPAQUE_RX_TRANSCRIPT_ID: &str = "opaque_rx"; -static DEFAULT_TX_TRANSCRIPT_ID: &str = "tx"; -static DEFAULT_RX_TRANSCRIPT_ID: &str = "rx"; - -/// Configuration options which are common to both the leader and the follower -#[derive(Debug, Clone, Builder)] -pub struct MpcTlsCommonConfig { - /// The id of the tls session. - #[builder(setter(into))] - id: String, - /// The number of threads to use - #[builder(default = "8")] - num_threads: usize, - /// Tx transcript ID - #[builder(setter(into), default = "DEFAULT_TX_TRANSCRIPT_ID.to_string()")] - tx_transcript_id: String, - /// Rx transcript ID - #[builder(setter(into), default = "DEFAULT_RX_TRANSCRIPT_ID.to_string()")] - rx_transcript_id: String, - /// Opaque Tx transcript ID - #[builder(setter(into), default = "DEFAULT_OPAQUE_TX_TRANSCRIPT_ID.to_string()")] - opaque_tx_transcript_id: String, - /// Opaque Rx transcript ID - #[builder(setter(into), default = "DEFAULT_OPAQUE_RX_TRANSCRIPT_ID.to_string()")] - opaque_rx_transcript_id: String, - /// Maximum size of the transcript in bytes. - /// 16 KiB by default. - #[builder(default = "1 << 14")] - max_transcript_size: usize, - /// Whether the leader commits to the handshake data. - #[builder(default = "true")] - handshake_commit: bool, -} - -impl MpcTlsCommonConfig { - /// Creates a new builder for `MpcTlsCommonConfig`. - pub fn builder() -> MpcTlsCommonConfigBuilder { - MpcTlsCommonConfigBuilder::default() - } - - /// Returns the id of the tls session. - pub fn id(&self) -> &str { - &self.id - } - - /// Returns the number of threads to use. - pub fn num_threads(&self) -> usize { - self.num_threads - } - - /// Returns the tx transcript id. - pub fn tx_transcript_id(&self) -> &str { - &self.tx_transcript_id - } - - /// Returns the rx transcript id. - pub fn rx_transcript_id(&self) -> &str { - &self.rx_transcript_id - } - - /// Returns the opaque tx transcript id. - pub fn opaque_tx_transcript_id(&self) -> &str { - &self.opaque_tx_transcript_id - } - - /// Returns the opaque rx transcript id. - pub fn opaque_rx_transcript_id(&self) -> &str { - &self.opaque_rx_transcript_id - } - - /// Returns the maximum size of the transcript in bytes. - pub fn max_transcript_size(&self) -> usize { - self.max_transcript_size - } - - /// Whether the leader commits to the handshake data. - pub fn handshake_commit(&self) -> bool { - self.handshake_commit - } -} - -/// Configuration for the leader -#[allow(missing_docs)] -#[derive(Debug, Clone, Builder)] -pub struct MpcTlsLeaderConfig { - common: MpcTlsCommonConfig, -} - -impl MpcTlsLeaderConfig { - /// Creates a new builder for `MpcTlsLeaderConfig`. - pub fn builder() -> MpcTlsLeaderConfigBuilder { - MpcTlsLeaderConfigBuilder::default() - } - - /// Returns the common config. - pub fn common(&self) -> &MpcTlsCommonConfig { - &self.common - } -} - -/// Configuration for the follower -#[allow(missing_docs)] -#[derive(Debug, Clone, Builder)] -pub struct MpcTlsFollowerConfig { - common: MpcTlsCommonConfig, -} - -impl MpcTlsFollowerConfig { - /// Creates a new builder for `MpcTlsFollowerConfig`. - pub fn builder() -> MpcTlsFollowerConfigBuilder { - MpcTlsFollowerConfigBuilder::default() - } - - /// Returns the common config. - pub fn common(&self) -> &MpcTlsCommonConfig { - &self.common - } -} diff --git a/components/tls/tls-mpc/src/setup.rs b/components/tls/tls-mpc/src/setup.rs deleted file mode 100644 index 22eb2917fc..0000000000 --- a/components/tls/tls-mpc/src/setup.rs +++ /dev/null @@ -1,185 +0,0 @@ -use hmac_sha256 as prf; -use key_exchange as ke; -use mpz_garble::{Decode, DecodePrivate, Execute, Load, Prove, Verify, Vm}; -use mpz_share_conversion as ff; -use point_addition as pa; -use tlsn_stream_cipher as stream_cipher; -use tlsn_universal_hash as universal_hash; - -use aead::Aead; -use hmac_sha256::Prf; -use ke::KeyExchange; - -use utils_aio::mux::MuxChannel; - -use crate::{config::MpcTlsCommonConfig, MpcTlsError, TlsRole}; - -/// Helper function for setting up components -#[allow(clippy::type_complexity)] -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip_all, err) -)] -pub async fn setup_components< - M: MuxChannel + MuxChannel + Clone, - VM: Vm + Send, - PS: ff::ShareConversion + Send + Sync + 'static + std::fmt::Debug, - PR: ff::ShareConversion + Send + Sync + 'static + std::fmt::Debug, - GF: ff::ShareConversion + Send + Sync + Clone + 'static + std::fmt::Debug, ->( - config: &MpcTlsCommonConfig, - role: TlsRole, - mux: &mut M, - vm: &mut VM, - p256_send: PS, - p256_recv: PR, - gf: GF, -) -> Result< - ( - Box, - Box, - Box, - Box, - ), - MpcTlsError, -> -where - ::Thread: Execute + Load + Decode + DecodePrivate + Prove + Verify + Send + Sync, -{ - // Set up channels - let (mut mux_0, mut mux_1) = (mux.clone(), mux.clone()); - let (ke_channel, encrypter_channel, decrypter_channel) = futures::try_join!( - mux_0.get_channel("ke"), - mux_1.get_channel("encrypter"), - mux.get_channel("decrypter") - ) - .map_err(|_| std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "mux error"))?; - - let (ke_role, pa_role, aead_role) = match role { - TlsRole::Leader => ( - ke::Role::Leader, - pa::Role::Leader, - aead::aes_gcm::Role::Leader, - ), - TlsRole::Follower => ( - ke::Role::Follower, - pa::Role::Follower, - aead::aes_gcm::Role::Follower, - ), - }; - - // Key exchange - let ke = ke::KeyExchangeCore::new( - ke_channel, - pa::MpcPointAddition::new(pa_role, p256_send), - pa::MpcPointAddition::new(pa_role, p256_recv), - vm.new_thread("ke").await?, - ke::KeyExchangeConfig::builder() - .id("ke") - .role(ke_role) - .build() - .unwrap(), - ); - - // PRF - let prf_role = match role { - TlsRole::Leader => prf::Role::Leader, - TlsRole::Follower => prf::Role::Follower, - }; - let prf = prf::MpcPrf::new( - prf::PrfConfig::builder().role(prf_role).build().unwrap(), - vm.new_thread("prf/0").await?, - vm.new_thread("prf/1").await?, - ); - - // Encrypter - let block_cipher = block_cipher::MpcBlockCipher::::new( - block_cipher::BlockCipherConfig::builder() - .id("encrypter/block_cipher") - .build() - .unwrap(), - vm.new_thread("encrypter/block_cipher").await?, - ); - - let stream_cipher = stream_cipher::MpcStreamCipher::::new( - stream_cipher::StreamCipherConfig::builder() - .id("encrypter/stream_cipher") - .transcript_id("tx") - .build() - .unwrap(), - vm.new_thread_pool("encrypter/stream_cipher", config.num_threads()) - .await?, - ); - - let ghash = universal_hash::ghash::Ghash::new( - universal_hash::ghash::GhashConfig::builder() - .id("encrypter/ghash") - .initial_block_count(64) - .build() - .unwrap(), - gf.clone(), - ); - - let mut encrypter = aead::aes_gcm::MpcAesGcm::new( - aead::aes_gcm::AesGcmConfig::builder() - .id("encrypter/aes_gcm") - .role(aead_role) - .build() - .unwrap(), - encrypter_channel, - Box::new(block_cipher), - Box::new(stream_cipher), - Box::new(ghash), - ); - - encrypter.set_transcript_id(config.opaque_tx_transcript_id()); - - // Decrypter - let block_cipher = block_cipher::MpcBlockCipher::::new( - block_cipher::BlockCipherConfig::builder() - .id("decrypter/block_cipher") - .build() - .unwrap(), - vm.new_thread("decrypter/block_cipher").await?, - ); - - let stream_cipher = stream_cipher::MpcStreamCipher::::new( - stream_cipher::StreamCipherConfig::builder() - .id("decrypter/stream_cipher") - .transcript_id("rx") - .build() - .unwrap(), - vm.new_thread_pool("decrypter/stream_cipher", config.num_threads()) - .await?, - ); - - let ghash = universal_hash::ghash::Ghash::new( - universal_hash::ghash::GhashConfig::builder() - .id("decrypter/ghash") - .initial_block_count(64) - .build() - .unwrap(), - gf, - ); - - let mut decrypter = aead::aes_gcm::MpcAesGcm::new( - aead::aes_gcm::AesGcmConfig::builder() - .id("decrypter/aes_gcm") - .role(aead_role) - .build() - .unwrap(), - decrypter_channel, - Box::new(block_cipher), - Box::new(stream_cipher), - Box::new(ghash), - ); - - decrypter.set_transcript_id(config.opaque_rx_transcript_id()); - - Ok(( - Box::new(ke), - Box::new(prf), - Box::new(encrypter), - Box::new(decrypter), - )) -} diff --git a/components/tls/tls-mpc/tests/test.rs b/components/tls/tls-mpc/tests/test.rs deleted file mode 100644 index e015f72fea..0000000000 --- a/components/tls/tls-mpc/tests/test.rs +++ /dev/null @@ -1,346 +0,0 @@ -use std::{sync::Arc, time::Duration}; - -use futures::{AsyncReadExt, AsyncWriteExt, StreamExt}; -use mpz_garble::{config::Role as GarbleRole, protocol::deap::DEAPVm}; -use mpz_ot::{ - actor::kos::{ReceiverActor, SenderActor}, - chou_orlandi::{ - Receiver as BaseReceiver, ReceiverConfig as BaseReceiverConfig, Sender as BaseSender, - SenderConfig as BaseSenderConfig, - }, - kos::{Receiver, ReceiverConfig, Sender, SenderConfig}, -}; -use mpz_share_conversion as ff; -use mpz_share_conversion::{ShareConversionReveal, ShareConversionVerify}; -use tls_client::Certificate; -use tls_client_async::bind_client; -use tls_mpc::{ - setup_components, MpcTlsCommonConfig, MpcTlsFollower, MpcTlsFollowerConfig, MpcTlsLeader, - MpcTlsLeaderConfig, TlsRole, -}; -use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN}; -use tokio_util::compat::TokioAsyncReadCompatExt; -use uid_mux::{yamux, UidYamux}; -use utils_aio::{codec::BincodeMux, mux::MuxChannel}; - -#[tokio::test] -#[ignore] -async fn test() { - tracing_subscriber::fmt::init(); - - let (leader_socket, follower_socket) = tokio::io::duplex(1 << 25); - - let mut leader_mux = UidYamux::new( - yamux::Config::default(), - leader_socket.compat(), - yamux::Mode::Client, - ); - let mut follower_mux = UidYamux::new( - yamux::Config::default(), - follower_socket.compat(), - yamux::Mode::Server, - ); - - let leader_mux_control = leader_mux.control(); - let follower_mux_control = follower_mux.control(); - - tokio::spawn(async move { leader_mux.run().await.unwrap() }); - tokio::spawn(async move { follower_mux.run().await.unwrap() }); - - let mut leader_mux = BincodeMux::new(leader_mux_control); - let mut follower_mux = BincodeMux::new(follower_mux_control); - - let leader_ot_sender_config = SenderConfig::default(); - let follower_ot_recvr_config = ReceiverConfig::default(); - - let follower_ot_sender_config = SenderConfig::builder().sender_commit().build().unwrap(); - let leader_ot_recvr_config = ReceiverConfig::builder().sender_commit().build().unwrap(); - - let (leader_ot_sender_sink, leader_ot_sender_stream) = - leader_mux.get_channel("ot/0").await.unwrap().split(); - - let (follower_ot_recvr_sink, follower_ot_recvr_stream) = - follower_mux.get_channel("ot/0").await.unwrap().split(); - - let (leader_ot_receiver_sink, leader_ot_receiver_stream) = - leader_mux.get_channel("ot/1").await.unwrap().split(); - - let (follower_ot_sender_sink, follower_ot_sender_stream) = - follower_mux.get_channel("ot/1").await.unwrap().split(); - - let mut leader_ot_sender_actor = SenderActor::new( - Sender::new( - leader_ot_sender_config, - BaseReceiver::new(BaseReceiverConfig::default()), - ), - leader_ot_sender_sink, - leader_ot_sender_stream, - ); - - let mut follower_ot_recvr_actor = ReceiverActor::new( - Receiver::new( - follower_ot_recvr_config, - BaseSender::new(BaseSenderConfig::default()), - ), - follower_ot_recvr_sink, - follower_ot_recvr_stream, - ); - - let mut leader_ot_recvr_actor = ReceiverActor::new( - Receiver::new( - leader_ot_recvr_config, - BaseSender::new( - BaseSenderConfig::builder() - .receiver_commit() - .build() - .unwrap(), - ), - ), - leader_ot_receiver_sink, - leader_ot_receiver_stream, - ); - - let mut follower_ot_sender_actor = SenderActor::new( - Sender::new( - follower_ot_sender_config, - BaseReceiver::new( - BaseReceiverConfig::builder() - .receiver_commit() - .build() - .unwrap(), - ), - ), - follower_ot_sender_sink, - follower_ot_sender_stream, - ); - - let leader_ot_send = leader_ot_sender_actor.sender(); - let follower_ot_recv = follower_ot_recvr_actor.receiver(); - - let leader_ot_recv = leader_ot_recvr_actor.receiver(); - let follower_ot_send = follower_ot_sender_actor.sender(); - - tokio::spawn(async move { - leader_ot_sender_actor.setup(20000).await.unwrap(); - leader_ot_sender_actor.run().await.unwrap(); - }); - - tokio::spawn(async move { - follower_ot_recvr_actor.setup(20000).await.unwrap(); - follower_ot_recvr_actor.run().await.unwrap(); - }); - - tokio::spawn(async move { - leader_ot_recvr_actor.setup(20000).await.unwrap(); - leader_ot_recvr_actor.run().await.unwrap(); - }); - - tokio::spawn(async move { - follower_ot_sender_actor.setup(20000).await.unwrap(); - follower_ot_sender_actor.run().await.unwrap(); - follower_ot_sender_actor.reveal().await.unwrap(); - }); - - let mut leader_vm = DEAPVm::new( - "vm", - GarbleRole::Leader, - [0u8; 32], - leader_mux.get_channel("vm").await.unwrap(), - Box::new(leader_mux.clone()), - leader_ot_send.clone(), - leader_ot_recv.clone(), - ); - - let mut follower_vm = DEAPVm::new( - "vm", - GarbleRole::Follower, - [1u8; 32], - follower_mux.get_channel("vm").await.unwrap(), - Box::new(follower_mux.clone()), - follower_ot_send.clone(), - follower_ot_recv.clone(), - ); - - let leader_p256_send = ff::ConverterSender::::new( - ff::SenderConfig::builder().id("p256/0").build().unwrap(), - leader_ot_send.clone(), - leader_mux.get_channel("p256/0").await.unwrap(), - ); - - let leader_p256_recv = ff::ConverterReceiver::::new( - ff::ReceiverConfig::builder().id("p256/1").build().unwrap(), - leader_ot_recv.clone(), - leader_mux.get_channel("p256/1").await.unwrap(), - ); - - let follower_p256_send = ff::ConverterSender::::new( - ff::SenderConfig::builder().id("p256/1").build().unwrap(), - follower_ot_send.clone(), - follower_mux.get_channel("p256/1").await.unwrap(), - ); - - let follower_p256_recv = ff::ConverterReceiver::::new( - ff::ReceiverConfig::builder().id("p256/0").build().unwrap(), - follower_ot_recv.clone(), - follower_mux.get_channel("p256/0").await.unwrap(), - ); - - let mut leader_gf2 = ff::ConverterSender::::new( - ff::SenderConfig::builder() - .id("gf2") - .record() - .build() - .unwrap(), - leader_ot_send.clone(), - leader_mux.get_channel("gf2").await.unwrap(), - ); - - let mut follower_gf2 = ff::ConverterReceiver::::new( - ff::ReceiverConfig::builder() - .id("gf2") - .record() - .build() - .unwrap(), - follower_ot_recv.clone(), - follower_mux.get_channel("gf2").await.unwrap(), - ); - - let common_config = MpcTlsCommonConfig::builder().id("test").build().unwrap(); - - let (leader_ke, leader_prf, leader_encrypter, leader_decrypter) = setup_components( - &common_config, - TlsRole::Leader, - &mut leader_mux, - &mut leader_vm, - leader_p256_send, - leader_p256_recv, - leader_gf2.handle().unwrap(), - ) - .await - .unwrap(); - - let mut leader = MpcTlsLeader::new( - MpcTlsLeaderConfig::builder() - .common(common_config.clone()) - .build() - .unwrap(), - leader_mux.get_channel("test").await.unwrap(), - leader_ke, - leader_prf, - leader_encrypter, - leader_decrypter, - ); - - let (follower_ke, follower_prf, follower_encrypter, follower_decrypter) = setup_components( - &common_config, - TlsRole::Follower, - &mut follower_mux, - &mut follower_vm, - follower_p256_send, - follower_p256_recv, - follower_gf2.handle().unwrap(), - ) - .await - .unwrap(); - - let mut follower = MpcTlsFollower::new( - MpcTlsFollowerConfig::builder() - .common(common_config) - .build() - .unwrap(), - follower_mux.get_channel("test").await.unwrap(), - follower_ke, - follower_prf, - follower_encrypter, - follower_decrypter, - ); - - let follower_task = tokio::spawn(async move { - follower.setup().await.unwrap(); - - let (_, fut) = follower.run(); - fut.await.unwrap() - }); - - leader.setup().await.unwrap(); - - let (leader_ctrl, leader_fut) = leader.run(); - let leader_task = tokio::spawn(leader_fut); - - let mut root_store = tls_client::RootCertStore::empty(); - root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap(); - let config = tls_client::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_no_client_auth(); - - let server_name = SERVER_DOMAIN.try_into().unwrap(); - - let client = tls_client::ClientConnection::new( - Arc::new(config), - Box::new(leader_ctrl.clone()), - server_name, - ) - .unwrap(); - - let (client_socket, server_socket) = tokio::io::duplex(1 << 16); - - tokio::spawn(bind_test_server_hyper(server_socket.compat())); - - let (mut conn, conn_fut) = bind_client(client_socket.compat(), client); - - let conn_task = tokio::spawn(conn_fut); - - let msg = concat!( - "POST /echo HTTP/1.1\r\n", - "Host: test-server.io\r\n", - "Connection: keep-alive\r\n", - "Accept-Encoding: identity\r\n", - "Content-Length: 5\r\n", - "\r\n", - "hello", - "\r\n" - ); - - conn.write_all(msg.as_bytes()).await.unwrap(); - - let mut buf = vec![0u8; 48]; - conn.read_exact(&mut buf).await.unwrap(); - - println!("{}", String::from_utf8_lossy(&buf)); - - leader_ctrl.defer_decryption().await.unwrap(); - - let msg = concat!( - "POST /echo HTTP/1.1\r\n", - "Host: test-server.io\r\n", - "Connection: close\r\n", - "Accept-Encoding: identity\r\n", - "Content-Length: 5\r\n", - "\r\n", - "hello", - "\r\n" - ); - - conn.write_all(msg.as_bytes()).await.unwrap(); - - // Wait for the server to reply. - tokio::time::sleep(Duration::from_millis(100)).await; - - leader_ctrl.commit().await.unwrap(); - - let mut buf = vec![0u8; 1024]; - conn.read_to_end(&mut buf).await.unwrap(); - - leader_ctrl.close_connection().await.unwrap(); - conn.close().await.unwrap(); - - follower_ot_send.shutdown().await.unwrap(); - - tokio::try_join!(leader_vm.finalize(), follower_vm.finalize()).unwrap(); - tokio::try_join!(leader_gf2.reveal(), follower_gf2.verify()).unwrap(); - - conn_task.await.unwrap().unwrap(); - _ = leader_task.await.unwrap(); - _ = follower_task.await.unwrap(); -} diff --git a/components/tls/tls-server-fixture/src/domain.crt b/components/tls/tls-server-fixture/src/domain.crt deleted file mode 100644 index f946fbdc7c..0000000000 --- a/components/tls/tls-server-fixture/src/domain.crt +++ /dev/null @@ -1,22 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDmzCCAoOgAwIBAgIUXYmS5GJNi70RMwqR1zzSjI6gjnIwDQYJKoZIhvcNAQEL -BQAwSjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxEjAQBgNVBAoM -CXRsc25vdGFyeTESMBAGA1UEAwwJdGxzbm90YXJ5MB4XDTIzMDYwNzAxMzcyOFoX -DTI0MDYwNjAxMzcyOFowWzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3Rh -dGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UEAwwL -dGVzdC1zZXJ2ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQClrwyM -b9JVV56IqAAE7QoAfnyYMRxL/93NV2II35Hq1DrBeXGvQ9EjD0qKMKCNIJFLQUaO -dIQ853+OGez9Q73836bPOqM7hSRdn34bB+4phKnxM2QegEL+0oR6YhXS+9iavAWj -obfgVmmZhtLMXAIRuZrMLVPRNm+MIi4BcEu6Ckgdxhvkvp7Gzb74NXffgdilqTJG -bKdvCuF7IgFZXTi/8ACAZDKuXt8t9w0VUZRC9X4YTNslesW8MaPbmRwlVR8+qDMA -f2UdoUS5t9NGnniM/zsVK+RvbrFD0yXoJ3Pr+NdpFwLFPuhQCX2l0r2edb5Xsw6d -unCm2vZIrajFmzLTAgMBAAGjaDBmMB8GA1UdIwQYMBaAFCzjJ91fETcwwaoAM0mB -Z8r8srGlMAkGA1UdEwQCMAAwGQYDVR0RBBIwEIIOdGVzdC1zZXJ2ZXIuaW8wHQYD -VR0OBBYEFJMXVtfmDSqllg3BuymCj4OgENVnMA0GCSqGSIb3DQEBCwUAA4IBAQBd -W2Y58hHXei5K1wXRKaSZV8uyI5a4F4h+75vNNDGcbU204YRAtwmTYLXZnUCtxhL5 -wDnH00z8Z8s+ZHfDdH/64lxM0VmVKNxIdF6KUMIyvrdK9aL2wRMLSWCZTMBGibs0 -npY6fGgXdAeZnG8iP6ede3tF7vNpr+no+lrsx7ZCYhUg/XvaGsR2wIoMpMhVyDv2 -jkxc+Xnt/Prr89mQQUQVg2zkwcPrgEM+NwpDMqH3BFVsx6Qu1FO6sAIREewSM6t9 -kgfkzmH97Z5HEjGV2CWjsNBEAPaafAnE8qqvHQkFUmps12LnsEGbZbM/8kxifjNX -V6wbaLYrV6WDttQINST8 ------END CERTIFICATE----- diff --git a/components/tls/tls-server-fixture/src/domain.csr b/components/tls/tls-server-fixture/src/domain.csr deleted file mode 100644 index a4201d2587..0000000000 --- a/components/tls/tls-server-fixture/src/domain.csr +++ /dev/null @@ -1,17 +0,0 @@ ------BEGIN CERTIFICATE REQUEST----- -MIICoDCCAYgCAQAwWzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx -ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UEAwwLdGVz -dC1zZXJ2ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQClrwyMb9JV -V56IqAAE7QoAfnyYMRxL/93NV2II35Hq1DrBeXGvQ9EjD0qKMKCNIJFLQUaOdIQ8 -53+OGez9Q73836bPOqM7hSRdn34bB+4phKnxM2QegEL+0oR6YhXS+9iavAWjobfg -VmmZhtLMXAIRuZrMLVPRNm+MIi4BcEu6Ckgdxhvkvp7Gzb74NXffgdilqTJGbKdv -CuF7IgFZXTi/8ACAZDKuXt8t9w0VUZRC9X4YTNslesW8MaPbmRwlVR8+qDMAf2Ud -oUS5t9NGnniM/zsVK+RvbrFD0yXoJ3Pr+NdpFwLFPuhQCX2l0r2edb5Xsw6dunCm -2vZIrajFmzLTAgMBAAGgADANBgkqhkiG9w0BAQsFAAOCAQEAUY8AusY1fDOy9Ary -rbkaX+pSUd4HSoElYEvn5ikrguqsTny5+Dvd+4aVGLHSO0/Y/cA1oRq8YqFEC/2C -2PxQkiZY8BdGKTWCT26f/3S197K/lOzxGmXbfZDZVHstzzobMIHFd0NEGBamLH4w -iSnuNRHre48cVIlCx3S2CVskNikJGBiZVfQeNVhkd5zqEkAdYViEycoQ5RyCu3po -HGlUQx3Z2TY4wNv1iKMPj8701C+c1uIosHMrjDPpIqHjS+nbaKPkXVwTF2lbOOcO -111t3fk0EeCdnGq8g6rS9mEL/kBZZNZI0/Pwl3SIBqoxvBv6U+ennqZhnPVJSK07 -//ovUA== ------END CERTIFICATE REQUEST----- diff --git a/components/tls/tls-server-fixture/src/domain.der b/components/tls/tls-server-fixture/src/domain.der deleted file mode 100644 index 0a9191c164..0000000000 Binary files a/components/tls/tls-server-fixture/src/domain.der and /dev/null differ diff --git a/components/tls/tls-server-fixture/src/domain.ext b/components/tls/tls-server-fixture/src/domain.ext deleted file mode 100644 index d1483b93da..0000000000 --- a/components/tls/tls-server-fixture/src/domain.ext +++ /dev/null @@ -1,5 +0,0 @@ -authorityKeyIdentifier=keyid,issuer -basicConstraints=CA:FALSE -subjectAltName = @alt_names -[alt_names] -DNS.1 = test-server.io \ No newline at end of file diff --git a/components/tls/tls-server-fixture/src/domain.key b/components/tls/tls-server-fixture/src/domain.key deleted file mode 100644 index 0ca070d38c..0000000000 --- a/components/tls/tls-server-fixture/src/domain.key +++ /dev/null @@ -1,30 +0,0 @@ ------BEGIN ENCRYPTED PRIVATE KEY----- -MIIFHDBOBgkqhkiG9w0BBQ0wQTApBgkqhkiG9w0BBQwwHAQInlBxGC6Q/nwCAggA -MAwGCCqGSIb3DQIJBQAwFAYIKoZIhvcNAwcECPW0f2DL+mvdBIIEyOCxtYAeFPSA -IHXX/UtF0Z0FtFP+d1/osMI0wxJrxz0KM9fN8z7XEkzXIR8ZjMCsduIl/zirl+bB -L9UGAHBNIn3yYpwh9UXhFBC4t8GM5xy89E6GMD7DU5X0xA/B1RvGFlHY+EkyTYr6 -crhY8uxTEoVPaxgbeyZfJaTP4bcZyAVbgIN0TqOX7fk+zIKmCHauRETwJho8ZZ+d -tvucnpA4WqNyLhgWW+ScApzgVEsbBXV80tEDOXEw7TVtaS1VqlKP2pFVVHuCWQWk -GnZT57LY4vzBkmpEnmpNj10IVPg5D9Ro+XNsXv9ji67NU42AOopWngOLhqMEyVS2 -DJL9X5DQYQJKSx7cKVkpNkBMLZC6V7iPhG1WHr+OUgYdNSf6l/59tLLQTpN0u/19 -VsgaS84CqAv4eFOVwfr9QTwPg7BNKbvdzn41IvXA05OT3xs0P2FhbjafFaJrivu+ -0krVb1T7RyiwqJSr8EujufHptIzkax/HE5+g++v6laiHlZXealGmGsYKNkO4Tyn9 -3QU38U41A4LyBrnlWn8mEZxY/SM7nQhYdfZ2CDCwqpBmvozxuY4p6fwcyPXmMmxv -bwzDlIRkk1VxobRKYqoWhCXKCRm9c/WZoymbWghLhEhbwgRAfM6sTAtNKbAooY15 -RGGCZE/QIqroWdMyH9hWQ01KOSA+CWWgsAUr+l6w678Jcx8bOSlbJkjpHVY4Kg0G -kngficAHb4nVNmGXNoDqg/EgDi57J8jNEKxLK3SIs5xwcFqTQl5qF6vJUR64yYS1 -aJe6+AedSvubrSzqnuI7SgX44WoHpOduhu99AAMgHOSAOgg5Nxr29Jyo/h6qKsx8 -u+EDrIcQPsgpITNALmI6yT7uZJYnkFfvuZdiTOqAURVUVLUsnAuCzaUzbBdM2Vbb -rJlOkHJqrX9NxhIl8CSzdGiwdBc1gwi5eL84NY3MTSOD4+HN0u17a9O/EK2OhVNR -Jyt+r0WgE2lmaLKNaRvruwuU9oVG0cfWxRilOxFKnDpk9SMe4ZvUUIGe8wjx2BNs -u87EvPK71+gGLkWxo6qGH6xk/26dO/yjclFHDxjLe63WBTBrUrMucKjMbyxuKYrC -JgNyHskeK7cqOixg4cDMCIfNoHg8nnawF2fAYY+NEKJq29YcAWO1P8PuqpINdgke -Sac+ogpp56ZtFlmgQ6PYOM/24b3dFQwdt46HzK7iluVR7Ihq0uiBO7UDilz9qUPX -sdPSmOYoFST6kgx9Occk6MeJKMBH+f8nVnTauwhBs3s2AZEgHd07MreUF+G1njbI -ukCvL8MHYhwqCCd6EZaH7U7oNGA8roIRPkvpSXeuNhDsabxQIzQNyOy9v3fKNtyt -EAR1bnrKwXX/32dC/JFcJOfzMuhCclw/88miNtzDz/tXym7EUU4UbX5ARoDNHuaJ -0DHfEMCTSa2umiqMClTvfOgn9OWI140y6HPqWXmJJReiKkGkDJiBcRiBCGj0uxwz -qrclNr5iDk+USY09qR2gWt30xaBclgzAL9b7YEMulfcrlrOGxrVUlDSezs2YZQRH -156WWeCrruY+MjFZ4w0jhqkRW0osO9TwbyAZTlVYqrn9RChmfnW1Mn6p526ZkDer -nSQuQlLmTNOHGRRFvKg/6Q== ------END ENCRYPTED PRIVATE KEY----- diff --git a/components/tls/tls-server-fixture/src/domain_key.der b/components/tls/tls-server-fixture/src/domain_key.der deleted file mode 100644 index 984284521b..0000000000 Binary files a/components/tls/tls-server-fixture/src/domain_key.der and /dev/null differ diff --git a/components/tls/tls-server-fixture/src/rootCA.crt b/components/tls/tls-server-fixture/src/rootCA.crt deleted file mode 100644 index 728847ce25..0000000000 --- a/components/tls/tls-server-fixture/src/rootCA.crt +++ /dev/null @@ -1,21 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDdTCCAl2gAwIBAgIUYL1NTquyWYgFShuzP3SXUW1LRLswDQYJKoZIhvcNAQEL -BQAwSjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxEjAQBgNVBAoM -CXRsc25vdGFyeTESMBAGA1UEAwwJdGxzbm90YXJ5MB4XDTIzMDYwNzAxMzQyN1oX -DTI4MDYwNTAxMzQyN1owSjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3Rh -dGUxEjAQBgNVBAoMCXRsc25vdGFyeTESMBAGA1UEAwwJdGxzbm90YXJ5MIIBIjAN -BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyhPglwjjaUxTvlQZK3IOVdiormCz -eHCvyZscBcolbhL8CbrtJYkozfMLwaBSccqBLpMbvulOeKBBB3ll9iffeWkQIcnh -m4ozf7DdG7EhDnp0lBS40UmDV4zHKbqlUzBBt737RGaDSDzf0w2b8tVOA0ZkBXw7 -aGlbj1ikV45GFlyH5b+wCvcboUaJkt4CllqF8uIo4Csjp3EqLMnGW6556HNukLoi -YlqW89VM+7C0yaL43ROjPr0lXPBpCrV9jnsVSaBJ2u3Ae35KgaqFxcMrXzpMArjS -jEXSW6Gi7XdI7bWMpfnLv3eyRkuEzkhCu1PPTA1EBa4eEikAqilU7ukKzQIDAQAB -o1MwUTAdBgNVHQ4EFgQULOMn3V8RNzDBqgAzSYFnyvyysaUwHwYDVR0jBBgwFoAU -LOMn3V8RNzDBqgAzSYFnyvyysaUwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0B -AQsFAAOCAQEArIKYiJ6eLqNcy1UCg4vboRWByVG4nN5NOPoBNRT1xTzEPsjEoZpw -EdE8roKD4jG6Pf73yiMsLGhSUA7HBET6NTS5g7wGgGL0QPQIeLHLf1oLWb17Zqmm -i316NBzwPEXx8snK/WcBzPimbcKRDAYJG5UahGJyaIE5tcXORq7tK4YCycM1qTq0 -dx5RjO+ZJ1160nstXRUhRumilPW1FzxOt+wnaJza07zwLF9NphnAh6XGCsisz/KW -8scL+DNu6W0+irixhGl3Ly9idNQLmDUEWoO6sp3MFW7He5iJgzFlIY1oo7uQFkHR -GzrhAlDfCcoCHFvHc5MuqFpvgSsWEoC2AQ== ------END CERTIFICATE----- diff --git a/components/tls/tls-server-fixture/src/rootCA.der b/components/tls/tls-server-fixture/src/rootCA.der deleted file mode 100644 index d974cd41ff..0000000000 Binary files a/components/tls/tls-server-fixture/src/rootCA.der and /dev/null differ diff --git a/components/tls/tls-server-fixture/src/rootCA.key b/components/tls/tls-server-fixture/src/rootCA.key deleted file mode 100644 index dbd954eb03..0000000000 --- a/components/tls/tls-server-fixture/src/rootCA.key +++ /dev/null @@ -1,30 +0,0 @@ ------BEGIN ENCRYPTED PRIVATE KEY----- -MIIFHDBOBgkqhkiG9w0BBQ0wQTApBgkqhkiG9w0BBQwwHAQIwMUhKGCZU0kCAggA -MAwGCCqGSIb3DQIJBQAwFAYIKoZIhvcNAwcECJxXqH54fWnLBIIEyI9L1uoU2blz -siT/FPwMzDTgGJCYtre8yIuA81OTURH8V3PM4r8SguPk5xv75WcXu71/+I5hK0s6 -ia1BZHHSQosnu8XrGNTyIaybg1JzrXtunlFUoH64Tg5tXc2G53lbXb7LMGs7YAhh -SG6bmbSerVVOj5Iu1X3+Vdjlsp26BdcOGZ83xZVyRWRWUDjHfrQ0xmMBdB2pe0nD -7cqpSD6WGFXHaRoV+zL/ErjelkqU7V9GSDajZxQxniENEPddpucc7Pz29rqE/2Qf -XX/HD/9hz95RvssYX7b3dczoqheRxIsx72Fn5B/QzEGW/PPaNGi8iQZr8rsoE3S1 -S+nbBKCRiMTwLYu3Z5EjWeTHBtA8tKGPW6HF/9Ga1eEnKTX5+7SgyxbfkeDK1SJN -M77V4qN6R9HvbvduSjO9utDr/zCvHu68Ju9dYCNOiKdLS417pVNtB47EnUiN7kDF -Zrd+lL/8crqNWokocrw2qxLuGMsBqHdvtl6F5O57w9rKZl72gfQTg0dsh5YPN7K9 -xckGaZg0n8nUDAop6bSzXc+0JCwzXfGtKsdq4VjBgy/QtqLX24SgloHJHrl1qyGj -eSnKRvDcazD4XEDEsn9HBIaHHHm8lIrh8IhgsE3UvifgfOf4DO9uasKEdXahZaEu -9WH1mxnY67X+z9IaslMgkzoVnZARq1Swin+SxdY6rWJHucCLO8pCh4R9w6J/olXM -+bs3ivrjygJ2TOYgPV8CJ8IU65Z4C/Wv/Rt4p1QwxK/swkUjINqryh6rg2BGfHE8 -yLUioVimx4c/kziZEq3z8+eq6vs2cB28q3ze39REHAtk3ezsvIirtY/FG10zrWgx -cWHJYUKjYug1RkPWwYr3/vuTh68IrEmGi4TMcojxSUpq1vdc2lk9zu9vXTEsKF00 -NQP5txscxQZwdb5Q01pylcTHOeiEjbQ0hRQ+oelwRniKXAuAp5D7H4LnUcw07wfp -nA1GEqMh7M7/dM3E7eXmzS2izDnwE9E0Va0Qu4fYpu+VPwahgP71bVheYRb1Rr1a -ODnUcReuNqJhZ7q/BTRZ4fK7vrVfuJjecMGNBdFRYH1etQgcZPJzJWnZaMP+S+iO -Jqk/1aMAR5E6ejO5KwExmO9JFQplfbJmrvtGqq1AkkOr+EefR5CK7Tr2jNdEE2xg -b5Z1BmqniJ2I/o9L9AxFFKE9PjXeLLS6XHS3N5VPjLzACJhchQZWI1RbEt6qAW6i -LwizIDJLN57mxaxtuQsEg+hD1tNiRIaF3dsCpBq7MTc64/iWphCydYF8adRSzoVy -07GwAmZ9JX/fd6WCWYAo97N0S7/adBqAi/XXA1wqkW/ZpRF63tXNcLJXwLcoeclr -DKpMb0NtjCT5QNJNVGHTUv4RdGEWbN2u6Oq6rfOfYnXvjrzmpAVVkRJ8BlQsTkzN -5T+BV4txwS5a1EH2ERfY4KuQWcSKIP2lKjNvqKCtwbHOSoQ6wmtLJKzw9mLD8JzQ -9wROywbA4CDW4dbSPuHUTmSpfBrt36mfV8+loJ8m3+eqcQNIh44drO3P8wwaIAEe -1YT7Yt75YD2EjtXvw2umxvQP/PWk2t+EI+Fo9/t/U2zPHgFuxevCMzNsRktz108M -rIuTToR3EfYojhEvyMyf7g== ------END ENCRYPTED PRIVATE KEY----- diff --git a/components/uid-mux/Cargo.toml b/components/uid-mux/Cargo.toml deleted file mode 100644 index f1e37a06f9..0000000000 --- a/components/uid-mux/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ -[package] -name = "uid-mux" -version = "0.1.0-alpha.3" -authors = ["TLSNotary Team"] -description = "Multiplex connections asynchronously." -keywords = ["multiplex", "channel", "futures", "async"] -categories = ["network-programming", "asynchronous"] -license = "MIT OR Apache-2.0" -edition = "2021" - -[features] -tracing = ["dep:tracing"] - -[dependencies] -tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" } - -async-trait = "0.1" -futures = "0.3" -yamux = "0.11" -tracing = { version = "0.1", optional = true } - -[dev-dependencies] -tokio-util = { version = "0.7", features = ["compat"] } -tokio = { version = "1", features = [ - "macros", - "rt", - "rt-multi-thread", - "time", -] } diff --git a/components/uid-mux/src/lib.rs b/components/uid-mux/src/lib.rs deleted file mode 100644 index 1c1b50f736..0000000000 --- a/components/uid-mux/src/lib.rs +++ /dev/null @@ -1,366 +0,0 @@ -//! This library provides tools to multiplex a connection and uses [yamux] under the hood. -//! -//! To use this library, instantiate a [UidYamux] by providing an underlying socket (anything which -//! implements [AsyncRead] and [AsyncWrite]). After running [run](UidYamux::run) in the background -//! you can create controls with [control](UidYamux::control), which can be easily passed around. -//! They allow to open new streams ([get_stream](UidYamuxControl::get_stream)) by providing unique -//! stream ids. - -#![deny(missing_docs, unreachable_pub, unused_must_use)] -#![deny(clippy::all)] -#![forbid(unsafe_code)] - -use std::{ - collections::{HashMap, HashSet}, - sync::{Arc, Mutex}, -}; - -use async_trait::async_trait; - -use futures::{ - channel::oneshot, stream::FuturesUnordered, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, - StreamExt, -}; -use utils_aio::mux::{MuxStream, MuxerError}; - -pub use yamux; - -#[derive(Debug, Default)] -struct MuxState { - stream_ids: HashSet, - waiting_callers: HashMap>>, - waiting_streams: HashMap, -} - -/// A wrapper around [yamux] to facilitate multiplexing with unique stream ids. -pub struct UidYamux { - mode: yamux::Mode, - conn: Option>, - control: yamux::Control, - state: Arc>, -} - -impl std::fmt::Debug for UidYamux { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("UidYamux") - .field("mode", &self.mode) - .field("conn", &"{{ ... }}") - .field("control", &self.control) - .field("state", &self.state) - .finish() - } -} - -/// A muxer control for [opening streams](Self::get_stream) with the remote -#[derive(Debug, Clone)] -pub struct UidYamuxControl { - mode: yamux::Mode, - control: yamux::Control, - state: Arc>, -} - -impl UidYamuxControl { - /// Close the connection - pub async fn close(&mut self) -> Result<(), MuxerError> { - self.control - .close() - .await - .map_err(|err| MuxerError::InternalError(format!("shutdown error: {0:?}", err))) - } -} - -impl UidYamux -where - T: AsyncWrite + AsyncRead + Send + Unpin + 'static, -{ - /// Creates a new muxer with the provided config and socket - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip(socket), ret) - )] - pub fn new(config: yamux::Config, socket: T, mode: yamux::Mode) -> Self { - let (control, conn) = yamux::Control::new(yamux::Connection::new(socket, config, mode)); - - Self { - mode, - conn: Some(conn), - control, - state: Arc::new(Mutex::new(MuxState::default())), - } - } - - /// Runs the muxer. - /// - /// This method will poll the underlying connection for new streams and - /// handle them appropriately. - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(self), err) - )] - pub async fn run(&mut self) -> Result<(), MuxerError> { - let mut conn = Box::pin( - self.conn - .take() - .ok_or_else(|| MuxerError::InternalError("connection shutdown".to_string()))? - .fuse(), - ); - - // The size of this buffer is bounded by yamux max stream config. - let mut pending_streams = FuturesUnordered::new(); - loop { - futures::select! { - // Handle incoming streams - stream = conn.select_next_some() => { - if self.mode == yamux::Mode::Client { - return Err(MuxerError::InternalError( - "client mode cannot accept incoming streams".to_string(), - )); - } - - let mut stream = - stream.map_err(|e| MuxerError::InternalError(format!("connection error: {0:?}", e)))?; - - pending_streams.push(async move { - let stream_id = read_stream_id(&mut stream).await?; - - Ok::<_, MuxerError>((stream_id, stream)) - }); - } - // Handle streams for which we've received the id - stream = pending_streams.select_next_some() => { - let (stream_id, stream) = stream?; - - let mut state = self.state.lock().unwrap(); - state.stream_ids.insert(stream_id.clone()); - if let Some(sender) = state.waiting_callers.remove(&stream_id) { - // ignore if receiver dropped - _ = sender.send(Ok(stream)); - } else { - state.waiting_streams.insert(stream_id, stream); - } - } - complete => return Ok(()), - } - } - } - - /// Returns a [UidYamuxControl] that can be used to open streams. - pub fn control(&self) -> UidYamuxControl { - UidYamuxControl { - mode: self.mode, - control: self.control.clone(), - state: self.state.clone(), - } - } -} - -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(stream), err) -)] -async fn write_stream_id( - stream: &mut T, - id: &str, -) -> Result<(), std::io::Error> { - let id = id.as_bytes(); - - if id.len() > u32::MAX as usize { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "id too long", - )); - } - - stream.write_all(&(id.len() as u32).to_be_bytes()).await?; - stream.write_all(id).await?; - - Ok(()) -} - -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(stream), ret, err) -)] -async fn read_stream_id(stream: &mut T) -> Result { - let mut len = [0u8; 4]; - stream.read_exact(&mut len).await?; - - let len = u32::from_be_bytes(len) as usize; - - let mut id = vec![0u8; len]; - stream.read_exact(&mut id).await?; - - Ok(String::from_utf8_lossy(&id).to_string()) -} - -#[async_trait] -impl MuxStream for UidYamuxControl { - type Stream = yamux::Stream; - - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip(self), err) - )] - async fn get_stream(&mut self, id: &str) -> Result { - match self.mode { - yamux::Mode::Client => { - if !self.state.lock().unwrap().stream_ids.insert(id.to_string()) { - return Err(MuxerError::DuplicateStreamId(id.to_string())); - } - - let mut stream = self.control.open_stream().await.map_err(|e| { - MuxerError::InternalError(format!("failed to open stream: {}", e)) - })?; - - write_stream_id(&mut stream, id).await?; - - Ok(stream) - } - yamux::Mode::Server => { - let receiver = { - let mut state = self.state.lock().unwrap(); - - // If we already have the stream, return it - if let Some(stream) = state.waiting_streams.remove(id) { - return Ok(stream); - } - - // Prevent duplicate stream ids - if state.stream_ids.contains(id) { - return Err(MuxerError::DuplicateStreamId(id.to_string())); - } - - let (sender, receiver) = oneshot::channel(); - state.waiting_callers.insert(id.to_string(), sender); - - receiver - }; - - receiver - .await - .map_err(|_| MuxerError::InternalError("sender dropped".to_string()))? - } - } - } -} - -#[cfg(test)] -mod tests { - use futures::{AsyncReadExt, AsyncWriteExt, FutureExt}; - use tokio_util::compat::TokioAsyncReadCompatExt; - - use super::*; - - async fn create_pair() -> (UidYamuxControl, UidYamuxControl) { - let (socket_a, socket_b) = tokio::io::duplex(1024); - - let mut mux_a = UidYamux::new( - yamux::Config::default(), - socket_a.compat(), - yamux::Mode::Client, - ); - let mut mux_b = UidYamux::new( - yamux::Config::default(), - socket_b.compat(), - yamux::Mode::Server, - ); - - let control_a = mux_a.control(); - let control_b = mux_b.control(); - - tokio::spawn(async move { - mux_a.run().await.unwrap(); - }); - - tokio::spawn(async move { - mux_b.run().await.unwrap(); - }); - - (control_a, control_b) - } - - #[tokio::test] - async fn test_mux() { - let (mut control_a, mut control_b) = create_pair().await; - - let (mut stream_a, mut stream_b) = - tokio::try_join!(control_a.get_stream("test"), control_b.get_stream("test")).unwrap(); - - let msg = b"hello world"; - - stream_a.write_all(msg).await.unwrap(); - stream_a.flush().await.unwrap(); - - let mut buf = [0u8; 11]; - stream_b.read_exact(&mut buf).await.unwrap(); - - assert_eq!(&buf, msg); - } - - #[tokio::test] - async fn test_mux_multiple_streams() { - let (mut control_a, mut control_b) = create_pair().await; - - let (mut stream_a, mut stream_b) = - tokio::try_join!(control_a.get_stream("test"), control_b.get_stream("test")).unwrap(); - - let (mut stream_c, mut stream_d) = - tokio::try_join!(control_a.get_stream("test2"), control_b.get_stream("test2")).unwrap(); - - let msg = b"hello world"; - - stream_d.write_all(msg).await.unwrap(); - stream_d.flush().await.unwrap(); - - let mut buf = [0u8; 11]; - stream_c.read_exact(&mut buf).await.unwrap(); - - assert_eq!(&buf, msg); - - let msg = b"hello world2"; - - stream_a.write_all(msg).await.unwrap(); - stream_a.flush().await.unwrap(); - - let mut buf = [0u8; 12]; - stream_b.read_exact(&mut buf).await.unwrap(); - - assert_eq!(&buf, msg); - } - - #[tokio::test] - async fn test_mux_no_duplicates() { - let (mut control_a, mut control_b) = create_pair().await; - - let _ = - tokio::try_join!(control_a.get_stream("test"), control_b.get_stream("test")).unwrap(); - - let (err_a, err_b) = - tokio::join!(control_a.get_stream("test"), control_b.get_stream("test")); - - assert!(err_a.is_err()); - assert!(err_b.is_err()); - } - - #[tokio::test] - async fn test_mux_send_before_opened() { - let (mut control_a, mut control_b) = create_pair().await; - - let mut stream_a = control_a.get_stream("test").await.unwrap(); - - let msg = b"hello world"; - - stream_a.write_all(msg).await.unwrap(); - stream_a.flush().await.unwrap(); - - let mut stream_b = control_b.get_stream("test").await.unwrap(); - - let mut buf = [0u8; 11]; - let read = futures::select! { - read = stream_b.read(&mut buf).fuse() => read.unwrap(), - _ = tokio::time::sleep(std::time::Duration::from_secs(5)).fuse() => panic!("timed out"), - }; - - assert_eq!(&buf[..read], msg); - } -} diff --git a/components/universal-hash/Cargo.toml b/components/universal-hash/Cargo.toml deleted file mode 100644 index 3044c34509..0000000000 --- a/components/universal-hash/Cargo.toml +++ /dev/null @@ -1,43 +0,0 @@ -[package] -name = "tlsn-universal-hash" -authors = ["TLSNotary Team"] -description = "A crate which implements different hash functions for two-party computation" -keywords = ["tls", "mpc", "2pc", "hash"] -categories = ["cryptography"] -license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" -edition = "2021" - -[features] -default = ["ghash", "mock"] -tracing = ["dep:tracing"] -ghash = [] -mock = [] - -[dependencies] -# tlsn -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-share-conversion-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } - -# async -async-trait = "0.1" -futures = "0.3" -futures-util = "0.3" - -# error/log -thiserror = "1" -opaque-debug = "0.3" -tracing = { version = "0.1", optional = true } - -# misc -derive_builder = "0.12" - -[dev-dependencies] -ghash_rc = { package = "ghash", version = "0.5" } -tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] } -criterion = "0.5" -rstest = "0.17" -rand_chacha = "0.3" -rand = "0.8" -generic-array = "0.14" diff --git a/crates/benches/.gitignore b/crates/benches/.gitignore new file mode 100644 index 0000000000..77699ee8a3 --- /dev/null +++ b/crates/benches/.gitignore @@ -0,0 +1,2 @@ +*.svg +*.html diff --git a/crates/benches/Cargo.toml b/crates/benches/Cargo.toml new file mode 100644 index 0000000000..c0991da4d5 --- /dev/null +++ b/crates/benches/Cargo.toml @@ -0,0 +1,46 @@ +[package] +edition = "2021" +name = "tlsn-benches" +publish = false +version = "0.0.0" + +[dependencies] +anyhow = { workspace = true } +charming = { version = "0.3.1", features = ["ssr"] } +csv = "1.3.0" +futures = { workspace = true } +serde = { workspace = true } +tlsn-common = { workspace = true } +tlsn-core = { workspace = true } +tlsn-prover = { workspace = true } +tlsn-server-fixture = { workspace = true } +tlsn-server-fixture-certs = { workspace = true } +tlsn-tls-core = { workspace = true } +tlsn-verifier = { workspace = true } +tokio = { workspace = true, features = [ + "rt", + "rt-multi-thread", + "macros", + "net", + "io-std", + "fs", +] } +tokio-util = { workspace = true } +toml = "0.8.11" +tracing-subscriber = { workspace = true, features = ["env-filter"] } + +[[bin]] +name = "bench" +path = "bin/bench.rs" + +[[bin]] +name = "prover" +path = "bin/prover.rs" + +[[bin]] +name = "verifier" +path = "bin/verifier.rs" + +[[bin]] +name = "plot" +path = "bin/plot.rs" diff --git a/crates/benches/README.md b/crates/benches/README.md new file mode 100644 index 0000000000..181aa00623 --- /dev/null +++ b/crates/benches/README.md @@ -0,0 +1,35 @@ +# TLSNotary bench utilities + +This crate provides utilities for benchmarking protocol performance under various network conditions and usage patterns. + +As the protocol is mostly IO bound, it's important to track how it performs in low bandwidth and/or high latency environments. To do this we set up temporary network namespaces and add virtual ethernet interfaces which we can control using the linux `tc` (Traffic Control) utility. + +## Configuration + +See the `bench.toml` file for benchmark configurations. + +## Preliminaries + +To run the benchmarks you will need `iproute2` installed, eg: +```sh +sudo apt-get install iproute2 -y +``` + +## Running benches + +Running the benches requires root privileges because they will set up virtual interfaces. The script is designed to fully clean up when the benches are done, but run them at your own risk. + +Make sure you're in the `crates/benches/` directory, build the binaries then run the script: + +```sh +cargo build --release +sudo ./bench.sh +``` + +## Metrics + +After you run the benches you will see a `metrics.csv` file in the working directory. It will be owned by `root`, so you probably want to run + +```sh +sudo chown $USER metrics.csv +``` \ No newline at end of file diff --git a/crates/benches/bench.sh b/crates/benches/bench.sh new file mode 100755 index 0000000000..d398db87f1 --- /dev/null +++ b/crates/benches/bench.sh @@ -0,0 +1,13 @@ +#! /bin/bash + +# Check if we are running as root +if [ "$EUID" -ne 0 ]; then + echo "This script must be run as root" + exit +fi + +# Run the benchmark binary +../../target/release/bench + +# Plot the results +../../target/release/plot metrics.csv diff --git a/crates/benches/bench.toml b/crates/benches/bench.toml new file mode 100644 index 0000000000..44f86f7d6a --- /dev/null +++ b/crates/benches/bench.toml @@ -0,0 +1,39 @@ +[[benches]] +name = "latency" +upload = 250 +upload-delay = [10, 25, 50] +download = 250 +download-delay = [10, 25, 50] +upload-size = 1024 +download-size = 4096 +defer-decryption = true + +[[benches]] +name = "download_bandwidth" +upload = 250 +upload-delay = 25 +download = [10, 25, 50, 100, 250] +download-delay = 25 +upload-size = 1024 +download-size = 4096 +defer-decryption = true + +[[benches]] +name = "upload_bandwidth" +upload = [10, 25, 50, 100, 250] +upload-delay = 25 +download = 250 +download-delay = 25 +upload-size = 1024 +download-size = 4096 +defer-decryption = [false, true] + +[[benches]] +name = "download_volume" +upload = 250 +upload-delay = 25 +download = 250 +download-delay = 25 +upload-size = 1024 +download-size = [1024, 4096, 16384, 65536] +defer-decryption = true diff --git a/crates/benches/benches.Dockerfile b/crates/benches/benches.Dockerfile new file mode 100644 index 0000000000..a2e9967a0a --- /dev/null +++ b/crates/benches/benches.Dockerfile @@ -0,0 +1,21 @@ +FROM rust AS builder +WORKDIR /usr/src/tlsn +COPY . . +RUN cd crates/benches && cargo build --release + +FROM ubuntu:latest + +RUN apt-get update && apt-get -y upgrade && apt-get install -y --no-install-recommends \ + iproute2 \ + sudo \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +COPY --from=builder ["/usr/src/tlsn/target/release/bench", "/usr/src/tlsn/target/release/prover", "/usr/src/tlsn/target/release/verifier", "/usr/src/tlsn/target/release/plot", "/usr/local/bin/"] + +ENV PROVER_PATH="/usr/local/bin/prover" +ENV VERIFIER_PATH="/usr/local/bin/verifier" + +VOLUME [ "/benches" ] +WORKDIR "/benches" +CMD ["/bin/bash", "-c", "bench && plot /benches/metrics.csv && cat /benches/metrics.csv"] diff --git a/crates/benches/benches.Dockerfile.dockerignore b/crates/benches/benches.Dockerfile.dockerignore new file mode 100644 index 0000000000..1a3de888d1 --- /dev/null +++ b/crates/benches/benches.Dockerfile.dockerignore @@ -0,0 +1,2 @@ +# exclude any /target folders +**/target* diff --git a/crates/benches/bin/bench.rs b/crates/benches/bin/bench.rs new file mode 100644 index 0000000000..acdaa3eef1 --- /dev/null +++ b/crates/benches/bin/bench.rs @@ -0,0 +1,44 @@ +use std::process::Command; + +use tlsn_benches::{clean_up, set_up}; + +fn main() { + let prover_path = + std::env::var("PROVER_PATH").unwrap_or_else(|_| "../../target/release/prover".to_string()); + let verifier_path = std::env::var("VERIFIER_PATH") + .unwrap_or_else(|_| "../../target/release/verifier".to_string()); + + if let Err(e) = set_up() { + println!("Error setting up: {}", e); + clean_up(); + } + + // Run prover and verifier binaries in parallel + let Ok(mut verifier) = Command::new("ip") + .arg("netns") + .arg("exec") + .arg("verifier-ns") + .arg(verifier_path) + .spawn() + else { + println!("Failed to start verifier"); + return clean_up(); + }; + + let Ok(mut prover) = Command::new("ip") + .arg("netns") + .arg("exec") + .arg("prover-ns") + .arg(prover_path) + .spawn() + else { + println!("Failed to start prover"); + return clean_up(); + }; + + // Wait for both to finish + _ = prover.wait(); + _ = verifier.wait(); + + clean_up(); +} diff --git a/crates/benches/bin/plot.rs b/crates/benches/bin/plot.rs new file mode 100644 index 0000000000..8ede5f3952 --- /dev/null +++ b/crates/benches/bin/plot.rs @@ -0,0 +1,156 @@ +use charming::{ + component::{ + Axis, DataView, Feature, Legend, Restore, SaveAsImage, Title, Toolbox, ToolboxDataZoom, + }, + element::{NameLocation, Orient, Tooltip, Trigger}, + series::{Line, Scatter}, + theme::Theme, + Chart, HtmlRenderer, +}; +use tlsn_benches::metrics::Metrics; + +const THEME: Theme = Theme::Default; + +fn main() -> Result<(), Box> { + let csv_file = std::env::args() + .nth(1) + .expect("Usage: plot "); + + let mut rdr = csv::Reader::from_path(csv_file)?; + + // Prepare data for plotting + let all_data: Vec = rdr + .deserialize::() + .collect::, _>>()?; // Attempt to collect all results, return an error if any fail + + let _chart = runtime_vs_latency(&all_data)?; + let _chart = runtime_vs_bandwidth(&all_data)?; + + Ok(()) +} + +fn runtime_vs_latency(all_data: &[Metrics]) -> Result> { + const TITLE: &str = "Runtime vs Latency"; + + let data: Vec> = all_data + .iter() + .filter(|record| record.name == "latency") + .map(|record| { + let total_delay = record.upload_delay + record.download_delay; // Calculate the sum of upload and download delays. + vec![total_delay as f32, record.runtime as f32] + }) + .collect(); + + // https://github.com/yuankunzhang/charming + let chart = Chart::new() + .title(Title::new().text(TITLE)) + .tooltip(Tooltip::new().trigger(Trigger::Axis)) + .legend(Legend::new().orient(Orient::Vertical)) + .toolbox( + Toolbox::new().show(true).feature( + Feature::new() + .save_as_image(SaveAsImage::new()) + .restore(Restore::new()) + .data_zoom(ToolboxDataZoom::new().y_axis_index("none")) + .data_view(DataView::new().read_only(false)), + ), + ) + .x_axis( + Axis::new() + .scale(true) + .name("Upload + Download Latency (ms)") + .name_location(NameLocation::Center), + ) + .y_axis( + Axis::new() + .scale(true) + .name("Runtime (s)") + .name_location(NameLocation::Middle), + ) + .series( + Scatter::new() + .name("Combined Latency") + .symbol_size(10) + .data(data), + ); + + // Save the chart as HTML file. + HtmlRenderer::new(TITLE, 1000, 800) + .theme(THEME) + .save(&chart, "runtime_vs_latency.html") + .unwrap(); + + Ok(chart) +} + +fn runtime_vs_bandwidth(all_data: &[Metrics]) -> Result> { + const TITLE: &str = "Runtime vs Bandwidth"; + + let download_data: Vec> = all_data + .iter() + .filter(|record| record.name == "download_bandwidth") + .map(|record| vec![record.download as f32, record.runtime as f32]) + .collect(); + let upload_deferred_data: Vec> = all_data + .iter() + .filter(|record| record.name == "upload_bandwidth" && record.defer_decryption) + .map(|record| vec![record.upload as f32, record.runtime as f32]) + .collect(); + let upload_non_deferred_data: Vec> = all_data + .iter() + .filter(|record| record.name == "upload_bandwidth" && !record.defer_decryption) + .map(|record| vec![record.upload as f32, record.runtime as f32]) + .collect(); + + // https://github.com/yuankunzhang/charming + let chart = Chart::new() + .title(Title::new().text(TITLE)) + .tooltip(Tooltip::new().trigger(Trigger::Axis)) + .legend(Legend::new().orient(Orient::Vertical)) + .toolbox( + Toolbox::new().show(true).feature( + Feature::new() + .save_as_image(SaveAsImage::new()) + .restore(Restore::new()) + .data_zoom(ToolboxDataZoom::new().y_axis_index("none")) + .data_view(DataView::new().read_only(false)), + ), + ) + .x_axis( + Axis::new() + .scale(true) + .name("Bandwidth (Mbps)") + .name_location(NameLocation::Center), + ) + .y_axis( + Axis::new() + .scale(true) + .name("Runtime (s)") + .name_location(NameLocation::Middle), + ) + .series( + Line::new() + .name("Download bandwidth") + .symbol_size(10) + .data(download_data), + ) + .series( + Line::new() + .name("Upload bandwidth (deferred decryption)") + .symbol_size(10) + .data(upload_deferred_data), + ) + .series( + Line::new() + .name("Upload bandwidth") + .symbol_size(10) + .data(upload_non_deferred_data), + ); + // Save the chart as HTML file. + HtmlRenderer::new(TITLE, 1000, 800) + .theme(THEME) + .save(&chart, "runtime_vs_bandwidth.html") + .unwrap(); + + Ok(chart) +} diff --git a/crates/benches/bin/prover.rs b/crates/benches/bin/prover.rs new file mode 100644 index 0000000000..5d737b1aff --- /dev/null +++ b/crates/benches/bin/prover.rs @@ -0,0 +1,197 @@ +use std::{ + io::Write, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Instant, +}; + +use anyhow::Context; +use futures::{AsyncReadExt, AsyncWriteExt}; +use tls_core::verify::WebPkiVerifier; +use tlsn_benches::{ + config::{BenchInstance, Config}, + metrics::Metrics, + set_interface, PROVER_INTERFACE, +}; +use tlsn_common::config::ProtocolConfig; +use tlsn_core::{transcript::Idx, CryptoProvider}; +use tlsn_server_fixture::bind; +use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::{ + compat::TokioAsyncReadCompatExt, + io::{InspectReader, InspectWriter}, +}; + +use tlsn_prover::{Prover, ProverConfig}; +use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let config_path = std::env::var("CFG").unwrap_or_else(|_| "bench.toml".to_string()); + let config: Config = toml::from_str( + &std::fs::read_to_string(config_path).context("failed to read config file")?, + ) + .context("failed to parse config")?; + + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) + .init(); + + let ip = std::env::var("VERIFIER_IP").unwrap_or_else(|_| "10.10.1.1".to_string()); + let port: u16 = std::env::var("VERIFIER_PORT") + .map(|port| port.parse().expect("port is valid u16")) + .unwrap_or(8000); + let verifier_host = (ip.as_str(), port); + + let mut file = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open("metrics.csv") + .context("failed to open metrics file")?; + + { + let mut metric_wtr = csv::Writer::from_writer(&mut file); + for bench in config.benches { + let instances = bench.flatten(); + for instance in instances { + println!("{:?}", &instance); + + let io = tokio::net::TcpStream::connect(verifier_host) + .await + .context("failed to open tcp connection")?; + metric_wtr.serialize( + run_instance(instance, io) + .await + .context("failed to run instance")?, + )?; + metric_wtr.flush()?; + } + } + } + + file.flush()?; + + Ok(()) +} + +async fn run_instance( + instance: BenchInstance, + io: S, +) -> anyhow::Result { + let uploaded = Arc::new(AtomicU64::new(0)); + let downloaded = Arc::new(AtomicU64::new(0)); + let io = InspectWriter::new( + InspectReader::new(io, { + let downloaded = downloaded.clone(); + move |data| { + downloaded.fetch_add(data.len() as u64, Ordering::Relaxed); + } + }), + { + let uploaded = uploaded.clone(); + move |data| { + uploaded.fetch_add(data.len() as u64, Ordering::Relaxed); + } + }, + ); + + let BenchInstance { + name, + upload, + upload_delay, + download, + download_delay, + upload_size, + download_size, + defer_decryption, + } = instance.clone(); + + set_interface(PROVER_INTERFACE, upload, 1, upload_delay)?; + + let (client_conn, server_conn) = tokio::io::duplex(2 << 16); + tokio::spawn(bind(server_conn.compat())); + + let start_time = Instant::now(); + + let provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store(), None), + ..Default::default() + }; + + let protocol_config = if defer_decryption { + ProtocolConfig::builder() + .max_sent_data(upload_size + 256) + .max_recv_data(download_size + 256) + .build() + .unwrap() + } else { + ProtocolConfig::builder() + .max_sent_data(upload_size + 256) + .max_recv_data(download_size + 256) + .max_recv_data_online(download_size + 256) + .build() + .unwrap() + }; + + let prover = Prover::new( + ProverConfig::builder() + .server_name(SERVER_DOMAIN) + .protocol_config(protocol_config) + .defer_decryption_from_start(defer_decryption) + .crypto_provider(provider) + .build() + .context("invalid prover config")?, + ) + .setup(io.compat()) + .await?; + + let (mut mpc_tls_connection, prover_fut) = prover.connect(client_conn.compat()).await.unwrap(); + + let prover_task = tokio::spawn(prover_fut); + + let request = format!( + "GET /bytes?size={} HTTP/1.1\r\nConnection: close\r\nData: {}\r\n\r\n", + download_size, + String::from_utf8(vec![0x42u8; upload_size]).unwrap(), + ); + + mpc_tls_connection.write_all(request.as_bytes()).await?; + mpc_tls_connection.close().await?; + + let mut response = vec![]; + mpc_tls_connection.read_to_end(&mut response).await?; + + let mut prover = prover_task.await??.start_prove(); + + let (sent_len, recv_len) = prover.transcript().len(); + prover + .prove_transcript(Idx::new(0..sent_len), Idx::new(0..recv_len)) + .await?; + prover.finalize().await?; + + Ok(Metrics { + name, + upload, + upload_delay, + download, + download_delay, + upload_size, + download_size, + defer_decryption, + runtime: Instant::now().duration_since(start_time).as_secs(), + uploaded: uploaded.load(Ordering::SeqCst), + downloaded: downloaded.load(Ordering::SeqCst), + }) +} + +fn root_store() -> tls_core::anchors::RootCertStore { + let mut root_store = tls_core::anchors::RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + root_store +} diff --git a/crates/benches/bin/verifier.rs b/crates/benches/bin/verifier.rs new file mode 100644 index 0000000000..e4d5a62fb4 --- /dev/null +++ b/crates/benches/bin/verifier.rs @@ -0,0 +1,100 @@ +use anyhow::Context; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::compat::TokioAsyncReadCompatExt; +use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; + +use tls_core::verify::WebPkiVerifier; +use tlsn_benches::{ + config::{BenchInstance, Config}, + set_interface, VERIFIER_INTERFACE, +}; +use tlsn_common::config::ProtocolConfigValidator; +use tlsn_core::CryptoProvider; +use tlsn_server_fixture_certs::CA_CERT_DER; +use tlsn_verifier::{Verifier, VerifierConfig}; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let config_path = std::env::var("CFG").unwrap_or_else(|_| "bench.toml".to_string()); + let config: Config = toml::from_str( + &std::fs::read_to_string(config_path).context("failed to read config file")?, + ) + .context("failed to parse config")?; + + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) + .init(); + + let ip = std::env::var("VERIFIER_IP").unwrap_or_else(|_| "10.10.1.1".to_string()); + let port: u16 = std::env::var("VERIFIER_PORT") + .map(|port| port.parse().expect("port is valid u16")) + .unwrap_or(8000); + let host = (ip.as_str(), port); + + let listener = tokio::net::TcpListener::bind(host) + .await + .context("failed to bind to port")?; + + for bench in config.benches { + for instance in bench.flatten() { + let (io, _) = listener + .accept() + .await + .context("failed to accept connection")?; + run_instance(instance, io) + .await + .context("failed to run instance")?; + } + } + + Ok(()) +} + +async fn run_instance( + instance: BenchInstance, + io: S, +) -> anyhow::Result<()> { + let BenchInstance { + download, + download_delay, + upload_size, + download_size, + .. + } = instance; + + set_interface(VERIFIER_INTERFACE, download, 1, download_delay)?; + + let provider = CryptoProvider { + cert: cert_verifier(), + ..Default::default() + }; + + let config_validator = ProtocolConfigValidator::builder() + .max_sent_data(upload_size + 256) + .max_recv_data(download_size + 256) + .build() + .unwrap(); + + let verifier = Verifier::new( + VerifierConfig::builder() + .protocol_config_validator(config_validator) + .crypto_provider(provider) + .build()?, + ); + + _ = verifier.verify(io.compat()).await?; + + println!("verifier done"); + + Ok(()) +} + +fn cert_verifier() -> WebPkiVerifier { + let mut root_store = tls_core::anchors::RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + WebPkiVerifier::new(root_store, None) +} diff --git a/crates/benches/docker.md b/crates/benches/docker.md new file mode 100644 index 0000000000..a8995be37d --- /dev/null +++ b/crates/benches/docker.md @@ -0,0 +1,12 @@ +# Run the TLSN benches with Docker + +In the root folder of this repository, run: +``` +docker build -t tlsn-bench . -f ./crates/benches/benches.Dockerfile +``` + +Next run the benches with: +``` +docker run -it --privileged -v ./crates/benches/:/benches tlsn-bench +``` +The `--privileged` parameter is required because this test bench needs permission to create networks with certain parameters \ No newline at end of file diff --git a/crates/benches/src/config.rs b/crates/benches/src/config.rs new file mode 100644 index 0000000000..87e07247b2 --- /dev/null +++ b/crates/benches/src/config.rs @@ -0,0 +1,111 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize)] +#[serde(untagged)] +pub enum Field { + Single(T), + Multiple(Vec), +} + +#[derive(Deserialize)] +pub struct Config { + pub benches: Vec, +} + +#[derive(Deserialize)] +pub struct Bench { + pub name: String, + pub upload: Field, + #[serde(rename = "upload-delay")] + pub upload_delay: Field, + pub download: Field, + #[serde(rename = "download-delay")] + pub download_delay: Field, + #[serde(rename = "upload-size")] + pub upload_size: Field, + #[serde(rename = "download-size")] + pub download_size: Field, + #[serde(rename = "defer-decryption")] + pub defer_decryption: Field, +} + +impl Bench { + /// Flattens the config into a list of instances + pub fn flatten(self) -> Vec { + let mut instances = vec![]; + + let upload = match self.upload { + Field::Single(u) => vec![u], + Field::Multiple(u) => u, + }; + + let upload_delay = match self.upload_delay { + Field::Single(u) => vec![u], + Field::Multiple(u) => u, + }; + + let download = match self.download { + Field::Single(u) => vec![u], + Field::Multiple(u) => u, + }; + + let download_latency = match self.download_delay { + Field::Single(u) => vec![u], + Field::Multiple(u) => u, + }; + + let upload_size = match self.upload_size { + Field::Single(u) => vec![u], + Field::Multiple(u) => u, + }; + + let download_size = match self.download_size { + Field::Single(u) => vec![u], + Field::Multiple(u) => u, + }; + + let defer_decryption = match self.defer_decryption { + Field::Single(u) => vec![u], + Field::Multiple(u) => u, + }; + + for u in upload { + for ul in &upload_delay { + for d in &download { + for dl in &download_latency { + for us in &upload_size { + for ds in &download_size { + for dd in &defer_decryption { + instances.push(BenchInstance { + name: self.name.clone(), + upload: u, + upload_delay: *ul, + download: *d, + download_delay: *dl, + upload_size: *us, + download_size: *ds, + defer_decryption: *dd, + }); + } + } + } + } + } + } + } + + instances + } +} + +#[derive(Debug, Clone, Serialize)] +pub struct BenchInstance { + pub name: String, + pub upload: usize, + pub upload_delay: usize, + pub download: usize, + pub download_delay: usize, + pub upload_size: usize, + pub download_size: usize, + pub defer_decryption: bool, +} diff --git a/crates/benches/src/lib.rs b/crates/benches/src/lib.rs new file mode 100644 index 0000000000..8dedfd60f5 --- /dev/null +++ b/crates/benches/src/lib.rs @@ -0,0 +1,255 @@ +pub mod config; +pub mod metrics; + +use std::{io, process::Command}; + +pub const PROVER_NAMESPACE: &str = "prover-ns"; +pub const PROVER_INTERFACE: &str = "prover-veth"; +pub const PROVER_SUBNET: &str = "10.10.1.0/24"; +pub const VERIFIER_NAMESPACE: &str = "verifier-ns"; +pub const VERIFIER_INTERFACE: &str = "verifier-veth"; +pub const VERIFIER_SUBNET: &str = "10.10.1.1/24"; + +pub fn set_up() -> io::Result<()> { + // Create network namespaces + create_network_namespace(PROVER_NAMESPACE)?; + create_network_namespace(VERIFIER_NAMESPACE)?; + + // Create veth pair and attach to namespaces + create_veth_pair( + PROVER_NAMESPACE, + PROVER_INTERFACE, + VERIFIER_NAMESPACE, + VERIFIER_INTERFACE, + )?; + + // Set devices up + set_device_up(PROVER_NAMESPACE, PROVER_INTERFACE)?; + set_device_up(VERIFIER_NAMESPACE, VERIFIER_INTERFACE)?; + + // Assign IPs + assign_ip_to_interface(PROVER_NAMESPACE, PROVER_INTERFACE, PROVER_SUBNET)?; + assign_ip_to_interface(VERIFIER_NAMESPACE, VERIFIER_INTERFACE, VERIFIER_SUBNET)?; + + // Set default routes + set_default_route( + PROVER_NAMESPACE, + PROVER_INTERFACE, + PROVER_SUBNET.split('/').next().unwrap(), + )?; + set_default_route( + VERIFIER_NAMESPACE, + VERIFIER_INTERFACE, + VERIFIER_SUBNET.split('/').next().unwrap(), + )?; + + Ok(()) +} + +pub fn clean_up() { + // Delete interface pair + if let Err(e) = Command::new("ip") + .args([ + "netns", + "exec", + PROVER_NAMESPACE, + "ip", + "link", + "delete", + PROVER_INTERFACE, + ]) + .status() + { + println!("Error deleting interface {}: {}", PROVER_INTERFACE, e); + } + + // Delete namespaces + if let Err(e) = Command::new("ip") + .args(["netns", "del", PROVER_NAMESPACE]) + .status() + { + println!("Error deleting namespace {}: {}", PROVER_NAMESPACE, e); + } + + if let Err(e) = Command::new("ip") + .args(["netns", "del", VERIFIER_NAMESPACE]) + .status() + { + println!("Error deleting namespace {}: {}", VERIFIER_NAMESPACE, e); + } +} + +/// Sets the interface parameters. +/// +/// Must be run in the correct namespace. +/// +/// # Arguments +/// +/// * `egress` - The egress bandwidth in mbps. +/// * `burst` - The burst in mbps. +/// * `delay` - The delay in ms. +pub fn set_interface(interface: &str, egress: usize, burst: usize, delay: usize) -> io::Result<()> { + // Clear rules + _ = Command::new("tc") + .arg("qdisc") + .arg("del") + .arg("dev") + .arg(interface) + .arg("root") + .status(); + + // Egress + Command::new("tc") + .arg("qdisc") + .arg("add") + .arg("dev") + .arg(interface) + .arg("root") + .arg("handle") + .arg("1:") + .arg("tbf") + .arg("rate") + .arg(format!("{}mbit", egress)) + .arg("burst") + .arg(format!("{}mbit", burst)) + .arg("latency") + .arg("60s") + .status()?; + + // Delay + Command::new("tc") + .arg("qdisc") + .arg("add") + .arg("dev") + .arg(interface) + .arg("parent") + .arg("1:1") + .arg("handle") + .arg("10:") + .arg("netem") + .arg("delay") + .arg(format!("{}ms", delay)) + .status()?; + + Ok(()) +} + +/// Create a network namespace with the given name if it does not already exist. +fn create_network_namespace(name: &str) -> io::Result<()> { + // Check if namespace already exists + if Command::new("ip") + .args(["netns", "list"]) + .output()? + .stdout + .windows(name.len()) + .any(|ns| ns == name.as_bytes()) + { + println!("Namespace {} already exists", name); + return Ok(()); + } else { + println!("Creating namespace {}", name); + Command::new("ip").args(["netns", "add", name]).status()?; + } + + Ok(()) +} + +fn create_veth_pair( + left_namespace: &str, + left_interface: &str, + right_namespace: &str, + right_interface: &str, +) -> io::Result<()> { + // Check if interfaces are already present in namespaces + if is_interface_present_in_namespace(left_namespace, left_interface)? + || is_interface_present_in_namespace(right_namespace, right_interface)? + { + println!("Virtual interface already exists."); + return Ok(()); + } + + // Create veth pair + Command::new("ip") + .args([ + "link", + "add", + left_interface, + "type", + "veth", + "peer", + "name", + right_interface, + ]) + .status()?; + + println!( + "Created veth pair {} and {}", + left_interface, right_interface + ); + + // Attach veth pair to namespaces + attach_interface_to_namespace(left_namespace, left_interface)?; + attach_interface_to_namespace(right_namespace, right_interface)?; + + Ok(()) +} + +fn attach_interface_to_namespace(namespace: &str, interface: &str) -> io::Result<()> { + Command::new("ip") + .args(["link", "set", interface, "netns", namespace]) + .status()?; + + println!("Attached {} to namespace {}", interface, namespace); + + Ok(()) +} + +fn set_default_route(namespace: &str, interface: &str, ip: &str) -> io::Result<()> { + Command::new("ip") + .args([ + "netns", "exec", namespace, "ip", "route", "add", "default", "via", ip, "dev", + interface, + ]) + .status()?; + + println!( + "Set default route for namespace {} ip {} to {}", + namespace, ip, interface + ); + + Ok(()) +} + +fn is_interface_present_in_namespace( + namespace: &str, + interface: &str, +) -> Result { + Ok(Command::new("ip") + .args([ + "netns", "exec", namespace, "ip", "link", "list", "dev", interface, + ]) + .output()? + .stdout + .windows(interface.len()) + .any(|ns| ns == interface.as_bytes())) +} + +fn set_device_up(namespace: &str, interface: &str) -> io::Result<()> { + Command::new("ip") + .args([ + "netns", "exec", namespace, "ip", "link", "set", interface, "up", + ]) + .status()?; + + Ok(()) +} + +fn assign_ip_to_interface(namespace: &str, interface: &str, ip: &str) -> io::Result<()> { + Command::new("ip") + .args([ + "netns", "exec", namespace, "ip", "addr", "add", ip, "dev", interface, + ]) + .status()?; + + Ok(()) +} diff --git a/crates/benches/src/metrics.rs b/crates/benches/src/metrics.rs new file mode 100644 index 0000000000..aba096b6a6 --- /dev/null +++ b/crates/benches/src/metrics.rs @@ -0,0 +1,26 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Metrics { + pub name: String, + /// Upload bandwidth in Mbps. + pub upload: usize, + /// Upload latency in ms. + pub upload_delay: usize, + /// Download bandwidth in Mbps. + pub download: usize, + /// Download latency in ms. + pub download_delay: usize, + /// Total bytes sent to the server. + pub upload_size: usize, + /// Total bytes received from the server. + pub download_size: usize, + /// Whether deferred decryption was used. + pub defer_decryption: bool, + /// The total runtime of the benchmark in seconds. + pub runtime: u64, + /// The total amount of data uploaded to the verifier in bytes. + pub uploaded: u64, + /// The total amount of data downloaded from the verifier in bytes. + pub downloaded: u64, +} diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml new file mode 100644 index 0000000000..5f90753518 --- /dev/null +++ b/crates/common/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "tlsn-common" +description = "Common code shared between tlsn-prover and tlsn-verifier" +version = "0.1.0-alpha.7" +edition = "2021" + +[features] +default = [] +# Enables common types and config parameters used in the AuthDecode protocol. +# This feature can only be used if the prover and the verifier enable their `authdecode_unsafe` feature. +authdecode_unsafe_common = [] + +[dependencies] +tlsn-core = { workspace = true } +mpz-common = { workspace = true } +mpz-garble = { workspace = true } +mpz-ot = { workspace = true } + +derive_builder = { workspace = true } +futures = { workspace = true } +once_cell = { workspace = true } +serio = { workspace = true, features = ["codec", "bincode"] } +thiserror = { workspace = true } +tracing = { workspace = true } +uid-mux = { workspace = true, features = ["serio"] } +serde = { workspace = true, features = ["derive"] } +tlsn-utils = { workspace = true } +semver = { version = "1.0", features = ["serde"] } + +[dev-dependencies] +rstest = { workspace = true } diff --git a/crates/common/src/config.rs b/crates/common/src/config.rs new file mode 100644 index 0000000000..a86678e4d7 --- /dev/null +++ b/crates/common/src/config.rs @@ -0,0 +1,356 @@ +//! TLSNotary protocol config and config utilities. +use core::fmt; +use once_cell::sync::Lazy; +use semver::Version; +use serde::{Deserialize, Serialize}; +use std::error::Error; + +use crate::Role; + +// Extra cushion room, eg. for sharing J0 blocks. +const EXTRA_OTS: usize = 16384; + +const OTS_PER_BYTE_SENT: usize = 8; + +// Without deferred decryption we use 16, with it we use 8. +const OTS_PER_BYTE_RECV_ONLINE: usize = 16; +const OTS_PER_BYTE_RECV_DEFER: usize = 8; + +// Current version that is running. +static VERSION: Lazy = Lazy::new(|| { + Version::parse(env!("CARGO_PKG_VERSION")) + .map_err(|err| ProtocolConfigError::new(ErrorKind::Version, err)) + .unwrap() +}); + +/// Protocol configuration to be set up initially by prover and verifier. +#[derive(derive_builder::Builder, Clone, Debug, Deserialize, Serialize)] +#[builder(build_fn(validate = "Self::validate"))] +pub struct ProtocolConfig { + /// Maximum number of bytes that can be sent. + max_sent_data: usize, + /// Maximum number of bytes that can be decrypted online, i.e. while the + /// MPC-TLS connection is active. + #[builder(default = "0")] + max_recv_data_online: usize, + /// Maximum number of bytes that can be received. + max_recv_data: usize, + #[cfg(feature = "authdecode_unsafe_common")] + /// Maximum number of plaintext bytes which can be authenticated using the AuthDecode + /// protocol. + #[builder(default = "0")] + max_authdecode_data: usize, + /// Version that is being run by prover/verifier. + #[builder(setter(skip), default = "VERSION.clone()")] + version: Version, +} + +impl ProtocolConfigBuilder { + fn validate(&self) -> Result<(), String> { + if self.max_recv_data_online > self.max_recv_data { + return Err( + "max_recv_data_online must be smaller or equal to max_recv_data".to_string(), + ); + } + Ok(()) + } +} + +impl ProtocolConfig { + /// Creates a new builder for `ProtocolConfig`. + pub fn builder() -> ProtocolConfigBuilder { + ProtocolConfigBuilder::default() + } + + /// Returns the maximum number of bytes that can be sent. + pub fn max_sent_data(&self) -> usize { + self.max_sent_data + } + + /// Returns the maximum number of bytes that can be decrypted online. + pub fn max_recv_data_online(&self) -> usize { + self.max_recv_data_online + } + + /// Returns the maximum number of bytes that can be received. + pub fn max_recv_data(&self) -> usize { + self.max_recv_data + } + + #[cfg(feature = "authdecode_unsafe_common")] + /// Returns the maximum number of plaintext bytes which can be authenticated using the AuthDecode + /// protocol. + pub fn max_authdecode_data(&self) -> usize { + self.max_authdecode_data + } + + /// Returns OT sender setup count. + pub fn ot_sender_setup_count(&self, role: Role) -> usize { + ot_send_estimate( + role, + self.max_sent_data, + self.max_recv_data_online, + self.max_recv_data, + ) + } + + /// Returns OT receiver setup count. + pub fn ot_receiver_setup_count(&self, role: Role) -> usize { + ot_recv_estimate( + role, + self.max_sent_data, + self.max_recv_data_online, + self.max_recv_data, + ) + } +} + +/// Protocol configuration validator used by checker (i.e. verifier) to perform +/// compatibility check with the peer's (i.e. the prover's) configuration. +#[derive(derive_builder::Builder, Clone, Debug)] +pub struct ProtocolConfigValidator { + /// Maximum number of bytes that can be sent. + max_sent_data: usize, + /// Maximum number of bytes that can be received. + max_recv_data: usize, + /// Maximum number of plaintext bytes which can be authenticated using the AuthDecode + /// protocol. + #[builder(default = "0")] + #[cfg(feature = "authdecode_unsafe_common")] + max_authdecode_data: usize, + /// Version that is being run by checker. + #[builder(setter(skip), default = "VERSION.clone()")] + version: Version, +} + +impl ProtocolConfigValidator { + /// Creates a new builder for `ProtocolConfigValidator`. + pub fn builder() -> ProtocolConfigValidatorBuilder { + ProtocolConfigValidatorBuilder::default() + } + + /// Returns the maximum number of bytes that can be sent. + pub fn max_sent_data(&self) -> usize { + self.max_sent_data + } + + /// Returns the maximum number of bytes that can be received. + pub fn max_recv_data(&self) -> usize { + self.max_recv_data + } + + #[cfg(feature = "authdecode_unsafe_common")] + /// Returns the maximum number of plaintext bytes which can be authenticated using the AuthDecode + /// protocol. + pub fn max_authdecode_data(&self) -> usize { + self.max_authdecode_data + } + + /// Performs compatibility check of the protocol configuration between prover and verifier. + pub fn validate(&self, config: &ProtocolConfig) -> Result<(), ProtocolConfigError> { + self.check_max_transcript_size(config.max_sent_data, config.max_recv_data)?; + #[cfg(feature = "authdecode_unsafe_common")] + self.check_max_authdecode_data(config.max_authdecode_data)?; + self.check_version(&config.version)?; + Ok(()) + } + + // Checks if both the sent and recv data are within limits. + fn check_max_transcript_size( + &self, + max_sent_data: usize, + max_recv_data: usize, + ) -> Result<(), ProtocolConfigError> { + if max_sent_data > self.max_sent_data { + return Err(ProtocolConfigError::max_transcript_size(format!( + "max_sent_data {:?} is greater than the configured limit {:?}", + max_sent_data, self.max_sent_data, + ))); + } + + if max_recv_data > self.max_recv_data { + return Err(ProtocolConfigError::max_transcript_size(format!( + "max_recv_data {:?} is greater than the configured limit {:?}", + max_recv_data, self.max_recv_data, + ))); + } + + Ok(()) + } + + #[cfg(feature = "authdecode_unsafe_common")] + // Checks if the number of authenticated bytes is within limits. + fn check_max_authdecode_data( + &self, + max_authdecode_data: usize, + ) -> Result<(), ProtocolConfigError> { + if max_authdecode_data > self.max_authdecode_data { + return Err(ProtocolConfigError::version(format!( + "max_authdecode_data {:?} is greater than the configured limit {:?}", + max_authdecode_data, self.max_authdecode_data + ))); + } + + Ok(()) + } + + // Checks if both versions are the same (might support check for different but compatible versions + // in the future). + fn check_version(&self, peer_version: &Version) -> Result<(), ProtocolConfigError> { + if *peer_version != self.version { + return Err(ProtocolConfigError::version(format!( + "prover's version {:?} is different from verifier's version {:?}", + peer_version, self.version + ))); + } + + Ok(()) + } +} + +/// A ProtocolConfig error. +#[derive(thiserror::Error, Debug)] +pub struct ProtocolConfigError { + kind: ErrorKind, + #[source] + source: Option>, +} + +impl ProtocolConfigError { + fn new(kind: ErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } + + fn max_transcript_size(msg: impl Into) -> Self { + Self { + kind: ErrorKind::MaxTranscriptSize, + source: Some(msg.into().into()), + } + } + + fn version(msg: impl Into) -> Self { + Self { + kind: ErrorKind::Version, + source: Some(msg.into().into()), + } + } +} + +impl fmt::Display for ProtocolConfigError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.kind { + ErrorKind::MaxTranscriptSize => write!(f, "max transcript size error")?, + ErrorKind::Version => write!(f, "version error")?, + } + + if let Some(ref source) = self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +#[derive(Debug)] +enum ErrorKind { + MaxTranscriptSize, + Version, +} + +/// Returns an estimate of the number of OTs that will be sent. +pub fn ot_send_estimate( + role: Role, + max_sent_data: usize, + max_recv_data_online: usize, + max_recv_data: usize, +) -> usize { + match role { + Role::Prover => EXTRA_OTS, + Role::Verifier => { + EXTRA_OTS + + (max_sent_data * OTS_PER_BYTE_SENT) + + (max_recv_data_online * OTS_PER_BYTE_RECV_ONLINE) + + ((max_recv_data - max_recv_data_online) * OTS_PER_BYTE_RECV_DEFER) + } + } +} + +/// Returns an estimate of the number of OTs that will be received. +pub fn ot_recv_estimate( + role: Role, + max_sent_data: usize, + max_recv_data_online: usize, + max_recv_data: usize, +) -> usize { + match role { + Role::Prover => { + EXTRA_OTS + + (max_sent_data * OTS_PER_BYTE_SENT) + + (max_recv_data_online * OTS_PER_BYTE_RECV_ONLINE) + + ((max_recv_data - max_recv_data_online) * OTS_PER_BYTE_RECV_DEFER) + } + Role::Verifier => EXTRA_OTS, + } +} + +#[cfg(test)] +mod test { + use super::*; + use rstest::{fixture, rstest}; + + const TEST_MAX_SENT_LIMIT: usize = 1 << 12; + const TEST_MAX_RECV_LIMIT: usize = 1 << 14; + + #[fixture] + #[once] + fn config_validator() -> ProtocolConfigValidator { + ProtocolConfigValidator::builder() + .max_sent_data(TEST_MAX_SENT_LIMIT) + .max_recv_data(TEST_MAX_RECV_LIMIT) + .build() + .unwrap() + } + + #[rstest] + #[case::same_max_sent_recv_data(TEST_MAX_SENT_LIMIT, TEST_MAX_RECV_LIMIT)] + #[case::smaller_max_sent_data(1 << 11, TEST_MAX_RECV_LIMIT)] + #[case::smaller_max_recv_data(TEST_MAX_SENT_LIMIT, 1 << 13)] + #[case::smaller_max_sent_recv_data(1 << 7, 1 << 9)] + fn test_check_success( + config_validator: &ProtocolConfigValidator, + #[case] max_sent_data: usize, + #[case] max_recv_data: usize, + ) { + let peer_config = ProtocolConfig::builder() + .max_sent_data(max_sent_data) + .max_recv_data(max_recv_data) + .build() + .unwrap(); + + assert!(config_validator.validate(&peer_config).is_ok()) + } + + #[rstest] + #[case::bigger_max_sent_data(1 << 13, TEST_MAX_RECV_LIMIT)] + #[case::bigger_max_recv_data(1 << 10, 1 << 16)] + #[case::bigger_max_sent_recv_data(1 << 14, 1 << 21)] + fn test_check_fail( + config_validator: &ProtocolConfigValidator, + #[case] max_sent_data: usize, + #[case] max_recv_data: usize, + ) { + let peer_config = ProtocolConfig::builder() + .max_sent_data(max_sent_data) + .max_recv_data(max_recv_data) + .build() + .unwrap(); + + assert!(config_validator.validate(&peer_config).is_err()) + } +} diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs new file mode 100644 index 0000000000..7d2aef28f8 --- /dev/null +++ b/crates/common/src/lib.rs @@ -0,0 +1,40 @@ +//! Common code shared between `tlsn-prover` and `tlsn-verifier`. + +#![deny(missing_docs, unreachable_pub, unused_must_use)] +#![deny(clippy::all)] +#![forbid(unsafe_code)] + +pub mod config; +pub mod msg; +pub mod mux; + +use serio::codec::Codec; + +use crate::mux::MuxControl; + +/// IO type. +pub type Io = >::Framed; +/// Base OT sender. +pub type BaseOTSender = mpz_ot::chou_orlandi::Sender; +/// Base OT receiver. +pub type BaseOTReceiver = mpz_ot::chou_orlandi::Receiver; +/// OT sender. +pub type OTSender = mpz_ot::kos::SharedSender; +/// OT receiver. +pub type OTReceiver = mpz_ot::kos::SharedReceiver; +/// MPC executor. +pub type Executor = mpz_common::executor::MTExecutor; +/// MPC thread context. +pub type Context = mpz_common::executor::MTContext; +/// DEAP thread. +pub type DEAPThread = mpz_garble::protocol::deap::DEAPThread; + +/// The party's role in the TLSN protocol. +/// +/// A Notary is classified as a Verifier. +pub enum Role { + /// The prover. + Prover, + /// The verifier. + Verifier, +} diff --git a/crates/common/src/msg.rs b/crates/common/src/msg.rs new file mode 100644 index 0000000000..912b0b4f01 --- /dev/null +++ b/crates/common/src/msg.rs @@ -0,0 +1,14 @@ +//! Message types. + +use serde::{Deserialize, Serialize}; + +use tlsn_core::connection::{ServerCertData, ServerName}; + +/// Message sent from Prover to Verifier to prove the server identity. +#[derive(Debug, Serialize, Deserialize)] +pub struct ServerIdentityProof { + /// Server name. + pub name: ServerName, + /// Server identity data. + pub data: ServerCertData, +} diff --git a/crates/common/src/mux.rs b/crates/common/src/mux.rs new file mode 100644 index 0000000000..6562d47c8d --- /dev/null +++ b/crates/common/src/mux.rs @@ -0,0 +1,91 @@ +//! Multiplexer used in the TLSNotary protocol. + +use std::future::IntoFuture; + +use futures::{ + future::{FusedFuture, FutureExt}, + AsyncRead, AsyncWrite, Future, +}; +use serio::codec::Bincode; +use tracing::error; +use uid_mux::{yamux, FramedMux}; + +use crate::Role; + +/// Multiplexer supporting unique deterministic stream IDs. +pub type Mux = yamux::Yamux; +/// Multiplexer controller providing streams with a codec attached. +pub type MuxControl = FramedMux; + +/// Multiplexer future which must be polled for the muxer to make progress. +pub struct MuxFuture( + Box> + Send + Unpin>, +); + +impl MuxFuture { + /// Returns true if the muxer is complete. + pub fn is_complete(&self) -> bool { + self.0.is_terminated() + } + + /// Awaits a future, polling the muxer future concurrently. + pub async fn poll_with(&mut self, fut: F) -> R + where + F: Future, + { + let mut fut = Box::pin(fut.fuse()); + // Poll the future concurrently with the muxer future. + // If the muxer returns an error, continue polling the future + // until it completes. + loop { + futures::select! { + res = fut => return res, + res = &mut self.0 => if let Err(e) = res { + error!("mux error: {:?}", e); + }, + } + } + } +} + +impl Future for MuxFuture { + type Output = Result<(), yamux::ConnectionError>; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + self.0.as_mut().poll_unpin(cx) + } +} + +/// Attaches a multiplexer to the provided socket. +/// +/// Returns the multiplexer and a controller for creating streams with a codec +/// attached. +/// +/// # Arguments +/// +/// * `socket` - The socket to attach the multiplexer to. +/// * `role` - The role of the party using the multiplexer. +pub fn attach_mux( + socket: T, + role: Role, +) -> (MuxFuture, MuxControl) { + let mut mux_config = yamux::Config::default(); + mux_config.set_max_num_streams(64); + + let mux_role = match role { + Role::Prover => yamux::Mode::Client, + Role::Verifier => yamux::Mode::Server, + }; + + let mux = Mux::new(socket, mux_config, mux_role); + let ctrl = FramedMux::new(mux.control(), Bincode); + + if let Role::Prover = role { + ctrl.mux().alloc(64); + } + + (MuxFuture(Box::new(mux.into_future().fuse())), ctrl) +} diff --git a/crates/components/aead/Cargo.toml b/crates/components/aead/Cargo.toml new file mode 100644 index 0000000000..a92abd5cd6 --- /dev/null +++ b/crates/components/aead/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "tlsn-aead" +authors = ["TLSNotary Team"] +description = "This crate provides an implementation of a two-party version of AES-GCM behind an AEAD trait" +keywords = ["tls", "mpc", "2pc", "aead", "aes", "aes-gcm"] +categories = ["cryptography"] +license = "MIT OR Apache-2.0" +version = "0.1.0-alpha.7" +edition = "2021" + +[lib] +name = "aead" + +[features] +default = ["mock"] +mock = ["mpz-common/test-utils", "dep:mpz-ot"] + +[dependencies] +tlsn-block-cipher = { workspace = true } +tlsn-stream-cipher = { workspace = true } +tlsn-universal-hash = { workspace = true } + +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", optional = true, features = [ + "ideal", +] } + +serio = { workspace = true } + +async-trait = { workspace = true } +derive_builder = { workspace = true } +futures = { workspace = true } +serde = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] } +aes-gcm = { workspace = true } diff --git a/components/aead/src/aes_gcm/config.rs b/crates/components/aead/src/aes_gcm/config.rs similarity index 70% rename from components/aead/src/aes_gcm/config.rs rename to crates/components/aead/src/aes_gcm/config.rs index 5bc908c7b4..26a92c298d 100644 --- a/components/aead/src/aes_gcm/config.rs +++ b/crates/components/aead/src/aes_gcm/config.rs @@ -1,6 +1,6 @@ use derive_builder::Builder; -/// Protocol role +/// Protocol role. #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[allow(missing_docs)] pub enum Role { @@ -11,25 +11,25 @@ pub enum Role { /// Configuration for AES-GCM. #[derive(Debug, Clone, Builder)] pub struct AesGcmConfig { - /// The id of this instance + /// The id of this instance. #[builder(setter(into))] id: String, - /// The protocol role + /// The protocol role. role: Role, } impl AesGcmConfig { - /// Creates a new builder for the AES-GCM configuration + /// Creates a new builder for the AES-GCM configuration. pub fn builder() -> AesGcmConfigBuilder { AesGcmConfigBuilder::default() } - /// Returns the id of this instance + /// Returns the id of this instance. pub fn id(&self) -> &str { &self.id } - /// Returns the protocol role + /// Returns the protocol role. pub fn role(&self) -> &Role { &self.role } diff --git a/crates/components/aead/src/aes_gcm/error.rs b/crates/components/aead/src/aes_gcm/error.rs new file mode 100644 index 0000000000..21ee698af4 --- /dev/null +++ b/crates/components/aead/src/aes_gcm/error.rs @@ -0,0 +1,102 @@ +use std::fmt::Display; + +/// AES-GCM error. +#[derive(Debug, thiserror::Error)] +pub struct AesGcmError { + kind: ErrorKind, + #[source] + source: Option>, +} + +impl AesGcmError { + pub(crate) fn new(kind: ErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } + + #[cfg(test)] + pub(crate) fn kind(&self) -> ErrorKind { + self.kind + } + + pub(crate) fn invalid_tag() -> Self { + Self { + kind: ErrorKind::Tag, + source: None, + } + } + + pub(crate) fn peer(reason: impl Into) -> Self { + Self { + kind: ErrorKind::PeerMisbehaved, + source: Some(reason.into().into()), + } + } + + pub(crate) fn payload(reason: impl Into) -> Self { + Self { + kind: ErrorKind::Payload, + source: Some(reason.into().into()), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub(crate) enum ErrorKind { + Io, + BlockCipher, + StreamCipher, + Ghash, + Tag, + PeerMisbehaved, + Payload, +} + +impl Display for AesGcmError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.kind { + ErrorKind::Io => write!(f, "io error")?, + ErrorKind::BlockCipher => write!(f, "block cipher error")?, + ErrorKind::StreamCipher => write!(f, "stream cipher error")?, + ErrorKind::Ghash => write!(f, "ghash error")?, + ErrorKind::Tag => write!(f, "payload has corrupted tag")?, + ErrorKind::PeerMisbehaved => write!(f, "peer misbehaved")?, + ErrorKind::Payload => write!(f, "payload error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +impl From for AesGcmError { + fn from(err: std::io::Error) -> Self { + Self::new(ErrorKind::Io, err) + } +} + +impl From for AesGcmError { + fn from(err: block_cipher::BlockCipherError) -> Self { + Self::new(ErrorKind::BlockCipher, err) + } +} + +impl From for AesGcmError { + fn from(err: tlsn_stream_cipher::StreamCipherError) -> Self { + Self::new(ErrorKind::StreamCipher, err) + } +} + +impl From for AesGcmError { + fn from(err: tlsn_universal_hash::UniversalHashError) -> Self { + Self::new(ErrorKind::Ghash, err) + } +} diff --git a/components/aead/src/aes_gcm/mock.rs b/crates/components/aead/src/aes_gcm/mock.rs similarity index 56% rename from components/aead/src/aes_gcm/mock.rs rename to crates/components/aead/src/aes_gcm/mock.rs index 48e23252bc..b94dcae311 100644 --- a/components/aead/src/aes_gcm/mock.rs +++ b/crates/components/aead/src/aes_gcm/mock.rs @@ -1,10 +1,12 @@ //! Mock implementation of AES-GCM for testing purposes. use block_cipher::{BlockCipherConfig, MpcBlockCipher}; -use mpz_garble::{Decode, DecodePrivate, Execute, Memory, Prove, Verify, Vm}; +use mpz_common::executor::{test_st_executor, STExecutor}; +use mpz_garble::protocol::deap::mock::{MockFollower, MockLeader}; +use mpz_ot::ideal::ot::ideal_ot; +use serio::channel::MemoryDuplex; use tlsn_stream_cipher::{MpcStreamCipher, StreamCipherConfig}; -use tlsn_universal_hash::ghash::{mock_ghash_pair, GhashConfig}; -use utils_aio::duplex::MemoryDuplex; +use tlsn_universal_hash::ghash::ideal_ghash; use super::*; @@ -13,35 +15,45 @@ use super::*; /// # Arguments /// /// * `id` - The id of the AES-GCM instances. -/// * `leader_vm` - The VM of the leader. -/// * `follower_vm` - The VM of the follower. +/// * `(leader, follower)` - The leader and follower vms. /// * `leader_config` - The configuration of the leader. /// * `follower_config` - The configuration of the follower. -pub async fn create_mock_aes_gcm_pair( +pub async fn create_mock_aes_gcm_pair( id: &str, - leader_vm: &mut T, - follower_vm: &mut T, + (leader, follower): (MockLeader, MockFollower), leader_config: AesGcmConfig, follower_config: AesGcmConfig, -) -> (MpcAesGcm, MpcAesGcm) -where - T: Vm + Send, - ::Thread: Memory + Execute + Decode + DecodePrivate + Prove + Verify + Send + Sync, -{ +) -> ( + MpcAesGcm>, + MpcAesGcm>, +) { let block_cipher_id = format!("{}/block_cipher", id); + let (ctx_leader, ctx_follower) = test_st_executor(128); + + let (leader_ot_send, follower_ot_recv) = ideal_ot(); + let (follower_ot_send, leader_ot_recv) = ideal_ot(); + + let block_leader = leader + .new_thread(ctx_leader, leader_ot_send, leader_ot_recv) + .unwrap(); + + let block_follower = follower + .new_thread(ctx_follower, follower_ot_send, follower_ot_recv) + .unwrap(); + let leader_block_cipher = MpcBlockCipher::new( BlockCipherConfig::builder() .id(block_cipher_id.clone()) .build() .unwrap(), - leader_vm.new_thread(&block_cipher_id).await.unwrap(), + block_leader, ); let follower_block_cipher = MpcBlockCipher::new( BlockCipherConfig::builder() .id(block_cipher_id.clone()) .build() .unwrap(), - follower_vm.new_thread(&block_cipher_id).await.unwrap(), + block_follower, ); let stream_cipher_id = format!("{}/stream_cipher", id); @@ -50,40 +62,23 @@ where .id(stream_cipher_id.clone()) .build() .unwrap(), - leader_vm - .new_thread_pool(&stream_cipher_id, 4) - .await - .unwrap(), + leader, ); let follower_stream_cipher = MpcStreamCipher::new( StreamCipherConfig::builder() .id(stream_cipher_id.clone()) .build() .unwrap(), - follower_vm - .new_thread_pool(&stream_cipher_id, 4) - .await - .unwrap(), - ); - - let (leader_ghash, follower_ghash) = mock_ghash_pair( - GhashConfig::builder() - .id(format!("{}/ghash", id)) - .initial_block_count(64) - .build() - .unwrap(), - GhashConfig::builder() - .id(format!("{}/ghash", id)) - .initial_block_count(64) - .build() - .unwrap(), + follower, ); - let (leader_channel, follower_channel) = MemoryDuplex::new(); + let (ctx_a, ctx_b) = test_st_executor(128); + let (leader_ghash, follower_ghash) = ideal_ghash(ctx_a, ctx_b); + let (ctx_a, ctx_b) = test_st_executor(128); let leader = MpcAesGcm::new( leader_config, - Box::new(leader_channel), + ctx_a, Box::new(leader_block_cipher), Box::new(leader_stream_cipher), Box::new(leader_ghash), @@ -91,7 +86,7 @@ where let follower = MpcAesGcm::new( follower_config, - Box::new(follower_channel), + ctx_b, Box::new(follower_block_cipher), Box::new(follower_stream_cipher), Box::new(follower_ghash), diff --git a/components/aead/src/aes_gcm/mod.rs b/crates/components/aead/src/aes_gcm/mod.rs similarity index 51% rename from components/aead/src/aes_gcm/mod.rs rename to crates/components/aead/src/aes_gcm/mod.rs index c9ef641f2c..eb52dc191c 100644 --- a/components/aead/src/aes_gcm/mod.rs +++ b/crates/components/aead/src/aes_gcm/mod.rs @@ -1,244 +1,146 @@ //! This module provides an implementation of 2PC AES-GCM. mod config; +mod error; #[cfg(feature = "mock")] pub mod mock; mod tag; pub use config::{AesGcmConfig, AesGcmConfigBuilder, AesGcmConfigBuilderError, Role}; - -use crate::{ - msg::{AeadMessage, TagShare}, - Aead, AeadChannel, AeadError, -}; +pub use error::AesGcmError; use async_trait::async_trait; -use futures::{SinkExt, StreamExt, TryFutureExt}; - use block_cipher::{Aes128, BlockCipher}; -use mpz_core::commit::HashCommit; +use futures::TryFutureExt; +use mpz_common::Context; use mpz_garble::value::ValueRef; use tlsn_stream_cipher::{Aes128Ctr, StreamCipher}; use tlsn_universal_hash::UniversalHash; -use utils_aio::expect_msg_or_err; +use tracing::instrument; -pub(crate) use tag::AesGcmTagShare; -use tag::{build_ghash_data, AES_GCM_TAG_LEN}; +use crate::{ + aes_gcm::tag::{compute_tag, verify_tag, TAG_LEN}, + Aead, +}; -/// An implementation of 2PC AES-GCM. -pub struct MpcAesGcm { +/// MPC AES-GCM. +pub struct MpcAesGcm { config: AesGcmConfig, - channel: AeadChannel, + ctx: Ctx, aes_block: Box>, aes_ctr: Box>, ghash: Box, } -impl std::fmt::Debug for MpcAesGcm { +impl std::fmt::Debug for MpcAesGcm { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("MpcAesGcm") .field("config", &self.config) - .field("channel", &"AeadChannel {{ ... }}") - .field("aes_block", &"BlockCipher {{ ... }}") - .field("aes_ctr", &"StreamCipher {{ ... }}") - .field("ghash", &"UniversalHash {{ ... }}") .finish() } } -impl MpcAesGcm { +impl MpcAesGcm { /// Creates a new instance of [`MpcAesGcm`]. - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip(channel, aes_block, aes_ctr, ghash), ret) - )] pub fn new( config: AesGcmConfig, - channel: AeadChannel, + context: Ctx, aes_block: Box>, aes_ctr: Box>, ghash: Box, ) -> Self { Self { config, - channel, + ctx: context, aes_block, aes_ctr, ghash, } } - - #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", err))] - async fn compute_j0_share(&mut self, explicit_nonce: Vec) -> Result, AeadError> { - let j0_share = self - .aes_ctr - .share_keystream_block(explicit_nonce.clone(), 1) - .await?; - - Ok(j0_share) - } - - #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", err, ret))] - async fn compute_tag_share( - &mut self, - explicit_nonce: Vec, - aad: Vec, - ciphertext: Vec, - ) -> Result { - let j0_share = self.compute_j0_share(explicit_nonce.clone()).await?; - - let hash = self - .ghash - .finalize(build_ghash_data(aad, ciphertext)) - .await?; - - let mut tag_share = [0u8; 16]; - tag_share.copy_from_slice(&hash[..]); - for i in 0..16 { - tag_share[i] ^= j0_share[i]; - } - - Ok(AesGcmTagShare(tag_share)) - } - - #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", err, ret))] - async fn compute_tag( - &mut self, - explicit_nonce: Vec, - ciphertext: Vec, - aad: Vec, - ) -> Result, AeadError> { - let tag_share = self - .compute_tag_share(explicit_nonce, aad, ciphertext.clone()) - .await?; - - let tag = match self.config.role() { - Role::Leader => { - // Send commitment of tag share to follower - let (tag_share_decommitment, tag_share_commitment) = - TagShare::from(tag_share).hash_commit(); - - self.channel - .send(AeadMessage::TagShareCommitment(tag_share_commitment)) - .await?; - - // Expect tag share from follower - let msg = expect_msg_or_err!(self.channel, AeadMessage::TagShare)?; - - let other_tag_share = AesGcmTagShare::from_unchecked(&msg.share)?; - - // Send decommitment (tag share) to follower - self.channel - .send(AeadMessage::TagShareDecommitment(tag_share_decommitment)) - .await?; - - tag_share + other_tag_share - } - Role::Follower => { - // Wait for commitment from leader - let commitment = expect_msg_or_err!(self.channel, AeadMessage::TagShareCommitment)?; - - // Send tag share to leader - self.channel - .send(AeadMessage::TagShare(tag_share.into())) - .await?; - - // Expect decommitment (tag share) from leader - let decommitment = - expect_msg_or_err!(self.channel, AeadMessage::TagShareDecommitment)?; - - // Verify decommitment - decommitment.verify(&commitment).map_err(|_| { - AeadError::ValidationError( - "Leader tag share commitment verification failed".to_string(), - ) - })?; - - let other_tag_share = - AesGcmTagShare::from_unchecked(&decommitment.into_inner().share)?; - - tag_share + other_tag_share - } - }; - - Ok(tag) - } - - /// Splits off the tag from the end of the payload and verifies it. - async fn _verify_tag( - &mut self, - explicit_nonce: Vec, - payload: &mut Vec, - aad: Vec, - ) -> Result<(), AeadError> { - let purported_tag = payload.split_off(payload.len() - AES_GCM_TAG_LEN); - - let tag = self - .compute_tag(explicit_nonce, payload.clone(), aad) - .await?; - - // Reject if tag is incorrect. - if tag != purported_tag { - return Err(AeadError::CorruptedTag); - } - - Ok(()) - } } #[async_trait] -impl Aead for MpcAesGcm { - #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", err))] - async fn set_key(&mut self, key: ValueRef, iv: ValueRef) -> Result<(), AeadError> { +impl Aead for MpcAesGcm { + type Error = AesGcmError; + + #[instrument(level = "info", skip_all, err)] + async fn set_key(&mut self, key: ValueRef, iv: ValueRef) -> Result<(), AesGcmError> { self.aes_block.set_key(key.clone()); self.aes_ctr.set_key(key, iv); - // Share zero block - let h_share = self.aes_block.encrypt_share(vec![0u8; 16]).await?; - - self.ghash.set_key(h_share).await?; - Ok(()) } - #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", err))] - async fn decode_key_private(&mut self) -> Result<(), AeadError> { + #[instrument(level = "info", skip_all, err)] + async fn decode_key_private(&mut self) -> Result<(), AesGcmError> { self.aes_ctr .decode_key_private() .await - .map_err(AeadError::from) + .map_err(AesGcmError::from) } - #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", err))] - async fn decode_key_blind(&mut self) -> Result<(), AeadError> { + #[instrument(level = "info", skip_all, err)] + async fn decode_key_blind(&mut self) -> Result<(), AesGcmError> { self.aes_ctr .decode_key_blind() .await - .map_err(AeadError::from) + .map_err(AesGcmError::from) } fn set_transcript_id(&mut self, id: &str) { self.aes_ctr.set_transcript_id(id) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(plaintext), err) - )] + #[instrument(level = "debug", skip(self), err)] + async fn setup(&mut self) -> Result<(), AesGcmError> { + self.ghash.setup().await?; + + Ok(()) + } + + #[instrument(level = "debug", skip(self), err)] + async fn preprocess(&mut self, len: usize) -> Result<(), AesGcmError> { + futures::try_join!( + // Preprocess the GHASH key block. + self.aes_block + .preprocess(block_cipher::Visibility::Public, 1) + .map_err(AesGcmError::from), + self.aes_ctr.preprocess(len).map_err(AesGcmError::from), + self.ghash.preprocess().map_err(AesGcmError::from), + )?; + + Ok(()) + } + + #[instrument(level = "debug", skip_all, err)] + async fn start(&mut self) -> Result<(), AesGcmError> { + let h_share = self.aes_block.encrypt_share(vec![0u8; 16]).await?; + self.ghash.set_key(h_share).await?; + + Ok(()) + } + + #[instrument(level = "debug", skip_all, err)] async fn encrypt_public( &mut self, explicit_nonce: Vec, plaintext: Vec, aad: Vec, - ) -> Result, AeadError> { + ) -> Result, AesGcmError> { let ciphertext = self .aes_ctr .encrypt_public(explicit_nonce.clone(), plaintext) .await?; - let tag = self - .compute_tag(explicit_nonce, ciphertext.clone(), aad) - .await?; + let tag = compute_tag( + &mut self.ctx, + self.aes_ctr.as_mut(), + self.ghash.as_mut(), + explicit_nonce, + ciphertext.clone(), + aad, + ) + .await?; let mut payload = ciphertext; payload.extend(tag); @@ -246,24 +148,27 @@ impl Aead for MpcAesGcm { Ok(payload) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(plaintext), err) - )] + #[instrument(level = "debug", skip_all, err)] async fn encrypt_private( &mut self, explicit_nonce: Vec, plaintext: Vec, aad: Vec, - ) -> Result, AeadError> { + ) -> Result, AesGcmError> { let ciphertext = self .aes_ctr .encrypt_private(explicit_nonce.clone(), plaintext) .await?; - let tag = self - .compute_tag(explicit_nonce, ciphertext.clone(), aad) - .await?; + let tag = compute_tag( + &mut self.ctx, + self.aes_ctr.as_mut(), + self.ghash.as_mut(), + explicit_nonce, + ciphertext.clone(), + aad, + ) + .await?; let mut payload = ciphertext; payload.extend(tag); @@ -271,21 +176,27 @@ impl Aead for MpcAesGcm { Ok(payload) } - #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", err))] + #[instrument(level = "debug", skip_all, err)] async fn encrypt_blind( &mut self, explicit_nonce: Vec, plaintext_len: usize, aad: Vec, - ) -> Result, AeadError> { + ) -> Result, AesGcmError> { let ciphertext = self .aes_ctr .encrypt_blind(explicit_nonce.clone(), plaintext_len) .await?; - let tag = self - .compute_tag(explicit_nonce, ciphertext.clone(), aad) - .await?; + let tag = compute_tag( + &mut self.ctx, + self.aes_ctr.as_mut(), + self.ghash.as_mut(), + explicit_nonce, + ciphertext.clone(), + aad, + ) + .await?; let mut payload = ciphertext; payload.extend(tag); @@ -293,154 +204,234 @@ impl Aead for MpcAesGcm { Ok(payload) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(payload), err) - )] + #[instrument(level = "debug", skip_all, err)] async fn decrypt_public( &mut self, explicit_nonce: Vec, mut payload: Vec, aad: Vec, - ) -> Result, AeadError> { - self._verify_tag(explicit_nonce.clone(), &mut payload, aad) + ) -> Result, AesGcmError> { + let purported_tag: [u8; TAG_LEN] = payload + .split_off(payload.len() - TAG_LEN) + .try_into() + .map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?; + let ciphertext = payload; + + verify_tag( + &mut self.ctx, + self.aes_ctr.as_mut(), + self.ghash.as_mut(), + *self.config.role(), + explicit_nonce.clone(), + ciphertext.clone(), + aad, + purported_tag, + ) + .await?; + + let plaintext = self + .aes_ctr + .decrypt_public(explicit_nonce, ciphertext) .await?; - self.aes_ctr - .decrypt_public(explicit_nonce, payload) - .map_err(AeadError::from) - .await + Ok(plaintext) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(payload), err) - )] + #[instrument(level = "debug", skip_all, err)] async fn decrypt_private( &mut self, explicit_nonce: Vec, mut payload: Vec, aad: Vec, - ) -> Result, AeadError> { - self._verify_tag(explicit_nonce.clone(), &mut payload, aad) + ) -> Result, AesGcmError> { + let purported_tag: [u8; TAG_LEN] = payload + .split_off(payload.len() - TAG_LEN) + .try_into() + .map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?; + let ciphertext = payload; + + verify_tag( + &mut self.ctx, + self.aes_ctr.as_mut(), + self.ghash.as_mut(), + *self.config.role(), + explicit_nonce.clone(), + ciphertext.clone(), + aad, + purported_tag, + ) + .await?; + + let plaintext = self + .aes_ctr + .decrypt_private(explicit_nonce, ciphertext) .await?; - self.aes_ctr - .decrypt_private(explicit_nonce, payload) - .map_err(AeadError::from) - .await + Ok(plaintext) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(payload), err) - )] + #[instrument(level = "debug", skip_all, err)] async fn decrypt_blind( &mut self, explicit_nonce: Vec, mut payload: Vec, aad: Vec, - ) -> Result<(), AeadError> { - self._verify_tag(explicit_nonce.clone(), &mut payload, aad) - .await?; + ) -> Result<(), AesGcmError> { + let purported_tag: [u8; TAG_LEN] = payload + .split_off(payload.len() - TAG_LEN) + .try_into() + .map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?; + let ciphertext = payload; + + verify_tag( + &mut self.ctx, + self.aes_ctr.as_mut(), + self.ghash.as_mut(), + *self.config.role(), + explicit_nonce.clone(), + ciphertext.clone(), + aad, + purported_tag, + ) + .await?; self.aes_ctr - .decrypt_blind(explicit_nonce, payload) - .map_err(AeadError::from) - .await + .decrypt_blind(explicit_nonce, ciphertext) + .await?; + + Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(payload), err) - )] + #[instrument(level = "debug", skip_all, err)] async fn verify_tag( &mut self, explicit_nonce: Vec, mut payload: Vec, aad: Vec, - ) -> Result<(), AeadError> { - self._verify_tag(explicit_nonce.clone(), &mut payload, aad) - .await + ) -> Result<(), AesGcmError> { + let purported_tag: [u8; TAG_LEN] = payload + .split_off(payload.len() - TAG_LEN) + .try_into() + .map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?; + let ciphertext = payload; + + verify_tag( + &mut self.ctx, + self.aes_ctr.as_mut(), + self.ghash.as_mut(), + *self.config.role(), + explicit_nonce, + ciphertext, + aad, + purported_tag, + ) + .await?; + + Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(payload), err) - )] + #[instrument(level = "debug", skip_all, err)] async fn prove_plaintext( &mut self, explicit_nonce: Vec, mut payload: Vec, aad: Vec, - ) -> Result, AeadError> { - self._verify_tag(explicit_nonce.clone(), &mut payload, aad) + ) -> Result, AesGcmError> { + let purported_tag: [u8; TAG_LEN] = payload + .split_off(payload.len() - TAG_LEN) + .try_into() + .map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?; + let ciphertext = payload; + + verify_tag( + &mut self.ctx, + self.aes_ctr.as_mut(), + self.ghash.as_mut(), + *self.config.role(), + explicit_nonce.clone(), + ciphertext.clone(), + aad, + purported_tag, + ) + .await?; + + let plaintext = self + .aes_ctr + .prove_plaintext(explicit_nonce, ciphertext) .await?; - self.prove_plaintext_no_tag(explicit_nonce, payload).await + Ok(plaintext) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(ciphertext), err) - )] + #[instrument(level = "debug", skip_all, err)] async fn prove_plaintext_no_tag( &mut self, explicit_nonce: Vec, ciphertext: Vec, - ) -> Result, AeadError> { + ) -> Result, AesGcmError> { self.aes_ctr .prove_plaintext(explicit_nonce, ciphertext) - .map_err(AeadError::from) + .map_err(AesGcmError::from) .await } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(payload), err) - )] + #[instrument(level = "debug", skip_all, err)] async fn verify_plaintext( &mut self, explicit_nonce: Vec, mut payload: Vec, aad: Vec, - ) -> Result<(), AeadError> { - self._verify_tag(explicit_nonce.clone(), &mut payload, aad) + ) -> Result<(), AesGcmError> { + let purported_tag: [u8; TAG_LEN] = payload + .split_off(payload.len() - TAG_LEN) + .try_into() + .map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?; + let ciphertext = payload; + + verify_tag( + &mut self.ctx, + self.aes_ctr.as_mut(), + self.ghash.as_mut(), + *self.config.role(), + explicit_nonce.clone(), + ciphertext.clone(), + aad, + purported_tag, + ) + .await?; + + self.aes_ctr + .verify_plaintext(explicit_nonce, ciphertext) .await?; - self.verify_plaintext_no_tag(explicit_nonce, payload).await + Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(ciphertext), err) - )] + #[instrument(level = "debug", skip_all, err)] async fn verify_plaintext_no_tag( &mut self, explicit_nonce: Vec, ciphertext: Vec, - ) -> Result<(), AeadError> { + ) -> Result<(), AesGcmError> { self.aes_ctr .verify_plaintext(explicit_nonce, ciphertext) - .map_err(AeadError::from) + .map_err(AesGcmError::from) .await } } #[cfg(test)] mod tests { - use super::{mock::create_mock_aes_gcm_pair, *}; - use crate::Aead; - - use mpz_garble::{ - protocol::deap::mock::{create_mock_deap_vm, MockFollower, MockLeader}, - Memory, Vm, - }; - - use ::aes_gcm::{ - aead::{AeadInPlace, KeyInit}, - Aes128Gcm, Nonce, + use super::*; + use crate::{ + aes_gcm::{mock::create_mock_aes_gcm_pair, AesGcmConfigBuilder, Role}, + Aead, }; + use ::aes_gcm::{aead::AeadInPlace, Aes128Gcm, NewAead, Nonce}; + use error::ErrorKind; + use mpz_common::executor::STExecutor; + use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory}; + use serio::channel::MemoryDuplex; fn reference_impl( key: &[u8], @@ -464,30 +455,31 @@ mod tests { async fn setup_pair( key: Vec, iv: Vec, - ) -> ((MpcAesGcm, MpcAesGcm), (MockLeader, MockFollower)) { - let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test_vm").await; + ) -> ( + MpcAesGcm>, + MpcAesGcm>, + ) { + let (leader_vm, follower_vm) = create_mock_deap_vm(); - let leader_thread = leader_vm.new_thread("test_thread").await.unwrap(); - let leader_key = leader_thread + let leader_key = leader_vm .new_public_array_input::("key", key.len()) .unwrap(); - let leader_iv = leader_thread + let leader_iv = leader_vm .new_public_array_input::("iv", iv.len()) .unwrap(); - leader_thread.assign(&leader_key, key.clone()).unwrap(); - leader_thread.assign(&leader_iv, iv.clone()).unwrap(); + leader_vm.assign(&leader_key, key.clone()).unwrap(); + leader_vm.assign(&leader_iv, iv.clone()).unwrap(); - let follower_thread = follower_vm.new_thread("test_thread").await.unwrap(); - let follower_key = follower_thread + let follower_key = follower_vm .new_public_array_input::("key", key.len()) .unwrap(); - let follower_iv = follower_thread + let follower_iv = follower_vm .new_public_array_input::("iv", iv.len()) .unwrap(); - follower_thread.assign(&follower_key, key.clone()).unwrap(); - follower_thread.assign(&follower_iv, iv.clone()).unwrap(); + follower_vm.assign(&follower_key, key.clone()).unwrap(); + follower_vm.assign(&follower_iv, iv.clone()).unwrap(); let leader_config = AesGcmConfigBuilder::default() .id("test".to_string()) @@ -502,8 +494,7 @@ mod tests { let (mut leader, mut follower) = create_mock_aes_gcm_pair( "test", - &mut leader_vm, - &mut follower_vm, + (leader_vm, follower_vm), leader_config, follower_config, ) @@ -515,10 +506,14 @@ mod tests { ) .unwrap(); - ((leader, follower), (leader_vm, follower_vm)) + futures::try_join!(leader.setup(), follower.setup()).unwrap(); + futures::try_join!(leader.start(), follower.start()).unwrap(); + + (leader, follower) } #[tokio::test] + #[ignore = "expensive"] async fn test_aes_gcm_encrypt_private() { let key = vec![0u8; 16]; let iv = vec![0u8; 4]; @@ -526,8 +521,7 @@ mod tests { let plaintext = vec![1u8; 32]; let aad = vec![2u8; 12]; - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = - setup_pair(key.clone(), iv.clone()).await; + let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await; let (leader_ciphertext, follower_ciphertext) = tokio::try_join!( leader.encrypt_private(explicit_nonce.clone(), plaintext.clone(), aad.clone(),), @@ -543,6 +537,7 @@ mod tests { } #[tokio::test] + #[ignore = "expensive"] async fn test_aes_gcm_encrypt_public() { let key = vec![0u8; 16]; let iv = vec![0u8; 4]; @@ -550,8 +545,7 @@ mod tests { let plaintext = vec![1u8; 32]; let aad = vec![2u8; 12]; - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = - setup_pair(key.clone(), iv.clone()).await; + let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await; let (leader_ciphertext, follower_ciphertext) = tokio::try_join!( leader.encrypt_public(explicit_nonce.clone(), plaintext.clone(), aad.clone(),), @@ -567,6 +561,7 @@ mod tests { } #[tokio::test] + #[ignore = "expensive"] async fn test_aes_gcm_decrypt_private() { let key = vec![0u8; 16]; let iv = vec![0u8; 4]; @@ -575,8 +570,7 @@ mod tests { let aad = vec![2u8; 12]; let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad); - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = - setup_pair(key.clone(), iv.clone()).await; + let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await; let (leader_plaintext, _) = tokio::try_join!( leader.decrypt_private(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),), @@ -588,6 +582,7 @@ mod tests { } #[tokio::test] + #[ignore = "expensive"] async fn test_aes_gcm_decrypt_private_bad_tag() { let key = vec![0u8; 16]; let iv = vec![0u8; 4]; @@ -602,8 +597,7 @@ mod tests { let mut corrupted = ciphertext.clone(); corrupted[len - 1] -= 1; - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = - setup_pair(key.clone(), iv.clone()).await; + let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await; // leader receives corrupted tag let err = tokio::try_join!( @@ -611,10 +605,9 @@ mod tests { follower.decrypt_blind(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),) ) .unwrap_err(); - assert!(matches!(err, AeadError::CorruptedTag)); + assert_eq!(err.kind(), ErrorKind::Tag); - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = - setup_pair(key.clone(), iv.clone()).await; + let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await; // follower receives corrupted tag let err = tokio::try_join!( @@ -622,10 +615,11 @@ mod tests { follower.decrypt_blind(explicit_nonce.clone(), corrupted.clone(), aad.clone(),) ) .unwrap_err(); - assert!(matches!(err, AeadError::CorruptedTag)); + assert_eq!(err.kind(), ErrorKind::Tag); } #[tokio::test] + #[ignore = "expensive"] async fn test_aes_gcm_decrypt_public() { let key = vec![0u8; 16]; let iv = vec![0u8; 4]; @@ -634,8 +628,7 @@ mod tests { let aad = vec![2u8; 12]; let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad); - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = - setup_pair(key.clone(), iv.clone()).await; + let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await; let (leader_plaintext, follower_plaintext) = tokio::try_join!( leader.decrypt_public(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),), @@ -648,6 +641,7 @@ mod tests { } #[tokio::test] + #[ignore = "expensive"] async fn test_aes_gcm_decrypt_public_bad_tag() { let key = vec![0u8; 16]; let iv = vec![0u8; 4]; @@ -658,34 +652,33 @@ mod tests { let len = ciphertext.len(); - // corrupt tag + // Corrupt tag. let mut corrupted = ciphertext.clone(); corrupted[len - 1] -= 1; - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = - setup_pair(key.clone(), iv.clone()).await; + let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await; - // leader receives corrupted tag + // Leader receives corrupted tag. let err = tokio::try_join!( leader.decrypt_public(explicit_nonce.clone(), corrupted.clone(), aad.clone(),), follower.decrypt_public(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),) ) .unwrap_err(); - assert!(matches!(err, AeadError::CorruptedTag)); + assert_eq!(err.kind(), ErrorKind::Tag); - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = - setup_pair(key.clone(), iv.clone()).await; + let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await; - // follower receives corrupted tag + // Follower receives corrupted tag. let err = tokio::try_join!( leader.decrypt_public(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),), follower.decrypt_public(explicit_nonce.clone(), corrupted.clone(), aad.clone(),) ) .unwrap_err(); - assert!(matches!(err, AeadError::CorruptedTag)); + assert_eq!(err.kind(), ErrorKind::Tag); } #[tokio::test] + #[ignore = "expensive"] async fn test_aes_gcm_verify_tag() { let key = vec![0u8; 16]; let iv = vec![0u8; 4]; @@ -696,8 +689,7 @@ mod tests { let len = ciphertext.len(); - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = - setup_pair(key.clone(), iv.clone()).await; + let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await; tokio::try_join!( leader.verify_tag(explicit_nonce.clone(), ciphertext.clone(), aad.clone()), @@ -705,7 +697,7 @@ mod tests { ) .unwrap(); - // corrupt tag + //Corrupt tag. let mut corrupted = ciphertext.clone(); corrupted[len - 1] -= 1; @@ -714,7 +706,7 @@ mod tests { follower.verify_tag(explicit_nonce.clone(), corrupted, aad.clone()) ); - assert!(matches!(leader_res.unwrap_err(), AeadError::CorruptedTag)); - assert!(matches!(follower_res.unwrap_err(), AeadError::CorruptedTag)); + assert_eq!(leader_res.unwrap_err().kind(), ErrorKind::Tag); + assert_eq!(follower_res.unwrap_err().kind(), ErrorKind::Tag); } } diff --git a/crates/components/aead/src/aes_gcm/tag.rs b/crates/components/aead/src/aes_gcm/tag.rs new file mode 100644 index 0000000000..6c5034629a --- /dev/null +++ b/crates/components/aead/src/aes_gcm/tag.rs @@ -0,0 +1,179 @@ +use futures::TryFutureExt; +use mpz_common::Context; +use mpz_core::{ + commit::{Decommitment, HashCommit}, + hash::Hash, +}; +use serde::{Deserialize, Serialize}; +use serio::{stream::IoStreamExt, SinkExt}; +use std::ops::Add; +use tlsn_stream_cipher::{Aes128Ctr, StreamCipher}; +use tlsn_universal_hash::UniversalHash; +use tracing::instrument; + +use crate::aes_gcm::{AesGcmError, Role}; + +pub(crate) const TAG_LEN: usize = 16; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct TagShare([u8; TAG_LEN]); + +impl AsRef<[u8]> for TagShare { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl Add for TagShare { + type Output = [u8; TAG_LEN]; + + fn add(self, rhs: Self) -> Self::Output { + core::array::from_fn(|i| self.0[i] ^ rhs.0[i]) + } +} + +#[instrument(level = "trace", skip_all, err)] +async fn compute_tag_share + ?Sized, H: UniversalHash + ?Sized>( + aes_ctr: &mut C, + hasher: &mut H, + explicit_nonce: Vec, + ciphertext: Vec, + aad: Vec, +) -> Result { + let (j0, hash) = futures::try_join!( + aes_ctr + .share_keystream_block(explicit_nonce, 1) + .map_err(AesGcmError::from), + hasher + .finalize(build_ghash_data(aad, ciphertext)) + .map_err(AesGcmError::from) + )?; + + debug_assert!(j0.len() == TAG_LEN); + debug_assert!(hash.len() == TAG_LEN); + + let tag_share = core::array::from_fn(|i| j0[i] ^ hash[i]); + + Ok(TagShare(tag_share)) +} + +/// Computes the tag for a ciphertext and additional data. +/// +/// The commit-reveal step is not required for computing a tag sent to the +/// Server, as it will be able to detect if the tag is incorrect. +#[instrument(level = "debug", skip_all, err)] +pub(crate) async fn compute_tag< + Ctx: Context, + C: StreamCipher + ?Sized, + H: UniversalHash + ?Sized, +>( + ctx: &mut Ctx, + aes_ctr: &mut C, + hasher: &mut H, + explicit_nonce: Vec, + ciphertext: Vec, + aad: Vec, +) -> Result<[u8; TAG_LEN], AesGcmError> { + let tag_share = compute_tag_share(aes_ctr, hasher, explicit_nonce, ciphertext, aad).await?; + + // TODO: The follower doesn't really need to learn the tag, + // we could reduce some latency by not sending it. + let io = ctx.io_mut(); + io.send(tag_share.clone()).await?; + let other_tag_share: TagShare = io.expect_next().await?; + + let tag = tag_share + other_tag_share; + + Ok(tag) +} + +/// Verifies a purported tag against the ciphertext and additional data. +/// +/// Verifying a tag requires a commit-reveal protocol between the leader and +/// follower. Without it, the party which receives the other's tag share first +/// could trivially compute a tag share which would cause an invalid message to +/// be accepted. +#[instrument(level = "debug", skip_all, err)] +#[allow(clippy::too_many_arguments)] +pub(crate) async fn verify_tag< + Ctx: Context, + C: StreamCipher + ?Sized, + H: UniversalHash + ?Sized, +>( + ctx: &mut Ctx, + aes_ctr: &mut C, + hasher: &mut H, + role: Role, + explicit_nonce: Vec, + ciphertext: Vec, + aad: Vec, + purported_tag: [u8; TAG_LEN], +) -> Result<(), AesGcmError> { + let tag_share = compute_tag_share(aes_ctr, hasher, explicit_nonce, ciphertext, aad).await?; + + let io = ctx.io_mut(); + let tag = match role { + Role::Leader => { + // Send commitment of tag share to follower. + let (tag_share_decommitment, tag_share_commitment) = tag_share.clone().hash_commit(); + + io.send(tag_share_commitment).await?; + + let follower_tag_share: TagShare = io.expect_next().await?; + + // Send decommitment (tag share) to follower. + io.send(tag_share_decommitment).await?; + + tag_share + follower_tag_share + } + Role::Follower => { + // Wait for commitment from leader. + let commitment: Hash = io.expect_next().await?; + + // Send tag share to leader. + io.send(tag_share.clone()).await?; + + // Expect decommitment (tag share) from leader. + let decommitment: Decommitment = io.expect_next().await?; + + // Verify decommitment. + decommitment.verify(&commitment).map_err(|_| { + AesGcmError::peer("leader tag share commitment verification failed") + })?; + + let leader_tag_share = decommitment.into_inner(); + + tag_share + leader_tag_share + } + }; + + // Reject if tag is incorrect. + if tag != purported_tag { + return Err(AesGcmError::invalid_tag()); + } + + Ok(()) +} + +/// Builds padded data for GHASH. +fn build_ghash_data(mut aad: Vec, mut ciphertext: Vec) -> Vec { + let associated_data_bitlen = (aad.len() as u64) * 8; + let text_bitlen = (ciphertext.len() as u64) * 8; + + let len_block = ((associated_data_bitlen as u128) << 64) + (text_bitlen as u128); + + // Pad data to be a multiple of 16 bytes. + let aad_padded_block_count = (aad.len() / 16) + (aad.len() % 16 != 0) as usize; + aad.resize(aad_padded_block_count * 16, 0); + + let ciphertext_padded_block_count = + (ciphertext.len() / 16) + (ciphertext.len() % 16 != 0) as usize; + ciphertext.resize(ciphertext_padded_block_count * 16, 0); + + let mut data: Vec = Vec::with_capacity(aad.len() + ciphertext.len() + 16); + data.extend(aad); + data.extend(ciphertext); + data.extend_from_slice(&len_block.to_be_bytes()); + + data +} diff --git a/components/aead/src/lib.rs b/crates/components/aead/src/lib.rs similarity index 63% rename from components/aead/src/lib.rs rename to crates/components/aead/src/lib.rs index 77b17d2d7a..0b9cfe0fb2 100644 --- a/components/aead/src/lib.rs +++ b/crates/components/aead/src/lib.rs @@ -1,79 +1,72 @@ -//! This crate provides implementations of 2PC AEADs for authenticated encryption with -//! a shared key. +//! This crate provides implementations of 2PC AEADs for authenticated +//! encryption with a shared key. //! -//! Both parties can work together to encrypt and decrypt messages with different visibility -//! configurations. See [`Aead`] for more information on the interface. +//! Both parties can work together to encrypt and decrypt messages with +//! different visibility configurations. See [`Aead`] for more information on +//! the interface. //! -//! For example, one party can privately provide the plaintext to encrypt, while both parties -//! can see the ciphertext and the tag. Or, both parties can cooperate to decrypt a ciphertext -//! and verify the tag, while only one party can see the plaintext. +//! For example, one party can privately provide the plaintext to encrypt, while +//! both parties can see the ciphertext and the tag. Or, both parties can +//! cooperate to decrypt a ciphertext and verify the tag, while only one party +//! can see the plaintext. #![deny(missing_docs, unreachable_pub, unused_must_use)] #![deny(clippy::all)] #![forbid(unsafe_code)] pub mod aes_gcm; -pub mod msg; - -pub use msg::AeadMessage; use async_trait::async_trait; - use mpz_garble::value::ValueRef; -use utils_aio::duplex::Duplex; - -/// A channel for sending and receiving AEAD messages. -pub type AeadChannel = Box>; - -/// An error that can occur during AEAD operations. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum AeadError { - #[error(transparent)] - BlockCipherError(#[from] block_cipher::BlockCipherError), - #[error(transparent)] - StreamCipherError(#[from] tlsn_stream_cipher::StreamCipherError), - #[error(transparent)] - UniversalHashError(#[from] tlsn_universal_hash::UniversalHashError), - #[error("Corrupted Tag")] - CorruptedTag, - #[error("Validation Error: {0}")] - ValidationError(String), - #[error(transparent)] - IoError(#[from] std::io::Error), -} /// This trait defines the interface for AEADs. #[async_trait] pub trait Aead: Send { + /// The error type for the AEAD. + type Error: std::error::Error + Send + Sync + 'static; + /// Sets the key for the AEAD. - async fn set_key(&mut self, key: ValueRef, iv: ValueRef) -> Result<(), AeadError>; + async fn set_key(&mut self, key: ValueRef, iv: ValueRef) -> Result<(), Self::Error>; /// Decodes the key for the AEAD, revealing it to this party. - async fn decode_key_private(&mut self) -> Result<(), AeadError>; + async fn decode_key_private(&mut self) -> Result<(), Self::Error>; /// Decodes the key for the AEAD, revealing it to the other party(s). - async fn decode_key_blind(&mut self) -> Result<(), AeadError>; + async fn decode_key_blind(&mut self) -> Result<(), Self::Error>; - /// Sets the transcript id + /// Sets the transcript id. /// /// The AEAD assigns unique identifiers to each byte of plaintext /// during encryption and decryption. /// - /// For example, if the transcript id is set to `foo`, then the first byte will - /// be assigned the id `foo/0`, the second byte `foo/1`, and so on. + /// For example, if the transcript id is set to `foo`, then the first byte + /// will be assigned the id `foo/0`, the second byte `foo/1`, and so on. /// /// Each transcript id has an independent counter. /// /// # Note /// - /// The state of a transcript counter is preserved between calls to `set_transcript_id`. + /// The state of a transcript counter is preserved between calls to + /// `set_transcript_id`. fn set_transcript_id(&mut self, id: &str); + /// Performs any necessary one-time setup for the AEAD. + async fn setup(&mut self) -> Result<(), Self::Error>; + + /// Preprocesses for the given number of bytes. + async fn preprocess(&mut self, len: usize) -> Result<(), Self::Error>; + + /// Starts the AEAD. + /// + /// This method performs initialization for the AEAD after setting the key. + async fn start(&mut self) -> Result<(), Self::Error>; + /// Encrypts a plaintext message, returning the ciphertext and tag. /// /// The plaintext is provided by both parties. /// + /// # Arguments + /// /// * `explicit_nonce` - The explicit nonce to use for encryption. /// * `plaintext` - The plaintext to encrypt. /// * `aad` - Additional authenticated data. @@ -82,9 +75,12 @@ pub trait Aead: Send { explicit_nonce: Vec, plaintext: Vec, aad: Vec, - ) -> Result, AeadError>; + ) -> Result, Self::Error>; - /// Encrypts a plaintext message, hiding it from the other party, returning the ciphertext and tag. + /// Encrypts a plaintext message, hiding it from the other party, returning + /// the ciphertext and tag. + /// + /// # Arguments /// /// * `explicit_nonce` - The explicit nonce to use for encryption. /// * `plaintext` - The plaintext to encrypt. @@ -94,11 +90,13 @@ pub trait Aead: Send { explicit_nonce: Vec, plaintext: Vec, aad: Vec, - ) -> Result, AeadError>; + ) -> Result, Self::Error>; /// Encrypts a plaintext message provided by the other party, returning /// the ciphertext and tag. /// + /// # Arguments + /// /// * `explicit_nonce` - The explicit nonce to use for encryption. /// * `plaintext_len` - The length of the plaintext to encrypt. /// * `aad` - Additional authenticated data. @@ -107,11 +105,14 @@ pub trait Aead: Send { explicit_nonce: Vec, plaintext_len: usize, aad: Vec, - ) -> Result, AeadError>; + ) -> Result, Self::Error>; /// Decrypts a ciphertext message, returning the plaintext to both parties. /// - /// This method checks the authenticity of the ciphertext, tag and additional data. + /// This method checks the authenticity of the ciphertext, tag and + /// additional data. + /// + /// # Arguments /// /// * `explicit_nonce` - The explicit nonce to use for decryption. /// * `payload` - The ciphertext and tag to authenticate and decrypt. @@ -121,11 +122,15 @@ pub trait Aead: Send { explicit_nonce: Vec, payload: Vec, aad: Vec, - ) -> Result, AeadError>; + ) -> Result, Self::Error>; - /// Decrypts a ciphertext message, returning the plaintext only to this party. + /// Decrypts a ciphertext message, returning the plaintext only to this + /// party. + /// + /// This method checks the authenticity of the ciphertext, tag and + /// additional data. /// - /// This method checks the authenticity of the ciphertext, tag and additional data. + /// # Arguments /// /// * `explicit_nonce` - The explicit nonce to use for decryption. /// * `payload` - The ciphertext and tag to authenticate and decrypt. @@ -135,11 +140,15 @@ pub trait Aead: Send { explicit_nonce: Vec, payload: Vec, aad: Vec, - ) -> Result, AeadError>; + ) -> Result, Self::Error>; - /// Decrypts a ciphertext message, returning the plaintext only to the other party. + /// Decrypts a ciphertext message, returning the plaintext only to the other + /// party. /// - /// This method checks the authenticity of the ciphertext, tag and additional data. + /// This method checks the authenticity of the ciphertext, tag and + /// additional data. + /// + /// # Arguments /// /// * `explicit_nonce` - The explicit nonce to use for decryption. /// * `payload` - The ciphertext and tag to authenticate and decrypt. @@ -149,11 +158,14 @@ pub trait Aead: Send { explicit_nonce: Vec, payload: Vec, aad: Vec, - ) -> Result<(), AeadError>; + ) -> Result<(), Self::Error>; /// Verifies the tag of a ciphertext message. /// - /// This method checks the authenticity of the ciphertext, tag and additional data. + /// This method checks the authenticity of the ciphertext, tag and + /// additional data. + /// + /// # Arguments /// /// * `explicit_nonce` - The explicit nonce to use for decryption. /// * `payload` - The ciphertext and tag to authenticate and decrypt. @@ -163,77 +175,81 @@ pub trait Aead: Send { explicit_nonce: Vec, payload: Vec, aad: Vec, - ) -> Result<(), AeadError>; + ) -> Result<(), Self::Error>; - /// Locally decrypts the provided ciphertext and then proves in ZK to the other party(s) that the - /// plaintext is correct. + /// Locally decrypts the provided ciphertext and then proves in ZK to the + /// other party(s) that the plaintext is correct. /// /// Returns the plaintext. /// - /// This method requires this party to know the encryption key, which can be achieved by calling - /// the `decode_key_private` method. + /// This method requires this party to know the encryption key, which can be + /// achieved by calling the `decode_key_private` method. /// /// # Arguments /// - /// * `explicit_nonce`: The explicit nonce to use for the keystream. - /// * `payload`: The ciphertext and tag to decrypt and prove. - /// * `aad`: Additional authenticated data. + /// * `explicit_nonce` - The explicit nonce to use for the keystream. + /// * `payload` - The ciphertext and tag to decrypt and prove. + /// * `aad` - Additional authenticated data. async fn prove_plaintext( &mut self, explicit_nonce: Vec, payload: Vec, aad: Vec, - ) -> Result, AeadError>; + ) -> Result, Self::Error>; - /// Locally decrypts the provided ciphertext and then proves in ZK to the other party(s) that the - /// plaintext is correct. + /// Locally decrypts the provided ciphertext and then proves in ZK to the + /// other party(s) that the plaintext is correct. /// /// Returns the plaintext. /// - /// This method requires this party to know the encryption key, which can be achieved by calling - /// the `decode_key_private` method. + /// This method requires this party to know the encryption key, which can be + /// achieved by calling the `decode_key_private` method. /// /// # WARNING /// - /// This method does not verify the tag of the ciphertext. Only use this if you know what you're doing. + /// This method does not verify the tag of the ciphertext. Only use this if + /// you know what you're doing. /// /// # Arguments /// - /// * `explicit_nonce`: The explicit nonce to use for the keystream. - /// * `ciphertext`: The ciphertext to decrypt and prove. + /// * `explicit_nonce` - The explicit nonce to use for the keystream. + /// * `ciphertext` - The ciphertext to decrypt and prove. async fn prove_plaintext_no_tag( &mut self, explicit_nonce: Vec, ciphertext: Vec, - ) -> Result, AeadError>; + ) -> Result, Self::Error>; - /// Verifies the other party(s) can prove they know a plaintext which encrypts to the given ciphertext. + /// Verifies the other party(s) can prove they know a plaintext which + /// encrypts to the given ciphertext. /// /// # Arguments /// - /// * `explicit_nonce`: The explicit nonce to use for the keystream. - /// * `payload`: The ciphertext and tag to verify. - /// * `aad`: Additional authenticated data. + /// * `explicit_nonce` - The explicit nonce to use for the keystream. + /// * `payload` - The ciphertext and tag to verify. + /// * `aad` - Additional authenticated data. async fn verify_plaintext( &mut self, explicit_nonce: Vec, payload: Vec, aad: Vec, - ) -> Result<(), AeadError>; + ) -> Result<(), Self::Error>; - /// Verifies the other party(s) can prove they know a plaintext which encrypts to the given ciphertext. + /// Verifies the other party(s) can prove they know a plaintext which + /// encrypts to the given ciphertext. /// /// # WARNING /// - /// This method does not verify the tag of the ciphertext. Only use this if you know what you're doing. + /// This method does not verify the tag of the ciphertext. Only use this if + /// you know what you're doing. /// /// # Arguments /// - /// * `explicit_nonce`: The explicit nonce to use for the keystream. - /// * `ciphertext`: The ciphertext to verify. + /// * `explicit_nonce` - The explicit nonce to use for the keystream. + /// * `ciphertext` - The ciphertext to verify. async fn verify_plaintext_no_tag( &mut self, explicit_nonce: Vec, ciphertext: Vec, - ) -> Result<(), AeadError>; + ) -> Result<(), Self::Error>; } diff --git a/crates/components/authdecode/transcript/Cargo.toml b/crates/components/authdecode/transcript/Cargo.toml new file mode 100644 index 0000000000..ee5290ca4e --- /dev/null +++ b/crates/components/authdecode/transcript/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "tlsn-authdecode-transcript" +authors = ["TLSNotary Team"] +description = "A convenience type for using AuthDecode with transcript data" +keywords = ["tls", "mpc", "2pc"] +categories = ["cryptography"] +license = "MIT OR Apache-2.0" +version = "0.1.0" +edition = "2021" + +[lib] +name = "authdecode_transcript" + +[dependencies] +tlsn-authdecode-core = { workspace = true } +tlsn-core = { workspace = true } +mpz-circuits = { workspace = true } +mpz-garble-core = { workspace = true } +mpz-core = { workspace = true } + +getset = "0.1.2" +serde = { workspace = true } diff --git a/crates/components/authdecode/transcript/src/lib.rs b/crates/components/authdecode/transcript/src/lib.rs new file mode 100644 index 0000000000..e01271723d --- /dev/null +++ b/crates/components/authdecode/transcript/src/lib.rs @@ -0,0 +1,150 @@ +//! A convenience type for using AuthDecode with transcript data. +use core::ops::Range; +use getset::Getters; +use serde::{Deserialize, Serialize}; + +use authdecode_core::{ + backend::halo2::CHUNK_SIZE, + encodings::{Encoding, EncodingProvider, EncodingProviderError, FullEncodings}, + id::{Id, IdCollection}, + SSP, +}; +use mpz_circuits::types::ValueType; +use mpz_core::{utils::blake3, Block}; +use mpz_garble_core::ChaChaEncoder; +use tlsn_core::transcript::{Direction, RX_TRANSCRIPT_ID, TX_TRANSCRIPT_ID}; + +#[derive(Clone, PartialEq, Serialize, Deserialize, Getters)] +/// Information about a subset of transcript data. +/// +/// The data is treated as a big-endian bytestring with MSB0 bit ordering. +pub struct TranscriptData { + /// The direction in which the data was transmitted. + #[getset(get = "pub")] + direction: Direction, + /// The byterange in the transcript where the data is located. + #[getset(get = "pub")] + range: Range, +} + +impl TranscriptData { + /// Creates a new `TranscriptData`. + /// + /// # Panics + /// + /// Panics if the range length exceeds the maximim allowed length. + pub fn new(direction: Direction, range: &Range) -> Self { + assert!(range.len() <= CHUNK_SIZE); + + Self { + direction, + range: range.clone(), + } + } +} + +impl Default for TranscriptData { + fn default() -> Self { + Self { + direction: Direction::Sent, + range: Range::default(), + } + } +} + +impl IdCollection for TranscriptData { + fn drain_front(&mut self, count: usize) -> Self { + assert!(count % 8 == 0); + assert!(count <= CHUNK_SIZE * 8); + // We will never need to drain since the collection spans a single chunk. + self.clone() + } + + fn id(&self, _index: usize) -> Id { + unimplemented!() + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn len(&self) -> usize { + self.range.len() * 8 + } + + fn new_from_iter>(_iter: I) -> Self { + unimplemented!() + } +} + +/// An encoder of a TLS transcript. +pub struct TranscriptEncoder { + encoder: ChaChaEncoder, +} + +impl TranscriptEncoder { + /// Creates a new encoder from the `seed`. + /// + /// # Arguments + /// + /// * `seed` - The seed to create the encoder from. + pub fn new(seed: [u8; 32]) -> Self { + Self { + encoder: ChaChaEncoder::new(seed), + } + } + + /// Encodes a byte at the given position and direction in the transcript. + fn encode_byte(&self, dir: Direction, pos: usize) -> Vec<[Encoding; 2]> { + let id = match dir { + Direction::Sent => TX_TRANSCRIPT_ID, + Direction::Received => RX_TRANSCRIPT_ID, + }; + + let id_hash = blake3(format!("{}/{}", id, pos).as_bytes()); + let id = u64::from_be_bytes(id_hash[..8].try_into().unwrap()); + + let mut encodings = ::encode_by_type( + &self.encoder, + id, + &ValueType::U8, + ) + .iter_blocks() + .map(|blocks| { + // Hash the encodings to break the correlation and truncate them. + [ + Encoding::new( + blake3(&Block::to_bytes(blocks[0]))[0..SSP / 8] + .try_into() + .unwrap(), + false, + ), + Encoding::new( + blake3(&Block::to_bytes(blocks[1]))[0..SSP / 8] + .try_into() + .unwrap(), + true, + ), + ] + }) + .collect::>(); + // Reverse byte encodings to MSB0. + encodings.reverse(); + encodings + } +} + +impl EncodingProvider for TranscriptEncoder { + fn get_by_ids( + &self, + ids: &TranscriptData, + ) -> Result, EncodingProviderError> { + let mut full_encoding = Vec::with_capacity(ids.range().len() * 8); + + for pos in ids.range().clone() { + full_encoding.extend(self.encode_byte(*ids.direction(), pos)); + } + + Ok(FullEncodings::new(full_encoding, ids.clone())) + } +} diff --git a/components/cipher/block-cipher/Cargo.toml b/crates/components/block-cipher/Cargo.toml similarity index 56% rename from components/cipher/block-cipher/Cargo.toml rename to crates/components/block-cipher/Cargo.toml index 8059f331f7..079a331ce6 100644 --- a/components/cipher/block-cipher/Cargo.toml +++ b/crates/components/block-cipher/Cargo.toml @@ -5,7 +5,7 @@ description = "2PC block cipher implementation" keywords = ["tls", "mpc", "2pc", "block-cipher"] categories = ["cryptography"] license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" +version = "0.1.0-alpha.7" edition = "2021" [lib] @@ -13,19 +13,18 @@ name = "block_cipher" [features] default = ["mock"] -tracing = ["dep:tracing"] mock = [] [dependencies] -mpz-circuits.workspace = true -mpz-garble.workspace = true -tlsn-utils.workspace = true -async-trait.workspace = true -thiserror.workspace = true -derive_builder.workspace = true -tracing = { workspace = true, optional = true } +mpz-circuits = { workspace = true } +mpz-garble = { workspace = true } +tlsn-utils = { workspace = true } +async-trait = { workspace = true } +thiserror = { workspace = true } +derive_builder = { workspace = true } +tracing = { workspace = true } [dev-dependencies] -aes.workspace = true -cipher.workspace = true +aes = { workspace = true } +cipher = { workspace = true } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } diff --git a/crates/components/block-cipher/src/cipher.rs b/crates/components/block-cipher/src/cipher.rs new file mode 100644 index 0000000000..347fe85ca1 --- /dev/null +++ b/crates/components/block-cipher/src/cipher.rs @@ -0,0 +1,277 @@ +use std::{collections::VecDeque, marker::PhantomData}; + +use async_trait::async_trait; + +use mpz_garble::{value::ValueRef, Decode, DecodePrivate, Execute, Load, Memory}; +use tracing::instrument; +use utils::id::NestedId; + +use crate::{BlockCipher, BlockCipherCircuit, BlockCipherConfig, BlockCipherError, Visibility}; + +#[derive(Debug)] +struct State { + private_execution_id: NestedId, + public_execution_id: NestedId, + preprocessed_private: VecDeque, + preprocessed_public: VecDeque, + key: Option, +} + +#[derive(Debug)] +struct BlockVars { + msg: ValueRef, + ciphertext: ValueRef, +} + +/// An MPC block cipher. +#[derive(Debug)] +pub struct MpcBlockCipher +where + C: BlockCipherCircuit, + E: Memory + Execute + Decode + DecodePrivate + Send + Sync, +{ + state: State, + + executor: E, + + _cipher: PhantomData, +} + +impl MpcBlockCipher +where + C: BlockCipherCircuit, + E: Memory + Execute + Decode + DecodePrivate + Send + Sync, +{ + /// Creates a new MPC block cipher. + /// + /// # Arguments + /// + /// * `config` - The configuration for the block cipher. + /// * `executor` - The executor to use for the MPC. + pub fn new(config: BlockCipherConfig, executor: E) -> Self { + let private_execution_id = NestedId::new(&config.id) + .append_string("private") + .append_counter(); + let public_execution_id = NestedId::new(&config.id) + .append_string("public") + .append_counter(); + Self { + state: State { + private_execution_id, + public_execution_id, + preprocessed_private: VecDeque::new(), + preprocessed_public: VecDeque::new(), + key: None, + }, + executor, + _cipher: PhantomData, + } + } + + fn define_block(&mut self, vis: Visibility) -> BlockVars { + let (id, msg) = match vis { + Visibility::Private => { + let id = self + .state + .private_execution_id + .increment_in_place() + .to_string(); + let msg = self + .executor + .new_private_input::(&format!("{}/msg", &id)) + .expect("message is not defined"); + (id, msg) + } + Visibility::Blind => { + let id = self + .state + .private_execution_id + .increment_in_place() + .to_string(); + let msg = self + .executor + .new_blind_input::(&format!("{}/msg", &id)) + .expect("message is not defined"); + (id, msg) + } + Visibility::Public => { + let id = self + .state + .public_execution_id + .increment_in_place() + .to_string(); + let msg = self + .executor + .new_public_input::(&format!("{}/msg", &id)) + .expect("message is not defined"); + (id, msg) + } + }; + + let ciphertext = self + .executor + .new_output::(&format!("{}/ciphertext", &id)) + .expect("message is not defined"); + + BlockVars { msg, ciphertext } + } +} + +#[async_trait] +impl BlockCipher for MpcBlockCipher +where + C: BlockCipherCircuit, + E: Memory + Load + Execute + Decode + DecodePrivate + Send + Sync + Send, +{ + #[instrument(level = "trace", skip_all)] + fn set_key(&mut self, key: ValueRef) { + self.state.key = Some(key); + } + + #[instrument(level = "debug", skip_all, err)] + async fn preprocess( + &mut self, + visibility: Visibility, + count: usize, + ) -> Result<(), BlockCipherError> { + let key = self + .state + .key + .clone() + .ok_or_else(BlockCipherError::key_not_set)?; + + for _ in 0..count { + let vars = self.define_block(visibility); + + self.executor + .load( + C::circuit(), + &[key.clone(), vars.msg.clone()], + &[vars.ciphertext.clone()], + ) + .await?; + + match visibility { + Visibility::Private | Visibility::Blind => { + self.state.preprocessed_private.push_back(vars) + } + Visibility::Public => self.state.preprocessed_public.push_back(vars), + } + } + + Ok(()) + } + + #[instrument(level = "debug", skip_all, err)] + async fn encrypt_private(&mut self, plaintext: Vec) -> Result, BlockCipherError> { + let len = plaintext.len(); + let block: C::BLOCK = plaintext + .try_into() + .map_err(|_| BlockCipherError::invalid_message_length::(len))?; + + let key = self + .state + .key + .clone() + .ok_or_else(BlockCipherError::key_not_set)?; + + let BlockVars { msg, ciphertext } = + if let Some(vars) = self.state.preprocessed_private.pop_front() { + vars + } else { + self.define_block(Visibility::Private) + }; + + self.executor.assign(&msg, block)?; + + self.executor + .execute(C::circuit(), &[key, msg], &[ciphertext.clone()]) + .await?; + + let mut outputs = self.executor.decode(&[ciphertext]).await?; + + let ciphertext: C::BLOCK = if let Ok(ciphertext) = outputs + .pop() + .expect("ciphertext should be present") + .try_into() + { + ciphertext + } else { + panic!("ciphertext should be a block") + }; + + Ok(ciphertext.into()) + } + + #[instrument(level = "debug", skip_all, err)] + async fn encrypt_blind(&mut self) -> Result, BlockCipherError> { + let key = self + .state + .key + .clone() + .ok_or_else(BlockCipherError::key_not_set)?; + + let BlockVars { msg, ciphertext } = + if let Some(vars) = self.state.preprocessed_private.pop_front() { + vars + } else { + self.define_block(Visibility::Blind) + }; + + self.executor + .execute(C::circuit(), &[key, msg], &[ciphertext.clone()]) + .await?; + + let mut outputs = self.executor.decode(&[ciphertext]).await?; + + let ciphertext: C::BLOCK = if let Ok(ciphertext) = outputs + .pop() + .expect("ciphertext should be present") + .try_into() + { + ciphertext + } else { + panic!("ciphertext should be a block") + }; + + Ok(ciphertext.into()) + } + + #[instrument(level = "debug", skip_all, err)] + async fn encrypt_share(&mut self, plaintext: Vec) -> Result, BlockCipherError> { + let len = plaintext.len(); + let block: C::BLOCK = plaintext + .try_into() + .map_err(|_| BlockCipherError::invalid_message_length::(len))?; + + let key = self + .state + .key + .clone() + .ok_or_else(BlockCipherError::key_not_set)?; + + let BlockVars { msg, ciphertext } = + if let Some(vars) = self.state.preprocessed_public.pop_front() { + vars + } else { + self.define_block(Visibility::Public) + }; + + self.executor.assign(&msg, block)?; + + self.executor + .execute(C::circuit(), &[key, msg], &[ciphertext.clone()]) + .await?; + + let mut outputs = self.executor.decode_shared(&[ciphertext]).await?; + + let share: C::BLOCK = + if let Ok(share) = outputs.pop().expect("share should be present").try_into() { + share + } else { + panic!("share should be a block") + }; + + Ok(share.into()) + } +} diff --git a/components/cipher/block-cipher/src/circuit.rs b/crates/components/block-cipher/src/circuit.rs similarity index 83% rename from components/cipher/block-cipher/src/circuit.rs rename to crates/components/block-cipher/src/circuit.rs index fefe978865..4146f4cfb8 100644 --- a/components/cipher/block-cipher/src/circuit.rs +++ b/crates/components/block-cipher/src/circuit.rs @@ -8,17 +8,17 @@ use mpz_circuits::{ /// A block cipher circuit. pub trait BlockCipherCircuit: Default + Clone + Send + Sync { - /// The key type + /// The key type. type KEY: StaticValueType + Send + Sync; - /// The block type + /// The block type. type BLOCK: StaticValueType + TryFrom> + TryFrom + Into> + Send + Sync; - /// The length of the key + /// The length of the key. const KEY_LEN: usize; - /// The length of the block + /// The length of the block. const BLOCK_LEN: usize; - /// Returns the circuit of the cipher + /// Returns the circuit of the cipher. fn circuit() -> Arc; } diff --git a/components/cipher/block-cipher/src/config.rs b/crates/components/block-cipher/src/config.rs similarity index 81% rename from components/cipher/block-cipher/src/config.rs rename to crates/components/block-cipher/src/config.rs index 813a8a3df4..f21f9623b2 100644 --- a/components/cipher/block-cipher/src/config.rs +++ b/crates/components/block-cipher/src/config.rs @@ -1,15 +1,15 @@ use derive_builder::Builder; -/// Configuration for a block cipher +/// Configuration for a block cipher. #[derive(Debug, Clone, Builder)] pub struct BlockCipherConfig { - /// The ID of the block cipher + /// The ID of the block cipher. #[builder(setter(into))] pub(crate) id: String, } impl BlockCipherConfig { - /// Creates a new builder for the block cipher configuration + /// Creates a new builder for the block cipher configuration. pub fn builder() -> BlockCipherConfigBuilder { BlockCipherConfigBuilder::default() } diff --git a/crates/components/block-cipher/src/error.rs b/crates/components/block-cipher/src/error.rs new file mode 100644 index 0000000000..1612484a40 --- /dev/null +++ b/crates/components/block-cipher/src/error.rs @@ -0,0 +1,92 @@ +use core::fmt; +use std::error::Error; + +use crate::BlockCipherCircuit; + +/// A block cipher error. +#[derive(Debug, thiserror::Error)] +pub struct BlockCipherError { + kind: ErrorKind, + #[source] + source: Option>, +} + +impl BlockCipherError { + pub(crate) fn new(kind: ErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } + + pub(crate) fn key_not_set() -> Self { + Self { + kind: ErrorKind::Key, + source: Some("key not set".into()), + } + } + + pub(crate) fn invalid_message_length(len: usize) -> Self { + Self { + kind: ErrorKind::Msg, + source: Some( + format!( + "message length does not equal block length: {} != {}", + len, + C::BLOCK_LEN + ) + .into(), + ), + } + } +} + +#[derive(Debug)] +pub(crate) enum ErrorKind { + Vm, + Key, + Msg, +} + +impl fmt::Display for BlockCipherError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.kind { + ErrorKind::Vm => write!(f, "vm error")?, + ErrorKind::Key => write!(f, "key error")?, + ErrorKind::Msg => write!(f, "message error")?, + } + + if let Some(ref source) = self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +impl From for BlockCipherError { + fn from(error: mpz_garble::MemoryError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for BlockCipherError { + fn from(error: mpz_garble::LoadError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for BlockCipherError { + fn from(error: mpz_garble::ExecutionError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for BlockCipherError { + fn from(error: mpz_garble::DecodeError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} diff --git a/crates/components/block-cipher/src/lib.rs b/crates/components/block-cipher/src/lib.rs new file mode 100644 index 0000000000..5fde927fd3 --- /dev/null +++ b/crates/components/block-cipher/src/lib.rs @@ -0,0 +1,236 @@ +//! This crate provides a 2PC block cipher implementation. +//! +//! Both parties work together to encrypt or share an encrypted block using a +//! shared key. + +#![deny(missing_docs, unreachable_pub, unused_must_use)] +#![deny(clippy::all)] +#![deny(unsafe_code)] + +mod cipher; +mod circuit; +mod config; +mod error; + +use async_trait::async_trait; + +use mpz_garble::value::ValueRef; + +pub use crate::{ + cipher::MpcBlockCipher, + circuit::{Aes128, BlockCipherCircuit}, +}; +pub use config::{BlockCipherConfig, BlockCipherConfigBuilder, BlockCipherConfigBuilderError}; +pub use error::BlockCipherError; + +/// Visibility of a message plaintext. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Visibility { + /// Private message. + Private, + /// Blind message. + Blind, + /// Public message. + Public, +} + +/// A trait for MPC block ciphers. +#[async_trait] +pub trait BlockCipher: Send + Sync +where + Cipher: BlockCipherCircuit, +{ + /// Sets the key for the block cipher. + fn set_key(&mut self, key: ValueRef); + + /// Preprocesses `count` blocks. + /// + /// # Arguments + /// + /// * `visibility` - The visibility of the plaintext. + /// * `count` - The number of blocks to preprocess. + async fn preprocess( + &mut self, + visibility: Visibility, + count: usize, + ) -> Result<(), BlockCipherError>; + + /// Encrypts the given plaintext keeping it hidden from the other party(s). + /// + /// Returns the ciphertext. + /// + /// # Arguments + /// + /// * `plaintext` - The plaintext to encrypt. + async fn encrypt_private(&mut self, plaintext: Vec) -> Result, BlockCipherError>; + + /// Encrypts a plaintext provided by the other party(s). + /// + /// Returns the ciphertext. + async fn encrypt_blind(&mut self) -> Result, BlockCipherError>; + + /// Encrypts a plaintext provided by both parties. Fails if the + /// plaintext provided by both parties does not match. + /// + /// Returns an additive share of the ciphertext. + /// + /// # Arguments + /// + /// * `plaintext` - The plaintext to encrypt. + async fn encrypt_share(&mut self, plaintext: Vec) -> Result, BlockCipherError>; +} + +#[cfg(test)] +mod tests { + use super::*; + + use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory}; + + use crate::circuit::Aes128; + + use ::aes::Aes128 as TestAes128; + use ::cipher::{BlockEncrypt, KeyInit}; + + fn aes128(key: [u8; 16], msg: [u8; 16]) -> [u8; 16] { + let mut msg = msg.into(); + let cipher = TestAes128::new(&key.into()); + cipher.encrypt_block(&mut msg); + msg.into() + } + + #[tokio::test] + #[ignore = "expensive"] + async fn test_block_cipher_blind() { + let leader_config = BlockCipherConfig::builder().id("test").build().unwrap(); + let follower_config = BlockCipherConfig::builder().id("test").build().unwrap(); + + let key = [0u8; 16]; + + let (leader_vm, follower_vm) = create_mock_deap_vm(); + + // Key is public just for this test, typically it is private. + let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap(); + let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap(); + + leader_vm.assign(&leader_key, key).unwrap(); + follower_vm.assign(&follower_key, key).unwrap(); + + let mut leader = MpcBlockCipher::::new(leader_config, leader_vm); + leader.set_key(leader_key); + + let mut follower = MpcBlockCipher::::new(follower_config, follower_vm); + follower.set_key(follower_key); + + let plaintext = [0u8; 16]; + + let (leader_ciphertext, follower_ciphertext) = tokio::try_join!( + leader.encrypt_private(plaintext.to_vec()), + follower.encrypt_blind() + ) + .unwrap(); + + let expected = aes128(key, plaintext); + + assert_eq!(leader_ciphertext, expected.to_vec()); + assert_eq!(leader_ciphertext, follower_ciphertext); + } + + #[tokio::test] + #[ignore = "expensive"] + async fn test_block_cipher_share() { + let leader_config = BlockCipherConfig::builder().id("test").build().unwrap(); + let follower_config = BlockCipherConfig::builder().id("test").build().unwrap(); + + let key = [0u8; 16]; + + let (leader_vm, follower_vm) = create_mock_deap_vm(); + + // Key is public just for this test, typically it is private. + let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap(); + let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap(); + + leader_vm.assign(&leader_key, key).unwrap(); + follower_vm.assign(&follower_key, key).unwrap(); + + let mut leader = MpcBlockCipher::::new(leader_config, leader_vm); + leader.set_key(leader_key); + + let mut follower = MpcBlockCipher::::new(follower_config, follower_vm); + follower.set_key(follower_key); + + let plaintext = [0u8; 16]; + + let (leader_share, follower_share) = tokio::try_join!( + leader.encrypt_share(plaintext.to_vec()), + follower.encrypt_share(plaintext.to_vec()) + ) + .unwrap(); + + let expected = aes128(key, plaintext); + + let result: [u8; 16] = std::array::from_fn(|i| leader_share[i] ^ follower_share[i]); + + assert_eq!(result, expected); + } + + #[tokio::test] + #[ignore = "expensive"] + async fn test_block_cipher_preprocess() { + let leader_config = BlockCipherConfig::builder().id("test").build().unwrap(); + let follower_config = BlockCipherConfig::builder().id("test").build().unwrap(); + + let key = [0u8; 16]; + + let (leader_vm, follower_vm) = create_mock_deap_vm(); + + // Key is public just for this test, typically it is private. + let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap(); + let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap(); + + leader_vm.assign(&leader_key, key).unwrap(); + follower_vm.assign(&follower_key, key).unwrap(); + + let mut leader = MpcBlockCipher::::new(leader_config, leader_vm); + leader.set_key(leader_key); + + let mut follower = MpcBlockCipher::::new(follower_config, follower_vm); + follower.set_key(follower_key); + + let plaintext = [0u8; 16]; + + tokio::try_join!( + leader.preprocess(Visibility::Private, 1), + follower.preprocess(Visibility::Blind, 1) + ) + .unwrap(); + + let (leader_ciphertext, follower_ciphertext) = tokio::try_join!( + leader.encrypt_private(plaintext.to_vec()), + follower.encrypt_blind() + ) + .unwrap(); + + let expected = aes128(key, plaintext); + + assert_eq!(leader_ciphertext, expected.to_vec()); + assert_eq!(leader_ciphertext, follower_ciphertext); + + tokio::try_join!( + leader.preprocess(Visibility::Public, 1), + follower.preprocess(Visibility::Public, 1) + ) + .unwrap(); + + let (leader_share, follower_share) = tokio::try_join!( + leader.encrypt_share(plaintext.to_vec()), + follower.encrypt_share(plaintext.to_vec()) + ) + .unwrap(); + + let expected = aes128(key, plaintext); + + let result: [u8; 16] = std::array::from_fn(|i| leader_share[i] ^ follower_share[i]); + + assert_eq!(result, expected); + } +} diff --git a/components/prf/hmac-sha256-circuits/Cargo.toml b/crates/components/hmac-sha256-circuits/Cargo.toml similarity index 68% rename from components/prf/hmac-sha256-circuits/Cargo.toml rename to crates/components/hmac-sha256-circuits/Cargo.toml index fec1214f57..5d232c554d 100644 --- a/components/prf/hmac-sha256-circuits/Cargo.toml +++ b/crates/components/hmac-sha256-circuits/Cargo.toml @@ -5,18 +5,15 @@ description = "The 2PC circuits for TLS HMAC-SHA256 PRF" keywords = ["tls", "mpc", "2pc", "hmac", "sha256"] categories = ["cryptography"] license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" +version = "0.1.0-alpha.7" edition = "2021" [lib] name = "hmac_sha256_circuits" -[features] -tracing = ["dep:tracing"] - [dependencies] -mpz-circuits.workspace = true -tracing = { workspace = true, optional = true } +mpz-circuits = { workspace = true } +tracing = { workspace = true } [dev-dependencies] -ring = "0.17" +ring = { workspace = true } diff --git a/components/prf/hmac-sha256-circuits/src/hmac_sha256.rs b/crates/components/hmac-sha256-circuits/src/hmac_sha256.rs similarity index 82% rename from components/prf/hmac-sha256-circuits/src/hmac_sha256.rs rename to crates/components/hmac-sha256-circuits/src/hmac_sha256.rs index b42352a024..5f135cea68 100644 --- a/components/prf/hmac-sha256-circuits/src/hmac_sha256.rs +++ b/crates/components/hmac-sha256-circuits/src/hmac_sha256.rs @@ -18,12 +18,8 @@ static SHA256_INITIAL_STATE: [u32; 8] = [ /// /// # Arguments /// -/// * `builder_state` - Reference to builder state -/// * `key` - N-byte key (must be <= 64 bytes) -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(key, builder_state)) -)] +/// * `builder_state` - Reference to builder state. +/// * `key` - N-byte key (must be <= 64 bytes). pub fn hmac_sha256_partial_trace<'a>( builder_state: &'a RefCell, key: &[Tracer<'a, U8>], @@ -64,8 +60,7 @@ pub fn hmac_sha256_partial_trace<'a>( /// /// # Arguments /// -/// * `key` - N-byte key (must be <= 64 bytes) -#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(key)))] +/// * `key` - N-byte key (must be <= 64 bytes). pub fn hmac_sha256_partial(key: &[u8]) -> ([u32; 8], [u32; 8]) { assert!(key.len() <= 64); @@ -85,17 +80,14 @@ pub fn hmac_sha256_partial(key: &[u8]) -> ([u32; 8], [u32; 8]) { /// HMAC-SHA256 finalization function. /// -/// Returns the HMAC-SHA256 digest of the provided message using existing outer and inner states. +/// Returns the HMAC-SHA256 digest of the provided message using existing outer +/// and inner states. /// /// # Arguments /// -/// * `outer_state` - 256-bit outer state -/// * `inner_state` - 256-bit inner state -/// * `msg` - N-byte message -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(builder_state, outer_state, inner_state, msg)) -)] +/// * `outer_state` - 256-bit outer state. +/// * `inner_state` - 256-bit inner state. +/// * `msg` - N-byte message. pub fn hmac_sha256_finalize_trace<'a>( builder_state: &'a RefCell, outer_state: [Tracer<'a, U32>; 8], @@ -112,17 +104,14 @@ pub fn hmac_sha256_finalize_trace<'a>( /// Reference implementation of the HMAC-SHA256 finalization function. /// -/// Returns the HMAC-SHA256 digest of the provided message using existing outer and inner states. +/// Returns the HMAC-SHA256 digest of the provided message using existing outer +/// and inner states. /// /// # Arguments /// -/// * `outer_state` - 256-bit outer state -/// * `inner_state` - 256-bit inner state -/// * `msg` - N-byte message -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(outer_state, inner_state, msg)) -)] +/// * `outer_state` - 256-bit outer state. +/// * `inner_state` - 256-bit inner state. +/// * `msg` - N-byte message. pub fn hmac_sha256_finalize(outer_state: [u32; 8], inner_state: [u32; 8], msg: &[u8]) -> [u8; 32] { sha256(outer_state, 64, &sha256(inner_state, 64, msg)) } diff --git a/components/prf/hmac-sha256-circuits/src/lib.rs b/crates/components/hmac-sha256-circuits/src/lib.rs similarity index 92% rename from components/prf/hmac-sha256-circuits/src/lib.rs rename to crates/components/hmac-sha256-circuits/src/lib.rs index 901f009bcb..2892a5bcf0 100644 --- a/components/prf/hmac-sha256-circuits/src/lib.rs +++ b/crates/components/hmac-sha256-circuits/src/lib.rs @@ -22,7 +22,7 @@ use mpz_circuits::{Circuit, CircuitBuilder, Tracer}; use std::sync::Arc; /// Builds session key derivation circuit. -#[cfg_attr(feature = "tracing", tracing::instrument(level = "info"))] +#[tracing::instrument(level = "trace")] pub fn build_session_keys() -> Arc { let builder = CircuitBuilder::new(); let pms = builder.add_array_input::(); @@ -40,7 +40,7 @@ pub fn build_session_keys() -> Arc { } /// Builds a verify data circuit. -#[cfg_attr(feature = "tracing", tracing::instrument(level = "info", skip(label)))] +#[tracing::instrument(level = "trace")] pub fn build_verify_data(label: &[u8]) -> Arc { let builder = CircuitBuilder::new(); let outer_state = builder.add_array_input::(); diff --git a/components/prf/hmac-sha256-circuits/src/prf.rs b/crates/components/hmac-sha256-circuits/src/prf.rs similarity index 85% rename from components/prf/hmac-sha256-circuits/src/prf.rs rename to crates/components/hmac-sha256-circuits/src/prf.rs index a6c79acc7d..664a93393f 100644 --- a/components/prf/hmac-sha256-circuits/src/prf.rs +++ b/crates/components/hmac-sha256-circuits/src/prf.rs @@ -9,10 +9,6 @@ use mpz_circuits::{ use crate::hmac_sha256::{hmac_sha256_finalize, hmac_sha256_finalize_trace}; -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(builder_state, outer_state, inner_state, seed)) -)] fn p_hash_trace<'a>( builder_state: &'a RefCell, outer_state: [Tracer<'a, U32>; 8], @@ -45,10 +41,6 @@ fn p_hash_trace<'a>( output } -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(outer_state, inner_state, seed)) -)] fn p_hash(outer_state: [u32; 8], inner_state: [u32; 8], seed: &[u8], iterations: usize) -> Vec { // A() is defined as: // @@ -75,23 +67,16 @@ fn p_hash(outer_state: [u32; 8], inner_state: [u32; 8], seed: &[u8], iterations: output } -/// Computes PRF(secret, label, seed) +/// Computes PRF(secret, label, seed). /// /// # Arguments /// /// * `builder_state` - Reference to builder state. -/// * `outer_state` - The outer state of HMAC-SHA256 -/// * `inner_state` - The inner state of HMAC-SHA256 -/// * `seed` - The seed to use -/// * `label` - The label to use -/// * `bytes` - The number of bytes to output -#[cfg_attr( - feature = "tracing", - tracing::instrument( - level = "trace", - skip(builder_state, outer_state, inner_state, seed, label) - ) -)] +/// * `outer_state` - The outer state of HMAC-SHA256. +/// * `inner_state` - The inner state of HMAC-SHA256. +/// * `seed` - The seed to use. +/// * `label` - The label to use. +/// * `bytes` - The number of bytes to output. pub fn prf_trace<'a>( builder_state: &'a RefCell, outer_state: [Tracer<'a, U32>; 8], @@ -116,19 +101,15 @@ pub fn prf_trace<'a>( output } -/// Reference implementation of PRF(secret, label, seed) +/// Reference implementation of PRF(secret, label, seed). /// /// # Arguments /// -/// * `outer_state` - The outer state of HMAC-SHA256 -/// * `inner_state` - The inner state of HMAC-SHA256 -/// * `seed` - The seed to use -/// * `label` - The label to use -/// * `bytes` - The number of bytes to output -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(outer_state, inner_state, seed, label)) -)] +/// * `outer_state` - The outer state of HMAC-SHA256. +/// * `inner_state` - The inner state of HMAC-SHA256. +/// * `seed` - The seed to use. +/// * `label` - The label to use. +/// * `bytes` - The number of bytes to output. pub fn prf( outer_state: [u32; 8], inner_state: [u32; 8], diff --git a/components/prf/hmac-sha256-circuits/src/session_keys.rs b/crates/components/hmac-sha256-circuits/src/session_keys.rs similarity index 86% rename from components/prf/hmac-sha256-circuits/src/session_keys.rs rename to crates/components/hmac-sha256-circuits/src/session_keys.rs index 15a667d833..4ebe7f0841 100644 --- a/components/prf/hmac-sha256-circuits/src/session_keys.rs +++ b/crates/components/hmac-sha256-circuits/src/session_keys.rs @@ -10,30 +10,26 @@ use crate::{ prf::{prf, prf_trace}, }; -/// Session Keys +/// Session Keys. /// -/// Compute expanded p1 which consists of client_write_key + server_write_key -/// Compute expanded p2 which consists of client_IV + server_IV +/// Computes expanded p1 which consists of client_write_key + server_write_key. +/// Computes expanded p2 which consists of client_IV + server_IV. /// /// # Arguments /// -/// * `builder_state` - Reference to builder state -/// * `pms` - 32-byte premaster secret -/// * `client_random` - 32-byte client random -/// * `server_random` - 32-byte server random +/// * `builder_state` - Reference to builder state. +/// * `pms` - 32-byte premaster secret. +/// * `client_random` - 32-byte client random. +/// * `server_random` - 32-byte server random. /// /// # Returns /// -/// * `client_write_key` - 16-byte client write key -/// * `server_write_key` - 16-byte server write key -/// * `client_IV` - 4-byte client IV -/// * `server_IV` - 4-byte server IV -/// * `outer_hash_state` - 256-bit master-secret outer HMAC state -/// * `inner_hash_state` - 256-bit master-secret inner HMAC state -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(builder_state, pms)) -)] +/// * `client_write_key` - 16-byte client write key. +/// * `server_write_key` - 16-byte server write key. +/// * `client_IV` - 4-byte client IV. +/// * `server_IV` - 4-byte server IV. +/// * `outer_hash_state` - 256-bit master-secret outer HMAC state. +/// * `inner_hash_state` - 256-bit master-secret inner HMAC state. #[allow(clippy::type_complexity)] pub fn session_keys_trace<'a>( builder_state: &'a RefCell, @@ -109,7 +105,6 @@ pub fn session_keys_trace<'a>( } /// Reference implementation of session keys derivation. -#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(pms)))] pub fn session_keys( pms: [u8; 32], client_random: [u8; 32], diff --git a/components/prf/hmac-sha256-circuits/src/verify_data.rs b/crates/components/hmac-sha256-circuits/src/verify_data.rs similarity index 74% rename from components/prf/hmac-sha256-circuits/src/verify_data.rs rename to crates/components/hmac-sha256-circuits/src/verify_data.rs index 9f107b77fa..01f10b7469 100644 --- a/components/prf/hmac-sha256-circuits/src/verify_data.rs +++ b/crates/components/hmac-sha256-circuits/src/verify_data.rs @@ -10,19 +10,16 @@ use crate::prf::{prf, prf_trace}; /// Computes verify_data as specified in RFC 5246, Section 7.4.9. /// /// verify_data -/// PRF(master_secret, finished_label, Hash(handshake_messages))[0..verify_data_length-1]; +/// PRF(master_secret, finished_label, +/// Hash(handshake_messages))[0..verify_data_length-1]; /// /// # Arguments /// -/// * `builder_state` - The builder state -/// * `outer_state` - The outer HMAC state of the master secret -/// * `inner_state` - The inner HMAC state of the master secret -/// * `label` - The label to use -/// * `hs_hash` - The handshake hash -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(builder_state, outer_state, inner_state, label)) -)] +/// * `builder_state` - The builder state. +/// * `outer_state` - The outer HMAC state of the master secret. +/// * `inner_state` - The inner HMAC state of the master secret. +/// * `label` - The label to use. +/// * `hs_hash` - The handshake hash. pub fn verify_data_trace<'a>( builder_state: &'a RefCell, outer_state: [Tracer<'a, U32>; 8], @@ -35,18 +32,15 @@ pub fn verify_data_trace<'a>( vd.try_into().expect("vd is 12 bytes") } -/// Reference implementation of verify_data as specified in RFC 5246, Section 7.4.9. +/// Reference implementation of verify_data as specified in RFC 5246, Section +/// 7.4.9. /// /// # Arguments /// -/// * `outer_state` - The outer HMAC state of the master secret -/// * `inner_state` - The inner HMAC state of the master secret -/// * `label` - The label to use -/// * `hs_hash` - The handshake hash -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(outer_state, inner_state, label)) -)] +/// * `outer_state` - The outer HMAC state of the master secret. +/// * `inner_state` - The inner HMAC state of the master secret. +/// * `label` - The label to use. +/// * `hs_hash` - The handshake hash. pub fn verify_data( outer_state: [u32; 8], inner_state: [u32; 8], diff --git a/components/prf/hmac-sha256/Cargo.toml b/crates/components/hmac-sha256/Cargo.toml similarity index 52% rename from components/prf/hmac-sha256/Cargo.toml rename to crates/components/hmac-sha256/Cargo.toml index b8c3a69d97..c6ff0dfa17 100644 --- a/components/prf/hmac-sha256/Cargo.toml +++ b/crates/components/hmac-sha256/Cargo.toml @@ -5,7 +5,7 @@ description = "A 2PC implementation of TLS HMAC-SHA256 PRF" keywords = ["tls", "mpc", "2pc", "hmac", "sha256"] categories = ["cryptography"] license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" +version = "0.1.0-alpha.7" edition = "2021" [lib] @@ -13,24 +13,26 @@ name = "hmac_sha256" [features] default = ["mock"] -tracing = ["dep:tracing", "tlsn-hmac-sha256-circuits/tracing"] +rayon = ["mpz-common/rayon"] mock = [] [dependencies] -tlsn-hmac-sha256-circuits = { path = "../hmac-sha256-circuits" } -tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" } -mpz-garble.workspace = true -mpz-circuits.workspace = true +tlsn-hmac-sha256-circuits = { workspace = true } -async-trait.workspace = true -futures.workspace = true -thiserror.workspace = true -tracing = { workspace = true, optional = true } -derive_builder = "0.12" -enum-try-as-inner = "0.1" +mpz-garble = { workspace = true } +mpz-circuits = { workspace = true } +mpz-common = { workspace = true } + +async-trait = { workspace = true } +derive_builder = { workspace = true } +futures = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } [dev-dependencies] criterion = { workspace = true, features = ["async_tokio"] } +mpz-common = { workspace = true, features = ["test-utils"] } +mpz-ot = { workspace = true, features = ["ideal"] } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } [[bench]] diff --git a/crates/components/hmac-sha256/benches/prf.rs b/crates/components/hmac-sha256/benches/prf.rs new file mode 100644 index 0000000000..1a31aba397 --- /dev/null +++ b/crates/components/hmac-sha256/benches/prf.rs @@ -0,0 +1,188 @@ +use criterion::{criterion_group, criterion_main, Criterion}; + +use hmac_sha256::{MpcPrf, Prf, PrfConfig, Role}; +use mpz_common::executor::test_mt_executor; +use mpz_garble::{config::Role as DEAPRole, protocol::deap::DEAPThread, Memory}; +use mpz_ot::ideal::ot::ideal_ot; + +#[allow(clippy::unit_arg)] +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("prf"); + group.sample_size(10); + let rt = tokio::runtime::Runtime::new().unwrap(); + + group.bench_function("prf_preprocess", |b| b.to_async(&rt).iter(preprocess)); + group.bench_function("prf", |b| b.to_async(&rt).iter(prf)); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); + +async fn preprocess() { + let (mut leader_exec, mut follower_exec) = test_mt_executor(128); + + let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot(); + let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot(); + let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot(); + let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot(); + + let leader_thread_0 = DEAPThread::new( + DEAPRole::Leader, + [0u8; 32], + leader_exec.new_thread().await.unwrap(), + leader_ot_send_0, + leader_ot_recv_0, + ); + let leader_thread_1 = leader_thread_0 + .new_thread( + leader_exec.new_thread().await.unwrap(), + leader_ot_send_1, + leader_ot_recv_1, + ) + .unwrap(); + + let follower_thread_0 = DEAPThread::new( + DEAPRole::Follower, + [0u8; 32], + follower_exec.new_thread().await.unwrap(), + follower_ot_send_0, + follower_ot_recv_0, + ); + let follower_thread_1 = follower_thread_0 + .new_thread( + follower_exec.new_thread().await.unwrap(), + follower_ot_send_1, + follower_ot_recv_1, + ) + .unwrap(); + + let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap(); + let follower_pms = follower_thread_0 + .new_public_input::<[u8; 32]>("pms") + .unwrap(); + + let mut leader = MpcPrf::new( + PrfConfig::builder().role(Role::Leader).build().unwrap(), + leader_thread_0, + leader_thread_1, + ); + let mut follower = MpcPrf::new( + PrfConfig::builder().role(Role::Follower).build().unwrap(), + follower_thread_0, + follower_thread_1, + ); + + futures::join!( + async { + leader.setup(leader_pms).await.unwrap(); + leader.set_client_random(Some([0u8; 32])).await.unwrap(); + leader.preprocess().await.unwrap(); + }, + async { + follower.setup(follower_pms).await.unwrap(); + follower.set_client_random(None).await.unwrap(); + follower.preprocess().await.unwrap(); + } + ); +} + +async fn prf() { + let (mut leader_exec, mut follower_exec) = test_mt_executor(128); + + let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot(); + let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot(); + let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot(); + let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot(); + + let leader_thread_0 = DEAPThread::new( + DEAPRole::Leader, + [0u8; 32], + leader_exec.new_thread().await.unwrap(), + leader_ot_send_0, + leader_ot_recv_0, + ); + let leader_thread_1 = leader_thread_0 + .new_thread( + leader_exec.new_thread().await.unwrap(), + leader_ot_send_1, + leader_ot_recv_1, + ) + .unwrap(); + + let follower_thread_0 = DEAPThread::new( + DEAPRole::Follower, + [0u8; 32], + follower_exec.new_thread().await.unwrap(), + follower_ot_send_0, + follower_ot_recv_0, + ); + let follower_thread_1 = follower_thread_0 + .new_thread( + follower_exec.new_thread().await.unwrap(), + follower_ot_send_1, + follower_ot_recv_1, + ) + .unwrap(); + + let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap(); + let follower_pms = follower_thread_0 + .new_public_input::<[u8; 32]>("pms") + .unwrap(); + + let mut leader = MpcPrf::new( + PrfConfig::builder().role(Role::Leader).build().unwrap(), + leader_thread_0, + leader_thread_1, + ); + let mut follower = MpcPrf::new( + PrfConfig::builder().role(Role::Follower).build().unwrap(), + follower_thread_0, + follower_thread_1, + ); + + let pms = [42u8; 32]; + let client_random = [0u8; 32]; + let server_random = [1u8; 32]; + let cf_hs_hash = [2u8; 32]; + let sf_hs_hash = [3u8; 32]; + + futures::join!( + async { + leader.setup(leader_pms.clone()).await.unwrap(); + leader.set_client_random(Some(client_random)).await.unwrap(); + leader.preprocess().await.unwrap(); + }, + async { + follower.setup(follower_pms.clone()).await.unwrap(); + follower.set_client_random(None).await.unwrap(); + follower.preprocess().await.unwrap(); + } + ); + + leader.thread_mut().assign(&leader_pms, pms).unwrap(); + follower.thread_mut().assign(&follower_pms, pms).unwrap(); + + let (_leader_keys, _follower_keys) = futures::try_join!( + leader.compute_session_keys(server_random), + follower.compute_session_keys(server_random) + ) + .unwrap(); + + let _ = futures::try_join!( + leader.compute_client_finished_vd(cf_hs_hash), + follower.compute_client_finished_vd(cf_hs_hash) + ) + .unwrap(); + + let _ = futures::try_join!( + leader.compute_server_finished_vd(sf_hs_hash), + follower.compute_server_finished_vd(sf_hs_hash) + ) + .unwrap(); + + futures::try_join!( + leader.thread_mut().finalize(), + follower.thread_mut().finalize() + ) + .unwrap(); +} diff --git a/components/prf/hmac-sha256/src/config.rs b/crates/components/hmac-sha256/src/config.rs similarity index 100% rename from components/prf/hmac-sha256/src/config.rs rename to crates/components/hmac-sha256/src/config.rs diff --git a/crates/components/hmac-sha256/src/error.rs b/crates/components/hmac-sha256/src/error.rs new file mode 100644 index 0000000000..ec81638256 --- /dev/null +++ b/crates/components/hmac-sha256/src/error.rs @@ -0,0 +1,83 @@ +use core::fmt; +use std::error::Error; + +/// A PRF error. +#[derive(Debug, thiserror::Error)] +pub struct PrfError { + kind: ErrorKind, + #[source] + source: Option>, +} + +impl PrfError { + pub(crate) fn new(kind: ErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } + + pub(crate) fn state(msg: impl Into) -> Self { + Self { + kind: ErrorKind::State, + source: Some(msg.into().into()), + } + } + + pub(crate) fn role(msg: impl Into) -> Self { + Self { + kind: ErrorKind::Role, + source: Some(msg.into().into()), + } + } +} + +#[derive(Debug)] +pub(crate) enum ErrorKind { + Vm, + State, + Role, +} + +impl fmt::Display for PrfError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.kind { + ErrorKind::Vm => write!(f, "vm error")?, + ErrorKind::State => write!(f, "state error")?, + ErrorKind::Role => write!(f, "role error")?, + } + + if let Some(ref source) = self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +impl From for PrfError { + fn from(error: mpz_garble::MemoryError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for PrfError { + fn from(error: mpz_garble::LoadError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for PrfError { + fn from(error: mpz_garble::ExecutionError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for PrfError { + fn from(error: mpz_garble::DecodeError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} diff --git a/components/prf/hmac-sha256/src/lib.rs b/crates/components/hmac-sha256/src/lib.rs similarity index 59% rename from components/prf/hmac-sha256/src/lib.rs rename to crates/components/hmac-sha256/src/lib.rs index 57e450f89f..6550463cd4 100644 --- a/components/prf/hmac-sha256/src/lib.rs +++ b/crates/components/hmac-sha256/src/lib.rs @@ -20,7 +20,7 @@ pub(crate) static CF_LABEL: &[u8] = b"client finished"; pub(crate) static SF_LABEL: &[u8] = b"server finished"; /// Session keys computed by the PRF. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct SessionKeys { /// Client write key. pub client_write_key: ValueRef, @@ -35,47 +35,61 @@ pub struct SessionKeys { /// PRF trait for computing TLS PRF. #[async_trait] pub trait Prf { - /// Performs any necessary one-time setup. + /// Sets up the PRF. /// /// # Arguments /// /// * `pms` - The pre-master secret. - async fn setup(&mut self, pms: ValueRef) -> Result<(), PrfError>; + async fn setup(&mut self, pms: ValueRef) -> Result; - /// Computes the session keys using the provided client random, server random and PMS. - async fn compute_session_keys_private( - &mut self, - client_random: [u8; 32], - server_random: [u8; 32], - ) -> Result; + /// Sets the client random. + /// + /// This must be set after calling [`Prf::setup`]. + /// + /// Only the leader can provide the client random. + async fn set_client_random(&mut self, client_random: Option<[u8; 32]>) -> Result<(), PrfError>; + + /// Preprocesses the PRF. + async fn preprocess(&mut self) -> Result<(), PrfError>; - /// Computes the client finished verify data using the provided handshake hash. - async fn compute_client_finished_vd_private( + /// Computes the client finished verify data. + /// + /// # Arguments + /// + /// * `handshake_hash` - The handshake transcript hash. + async fn compute_client_finished_vd( &mut self, handshake_hash: [u8; 32], ) -> Result<[u8; 12], PrfError>; - /// Computes the server finished verify data using the provided handshake hash. - async fn compute_server_finished_vd_private( + /// Computes the server finished verify data. + /// + /// # Arguments + /// + /// * `handshake_hash` - The handshake transcript hash. + async fn compute_server_finished_vd( &mut self, handshake_hash: [u8; 32], ) -> Result<[u8; 12], PrfError>; - /// Computes the session keys using randoms provided by the other party. - async fn compute_session_keys_blind(&mut self) -> Result; - - /// Computes the client finished verify data using the handshake hash provided by the other party. - async fn compute_client_finished_vd_blind(&mut self) -> Result<(), PrfError>; - - /// Computes the server finished verify data using the handshake hash provided by the other party. - async fn compute_server_finished_vd_blind(&mut self) -> Result<(), PrfError>; + /// Computes the session keys. + /// + /// # Arguments + /// + /// * `server_random` - The server random. + async fn compute_session_keys( + &mut self, + server_random: [u8; 32], + ) -> Result; } #[cfg(test)] mod tests { - use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Decode, Memory, Vm}; + use mpz_common::executor::test_st_executor; + use mpz_garble::{config::Role as DEAPRole, protocol::deap::DEAPThread, Decode, Memory}; use hmac_sha256_circuits::{hmac_sha256_partial, prf, session_keys}; + use mpz_ot::ideal::ot::ideal_ot; use super::*; @@ -104,38 +118,72 @@ mod tests { let server_random: [u8; 32] = [96u8; 32]; let ms = compute_ms(pms, client_random, server_random); - let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test").await; + let (leader_ctx_0, follower_ctx_0) = test_st_executor(128); + let (leader_ctx_1, follower_ctx_1) = test_st_executor(128); - let mut leader_test_thread = leader_vm.new_thread("test").await.unwrap(); - let mut follower_test_thread = follower_vm.new_thread("test").await.unwrap(); + let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot(); + let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot(); + let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot(); + let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot(); - // Setup public PMS for testing - let leader_pms = leader_test_thread - .new_public_input::<[u8; 32]>("pms") + let leader_thread_0 = DEAPThread::new( + DEAPRole::Leader, + [0u8; 32], + leader_ctx_0, + leader_ot_send_0, + leader_ot_recv_0, + ); + let leader_thread_1 = leader_thread_0 + .new_thread(leader_ctx_1, leader_ot_send_1, leader_ot_recv_1) .unwrap(); - let follower_pms = follower_test_thread + + let follower_thread_0 = DEAPThread::new( + DEAPRole::Follower, + [0u8; 32], + follower_ctx_0, + follower_ot_send_0, + follower_ot_recv_0, + ); + let follower_thread_1 = follower_thread_0 + .new_thread(follower_ctx_1, follower_ot_send_1, follower_ot_recv_1) + .unwrap(); + + // Set up public PMS for testing. + let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap(); + let follower_pms = follower_thread_0 .new_public_input::<[u8; 32]>("pms") .unwrap(); - leader_test_thread.assign(&leader_pms, pms).unwrap(); - follower_test_thread.assign(&follower_pms, pms).unwrap(); + leader_thread_0.assign(&leader_pms, pms).unwrap(); + follower_thread_0.assign(&follower_pms, pms).unwrap(); let mut leader = MpcPrf::new( PrfConfig::builder().role(Role::Leader).build().unwrap(), - leader_vm.new_thread("prf/0").await.unwrap(), - leader_vm.new_thread("prf/1").await.unwrap(), + leader_thread_0, + leader_thread_1, ); let mut follower = MpcPrf::new( PrfConfig::builder().role(Role::Follower).build().unwrap(), - follower_vm.new_thread("prf/0").await.unwrap(), - follower_vm.new_thread("prf/1").await.unwrap(), + follower_thread_0, + follower_thread_1, ); - futures::try_join!(leader.setup(leader_pms), follower.setup(follower_pms)).unwrap(); + futures::join!( + async { + leader.setup(leader_pms).await.unwrap(); + leader.set_client_random(Some(client_random)).await.unwrap(); + leader.preprocess().await.unwrap(); + }, + async { + follower.setup(follower_pms).await.unwrap(); + follower.set_client_random(None).await.unwrap(); + follower.preprocess().await.unwrap(); + } + ); let (leader_session_keys, follower_session_keys) = futures::try_join!( - leader.compute_session_keys_private(client_random, server_random), - follower.compute_session_keys_blind() + leader.compute_session_keys(server_random), + follower.compute_session_keys(server_random) ) .unwrap(); @@ -155,13 +203,15 @@ mod tests { // Decode session keys let (leader_session_keys, follower_session_keys) = futures::try_join!( - async move { - leader_test_thread + async { + leader + .thread_mut() .decode(&[leader_cwk, leader_swk, leader_civ, leader_siv]) .await }, - async move { - follower_test_thread + async { + follower + .thread_mut() .decode(&[follower_cwk, follower_swk, follower_civ, follower_siv]) .await } @@ -195,8 +245,8 @@ mod tests { let sf_hs_hash = [2u8; 32]; let (cf_vd, _) = futures::try_join!( - leader.compute_client_finished_vd_private(cf_hs_hash), - follower.compute_client_finished_vd_blind() + leader.compute_client_finished_vd(cf_hs_hash), + follower.compute_client_finished_vd(cf_hs_hash) ) .unwrap(); @@ -205,8 +255,8 @@ mod tests { assert_eq!(cf_vd, expected_cf_vd); let (sf_vd, _) = futures::try_join!( - leader.compute_server_finished_vd_private(sf_hs_hash), - follower.compute_server_finished_vd_blind() + leader.compute_server_finished_vd(sf_hs_hash), + follower.compute_server_finished_vd(sf_hs_hash) ) .unwrap(); diff --git a/crates/components/hmac-sha256/src/prf.rs b/crates/components/hmac-sha256/src/prf.rs new file mode 100644 index 0000000000..77afab3eaf --- /dev/null +++ b/crates/components/hmac-sha256/src/prf.rs @@ -0,0 +1,443 @@ +use std::{ + fmt::Debug, + sync::{Arc, OnceLock}, +}; + +use async_trait::async_trait; + +use hmac_sha256_circuits::{build_session_keys, build_verify_data}; +use mpz_circuits::Circuit; +use mpz_common::cpu::CpuBackend; +use mpz_garble::{config::Visibility, value::ValueRef, Decode, Execute, Load, Memory}; +use tracing::instrument; + +use crate::{Prf, PrfConfig, PrfError, Role, SessionKeys, CF_LABEL, SF_LABEL}; + +/// Circuit for computing TLS session keys. +static SESSION_KEYS_CIRC: OnceLock> = OnceLock::new(); +/// Circuit for computing TLS client verify data. +static CLIENT_VD_CIRC: OnceLock> = OnceLock::new(); +/// Circuit for computing TLS server verify data. +static SERVER_VD_CIRC: OnceLock> = OnceLock::new(); + +#[derive(Debug)] +pub(crate) struct Randoms { + pub(crate) client_random: ValueRef, + pub(crate) server_random: ValueRef, +} + +#[derive(Debug, Clone)] +pub(crate) struct HashState { + pub(crate) ms_outer_hash_state: ValueRef, + pub(crate) ms_inner_hash_state: ValueRef, +} + +#[derive(Debug)] +pub(crate) struct VerifyData { + pub(crate) handshake_hash: ValueRef, + pub(crate) vd: ValueRef, +} + +#[derive(Debug)] +pub(crate) enum State { + Initialized, + SessionKeys { + pms: ValueRef, + randoms: Randoms, + hash_state: HashState, + keys: crate::SessionKeys, + cf_vd: VerifyData, + sf_vd: VerifyData, + }, + ClientFinished { + hash_state: HashState, + cf_vd: VerifyData, + sf_vd: VerifyData, + }, + ServerFinished { + hash_state: HashState, + sf_vd: VerifyData, + }, + Complete, + Error, +} + +impl State { + fn take(&mut self) -> State { + std::mem::replace(self, State::Error) + } +} + +/// MPC PRF for computing TLS HMAC-SHA256 PRF. +pub struct MpcPrf { + config: PrfConfig, + state: State, + thread_0: E, + thread_1: E, +} + +impl Debug for MpcPrf { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MpcPrf") + .field("config", &self.config) + .field("state", &self.state) + .finish() + } +} + +impl MpcPrf +where + E: Load + Memory + Execute + Decode + Send, +{ + /// Creates a new instance of the PRF. + pub fn new(config: PrfConfig, thread_0: E, thread_1: E) -> MpcPrf { + MpcPrf { + config, + state: State::Initialized, + thread_0, + thread_1, + } + } + + /// Returns a mutable reference to the MPC thread. + pub fn thread_mut(&mut self) -> &mut E { + &mut self.thread_0 + } + + /// Executes a circuit which computes TLS session keys. + #[instrument(level = "debug", skip_all, err)] + async fn execute_session_keys( + &mut self, + server_random: [u8; 32], + ) -> Result { + let State::SessionKeys { + pms, + randoms: randoms_refs, + hash_state, + keys, + cf_vd, + sf_vd, + } = self.state.take() + else { + return Err(PrfError::state("session keys not initialized")); + }; + + let circ = SESSION_KEYS_CIRC + .get() + .expect("session keys circuit is set"); + + self.thread_0 + .assign(&randoms_refs.server_random, server_random)?; + + self.thread_0 + .execute( + circ.clone(), + &[pms, randoms_refs.client_random, randoms_refs.server_random], + &[ + keys.client_write_key.clone(), + keys.server_write_key.clone(), + keys.client_iv.clone(), + keys.server_iv.clone(), + hash_state.ms_outer_hash_state.clone(), + hash_state.ms_inner_hash_state.clone(), + ], + ) + .await?; + + self.state = State::ClientFinished { + hash_state, + cf_vd, + sf_vd, + }; + + Ok(keys) + } + + #[instrument(level = "debug", skip_all, err)] + async fn execute_cf_vd(&mut self, handshake_hash: [u8; 32]) -> Result<[u8; 12], PrfError> { + let State::ClientFinished { + hash_state, + cf_vd, + sf_vd, + } = self.state.take() + else { + return Err(PrfError::state("PRF not in client finished state")); + }; + + let circ = CLIENT_VD_CIRC.get().expect("client vd circuit is set"); + + self.thread_0 + .assign(&cf_vd.handshake_hash, handshake_hash)?; + + self.thread_0 + .execute( + circ.clone(), + &[ + hash_state.ms_outer_hash_state.clone(), + hash_state.ms_inner_hash_state.clone(), + cf_vd.handshake_hash, + ], + &[cf_vd.vd.clone()], + ) + .await?; + + let mut outputs = self.thread_0.decode(&[cf_vd.vd]).await?; + let vd: [u8; 12] = outputs.remove(0).try_into().expect("vd is 12 bytes"); + + self.state = State::ServerFinished { hash_state, sf_vd }; + + Ok(vd) + } + + #[instrument(level = "debug", skip_all, err)] + async fn execute_sf_vd(&mut self, handshake_hash: [u8; 32]) -> Result<[u8; 12], PrfError> { + let State::ServerFinished { hash_state, sf_vd } = self.state.take() else { + return Err(PrfError::state("PRF not in server finished state")); + }; + + let circ = SERVER_VD_CIRC.get().expect("server vd circuit is set"); + + self.thread_0 + .assign(&sf_vd.handshake_hash, handshake_hash)?; + + self.thread_0 + .execute( + circ.clone(), + &[ + hash_state.ms_outer_hash_state, + hash_state.ms_inner_hash_state, + sf_vd.handshake_hash, + ], + &[sf_vd.vd.clone()], + ) + .await?; + + let mut outputs = self.thread_0.decode(&[sf_vd.vd]).await?; + let vd: [u8; 12] = outputs.remove(0).try_into().expect("vd is 12 bytes"); + + self.state = State::Complete; + + Ok(vd) + } +} + +#[async_trait] +impl Prf for MpcPrf +where + E: Memory + Load + Execute + Decode + Send, +{ + #[instrument(level = "debug", skip_all, err)] + async fn setup(&mut self, pms: ValueRef) -> Result { + let State::Initialized = self.state.take() else { + return Err(PrfError::state("PRF not in initialized state")); + }; + + let thread = &mut self.thread_0; + + let randoms = Randoms { + // The client random is kept private so that the handshake transcript + // hashes do not leak information about the server's identity. + client_random: thread.new_input::<[u8; 32]>( + "client_random", + match self.config.role { + Role::Leader => Visibility::Private, + Role::Follower => Visibility::Blind, + }, + )?, + server_random: thread.new_input::<[u8; 32]>("server_random", Visibility::Public)?, + }; + + let keys = SessionKeys { + client_write_key: thread.new_output::<[u8; 16]>("client_write_key")?, + server_write_key: thread.new_output::<[u8; 16]>("server_write_key")?, + client_iv: thread.new_output::<[u8; 4]>("client_write_iv")?, + server_iv: thread.new_output::<[u8; 4]>("server_write_iv")?, + }; + + let hash_state = HashState { + ms_outer_hash_state: thread.new_output::<[u32; 8]>("ms_outer_hash_state")?, + ms_inner_hash_state: thread.new_output::<[u32; 8]>("ms_inner_hash_state")?, + }; + + let cf_vd = VerifyData { + handshake_hash: thread.new_input::<[u8; 32]>("cf_hash", Visibility::Public)?, + vd: thread.new_output::<[u8; 12]>("cf_vd")?, + }; + + let sf_vd = VerifyData { + handshake_hash: thread.new_input::<[u8; 32]>("sf_hash", Visibility::Public)?, + vd: thread.new_output::<[u8; 12]>("sf_vd")?, + }; + + self.state = State::SessionKeys { + pms, + randoms, + hash_state, + keys: keys.clone(), + cf_vd, + sf_vd, + }; + + Ok(keys) + } + + #[instrument(level = "debug", skip_all, err)] + async fn set_client_random(&mut self, client_random: Option<[u8; 32]>) -> Result<(), PrfError> { + let State::SessionKeys { randoms, .. } = &self.state else { + return Err(PrfError::state("PRF not set up")); + }; + + if self.config.role == Role::Leader { + let Some(client_random) = client_random else { + return Err(PrfError::role("leader must provide client random")); + }; + + self.thread_0 + .assign(&randoms.client_random, client_random)?; + } else if client_random.is_some() { + return Err(PrfError::role("only leader can set client random")); + } + + self.thread_0 + .commit(&[randoms.client_random.clone()]) + .await?; + + Ok(()) + } + + #[instrument(level = "debug", skip_all, err)] + async fn preprocess(&mut self) -> Result<(), PrfError> { + let State::SessionKeys { + pms, + randoms, + hash_state, + keys, + cf_vd, + sf_vd, + } = self.state.take() + else { + return Err(PrfError::state("PRF not set up")); + }; + + // Builds all circuits in parallel and preprocesses the session keys circuit. + futures::try_join!( + async { + if SESSION_KEYS_CIRC.get().is_none() { + _ = SESSION_KEYS_CIRC.set(CpuBackend::blocking(build_session_keys).await); + } + + let circ = SESSION_KEYS_CIRC + .get() + .expect("session keys circuit should be built"); + + self.thread_0 + .load( + circ.clone(), + &[ + pms.clone(), + randoms.client_random.clone(), + randoms.server_random.clone(), + ], + &[ + keys.client_write_key.clone(), + keys.server_write_key.clone(), + keys.client_iv.clone(), + keys.server_iv.clone(), + hash_state.ms_outer_hash_state.clone(), + hash_state.ms_inner_hash_state.clone(), + ], + ) + .await?; + + Ok::<_, PrfError>(()) + }, + async { + if CLIENT_VD_CIRC.get().is_none() { + _ = CLIENT_VD_CIRC + .set(CpuBackend::blocking(move || build_verify_data(CF_LABEL)).await); + } + + Ok::<_, PrfError>(()) + }, + async { + if SERVER_VD_CIRC.get().is_none() { + _ = SERVER_VD_CIRC + .set(CpuBackend::blocking(move || build_verify_data(SF_LABEL)).await); + } + + Ok::<_, PrfError>(()) + } + )?; + + // Finishes preprocessing the verify data circuits. + futures::try_join!( + async { + self.thread_0 + .load( + CLIENT_VD_CIRC + .get() + .expect("client finished circuit should be built") + .clone(), + &[ + hash_state.ms_outer_hash_state.clone(), + hash_state.ms_inner_hash_state.clone(), + cf_vd.handshake_hash.clone(), + ], + &[cf_vd.vd.clone()], + ) + .await + }, + async { + self.thread_1 + .load( + SERVER_VD_CIRC + .get() + .expect("server finished circuit should be built") + .clone(), + &[ + hash_state.ms_outer_hash_state.clone(), + hash_state.ms_inner_hash_state.clone(), + sf_vd.handshake_hash.clone(), + ], + &[sf_vd.vd.clone()], + ) + .await + } + )?; + + self.state = State::SessionKeys { + pms, + randoms, + hash_state, + keys, + cf_vd, + sf_vd, + }; + + Ok(()) + } + + #[instrument(level = "debug", skip_all, err)] + async fn compute_client_finished_vd( + &mut self, + handshake_hash: [u8; 32], + ) -> Result<[u8; 12], PrfError> { + self.execute_cf_vd(handshake_hash).await + } + + #[instrument(level = "debug", skip_all, err)] + async fn compute_server_finished_vd( + &mut self, + handshake_hash: [u8; 32], + ) -> Result<[u8; 12], PrfError> { + self.execute_sf_vd(handshake_hash).await + } + + #[instrument(level = "debug", skip_all, err)] + async fn compute_session_keys( + &mut self, + server_random: [u8; 32], + ) -> Result { + self.execute_session_keys(server_random).await + } +} diff --git a/crates/components/key-exchange/Cargo.toml b/crates/components/key-exchange/Cargo.toml new file mode 100644 index 0000000000..1f5f26357f --- /dev/null +++ b/crates/components/key-exchange/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "tlsn-key-exchange" +authors = ["TLSNotary Team"] +description = "Implementation of the 3-party key-exchange protocol" +keywords = ["tls", "mpc", "2pc", "pms", "key-exchange"] +categories = ["cryptography"] +license = "MIT OR Apache-2.0" +version = "0.1.0-alpha.7" +edition = "2021" + +[lib] +name = "key_exchange" + +[features] +default = ["mock"] +mock = [] + +[dependencies] +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", features = [ + "ideal", +] } +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } + +p256 = { workspace = true, features = ["ecdh", "serde"] } +async-trait = { workspace = true } +thiserror = { workspace = true } +serde = { workspace = true } +futures = { workspace = true } +serio = { workspace = true } +derive_builder = { workspace = true } +tracing = { workspace = true } +rand = { workspace = true } + +[dev-dependencies] +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", features = [ + "ideal", +] } + +rand_chacha = { workspace = true } +rand_core = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } diff --git a/crates/components/key-exchange/src/circuit.rs b/crates/components/key-exchange/src/circuit.rs new file mode 100644 index 0000000000..28fdc42714 --- /dev/null +++ b/crates/components/key-exchange/src/circuit.rs @@ -0,0 +1,43 @@ +//! This module provides the circuits used in the key exchange protocol. + +use std::sync::Arc; + +use mpz_circuits::{circuits::big_num::nbyte_add_mod_trace, Circuit, CircuitBuilder}; + +/// NIST P-256 prime big-endian. +static P: [u8; 32] = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, +]; + +/// Circuit for combining additive shares of the PMS, twice +/// +/// # Inputs +/// +/// 0. PMS_SHARE_A0: 32 bytes PMS Additive Share +/// 1. PMS_SHARE_B0: 32 bytes PMS Additive Share +/// 2. PMS_SHARE_A1: 32 bytes PMS Additive Share +/// 3. PMS_SHARE_B1: 32 bytes PMS Additive Share +/// +/// # Outputs +/// 0. PMS_0: Pre-master Secret = PMS_SHARE_A0 + PMS_SHARE_B0 +/// 1. PMS_1: Pre-master Secret = PMS_SHARE_A1 + PMS_SHARE_B1 +/// 2. EQ: Equality check of PMS_0 and PMS_1 +pub(crate) fn build_pms_circuit() -> Arc { + let builder = CircuitBuilder::new(); + let share_a0 = builder.add_array_input::(); + let share_b0 = builder.add_array_input::(); + let share_a1 = builder.add_array_input::(); + let share_b1 = builder.add_array_input::(); + + let pms_0 = nbyte_add_mod_trace(builder.state(), share_a0, share_b0, P); + let pms_1 = nbyte_add_mod_trace(builder.state(), share_a1, share_b1, P); + + let eq: [_; 32] = std::array::from_fn(|i| pms_0[i] ^ pms_1[i]); + + builder.add_output(pms_0); + builder.add_output(pms_1); + builder.add_output(eq); + + Arc::new(builder.build().expect("pms circuit is valid")) +} diff --git a/components/key-exchange/src/config.rs b/crates/components/key-exchange/src/config.rs similarity index 50% rename from components/key-exchange/src/config.rs rename to crates/components/key-exchange/src/config.rs index 497b299e52..5e32f55c73 100644 --- a/components/key-exchange/src/config.rs +++ b/crates/components/key-exchange/src/config.rs @@ -1,37 +1,28 @@ -//! This module provides the [KeyExchangeConfig] for configuration of the key exchange instance - use derive_builder::Builder; -/// Role in the key exchange protocol +/// Role in the key exchange protocol. #[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(missing_docs)] pub enum Role { + /// Leader. Leader, + /// Follower. Follower, } -/// A config used for [KeyExchangeCore](super::KeyExchangeCore) +/// A config used for [MpcKeyExchange](super::MpcKeyExchange). #[derive(Debug, Clone, Builder)] pub struct KeyExchangeConfig { - /// The id of this instance - #[builder(setter(into))] - id: String, - /// Protocol role + /// Protocol role. role: Role, } impl KeyExchangeConfig { - /// Creates a new builder for the key exchange configuration + /// Creates a new builder for the key exchange configuration. pub fn builder() -> KeyExchangeConfigBuilder { KeyExchangeConfigBuilder::default() } - /// Get the id of this instance - pub fn id(&self) -> &str { - &self.id - } - - /// Get the role of this instance + /// Get the role of this instance. pub fn role(&self) -> &Role { &self.role } diff --git a/crates/components/key-exchange/src/error.rs b/crates/components/key-exchange/src/error.rs new file mode 100644 index 0000000000..cf9e7da4d0 --- /dev/null +++ b/crates/components/key-exchange/src/error.rs @@ -0,0 +1,120 @@ +use core::fmt; +use std::error::Error; + +/// A key exchange error. +#[derive(Debug, thiserror::Error)] +pub struct KeyExchangeError { + kind: ErrorKind, + #[source] + source: Option>, +} + +impl KeyExchangeError { + pub(crate) fn new(kind: ErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } + + #[cfg(test)] + pub(crate) fn kind(&self) -> &ErrorKind { + &self.kind + } + + pub(crate) fn state(msg: impl Into) -> Self { + Self { + kind: ErrorKind::State, + source: Some(msg.into().into()), + } + } + + pub(crate) fn role(msg: impl Into) -> Self { + Self { + kind: ErrorKind::Role, + source: Some(msg.into().into()), + } + } +} + +#[derive(Debug)] +pub(crate) enum ErrorKind { + Io, + Context, + Vm, + ShareConversion, + Key, + State, + Role, +} + +impl fmt::Display for KeyExchangeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.kind { + ErrorKind::Io => write!(f, "io error")?, + ErrorKind::Context => write!(f, "context error")?, + ErrorKind::Vm => write!(f, "vm error")?, + ErrorKind::ShareConversion => write!(f, "share conversion error")?, + ErrorKind::Key => write!(f, "key error")?, + ErrorKind::State => write!(f, "state error")?, + ErrorKind::Role => write!(f, "role error")?, + } + + if let Some(ref source) = self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +impl From for KeyExchangeError { + fn from(error: mpz_common::ContextError) -> Self { + Self::new(ErrorKind::Context, error) + } +} + +impl From for KeyExchangeError { + fn from(error: mpz_garble::MemoryError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for KeyExchangeError { + fn from(error: mpz_garble::LoadError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for KeyExchangeError { + fn from(error: mpz_garble::ExecutionError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for KeyExchangeError { + fn from(error: mpz_garble::DecodeError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for KeyExchangeError { + fn from(error: mpz_share_conversion::ShareConversionError) -> Self { + Self::new(ErrorKind::ShareConversion, error) + } +} + +impl From for KeyExchangeError { + fn from(error: p256::elliptic_curve::Error) -> Self { + Self::new(ErrorKind::Key, error) + } +} + +impl From for KeyExchangeError { + fn from(error: std::io::Error) -> Self { + Self::new(ErrorKind::Io, error) + } +} diff --git a/crates/components/key-exchange/src/exchange.rs b/crates/components/key-exchange/src/exchange.rs new file mode 100644 index 0000000000..76cc53047e --- /dev/null +++ b/crates/components/key-exchange/src/exchange.rs @@ -0,0 +1,656 @@ +//! This module implements the key exchange logic. + +use async_trait::async_trait; +use mpz_common::{scoped_futures::ScopedFutureExt, Allocate, Context, Preprocess}; +use mpz_garble::{value::ValueRef, Decode, Execute, Load, Memory}; + +use mpz_fields::{p256::P256, Field}; +use mpz_share_conversion::{ShareConversionError, ShareConvert}; +use p256::{EncodedPoint, PublicKey, SecretKey}; +use serio::{stream::IoStreamExt, SinkExt}; +use std::fmt::Debug; +use tracing::{debug, instrument}; + +use crate::{ + circuit::build_pms_circuit, + config::{KeyExchangeConfig, Role}, + error::ErrorKind, + point_addition::derive_x_coord_share, + KeyExchange, KeyExchangeError, Pms, +}; + +#[derive(Debug)] +enum State { + Initialized, + Setup { + share_a0: ValueRef, + share_b0: ValueRef, + share_a1: ValueRef, + share_b1: ValueRef, + pms_0: ValueRef, + pms_1: ValueRef, + eq: ValueRef, + }, + Preprocessed { + share_a0: ValueRef, + share_b0: ValueRef, + share_a1: ValueRef, + share_b1: ValueRef, + pms_0: ValueRef, + pms_1: ValueRef, + eq: ValueRef, + }, + Complete, + Error, +} + +impl State { + fn is_preprocessed(&self) -> bool { + matches!(self, Self::Preprocessed { .. }) + } + + fn take(&mut self) -> Self { + std::mem::replace(self, Self::Error) + } +} + +/// An MPC key exchange protocol. +/// +/// Can be either a leader or a follower depending on the `role` field in +/// [`KeyExchangeConfig`]. +#[derive(Debug)] +pub struct MpcKeyExchange { + ctx: Ctx, + /// Share conversion protocol 0. + converter_0: C0, + /// Share conversion protocol 1. + converter_1: C1, + /// MPC executor. + executor: E, + /// The private key of the party behind this instance, either follower or + /// leader. + private_key: Option, + /// The public key of the server. + server_key: Option, + /// The config used for the key exchange protocol. + config: KeyExchangeConfig, + /// The state of the protocol. + state: State, +} + +impl MpcKeyExchange { + /// Creates a new [`MpcKeyExchange`]. + /// + /// # Arguments + /// + /// * `config` - Key exchange configuration. + /// * `ctx` - Thread context. + /// * `converter_0` - Share conversion protocol instance 0. + /// * `converter_1` - Share conversion protocol instance 1. + /// * `executor` - MPC executor. + pub fn new( + config: KeyExchangeConfig, + ctx: Ctx, + converter_0: C0, + converter_1: C1, + executor: E, + ) -> Self { + Self { + ctx, + converter_0, + converter_1, + executor, + private_key: None, + server_key: None, + config, + state: State::Initialized, + } + } +} + +impl MpcKeyExchange +where + Ctx: Context, + E: Execute + Load + Memory + Decode + Send, + C0: ShareConvert + Send, + C1: ShareConvert + Send, +{ + async fn compute_pms_shares( + &mut self, + server_key: PublicKey, + private_key: SecretKey, + ) -> Result<(P256, P256), KeyExchangeError> { + compute_pms_shares( + &mut self.ctx, + *self.config.role(), + &mut self.converter_0, + &mut self.converter_1, + server_key, + private_key, + ) + .await + } + + // Computes the PMS using both parties' shares, performing an equality check + // to ensure the shares are equal. + async fn compute_pms_with( + &mut self, + share_0: P256, + share_1: P256, + ) -> Result { + let State::Preprocessed { + share_a0, + share_b0, + share_a1, + share_b1, + pms_0, + pms_1, + eq, + } = self.state.take() + else { + return Err(KeyExchangeError::state("not in preprocessed state")); + }; + + let share_0_bytes: [u8; 32] = share_0 + .to_be_bytes() + .try_into() + .expect("pms share is 32 bytes"); + let share_1_bytes: [u8; 32] = share_1 + .to_be_bytes() + .try_into() + .expect("pms share is 32 bytes"); + + match self.config.role() { + Role::Leader => { + self.executor.assign(&share_a0, share_0_bytes)?; + self.executor.assign(&share_a1, share_1_bytes)?; + } + Role::Follower => { + self.executor.assign(&share_b0, share_0_bytes)?; + self.executor.assign(&share_b1, share_1_bytes)?; + } + } + + self.executor + .execute( + build_pms_circuit(), + &[share_a0, share_b0, share_a1, share_b1], + &[pms_0.clone(), pms_1, eq.clone()], + ) + .await?; + + let eq: [u8; 32] = self + .executor + .decode(&[eq]) + .await? + .pop() + .expect("output 0 is eq") + .try_into() + .expect("eq is 32 bytes"); + + // Eq should be all zeros if pms_1 == pms_2. + if eq != [0u8; 32] { + return Err(KeyExchangeError::new( + ErrorKind::ShareConversion, + "PMS values not equal", + )); + } + + // Both parties use pms_0 as the pre-master secret. + Ok(Pms::new(pms_0)) + } +} + +#[async_trait] +impl KeyExchange for MpcKeyExchange +where + Ctx: Context, + E: Execute + Load + Memory + Decode + Send, + C0: Allocate + Preprocess + ShareConvert + Send, + C1: Allocate + Preprocess + ShareConvert + Send, +{ + fn server_key(&self) -> Option { + self.server_key + } + + async fn set_server_key(&mut self, server_key: PublicKey) -> Result<(), KeyExchangeError> { + let Role::Leader = self.config.role() else { + return Err(KeyExchangeError::role("follower cannot set server key")); + }; + + // Send server public key to follower. + self.ctx.io_mut().send(server_key).await?; + + self.server_key = Some(server_key); + + Ok(()) + } + + #[instrument(level = "debug", skip_all, err)] + async fn setup(&mut self) -> Result { + let State::Initialized = self.state.take() else { + return Err(KeyExchangeError::state("not in initialized state")); + }; + + // 2 A2M, 1 M2A. + self.converter_0.alloc(3); + self.converter_1.alloc(3); + + let (share_a0, share_b0, share_a1, share_b1) = match self.config.role() { + Role::Leader => { + let share_a0 = self + .executor + .new_private_input::<[u8; 32]>("pms/share_a0")?; + let share_b0 = self.executor.new_blind_input::<[u8; 32]>("pms/share_b0")?; + let share_a1 = self + .executor + .new_private_input::<[u8; 32]>("pms/share_a1")?; + let share_b1 = self.executor.new_blind_input::<[u8; 32]>("pms/share_b1")?; + + (share_a0, share_b0, share_a1, share_b1) + } + Role::Follower => { + let share_a0 = self.executor.new_blind_input::<[u8; 32]>("pms/share_a0")?; + let share_b0 = self + .executor + .new_private_input::<[u8; 32]>("pms/share_b0")?; + let share_a1 = self.executor.new_blind_input::<[u8; 32]>("pms/share_a1")?; + let share_b1 = self + .executor + .new_private_input::<[u8; 32]>("pms/share_b1")?; + + (share_a0, share_b0, share_a1, share_b1) + } + }; + + let pms_0 = self.executor.new_output::<[u8; 32]>("pms_0")?; + let pms_1 = self.executor.new_output::<[u8; 32]>("pms_1")?; + let eq = self.executor.new_output::<[u8; 32]>("eq")?; + + self.state = State::Setup { + share_a0, + share_b0, + share_a1, + share_b1, + pms_0: pms_0.clone(), + pms_1, + eq, + }; + + Ok(Pms::new(pms_0)) + } + + #[instrument(level = "debug", skip_all, err)] + async fn preprocess(&mut self) -> Result<(), KeyExchangeError> { + let State::Setup { + share_a0, + share_b0, + share_a1, + share_b1, + pms_0, + pms_1, + eq, + } = self.state.take() + else { + return Err(KeyExchangeError::state("not in setup state")); + }; + + // Preprocess share conversion and garbled circuits concurrently. + futures::try_join!( + async { + self.ctx + .try_join( + |ctx| self.converter_0.preprocess(ctx).scope_boxed(), + |ctx| self.converter_1.preprocess(ctx).scope_boxed(), + ) + .await??; + + Ok::<_, KeyExchangeError>(()) + }, + async { + self.executor + .load( + build_pms_circuit(), + &[ + share_a0.clone(), + share_b0.clone(), + share_a1.clone(), + share_b1.clone(), + ], + &[pms_0.clone(), pms_1.clone(), eq.clone()], + ) + .await?; + + Ok::<_, KeyExchangeError>(()) + } + )?; + + // Follower can forward their key share immediately. + if let Role::Follower = self.config.role() { + let private_key = self + .private_key + .get_or_insert_with(|| SecretKey::random(&mut rand::rngs::OsRng)); + + self.ctx.io_mut().send(private_key.public_key()).await?; + + debug!("sent public key share to leader"); + } + + self.state = State::Preprocessed { + share_a0, + share_b0, + share_a1, + share_b1, + pms_0, + pms_1, + eq, + }; + + Ok(()) + } + + #[instrument(level = "debug", skip_all, err)] + async fn client_key(&mut self) -> Result { + if let Role::Leader = self.config.role() { + let private_key = self + .private_key + .get_or_insert_with(|| SecretKey::random(&mut rand::rngs::OsRng)); + let public_key = private_key.public_key(); + + // Receive public key share from follower. + let follower_public_key: PublicKey = self.ctx.io_mut().expect_next().await?; + + debug!("received public key share from follower"); + + // Combine public keys. + let client_public_key = PublicKey::from_affine( + (public_key.to_projective() + follower_public_key.to_projective()).to_affine(), + )?; + + Ok(client_public_key) + } else { + Err(KeyExchangeError::role("follower does not learn client key")) + } + } + + #[instrument(level = "debug", skip_all, err)] + async fn compute_pms(&mut self) -> Result { + if !self.state.is_preprocessed() { + return Err(KeyExchangeError::state("not in preprocessed state")); + } + + let server_key = match self.config.role() { + Role::Leader => self + .server_key + .ok_or_else(|| KeyExchangeError::state("server public key not set"))?, + Role::Follower => { + // Receive server public key from leader. + let server_key = self.ctx.io_mut().expect_next().await?; + + self.server_key = Some(server_key); + + server_key + } + }; + + let private_key = self + .private_key + .take() + .ok_or(KeyExchangeError::state("private key not set"))?; + + let (pms_share_0, pms_share_1) = self.compute_pms_shares(server_key, private_key).await?; + let pms = self.compute_pms_with(pms_share_0, pms_share_1).await?; + + self.state = State::Complete; + + Ok(pms) + } +} + +async fn compute_pms_shares< + Ctx: Context, + C0: ShareConvert + Send, + C1: ShareConvert + Send, +>( + ctx: &mut Ctx, + role: Role, + converter_0: &mut C0, + converter_1: &mut C1, + server_key: PublicKey, + private_key: SecretKey, +) -> Result<(P256, P256), KeyExchangeError> { + // Compute the leader's/follower's share of the pre-master secret. + // + // We need to mimic the [diffie-hellman](p256::ecdh::diffie_hellman) function + // without the [SharedSecret](p256::ecdh::SharedSecret) wrapper, because + // this makes it harder to get the result as an EC curve point. + let shared_secret = { + let public_projective = server_key.to_projective(); + (public_projective * private_key.to_nonzero_scalar().as_ref()).to_affine() + }; + + let encoded_point = EncodedPoint::from(PublicKey::from_affine(shared_secret)?); + + let (pms_share_0, pms_share_1) = ctx + .try_join( + |ctx| { + async { derive_x_coord_share(role, ctx, converter_0, encoded_point).await } + .scope_boxed() + }, + |ctx| { + async { derive_x_coord_share(role, ctx, converter_1, encoded_point).await } + .scope_boxed() + }, + ) + .await??; + + Ok((pms_share_0, pms_share_1)) +} + +#[cfg(test)] +mod tests { + use super::*; + + use mpz_common::executor::{test_st_executor, STExecutor}; + use mpz_garble::protocol::deap::mock::{create_mock_deap_vm, MockFollower, MockLeader}; + use mpz_share_conversion::ideal::{ideal_share_converter, IdealShareConverter}; + use p256::{NonZeroScalar, PublicKey, SecretKey}; + use rand_chacha::ChaCha12Rng; + use rand_core::SeedableRng; + use serio::channel::MemoryDuplex; + + #[allow(clippy::type_complexity)] + fn create_pair() -> ( + MpcKeyExchange< + STExecutor, + IdealShareConverter, + IdealShareConverter, + MockLeader, + >, + MpcKeyExchange< + STExecutor, + IdealShareConverter, + IdealShareConverter, + MockFollower, + >, + ) { + let (leader_ctx, follower_ctx) = test_st_executor(8); + let (leader_converter_0, follower_converter_0) = ideal_share_converter(); + let (follower_converter_1, leader_converter_1) = ideal_share_converter(); + let (leader_vm, follower_vm) = create_mock_deap_vm(); + + let leader = MpcKeyExchange::new( + KeyExchangeConfig::builder() + .role(Role::Leader) + .build() + .unwrap(), + leader_ctx, + leader_converter_0, + leader_converter_1, + leader_vm, + ); + + let follower = MpcKeyExchange::new( + KeyExchangeConfig::builder() + .role(Role::Follower) + .build() + .unwrap(), + follower_ctx, + follower_converter_0, + follower_converter_1, + follower_vm, + ); + + (leader, follower) + } + + #[tokio::test] + async fn test_key_exchange() { + let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); + + let leader_private_key = SecretKey::random(&mut rng); + let follower_private_key = SecretKey::random(&mut rng); + let server_public_key = PublicKey::from_secret_scalar(&NonZeroScalar::random(&mut rng)); + + let (mut leader, mut follower) = create_pair(); + + leader.private_key = Some(leader_private_key.clone()); + follower.private_key = Some(follower_private_key.clone()); + + tokio::try_join!(leader.setup(), follower.setup()).unwrap(); + tokio::try_join!(leader.preprocess(), follower.preprocess()).unwrap(); + + let client_public_key = leader.client_key().await.unwrap(); + leader.set_server_key(server_public_key).await.unwrap(); + + let expected_client_public_key = PublicKey::from_affine( + (leader_private_key.public_key().to_projective() + + follower_private_key.public_key().to_projective()) + .to_affine(), + ) + .unwrap(); + + assert_eq!(client_public_key, expected_client_public_key); + } + + #[tokio::test] + async fn test_compute_pms() { + let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); + + let leader_private_key = SecretKey::random(&mut rng); + let follower_private_key = SecretKey::random(&mut rng); + let server_private_key = NonZeroScalar::random(&mut rng); + let server_public_key = PublicKey::from_secret_scalar(&server_private_key); + + let (mut leader, mut follower) = create_pair(); + + leader.private_key = Some(leader_private_key); + follower.private_key = Some(follower_private_key); + + tokio::try_join!(leader.setup(), follower.setup()).unwrap(); + tokio::try_join!(leader.preprocess(), follower.preprocess()).unwrap(); + + leader.set_server_key(server_public_key).await.unwrap(); + + let (_leader_pms, _follower_pms) = + tokio::try_join!(leader.compute_pms(), follower.compute_pms()).unwrap(); + + assert_eq!(leader.server_key.unwrap(), server_public_key); + assert_eq!(follower.server_key.unwrap(), server_public_key); + } + + #[tokio::test] + async fn test_compute_pms_shares() { + let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); + let (mut ctx_leader, mut ctx_follower) = test_st_executor(8); + let (mut leader_converter_0, mut follower_converter_0) = ideal_share_converter(); + let (mut follower_converter_1, mut leader_converter_1) = ideal_share_converter(); + + let leader_private_key = SecretKey::random(&mut rng); + let follower_private_key = SecretKey::random(&mut rng); + let server_private_key = NonZeroScalar::random(&mut rng); + let server_public_key = PublicKey::from_secret_scalar(&server_private_key); + + let client_public_key = PublicKey::from_affine( + (leader_private_key.public_key().to_projective() + + follower_private_key.public_key().to_projective()) + .to_affine(), + ) + .unwrap(); + + let ((leader_share_0, leader_share_1), (follower_share_0, follower_share_1)) = + tokio::try_join!( + compute_pms_shares( + &mut ctx_leader, + Role::Leader, + &mut leader_converter_0, + &mut leader_converter_1, + server_public_key, + leader_private_key + ), + compute_pms_shares( + &mut ctx_follower, + Role::Follower, + &mut follower_converter_0, + &mut follower_converter_1, + server_public_key, + follower_private_key + ) + ) + .unwrap(); + + let expected_ecdh_x = + p256::ecdh::diffie_hellman(server_private_key, client_public_key.as_affine()); + + assert_eq!( + (leader_share_0 + follower_share_0).to_be_bytes(), + expected_ecdh_x.raw_secret_bytes().to_vec() + ); + assert_eq!( + (leader_share_1 + follower_share_1).to_be_bytes(), + expected_ecdh_x.raw_secret_bytes().to_vec() + ); + + assert_ne!(leader_share_0, follower_share_0); + assert_ne!(leader_share_1, follower_share_1); + } + + #[tokio::test] + async fn test_compute_pms_fail() { + let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); + + let leader_private_key = SecretKey::random(&mut rng); + let follower_private_key = SecretKey::random(&mut rng); + let server_private_key = NonZeroScalar::random(&mut rng); + let server_public_key = PublicKey::from_secret_scalar(&server_private_key); + + let (mut leader, mut follower) = create_pair(); + + leader.private_key = Some(leader_private_key.clone()); + follower.private_key = Some(follower_private_key.clone()); + + tokio::try_join!(leader.setup(), follower.setup()).unwrap(); + tokio::try_join!(leader.preprocess(), follower.preprocess()).unwrap(); + + leader.set_server_key(server_public_key).await.unwrap(); + + let ((mut share_a0, share_a1), (share_b0, share_b1)) = tokio::try_join!( + leader.compute_pms_shares(server_public_key, leader_private_key), + follower.compute_pms_shares(server_public_key, follower_private_key) + ) + .unwrap(); + + share_a0 = share_a0 + P256::one(); + + let (leader_res, follower_res) = tokio::join!( + leader.compute_pms_with(share_a0, share_a1), + follower.compute_pms_with(share_b0, share_b1) + ); + + let leader_err = leader_res.unwrap_err(); + let follower_err = follower_res.unwrap_err(); + + assert!(matches!(leader_err.kind(), ErrorKind::ShareConversion)); + assert!(matches!(follower_err.kind(), ErrorKind::ShareConversion)); + } +} diff --git a/crates/components/key-exchange/src/lib.rs b/crates/components/key-exchange/src/lib.rs new file mode 100644 index 0000000000..1215ebfe9d --- /dev/null +++ b/crates/components/key-exchange/src/lib.rs @@ -0,0 +1,77 @@ +//! # The Key Exchange Protocol +//! +//! This crate implements a key exchange protocol with 3 parties, namely server, +//! leader and follower. The goal is to end up with a shared secret (ECDH) +//! between the server and the client. The client in this context is leader and +//! follower combined, which means that each of them will end up with a share of +//! the shared secret. The leader will do all the necessary communication +//! with the server alone and forward all messages from and to the follower. +//! +//! A detailed description of this protocol can be found in our documentation +//! . + +#![deny(missing_docs, unreachable_pub, unused_must_use)] +#![deny(clippy::all)] +#![forbid(unsafe_code)] + +mod circuit; +mod config; +pub(crate) mod error; +mod exchange; +#[cfg(feature = "mock")] +pub mod mock; +pub(crate) mod point_addition; + +pub use config::{ + KeyExchangeConfig, KeyExchangeConfigBuilder, KeyExchangeConfigBuilderError, Role, +}; +pub use error::KeyExchangeError; +pub use exchange::MpcKeyExchange; + +use async_trait::async_trait; +use mpz_garble::value::ValueRef; +use p256::PublicKey; + +/// Pre-master secret. +#[derive(Debug, Clone)] +pub struct Pms(ValueRef); + +impl Pms { + /// Creates a new PMS. + pub fn new(value: ValueRef) -> Self { + Self(value) + } + + /// Gets the value of the PMS. + pub fn into_value(self) -> ValueRef { + self.0 + } +} + +/// A trait for the 3-party key exchange protocol. +#[async_trait] +pub trait KeyExchange { + /// Gets the server's public key. + fn server_key(&self) -> Option; + + /// Sets the server's public key. + async fn set_server_key(&mut self, server_key: PublicKey) -> Result<(), KeyExchangeError>; + + /// Computes the client's public key. + /// + /// The client's public key in this context is the combined public key (EC + /// point addition) of the leader's public key and the follower's public + /// key. + async fn client_key(&mut self) -> Result; + + /// Performs any necessary one-time setup, returning a reference to the PMS. + /// + /// The PMS will not be assigned until `compute_pms` is called. + async fn setup(&mut self) -> Result; + + /// Preprocesses the key exchange. + async fn preprocess(&mut self) -> Result<(), KeyExchangeError>; + + /// Computes the PMS. + async fn compute_pms(&mut self) -> Result; +} diff --git a/crates/components/key-exchange/src/mock.rs b/crates/components/key-exchange/src/mock.rs new file mode 100644 index 0000000000..00d9fe9970 --- /dev/null +++ b/crates/components/key-exchange/src/mock.rs @@ -0,0 +1,71 @@ +//! This module provides mock types for key exchange leader and follower and a +//! function to create such a pair. + +use crate::{KeyExchangeConfig, MpcKeyExchange, Role}; + +use mpz_common::executor::{test_st_executor, STExecutor}; +use mpz_garble::{Decode, Execute, Memory}; +use mpz_share_conversion::ideal::{ideal_share_converter, IdealShareConverter}; +use serio::channel::MemoryDuplex; + +/// A mock key exchange instance. +pub type MockKeyExchange = + MpcKeyExchange, IdealShareConverter, IdealShareConverter, E>; + +/// Creates a mock pair of key exchange leader and follower. +pub fn create_mock_key_exchange_pair( + leader_executor: E, + follower_executor: E, +) -> (MockKeyExchange, MockKeyExchange) { + let (leader_ctx, follower_ctx) = test_st_executor(8); + let (leader_converter_0, follower_converter_0) = ideal_share_converter(); + let (leader_converter_1, follower_converter_1) = ideal_share_converter(); + + let key_exchange_config_leader = KeyExchangeConfig::builder() + .role(Role::Leader) + .build() + .unwrap(); + + let key_exchange_config_follower = KeyExchangeConfig::builder() + .role(Role::Follower) + .build() + .unwrap(); + + let leader = MpcKeyExchange::new( + key_exchange_config_leader, + leader_ctx, + leader_converter_0, + leader_converter_1, + leader_executor, + ); + + let follower = MpcKeyExchange::new( + key_exchange_config_follower, + follower_ctx, + follower_converter_0, + follower_converter_1, + follower_executor, + ); + + (leader, follower) +} + +#[cfg(test)] +mod tests { + use mpz_garble::protocol::deap::mock::create_mock_deap_vm; + + use crate::KeyExchange; + + use super::*; + + #[test] + fn test_mock_is_ke() { + let (leader_vm, follower_vm) = create_mock_deap_vm(); + let (leader, follower) = create_mock_key_exchange_pair(leader_vm, follower_vm); + + fn is_key_exchange(_: T) {} + + is_key_exchange(leader); + is_key_exchange(follower); + } +} diff --git a/components/key-exchange/src/msg.rs b/crates/components/key-exchange/src/msg.rs similarity index 75% rename from components/key-exchange/src/msg.rs rename to crates/components/key-exchange/src/msg.rs index 8aa0d9c634..ddb1f24699 100644 --- a/components/key-exchange/src/msg.rs +++ b/crates/components/key-exchange/src/msg.rs @@ -1,11 +1,12 @@ -//! This module contains the message types exchanged between user and notary +//! This module contains the message types exchanged between the prover and the TLS verifier. use std::fmt::{self, Display, Formatter}; use p256::{elliptic_curve::sec1::ToEncodedPoint, PublicKey as P256PublicKey}; use serde::{Deserialize, Serialize}; -/// A type for messages exchanged between user and notary during the key exchange protocol +/// A type for messages exchanged between the prover and the TLS verifier during the key exchange +/// protocol. #[derive(Debug, Clone, Serialize, Deserialize)] #[allow(missing_docs)] pub enum KeyExchangeMessage { @@ -13,14 +14,14 @@ pub enum KeyExchangeMessage { ServerPublicKey(PublicKey), } -/// A wrapper for a serialized public key +/// A wrapper for a serialized public key. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PublicKey { - /// The sec1 serialized public key + /// The sec1 serialized public key. pub key: Vec, } -/// An error that can occur during parsing of a public key +/// An error that can occur during parsing of a public key. #[derive(Debug, thiserror::Error)] pub struct KeyParseError(#[from] p256::elliptic_curve::Error); diff --git a/crates/components/key-exchange/src/point_addition.rs b/crates/components/key-exchange/src/point_addition.rs new file mode 100644 index 0000000000..d95056d8a6 --- /dev/null +++ b/crates/components/key-exchange/src/point_addition.rs @@ -0,0 +1,143 @@ +//! This module implements a secure two-party computation protocol for adding +//! two private EC points and secret-sharing the resulting x coordinate (the +//! shares are field elements of the field underlying the elliptic curve). +//! This protocol has semi-honest security. +//! +//! The protocol is described in + +use mpz_common::Context; +use mpz_fields::{p256::P256, Field}; +use mpz_share_conversion::{AdditiveToMultiplicative, MultiplicativeToAdditive}; +use p256::EncodedPoint; + +use crate::{config::Role, error::ErrorKind, KeyExchangeError}; + +/// Derives the x-coordinate share of an elliptic curve point. +pub(crate) async fn derive_x_coord_share( + role: Role, + ctx: &mut Ctx, + converter: &mut C, + share: EncodedPoint, +) -> Result +where + Ctx: Context, + C: AdditiveToMultiplicative + MultiplicativeToAdditive, +{ + let [x, y] = decompose_point(share)?; + + // Follower negates their share coordinates. + let inputs = match role { + Role::Leader => vec![y, x], + Role::Follower => vec![-y, -x], + }; + + let [a, b] = converter + .to_multiplicative(ctx, inputs) + .await? + .try_into() + .expect("output is same length as input"); + + let c = a * b.inverse(); + let c = c * c; + + let d = converter.to_additive(ctx, vec![c]).await?[0]; + + let x_r = d + -x; + + Ok(x_r) +} + +/// Decomposes the x and y coordinates of a SEC1 encoded point. +fn decompose_point(point: EncodedPoint) -> Result<[P256; 2], KeyExchangeError> { + // Coordinates are stored as big-endian bytes. + let mut x: [u8; 32] = (*point.x().ok_or(KeyExchangeError::new( + ErrorKind::Key, + "key share is an identity point", + ))?) + .into(); + let mut y: [u8; 32] = (*point.y().ok_or(KeyExchangeError::new( + ErrorKind::Key, + "key share is an identity point or compressed", + ))?) + .into(); + + // Reverse to little endian. + x.reverse(); + y.reverse(); + + let x = P256::try_from(x).unwrap(); + let y = P256::try_from(y).unwrap(); + + Ok([x, y]) +} + +#[cfg(test)] +mod tests { + use super::*; + + use mpz_common::executor::test_st_executor; + use mpz_fields::{p256::P256, Field}; + use mpz_share_conversion::ideal::ideal_share_converter; + use p256::{ + elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint}, + EncodedPoint, NonZeroScalar, ProjectivePoint, PublicKey, + }; + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha12Rng; + + #[tokio::test] + async fn test_point_addition() { + let (mut ctx_a, mut ctx_b) = test_st_executor(8); + let mut rng = ChaCha12Rng::from_seed([0u8; 32]); + + let p1: [u8; 32] = rng.gen(); + let p2: [u8; 32] = rng.gen(); + + let p1 = curve_point_from_be_bytes(p1); + let p2 = curve_point_from_be_bytes(p2); + + let p = add_curve_points(&p1, &p2); + + let (mut c_a, mut c_b) = ideal_share_converter(); + + let (a, b) = tokio::try_join!( + derive_x_coord_share(Role::Leader, &mut ctx_a, &mut c_a, p1), + derive_x_coord_share(Role::Follower, &mut ctx_b, &mut c_b, p2) + ) + .unwrap(); + + let [expected_x, _] = decompose_point(p).unwrap(); + + assert_eq!(expected_x, a + b); + } + + #[test] + fn test_decompose_point() { + let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); + + let p_expected: [u8; 32] = rng.gen(); + let p_expected = curve_point_from_be_bytes(p_expected); + + let p256: [P256; 2] = decompose_point(p_expected).unwrap(); + + let x: [u8; 32] = p256[0].to_be_bytes().try_into().unwrap(); + let y: [u8; 32] = p256[1].to_be_bytes().try_into().unwrap(); + + let p = EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), false); + + assert_eq!(p_expected, p); + } + + fn curve_point_from_be_bytes(bytes: [u8; 32]) -> EncodedPoint { + let scalar = NonZeroScalar::from_repr(bytes.into()).unwrap(); + let pk = PublicKey::from_secret_scalar(&scalar); + pk.to_encoded_point(false) + } + + fn add_curve_points(p1: &EncodedPoint, p2: &EncodedPoint) -> EncodedPoint { + let p1 = ProjectivePoint::from_encoded_point(p1).unwrap(); + let p2 = ProjectivePoint::from_encoded_point(p2).unwrap(); + let p = p1 + p2; + p.to_encoded_point(false) + } +} diff --git a/components/cipher/stream-cipher/Cargo.toml b/crates/components/stream-cipher/Cargo.toml similarity index 55% rename from components/cipher/stream-cipher/Cargo.toml rename to crates/components/stream-cipher/Cargo.toml index 850c6a232c..4a670aac89 100644 --- a/components/cipher/stream-cipher/Cargo.toml +++ b/crates/components/stream-cipher/Cargo.toml @@ -5,28 +5,29 @@ description = "2PC stream cipher implementation" keywords = ["tls", "mpc", "2pc", "stream-cipher"] categories = ["cryptography"] license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" +version = "0.1.0-alpha.7" edition = "2021" [features] default = ["mock"] -tracing = ["dep:tracing"] +rayon = ["mpz-garble/rayon"] mock = [] [dependencies] -mpz-circuits.workspace = true -mpz-garble.workspace = true -tlsn-utils.workspace = true -aes.workspace = true -ctr.workspace = true -cipher.workspace = true -async-trait.workspace = true -thiserror.workspace = true -derive_builder.workspace = true -tracing = { workspace = true, optional = true } +mpz-circuits = { workspace = true } +mpz-garble = { workspace = true } +tlsn-utils = { workspace = true } +aes = { workspace = true } +ctr = { workspace = true } +cipher = { workspace = true } +async-trait = { workspace = true } +thiserror = { workspace = true } +derive_builder = { workspace = true } +tracing = { workspace = true } +opaque-debug = { workspace = true } [dev-dependencies] -futures.workspace = true +futures = { workspace = true } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } rstest = { workspace = true, features = ["async-timeout"] } criterion = { workspace = true, features = ["async_tokio"] } diff --git a/crates/components/stream-cipher/benches/mock.rs b/crates/components/stream-cipher/benches/mock.rs new file mode 100644 index 0000000000..3f782257f9 --- /dev/null +++ b/crates/components/stream-cipher/benches/mock.rs @@ -0,0 +1,132 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; + +use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory}; +use tlsn_stream_cipher::{ + Aes128Ctr, CtrCircuit, MpcStreamCipher, StreamCipher, StreamCipherConfigBuilder, +}; + +async fn bench_stream_cipher_encrypt(len: usize) { + let (leader_vm, follower_vm) = create_mock_deap_vm(); + + let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap(); + let leader_iv = leader_vm.new_public_input::<[u8; 4]>("iv").unwrap(); + + leader_vm.assign(&leader_key, [0u8; 16]).unwrap(); + leader_vm.assign(&leader_iv, [0u8; 4]).unwrap(); + + let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap(); + let follower_iv = follower_vm.new_public_input::<[u8; 4]>("iv").unwrap(); + + follower_vm.assign(&follower_key, [0u8; 16]).unwrap(); + follower_vm.assign(&follower_iv, [0u8; 4]).unwrap(); + + let leader_config = StreamCipherConfigBuilder::default() + .id("test".to_string()) + .build() + .unwrap(); + + let follower_config = StreamCipherConfigBuilder::default() + .id("test".to_string()) + .build() + .unwrap(); + + let mut leader = MpcStreamCipher::::new(leader_config, leader_vm); + leader.set_key(leader_key, leader_iv); + + let mut follower = MpcStreamCipher::::new(follower_config, follower_vm); + follower.set_key(follower_key, follower_iv); + + let plaintext = vec![0u8; len]; + let explicit_nonce = vec![0u8; 8]; + + _ = tokio::try_join!( + leader.encrypt_private(explicit_nonce.clone(), plaintext), + follower.encrypt_blind(explicit_nonce, len) + ) + .unwrap(); + + _ = tokio::try_join!( + leader.thread_mut().finalize(), + follower.thread_mut().finalize() + ) + .unwrap(); +} + +async fn bench_stream_cipher_zk(len: usize) { + let (leader_vm, follower_vm) = create_mock_deap_vm(); + + let key = [0u8; 16]; + let iv = [0u8; 4]; + + let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap(); + let leader_iv = leader_vm.new_public_input::<[u8; 4]>("iv").unwrap(); + + leader_vm.assign(&leader_key, key).unwrap(); + leader_vm.assign(&leader_iv, iv).unwrap(); + + let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap(); + let follower_iv = follower_vm.new_public_input::<[u8; 4]>("iv").unwrap(); + + follower_vm.assign(&follower_key, key).unwrap(); + follower_vm.assign(&follower_iv, iv).unwrap(); + + let leader_config = StreamCipherConfigBuilder::default() + .id("test".to_string()) + .build() + .unwrap(); + + let follower_config = StreamCipherConfigBuilder::default() + .id("test".to_string()) + .build() + .unwrap(); + + let mut leader = MpcStreamCipher::::new(leader_config, leader_vm); + leader.set_key(leader_key, leader_iv); + + let mut follower = MpcStreamCipher::::new(follower_config, follower_vm); + follower.set_key(follower_key, follower_iv); + + futures::try_join!(leader.decode_key_private(), follower.decode_key_blind()).unwrap(); + + let plaintext = vec![0u8; len]; + let explicit_nonce = [0u8; 8]; + let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &plaintext).unwrap(); + + _ = tokio::try_join!( + leader.prove_plaintext(explicit_nonce.to_vec(), plaintext), + follower.verify_plaintext(explicit_nonce.to_vec(), ciphertext) + ) + .unwrap(); + + _ = tokio::try_join!( + leader.thread_mut().finalize(), + follower.thread_mut().finalize() + ) + .unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let rt = tokio::runtime::Runtime::new().unwrap(); + let len = 1024; + + let mut group = c.benchmark_group("stream_cipher/encrypt_private"); + group.throughput(Throughput::Bytes(len as u64)); + group.bench_function(BenchmarkId::from_parameter(len), |b| { + b.to_async(&rt) + .iter(|| async { bench_stream_cipher_encrypt(len).await }) + }); + + drop(group); + + let mut group = c.benchmark_group("stream_cipher/zk"); + group.throughput(Throughput::Bytes(len as u64)); + group.bench_function(BenchmarkId::from_parameter(len), |b| { + b.to_async(&rt) + .iter(|| async { bench_stream_cipher_zk(len).await }) + }); + + drop(group); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/components/cipher/stream-cipher/src/cipher.rs b/crates/components/stream-cipher/src/cipher.rs similarity index 75% rename from components/cipher/stream-cipher/src/cipher.rs rename to crates/components/stream-cipher/src/cipher.rs index f9e75e3601..376b3c77a6 100644 --- a/components/cipher/stream-cipher/src/cipher.rs +++ b/crates/components/stream-cipher/src/cipher.rs @@ -9,9 +9,9 @@ use crate::{circuit::AES_CTR, StreamCipherError}; /// A counter-mode block cipher circuit. pub trait CtrCircuit: Default + Clone + Send + Sync + 'static { - /// The key type + /// The key type. type KEY: StaticValueType + TryFrom> + Send + Sync + 'static; - /// The block type + /// The block type. type BLOCK: StaticValueType + TryFrom> + TryFrom @@ -20,7 +20,7 @@ pub trait CtrCircuit: Default + Clone + Send + Sync + 'static { + Send + Sync + 'static; - /// The IV type + /// The IV type. type IV: StaticValueType + TryFrom> + TryFrom @@ -28,7 +28,7 @@ pub trait CtrCircuit: Default + Clone + Send + Sync + 'static { + Send + Sync + 'static; - /// The nonce type + /// The nonce type. type NONCE: StaticValueType + TryFrom> + TryFrom @@ -40,19 +40,19 @@ pub trait CtrCircuit: Default + Clone + Send + Sync + 'static { + std::fmt::Debug + 'static; - /// The length of the key + /// The length of the key. const KEY_LEN: usize; - /// The length of the block + /// The length of the block. const BLOCK_LEN: usize; - /// The length of the IV + /// The length of the IV. const IV_LEN: usize; - /// The length of the nonce + /// The length of the nonce. const NONCE_LEN: usize; - /// Returns the circuit of the cipher + /// Returns the circuit of the cipher. fn circuit() -> Arc; - /// Applies the keystream to the message + /// Applies the keystream to the message. fn apply_keystream( key: &[u8], iv: &[u8], @@ -94,22 +94,13 @@ impl CtrCircuit for Aes128Ctr { let key: &[u8; 16] = key .try_into() - .map_err(|_| StreamCipherError::InvalidKeyLength { - expected: 16, - actual: key.len(), - })?; + .map_err(|_| StreamCipherError::key_len::(key.len()))?; let iv: &[u8; 4] = iv .try_into() - .map_err(|_| StreamCipherError::InvalidIvLength { - expected: 4, - actual: iv.len(), - })?; - let explicit_nonce: &[u8; 8] = explicit_nonce.try_into().map_err(|_| { - StreamCipherError::InvalidExplicitNonceLength { - expected: 8, - actual: explicit_nonce.len(), - } - })?; + .map_err(|_| StreamCipherError::iv_len::(iv.len()))?; + let explicit_nonce: &[u8; 8] = explicit_nonce + .try_into() + .map_err(|_| StreamCipherError::explicit_nonce_len::(explicit_nonce.len()))?; let mut full_iv = [0u8; 16]; full_iv[0..4].copy_from_slice(iv); diff --git a/components/cipher/stream-cipher/src/circuit.rs b/crates/components/stream-cipher/src/circuit.rs similarity index 98% rename from components/cipher/stream-cipher/src/circuit.rs rename to crates/components/stream-cipher/src/circuit.rs index bf1ea0a78a..b3ae66f487 100644 --- a/components/cipher/stream-cipher/src/circuit.rs +++ b/crates/components/stream-cipher/src/circuit.rs @@ -1,7 +1,7 @@ use mpz_circuits::{circuits::aes128_trace, once_cell::sync::Lazy, trace, Circuit, CircuitBuilder}; use std::sync::Arc; -/// AES encrypt counter block. +/// AES encrypts a counter block. /// /// # Inputs /// diff --git a/components/cipher/stream-cipher/src/config.rs b/crates/components/stream-cipher/src/config.rs similarity index 62% rename from components/cipher/stream-cipher/src/config.rs rename to crates/components/stream-cipher/src/config.rs index 270d80c565..4af771da21 100644 --- a/components/cipher/stream-cipher/src/config.rs +++ b/crates/components/stream-cipher/src/config.rs @@ -1,11 +1,6 @@ -use std::marker::PhantomData; - use derive_builder::Builder; -use mpz_garble::value::ValueRef; use std::fmt::Debug; -use crate::CtrCircuit; - /// Configuration for a stream cipher. #[derive(Debug, Clone, Builder)] pub struct StreamCipherConfig { @@ -28,37 +23,6 @@ impl StreamCipherConfig { } } -pub(crate) struct KeyBlockConfig { - pub(crate) key: ValueRef, - pub(crate) iv: ValueRef, - pub(crate) explicit_nonce: C::NONCE, - pub(crate) ctr: u32, - _pd: PhantomData, -} - -impl Debug for KeyBlockConfig { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("KeyBlockConfig") - .field("key", &self.key) - .field("iv", &self.iv) - .field("explicit_nonce", &self.explicit_nonce) - .field("ctr", &self.ctr) - .finish() - } -} - -impl KeyBlockConfig { - pub(crate) fn new(key: ValueRef, iv: ValueRef, explicit_nonce: C::NONCE, ctr: u32) -> Self { - Self { - key, - iv, - explicit_nonce, - ctr, - _pd: PhantomData, - } - } -} - pub(crate) enum InputText { Public { ids: Vec, text: Vec }, Private { ids: Vec, text: Vec }, @@ -82,3 +46,23 @@ impl std::fmt::Debug for InputText { } } } + +/// The mode of execution. +#[derive(Debug, Clone, Copy)] +pub(crate) enum ExecutionMode { + /// Computes either the plaintext or the ciphertext. + Mpc, + /// Computes the ciphertext and proves its authenticity and correctness. + Prove, + /// Computes the ciphertext and verifies its authenticity and correctness. + Verify, +} + +pub(crate) fn is_valid_mode(mode: &ExecutionMode, input_text: &InputText) -> bool { + matches!( + (mode, input_text), + (ExecutionMode::Mpc, _) + | (ExecutionMode::Prove, InputText::Private { .. }) + | (ExecutionMode::Verify, InputText::Blind { .. }) + ) +} diff --git a/crates/components/stream-cipher/src/error.rs b/crates/components/stream-cipher/src/error.rs new file mode 100644 index 0000000000..a57773609b --- /dev/null +++ b/crates/components/stream-cipher/src/error.rs @@ -0,0 +1,122 @@ +use core::fmt; +use std::error::Error; + +use crate::CtrCircuit; + +/// A stream cipher error. +#[derive(Debug, thiserror::Error)] +pub struct StreamCipherError { + kind: ErrorKind, + #[source] + source: Option>, +} + +impl StreamCipherError { + pub(crate) fn new(kind: ErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } + + pub(crate) fn key_len(len: usize) -> Self { + Self { + kind: ErrorKind::Key, + source: Some( + format!("invalid key length: expected {}, got {}", C::KEY_LEN, len).into(), + ), + } + } + + pub(crate) fn iv_len(len: usize) -> Self { + Self { + kind: ErrorKind::Iv, + source: Some(format!("invalid iv length: expected {}, got {}", C::IV_LEN, len).into()), + } + } + + pub(crate) fn explicit_nonce_len(len: usize) -> Self { + Self { + kind: ErrorKind::ExplicitNonce, + source: Some( + format!( + "invalid explicit nonce length: expected {}, got {}", + C::NONCE_LEN, + len + ) + .into(), + ), + } + } + + pub(crate) fn key_not_set() -> Self { + Self { + kind: ErrorKind::Key, + source: Some("key not set".into()), + } + } +} + +#[derive(Debug)] +pub(crate) enum ErrorKind { + Vm, + Key, + Iv, + ExplicitNonce, +} + +impl fmt::Display for StreamCipherError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.kind { + ErrorKind::Vm => write!(f, "vm error")?, + ErrorKind::Key => write!(f, "key error")?, + ErrorKind::Iv => write!(f, "iv error")?, + ErrorKind::ExplicitNonce => write!(f, "explicit nonce error")?, + } + + if let Some(ref source) = self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +impl From for StreamCipherError { + fn from(error: mpz_garble::MemoryError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for StreamCipherError { + fn from(error: mpz_garble::LoadError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for StreamCipherError { + fn from(error: mpz_garble::ExecutionError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for StreamCipherError { + fn from(error: mpz_garble::ProveError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for StreamCipherError { + fn from(error: mpz_garble::VerifyError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} + +impl From for StreamCipherError { + fn from(error: mpz_garble::DecodeError) -> Self { + Self::new(ErrorKind::Vm, error) + } +} diff --git a/crates/components/stream-cipher/src/keystream.rs b/crates/components/stream-cipher/src/keystream.rs new file mode 100644 index 0000000000..1c020da90d --- /dev/null +++ b/crates/components/stream-cipher/src/keystream.rs @@ -0,0 +1,216 @@ +use std::{collections::VecDeque, marker::PhantomData}; + +use mpz_garble::{value::ValueRef, Execute, Load, Memory, Prove, Thread, Verify}; +use tracing::instrument; +use utils::id::NestedId; + +use crate::{config::ExecutionMode, CtrCircuit, StreamCipherError}; + +pub(crate) struct KeyStream { + block_counter: NestedId, + preprocessed: BlockVars, + _pd: PhantomData, +} + +#[derive(Default)] +struct BlockVars { + blocks: VecDeque, + nonces: VecDeque, + ctrs: VecDeque, +} + +impl BlockVars { + fn is_empty(&self) -> bool { + self.blocks.is_empty() + } + + fn len(&self) -> usize { + self.blocks.len() + } + + fn drain(&mut self, count: usize) -> BlockVars { + let blocks = self.blocks.drain(0..count).collect(); + let nonces = self.nonces.drain(0..count).collect(); + let ctrs = self.ctrs.drain(0..count).collect(); + + BlockVars { + blocks, + nonces, + ctrs, + } + } + + fn extend(&mut self, vars: BlockVars) { + self.blocks.extend(vars.blocks); + self.nonces.extend(vars.nonces); + self.ctrs.extend(vars.ctrs); + } + + fn iter(&self) -> impl Iterator { + self.blocks + .iter() + .zip(self.nonces.iter()) + .zip(self.ctrs.iter()) + .map(|((block, nonce), ctr)| (block, nonce, ctr)) + } + + fn flatten(&self, len: usize) -> Vec { + self.blocks + .iter() + .flat_map(|block| block.iter()) + .take(len) + .cloned() + .map(|byte| ValueRef::Value { id: byte }) + .collect() + } +} + +impl KeyStream { + pub(crate) fn new(id: &str) -> Self { + let block_counter = NestedId::new(id).append_counter(); + Self { + block_counter, + preprocessed: BlockVars::default(), + _pd: PhantomData, + } + } + + fn define_vars( + &mut self, + mem: &mut impl Memory, + count: usize, + ) -> Result { + let mut vars = BlockVars::default(); + for _ in 0..count { + let block_id = self.block_counter.increment_in_place(); + let block = mem.new_output::(&block_id.to_string())?; + let nonce = + mem.new_public_input::(&block_id.append_string("nonce").to_string())?; + let ctr = + mem.new_public_input::<[u8; 4]>(&block_id.append_string("ctr").to_string())?; + + vars.blocks.push_back(block); + vars.nonces.push_back(nonce); + vars.ctrs.push_back(ctr); + } + + Ok(vars) + } + + #[instrument(level = "debug", skip_all, err)] + pub(crate) async fn preprocess( + &mut self, + thread: &mut T, + key: &ValueRef, + iv: &ValueRef, + len: usize, + ) -> Result<(), StreamCipherError> + where + T: Thread + Memory + Load + Send + 'static, + { + let block_count = (len / C::BLOCK_LEN) + (len % C::BLOCK_LEN != 0) as usize; + let vars = self.define_vars(thread, block_count)?; + + let calls = vars + .iter() + .map(|(block, nonce, ctr)| { + ( + C::circuit(), + vec![key.clone(), iv.clone(), nonce.clone(), ctr.clone()], + vec![block.clone()], + ) + }) + .collect::>(); + + for (circ, inputs, outputs) in calls { + thread.load(circ, &inputs, &outputs).await?; + } + + self.preprocessed.extend(vars); + + Ok(()) + } + + #[instrument(level = "debug", skip_all, err)] + #[allow(clippy::too_many_arguments)] + pub(crate) async fn compute( + &mut self, + thread: &mut T, + mode: ExecutionMode, + key: &ValueRef, + iv: &ValueRef, + explicit_nonce: Vec, + start_ctr: usize, + len: usize, + ) -> Result + where + T: Thread + Memory + Execute + Prove + Verify + Send + 'static, + { + let block_count = (len / C::BLOCK_LEN) + (len % C::BLOCK_LEN != 0) as usize; + let explicit_nonce_len = explicit_nonce.len(); + let explicit_nonce: C::NONCE = explicit_nonce + .try_into() + .map_err(|_| StreamCipherError::explicit_nonce_len::(explicit_nonce_len))?; + + // Take any preprocessed blocks if available, and define new ones if needed. + let vars = if !self.preprocessed.is_empty() { + let mut vars = self + .preprocessed + .drain(block_count.min(self.preprocessed.len())); + if vars.len() < block_count { + vars.extend(self.define_vars(thread, block_count - vars.len())?) + } + vars + } else { + self.define_vars(thread, block_count)? + }; + + let mut calls = Vec::with_capacity(vars.len()); + let mut inputs = Vec::with_capacity(vars.len() * 4); + for (i, (block, nonce_ref, ctr_ref)) in vars.iter().enumerate() { + thread.assign(nonce_ref, explicit_nonce)?; + thread.assign(ctr_ref, ((start_ctr + i) as u32).to_be_bytes())?; + + inputs.push(key.clone()); + inputs.push(iv.clone()); + inputs.push(nonce_ref.clone()); + inputs.push(ctr_ref.clone()); + + calls.push(( + C::circuit(), + vec![key.clone(), iv.clone(), nonce_ref.clone(), ctr_ref.clone()], + vec![block.clone()], + )); + } + + match mode { + ExecutionMode::Mpc => { + thread.commit(&inputs).await?; + for (circ, inputs, outputs) in calls { + thread.execute(circ, &inputs, &outputs).await?; + } + } + ExecutionMode::Prove => { + // Note that after the circuit execution, the value of `block` can be considered + // as implicitly authenticated since `key` and `iv` have already + // been authenticated earlier and `nonce_ref` and `ctr_ref` are + // public. [Prove::prove] will **not** be called on `block` at + // any later point. + thread.commit_prove(&inputs).await?; + for (circ, inputs, outputs) in calls { + thread.execute_prove(circ, &inputs, &outputs).await?; + } + } + ExecutionMode::Verify => { + thread.commit_verify(&inputs).await?; + for (circ, inputs, outputs) in calls { + thread.execute_verify(circ, &inputs, &outputs).await?; + } + } + } + + let keystream = thread.array_from_values(&vars.flatten(len))?; + + Ok(keystream) + } +} diff --git a/components/cipher/stream-cipher/src/lib.rs b/crates/components/stream-cipher/src/lib.rs similarity index 66% rename from components/cipher/stream-cipher/src/lib.rs rename to crates/components/stream-cipher/src/lib.rs index 1de38ff869..1acdb82063 100644 --- a/components/cipher/stream-cipher/src/lib.rs +++ b/crates/components/stream-cipher/src/lib.rs @@ -1,14 +1,19 @@ -//! This crate provides a 2PC stream cipher implementation using a block cipher in counter mode. +//! This crate provides a 2PC stream cipher implementation using a block cipher +//! in counter mode. //! -//! Each party plays a specific role, either the `StreamCipherLeader` or the `StreamCipherFollower`. Both parties -//! work together to encrypt and decrypt messages using a shared key. +//! Each party plays a specific role, either the `StreamCipherLeader` or the +//! `StreamCipherFollower`. Both parties work together to encrypt and decrypt +//! messages using a shared key. //! //! # Transcript //! -//! Using the `record` flag, the `StreamCipherFollower` can optionally use a dedicated stream when encoding the plaintext labels, which -//! allows the `StreamCipherLeader` to build a transcript of active labels which are pushed to the provided `TranscriptSink`. +//! Using the `record` flag, the `StreamCipherFollower` can optionally use a +//! dedicated stream when encoding the plaintext labels, which allows the +//! `StreamCipherLeader` to build a transcript of active labels which are pushed +//! to the provided `TranscriptSink`. //! -//! Afterwards, the `StreamCipherLeader` can create commitments to the transcript which can be used in a selective disclosure protocol. +//! Afterwards, the `StreamCipherLeader` can create commitments to the +//! transcript which can be used in a selective disclosure protocol. #![deny(missing_docs, unreachable_pub, unused_must_use)] #![deny(clippy::all)] @@ -17,41 +22,18 @@ mod cipher; mod circuit; mod config; +pub(crate) mod error; +pub(crate) mod keystream; mod stream_cipher; pub use self::cipher::{Aes128Ctr, CtrCircuit}; pub use config::{StreamCipherConfig, StreamCipherConfigBuilder, StreamCipherConfigBuilderError}; +pub use error::StreamCipherError; pub use stream_cipher::MpcStreamCipher; use async_trait::async_trait; use mpz_garble::value::ValueRef; -/// Error that can occur when using a stream cipher -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum StreamCipherError { - #[error(transparent)] - MemoryError(#[from] mpz_garble::MemoryError), - #[error(transparent)] - ExecutionError(#[from] mpz_garble::ExecutionError), - #[error(transparent)] - DecodeError(#[from] mpz_garble::DecodeError), - #[error(transparent)] - ProveError(#[from] mpz_garble::ProveError), - #[error(transparent)] - VerifyError(#[from] mpz_garble::VerifyError), - #[error("key and iv is not set")] - KeyIvNotSet, - #[error("invalid key length: expected {expected}, got {actual}")] - InvalidKeyLength { expected: usize, actual: usize }, - #[error("invalid iv length: expected {expected}, got {actual}")] - InvalidIvLength { expected: usize, actual: usize }, - #[error("invalid explicit nonce length: expected {expected}, got {actual}")] - InvalidExplicitNonceLength { expected: usize, actual: usize }, - #[error("missing value for {0}")] - MissingValue(String), -} - /// A trait for MPC stream ciphers. #[async_trait] pub trait StreamCipher: Send + Sync @@ -64,7 +46,8 @@ where /// Decodes the key for the stream cipher, revealing it to this party. async fn decode_key_private(&mut self) -> Result<(), StreamCipherError>; - /// Decodes the key for the stream cipher, revealing it to the other party(s). + /// Decodes the key for the stream cipher, revealing it to the other + /// party(s). async fn decode_key_blind(&mut self) -> Result<(), StreamCipherError>; /// Sets the transcript id @@ -72,23 +55,27 @@ where /// The stream cipher assigns unique identifiers to each byte of plaintext /// during encryption and decryption. /// - /// For example, if the transcript id is set to `foo`, then the first byte will - /// be assigned the id `foo/0`, the second byte `foo/1`, and so on. + /// For example, if the transcript id is set to `foo`, then the first byte + /// will be assigned the id `foo/0`, the second byte `foo/1`, and so on. /// /// Each transcript id has an independent counter. /// /// # Note /// - /// The state of a transcript counter is preserved between calls to `set_transcript_id`. + /// The state of a transcript counter is preserved between calls to + /// `set_transcript_id`. fn set_transcript_id(&mut self, id: &str); + /// Preprocesses the keystream for the given number of bytes. + async fn preprocess(&mut self, len: usize) -> Result<(), StreamCipherError>; + /// Applies the keystream to the given plaintext, where all parties /// provide the plaintext as an input. /// /// # Arguments /// - /// * `explicit_nonce`: The explicit nonce to use for the keystream. - /// * `plaintext`: The message to apply the keystream to. + /// * `explicit_nonce` - The explicit nonce to use for the keystream. + /// * `plaintext` - The message to apply the keystream to. async fn encrypt_public( &mut self, explicit_nonce: Vec, @@ -100,8 +87,8 @@ where /// /// # Arguments /// - /// * `explicit_nonce`: The explicit nonce to use for the keystream. - /// * `plaintext`: The message to apply the keystream to. + /// * `explicit_nonce` - The explicit nonce to use for the keystream. + /// * `plaintext` - The message to apply the keystream to. async fn encrypt_private( &mut self, explicit_nonce: Vec, @@ -112,8 +99,8 @@ where /// /// # Arguments /// - /// * `explicit_nonce`: The explicit nonce to use for the keystream. - /// * `len`: The length of the plaintext provided by another party. + /// * `explicit_nonce` - The explicit nonce to use for the keystream. + /// * `len` - The length of the plaintext provided by another party. async fn encrypt_blind( &mut self, explicit_nonce: Vec, @@ -125,8 +112,8 @@ where /// /// # Arguments /// - /// * `explicit_nonce`: The explicit nonce to use for the keystream. - /// * `ciphertext`: The ciphertext to decrypt. + /// * `explicit_nonce` - The explicit nonce to use for the keystream. + /// * `ciphertext` - The ciphertext to decrypt. async fn decrypt_public( &mut self, explicit_nonce: Vec, @@ -138,8 +125,8 @@ where /// /// # Arguments /// - /// * `explicit_nonce`: The explicit nonce to use for the keystream. - /// * `ciphertext`: The ciphertext to decrypt. + /// * `explicit_nonce` - The explicit nonce to use for the keystream. + /// * `ciphertext` - The ciphertext to decrypt. async fn decrypt_private( &mut self, explicit_nonce: Vec, @@ -151,50 +138,52 @@ where /// /// # Arguments /// - /// * `explicit_nonce`: The explicit nonce to use for the keystream. - /// * `ciphertext`: The ciphertext to decrypt. + /// * `explicit_nonce` - The explicit nonce to use for the keystream. + /// * `ciphertext` - The ciphertext to decrypt. async fn decrypt_blind( &mut self, explicit_nonce: Vec, ciphertext: Vec, ) -> Result<(), StreamCipherError>; - /// Locally decrypts the provided ciphertext and then proves in ZK to the other party(s) that the - /// plaintext is correct. + /// Locally decrypts the provided ciphertext and then proves in ZK to the + /// other party(s) that the plaintext is correct. /// /// Returns the plaintext. /// - /// This method requires this party to know the encryption key, which can be achieved by calling - /// the `decode_key_private` method. + /// This method requires this party to know the encryption key, which can be + /// achieved by calling the `decode_key_private` method. /// /// # Arguments /// - /// * `explicit_nonce`: The explicit nonce to use for the keystream. - /// * `ciphertext`: The ciphertext to decrypt and prove. + /// * `explicit_nonce` - The explicit nonce to use for the keystream. + /// * `ciphertext` - The ciphertext to decrypt and prove. async fn prove_plaintext( &mut self, explicit_nonce: Vec, ciphertext: Vec, ) -> Result, StreamCipherError>; - /// Verifies the other party(s) can prove they know a plaintext which encrypts to the given ciphertext. + /// Verifies the other party(s) can prove they know a plaintext which + /// encrypts to the given ciphertext. /// /// # Arguments /// - /// * `explicit_nonce`: The explicit nonce to use for the keystream. - /// * `ciphertext`: The ciphertext to verify. + /// * `explicit_nonce` - The explicit nonce to use for the keystream. + /// * `ciphertext` - The ciphertext to verify. async fn verify_plaintext( &mut self, explicit_nonce: Vec, ciphertext: Vec, ) -> Result<(), StreamCipherError>; - /// Returns an additive share of the keystream block for the given explicit nonce and counter. + /// Returns an additive share of the keystream block for the given explicit + /// nonce and counter. /// /// # Arguments /// - /// * `explicit_nonce`: The explicit nonce to use for the keystream block. - /// * `ctr`: The counter to use for the keystream block. + /// * `explicit_nonce` - The explicit nonce to use for the keystream block. + /// * `ctr` - The counter to use for the keystream block. async fn share_keystream_block( &mut self, explicit_nonce: Vec, @@ -211,10 +200,8 @@ mod tests { use super::*; use mpz_garble::{ - protocol::deap::mock::{ - create_mock_deap_vm, MockFollower, MockFollowerThread, MockLeader, MockLeaderThread, - }, - Memory, Vm, + protocol::deap::mock::{create_mock_deap_vm, MockFollower, MockLeader}, + Memory, }; use rstest::*; @@ -222,38 +209,23 @@ mod tests { start_ctr: usize, key: [u8; 16], iv: [u8; 4], - thread_count: usize, ) -> ( - ( - MpcStreamCipher, - MpcStreamCipher, - ), - (MockLeader, MockFollower), + MpcStreamCipher, + MpcStreamCipher, ) { - let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test").await; + let (leader_vm, follower_vm) = create_mock_deap_vm(); - let leader_thread = leader_vm.new_thread("key_config").await.unwrap(); - let leader_key = leader_thread.new_public_input::<[u8; 16]>("key").unwrap(); - let leader_iv = leader_thread.new_public_input::<[u8; 4]>("iv").unwrap(); + let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap(); + let leader_iv = leader_vm.new_public_input::<[u8; 4]>("iv").unwrap(); - leader_thread.assign(&leader_key, key).unwrap(); - leader_thread.assign(&leader_iv, iv).unwrap(); + leader_vm.assign(&leader_key, key).unwrap(); + leader_vm.assign(&leader_iv, iv).unwrap(); - let follower_thread = follower_vm.new_thread("key_config").await.unwrap(); - let follower_key = follower_thread.new_public_input::<[u8; 16]>("key").unwrap(); - let follower_iv = follower_thread.new_public_input::<[u8; 4]>("iv").unwrap(); + let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap(); + let follower_iv = follower_vm.new_public_input::<[u8; 4]>("iv").unwrap(); - follower_thread.assign(&follower_key, key).unwrap(); - follower_thread.assign(&follower_iv, iv).unwrap(); - - let leader_thread_pool = leader_vm - .new_thread_pool("mock", thread_count) - .await - .unwrap(); - let follower_thread_pool = follower_vm - .new_thread_pool("mock", thread_count) - .await - .unwrap(); + follower_vm.assign(&follower_key, key).unwrap(); + follower_vm.assign(&follower_iv, iv).unwrap(); let leader_config = StreamCipherConfig::builder() .id("test") @@ -267,18 +239,19 @@ mod tests { .build() .unwrap(); - let mut leader = MpcStreamCipher::::new(leader_config, leader_thread_pool); + let mut leader = MpcStreamCipher::::new(leader_config, leader_vm); leader.set_key(leader_key, leader_iv); - let mut follower = MpcStreamCipher::::new(follower_config, follower_thread_pool); + let mut follower = MpcStreamCipher::::new(follower_config, follower_vm); follower.set_key(follower_key, follower_iv); - ((leader, follower), (leader_vm, follower_vm)) + (leader, follower) } #[rstest] #[timeout(Duration::from_millis(10000))] #[tokio::test] + #[ignore = "expensive"] async fn test_stream_cipher_public() { let key = [0u8; 16]; let iv = [0u8; 4]; @@ -286,8 +259,7 @@ mod tests { let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec(); - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = - create_test_pair::(1, key, iv, 8).await; + let (mut leader, mut follower) = create_test_pair::(1, key, iv).await; let leader_fut = async { let leader_encrypted_msg = leader @@ -333,6 +305,7 @@ mod tests { #[rstest] #[timeout(Duration::from_millis(10000))] #[tokio::test] + #[ignore = "expensive"] async fn test_stream_cipher_private() { let key = [0u8; 16]; let iv = [0u8; 4]; @@ -342,8 +315,7 @@ mod tests { let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap(); - let ((mut leader, mut follower), (mut leader_vm, mut follower_vm)) = - create_test_pair::(1, key, iv, 8).await; + let (mut leader, mut follower) = create_test_pair::(1, key, iv).await; let leader_fut = async { let leader_decrypted_msg = leader @@ -378,19 +350,23 @@ mod tests { assert_eq!(leader_decrypted_msg, msg); assert_eq!(follower_encrypted_msg, ciphertext); - futures::try_join!(leader_vm.finalize(), follower_vm.finalize()).unwrap(); + futures::try_join!( + leader.thread_mut().finalize(), + follower.thread_mut().finalize() + ) + .unwrap(); } #[rstest] #[timeout(Duration::from_millis(10000))] #[tokio::test] + #[ignore = "expensive"] async fn test_stream_cipher_share_key_block() { let key = [0u8; 16]; let iv = [0u8; 4]; let explicit_nonce = [0u8; 8]; - let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = - create_test_pair::(1, key, iv, 8).await; + let (mut leader, mut follower) = create_test_pair::(1, key, iv).await; let leader_fut = async { leader @@ -423,6 +399,7 @@ mod tests { #[rstest] #[timeout(Duration::from_millis(10000))] #[tokio::test] + #[ignore = "expensive"] async fn test_stream_cipher_zk() { let key = [0u8; 16]; let iv = [0u8; 4]; @@ -432,8 +409,7 @@ mod tests { let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &msg).unwrap(); - let ((mut leader, mut follower), (mut leader_vm, mut follower_vm)) = - create_test_pair::(2, key, iv, 8).await; + let (mut leader, mut follower) = create_test_pair::(2, key, iv).await; futures::try_join!(leader.decode_key_private(), follower.decode_key_blind()).unwrap(); @@ -442,6 +418,57 @@ mod tests { follower.verify_plaintext(explicit_nonce.to_vec(), ciphertext) ) .unwrap(); - futures::try_join!(leader_vm.finalize(), follower_vm.finalize()).unwrap(); + futures::try_join!( + leader.thread_mut().finalize(), + follower.thread_mut().finalize() + ) + .unwrap(); + } + + #[rstest] + #[case::one_block(16)] + #[case::partial(17)] + #[case::extra(128)] + #[timeout(Duration::from_millis(10000))] + #[tokio::test] + #[ignore = "expensive"] + async fn test_stream_cipher_preprocess(#[case] len: usize) { + let key = [0u8; 16]; + let iv = [0u8; 4]; + let explicit_nonce = [1u8; 8]; + + let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec(); + + let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap(); + + let (mut leader, mut follower) = create_test_pair::(1, key, iv).await; + + let leader_fut = async { + leader.preprocess(len).await.unwrap(); + + leader + .decrypt_private(explicit_nonce.to_vec(), ciphertext.clone()) + .await + .unwrap() + }; + + let follower_fut = async { + follower.preprocess(len).await.unwrap(); + + follower + .decrypt_blind(explicit_nonce.to_vec(), ciphertext.clone()) + .await + .unwrap(); + }; + + let (leader_decrypted_msg, _) = futures::join!(leader_fut, follower_fut); + + assert_eq!(leader_decrypted_msg, msg); + + futures::try_join!( + leader.thread_mut().finalize(), + follower.thread_mut().finalize() + ) + .unwrap(); } } diff --git a/crates/components/stream-cipher/src/stream_cipher.rs b/crates/components/stream-cipher/src/stream_cipher.rs new file mode 100644 index 0000000000..449ed0329f --- /dev/null +++ b/crates/components/stream-cipher/src/stream_cipher.rs @@ -0,0 +1,671 @@ +use async_trait::async_trait; +use mpz_circuits::types::Value; +use std::collections::HashMap; +use tracing::instrument; + +use mpz_garble::{value::ValueRef, Decode, DecodePrivate, Execute, Load, Prove, Thread, Verify}; +use utils::id::NestedId; + +use crate::{ + cipher::CtrCircuit, + circuit::build_array_xor, + config::{is_valid_mode, ExecutionMode, InputText, StreamCipherConfig}, + keystream::KeyStream, + StreamCipher, StreamCipherError, +}; + +/// An MPC stream cipher. +#[derive(Debug)] +pub struct MpcStreamCipher +where + C: CtrCircuit, + E: Thread + Execute + Decode + DecodePrivate + Send + Sync, +{ + config: StreamCipherConfig, + state: State, + thread: E, +} + +struct State { + /// Encoded key and IV for the cipher. + encoded_key_iv: Option, + /// Key and IV for the cipher. + key_iv: Option, + /// Keystream state. + keystream: KeyStream, + /// Current transcript. + transcript: Transcript, + /// Maps a transcript ID to the corresponding transcript. + transcripts: HashMap, + /// Number of messages operated on. + counter: usize, +} + +opaque_debug::implement!(State); + +#[derive(Clone)] +struct EncodedKeyAndIv { + key: ValueRef, + iv: ValueRef, +} + +#[derive(Clone)] +struct KeyAndIv { + key: Vec, + iv: Vec, +} + +/// A subset of plaintext bytes processed by the stream cipher. +/// +/// Note that `Transcript` does not store the actual bytes. Instead, it provides +/// IDs which are assigned to plaintext bytes of the stream cipher. +struct Transcript { + /// The ID of this transcript. + id: String, + /// The ID for the next plaintext byte. + plaintext: NestedId, +} + +impl Transcript { + fn new(id: &str) -> Self { + Self { + id: id.to_string(), + plaintext: NestedId::new(id).append_counter(), + } + } + + /// Returns unique identifiers for the next plaintext bytes in the + /// transcript. + fn extend_plaintext(&mut self, len: usize) -> Vec { + (0..len) + .map(|_| self.plaintext.increment_in_place().to_string()) + .collect() + } +} + +impl MpcStreamCipher +where + C: CtrCircuit, + E: Thread + Execute + Load + Prove + Verify + Decode + DecodePrivate + Send + Sync + 'static, +{ + /// Creates a new counter-mode cipher. + pub fn new(config: StreamCipherConfig, thread: E) -> Self { + let keystream = KeyStream::new(&config.id); + let transcript = Transcript::new(&config.transcript_id); + Self { + config, + state: State { + encoded_key_iv: None, + key_iv: None, + keystream, + transcript, + transcripts: HashMap::new(), + counter: 0, + }, + thread, + } + } + + /// Returns a mutable reference to the underlying thread. + pub fn thread_mut(&mut self) -> &mut E { + &mut self.thread + } + + /// Computes a keystream of the given length. + async fn compute_keystream( + &mut self, + explicit_nonce: Vec, + start_ctr: usize, + len: usize, + mode: ExecutionMode, + ) -> Result { + let EncodedKeyAndIv { key, iv } = self + .state + .encoded_key_iv + .as_ref() + .ok_or_else(StreamCipherError::key_not_set)?; + + let keystream = self + .state + .keystream + .compute( + &mut self.thread, + mode, + key, + iv, + explicit_nonce, + start_ctr, + len, + ) + .await?; + + self.state.counter += 1; + + Ok(keystream) + } + + /// Applies the keystream to the provided input text. + async fn apply_keystream( + &mut self, + mode: ExecutionMode, + input_text: InputText, + keystream: ValueRef, + ) -> Result { + debug_assert!( + is_valid_mode(&mode, &input_text), + "invalid execution mode for input text" + ); + + let input_text = match input_text { + InputText::Public { ids, text } => { + let refs = text + .into_iter() + .zip(ids) + .map(|(byte, id)| { + let value_ref = self.thread.new_public_input::(&id)?; + self.thread.assign(&value_ref, byte)?; + + Ok::<_, StreamCipherError>(value_ref) + }) + .collect::, _>>()?; + self.thread.array_from_values(&refs)? + } + InputText::Private { ids, text } => { + let refs = text + .into_iter() + .zip(ids) + .map(|(byte, id)| { + let value_ref = self.thread.new_private_input::(&id)?; + self.thread.assign(&value_ref, byte)?; + + Ok::<_, StreamCipherError>(value_ref) + }) + .collect::, _>>()?; + self.thread.array_from_values(&refs)? + } + InputText::Blind { ids } => { + let refs = ids + .into_iter() + .map(|id| self.thread.new_blind_input::(&id)) + .collect::, _>>()?; + self.thread.array_from_values(&refs)? + } + }; + + let output_text = self.thread.new_array_output::( + &format!("{}/out/{}", self.config.id, self.state.counter), + input_text.len(), + )?; + + let circ = build_array_xor(input_text.len()); + + match mode { + ExecutionMode::Mpc => { + self.thread + .execute(circ, &[input_text, keystream], &[output_text.clone()]) + .await?; + } + ExecutionMode::Prove => { + self.thread + .execute_prove(circ, &[input_text, keystream], &[output_text.clone()]) + .await?; + } + ExecutionMode::Verify => { + self.thread + .execute_verify(circ, &[input_text, keystream], &[output_text.clone()]) + .await?; + } + } + + Ok(output_text) + } + + async fn decode_public(&mut self, value: ValueRef) -> Result { + self.thread + .decode(&[value]) + .await + .map_err(StreamCipherError::from) + .map(|mut output| output.pop().unwrap()) + } + + async fn decode_shared(&mut self, value: ValueRef) -> Result { + self.thread + .decode_shared(&[value]) + .await + .map_err(StreamCipherError::from) + .map(|mut output| output.pop().unwrap()) + } + + async fn decode_private(&mut self, value: ValueRef) -> Result { + self.thread + .decode_private(&[value]) + .await + .map_err(StreamCipherError::from) + .map(|mut output| output.pop().unwrap()) + } + + async fn decode_blind(&mut self, value: ValueRef) -> Result<(), StreamCipherError> { + self.thread.decode_blind(&[value]).await?; + Ok(()) + } + + async fn prove(&mut self, value: ValueRef) -> Result<(), StreamCipherError> { + self.thread.prove(&[value]).await?; + Ok(()) + } + + async fn verify(&mut self, value: ValueRef, expected: Value) -> Result<(), StreamCipherError> { + self.thread.verify(&[value], &[expected]).await?; + Ok(()) + } +} + +#[async_trait] +impl StreamCipher for MpcStreamCipher +where + C: CtrCircuit, + E: Thread + Execute + Load + Prove + Verify + Decode + DecodePrivate + Send + Sync + 'static, +{ + fn set_key(&mut self, key: ValueRef, iv: ValueRef) { + self.state.encoded_key_iv = Some(EncodedKeyAndIv { key, iv }); + } + + #[instrument(level = "debug", skip_all, err)] + async fn decode_key_private(&mut self) -> Result<(), StreamCipherError> { + let EncodedKeyAndIv { key, iv } = self + .state + .encoded_key_iv + .clone() + .ok_or_else(StreamCipherError::key_not_set)?; + + let [key, iv]: [_; 2] = self + .thread + .decode_private(&[key, iv]) + .await? + .try_into() + .expect("decoded 2 values"); + + let key: Vec = key.try_into().expect("key is an array"); + let iv: Vec = iv.try_into().expect("iv is an array"); + + self.state.key_iv = Some(KeyAndIv { key, iv }); + + Ok(()) + } + + #[instrument(level = "debug", skip_all, err)] + async fn decode_key_blind(&mut self) -> Result<(), StreamCipherError> { + let EncodedKeyAndIv { key, iv } = self + .state + .encoded_key_iv + .clone() + .ok_or_else(StreamCipherError::key_not_set)?; + + self.thread.decode_blind(&[key, iv]).await?; + + Ok(()) + } + + fn set_transcript_id(&mut self, id: &str) { + if id == self.state.transcript.id { + return; + } + + let transcript = self + .state + .transcripts + .remove(id) + .unwrap_or_else(|| Transcript::new(id)); + let old_transcript = std::mem::replace(&mut self.state.transcript, transcript); + self.state + .transcripts + .insert(old_transcript.id.clone(), old_transcript); + } + + #[instrument(level = "debug", skip_all, err)] + async fn preprocess(&mut self, len: usize) -> Result<(), StreamCipherError> { + let EncodedKeyAndIv { key, iv } = self + .state + .encoded_key_iv + .as_ref() + .ok_or_else(StreamCipherError::key_not_set)?; + + self.state + .keystream + .preprocess(&mut self.thread, key, iv, len) + .await + } + + #[instrument(level = "debug", skip_all, err)] + async fn encrypt_public( + &mut self, + explicit_nonce: Vec, + plaintext: Vec, + ) -> Result, StreamCipherError> { + let keystream = self + .compute_keystream( + explicit_nonce, + self.config.start_ctr, + plaintext.len(), + ExecutionMode::Mpc, + ) + .await?; + + let plaintext_ids = self.state.transcript.extend_plaintext(plaintext.len()); + let ciphertext = self + .apply_keystream( + ExecutionMode::Mpc, + InputText::Public { + ids: plaintext_ids, + text: plaintext, + }, + keystream, + ) + .await?; + + let ciphertext: Vec = self + .decode_public(ciphertext) + .await? + .try_into() + .expect("ciphertext is array"); + + Ok(ciphertext) + } + + #[instrument(level = "debug", skip_all, err)] + async fn encrypt_private( + &mut self, + explicit_nonce: Vec, + plaintext: Vec, + ) -> Result, StreamCipherError> { + let keystream = self + .compute_keystream( + explicit_nonce, + self.config.start_ctr, + plaintext.len(), + ExecutionMode::Mpc, + ) + .await?; + + let plaintext_ids = self.state.transcript.extend_plaintext(plaintext.len()); + let ciphertext = self + .apply_keystream( + ExecutionMode::Mpc, + InputText::Private { + ids: plaintext_ids, + text: plaintext, + }, + keystream, + ) + .await?; + + let ciphertext: Vec = self + .decode_public(ciphertext) + .await? + .try_into() + .expect("ciphertext is array"); + + Ok(ciphertext) + } + + #[instrument(level = "debug", skip_all, err)] + async fn encrypt_blind( + &mut self, + explicit_nonce: Vec, + len: usize, + ) -> Result, StreamCipherError> { + let keystream = self + .compute_keystream( + explicit_nonce, + self.config.start_ctr, + len, + ExecutionMode::Mpc, + ) + .await?; + + let plaintext_ids = self.state.transcript.extend_plaintext(len); + let ciphertext = self + .apply_keystream( + ExecutionMode::Mpc, + InputText::Blind { ids: plaintext_ids }, + keystream, + ) + .await?; + + let ciphertext: Vec = self + .decode_public(ciphertext) + .await? + .try_into() + .expect("ciphertext is array"); + + Ok(ciphertext) + } + + #[instrument(level = "debug", skip_all, err)] + async fn decrypt_public( + &mut self, + explicit_nonce: Vec, + ciphertext: Vec, + ) -> Result, StreamCipherError> { + // TODO: We may want to support writing to the transcript when decrypting + // in public mode. + let keystream = self + .compute_keystream( + explicit_nonce, + self.config.start_ctr, + ciphertext.len(), + ExecutionMode::Mpc, + ) + .await?; + + let ciphertext_ids = (0..ciphertext.len()) + .map(|i| format!("ct/{}/{}", self.state.counter, i)) + .collect(); + let plaintext = self + .apply_keystream( + ExecutionMode::Mpc, + InputText::Public { + ids: ciphertext_ids, + text: ciphertext, + }, + keystream, + ) + .await?; + + let plaintext: Vec = self + .decode_public(plaintext) + .await? + .try_into() + .expect("plaintext is array"); + + Ok(plaintext) + } + + #[instrument(level = "debug", skip_all, err)] + async fn decrypt_private( + &mut self, + explicit_nonce: Vec, + ciphertext: Vec, + ) -> Result, StreamCipherError> { + let keystream_ref = self + .compute_keystream( + explicit_nonce, + self.config.start_ctr, + ciphertext.len(), + ExecutionMode::Mpc, + ) + .await?; + + let keystream: Vec = self + .decode_private(keystream_ref.clone()) + .await? + .try_into() + .expect("keystream is array"); + + let plaintext = ciphertext + .into_iter() + .zip(keystream) + .map(|(c, k)| c ^ k) + .collect::>(); + + // Prove plaintext encrypts back to ciphertext. + let plaintext_ids = self.state.transcript.extend_plaintext(plaintext.len()); + let ciphertext = self + .apply_keystream( + ExecutionMode::Prove, + InputText::Private { + ids: plaintext_ids, + text: plaintext.clone(), + }, + keystream_ref, + ) + .await?; + + self.prove(ciphertext).await?; + + Ok(plaintext) + } + + #[instrument(level = "debug", skip_all, err)] + async fn decrypt_blind( + &mut self, + explicit_nonce: Vec, + ciphertext: Vec, + ) -> Result<(), StreamCipherError> { + let keystream_ref = self + .compute_keystream( + explicit_nonce, + self.config.start_ctr, + ciphertext.len(), + ExecutionMode::Mpc, + ) + .await?; + + self.decode_blind(keystream_ref.clone()).await?; + + // Verify the plaintext encrypts back to ciphertext. + let plaintext_ids = self.state.transcript.extend_plaintext(ciphertext.len()); + let ciphertext_ref = self + .apply_keystream( + ExecutionMode::Verify, + InputText::Blind { ids: plaintext_ids }, + keystream_ref, + ) + .await?; + + self.verify(ciphertext_ref, ciphertext.into()).await?; + + Ok(()) + } + + #[instrument(level = "debug", skip_all, err)] + async fn prove_plaintext( + &mut self, + explicit_nonce: Vec, + ciphertext: Vec, + ) -> Result, StreamCipherError> { + let KeyAndIv { key, iv } = self + .state + .key_iv + .clone() + .ok_or_else(StreamCipherError::key_not_set)?; + + let plaintext = C::apply_keystream( + &key, + &iv, + self.config.start_ctr, + &explicit_nonce, + &ciphertext, + )?; + + // Prove plaintext encrypts back to ciphertext. + let keystream = self + .compute_keystream( + explicit_nonce, + self.config.start_ctr, + plaintext.len(), + ExecutionMode::Prove, + ) + .await?; + + let plaintext_ids = self.state.transcript.extend_plaintext(plaintext.len()); + let ciphertext = self + .apply_keystream( + ExecutionMode::Prove, + InputText::Private { + ids: plaintext_ids, + text: plaintext.clone(), + }, + keystream, + ) + .await?; + + self.prove(ciphertext).await?; + + Ok(plaintext) + } + + #[instrument(level = "debug", skip_all, err)] + async fn verify_plaintext( + &mut self, + explicit_nonce: Vec, + ciphertext: Vec, + ) -> Result<(), StreamCipherError> { + let keystream = self + .compute_keystream( + explicit_nonce, + self.config.start_ctr, + ciphertext.len(), + ExecutionMode::Verify, + ) + .await?; + + let plaintext_ids = self.state.transcript.extend_plaintext(ciphertext.len()); + let ciphertext_ref = self + .apply_keystream( + ExecutionMode::Verify, + InputText::Blind { ids: plaintext_ids }, + keystream, + ) + .await?; + + self.verify(ciphertext_ref, ciphertext.into()).await?; + + Ok(()) + } + + #[instrument(level = "debug", skip_all, err)] + async fn share_keystream_block( + &mut self, + explicit_nonce: Vec, + ctr: usize, + ) -> Result, StreamCipherError> { + let EncodedKeyAndIv { key, iv } = self + .state + .encoded_key_iv + .as_ref() + .ok_or_else(StreamCipherError::key_not_set)?; + + let key_block = self + .state + .keystream + .compute( + &mut self.thread, + ExecutionMode::Mpc, + key, + iv, + explicit_nonce, + ctr, + C::BLOCK_LEN, + ) + .await?; + + let share = self + .decode_shared(key_block) + .await? + .try_into() + .expect("key block is array"); + + Ok(share) + } +} diff --git a/crates/components/universal-hash/Cargo.toml b/crates/components/universal-hash/Cargo.toml new file mode 100644 index 0000000000..31cb512c7d --- /dev/null +++ b/crates/components/universal-hash/Cargo.toml @@ -0,0 +1,50 @@ +[package] +name = "tlsn-universal-hash" +authors = ["TLSNotary Team"] +description = "A crate which implements different hash functions for two-party computation" +keywords = ["tls", "mpc", "2pc", "hash"] +categories = ["cryptography"] +license = "MIT OR Apache-2.0" +version = "0.1.0-alpha.7" +edition = "2021" + +[features] +default = ["ghash", "ideal"] +ghash = [] +ideal = ["dep:ghash_rc"] + +[dependencies] +# tlsn +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", features = [ + "ideal", +] } +mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-share-conversion-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } + +ghash_rc = { package = "ghash", version = "0.5", optional = true } + +async-trait = { workspace = true } +futures = { workspace = true } +futures-util = { workspace = true } +thiserror = { workspace = true } +opaque-debug = { workspace = true } +tracing = { workspace = true } +derive_builder = { workspace = true } + +[dev-dependencies] +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", features = [ + "test-utils", +] } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", features = [ + "ideal", +] } + +ghash_rc = { package = "ghash", version = "0.5" } +tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } +criterion = { workspace = true } +rstest = { workspace = true } +rand_chacha = { workspace = true } +rand = { workspace = true } +generic-array = { workspace = true } diff --git a/components/universal-hash/src/ghash/ghash_core/core.rs b/crates/components/universal-hash/src/ghash/ghash_core/core.rs similarity index 66% rename from components/universal-hash/src/ghash/ghash_core/core.rs rename to crates/components/universal-hash/src/ghash/ghash_core/core.rs index ead214d87e..29a4ececfe 100644 --- a/components/universal-hash/src/ghash/ghash_core/core.rs +++ b/crates/components/universal-hash/src/ghash/ghash_core/core.rs @@ -1,5 +1,6 @@ use mpz_core::Block; -use mpz_share_conversion_core::fields::{gf2_128::Gf2_128, Field}; +use mpz_fields::{gf2_128::Gf2_128, Field}; +use tracing::instrument; use super::{ compute_missing_mul_shares, compute_new_add_shares, @@ -7,23 +8,24 @@ use super::{ GhashError, }; -/// The core logic for our 2PC Ghash implementation +/// The core logic for the 2PC Ghash implementation. /// -/// `GhashCore` will do all the necessary computation +/// `GhashCore` will do all the necessary computation. #[derive(Debug)] pub(crate) struct GhashCore { - /// Inner state + /// Inner state. state: T, - /// Maximum number of message blocks we want to authenticate + /// Maximum number of message blocks we want to authenticate. max_block_count: usize, } impl GhashCore { - /// Create a new `GhashCore` + /// Creates a new `GhashCore`. /// - /// * `max_block_count` - Determines the maximum number of 128-bit message blocks we want to - /// authenticate. Panics if `max_block_count` is 0. - #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug"))] + /// # Arguments + /// + /// * `max_block_count` - Determines the maximum number of 128-bit message + /// blocks we want to authenticate. Panics if `max_block_count` is 0. pub(crate) fn new(max_block_count: usize) -> Self { assert!(max_block_count > 0); @@ -33,14 +35,12 @@ impl GhashCore { } } - /// Transform `self` into a `GhashCore`, holding multiplicative shares of - /// powers of `H` + /// Transforms `self` into a `GhashCore`, holding + /// multiplicative shares of powers of `H`. /// - /// Converts `H` into `H`, `H^3`, `H^5`, ... depending on `self.max_block_count` - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(mul_share)) - )] + /// Converts `H` into `H`, `H^3`, `H^5`, ... depending on + /// `self.max_block_count`. + #[instrument(level = "trace", skip_all)] pub(crate) fn compute_odd_mul_powers(self, mul_share: Gf2_128) -> GhashCore { let mut hashkey_powers = vec![mul_share]; @@ -57,28 +57,25 @@ impl GhashCore { } impl GhashCore { - /// Return odd multiplicative shares of the hashkey + /// Returns odd multiplicative shares of the hashkey. /// /// Takes into account cached additive shares, so that /// multiplicative ones for which already an additive one /// exists, are not returned. - #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace"))] + #[instrument(level = "trace", skip_all)] pub(crate) fn odd_mul_shares(&self) -> Vec { - // If we already have some cached additive sharings, we do not need to compute new powers. - // So we compute an offset to ignore them. We divide by 2 because - // `self.state.cached_add_shares` contain even and odd powers, while - // `self.state.odd_mul_shares` only have odd powers. + // If we already have some cached additive sharings, we do not need to compute + // new powers. So we compute an offset to ignore them. We divide by 2 + // because `self.state.cached_add_shares` contain even and odd powers, + // while `self.state.odd_mul_shares` only have odd powers. let offset = self.state.cached_add_shares.len() / 2; self.state.odd_mul_shares[offset..].to_vec() } - /// Adds new additive shares of hashkey powers by also computing the even ones - /// and transforms `self` into a `GhashCore` - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(new_additive_odd_shares)) - )] + /// Adds new additive shares of hashkey powers by also computing the even + /// ones and transforms `self` into a `GhashCore`. + #[instrument(level = "trace", skip_all)] pub(crate) fn add_new_add_shares( mut self, new_additive_odd_shares: &[Gf2_128], @@ -96,15 +93,15 @@ impl GhashCore { } impl GhashCore { - /// Returns the currently configured maximum message length + /// Returns the currently configured maximum message length. pub(crate) fn get_max_blocks(&self) -> usize { self.max_block_count } - /// Generate the GHASH output + /// Generates the GHASH output. /// - /// Computes the 2PC additive share of the GHASH output - #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", err, ret))] + /// Computes the 2PC additive share of the GHASH output. + #[instrument(level = "debug", skip_all, err)] pub(crate) fn finalize(&self, message: &[Block]) -> Result { if message.len() > self.max_block_count { return Err(GhashError::InvalidMessageLength); @@ -122,11 +119,12 @@ impl GhashCore { Ok(output.reverse_bits()) } - /// Change the maximum hashkey power + /// Changes the maximum hashkey power. /// - /// If we want to create a GHASH output for a new message, which is longer than the old one, we need - /// to compute the missing shares of the powers of `H`. - #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace"))] + /// If we want to create a GHASH output for a new message, which is longer + /// than the old one, we need to compute the missing shares of the + /// powers of `H`. + #[instrument(level = "debug", skip(self))] pub(crate) fn change_max_hashkey( self, new_highest_hashkey_power: usize, diff --git a/components/universal-hash/src/ghash/ghash_core/mod.rs b/crates/components/universal-hash/src/ghash/ghash_core/mod.rs similarity index 76% rename from components/universal-hash/src/ghash/ghash_core/mod.rs rename to crates/components/universal-hash/src/ghash/ghash_core/mod.rs index 8ea28b1c5d..c4eae87e51 100644 --- a/components/universal-hash/src/ghash/ghash_core/mod.rs +++ b/crates/components/universal-hash/src/ghash/ghash_core/mod.rs @@ -1,30 +1,33 @@ -//! This module implements the AES-GCM's GHASH function in a secure two-party computation (2PC) -//! setting. The parties start with their secret XOR shares of H (the GHASH key) and at the end -//! each gets their XOR share of the GHASH output. The method is described here: -//! . +//! This module implements the AES-GCM's GHASH function in a secure two-party +//! computation (2PC) setting. The parties start with their secret XOR shares of +//! H (the GHASH key) and at the end each gets their XOR share of the GHASH +//! output. The method is described here: . //! -//! At first we will convert the XOR (additive) share of `H`, into a multiplicative share. This -//! allows us to compute all the necessary powers of `H^n` locally. Note, that it is only required -//! to compute the odd multiplicative powers, because of free squaring. Then each of these -//! multiplicative shares will be converted back into additive shares. The even additive shares can -//! then locally be built by using the odd ones. This way, we can batch nearly all oblivious -//! transfers and reduce the round complexity of the protocol. +//! At first we will convert the XOR (additive) share of `H`, into a +//! multiplicative share. This allows us to compute all the necessary powers of +//! `H^n` locally. Note, that it is only required to compute the odd +//! multiplicative powers, because of free squaring. Then each of these +//! multiplicative shares will be converted back into additive shares. The even +//! additive shares can then locally be built by using the odd ones. This way, +//! we can batch nearly all oblivious transfers and reduce the round complexity +//! of the protocol. //! -//! On the whole, we need a single additive-to-multiplicative (A2M) and `n/2`, where `n` is the -//! number of blocks of message, multiplicative-to-additive (M2A) conversions. Finally, having -//! additive shares of `H^n` for all needed `n`, we can compute an additive share of the GHASH -//! output. +//! On the whole, we need a single additive-to-multiplicative (A2M) and `n/2`, +//! where `n` is the number of blocks of message, multiplicative-to-additive +//! (M2A) conversions. Finally, having additive shares of `H^n` for all needed +//! `n`, we can compute an additive share of the GHASH output. -/// Contains the core logic for ghash +/// Contains the core logic for ghash. mod core; -/// Contains the different states +/// Contains the different states. pub(crate) mod state; pub(crate) use self::core::GhashCore; -use mpz_share_conversion_core::fields::{compute_product_repeated, gf2_128::Gf2_128}; +use mpz_fields::{compute_product_repeated, gf2_128::Gf2_128}; use thiserror::Error; +use tracing::instrument; #[derive(Debug, Error)] pub(crate) enum GhashError { @@ -32,20 +35,21 @@ pub(crate) enum GhashError { InvalidMessageLength, } -/// Computes missing odd multiplicative shares of the hashkey powers +/// Computes missing odd multiplicative shares of the hashkey powers. /// -/// Checks if depending on the number of `needed` shares, we need more odd multiplicative shares and -/// computes them. Notice that we only need odd multiplicative shares for the OT, because we can -/// derive even additive shares from odd additive shares, which we call free squaring. +/// Checks if depending on the number of `needed` shares, we need more odd +/// multiplicative shares and computes them. Notice that we only need odd +/// multiplicative shares for the OT, because we can derive even additive shares +/// from odd additive shares, which we call free squaring. /// -/// * `present_odd_mul_shares` - multiplicative odd shares already present -/// * `needed` - how many powers we need including odd and even -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(present_odd_mul_shares)) -)] +/// # Arguments +/// +/// * `present_odd_mul_shares` - Multiplicative odd shares already present. +/// * `needed` - How many powers we need including odd and +/// even. +#[instrument(level = "trace", skip(present_odd_mul_shares))] fn compute_missing_mul_shares(present_odd_mul_shares: &mut Vec, needed: usize) { - // divide by 2 and round up + // Divide by 2 and round up. let needed_odd_powers: usize = needed / 2 + (needed & 1); let present_odd_len = present_odd_mul_shares.len(); @@ -59,34 +63,35 @@ fn compute_missing_mul_shares(present_odd_mul_shares: &mut Vec, needed: } } -/// Computes new even (additive) shares from new odd (additive) shares and saves both the new odd shares -/// and the new even shares. +/// Computes new even (additive) shares from new odd (additive) shares and saves +/// both the new odd shares and the new even shares. +/// +/// This function implements the derivation of even additive shares from odd +/// additive shares, which we refer to as free squaring. Every additive share of +/// an even power of `H` can be computed without an OT interaction by squaring +/// the corresponding additive share of an odd power of `H`, e.g. if we have a +/// share of H^3, we can derive the share of H^6 by doing (H^3)^2. /// -/// This function implements the derivation of even additive shares from odd additive shares, -/// which we refer to as free squaring. Every additive share of an even power of -/// `H` can be computed without an OT interaction by squaring the corresponding additive share -/// of an odd power of `H`, e.g. if we have a share of H^3, we can derive the share of H^6 by doing -/// (H^3)^2 +/// # Arguments /// -/// * `new_add_odd_shares` - new odd additive shares we got as a result of doing an OT on odd -/// multiplicative shares -/// * `add_shares` - all additive shares (even and odd) we already have. This is a mutable -/// reference to cached_add_shares in [crate::ghash::state::Intermediate] -#[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip(new_add_odd_shares, add_shares)) -)] +/// * `new_add_odd_shares` - New odd additive shares we got as a result of doing +/// an OT on odd multiplicative shares. +/// * `add_shares` - All additive shares (even and odd) we already have. +/// This is a mutable reference to cached_add_shares in +/// [crate::ghash::state::Intermediate]. +#[instrument(level = "trace", skip_all)] fn compute_new_add_shares(new_add_odd_shares: &[Gf2_128], add_shares: &mut Vec) { for (odd_share, current_odd_power) in new_add_odd_shares .iter() .zip((add_shares.len() + 1..).step_by(2)) { - // `add_shares` always have an even number of shares so we simply add the next odd share + // `add_shares` always have an even number of shares so we simply add the next + // odd share. add_shares.push(*odd_share); - // now we need to compute the next even share and add it - // note that the n-th index corresponds to the (n+1)-th power, e.g. add_shares[4] - // is the share of H^5 + // Now we need to compute the next even share and add it. + // Note that the n-th index corresponds to the (n+1)-th power, e.g. + // add_shares[4] is the share of H^5. let mut base_share = add_shares[current_odd_power / 2]; base_share = base_share * base_share; add_shares.push(base_share); @@ -101,7 +106,7 @@ mod tests { GHash, }; use mpz_core::Block; - use mpz_share_conversion_core::fields::{gf2_128::Gf2_128, Field}; + use mpz_fields::{gf2_128::Gf2_128, Field}; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha12Rng; @@ -115,7 +120,7 @@ mod tests { fn test_ghash_product_sharing() { let mut rng = ChaCha12Rng::from_seed([0; 32]); - // The Ghash key + // The Ghash key. let h: Gf2_128 = rng.gen(); let message = Block::random_vec(&mut rng, 10); let message_len = message.len(); @@ -126,14 +131,14 @@ mod tests { let mut powers_h = vec![h]; compute_product_repeated(&mut powers_h, h * h, number_of_powers_needed); - // Length check + // Length check. assert_eq!(sender.state().odd_mul_shares.len(), number_of_powers_needed); assert_eq!( receiver.state().odd_mul_shares.len(), number_of_powers_needed ); - // Product check + // Product check. for (k, (sender_share, receiver_share)) in std::iter::zip( sender.state().odd_mul_shares.iter(), receiver.state().odd_mul_shares.iter(), @@ -148,7 +153,7 @@ mod tests { fn test_ghash_sum_sharing() { let mut rng = ChaCha12Rng::from_seed([0; 32]); - // The Ghash key + // The Ghash key. let h: Gf2_128 = rng.gen(); let message = Block::random_vec(&mut rng, 10); let message_len = message.len(); @@ -159,7 +164,7 @@ mod tests { let mut powers_h = vec![h]; compute_product_repeated(&mut powers_h, h, message_len); - // Length check + // Length check. assert_eq!( sender.state().add_shares.len(), message_len + (message_len & 1) @@ -169,11 +174,11 @@ mod tests { message_len + (message_len & 1) ); - // Sum check - for k in 0..message_len { + // Sum check. + for (k, item) in powers_h.iter().enumerate().take(message_len) { assert_eq!( sender.state().add_shares[k] + receiver.state().add_shares[k], - powers_h[k] + *item ); } } @@ -182,7 +187,7 @@ mod tests { fn test_ghash_output() { let mut rng = ChaCha12Rng::from_seed([0; 32]); - // The Ghash key + // The Ghash key. let h: Gf2_128 = rng.gen(); let message = Block::random_vec(&mut rng, 10); @@ -201,7 +206,7 @@ mod tests { fn test_ghash_change_message_short() { let mut rng = ChaCha12Rng::from_seed([0; 32]); - // The Ghash key + // The Ghash key. let h: Gf2_128 = rng.gen(); let message = Block::random_vec(&mut rng, 10); @@ -230,7 +235,7 @@ mod tests { fn test_ghash_change_message_long() { let mut rng = ChaCha12Rng::from_seed([0; 32]); - // The Ghash key + // The Ghash key. let h: Gf2_128 = rng.gen(); let message = Block::random_vec(&mut rng, 10); @@ -266,14 +271,14 @@ mod tests { compute_missing_mul_shares(&mut powers, needed); - // Check length + // Check length. if needed / 2 + (needed & 1) <= powers_len { assert_eq!(powers.len(), powers_len); } else { assert_eq!(powers.len(), needed / 2 + (needed & 1)) } - // Check shares + // Check shares. let first = *powers.first().unwrap(); let factor = first * first; @@ -291,7 +296,7 @@ mod tests { let new_add_odd_shares: Vec = gen_gf2_128_vec(); let mut add_shares: Vec = gen_gf2_128_vec(); - // We have the invariant that len of add_shares is always even + // We have the invariant that len of add_shares is always even. if add_shares.len() & 1 == 1 { add_shares.push(rng.gen()); } @@ -300,13 +305,13 @@ mod tests { compute_new_add_shares(&new_add_odd_shares, &mut add_shares); - // Check new length + // Check new length. assert_eq!( add_shares.len(), original_len + 2 * new_add_odd_shares.len() ); - // Check odd shares + // Check odd shares. for (k, l) in (original_len..add_shares.len()) .step_by(2) .zip(0..original_len) @@ -314,7 +319,7 @@ mod tests { assert_eq!(add_shares[k], new_add_odd_shares[l]); } - // Check even shares + // Check even shares. for k in (original_len + 1..add_shares.len()).step_by(2) { assert_eq!(add_shares[k], add_shares[k / 2] * add_shares[k / 2]); } @@ -323,7 +328,7 @@ mod tests { fn gen_gf2_128_vec() -> Vec { let mut rng = ChaCha12Rng::from_seed([0; 32]); - // Sample some message + // Sample some message. let message_len: usize = rng.gen_range(16..128); let mut message: Vec = vec![Gf2_128::zero(); message_len]; message.iter_mut().for_each(|x| *x = rng.gen()); @@ -346,7 +351,7 @@ mod tests { ) -> (GhashCore, GhashCore) { let mut rng = ChaCha12Rng::from_seed([0; 32]); - // Create a multiplicative sharing + // Create a multiplicative sharing. let h1_multiplicative: Gf2_128 = rng.gen(); let h2_multiplicative: Gf2_128 = hashkey * h1_multiplicative.inverse(); diff --git a/components/universal-hash/src/ghash/ghash_core/state.rs b/crates/components/universal-hash/src/ghash/ghash_core/state.rs similarity index 75% rename from components/universal-hash/src/ghash/ghash_core/state.rs rename to crates/components/universal-hash/src/ghash/ghash_core/state.rs index 3db90bb4a9..3d792b84c1 100644 --- a/components/universal-hash/src/ghash/ghash_core/state.rs +++ b/crates/components/universal-hash/src/ghash/ghash_core/state.rs @@ -1,4 +1,4 @@ -use mpz_share_conversion_core::fields::gf2_128::Gf2_128; +use mpz_fields::gf2_128::Gf2_128; mod sealed { pub(crate) trait Sealed {} @@ -14,18 +14,18 @@ impl State for Init {} impl State for Intermediate {} impl State for Finalized {} -/// Init state for Ghash protocol +/// Init state for Ghash protocol. /// -/// This is before any OT has taken place +/// This is before any OT has taken place. #[derive(Clone)] pub(crate) struct Init; opaque_debug::implement!(Init); -/// Intermediate state for Ghash protocol +/// Intermediate state for Ghash protocol. /// -/// This is when the additive share has been converted into a multiplicative share and all the -/// needed powers have been computed +/// This is when the additive share has been converted into a multiplicative +/// share and all the needed powers have been computed. #[derive(Clone)] pub(crate) struct Intermediate { pub(super) odd_mul_shares: Vec, @@ -33,16 +33,16 @@ pub(crate) struct Intermediate { // (In order to simplify the code) the n-th index of the vec corresponds to the additive share // of the (n+1)-th power of H, e.g. the share of H^1 is located at the 0-th index of the vec // It always contains an even number of consecutive shares starting from the share of H^1 up to - // the share of H^(cached_add_shares.len()) + // the share of H^(cached_add_shares.len()). pub(super) cached_add_shares: Vec, } opaque_debug::implement!(Intermediate); -/// Final state for Ghash protocol +/// Final state for Ghash protocol. /// -/// This is when each party can compute a final share of the ghash output, because both now have -/// additive shares of all the powers of `H` +/// This is when each party can compute a final share of the ghash output, +/// because both now have additive shares of all the powers of `H`. #[derive(Clone)] pub(crate) struct Finalized { pub(super) odd_mul_shares: Vec, diff --git a/components/universal-hash/src/ghash/ghash_inner/config.rs b/crates/components/universal-hash/src/ghash/ghash_inner/config.rs similarity index 54% rename from components/universal-hash/src/ghash/ghash_inner/config.rs rename to crates/components/universal-hash/src/ghash/ghash_inner/config.rs index ed69004276..33027550cc 100644 --- a/components/universal-hash/src/ghash/ghash_inner/config.rs +++ b/crates/components/universal-hash/src/ghash/ghash_inner/config.rs @@ -1,21 +1,18 @@ use derive_builder::Builder; #[derive(Debug, Clone, Builder)] -/// Configuration struct for [Ghash](crate::ghash::Ghash) +/// Configuration struct for [Ghash](crate::ghash::Ghash). pub struct GhashConfig { - /// The instance ID - #[builder(setter(into))] - pub id: String, - /// Initial number of block shares to provision + /// Initial number of block shares to provision. #[builder(default = "1026")] pub initial_block_count: usize, - /// Maximum number of blocks supported + /// Maximum number of blocks supported. #[builder(default = "1026")] pub max_block_count: usize, } impl GhashConfig { - /// Creates a new builder for the [GhashConfig] + /// Creates a new builder for the [GhashConfig]. pub fn builder() -> GhashConfigBuilder { GhashConfigBuilder::default() } diff --git a/crates/components/universal-hash/src/ghash/ghash_inner/ideal.rs b/crates/components/universal-hash/src/ghash/ghash_inner/ideal.rs new file mode 100644 index 0000000000..900a6d9f04 --- /dev/null +++ b/crates/components/universal-hash/src/ghash/ghash_inner/ideal.rs @@ -0,0 +1,183 @@ +//! Ideal GHASH functionality. + +use async_trait::async_trait; +use ghash_rc::{ + universal_hash::{KeyInit, UniversalHash as UniversalHashReference}, + GHash, +}; +use mpz_common::{ + ideal::{ideal_f2p, Alice, Bob}, + Context, +}; + +use crate::{UniversalHash, UniversalHashError}; + +/// An ideal GHASH functionality. +#[derive(Debug)] +pub struct IdealGhash { + role: Role, + context: Ctx, +} + +#[derive(Debug)] +enum Role { + Alice(Alice), + Bob(Bob), +} + +#[async_trait] +impl UniversalHash for IdealGhash { + async fn set_key(&mut self, key: Vec) -> Result<(), UniversalHashError> { + match &mut self.role { + Role::Alice(alice) => { + alice + .call( + &mut self.context, + key, + |ghash, alice_key, bob_key: Vec| { + let key = alice_key + .into_iter() + .zip(bob_key) + .map(|(a, b)| a ^ b) + .collect::>(); + *ghash = GHash::new_from_slice(&key).unwrap(); + ((), ()) + }, + ) + .await + } + Role::Bob(bob) => { + bob.call( + &mut self.context, + key, + |ghash, alice_key: Vec, bob_key| { + let key = alice_key + .into_iter() + .zip(bob_key) + .map(|(a, b)| a ^ b) + .collect::>(); + *ghash = GHash::new_from_slice(&key).unwrap(); + ((), ()) + }, + ) + .await + } + } + + Ok(()) + } + + async fn setup(&mut self) -> Result<(), UniversalHashError> { + Ok(()) + } + + async fn preprocess(&mut self) -> Result<(), UniversalHashError> { + Ok(()) + } + + async fn finalize(&mut self, input: Vec) -> Result, UniversalHashError> { + Ok(match &mut self.role { + Role::Alice(alice) => { + alice + .call( + &mut self.context, + input, + |ghash, alice_input, bob_input: Vec| { + assert_eq!(&alice_input, &bob_input); + + let mut ghash = ghash.clone(); + ghash.update_padded(&alice_input); + let output = ghash.finalize().to_vec(); + + let output_bob = vec![0; output.len()]; + let output_alice: Vec = output + .iter() + .zip(output_bob.iter().copied()) + .map(|(o, b)| o ^ b) + .collect(); + (output_alice, output_bob) + }, + ) + .await + } + Role::Bob(bob) => { + bob.call( + &mut self.context, + input, + |ghash, alice_input: Vec, bob_input| { + assert_eq!(&alice_input, &bob_input); + + let mut ghash = ghash.clone(); + ghash.update_padded(&alice_input); + let output = ghash.finalize(); + + let output_bob = vec![0; output.len()]; + let output_alice: Vec = output + .iter() + .zip(output_bob.iter().copied()) + .map(|(o, b)| o ^ b) + .collect(); + (output_alice, output_bob) + }, + ) + .await + } + }) + } +} + +/// Creates an ideal GHASH pair. +pub fn ideal_ghash( + context_alice: Ctx, + context_bob: Ctx, +) -> (IdealGhash, IdealGhash) { + let (alice, bob) = ideal_f2p(GHash::new_from_slice(&[0u8; 16]).unwrap()); + ( + IdealGhash { + role: Role::Alice(alice), + context: context_alice, + }, + IdealGhash { + role: Role::Bob(bob), + context: context_bob, + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use mpz_common::executor::test_st_executor; + + #[tokio::test] + async fn test_ideal_ghash() { + let (ctx_a, ctx_b) = test_st_executor(8); + let (mut alice, mut bob) = ideal_ghash(ctx_a, ctx_b); + + let alice_key = vec![42u8; 16]; + let bob_key = vec![69u8; 16]; + let key = alice_key + .iter() + .zip(bob_key.iter()) + .map(|(a, b)| a ^ b) + .collect::>(); + + tokio::try_join!( + alice.set_key(alice_key.clone()), + bob.set_key(bob_key.clone()) + ) + .unwrap(); + + let input = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + + let (output_a, output_b) = + tokio::try_join!(alice.finalize(input.clone()), bob.finalize(input.clone())).unwrap(); + + let mut ghash = GHash::new_from_slice(&key).unwrap(); + ghash.update_padded(&input); + let expected_output = ghash.finalize(); + + let output: Vec = output_a.iter().zip(output_b).map(|(a, b)| a ^ b).collect(); + assert_eq!(output, expected_output.to_vec()); + } +} diff --git a/components/universal-hash/src/ghash/ghash_inner/mock.rs b/crates/components/universal-hash/src/ghash/ghash_inner/mock.rs similarity index 93% rename from components/universal-hash/src/ghash/ghash_inner/mock.rs rename to crates/components/universal-hash/src/ghash/ghash_inner/mock.rs index 0db00ac54b..a0d5e9d607 100644 --- a/components/universal-hash/src/ghash/ghash_inner/mock.rs +++ b/crates/components/universal-hash/src/ghash/ghash_inner/mock.rs @@ -5,7 +5,7 @@ use mpz_share_conversion::{ use super::{Ghash, GhashConfig}; -/// Create a Ghash sender/receiver pair for testing purpose +/// Create a Ghash sender/receiver pair for testing purpose. pub fn mock_ghash_pair( sender_config: GhashConfig, receiver_config: GhashConfig, diff --git a/components/universal-hash/src/ghash/ghash_inner/mod.rs b/crates/components/universal-hash/src/ghash/ghash_inner/mod.rs similarity index 56% rename from components/universal-hash/src/ghash/ghash_inner/mod.rs rename to crates/components/universal-hash/src/ghash/ghash_inner/mod.rs index 8380c5bd9d..b571d4e1b8 100644 --- a/components/universal-hash/src/ghash/ghash_inner/mod.rs +++ b/crates/components/universal-hash/src/ghash/ghash_inner/mod.rs @@ -6,13 +6,16 @@ use crate::{ UniversalHash, UniversalHashError, }; use async_trait::async_trait; +use mpz_common::{Context, Preprocess}; use mpz_core::Block; -use mpz_share_conversion::{Gf2_128, ShareConversion}; +use mpz_fields::gf2_128::Gf2_128; +use mpz_share_conversion::{ShareConversionError, ShareConvert}; use std::fmt::Debug; +use tracing::instrument; mod config; -#[cfg(feature = "mock")] -pub(crate) mod mock; +#[cfg(feature = "ideal")] +pub(crate) mod ideal; pub use config::{GhashConfig, GhashConfigBuilder, GhashConfigBuilderError}; @@ -23,60 +26,75 @@ enum State { Error, } -/// This is the common instance used by both sender and receiver +/// This is the common instance used by both sender and receiver. /// -/// It is an aio wrapper which mostly uses [GhashCore] for computation -#[derive(Debug)] -pub struct Ghash { +/// It is an aio wrapper which mostly uses [GhashCore] for computation. +pub struct Ghash { state: State, config: GhashConfig, converter: C, + context: Ctx, } -impl Ghash +impl Ghash where - C: ShareConversion + Send + Sync + Debug, + Ctx: Context, + C: ShareConvert, { - /// Creates a new instance + /// Creates a new instance. + /// + /// # Arguments /// - /// * `config` - The configuration for this Ghash instance - /// * `converter` - An instance which allows to convert multiplicative into additive shares - /// and vice versa - #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", ret))] - pub fn new(config: GhashConfig, converter: C) -> Self { + /// * `config` - The configuration for this Ghash instance. + /// * `converter` - An instance which allows to convert multiplicative + /// into additive shares and vice versa. + /// * `context` - The context. + pub fn new(config: GhashConfig, converter: C, context: Ctx) -> Self { Self { state: State::Init, config, converter, + context, } } - /// Computes all the additive shares of the hashkey powers + /// Computes all the additive shares of the hashkey powers. /// /// We need this when the max block count changes. - #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", err))] + #[instrument(level = "debug", skip_all, err)] async fn compute_add_shares( &mut self, core: GhashCore, ) -> Result, UniversalHashError> { let odd_mul_shares = core.odd_mul_shares(); - let add_shares = self.converter.to_additive(odd_mul_shares).await?; + let add_shares = self + .converter + .to_additive(&mut self.context, odd_mul_shares) + .await?; let core = core.add_new_add_shares(&add_shares); Ok(core) } } +impl Debug for Ghash { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Ghash") + .field("state", &self.state) + .field("config", &self.config) + .field("converter", &"{{ .. }}".to_string()) + .finish() + } +} + #[async_trait] -impl UniversalHash for Ghash +impl UniversalHash for Ghash where - C: ShareConversion + Send + Sync + Debug, + Ctx: Context, + C: Preprocess + ShareConvert + Send, { - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip(key), err) - )] + #[instrument(level = "info", fields(thread = %self.context.id()), skip_all, err)] async fn set_key(&mut self, key: Vec) -> Result<(), UniversalHashError> { if key.len() != 16 { return Err(UniversalHashError::KeyLengthError(16, key.len())); @@ -91,10 +109,13 @@ where let mut h_additive = [0u8; 16]; h_additive.copy_from_slice(key.as_slice()); - // GHASH reflects the bits of the key + // GHASH reflects the bits of the key. let h_additive = Gf2_128::new(u128::from_be_bytes(h_additive).reverse_bits()); - let h_multiplicative = self.converter.to_multiplicative(vec![h_additive]).await?; + let h_multiplicative = self + .converter + .to_multiplicative(&mut self.context, vec![h_additive]) + .await?; let core = GhashCore::new(self.config.initial_block_count); let core = core.compute_odd_mul_powers(h_multiplicative[0]); @@ -105,12 +126,27 @@ where Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip_all, err) - )] + #[instrument(level = "debug", fields(thread = %self.context.id()), skip_all, err)] + async fn setup(&mut self) -> Result<(), UniversalHashError> { + // We need only half the number of `max_block_count` M2As because of the free + // squaring trick and we need one extra A2M conversion in the beginning. + // Both M2A and A2M, each require a single OLE. + let ole_count = self.config.max_block_count / 2 + 1; + self.converter.alloc(ole_count); + + Ok(()) + } + + #[instrument(level = "debug", fields(thread = %self.context.id()), skip_all, err)] + async fn preprocess(&mut self) -> Result<(), UniversalHashError> { + self.converter.preprocess(&mut self.context).await?; + + Ok(()) + } + + #[instrument(level = "debug", fields(thread = %self.context.id()), skip_all, err)] async fn finalize(&mut self, mut input: Vec) -> Result, UniversalHashError> { - // Divide by block length and round up + // Divide by block length and round up. let block_count = input.len() / 16 + (input.len() % 16 != 0) as usize; if block_count > self.config.max_block_count { @@ -119,12 +155,12 @@ where let state = std::mem::replace(&mut self.state, State::Error); - // Calling finalize when not setup is a fatal error + // Calling finalize when not setup is a fatal error. let State::Ready { core } = state else { return Err(UniversalHashError::InvalidState("Key not set".to_string())); }; - // Compute new shares if the block count increased + // Compute new shares if the block count increased. let core = if block_count > core.get_max_blocks() { self.compute_add_shares(core.change_max_hashkey(block_count)) .await? @@ -132,10 +168,10 @@ where core }; - // Pad input to a multiple of 16 bytes + // Pad input to a multiple of 16 bytes. input.resize(block_count * 16, 0); - // Convert input to blocks + // Convert input to blocks. let blocks = input .chunks_exact(16) .map(|chunk| { @@ -149,7 +185,7 @@ where .finalize(&blocks) .expect("Input length should be valid"); - // Reinsert state + // Reinsert state. self.state = State::Ready { core }; Ok(tag.to_bytes().to_vec()) @@ -158,49 +194,62 @@ where #[cfg(test)] mod tests { - use super::{mock::mock_ghash_pair, GhashConfig, UniversalHash}; + use crate::{ + ghash::{Ghash, GhashConfig}, + UniversalHash, + }; use ghash_rc::{ universal_hash::{KeyInit, UniversalHash as UniversalHashReference}, GHash as GhashReference, }; + use mpz_common::{executor::test_st_executor, Context}; + use mpz_share_conversion::ideal::{ideal_share_converter, IdealShareConverter}; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha12Rng; - fn create_pair(id: &str, block_count: usize) -> (impl UniversalHash, impl UniversalHash) { + fn create_pair( + block_count: usize, + context_alice: Ctx, + context_bob: Ctx, + ) -> ( + Ghash, + Ghash, + ) { + let (convert_a, convert_b) = ideal_share_converter(); + let config = GhashConfig::builder() - .id(id) .initial_block_count(block_count) .build() .unwrap(); - mock_ghash_pair(config.clone(), config) + ( + Ghash::new(config.clone(), convert_a, context_alice), + Ghash::new(config, convert_b, context_bob), + ) } #[tokio::test] async fn test_ghash_output() { + let (ctx_a, ctx_b) = test_st_executor(8); let mut rng = ChaCha12Rng::from_seed([0; 32]); let h: u128 = rng.gen(); let sender_key: u128 = rng.gen(); let receiver_key: u128 = h ^ sender_key; let message: Vec = (0..128).map(|_| rng.gen()).collect(); - let (mut sender, mut receiver) = create_pair("test", 1); + let (mut sender, mut receiver) = create_pair(1, ctx_a, ctx_b); - let (sender_setup_fut, receiver_setup_fut) = ( + tokio::try_join!( sender.set_key(sender_key.to_be_bytes().to_vec()), - receiver.set_key(receiver_key.to_be_bytes().to_vec()), - ); - - let (sender_result, receiver_result) = tokio::join!(sender_setup_fut, receiver_setup_fut); - sender_result.unwrap(); - receiver_result.unwrap(); + receiver.set_key(receiver_key.to_be_bytes().to_vec()) + ) + .unwrap(); - let sender_share_fut = sender.finalize(message.clone()); - let receiver_share_fut = receiver.finalize(message.clone()); - - let (sender_share, receiver_share) = tokio::join!(sender_share_fut, receiver_share_fut); - let sender_share = sender_share.unwrap(); - let receiver_share = receiver_share.unwrap(); + let (sender_share, receiver_share) = tokio::try_join!( + sender.finalize(message.clone()), + receiver.finalize(message.clone()) + ) + .unwrap(); let tag = sender_share .iter() @@ -213,6 +262,7 @@ mod tests { #[tokio::test] async fn test_ghash_output_padded() { + let (ctx_a, ctx_b) = test_st_executor(8); let mut rng = ChaCha12Rng::from_seed([0; 32]); let h: u128 = rng.gen(); let sender_key: u128 = rng.gen(); @@ -221,23 +271,19 @@ mod tests { // Message length is not a multiple of the block length let message: Vec = (0..126).map(|_| rng.gen()).collect(); - let (mut sender, mut receiver) = create_pair("test", 1); + let (mut sender, mut receiver) = create_pair(1, ctx_a, ctx_b); - let (sender_setup_fut, receiver_setup_fut) = ( + tokio::try_join!( sender.set_key(sender_key.to_be_bytes().to_vec()), - receiver.set_key(receiver_key.to_be_bytes().to_vec()), - ); - - let (sender_result, receiver_result) = tokio::join!(sender_setup_fut, receiver_setup_fut); - sender_result.unwrap(); - receiver_result.unwrap(); - - let sender_share_fut = sender.finalize(message.clone()); - let receiver_share_fut = receiver.finalize(message.clone()); + receiver.set_key(receiver_key.to_be_bytes().to_vec()) + ) + .unwrap(); - let (sender_share, receiver_share) = tokio::join!(sender_share_fut, receiver_share_fut); - let sender_share = sender_share.unwrap(); - let receiver_share = receiver_share.unwrap(); + let (sender_share, receiver_share) = tokio::try_join!( + sender.finalize(message.clone()), + receiver.finalize(message.clone()) + ) + .unwrap(); let tag = sender_share .iter() @@ -250,40 +296,38 @@ mod tests { #[tokio::test] async fn test_ghash_long_message() { + let (ctx_a, ctx_b) = test_st_executor(8); let mut rng = ChaCha12Rng::from_seed([0; 32]); let h: u128 = rng.gen(); let sender_key: u128 = rng.gen(); let receiver_key: u128 = h ^ sender_key; let short_message: Vec = (0..128).map(|_| rng.gen()).collect(); - // A longer message + // A longer message. let long_message: Vec = (0..192).map(|_| rng.gen()).collect(); - // Create and setup sender and receiver for short message length - let (mut sender, mut receiver) = create_pair("test", 1); + // Create and setup sender and receiver for short message length. + let (mut sender, mut receiver) = create_pair(1, ctx_a, ctx_b); - let (sender_setup_fut, receiver_setup_fut) = ( + tokio::try_join!( sender.set_key(sender_key.to_be_bytes().to_vec()), - receiver.set_key(receiver_key.to_be_bytes().to_vec()), - ); - - let (sender_result, receiver_result) = tokio::join!(sender_setup_fut, receiver_setup_fut); - sender_result.unwrap(); - receiver_result.unwrap(); - - // Compute the shares for the short message - let sender_share_fut = sender.finalize(short_message.clone()); - let receiver_share_fut = receiver.finalize(short_message.clone()); - - let (sender_result, receiver_result) = tokio::join!(sender_share_fut, receiver_share_fut); - let (_, _) = (sender_result.unwrap(), receiver_result.unwrap()); - - // Now compute the shares for the longer message - let sender_share_fut = sender.finalize(long_message.clone()); - let receiver_share_fut = receiver.finalize(long_message.clone()); - - let (sender_result, receiver_result) = tokio::join!(sender_share_fut, receiver_share_fut); - let (sender_share, receiver_share) = (sender_result.unwrap(), receiver_result.unwrap()); + receiver.set_key(receiver_key.to_be_bytes().to_vec()) + ) + .unwrap(); + + // Compute the shares for the short message. + tokio::try_join!( + sender.finalize(short_message.clone()), + receiver.finalize(short_message.clone()) + ) + .unwrap(); + + // Now compute the shares for the longer message. + let (sender_share, receiver_share) = tokio::try_join!( + sender.finalize(long_message.clone()), + receiver.finalize(long_message.clone()) + ) + .unwrap(); let tag = sender_share .iter() @@ -293,12 +337,12 @@ mod tests { assert_eq!(tag, ghash_reference_impl(h, &long_message)); - // We should still be able to generate a Ghash output for the shorter message - let sender_share_fut = sender.finalize(short_message.clone()); - let receiver_share_fut = receiver.finalize(short_message.clone()); - - let (sender_result, receiver_result) = tokio::join!(sender_share_fut, receiver_share_fut); - let (sender_share, receiver_share) = (sender_result.unwrap(), receiver_result.unwrap()); + // We should still be able to generate a Ghash output for the shorter message. + let (sender_share, receiver_share) = tokio::try_join!( + sender.finalize(short_message.clone()), + receiver.finalize(short_message.clone()) + ) + .unwrap(); let tag = sender_share .iter() diff --git a/components/universal-hash/src/ghash/mod.rs b/crates/components/universal-hash/src/ghash/mod.rs similarity index 60% rename from components/universal-hash/src/ghash/mod.rs rename to crates/components/universal-hash/src/ghash/mod.rs index 8a9ef449d6..0bb71ddf75 100644 --- a/components/universal-hash/src/ghash/mod.rs +++ b/crates/components/universal-hash/src/ghash/mod.rs @@ -1,6 +1,6 @@ mod ghash_core; mod ghash_inner; -#[cfg(feature = "mock")] -pub use ghash_inner::mock::*; +#[cfg(feature = "ideal")] +pub use ghash_inner::ideal::{ideal_ghash, IdealGhash}; pub use ghash_inner::{Ghash, GhashConfig, GhashConfigBuilder, GhashConfigBuilderError}; diff --git a/components/universal-hash/src/lib.rs b/crates/components/universal-hash/src/lib.rs similarity index 63% rename from components/universal-hash/src/lib.rs rename to crates/components/universal-hash/src/lib.rs index 0082109f5d..8c364e0a97 100644 --- a/components/universal-hash/src/lib.rs +++ b/crates/components/universal-hash/src/lib.rs @@ -1,16 +1,16 @@ -//! A library for computing different kinds of hash functions in a 2PC setting +//! A library for computing different kinds of hash functions in a 2PC setting. #![deny(missing_docs, unreachable_pub, unused_must_use)] #![deny(clippy::all)] #![forbid(unsafe_code)] -/// This module implements [UniversalHash] for Ghash +/// This module implements [UniversalHash] for Ghash. #[cfg(feature = "ghash")] pub mod ghash; use async_trait::async_trait; -/// Errors for [UniversalHash] +/// Errors for [UniversalHash]. #[allow(missing_docs)] #[derive(Debug, thiserror::Error)] pub enum UniversalHashError { @@ -25,16 +25,26 @@ pub enum UniversalHashError { } #[async_trait] -/// A trait supporting different kinds of hash functions +/// A trait supporting different kinds of hash functions. pub trait UniversalHash: Send { - /// Set the key for the hash function + /// Sets the key for the hash function /// - /// * `key` - Key to use for the hash function + /// # Arguments + /// + /// * `key` - Key to use for the hash function. async fn set_key(&mut self, key: Vec) -> Result<(), UniversalHashError>; + /// Performs any necessary one-time setup. + async fn setup(&mut self) -> Result<(), UniversalHashError>; + + /// Preprocesses the hash function. + async fn preprocess(&mut self) -> Result<(), UniversalHashError>; + /// Computes hash of the input, padding the input to the block size /// if needed. /// - /// * `input` - Input to hash + /// # Arguments + /// + /// * `input` - Input to hash. async fn finalize(&mut self, input: Vec) -> Result, UniversalHashError>; } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml new file mode 100644 index 0000000000..c3b9bc794e --- /dev/null +++ b/crates/core/Cargo.toml @@ -0,0 +1,63 @@ +[package] +name = "tlsn-core" +authors = ["TLSNotary Team"] +description = "Core types for TLSNotary" +keywords = ["tls", "mpc", "2pc", "types"] +categories = ["cryptography"] +license = "MIT OR Apache-2.0" +version = "0.1.0-alpha.7" +edition = "2021" + +[features] +default = [] +fixtures = ["dep:hex", "dep:tlsn-data-fixtures"] +# Enables support for Poseidon hashes as specified in the `poseidon-halo2` crate. +# This feature can only be used if the prover and the verifier enable the `authdecode_unsafe` feature. +# This feature is EXPERIMENTAL and will be removed in future releases without prior notice. +poseidon = ["poseidon-halo2"] + +[dependencies] +tlsn-data-fixtures = { workspace = true, optional = true } +tlsn-tls-core = { workspace = true, features = ["serde"] } +tlsn-utils = { workspace = true } +tlsn-utils-aio = { workspace = true } + +mpz-core = { workspace = true } +mpz-garble-core = { workspace = true } +mpz-circuits = { workspace = true } + +bcs = { workspace = true } +bimap = { version = "0.6", features = ["serde"] } +blake3 = { workspace = true } +derive_builder = { workspace = true } +getset = "0.1.2" +hex = { workspace = true, optional = true } +k256 = { workspace = true } +opaque-debug = { workspace = true } +p256 = { workspace = true, features = ["serde"] } +poseidon-halo2 = { workspace = true, optional = true } +rand = { workspace = true } +rand_core = { workspace = true } +rs_merkle = { workspace = true, features = ["serde"] } +rstest = { workspace = true, optional = true } +serde = { workspace = true } +sha2 = { workspace = true } +thiserror = { workspace = true } +tiny-keccak = { version = "2.0", features = ["keccak"] } +web-time = { workspace = true } +webpki-roots = { workspace = true } + +[dev-dependencies] +rstest = { workspace = true } +hex = { workspace = true } +rand_chacha = { workspace = true } +bincode = { workspace = true } +tlsn-data-fixtures = { workspace = true } + +[[test]] +name = "api" +required-features = ["fixtures"] + +[target.'cfg(target_arch = "wasm32")'.dependencies] +ring = { version = "0.17", features = ["wasm32_unknown_unknown_js"] } +getrandom = { version = "0.2", features = ["js"] } diff --git a/crates/core/src/attestation.rs b/crates/core/src/attestation.rs new file mode 100644 index 0000000000..0ae8a67b61 --- /dev/null +++ b/crates/core/src/attestation.rs @@ -0,0 +1,274 @@ +//! Attestation types. +//! +//! An attestation is a cryptographically signed document issued by a Notary who +//! witnessed a TLS connection. It contains various fields which can be used to +//! verify statements about the connection and the associated application data. +//! +//! Attestations are comprised of two parts: a [`Header`] and a [`Body`]. +//! +//! The header is the data structure which is signed by a Notary. It +//! contains a unique identifier, the protocol version, and a Merkle root +//! of the body fields. +//! +//! The body contains the fields of the attestation. These fields include data +//! which can be used to verify aspects of a TLS connection, such as the +//! server's identity, and facts about the transcript. + +mod builder; +mod config; +mod proof; + +use std::fmt; + +use rand::distributions::{Distribution, Standard}; +use serde::{Deserialize, Serialize}; + +use crate::{ + connection::{ConnectionInfo, ServerCertCommitment, ServerEphemKey}, + hash::{impl_domain_separator, Hash, HashAlgorithm, HashAlgorithmExt, TypedHash}, + merkle::MerkleTree, + presentation::PresentationBuilder, + signing::{Signature, VerifyingKey}, + transcript::{encoding::EncodingCommitment, PlaintextHash}, + CryptoProvider, +}; + +pub use builder::{AttestationBuilder, AttestationBuilderError}; +pub use config::{AttestationConfig, AttestationConfigBuilder, AttestationConfigError}; +pub use proof::{AttestationError, AttestationProof}; + +/// Current version of attestations. +pub const VERSION: Version = Version(0); + +/// The maximum total number of fields allowed in the attestation. +pub const MAX_TOTAL_FIELDS: u32 = 1024; + +/// The maximum total number of plaintext hash commitments allowed in the attestation. +pub const MAX_TOTAL_PLAINTEXT_HASH: u32 = 512; + +/// The initial id for a plaintext hash commitment field. +pub const PLAINTEXT_HASH_INITIAL_FIELD_ID: u32 = MAX_TOTAL_FIELDS - MAX_TOTAL_PLAINTEXT_HASH - 1; + +/// Unique identifier for an attestation. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct Uid(pub [u8; 16]); + +impl From<[u8; 16]> for Uid { + fn from(id: [u8; 16]) -> Self { + Self(id) + } +} + +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> Uid { + Uid(self.sample(rng)) + } +} + +/// Version of an attestation. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct Version(u32); + +impl_domain_separator!(Version); + +/// Public attestation field. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Field { + /// Identifier of the field. + pub id: FieldId, + /// Field data. + pub data: T, +} + +/// Identifier for a field. +#[derive( + Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, +)] +pub struct FieldId(pub u32); + +impl FieldId { + /// Creates a new `FieldId` with the given initial id value. + pub(crate) fn new(initial_id: u32) -> Self { + Self(initial_id) + } + + pub(crate) fn next(&mut self, data: T) -> Field { + let id = *self; + self.0 += 1; + + Field { id, data } + } +} + +impl fmt::Display for FieldId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Kind of an attestation field. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[repr(u8)] +pub enum FieldKind { + /// Connection information. + ConnectionInfo = 0x01, + /// Server ephemeral key. + ServerEphemKey = 0x02, + /// Server identity commitment. + ServerIdentityCommitment = 0x03, + /// Encoding commitment. + EncodingCommitment = 0x04, + /// Plaintext hash commitment. + PlaintextHash = 0x05, +} + +/// Attestation header. +/// +/// See [module level documentation](crate::attestation) for more information. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Header { + /// An identifier for the attestation. + pub id: Uid, + /// Version of the attestation. + pub version: Version, + /// Merkle root of the attestation fields. + pub root: TypedHash, +} + +impl_domain_separator!(Header); + +/// Attestation body. +/// +/// See [module level documentation](crate::attestation) for more information. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Body { + verifying_key: Field, + connection_info: Field, + server_ephemeral_key: Field, + cert_commitment: Field, + encoding_commitment: Option>, + plaintext_hashes: Option>>, +} + +impl Body { + /// Returns the attestation verifying key. + pub fn verifying_key(&self) -> &VerifyingKey { + &self.verifying_key.data + } + + /// Computes the Merkle root of the attestation fields. + /// + /// This is only used when building an attestation. + pub(crate) fn root(&self, hasher: &dyn HashAlgorithm) -> TypedHash { + let mut tree = MerkleTree::new(hasher.id()); + let fields = self + .hash_fields(hasher) + .into_iter() + .map(|(_, hash)| hash) + .collect::>(); + tree.insert(hasher, fields); + tree.root() + } + + /// Returns the fields of the body hashed and sorted by id. + /// + /// Each field is hashed with a domain separator to mitigate type confusion + /// attacks. + /// + /// # Note + /// + /// The order of fields is not stable across versions. + pub(crate) fn hash_fields(&self, hasher: &dyn HashAlgorithm) -> Vec<(FieldId, Hash)> { + // CRITICAL: ensure all fields are included! If a new field is added to the + // struct without including it here it will not be verified to be + // included in the attestation. + let Self { + verifying_key, + connection_info: conn_info, + server_ephemeral_key, + cert_commitment, + encoding_commitment, + plaintext_hashes, + } = self; + + let mut fields: Vec<(FieldId, Hash)> = vec![ + (verifying_key.id, hasher.hash_separated(&verifying_key.data)), + (conn_info.id, hasher.hash_separated(&conn_info.data)), + ( + server_ephemeral_key.id, + hasher.hash_separated(&server_ephemeral_key.data), + ), + ( + cert_commitment.id, + hasher.hash_separated(&cert_commitment.data), + ), + ]; + + if let Some(encoding_commitment) = encoding_commitment { + fields.push(( + encoding_commitment.id, + hasher.hash_separated(&encoding_commitment.data), + )); + } + + if let Some(plaintext_hashes) = plaintext_hashes { + for field in plaintext_hashes.iter() { + fields.push((field.id, hasher.hash_separated(&field.data))); + } + } + + fields.sort_by_key(|(id, _)| *id); + fields + } + + /// Returns the connection information. + pub(crate) fn connection_info(&self) -> &ConnectionInfo { + &self.connection_info.data + } + + pub(crate) fn server_ephemeral_key(&self) -> &ServerEphemKey { + &self.server_ephemeral_key.data + } + + pub(crate) fn cert_commitment(&self) -> &ServerCertCommitment { + &self.cert_commitment.data + } + + /// Returns the encoding commitment. + pub(crate) fn encoding_commitment(&self) -> Option<&EncodingCommitment> { + self.encoding_commitment.as_ref().map(|field| &field.data) + } + + /// Returns the plaintext hash commitments. + pub(crate) fn plaintext_hashes(&self) -> &Option>> { + &self.plaintext_hashes + } +} + +/// An attestation. +/// +/// See [module level documentation](crate::attestation) for more information. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Attestation { + /// The signature of the attestation. + pub signature: Signature, + /// The attestation header. + pub header: Header, + /// The attestation body. + pub body: Body, +} + +impl Attestation { + /// Returns an attestation builder. + pub fn builder(config: &AttestationConfig) -> AttestationBuilder<'_> { + AttestationBuilder::new(config) + } + + /// Returns a presentation builder. + pub fn presentation_builder<'a>( + &'a self, + provider: &'a CryptoProvider, + ) -> PresentationBuilder<'a> { + PresentationBuilder::new(provider, self) + } +} diff --git a/crates/core/src/attestation/builder.rs b/crates/core/src/attestation/builder.rs new file mode 100644 index 0000000000..ce0bfd4b55 --- /dev/null +++ b/crates/core/src/attestation/builder.rs @@ -0,0 +1,537 @@ +use std::error::Error; + +use rand::{thread_rng, Rng}; + +use crate::{ + attestation::{ + Attestation, AttestationConfig, Body, EncodingCommitment, FieldId, FieldKind, Header, + ServerCertCommitment, MAX_TOTAL_PLAINTEXT_HASH, PLAINTEXT_HASH_INITIAL_FIELD_ID, VERSION, + }, + connection::{ConnectionInfo, ServerEphemKey}, + hash::{HashAlgId, TypedHash}, + request::Request, + serialize::CanonicalSerialize, + signing::SignatureAlgId, + transcript::PlaintextHash, + CryptoProvider, +}; + +use super::Field; + +/// Attestation builder state for accepting a request. +pub struct Accept { + /// A collection of authenticated plaintext hashes. + /// + /// The request must contain plaintext hashes only from this collection. + plaintext_hashes: Option>, +} + +pub struct Sign { + signature_alg: SignatureAlgId, + hash_alg: HashAlgId, + connection_info: Option, + server_ephemeral_key: Option, + cert_commitment: ServerCertCommitment, + encoding_commitment_root: Option, + encoding_seed: Option>, + /// Plaintext hash commitments sorted by field id. + /// + /// The field ids start from [PLAINTEXT_HASH_INITIAL_FIELD_ID]. + plaintext_hashes: Option>>, +} + +/// An attestation builder. +pub struct AttestationBuilder<'a, T = Accept> { + config: &'a AttestationConfig, + state: T, +} + +impl<'a> AttestationBuilder<'a, Accept> { + /// Creates a new attestation builder. + pub fn new(config: &'a AttestationConfig) -> Self { + Self { + config, + state: Accept { + plaintext_hashes: None, + }, + } + } + + /// Accepts the attestation request. + pub fn accept_request( + self, + request: Request, + ) -> Result, AttestationBuilderError> { + let config = self.config; + + let Request { + signature_alg, + hash_alg, + server_cert_commitment: cert_commitment, + encoding_commitment_root, + plaintext_hashes, + } = request; + + if !config.supported_signature_algs().contains(&signature_alg) { + return Err(AttestationBuilderError::new( + ErrorKind::Request, + format!("unsupported signature algorithm: {signature_alg}"), + )); + } + + if !config.supported_hash_algs().contains(&hash_alg) { + return Err(AttestationBuilderError::new( + ErrorKind::Request, + format!("unsupported hash algorithm: {hash_alg}"), + )); + } + + if encoding_commitment_root.is_some() + && !config + .supported_fields() + .contains(&FieldKind::EncodingCommitment) + { + return Err(AttestationBuilderError::new( + ErrorKind::Request, + "encoding commitment is not supported", + )); + } + + match (&plaintext_hashes, &self.state.plaintext_hashes) { + (Some(request_hashes), Some(authed_hashes)) => { + if request_hashes.len() > MAX_TOTAL_PLAINTEXT_HASH as usize { + return Err(AttestationBuilderError::new( + ErrorKind::Request, + "exceeded maximum allowed number of plaintext hashes", + )); + } + + for (index, hash) in request_hashes.iter().enumerate() { + if hash.id.0 != PLAINTEXT_HASH_INITIAL_FIELD_ID + index as u32 { + return Err(AttestationBuilderError::new( + ErrorKind::Request, + "unexpected field id for a plaintext hash", + )); + } + + if !authed_hashes.contains(&hash.data) { + return Err(AttestationBuilderError::new( + ErrorKind::Request, + "unexpected plaintext hash value", + )); + } + } + } + + (Some(_request_hashes), None) => { + return Err(AttestationBuilderError::new( + ErrorKind::Request, + "cannot accept plaintext hashes", + )); + } + + _ => {} + }; + + Ok(AttestationBuilder { + config: self.config, + state: Sign { + signature_alg, + hash_alg, + connection_info: None, + server_ephemeral_key: None, + cert_commitment, + encoding_commitment_root, + encoding_seed: None, + plaintext_hashes, + }, + }) + } + + /// Sets the authenticated plaintext hashes. + pub fn plaintext_hashes(&mut self, hashes: Vec) -> &mut Self { + self.state.plaintext_hashes = Some(hashes); + self + } +} + +impl AttestationBuilder<'_, Sign> { + /// Sets the connection information. + pub fn connection_info(&mut self, connection_info: ConnectionInfo) -> &mut Self { + self.state.connection_info = Some(connection_info); + self + } + + /// Sets the server ephemeral key. + pub fn server_ephemeral_key(&mut self, key: ServerEphemKey) -> &mut Self { + self.state.server_ephemeral_key = Some(key); + self + } + + /// Sets the encoding seed. + pub fn encoding_seed(&mut self, seed: Vec) -> &mut Self { + self.state.encoding_seed = Some(seed); + self + } + + /// Builds the attestation. + pub fn build(self, provider: &CryptoProvider) -> Result { + let Sign { + signature_alg, + hash_alg, + connection_info, + server_ephemeral_key, + cert_commitment, + encoding_commitment_root, + encoding_seed, + plaintext_hashes, + } = self.state; + + let hasher = provider.hash.get(&hash_alg).map_err(|_| { + AttestationBuilderError::new( + ErrorKind::Config, + format!("accepted hash algorithm {hash_alg} but it's missing in the provider"), + ) + })?; + let signer = provider.signer.get(&signature_alg).map_err(|_| { + AttestationBuilderError::new( + ErrorKind::Config, + format!( + "accepted signature algorithm {signature_alg} but it's missing in the provider" + ), + ) + })?; + + let encoding_commitment = if let Some(root) = encoding_commitment_root { + let Some(seed) = encoding_seed else { + return Err(AttestationBuilderError::new( + ErrorKind::Field, + "encoding commitment requested but seed was not set", + )); + }; + + Some(EncodingCommitment { root, seed }) + } else { + None + }; + + let mut field_id = FieldId::default(); + + let body = Body { + verifying_key: field_id.next(signer.verifying_key()), + connection_info: field_id.next(connection_info.ok_or_else(|| { + AttestationBuilderError::new(ErrorKind::Field, "connection info was not set") + })?), + server_ephemeral_key: field_id.next(server_ephemeral_key.ok_or_else(|| { + AttestationBuilderError::new(ErrorKind::Field, "handshake data was not set") + })?), + cert_commitment: field_id.next(cert_commitment), + encoding_commitment: encoding_commitment.map(|commitment| field_id.next(commitment)), + plaintext_hashes, + }; + + // Make sure there was no collision with plaintext hash field ids. + if field_id.next(()).id.0 > PLAINTEXT_HASH_INITIAL_FIELD_ID { + return Err(AttestationBuilderError::new( + ErrorKind::Field, + "plaintext hash field id collision detected", + )); + } + + let header = Header { + id: thread_rng().gen(), + version: VERSION, + root: body.root(hasher), + }; + + let signature = signer + .sign(&CanonicalSerialize::serialize(&header)) + .map_err(|err| AttestationBuilderError::new(ErrorKind::Signature, err))?; + + Ok(Attestation { + signature, + header, + body, + }) + } +} + +/// Error for [`AttestationBuilder`]. +#[derive(Debug, thiserror::Error)] +pub struct AttestationBuilderError { + kind: ErrorKind, + source: Option>, +} + +#[derive(Debug)] +enum ErrorKind { + Request, + Config, + Field, + Signature, +} + +impl AttestationBuilderError { + fn new(kind: ErrorKind, error: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(error.into()), + } + } + + /// Returns whether the error originates from a bad request. + pub fn is_request(&self) -> bool { + matches!(self.kind, ErrorKind::Request) + } +} + +impl std::fmt::Display for AttestationBuilderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.kind { + ErrorKind::Request => f.write_str("request error")?, + ErrorKind::Config => f.write_str("config error")?, + ErrorKind::Field => f.write_str("field error")?, + ErrorKind::Signature => f.write_str("signature error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use rstest::{fixture, rstest}; + use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON}; + + use crate::{ + connection::{HandshakeData, HandshakeDataV1_2}, + fixtures::{encoder_seed, encoding_provider, ConnectionFixture}, + hash::Blake3, + request::RequestConfig, + transcript::{encoding::EncodingTree, Transcript, TranscriptCommitConfigBuilder}, + }; + + use super::*; + + fn request_and_connection() -> (Request, ConnectionFixture) { + let provider = CryptoProvider::default(); + + let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON); + let (sent_len, recv_len) = transcript.len(); + // Plaintext encodings which the Prover obtained from GC evaluation + let encodings_provider = encoding_provider(GET_WITH_HEADER, OK_JSON); + + // At the end of the TLS connection the Prover holds the: + let ConnectionFixture { + server_name, + server_cert_data, + .. + } = ConnectionFixture::tlsnotary(transcript.length()); + + // Prover specifies the ranges it wants to commit to. + let mut transcript_commitment_builder = TranscriptCommitConfigBuilder::new(&transcript); + transcript_commitment_builder + .commit_sent(&(0..sent_len)) + .unwrap() + .commit_recv(&(0..recv_len)) + .unwrap(); + + let transcripts_commitment_config = transcript_commitment_builder.build().unwrap(); + + // Prover constructs encoding tree. + let encoding_tree = EncodingTree::new( + &Blake3::default(), + transcripts_commitment_config.iter_encoding(), + &encodings_provider, + &transcript.length(), + ) + .unwrap(); + + let request_config = RequestConfig::default(); + let mut request_builder = Request::builder(&request_config); + + request_builder + .server_name(server_name.clone()) + .server_cert_data(server_cert_data) + .transcript(transcript.clone()) + .encoding_tree(encoding_tree); + let (request, _) = request_builder.build(&provider).unwrap(); + + (request, ConnectionFixture::tlsnotary(transcript.length())) + } + + #[fixture] + #[once] + fn default_attestation_config() -> AttestationConfig { + AttestationConfig::builder() + .supported_signature_algs([SignatureAlgId::SECP256K1]) + .build() + .unwrap() + } + + #[fixture] + #[once] + fn crypto_provider() -> CryptoProvider { + let mut provider = CryptoProvider::default(); + provider.signer.set_secp256k1(&[42u8; 32]).unwrap(); + provider + } + + #[rstest] + fn test_attestation_builder_accept_unsupported_signer() { + let (request, _) = request_and_connection(); + let attestation_config = AttestationConfig::builder() + .supported_signature_algs([SignatureAlgId::SECP256R1]) + .build() + .unwrap(); + + let err = Attestation::builder(&attestation_config) + .accept_request(request) + .err() + .unwrap(); + assert!(err.is_request()); + } + + #[rstest] + fn test_attestation_builder_accept_unsupported_hasher() { + let (request, _) = request_and_connection(); + + let attestation_config = AttestationConfig::builder() + .supported_signature_algs([SignatureAlgId::SECP256K1]) + .supported_hash_algs([HashAlgId::KECCAK256]) + .build() + .unwrap(); + + let err = Attestation::builder(&attestation_config) + .accept_request(request) + .err() + .unwrap(); + assert!(err.is_request()); + } + + #[rstest] + fn test_attestation_builder_accept_unsupported_encoding_commitment() { + let (request, _) = request_and_connection(); + + let attestation_config = AttestationConfig::builder() + .supported_signature_algs([SignatureAlgId::SECP256K1]) + .supported_fields([ + FieldKind::ConnectionInfo, + FieldKind::ServerEphemKey, + FieldKind::ServerIdentityCommitment, + ]) + .build() + .unwrap(); + + let err = Attestation::builder(&attestation_config) + .accept_request(request) + .err() + .unwrap(); + assert!(err.is_request()); + } + + #[rstest] + fn test_attestation_builder_sign_missing_signer( + default_attestation_config: &AttestationConfig, + ) { + let (request, _) = request_and_connection(); + + let attestation_builder = Attestation::builder(default_attestation_config) + .accept_request(request.clone()) + .unwrap(); + + let mut provider = CryptoProvider::default(); + provider.signer.set_secp256r1(&[42u8; 32]).unwrap(); + + let err = attestation_builder.build(&provider).err().unwrap(); + assert!(matches!(err.kind, ErrorKind::Config)); + } + + #[rstest] + fn test_attestation_builder_sign_missing_encoding_seed( + default_attestation_config: &AttestationConfig, + crypto_provider: &CryptoProvider, + ) { + let (request, connection) = request_and_connection(); + + let mut attestation_builder = Attestation::builder(default_attestation_config) + .accept_request(request.clone()) + .unwrap(); + + let ConnectionFixture { + connection_info, + server_cert_data, + .. + } = connection; + + let HandshakeData::V1_2(HandshakeDataV1_2 { + server_ephemeral_key, + .. + }) = server_cert_data.handshake.clone(); + + attestation_builder + .connection_info(connection_info.clone()) + .server_ephemeral_key(server_ephemeral_key); + + let err = attestation_builder.build(crypto_provider).err().unwrap(); + assert!(matches!(err.kind, ErrorKind::Field)); + } + + #[rstest] + fn test_attestation_builder_sign_missing_server_ephemeral_key( + default_attestation_config: &AttestationConfig, + crypto_provider: &CryptoProvider, + ) { + let (request, connection) = request_and_connection(); + + let mut attestation_builder = Attestation::builder(default_attestation_config) + .accept_request(request.clone()) + .unwrap(); + + let ConnectionFixture { + connection_info, .. + } = connection; + + attestation_builder + .connection_info(connection_info.clone()) + .encoding_seed(encoder_seed().to_vec()); + + let err = attestation_builder.build(crypto_provider).err().unwrap(); + assert!(matches!(err.kind, ErrorKind::Field)); + } + + #[rstest] + fn test_attestation_builder_sign_missing_connection_info( + default_attestation_config: &AttestationConfig, + crypto_provider: &CryptoProvider, + ) { + let (request, connection) = request_and_connection(); + + let mut attestation_builder = Attestation::builder(default_attestation_config) + .accept_request(request.clone()) + .unwrap(); + + let ConnectionFixture { + server_cert_data, .. + } = connection; + + let HandshakeData::V1_2(HandshakeDataV1_2 { + server_ephemeral_key, + .. + }) = server_cert_data.handshake.clone(); + + attestation_builder + .server_ephemeral_key(server_ephemeral_key) + .encoding_seed(encoder_seed().to_vec()); + + let err = attestation_builder.build(crypto_provider).err().unwrap(); + assert!(matches!(err.kind, ErrorKind::Field)); + } +} diff --git a/crates/core/src/attestation/config.rs b/crates/core/src/attestation/config.rs new file mode 100644 index 0000000000..ae1619d5d9 --- /dev/null +++ b/crates/core/src/attestation/config.rs @@ -0,0 +1,124 @@ +use crate::{ + attestation::FieldKind, + hash::{HashAlgId, DEFAULT_SUPPORTED_HASH_ALGS}, + signing::SignatureAlgId, +}; + +const DEFAULT_SUPPORTED_FIELDS: &[FieldKind] = &[ + FieldKind::ConnectionInfo, + FieldKind::ServerEphemKey, + FieldKind::ServerIdentityCommitment, + FieldKind::EncodingCommitment, +]; + +#[derive(Debug)] +#[allow(dead_code)] +enum ErrorKind { + Builder, +} + +impl std::fmt::Display for ErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ErrorKind::Builder => write!(f, "builder"), + } + } +} + +/// Error for [`AttestationConfig`]. +#[derive(Debug, thiserror::Error)] +#[error("attestation config error: kind: {kind}, reason: {reason}")] +pub struct AttestationConfigError { + kind: ErrorKind, + reason: String, +} + +impl AttestationConfigError { + #[allow(dead_code)] + fn builder(reason: impl Into) -> Self { + Self { + kind: ErrorKind::Builder, + reason: reason.into(), + } + } +} + +/// Attestation configuration. +#[derive(Debug, Clone)] +pub struct AttestationConfig { + supported_signature_algs: Vec, + supported_hash_algs: Vec, + supported_fields: Vec, +} + +impl AttestationConfig { + /// Creates a new builder. + pub fn builder() -> AttestationConfigBuilder { + AttestationConfigBuilder::default() + } + + pub(crate) fn supported_signature_algs(&self) -> &[SignatureAlgId] { + &self.supported_signature_algs + } + + pub(crate) fn supported_hash_algs(&self) -> &[HashAlgId] { + &self.supported_hash_algs + } + + pub(crate) fn supported_fields(&self) -> &[FieldKind] { + &self.supported_fields + } +} + +/// Builder for [`AttestationConfig`]. +#[derive(Debug)] +pub struct AttestationConfigBuilder { + supported_signature_algs: Vec, + supported_hash_algs: Vec, + supported_fields: Vec, +} + +impl Default for AttestationConfigBuilder { + fn default() -> Self { + Self { + supported_signature_algs: Vec::default(), + supported_hash_algs: DEFAULT_SUPPORTED_HASH_ALGS.to_vec(), + supported_fields: DEFAULT_SUPPORTED_FIELDS.to_vec(), + } + } +} + +impl AttestationConfigBuilder { + /// Sets the supported signature algorithms. + pub fn supported_signature_algs( + &mut self, + supported_signature_algs: impl Into>, + ) -> &mut Self { + self.supported_signature_algs = supported_signature_algs.into(); + self + } + + /// Sets the supported hash algorithms. + pub fn supported_hash_algs( + &mut self, + supported_hash_algs: impl Into>, + ) -> &mut Self { + self.supported_hash_algs = supported_hash_algs.into(); + self + } + + /// Sets the supported attestation fields. + pub fn supported_fields(&mut self, supported_fields: impl Into>) -> &mut Self { + self.supported_fields = supported_fields.into(); + self + } + + /// Builds the configuration. + pub fn build(&self) -> Result { + Ok(AttestationConfig { + supported_signature_algs: self.supported_signature_algs.clone(), + supported_hash_algs: self.supported_hash_algs.clone(), + supported_fields: self.supported_fields.clone(), + }) + } +} diff --git a/crates/core/src/attestation/proof.rs b/crates/core/src/attestation/proof.rs new file mode 100644 index 0000000000..11c6eaeeed --- /dev/null +++ b/crates/core/src/attestation/proof.rs @@ -0,0 +1,178 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +use crate::{ + attestation::{Attestation, Body, Header}, + hash::HashAlgorithm, + merkle::{MerkleProof, MerkleTree}, + serialize::CanonicalSerialize, + signing::{Signature, VerifyingKey}, + CryptoProvider, +}; + +/// Proof of an attestation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AttestationProof { + signature: Signature, + header: Header, + body: BodyProof, +} + +impl AttestationProof { + pub(crate) fn new( + provider: &CryptoProvider, + attestation: &Attestation, + ) -> Result { + let hasher = provider + .hash + .get(&attestation.header.root.alg) + .map_err(|e| AttestationError::new(ErrorKind::Provider, e))?; + + let body = BodyProof::new(hasher, attestation.body.clone())?; + + Ok(Self { + signature: attestation.signature.clone(), + header: attestation.header.clone(), + body, + }) + } + + /// Returns the verifying key. + pub fn verifying_key(&self) -> &VerifyingKey { + self.body.verifying_key() + } + + /// Verifies the attestation proof. + /// + /// # Arguments + /// + /// * `provider` - Cryptography provider. + /// * `verifying_key` - Verifying key for the Notary signature. + pub fn verify(self, provider: &CryptoProvider) -> Result { + let signature_verifier = provider + .signature + .get(&self.signature.alg) + .map_err(|e| AttestationError::new(ErrorKind::Provider, e))?; + + // Verify body corresponding to the header. + let body = self.body.verify_with_provider(provider, &self.header)?; + + // Verify signature of the header. + signature_verifier + .verify( + &body.verifying_key.data, + &CanonicalSerialize::serialize(&self.header), + &self.signature.data, + ) + .map_err(|e| AttestationError::new(ErrorKind::Signature, e))?; + + Ok(Attestation { + signature: self.signature, + header: self.header, + body, + }) + } +} + +/// Proof of an attestation body. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct BodyProof { + body: Body, + proof: MerkleProof, +} + +impl BodyProof { + /// Returns a new body proof. + // TODO: Support including a subset of fields instead of the entire body. + pub(crate) fn new( + hasher: &dyn HashAlgorithm, + body: Body, + ) -> Result { + let (indices, leaves): (Vec<_>, Vec<_>) = body + .hash_fields(hasher) + .into_iter() + .map(|(id, hash)| (id.0 as usize, hash)) + .unzip(); + + let mut tree = MerkleTree::new(hasher.id()); + tree.insert(hasher, leaves); + + let proof = tree.proof(&indices); + + Ok(BodyProof { body, proof }) + } + + pub(crate) fn verifying_key(&self) -> &VerifyingKey { + &self.body.verifying_key.data + } + + /// Verifies the proof against the attestation header. + pub(crate) fn verify_with_provider( + self, + provider: &CryptoProvider, + header: &Header, + ) -> Result { + let hasher = provider + .hash + .get(&header.root.alg) + .map_err(|e| AttestationError::new(ErrorKind::Provider, e))?; + + let fields = self + .body + .hash_fields(hasher) + .into_iter() + .enumerate() + .map(|(idx, (_, hash))| (idx, hash)); + + self.proof + .verify(hasher, &header.root, fields) + .map_err(|e| AttestationError::new(ErrorKind::Body, e))?; + + Ok(self.body) + } +} + +/// Error for [`AttestationProof`]. +#[derive(Debug, thiserror::Error)] +pub struct AttestationError { + kind: ErrorKind, + source: Option>, +} + +impl AttestationError { + fn new(kind: ErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } +} + +impl fmt::Display for AttestationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("attestation proof error: ")?; + + match self.kind { + ErrorKind::Provider => f.write_str("provider error")?, + ErrorKind::Signature => f.write_str("signature error")?, + ErrorKind::Body => f.write_str("body proof error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +#[derive(Debug)] +enum ErrorKind { + Provider, + Signature, + Body, +} diff --git a/crates/core/src/connection.rs b/crates/core/src/connection.rs new file mode 100644 index 0000000000..30a30324b2 --- /dev/null +++ b/crates/core/src/connection.rs @@ -0,0 +1,666 @@ +//! TLS connection types. +//! +//! ## Commitment +//! +//! During the TLS handshake the Notary receives the Server's ephemeral public +//! key, and this key serves as a binding commitment to the identity of the +//! Server. The ephemeral key itself does not reveal the Server's identity, but +//! it is bound to it via a signature created using the Server's +//! X.509 certificate. +//! +//! A Prover can withhold the Server's signature and certificate chain from the +//! Notary to improve privacy and censorship resistance. +//! +//! ## Proving the Server's identity +//! +//! A Prover can prove the Server's identity to a Verifier by sending a +//! [`ServerIdentityProof`]. This proof contains all the information required to +//! establish the link between the TLS connection and the Server's X.509 +//! certificate. A Verifier checks the Server's certificate against their own +//! trust anchors, the same way a typical TLS client would. + +mod commit; +mod proof; + +use std::fmt; + +use serde::{Deserialize, Serialize}; +use tls_core::{ + msgs::{ + codec::Codec, + enums::NamedGroup, + handshake::{DigitallySignedStruct, ServerECDHParams}, + }, + verify::ServerCertVerifier as _, +}; +use web_time::{Duration, UNIX_EPOCH}; + +use crate::{hash::impl_domain_separator, CryptoProvider}; + +pub use commit::{ServerCertCommitment, ServerCertOpening}; +pub use proof::{ServerIdentityProof, ServerIdentityProofError}; + +/// TLS version. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum TlsVersion { + /// TLS 1.2. + V1_2, + /// TLS 1.3. + V1_3, +} + +impl TryFrom for TlsVersion { + type Error = &'static str; + + fn try_from(value: tls_core::msgs::enums::ProtocolVersion) -> Result { + Ok(match value { + tls_core::msgs::enums::ProtocolVersion::TLSv1_2 => TlsVersion::V1_2, + tls_core::msgs::enums::ProtocolVersion::TLSv1_3 => TlsVersion::V1_3, + _ => return Err("unsupported TLS version"), + }) + } +} + +/// Server's name, a.k.a. the DNS name. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ServerName(String); + +impl ServerName { + /// Creates a new server name. + pub fn new(name: String) -> Self { + Self(name) + } + + /// Returns the name as a string. + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl From<&str> for ServerName { + fn from(name: &str) -> Self { + Self(name.to_string()) + } +} + +impl AsRef for ServerName { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for ServerName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Type of a public key. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +#[non_exhaustive] +#[allow(non_camel_case_types)] +pub enum KeyType { + /// secp256r1. + SECP256R1 = 0x0017, +} + +/// Signature scheme on the key exchange parameters. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +#[allow(non_camel_case_types, missing_docs)] +pub enum SignatureScheme { + RSA_PKCS1_SHA1 = 0x0201, + ECDSA_SHA1_Legacy = 0x0203, + RSA_PKCS1_SHA256 = 0x0401, + ECDSA_NISTP256_SHA256 = 0x0403, + RSA_PKCS1_SHA384 = 0x0501, + ECDSA_NISTP384_SHA384 = 0x0503, + RSA_PKCS1_SHA512 = 0x0601, + ECDSA_NISTP521_SHA512 = 0x0603, + RSA_PSS_SHA256 = 0x0804, + RSA_PSS_SHA384 = 0x0805, + RSA_PSS_SHA512 = 0x0806, + ED25519 = 0x0807, +} + +impl TryFrom for SignatureScheme { + type Error = &'static str; + + fn try_from(value: tls_core::msgs::enums::SignatureScheme) -> Result { + use tls_core::msgs::enums::SignatureScheme as Core; + use SignatureScheme::*; + Ok(match value { + Core::RSA_PKCS1_SHA1 => RSA_PKCS1_SHA1, + Core::ECDSA_SHA1_Legacy => ECDSA_SHA1_Legacy, + Core::RSA_PKCS1_SHA256 => RSA_PKCS1_SHA256, + Core::ECDSA_NISTP256_SHA256 => ECDSA_NISTP256_SHA256, + Core::RSA_PKCS1_SHA384 => RSA_PKCS1_SHA384, + Core::ECDSA_NISTP384_SHA384 => ECDSA_NISTP384_SHA384, + Core::RSA_PKCS1_SHA512 => RSA_PKCS1_SHA512, + Core::ECDSA_NISTP521_SHA512 => ECDSA_NISTP521_SHA512, + Core::RSA_PSS_SHA256 => RSA_PSS_SHA256, + Core::RSA_PSS_SHA384 => RSA_PSS_SHA384, + Core::RSA_PSS_SHA512 => RSA_PSS_SHA512, + Core::ED25519 => ED25519, + _ => return Err("unsupported signature scheme"), + }) + } +} + +impl From for tls_core::msgs::enums::SignatureScheme { + fn from(value: SignatureScheme) -> Self { + use tls_core::msgs::enums::SignatureScheme::*; + match value { + SignatureScheme::RSA_PKCS1_SHA1 => RSA_PKCS1_SHA1, + SignatureScheme::ECDSA_SHA1_Legacy => ECDSA_SHA1_Legacy, + SignatureScheme::RSA_PKCS1_SHA256 => RSA_PKCS1_SHA256, + SignatureScheme::ECDSA_NISTP256_SHA256 => ECDSA_NISTP256_SHA256, + SignatureScheme::RSA_PKCS1_SHA384 => RSA_PKCS1_SHA384, + SignatureScheme::ECDSA_NISTP384_SHA384 => ECDSA_NISTP384_SHA384, + SignatureScheme::RSA_PKCS1_SHA512 => RSA_PKCS1_SHA512, + SignatureScheme::ECDSA_NISTP521_SHA512 => ECDSA_NISTP521_SHA512, + SignatureScheme::RSA_PSS_SHA256 => RSA_PSS_SHA256, + SignatureScheme::RSA_PSS_SHA384 => RSA_PSS_SHA384, + SignatureScheme::RSA_PSS_SHA512 => RSA_PSS_SHA512, + SignatureScheme::ED25519 => ED25519, + } + } +} + +/// X.509 certificate, DER encoded. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Certificate(pub Vec); + +impl From for Certificate { + fn from(cert: tls_core::key::Certificate) -> Self { + Self(cert.0) + } +} + +/// Server's signature of the key exchange parameters. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerSignature { + /// Signature scheme. + pub scheme: SignatureScheme, + /// Signature data. + pub sig: Vec, +} + +/// Server's ephemeral public key. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ServerEphemKey { + /// Type of the public key. + #[serde(rename = "type")] + pub typ: KeyType, + /// Public key data. + pub key: Vec, +} + +impl_domain_separator!(ServerEphemKey); + +impl ServerEphemKey { + /// Encodes the key exchange parameters as in TLS. + pub(crate) fn kx_params(&self) -> Vec { + let group = match self.typ { + KeyType::SECP256R1 => NamedGroup::secp256r1, + }; + + let mut kx_params = Vec::new(); + ServerECDHParams::new(group, &self.key).encode(&mut kx_params); + + kx_params + } +} + +impl TryFrom for ServerEphemKey { + type Error = &'static str; + + fn try_from(value: tls_core::key::PublicKey) -> Result { + let tls_core::msgs::enums::NamedGroup::secp256r1 = value.group else { + return Err("unsupported key type"); + }; + + Ok(ServerEphemKey { + typ: KeyType::SECP256R1, + key: value.key, + }) + } +} + +/// TLS session information. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ConnectionInfo { + /// UNIX time when the TLS connection started. + pub time: u64, + /// TLS version used in the connection. + pub version: TlsVersion, + /// Transcript length. + pub transcript_length: TranscriptLength, +} + +impl_domain_separator!(ConnectionInfo); + +/// Transcript length information. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TranscriptLength { + /// Number of bytes sent by the Prover to the Server. + pub sent: u32, + /// Number of bytes received by the Prover from the Server. + pub received: u32, +} + +/// TLS 1.2 handshake data. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HandshakeDataV1_2 { + /// Client random. + pub client_random: [u8; 32], + /// Server random. + pub server_random: [u8; 32], + /// Server's ephemeral public key. + pub server_ephemeral_key: ServerEphemKey, +} + +/// TLS handshake data. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +#[non_exhaustive] +pub enum HandshakeData { + /// TLS 1.2 handshake data. + V1_2(HandshakeDataV1_2), +} + +impl_domain_separator!(HandshakeData); + +/// Server certificate and handshake data. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerCertData { + /// Certificate chain. + pub certs: Vec, + /// Server signature of the key exchange parameters. + pub sig: ServerSignature, + /// TLS handshake data. + pub handshake: HandshakeData, +} + +impl ServerCertData { + /// Verifies the server certificate data. + /// + /// # Arguments + /// + /// * `provider` - The crypto provider to use for verification. + /// * `time` - The time of the connection. + /// * `server_ephemeral_key` - The server's ephemeral key. + /// * `server_name` - The server name. + pub fn verify_with_provider( + &self, + provider: &CryptoProvider, + time: u64, + server_ephemeral_key: &ServerEphemKey, + server_name: &ServerName, + ) -> Result<(), CertificateVerificationError> { + #[allow(irrefutable_let_patterns)] + let HandshakeData::V1_2(HandshakeDataV1_2 { + client_random, + server_random, + server_ephemeral_key: expected_server_ephemeral_key, + }) = &self.handshake + else { + unreachable!("only TLS 1.2 is implemented") + }; + + if server_ephemeral_key != expected_server_ephemeral_key { + return Err(CertificateVerificationError::InvalidServerEphemeralKey); + } + + // Verify server name + let server_name = tls_core::dns::ServerName::try_from(server_name.as_ref()) + .map_err(|_| CertificateVerificationError::InvalidIdentity(server_name.clone()))?; + + // Verify server certificate + let cert_chain = self + .certs + .clone() + .into_iter() + .map(|cert| tls_core::key::Certificate(cert.0)) + .collect::>(); + + let (end_entity, intermediates) = cert_chain + .split_first() + .ok_or(CertificateVerificationError::MissingCerts)?; + + // Verify the end entity cert is valid for the provided server name + // and that it chains to at least one of the roots we trust. + provider + .cert + .verify_server_cert( + end_entity, + intermediates, + &server_name, + &mut [].into_iter(), + &[], + UNIX_EPOCH + Duration::from_secs(time), + ) + .map_err(|_| CertificateVerificationError::InvalidCert)?; + + // Verify the signature matches the certificate and key exchange parameters. + let mut message = Vec::new(); + message.extend_from_slice(client_random); + message.extend_from_slice(server_random); + message.extend_from_slice(&server_ephemeral_key.kx_params()); + + let dss = DigitallySignedStruct::new(self.sig.scheme.into(), self.sig.sig.clone()); + + provider + .cert + .verify_tls12_signature(&message, end_entity, &dss) + .map_err(|_| CertificateVerificationError::InvalidServerSignature)?; + + Ok(()) + } +} + +/// Errors that can occur when verifying a certificate chain or signature. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum CertificateVerificationError { + #[error("invalid server identity: {0}")] + InvalidIdentity(ServerName), + #[error("missing server certificates")] + MissingCerts, + #[error("invalid server certificate")] + InvalidCert, + #[error("invalid server signature")] + InvalidServerSignature, + #[error("invalid server ephemeral key")] + InvalidServerEphemeralKey, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{fixtures::ConnectionFixture, transcript::Transcript}; + + use hex::FromHex; + use rstest::*; + use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON}; + + #[fixture] + #[once] + fn crypto_provider() -> CryptoProvider { + CryptoProvider::default() + } + + fn tlsnotary() -> ConnectionFixture { + ConnectionFixture::tlsnotary(Transcript::new(GET_WITH_HEADER, OK_JSON).length()) + } + + fn appliedzkp() -> ConnectionFixture { + ConnectionFixture::appliedzkp(Transcript::new(GET_WITH_HEADER, OK_JSON).length()) + } + + /// Expect chain verification to succeed. + #[rstest] + #[case::tlsnotary(tlsnotary())] + #[case::appliedzkp(appliedzkp())] + fn test_verify_cert_chain_sucess_ca_implicit( + crypto_provider: &CryptoProvider, + #[case] mut data: ConnectionFixture, + ) { + // Remove the CA cert + data.server_cert_data.certs.pop(); + + assert!(data + .server_cert_data + .verify_with_provider( + crypto_provider, + data.connection_info.time, + data.server_ephemeral_key(), + &ServerName::from(data.server_name.as_ref()), + ) + .is_ok()); + } + + /// Expect chain verification to succeed even when a trusted CA is provided + /// among the intermediate certs. webpki handles such cases properly. + #[rstest] + #[case::tlsnotary(tlsnotary())] + #[case::appliedzkp(appliedzkp())] + fn test_verify_cert_chain_success_ca_explicit( + crypto_provider: &CryptoProvider, + #[case] data: ConnectionFixture, + ) { + assert!(data + .server_cert_data + .verify_with_provider( + crypto_provider, + data.connection_info.time, + data.server_ephemeral_key(), + &ServerName::from(data.server_name.as_ref()), + ) + .is_ok()); + } + + /// Expect to fail since the end entity cert was not valid at the time. + #[rstest] + #[case::tlsnotary(tlsnotary())] + #[case::appliedzkp(appliedzkp())] + fn test_verify_cert_chain_fail_bad_time( + crypto_provider: &CryptoProvider, + #[case] data: ConnectionFixture, + ) { + // unix time when the cert chain was NOT valid + let bad_time: u64 = 1571465711; + + let err = data.server_cert_data.verify_with_provider( + crypto_provider, + bad_time, + data.server_ephemeral_key(), + &ServerName::from(data.server_name.as_ref()), + ); + + assert!(matches!( + err.unwrap_err(), + CertificateVerificationError::InvalidCert + )); + } + + /// Expect to fail when no intermediate cert provided. + #[rstest] + #[case::tlsnotary(tlsnotary())] + #[case::appliedzkp(appliedzkp())] + fn test_verify_cert_chain_fail_no_interm_cert( + crypto_provider: &CryptoProvider, + #[case] mut data: ConnectionFixture, + ) { + // Remove the CA cert + data.server_cert_data.certs.pop(); + // Remove the intermediate cert + data.server_cert_data.certs.pop(); + + let err = data.server_cert_data.verify_with_provider( + crypto_provider, + data.connection_info.time, + data.server_ephemeral_key(), + &ServerName::from(data.server_name.as_ref()), + ); + + assert!(matches!( + err.unwrap_err(), + CertificateVerificationError::InvalidCert + )); + } + + /// Expect to fail when no intermediate cert provided even if a trusted CA + /// cert is provided. + #[rstest] + #[case::tlsnotary(tlsnotary())] + #[case::appliedzkp(appliedzkp())] + fn test_verify_cert_chain_fail_no_interm_cert_with_ca_cert( + crypto_provider: &CryptoProvider, + #[case] mut data: ConnectionFixture, + ) { + // Remove the intermediate cert + data.server_cert_data.certs.remove(1); + + let err = data.server_cert_data.verify_with_provider( + crypto_provider, + data.connection_info.time, + data.server_ephemeral_key(), + &ServerName::from(data.server_name.as_ref()), + ); + + assert!(matches!( + err.unwrap_err(), + CertificateVerificationError::InvalidCert + )); + } + + /// Expect to fail because end-entity cert is wrong. + #[rstest] + #[case::tlsnotary(tlsnotary())] + #[case::appliedzkp(appliedzkp())] + fn test_verify_cert_chain_fail_bad_ee_cert( + crypto_provider: &CryptoProvider, + #[case] mut data: ConnectionFixture, + ) { + let ee: &[u8] = include_bytes!("./fixtures/data/unknown/ee.der"); + + // Change the end entity cert + data.server_cert_data.certs[0] = Certificate(ee.to_vec()); + + let err = data.server_cert_data.verify_with_provider( + crypto_provider, + data.connection_info.time, + data.server_ephemeral_key(), + &ServerName::from(data.server_name.as_ref()), + ); + + assert!(matches!( + err.unwrap_err(), + CertificateVerificationError::InvalidCert + )); + } + + /// Expect sig verification to fail because client_random is wrong. + #[rstest] + #[case::tlsnotary(tlsnotary())] + #[case::appliedzkp(appliedzkp())] + fn test_verify_sig_ke_params_fail_bad_client_random( + crypto_provider: &CryptoProvider, + #[case] mut data: ConnectionFixture, + ) { + let HandshakeData::V1_2(HandshakeDataV1_2 { client_random, .. }) = + &mut data.server_cert_data.handshake; + client_random[31] = client_random[31].wrapping_add(1); + + let err = data.server_cert_data.verify_with_provider( + crypto_provider, + data.connection_info.time, + data.server_ephemeral_key(), + &ServerName::from(data.server_name.as_ref()), + ); + + assert!(matches!( + err.unwrap_err(), + CertificateVerificationError::InvalidServerSignature + )); + } + + /// Expect sig verification to fail because the sig is wrong. + #[rstest] + #[case::tlsnotary(tlsnotary())] + #[case::appliedzkp(appliedzkp())] + fn test_verify_sig_ke_params_fail_bad_sig( + crypto_provider: &CryptoProvider, + #[case] mut data: ConnectionFixture, + ) { + data.server_cert_data.sig.sig[31] = data.server_cert_data.sig.sig[31].wrapping_add(1); + + let err = data.server_cert_data.verify_with_provider( + crypto_provider, + data.connection_info.time, + data.server_ephemeral_key(), + &ServerName::from(data.server_name.as_ref()), + ); + + assert!(matches!( + err.unwrap_err(), + CertificateVerificationError::InvalidServerSignature + )); + } + + /// Expect to fail because the dns name is not in the cert. + #[rstest] + #[case::tlsnotary(tlsnotary())] + #[case::appliedzkp(appliedzkp())] + fn test_check_dns_name_present_in_cert_fail_bad_host( + crypto_provider: &CryptoProvider, + #[case] data: ConnectionFixture, + ) { + let bad_name = ServerName::from("badhost.com"); + + let err = data.server_cert_data.verify_with_provider( + crypto_provider, + data.connection_info.time, + data.server_ephemeral_key(), + &bad_name, + ); + + assert!(matches!( + err.unwrap_err(), + CertificateVerificationError::InvalidCert + )); + } + + /// Expect to fail because the ephemeral key provided is wrong. + #[rstest] + #[case::tlsnotary(tlsnotary())] + #[case::appliedzkp(appliedzkp())] + fn test_invalid_ephemeral_key( + crypto_provider: &CryptoProvider, + #[case] data: ConnectionFixture, + ) { + let wrong_ephemeral_key = ServerEphemKey { + typ: KeyType::SECP256R1, + key: Vec::::from_hex(include_bytes!("./fixtures/data/unknown/pubkey")).unwrap(), + }; + + let err = data.server_cert_data.verify_with_provider( + crypto_provider, + data.connection_info.time, + &wrong_ephemeral_key, + &ServerName::from(data.server_name.as_ref()), + ); + + assert!(matches!( + err.unwrap_err(), + CertificateVerificationError::InvalidServerEphemeralKey + )); + } + + /// Expect to fail when no cert provided. + #[rstest] + #[case::tlsnotary(tlsnotary())] + #[case::appliedzkp(appliedzkp())] + fn test_verify_cert_chain_fail_no_cert( + crypto_provider: &CryptoProvider, + #[case] mut data: ConnectionFixture, + ) { + // Empty certs + data.server_cert_data.certs = Vec::new(); + + let err = data.server_cert_data.verify_with_provider( + crypto_provider, + data.connection_info.time, + data.server_ephemeral_key(), + &ServerName::from(data.server_name.as_ref()), + ); + + assert!(matches!( + err.unwrap_err(), + CertificateVerificationError::MissingCerts + )); + } +} diff --git a/crates/core/src/connection/commit.rs b/crates/core/src/connection/commit.rs new file mode 100644 index 0000000000..71d15c5342 --- /dev/null +++ b/crates/core/src/connection/commit.rs @@ -0,0 +1,40 @@ +//! Types for committing details of a connection. + +use serde::{Deserialize, Serialize}; + +use crate::{ + connection::ServerCertData, + hash::{impl_domain_separator, Blinded, HashAlgorithm, HashAlgorithmExt, TypedHash}, +}; + +/// Opens a [`ServerCertCommitment`]. +#[derive(Clone, Serialize, Deserialize)] +pub struct ServerCertOpening(Blinded); + +impl_domain_separator!(ServerCertOpening); + +opaque_debug::implement!(ServerCertOpening); + +impl ServerCertOpening { + pub(crate) fn new(data: ServerCertData) -> Self { + Self(Blinded::new(data)) + } + + pub(crate) fn commit(&self, hasher: &dyn HashAlgorithm) -> ServerCertCommitment { + ServerCertCommitment(TypedHash { + alg: hasher.id(), + value: hasher.hash_separated(self), + }) + } + + /// Returns the server identity data. + pub fn data(&self) -> &ServerCertData { + self.0.data() + } +} + +/// Commitment to a server certificate. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ServerCertCommitment(pub(crate) TypedHash); + +impl_domain_separator!(ServerCertCommitment); diff --git a/crates/core/src/connection/proof.rs b/crates/core/src/connection/proof.rs new file mode 100644 index 0000000000..9de13ffcdf --- /dev/null +++ b/crates/core/src/connection/proof.rs @@ -0,0 +1,103 @@ +//! Types for proving details of a connection. + +use serde::{Deserialize, Serialize}; + +use crate::{ + connection::{ + commit::{ServerCertCommitment, ServerCertOpening}, + CertificateVerificationError, ServerEphemKey, ServerName, + }, + hash::{HashAlgorithmExt, HashProviderError}, + CryptoProvider, +}; + +/// TLS server identity proof. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerIdentityProof { + name: ServerName, + opening: ServerCertOpening, +} + +impl ServerIdentityProof { + pub(crate) fn new(name: ServerName, opening: ServerCertOpening) -> Self { + Self { name, opening } + } + + /// Verifies the server identity proof. + /// + /// # Arguments + /// + /// * `provider` - The crypto provider to use for verification. + /// * `time` - The time of the connection. + /// * `server_ephemeral_key` - The server's ephemeral key. + /// * `commitment` - Commitment to the server certificate. + pub fn verify_with_provider( + self, + provider: &CryptoProvider, + time: u64, + server_ephemeral_key: &ServerEphemKey, + commitment: &ServerCertCommitment, + ) -> Result { + let hasher = provider.hash.get(&commitment.0.alg)?; + + if commitment.0.value != hasher.hash_separated(&self.opening) { + return Err(ServerIdentityProofError { + kind: ErrorKind::Commitment, + message: "certificate opening does not match commitment".to_string(), + }); + } + + // Verify certificate and identity. + self.opening.data().verify_with_provider( + provider, + time, + server_ephemeral_key, + &self.name, + )?; + + Ok(self.name) + } +} + +/// Error for [`ServerIdentityProof`]. +#[derive(Debug, thiserror::Error)] +#[error("server identity proof error: {kind}: {message}")] +pub struct ServerIdentityProofError { + kind: ErrorKind, + message: String, +} + +impl From for ServerIdentityProofError { + fn from(err: HashProviderError) -> Self { + Self { + kind: ErrorKind::Provider, + message: err.to_string(), + } + } +} + +impl From for ServerIdentityProofError { + fn from(err: CertificateVerificationError) -> Self { + Self { + kind: ErrorKind::Certificate, + message: err.to_string(), + } + } +} + +#[derive(Debug)] +enum ErrorKind { + Provider, + Commitment, + Certificate, +} + +impl std::fmt::Display for ErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ErrorKind::Provider => write!(f, "provider"), + ErrorKind::Commitment => write!(f, "commitment"), + ErrorKind::Certificate => write!(f, "certificate"), + } + } +} diff --git a/crates/core/src/fixtures.rs b/crates/core/src/fixtures.rs new file mode 100644 index 0000000000..cb7cab86fc --- /dev/null +++ b/crates/core/src/fixtures.rs @@ -0,0 +1,145 @@ +//! Fixtures for testing + +mod provider; + +pub use provider::ChaChaProvider; + +use hex::FromHex; +use p256::ecdsa::SigningKey; + +use crate::{ + connection::{ + Certificate, ConnectionInfo, HandshakeData, HandshakeDataV1_2, KeyType, ServerCertData, + ServerEphemKey, ServerName, ServerSignature, SignatureScheme, TlsVersion, TranscriptLength, + }, + request::Request, + transcript::{encoding::EncodingProvider, PlaintextHash, Transcript}, +}; + +/// A fixture containing various TLS connection data. +#[allow(missing_docs)] +pub struct ConnectionFixture { + pub server_name: ServerName, + pub connection_info: ConnectionInfo, + pub server_cert_data: ServerCertData, +} + +impl ConnectionFixture { + /// Returns a connection fixture for tlsnotary.org. + pub fn tlsnotary(transcript_length: TranscriptLength) -> Self { + ConnectionFixture { + server_name: ServerName::new("tlsnotary.org".to_string()), + connection_info: ConnectionInfo { + time: 1671637529, + version: TlsVersion::V1_2, + transcript_length, + }, + server_cert_data: ServerCertData { + certs: vec![ + Certificate(include_bytes!("fixtures/data/tlsnotary.org/ee.der").to_vec()), + Certificate(include_bytes!("fixtures/data/tlsnotary.org/inter.der").to_vec()), + Certificate(include_bytes!("fixtures/data/tlsnotary.org/ca.der").to_vec()), + ], + sig: ServerSignature { + scheme: SignatureScheme::RSA_PKCS1_SHA256, + sig: Vec::::from_hex(include_bytes!( + "fixtures/data/tlsnotary.org/signature" + )) + .unwrap(), + }, + handshake: HandshakeData::V1_2(HandshakeDataV1_2 { + client_random: <[u8; 32]>::from_hex(include_bytes!( + "fixtures/data/tlsnotary.org/client_random" + )) + .unwrap(), + server_random: <[u8; 32]>::from_hex(include_bytes!( + "fixtures/data/tlsnotary.org/server_random" + )) + .unwrap(), + server_ephemeral_key: ServerEphemKey { + typ: KeyType::SECP256R1, + key: Vec::::from_hex(include_bytes!( + "fixtures/data/tlsnotary.org/pubkey" + )) + .unwrap(), + }, + }), + }, + } + } + + /// Returns a connection fixture for appliedzkp.org. + pub fn appliedzkp(transcript_length: TranscriptLength) -> Self { + ConnectionFixture { + server_name: ServerName::new("appliedzkp.org".to_string()), + connection_info: ConnectionInfo { + time: 1671637529, + version: TlsVersion::V1_2, + transcript_length, + }, + server_cert_data: ServerCertData { + certs: vec![ + Certificate(include_bytes!("fixtures/data/appliedzkp.org/ee.der").to_vec()), + Certificate(include_bytes!("fixtures/data/appliedzkp.org/inter.der").to_vec()), + Certificate(include_bytes!("fixtures/data/appliedzkp.org/ca.der").to_vec()), + ], + sig: ServerSignature { + scheme: SignatureScheme::ECDSA_NISTP256_SHA256, + sig: Vec::::from_hex(include_bytes!( + "fixtures/data/appliedzkp.org/signature" + )) + .unwrap(), + }, + handshake: HandshakeData::V1_2(HandshakeDataV1_2 { + client_random: <[u8; 32]>::from_hex(include_bytes!( + "fixtures/data/appliedzkp.org/client_random" + )) + .unwrap(), + server_random: <[u8; 32]>::from_hex(include_bytes!( + "fixtures/data/appliedzkp.org/server_random" + )) + .unwrap(), + server_ephemeral_key: ServerEphemKey { + typ: KeyType::SECP256R1, + key: Vec::::from_hex(include_bytes!( + "fixtures/data/appliedzkp.org/pubkey" + )) + .unwrap(), + }, + }), + }, + } + } + + /// Returns the server_ephemeral_key fixture. + pub fn server_ephemeral_key(&self) -> &ServerEphemKey { + let HandshakeData::V1_2(HandshakeDataV1_2 { + server_ephemeral_key, + .. + }) = &self.server_cert_data.handshake; + server_ephemeral_key + } +} + +/// Returns an encoding provider fixture. +pub fn encoding_provider(tx: &[u8], rx: &[u8]) -> impl EncodingProvider { + ChaChaProvider::new(encoder_seed(), Transcript::new(tx, rx)) +} + +/// Returns an encoder seed fixture. +pub fn encoder_seed() -> [u8; 32] { + [0u8; 32] +} + +/// Returns a notary signing key fixture. +pub fn notary_signing_key() -> SigningKey { + SigningKey::from_slice(&[1; 32]).unwrap() +} + +/// Returns plaintext hashes contained in the request. +pub fn plaintext_hashes_from_request(request: &Request) -> Vec { + match &request.plaintext_hashes { + Some(hashes) => hashes.iter().map(|f| f.data.clone()).collect::>(), + None => Vec::new(), + } +} diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/README b/crates/core/src/fixtures/data/README.md similarity index 98% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/README rename to crates/core/src/fixtures/data/README.md index 9c0ad11546..7b84c9b8c3 100644 --- a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/README +++ b/crates/core/src/fixtures/data/README.md @@ -25,5 +25,4 @@ tshark -r out.pcap -Y "tcp.stream==$STREAM_ID and tcp.srcport == 443" -T fields # pubkey (ephemeral public key) tshark -r out.pcap -Y "tcp.stream==$STREAM_ID" -T fields -e tls.handshake.server_point # signature (over the key exchange parameters) -tshark -r out.pcap -Y "tcp.stream==$STREAM_ID" -T fields -e tls.handshake.sig - +tshark -r out.pcap -Y "tcp.stream==$STREAM_ID" -T fields -e tls.handshake.sig \ No newline at end of file diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/ca.der b/crates/core/src/fixtures/data/appliedzkp.org/ca.der similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/ca.der rename to crates/core/src/fixtures/data/appliedzkp.org/ca.der diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/client_random b/crates/core/src/fixtures/data/appliedzkp.org/client_random similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/client_random rename to crates/core/src/fixtures/data/appliedzkp.org/client_random diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/ee.der b/crates/core/src/fixtures/data/appliedzkp.org/ee.der similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/ee.der rename to crates/core/src/fixtures/data/appliedzkp.org/ee.der diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/inter.der b/crates/core/src/fixtures/data/appliedzkp.org/inter.der similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/inter.der rename to crates/core/src/fixtures/data/appliedzkp.org/inter.der diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/pubkey b/crates/core/src/fixtures/data/appliedzkp.org/pubkey similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/pubkey rename to crates/core/src/fixtures/data/appliedzkp.org/pubkey diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/server_random b/crates/core/src/fixtures/data/appliedzkp.org/server_random similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/server_random rename to crates/core/src/fixtures/data/appliedzkp.org/server_random diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/signature b/crates/core/src/fixtures/data/appliedzkp.org/signature similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/appliedzkp.org/signature rename to crates/core/src/fixtures/data/appliedzkp.org/signature diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/ca.der b/crates/core/src/fixtures/data/tlsnotary.org/ca.der similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/ca.der rename to crates/core/src/fixtures/data/tlsnotary.org/ca.der diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/client_random b/crates/core/src/fixtures/data/tlsnotary.org/client_random similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/client_random rename to crates/core/src/fixtures/data/tlsnotary.org/client_random diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/ee.der b/crates/core/src/fixtures/data/tlsnotary.org/ee.der similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/ee.der rename to crates/core/src/fixtures/data/tlsnotary.org/ee.der diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/inter.der b/crates/core/src/fixtures/data/tlsnotary.org/inter.der similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/inter.der rename to crates/core/src/fixtures/data/tlsnotary.org/inter.der diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/pubkey b/crates/core/src/fixtures/data/tlsnotary.org/pubkey similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/pubkey rename to crates/core/src/fixtures/data/tlsnotary.org/pubkey diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/server_random b/crates/core/src/fixtures/data/tlsnotary.org/server_random similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/server_random rename to crates/core/src/fixtures/data/tlsnotary.org/server_random diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/signature b/crates/core/src/fixtures/data/tlsnotary.org/signature similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/tlsnotary.org/signature rename to crates/core/src/fixtures/data/tlsnotary.org/signature diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/unknown/ca.der b/crates/core/src/fixtures/data/unknown/ca.der similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/unknown/ca.der rename to crates/core/src/fixtures/data/unknown/ca.der diff --git a/tlsn/tlsn-core/src/fixtures/testdata/key_exchange/unknown/ee.der b/crates/core/src/fixtures/data/unknown/ee.der similarity index 100% rename from tlsn/tlsn-core/src/fixtures/testdata/key_exchange/unknown/ee.der rename to crates/core/src/fixtures/data/unknown/ee.der diff --git a/crates/core/src/fixtures/data/unknown/pubkey b/crates/core/src/fixtures/data/unknown/pubkey new file mode 100644 index 0000000000..e9166c19fe --- /dev/null +++ b/crates/core/src/fixtures/data/unknown/pubkey @@ -0,0 +1 @@ +14e1f634ecfee5bd4f987f8c571146cb2acb432e400b2fabcbd8ed77f6ef08bd5496cd51d449ce131efd74a24d07b01c38ec794d22d3d43b2b05b907e72797534e \ No newline at end of file diff --git a/crates/core/src/fixtures/provider.rs b/crates/core/src/fixtures/provider.rs new file mode 100644 index 0000000000..179e2995dd --- /dev/null +++ b/crates/core/src/fixtures/provider.rs @@ -0,0 +1,33 @@ +use mpz_garble_core::ChaChaEncoder; + +use crate::transcript::{ + encoding::{Encoder, EncodingProvider}, + Direction, Idx, Transcript, +}; + +/// A ChaCha encoding provider fixture. +pub struct ChaChaProvider { + encoder: ChaChaEncoder, + transcript: Transcript, +} + +impl ChaChaProvider { + /// Creates a new ChaCha encoding provider. + pub(crate) fn new(seed: [u8; 32], transcript: Transcript) -> Self { + Self { + encoder: ChaChaEncoder::new(seed), + transcript, + } + } +} + +impl EncodingProvider for ChaChaProvider { + fn provide_encoding(&self, direction: Direction, idx: &Idx) -> Option> { + let seq = self.transcript.get(direction, idx)?; + Some(self.encoder.encode_subsequence(direction, &seq)) + } + + fn provide_bit_encodings(&self, _direction: Direction, _idx: &Idx) -> Option>> { + unimplemented!() + } +} diff --git a/crates/core/src/hash.rs b/crates/core/src/hash.rs new file mode 100644 index 0000000000..b67867ed4b --- /dev/null +++ b/crates/core/src/hash.rs @@ -0,0 +1,466 @@ +//! Hash types. + +use std::{collections::HashMap, fmt::Display}; + +use rand::{distributions::Standard, prelude::Distribution}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::serialize::CanonicalSerialize; + +pub(crate) const DEFAULT_SUPPORTED_HASH_ALGS: &[HashAlgId] = + &[HashAlgId::SHA256, HashAlgId::BLAKE3, HashAlgId::KECCAK256]; + +/// Maximum length of a hash value. +const MAX_LEN: usize = 64; + +/// An error for [`HashProvider`]. +#[derive(Debug, thiserror::Error)] +#[error("unknown hash algorithm id: {}", self.0)] +pub struct HashProviderError(HashAlgId); + +/// Hash provider. +pub struct HashProvider { + algs: HashMap>, +} + +impl Default for HashProvider { + fn default() -> Self { + let mut algs: HashMap<_, Box> = HashMap::new(); + + algs.insert(HashAlgId::SHA256, Box::new(Sha256::default())); + algs.insert(HashAlgId::BLAKE3, Box::new(Blake3::default())); + algs.insert(HashAlgId::KECCAK256, Box::new(Keccak256::default())); + #[cfg(feature = "poseidon")] + algs.insert( + HashAlgId::POSEIDON_BN256_434, + Box::new(PoseidonBn256::default()), + ); + + Self { algs } + } +} + +impl HashProvider { + /// Sets a hash algorithm. + /// + /// This can be used to add or override implementations of hash algorithms. + pub fn set_algorithm( + &mut self, + id: HashAlgId, + algorithm: Box, + ) { + self.algs.insert(id, algorithm); + } + + /// Returns the hash algorithm with the given identifier, or an error if the + /// hash algorithm does not exist. + pub fn get( + &self, + id: &HashAlgId, + ) -> Result<&(dyn HashAlgorithm + Send + Sync), HashProviderError> { + self.algs + .get(id) + .map(|alg| &**alg) + .ok_or(HashProviderError(*id)) + } +} + +/// A hash algorithm identifier. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct HashAlgId(u8); + +impl HashAlgId { + /// SHA-256 hash algorithm. + pub const SHA256: Self = Self(1); + /// BLAKE3 hash algorithm. + pub const BLAKE3: Self = Self(2); + /// Keccak-256 hash algorithm. + pub const KECCAK256: Self = Self(3); + /// Poseidon hash algorithm over the BN256 curve with 434-byte padding length. + #[cfg(feature = "poseidon")] + pub const POSEIDON_BN256_434: Self = Self(4); + + /// Creates a new hash algorithm identifier. + /// + /// # Panics + /// + /// Panics if the identifier is in the reserved range 0-127. + /// + /// # Arguments + /// + /// * id - Unique identifier for the hash algorithm. + pub const fn new(id: u8) -> Self { + assert!(id >= 128, "hash algorithm id range 0-127 is reserved"); + + Self(id) + } + + /// Returns the id as a `u8`. + pub const fn as_u8(&self) -> u8 { + self.0 + } +} + +impl Display for HashAlgId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:02x}", self.0) + } +} + +/// A typed hash value. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, std::hash::Hash)] +pub struct TypedHash { + /// The algorithm of the hash. + pub alg: HashAlgId, + /// The hash value. + pub value: Hash, +} + +/// A hash value. +#[derive(Debug, Clone, Copy, PartialEq, Eq, std::hash::Hash)] +pub struct Hash { + // To avoid heap allocation, we use a fixed-size array. + // 64 bytes should be sufficient for most hash algorithms. + value: [u8; MAX_LEN], + len: usize, +} + +impl Default for Hash { + fn default() -> Self { + Self { + value: [0u8; MAX_LEN], + len: 0, + } + } +} + +impl Serialize for Hash { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.collect_seq(&self.value[..self.len]) + } +} + +impl<'de> Deserialize<'de> for Hash { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + use core::marker::PhantomData; + use serde::de::{Error, SeqAccess, Visitor}; + + struct HashVisitor<'de>(PhantomData<&'de ()>); + + impl<'de> Visitor<'de> for HashVisitor<'de> { + type Value = Hash; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "an array at most 64 bytes long") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut value = [0; MAX_LEN]; + let mut len = 0; + + while let Some(byte) = seq.next_element()? { + if len >= MAX_LEN { + return Err(A::Error::invalid_length(len, &self)); + } + + value[len] = byte; + len += 1; + } + + Ok(Hash { value, len }) + } + } + + deserializer.deserialize_seq(HashVisitor(PhantomData)) + } +} + +impl Hash { + /// Creates a new hash value. + /// + /// # Panics + /// + /// Panics if the length of the value is greater than 64 bytes. + fn new(value: &[u8]) -> Self { + assert!( + value.len() <= MAX_LEN, + "hash value must be at most 64 bytes" + ); + + let mut bytes = [0; MAX_LEN]; + bytes[..value.len()].copy_from_slice(value); + + Self { + value: bytes, + len: value.len(), + } + } +} + +impl rs_merkle::Hash for Hash { + const SIZE: usize = MAX_LEN; +} + +impl TryFrom> for Hash { + type Error = &'static str; + + fn try_from(value: Vec) -> Result { + if value.len() > MAX_LEN { + return Err("hash value must be at most 64 bytes"); + } + + let mut bytes = [0; MAX_LEN]; + bytes[..value.len()].copy_from_slice(&value); + + Ok(Self { + value: bytes, + len: value.len(), + }) + } +} + +impl From for Vec { + fn from(value: Hash) -> Self { + value.value[..value.len].to_vec() + } +} + +/// A hashing algorithm. +pub trait HashAlgorithm { + /// Returns the hash algorithm identifier. + fn id(&self) -> HashAlgId; + + /// Computes the hash of the provided data. + fn hash(&self, data: &[u8]) -> Hash; + + /// Computes the hash of the provided data with a prefix. + fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> Hash; + + /// Computes the hash of the provided blinded data. + fn hash_blinded(&self, data: &Blinded>) -> Hash { + self.hash_canonical(data) + } +} + +pub(crate) trait HashAlgorithmExt: HashAlgorithm { + fn hash_canonical(&self, data: &T) -> Hash { + self.hash(&data.serialize()) + } + + fn hash_separated(&self, data: &T) -> Hash { + self.hash_prefixed(data.domain(), &data.serialize()) + } +} + +impl HashAlgorithmExt for T {} + +/// A hash blinder. +#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, std::hash::Hash)] +pub struct Blinder([u8; 16]); + +impl Blinder { + /// Returns a reference to the inner value. + pub fn as_inner(&self) -> &[u8] { + &self.0 + } +} + +opaque_debug::implement!(Blinder); + +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> Blinder { + let mut blinder = [0; 16]; + rng.fill(&mut blinder); + Blinder(blinder) + } +} + +/// A blinded pre-image of a hash. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Blinded { + data: T, + blinder: Blinder, +} + +impl Blinded { + /// Creates a new blinded pre-image. + pub(crate) fn new(data: T) -> Self { + Self { + data, + blinder: rand::random(), + } + } + + pub(crate) fn new_with_blinder(data: T, blinder: Blinder) -> Self { + Self { data, blinder } + } + + pub(crate) fn data(&self) -> &T { + &self.data + } + + pub(crate) fn into_parts(self) -> (T, Blinder) { + (self.data, self.blinder) + } +} + +/// A type with a domain separator which is used during hashing to mitigate type +/// confusion attacks. +pub(crate) trait DomainSeparator { + /// Returns the domain separator for the type. + fn domain(&self) -> &[u8]; +} + +macro_rules! impl_domain_separator { + ($type:ty) => { + impl $crate::hash::DomainSeparator for $type { + fn domain(&self) -> &[u8] { + use std::sync::LazyLock; + + // Computes a 16 byte hash of the types name to use as a domain separator. + static DOMAIN: LazyLock<[u8; 16]> = LazyLock::new(|| { + let domain: [u8; 32] = blake3::hash(stringify!($type).as_bytes()).into(); + domain[..16].try_into().unwrap() + }); + + &*DOMAIN + } + } + }; +} + +pub(crate) use impl_domain_separator; + +mod sha2 { + use ::sha2::Digest; + + /// SHA-256 hash algorithm. + #[derive(Default, Clone)] + pub struct Sha256 {} + + impl super::HashAlgorithm for Sha256 { + fn id(&self) -> super::HashAlgId { + super::HashAlgId::SHA256 + } + + fn hash(&self, data: &[u8]) -> super::Hash { + let mut hasher = ::sha2::Sha256::default(); + hasher.update(data); + super::Hash::new(hasher.finalize().as_slice()) + } + + fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash { + let mut hasher = ::sha2::Sha256::default(); + hasher.update(prefix); + hasher.update(data); + super::Hash::new(hasher.finalize().as_slice()) + } + } +} + +pub use sha2::Sha256; + +mod blake3 { + /// BLAKE3 hash algorithm. + #[derive(Default, Clone)] + pub struct Blake3 {} + + impl super::HashAlgorithm for Blake3 { + fn id(&self) -> super::HashAlgId { + super::HashAlgId::BLAKE3 + } + + fn hash(&self, data: &[u8]) -> super::Hash { + super::Hash::new(::blake3::hash(data).as_bytes()) + } + + fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash { + let mut hasher = ::blake3::Hasher::new(); + hasher.update(prefix); + hasher.update(data); + super::Hash::new(hasher.finalize().as_bytes()) + } + } +} + +pub use blake3::Blake3; + +mod keccak { + use tiny_keccak::Hasher; + + /// Keccak-256 hash algorithm. + #[derive(Default, Clone)] + pub struct Keccak256 {} + + impl super::HashAlgorithm for Keccak256 { + fn id(&self) -> super::HashAlgId { + super::HashAlgId::KECCAK256 + } + + fn hash(&self, data: &[u8]) -> super::Hash { + let mut hasher = tiny_keccak::Keccak::v256(); + hasher.update(data); + let mut output = vec![0; 32]; + hasher.finalize(&mut output); + super::Hash::new(&output) + } + + fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash { + let mut hasher = tiny_keccak::Keccak::v256(); + hasher.update(prefix); + hasher.update(data); + let mut output = vec![0; 32]; + hasher.finalize(&mut output); + super::Hash::new(&output) + } + } +} + +pub use keccak::Keccak256; + +#[cfg(feature = "poseidon")] +mod poseidon_halo2 { + use super::Blinded; + use poseidon_halo2::hash as poseidon_hash; + + /// Poseidon hash algorithm with preimage padding. + #[derive(Default, Clone)] + pub struct PoseidonBn256 {} + + /// Maximum allowed bytesize of the hash preimage. + /// + /// If more data than the allowed maximum needs to be hashed, then the data should be split up + /// into chunks and each chunk should be hashed separately. + pub const POSEIDON_MAX_INPUT_SIZE: usize = 434; + + impl super::HashAlgorithm for PoseidonBn256 { + fn id(&self) -> super::HashAlgId { + super::HashAlgId::POSEIDON_BN256_434 + } + + fn hash(&self, _data: &[u8]) -> super::Hash { + unimplemented!() + } + + fn hash_prefixed(&self, _prefix: &[u8], _data: &[u8]) -> super::Hash { + unimplemented!() + } + + fn hash_blinded(&self, data: &Blinded>) -> super::Hash { + let (data, blinder) = data.clone().into_parts(); + super::Hash::new(&poseidon_hash(data, blinder.as_inner().to_vec())) + } + } +} + +#[cfg(feature = "poseidon")] +pub use poseidon_halo2::{PoseidonBn256, POSEIDON_MAX_INPUT_SIZE}; diff --git a/crates/core/src/index.rs b/crates/core/src/index.rs new file mode 100644 index 0000000000..94e7dd1c9a --- /dev/null +++ b/crates/core/src/index.rs @@ -0,0 +1,180 @@ +//! Index types. + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::{ + attestation::{Field, FieldId}, + transcript::{Direction, Idx, PlaintextHash, PlaintextHashSecret}, +}; + +/// Index for items which can be looked up by either the transcript index and direction or by the +/// field id. +#[derive(Debug, Clone)] +pub(crate) struct Index { + items: Vec, + // Lookup by field id + field_ids: HashMap, + // Lookup by transcript direction and index + transcript_idxs: HashMap<(Direction, Idx), usize>, +} + +impl Default for Index { + fn default() -> Self { + Self { + items: Default::default(), + field_ids: Default::default(), + transcript_idxs: Default::default(), + } + } +} + +impl Serialize for Index { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.items.serialize(serializer) + } +} + +impl<'de, T: Deserialize<'de>> Deserialize<'de> for Index +where + Index: From>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Vec::::deserialize(deserializer).map(Index::from) + } +} + +impl From> for Vec { + fn from(value: Index) -> Self { + value.items + } +} + +impl Index { + pub(crate) fn new(items: Vec, f: F) -> Self + where + F: Fn(&T) -> (&FieldId, &Idx, &Direction), + { + let mut field_ids = HashMap::new(); + let mut transcript_idxs = HashMap::new(); + for (i, item) in items.iter().enumerate() { + let (id, idx, dir) = f(item); + field_ids.insert(*id, i); + transcript_idxs.insert((*dir, idx.clone()), i); + } + Self { + items, + field_ids, + transcript_idxs, + } + } + + #[allow(dead_code)] + pub(crate) fn iter(&self) -> impl Iterator { + self.items.iter() + } + + pub(crate) fn get_by_field_id(&self, id: &FieldId) -> Option<&T> { + self.field_ids.get(id).map(|i| &self.items[*i]) + } + + pub(crate) fn get_by_transcript_idx(&self, dir: &Direction, idx: &Idx) -> Option<&T> { + self.transcript_idxs + .get(&(*dir, idx.clone())) + .map(|i| &self.items[*i]) + } +} + +impl From>> for Index> { + fn from(items: Vec>) -> Self { + Self::new(items, |field: &Field| { + (&field.id, &field.data.idx, &field.data.direction) + }) + } +} + +impl From> for Index { + fn from(items: Vec) -> Self { + Self::new(items, |item: &PlaintextHashSecret| { + (&item.commitment, &item.idx, &item.direction) + }) + } +} + +#[cfg(test)] +mod test { + use utils::range::RangeSet; + + use super::*; + + #[derive(PartialEq, Debug, Clone)] + struct Stub { + field_index: FieldId, + index: Idx, + } + + impl From> for Index { + fn from(items: Vec) -> Self { + Self::new(items, |item: &Stub| (&item.field_index, &item.index)) + } + } + + fn stubs() -> Vec { + vec![ + Stub { + field_index: FieldId(1), + index: Idx::new(RangeSet::from([0..1, 18..21])), + }, + Stub { + field_index: FieldId(2), + index: Idx::new(RangeSet::from([1..5, 8..11])), + }, + ] + } + + #[test] + fn test_successful_retrieval() { + let stub_a_index = Idx::new(RangeSet::from([0..4, 7..10])); + let stub_b_field_index = FieldId(8); + + let stubs = vec![ + Stub { + field_index: FieldId(1), + index: stub_a_index.clone(), + }, + Stub { + field_index: stub_b_field_index, + index: Idx::new(RangeSet::from([1..5, 8..11])), + }, + ]; + let stubs_index: Index = stubs.clone().into(); + + assert_eq!( + stubs_index.get_by_field_id(&stub_b_field_index), + Some(&stubs[1]) + ); + assert_eq!( + stubs_index.get_by_transcript_idx(&stub_a_index), + Some(&stubs[0]) + ); + } + + #[test] + fn test_failed_retrieval() { + let stubs = stubs(); + let stubs_index: Index = stubs.clone().into(); + + let wrong_index = Idx::new(RangeSet::from([0..3, 4..5])); + let wrong_field_index = FieldId(200); + + assert_eq!(stubs_index.get_by_field_id(&wrong_field_index), None); + assert_eq!(stubs_index.get_by_transcript_idx(&wrong_index), None); + } +} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs new file mode 100644 index 0000000000..0272cad2b3 --- /dev/null +++ b/crates/core/src/lib.rs @@ -0,0 +1,192 @@ +//! TLSNotary core library. +//! +//! # Introduction +//! +//! This library provides core functionality for the TLSNotary **attestation** +//! protocol, including some more general types which are useful outside +//! of attestations. +//! +//! Once the MPC-TLS protocol has been completed the Prover holds a collection +//! of commitments pertaining to the TLS connection. Most importantly, the +//! Prover is committed to the [`ServerName`](crate::connection::ServerName), +//! and the [`Transcript`](crate::transcript::Transcript) of application data. +//! Subsequently, the Prover can request an +//! [`Attestation`](crate::attestation::Attestation) from the Notary who will +//! include the commitments as well as any additional information which may be +//! useful to an attestation Verifier. +//! +//! Holding an attestation, the Prover can construct a +//! [`Presentation`](crate::presentation::Presentation) which facilitates +//! selectively disclosing various aspects of the TLS connection to a Verifier. +//! If the Verifier trusts the Notary, or more specifically the verifying key of +//! the attestation, then the Verifier can trust the authenticity of the +//! information disclosed in the presentation. +//! +//! **Be sure to check out the various submodules for more information.** +//! +//! # Committing to the transcript +//! +//! The MPC-TLS protocol produces commitments to the entire transcript of +//! application data. However, we may want to disclose only a subset of the data +//! in a presentation. Prior to attestation, the Prover has the opportunity to +//! slice and dice the commitments into smaller sections which can be +//! selectively disclosed. Additionally, the Prover may want to use different +//! commitment schemes depending on the context they expect to disclose. +//! +//! The primary API for this process is the +//! [`TranscriptCommitConfigBuilder`](crate::transcript::TranscriptCommitConfigBuilder) +//! which is used to build up a configuration. +//! +//! Currently, only the +//! [`Encoding`](crate::transcript::TranscriptCommitmentKind::Encoding) +//! commitment kind is supported. In the future you will be able to acquire hash +//! commitments directly to the transcript data. +//! +//! ```no_run +//! # use tlsn_core::transcript::{TranscriptCommitConfigBuilder, Transcript, Direction}; +//! # use tlsn_core::hash::HashAlgId; +//! # fn main() -> Result<(), Box> { +//! # let transcript: Transcript = unimplemented!(); +//! let (sent_len, recv_len) = transcript.len(); +//! +//! // Create a new configuration builder. +//! let mut builder = TranscriptCommitConfigBuilder::new(&transcript); +//! +//! // Specify all the transcript commitments we want to make. +//! builder +//! // Use BLAKE3 for encoding commitments. +//! .encoding_hash_alg(HashAlgId::BLAKE3) +//! // Commit to all sent data. +//! .commit_sent(&(0..sent_len))? +//! // Commit to the first 10 bytes of sent data. +//! .commit_sent(&(0..10))? +//! // Skip some bytes so it can be omitted in the presentation. +//! .commit_sent(&(20..sent_len))? +//! // Commit to all received data. +//! .commit_recv(&(0..recv_len))?; +//! +//! let config = builder.build()?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Requesting an attestation +//! +//! The first step in the attestation protocol is for the Prover to make a +//! [`Request`](crate::request::Request), which can be configured using the +//! associated [builder](crate::request::RequestConfigBuilder). With it the +//! Prover can configure some of the details of the attestation, such as which +//! cryptographic algorithms are used (if the Notary supports them). +//! +//! Upon being issued an attestation, the Prover will also hold a corresponding +//! [`Secrets`] which contains all private information. This pair can be stored +//! and used later to construct a +//! [`Presentation`](crate::presentation::Presentation), [see +//! below](#constructing-a-presentation). +//! +//! # Issuing an attestation +//! +//! Upon receiving a request, the Notary can issue an +//! [`Attestation`](crate::attestation::Attestation) which can be configured +//! using the associated +//! [builder](crate::attestation::AttestationConfigBuilder). +//! +//! The Notary's [`CryptoProvider`] must be configured with an appropriate +//! signing key for attestations. See +//! [`SignerProvider`](crate::signing::SignerProvider) for more information. +//! +//! # Constructing a presentation +//! +//! A Prover can use an [`Attestation`](crate::attestation::Attestation) and the +//! corresponding [`Secrets`] to construct a verifiable +//! [`Presentation`](crate::presentation::Presentation). +//! +//! ```no_run +//! # use tlsn_core::presentation::Presentation; +//! # use tlsn_core::attestation::Attestation; +//! # use tlsn_core::transcript::{TranscriptCommitmentKind, Direction}; +//! # use tlsn_core::{Secrets, CryptoProvider}; +//! # fn main() -> Result<(), Box> { +//! # let attestation: Attestation = unimplemented!(); +//! # let secrets: Secrets = unimplemented!(); +//! # let crypto_provider: CryptoProvider = unimplemented!(); +//! let (_sent_len, recv_len) = secrets.transcript().len(); +//! +//! // First, we decide which application data we would like to disclose. +//! let mut builder = secrets.transcript_proof_builder(); +//! +//! builder +//! // Use transcript encoding commitments. +//! .default_kind(TranscriptCommitmentKind::Encoding) +//! // Disclose the first 10 bytes of the sent data. +//! .reveal(&(0..10), Direction::Sent)? +//! // Disclose all of the received data. +//! .reveal(&(0..recv_len), Direction::Received)?; +//! +//! let transcript_proof = builder.build()?; +//! +//! // Most cases we will also disclose the server identity. +//! let identity_proof = secrets.identity_proof(); +//! +//! // Now we can construct the presentation. +//! let mut builder = attestation.presentation_builder(&crypto_provider); +//! +//! builder +//! .identity_proof(identity_proof) +//! .transcript_proof(transcript_proof); +//! +//! // Finally, we build the presentation. Send it to a verifier! +//! let presentation: Presentation = builder.build()?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Verifying a presentation +//! +//! Verifying a presentation is as simple as checking the verifier trusts the +//! verifying key then calling +//! [`Presentation::verify`](crate::presentation::Presentation::verify). +//! +//! ```no_run +//! # use tlsn_core::presentation::{Presentation, PresentationOutput}; +//! # use tlsn_core::signing::VerifyingKey; +//! # use tlsn_core::CryptoProvider; +//! # fn main() -> Result<(), Box> { +//! # let presentation: Presentation = unimplemented!(); +//! # let trusted_key: VerifyingKey = unimplemented!(); +//! # let crypto_provider: CryptoProvider = unimplemented!(); +//! // Assert that we trust the verifying key. +//! assert_eq!(presentation.verifying_key(), &trusted_key); +//! +//! let PresentationOutput { +//! attestation, +//! server_name, +//! connection_info, +//! transcript, +//! .. +//! } = presentation.verify(&crypto_provider)?; +//! # Ok(()) +//! # } +//! ``` + +#![deny(missing_docs, unreachable_pub, unused_must_use)] +#![deny(clippy::all)] +#![forbid(unsafe_code)] + +pub mod attestation; +pub mod connection; +#[cfg(any(test, feature = "fixtures"))] +pub mod fixtures; +pub mod hash; +pub(crate) mod index; +pub(crate) mod merkle; +pub mod presentation; +mod provider; +pub mod request; +mod secrets; +pub(crate) mod serialize; +pub mod signing; +pub mod transcript; + +pub use provider::CryptoProvider; +pub use secrets::Secrets; diff --git a/crates/core/src/merkle.rs b/crates/core/src/merkle.rs new file mode 100644 index 0000000000..ebb200c170 --- /dev/null +++ b/crates/core/src/merkle.rs @@ -0,0 +1,308 @@ +//! Merkle tree types. + +use serde::{Deserialize, Serialize}; +use utils::iter::DuplicateCheck; + +use crate::hash::{Hash, HashAlgId, HashAlgorithm, TypedHash}; + +/// Errors that can occur during operations with Merkle tree and Merkle proof +#[derive(Debug, thiserror::Error)] +#[error("merkle error: {0}")] +pub(crate) struct MerkleError(String); + +impl MerkleError { + fn new(msg: impl Into) -> Self { + Self(msg.into()) + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub(crate) struct MerkleProof { + alg: HashAlgId, + tree_len: usize, + proof: rs_merkle::MerkleProof, +} + +opaque_debug::implement!(MerkleProof); + +impl MerkleProof { + /// Checks if indices, hashes and leaves count are valid for the provided + /// root + /// + /// # Panics + /// + /// - If the length of `leaf_indices` and `leaf_hashes` does not match. + /// - If `leaf_indices` contains duplicates. + pub(crate) fn verify( + &self, + hasher: &dyn HashAlgorithm, + root: &TypedHash, + leaves: impl IntoIterator, + ) -> Result<(), MerkleError> { + let mut leaves = leaves.into_iter().collect::>(); + + // Sort by index + leaves.sort_by_key(|(idx, _)| *idx); + + let (indices, leaves): (Vec<_>, Vec<_>) = leaves.into_iter().unzip(); + + if indices.iter().contains_dups() { + return Err(MerkleError::new("duplicate leaf indices provided")); + } + + if !self.proof.verify( + &RsMerkleHasher(hasher), + root.value, + &indices, + &leaves, + self.tree_len, + ) { + return Err(MerkleError::new("invalid merkle proof")); + } + + Ok(()) + } +} + +#[derive(Clone)] +struct RsMerkleHasher<'a>(&'a dyn HashAlgorithm); + +impl rs_merkle::Hasher for RsMerkleHasher<'_> { + type Hash = Hash; + + fn hash(&self, data: &[u8]) -> Hash { + self.0.hash(data) + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub(crate) struct MerkleTree { + alg: HashAlgId, + tree: rs_merkle::MerkleTree, +} + +impl MerkleTree { + pub(crate) fn new(alg: HashAlgId) -> Self { + Self { + alg, + tree: Default::default(), + } + } + + pub(crate) fn algorithm(&self) -> HashAlgId { + self.alg + } + + pub(crate) fn root(&self) -> TypedHash { + TypedHash { + alg: self.alg, + value: self.tree.root().expect("tree should not be empty"), + } + } + + /// Inserts leaves into the tree. + /// + /// # Panics + /// + /// - If the provided hasher is not the same as the one used to create the + /// tree. + pub(crate) fn insert(&mut self, hasher: &dyn HashAlgorithm, mut leaves: Vec) { + assert_eq!(self.alg, hasher.id(), "hash algorithm mismatch"); + + self.tree.append(&mut leaves); + self.tree.commit(&RsMerkleHasher(hasher)) + } + + /// Returns a Merkle proof for the provided indices. + /// + /// # Panics + /// + /// - If the provided indices are not unique and sorted. + pub(crate) fn proof(&self, indices: &[usize]) -> MerkleProof { + assert!( + indices.windows(2).all(|w| w[0] < w[1]), + "indices must be unique and sorted" + ); + + MerkleProof { + alg: self.alg, + tree_len: self.tree.leaves_len(), + proof: self.tree.proof(indices), + } + } +} + +#[cfg(test)] +mod test { + use crate::hash::{impl_domain_separator, Blake3, HashAlgorithmExt, Keccak256, Sha256}; + + use super::*; + use rstest::*; + + #[derive(Serialize)] + struct T(u64); + + impl_domain_separator!(T); + + fn leaves(hasher: &H, leaves: impl IntoIterator) -> Vec { + leaves + .into_iter() + .map(|x| hasher.hash_canonical(&x)) + .collect() + } + + fn choose_leaves( + indices: impl IntoIterator, + leaves: &[Hash], + ) -> Vec<(usize, Hash)> { + indices.into_iter().map(|i| (i, leaves[i])).collect() + } + + // Expect Merkle proof verification to succeed + #[rstest] + #[case::sha2(Sha256::default())] + #[case::blake3(Blake3::default())] + #[case::keccak(Keccak256::default())] + fn test_verify_success(#[case] hasher: H) { + let mut tree = MerkleTree::new(hasher.id()); + + let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]); + + tree.insert(&hasher, leaves.clone()); + + let proof = tree.proof(&[2, 3, 4]); + + assert!(proof + .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves)) + .is_ok()); + } + + #[rstest] + #[case::sha2(Sha256::default())] + #[case::blake3(Blake3::default())] + #[case::keccak(Keccak256::default())] + fn test_verify_fail_wrong_leaf(#[case] hasher: H) { + let mut tree = MerkleTree::new(hasher.id()); + + let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]); + + tree.insert(&hasher, leaves.clone()); + + let proof = tree.proof(&[2, 3, 4]); + + let mut choices = choose_leaves([2, 3, 4], &leaves); + + choices[1].1 = leaves[0]; + + // fail because the leaf is wrong + assert!(proof.verify(&hasher, &tree.root(), choices).is_err()); + } + + #[rstest] + #[case::sha2(Sha256::default())] + #[case::blake3(Blake3::default())] + #[case::keccak(Keccak256::default())] + #[should_panic] + fn test_proof_fail_length_unsorted(#[case] hasher: H) { + let mut tree = MerkleTree::new(hasher.id()); + + let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]); + + tree.insert(&hasher, leaves.clone()); + + _ = tree.proof(&[2, 4, 3]); + } + + #[rstest] + #[case::sha2(Sha256::default())] + #[case::blake3(Blake3::default())] + #[case::keccak(Keccak256::default())] + #[should_panic] + fn test_proof_fail_length_duplicates(#[case] hasher: H) { + let mut tree = MerkleTree::new(hasher.id()); + + let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]); + + tree.insert(&hasher, leaves.clone()); + + _ = tree.proof(&[2, 2, 3]); + } + + #[rstest] + #[case::sha2(Sha256::default())] + #[case::blake3(Blake3::default())] + #[case::keccak(Keccak256::default())] + fn test_verify_fail_duplicates(#[case] hasher: H) { + let mut tree = MerkleTree::new(hasher.id()); + + let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]); + + tree.insert(&hasher, leaves.clone()); + + let proof = tree.proof(&[2, 3, 4]); + + assert!(proof + .verify(&hasher, &tree.root(), choose_leaves([2, 2, 3], &leaves)) + .is_err()); + } + + #[rstest] + #[case::sha2(Sha256::default())] + #[case::blake3(Blake3::default())] + #[case::keccak(Keccak256::default())] + fn test_verify_fail_incorrect_leaf_count(#[case] hasher: H) { + let mut tree = MerkleTree::new(hasher.id()); + + let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]); + + tree.insert(&hasher, leaves.clone()); + + let mut proof = tree.proof(&[2, 3, 4]); + + proof.tree_len += 1; + + // fail because leaf count is wrong + assert!(proof + .verify(&hasher, &tree.root(), choose_leaves([2, 3, 4], &leaves)) + .is_err()); + } + + #[rstest] + #[case::sha2(Sha256::default())] + #[case::blake3(Blake3::default())] + #[case::keccak(Keccak256::default())] + fn test_verify_fail_incorrect_indices(#[case] hasher: H) { + let mut tree = MerkleTree::new(hasher.id()); + + let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]); + + tree.insert(&hasher, leaves.clone()); + + let proof = tree.proof(&[2, 3, 4]); + + let mut choices = choose_leaves([2, 3, 4], &leaves); + choices[1].0 = 1; + + // fail because leaf index is wrong + assert!(proof.verify(&hasher, &tree.root(), choices).is_err()); + } + + #[rstest] + #[case::sha2(Sha256::default())] + #[case::blake3(Blake3::default())] + #[case::keccak(Keccak256::default())] + fn test_verify_fail_fewer_indices(#[case] hasher: H) { + let mut tree = MerkleTree::new(hasher.id()); + + let leaves = leaves(&hasher, [T(0), T(1), T(2), T(3), T(4)]); + + tree.insert(&hasher, leaves.clone()); + + let proof = tree.proof(&[2, 3, 4]); + + // trying to verify less leaves than what was included in the proof + assert!(proof + .verify(&hasher, &tree.root(), choose_leaves([2, 3], &leaves)) + .is_err()); + } +} diff --git a/crates/core/src/presentation.rs b/crates/core/src/presentation.rs new file mode 100644 index 0000000000..0557e4b00a --- /dev/null +++ b/crates/core/src/presentation.rs @@ -0,0 +1,251 @@ +//! Verifiable presentation. +//! +//! We borrow the term "presentation" from the +//! [W3C Verifiable Credentials Data Model](https://www.w3.org/TR/vc-data-model/#presentations-0). +//! +//! > Data derived from one or more verifiable credentials, issued by one or +//! > more issuers, that is shared with a specific verifier. A verifiable +//! > presentation is a tamper-evident presentation encoded in such a way that +//! > authorship of the data can be trusted after a process of cryptographic +//! > verification. Certain types of verifiable presentations might contain data +//! > that is synthesized from, but do not contain, the original verifiable +//! > credentials (for example, zero-knowledge proofs). +//! +//! Instead of a credential, a presentation in this context is a proof of an +//! attestation from a Notary along with additional selectively disclosed +//! information about the TLS connection such as the server's identity and the +//! application data communicated with the server. +//! +//! A presentation is self-contained and can be verified by a Verifier without +//! needing access to external data. The Verifier need only check that the key +//! used to sign the attestation, referred to as a [`VerifyingKey`], is from a +//! Notary they trust. See an [example](crate#verifying-a-presentation) in the +//! crate level documentation. + +use std::fmt; + +use serde::{Deserialize, Serialize}; + +use crate::{ + attestation::{Attestation, AttestationError, AttestationProof}, + connection::{ConnectionInfo, ServerIdentityProof, ServerIdentityProofError, ServerName}, + signing::VerifyingKey, + transcript::{PartialTranscript, TranscriptProof, TranscriptProofError}, + CryptoProvider, +}; + +/// A verifiable presentation. +/// +/// See the [module level documentation](crate::presentation) for more +/// information. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Presentation { + attestation: AttestationProof, + identity: Option, + transcript: Option, +} + +impl Presentation { + /// Creates a new builder. + pub fn builder<'a>( + provider: &'a CryptoProvider, + attestation: &'a Attestation, + ) -> PresentationBuilder<'a> { + PresentationBuilder::new(provider, attestation) + } + + /// Returns the verifying key. + pub fn verifying_key(&self) -> &VerifyingKey { + self.attestation.verifying_key() + } + + /// Verifies the presentation. + pub fn verify( + self, + provider: &CryptoProvider, + ) -> Result { + let Self { + attestation, + identity, + transcript, + } = self; + + let attestation = attestation.verify(provider)?; + + let server_name = identity + .map(|identity| { + identity.verify_with_provider( + provider, + attestation.body.connection_info().time, + attestation.body.server_ephemeral_key(), + attestation.body.cert_commitment(), + ) + }) + .transpose()?; + + let transcript = transcript + .map(|transcript| transcript.verify_with_provider(provider, &attestation.body)) + .transpose()?; + + let connection_info = attestation.body.connection_info().clone(); + + Ok(PresentationOutput { + attestation, + server_name, + connection_info, + transcript, + }) + } +} + +/// Output of a verified [`Presentation`]. +#[derive(Debug)] +#[non_exhaustive] +pub struct PresentationOutput { + /// Verified attestation. + pub attestation: Attestation, + /// Authenticated server name. + pub server_name: Option, + /// Connection information. + pub connection_info: ConnectionInfo, + /// Authenticated transcript data. + pub transcript: Option, +} + +/// Builder for [`Presentation`]. +pub struct PresentationBuilder<'a> { + provider: &'a CryptoProvider, + attestation: &'a Attestation, + identity_proof: Option, + transcript_proof: Option, +} + +impl<'a> PresentationBuilder<'a> { + pub(crate) fn new(provider: &'a CryptoProvider, attestation: &'a Attestation) -> Self { + Self { + provider, + attestation, + identity_proof: None, + transcript_proof: None, + } + } + + /// Includes a server identity proof. + pub fn identity_proof(&mut self, proof: ServerIdentityProof) -> &mut Self { + self.identity_proof = Some(proof); + self + } + + /// Includes a transcript proof. + pub fn transcript_proof(&mut self, proof: TranscriptProof) -> &mut Self { + self.transcript_proof = Some(proof); + self + } + + /// Builds the presentation. + pub fn build(self) -> Result { + let attestation = AttestationProof::new(self.provider, self.attestation)?; + + Ok(Presentation { + attestation, + identity: self.identity_proof, + transcript: self.transcript_proof, + }) + } +} + +/// Error for [`PresentationBuilder`]. +#[derive(Debug, thiserror::Error)] +pub struct PresentationBuilderError { + kind: BuilderErrorKind, + source: Option>, +} + +#[derive(Debug)] +enum BuilderErrorKind { + Attestation, +} + +impl fmt::Display for PresentationBuilderError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("presentation builder error: ")?; + + match self.kind { + BuilderErrorKind::Attestation => f.write_str("attestation error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +impl From for PresentationBuilderError { + fn from(error: AttestationError) -> Self { + Self { + kind: BuilderErrorKind::Attestation, + source: Some(Box::new(error)), + } + } +} + +/// Error for [`Presentation`]. +#[derive(Debug, thiserror::Error)] +pub struct PresentationError { + kind: ErrorKind, + source: Option>, +} + +#[derive(Debug)] +enum ErrorKind { + Attestation, + Identity, + Transcript, +} + +impl fmt::Display for PresentationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("presentation error: ")?; + + match self.kind { + ErrorKind::Attestation => f.write_str("attestation error")?, + ErrorKind::Identity => f.write_str("server identity error")?, + ErrorKind::Transcript => f.write_str("transcript error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +impl From for PresentationError { + fn from(error: AttestationError) -> Self { + Self { + kind: ErrorKind::Attestation, + source: Some(Box::new(error)), + } + } +} + +impl From for PresentationError { + fn from(error: ServerIdentityProofError) -> Self { + Self { + kind: ErrorKind::Identity, + source: Some(Box::new(error)), + } + } +} + +impl From for PresentationError { + fn from(error: TranscriptProofError) -> Self { + Self { + kind: ErrorKind::Transcript, + source: Some(Box::new(error)), + } + } +} diff --git a/crates/core/src/provider.rs b/crates/core/src/provider.rs new file mode 100644 index 0000000000..3dabf4a039 --- /dev/null +++ b/crates/core/src/provider.rs @@ -0,0 +1,67 @@ +use tls_core::{ + anchors::{OwnedTrustAnchor, RootCertStore}, + verify::WebPkiVerifier, +}; + +use crate::{ + hash::HashProvider, + signing::{SignatureVerifierProvider, SignerProvider}, +}; + +/// Cryptography provider. +/// +/// ## Custom Algorithms +/// +/// This is the primary interface for extending cryptographic functionality. The +/// various providers can be configured with custom algorithms and +/// implementations. +/// +/// Algorithms are uniquely identified using an 8-bit ID, eg. +/// [`HashAlgId`](crate::hash::HashAlgId), half of which is reserved for the +/// officially supported algorithms. If you think that a new algorithm should be +/// added to the official set, please open an issue. Beware that other parties +/// may assign different algorithms to the same ID as you, and we make no effort +/// to mitigate this. +pub struct CryptoProvider { + /// Hash provider. + pub hash: HashProvider, + /// Certificate verifier. + /// + /// This is used to verify the server's certificate chain. + /// + /// The default verifier uses the Mozilla root certificates. + pub cert: WebPkiVerifier, + /// Signer provider. + /// + /// This is used for signing attestations. + pub signer: SignerProvider, + /// Signature verifier provider. + /// + /// This is used for verifying signatures of attestations. + pub signature: SignatureVerifierProvider, +} + +opaque_debug::implement!(CryptoProvider); + +impl Default for CryptoProvider { + fn default() -> Self { + Self { + hash: Default::default(), + cert: default_cert_verifier(), + signer: Default::default(), + signature: Default::default(), + } + } +} + +pub(crate) fn default_cert_verifier() -> WebPkiVerifier { + let mut root_store = RootCertStore::empty(); + root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject.as_ref(), + ta.subject_public_key_info.as_ref(), + ta.name_constraints.as_ref().map(|nc| nc.as_ref()), + ) + })); + WebPkiVerifier::new(root_store, None) +} diff --git a/crates/core/src/request.rs b/crates/core/src/request.rs new file mode 100644 index 0000000000..047ee33d10 --- /dev/null +++ b/crates/core/src/request.rs @@ -0,0 +1,282 @@ +//! Attestation requests. +//! +//! After the TLS connection, a Prover can request an attestation from the +//! Notary which contains various information about the connection. During this +//! process the Prover has the opportunity to configure certain aspects of the +//! attestation, such as which signature algorithm the Notary should use to sign +//! the attestation. Or which hash algorithm the Notary should use to merkelize +//! the fields. +//! +//! A [`Request`] can be created using a [`RequestBuilder`]. The builder will +//! take both configuration via a [`RequestConfig`] as well as the Prover's +//! secret data. The [`Secrets`](crate::Secrets) are of course not shared with +//! the Notary but are used to create commitments which are included in the +//! attestation. + +mod builder; +mod config; + +use serde::{Deserialize, Serialize}; + +use crate::{ + attestation::{Attestation, Field}, + connection::ServerCertCommitment, + hash::{HashAlgId, TypedHash}, + signing::SignatureAlgId, + transcript::PlaintextHash, +}; + +pub use builder::{RequestBuilder, RequestBuilderError}; +pub use config::{RequestConfig, RequestConfigBuilder, RequestConfigBuilderError}; + +/// Attestation request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Request { + pub(crate) signature_alg: SignatureAlgId, + pub(crate) hash_alg: HashAlgId, + pub(crate) server_cert_commitment: ServerCertCommitment, + pub(crate) encoding_commitment_root: Option, + /// Plaintext hash commitments sorted by field id. + /// + /// The field ids start from [PLAINTEXT_HASH_INITIAL_FIELD_ID]. + pub(crate) plaintext_hashes: Option>>, +} + +impl Request { + /// Returns a new request builder. + pub fn builder(config: &RequestConfig) -> RequestBuilder { + RequestBuilder::new(config) + } + + /// Validates the content of the attestation against this request. + pub fn validate(&self, attestation: &Attestation) -> Result<(), InconsistentAttestation> { + if attestation.signature.alg != self.signature_alg { + return Err(InconsistentAttestation(format!( + "signature algorithm: expected {:?}, got {:?}", + self.signature_alg, attestation.signature.alg + ))); + } + + if attestation.header.root.alg != self.hash_alg { + return Err(InconsistentAttestation(format!( + "hash algorithm: expected {:?}, got {:?}", + self.hash_alg, attestation.header.root.alg + ))); + } + + if attestation.body.cert_commitment() != &self.server_cert_commitment { + return Err(InconsistentAttestation( + "server certificate commitment does not match".to_string(), + )); + } + + if let Some(encoding_commitment_root) = &self.encoding_commitment_root { + let Some(encoding_commitment) = attestation.body.encoding_commitment() else { + return Err(InconsistentAttestation( + "encoding commitment is missing".to_string(), + )); + }; + + if &encoding_commitment.root != encoding_commitment_root { + return Err(InconsistentAttestation( + "encoding commitment root does not match".to_string(), + )); + } + } + + match (&self.plaintext_hashes, attestation.body.plaintext_hashes()) { + (Some(request_hashes), Some(attested_hashes)) => { + if request_hashes != attested_hashes { + return Err(InconsistentAttestation( + "plaintext hash commitments do not match".to_string(), + )); + } + } + // If there are no hashes in the request, do nothing. + (None, Some(_attested_hashes)) => {} + (None, None) => {} + (Some(_request_hashes), None) => { + return Err(InconsistentAttestation( + "plaintext hash commitments do not match".to_string(), + )); + } + } + + // TODO: add signature verification. + + Ok(()) + } +} + +/// Error for [`Request::validate`]. +#[derive(Debug, thiserror::Error)] +#[error("inconsistent attestation: {0}")] +pub struct InconsistentAttestation(String); + +#[cfg(test)] +mod test { + use super::*; + + use crate::{ + attestation::{Attestation, AttestationConfig}, + connection::{HandshakeData, HandshakeDataV1_2, ServerCertOpening, TranscriptLength}, + fixtures::{encoder_seed, encoding_provider, ConnectionFixture}, + hash::{Blake3, Hash, HashAlgId}, + signing::SignatureAlgId, + transcript::{encoding::EncodingTree, Transcript, TranscriptCommitConfigBuilder}, + CryptoProvider, + }; + + use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON}; + + fn attestation(payload: (Request, ConnectionFixture)) -> Attestation { + let (request, connection) = payload; + + let ConnectionFixture { + connection_info, + server_cert_data, + .. + } = connection; + + let HandshakeData::V1_2(HandshakeDataV1_2 { + server_ephemeral_key, + .. + }) = server_cert_data.handshake.clone(); + + let mut provider = CryptoProvider::default(); + provider.signer.set_secp256k1(&[42u8; 32]).unwrap(); + + let attestation_config = AttestationConfig::builder() + .supported_signature_algs([SignatureAlgId::SECP256K1]) + .build() + .unwrap(); + + let mut attestation_builder = Attestation::builder(&attestation_config) + .accept_request(request.clone()) + .unwrap(); + + attestation_builder + .connection_info(connection_info.clone()) + .server_ephemeral_key(server_ephemeral_key) + .encoding_seed(encoder_seed().to_vec()); + + attestation_builder.build(&provider).unwrap() + } + + fn request_and_connection() -> (Request, ConnectionFixture) { + let provider = CryptoProvider::default(); + + let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON); + let (sent_len, recv_len) = transcript.len(); + // Plaintext encodings which the Prover obtained from GC evaluation + let encodings_provider = encoding_provider(GET_WITH_HEADER, OK_JSON); + + // At the end of the TLS connection the Prover holds the: + let ConnectionFixture { + server_name, + server_cert_data, + .. + } = ConnectionFixture::tlsnotary(transcript.length()); + + // Prover specifies the ranges it wants to commit to. + let mut transcript_commitment_builder = TranscriptCommitConfigBuilder::new(&transcript); + transcript_commitment_builder + .commit_sent(&(0..sent_len)) + .unwrap() + .commit_recv(&(0..recv_len)) + .unwrap(); + + let transcripts_commitment_config = transcript_commitment_builder.build().unwrap(); + + // Prover constructs encoding tree. + let encoding_tree = EncodingTree::new( + &Blake3::default(), + transcripts_commitment_config.iter_encoding(), + &encodings_provider, + &transcript.length(), + ) + .unwrap(); + + let request_config = RequestConfig::default(); + let mut request_builder = Request::builder(&request_config); + + request_builder + .server_name(server_name.clone()) + .server_cert_data(server_cert_data) + .transcript(transcript.clone()) + .encoding_tree(encoding_tree); + let (request, _) = request_builder.build(&provider).unwrap(); + + (request, ConnectionFixture::tlsnotary(transcript.length())) + } + + #[test] + fn test_success() { + let (request, connection) = request_and_connection(); + + let attestation = attestation((request.clone(), connection)); + + assert!(request.validate(&attestation).is_ok()) + } + + #[test] + fn test_wrong_signature_alg() { + let (mut request, connection) = request_and_connection(); + + let attestation = attestation((request.clone(), connection)); + + request.signature_alg = SignatureAlgId::SECP256R1; + + let res = request.validate(&attestation); + assert!(res.is_err()); + } + + #[test] + fn test_wrong_hash_alg() { + let (mut request, connection) = request_and_connection(); + + let attestation = attestation((request.clone(), connection)); + + request.hash_alg = HashAlgId::SHA256; + + let res = request.validate(&attestation); + assert!(res.is_err()) + } + + #[test] + fn test_wrong_server_commitment() { + let (mut request, connection) = request_and_connection(); + + let attestation = attestation((request.clone(), connection)); + + let ConnectionFixture { + server_cert_data, .. + } = ConnectionFixture::appliedzkp(TranscriptLength { + sent: 100, + received: 100, + }); + let opening = ServerCertOpening::new(server_cert_data); + + let crypto_provider = CryptoProvider::default(); + request.server_cert_commitment = + opening.commit(crypto_provider.hash.get(&HashAlgId::BLAKE3).unwrap()); + + let res = request.validate(&attestation); + assert!(res.is_err()) + } + + #[test] + fn test_wrong_encoding_commitment_root() { + let (mut request, connection) = request_and_connection(); + + let attestation = attestation((request.clone(), connection)); + + request.encoding_commitment_root = Some(TypedHash { + alg: HashAlgId::BLAKE3, + value: Hash::default(), + }); + + let res = request.validate(&attestation); + assert!(res.is_err()) + } +} diff --git a/crates/core/src/request/builder.rs b/crates/core/src/request/builder.rs new file mode 100644 index 0000000000..b167c8928d --- /dev/null +++ b/crates/core/src/request/builder.rs @@ -0,0 +1,224 @@ +use rand::{ + distributions::{Distribution, Standard}, + thread_rng, +}; + +use crate::{ + attestation::{Field, FieldId, PLAINTEXT_HASH_INITIAL_FIELD_ID}, + connection::{ServerCertData, ServerCertOpening, ServerName}, + hash::{Blinded, Blinder, TypedHash}, + request::{Request, RequestConfig}, + secrets::Secrets, + transcript::{ + encoding::EncodingTree, PlaintextHash, PlaintextHashSecret, Transcript, TranscriptCommitmentKind + }, + CryptoProvider, +}; +use crate::transcript::commit::CommitInfo; + +#[cfg(feature = "poseidon")] +use crate::hash::HashAlgId; + +/// Builder for [`Request`]. +pub struct RequestBuilder<'a> { + config: &'a RequestConfig, + server_name: Option, + server_cert_data: Option, + encoding_tree: Option, + transcript: Option, + plaintext_hashes: Option>, +} + +impl<'a> RequestBuilder<'a> { + /// Creates a new request builder. + pub fn new(config: &'a RequestConfig) -> Self { + Self { + config, + server_name: None, + server_cert_data: None, + encoding_tree: None, + transcript: None, + plaintext_hashes: None, + } + } + + /// Sets the server name. + pub fn server_name(&mut self, name: ServerName) -> &mut Self { + self.server_name = Some(name); + self + } + + /// Sets the server identity data. + pub fn server_cert_data(&mut self, data: ServerCertData) -> &mut Self { + self.server_cert_data = Some(data); + self + } + + /// Sets the tree to commit to the transcript encodings. + pub fn encoding_tree(&mut self, tree: EncodingTree) -> &mut Self { + self.encoding_tree = Some(tree); + self + } + + /// Sets the transcript. + pub fn transcript(&mut self, transcript: Transcript) -> &mut Self { + self.transcript = Some(transcript); + self + } + + /// Sets the plaintext hash commitment info. + pub fn plaintext_hashes( + &mut self, + plaintext_hashes: impl Iterator, + ) -> &mut Self { + self.plaintext_hashes = Some(plaintext_hashes.collect::>()); + self + } + + /// Builds the attestation request and returns the corresponding secrets. + pub fn build( + self, + provider: &CryptoProvider, + ) -> Result<(Request, Secrets), RequestBuilderError> { + let Self { + config, + server_name, + server_cert_data, + encoding_tree, + transcript, + plaintext_hashes, + } = self; + + let signature_alg = *config.signature_alg(); + let hash_alg = *config.hash_alg(); + + let hasher = provider.hash.get(&hash_alg).map_err(|_| { + RequestBuilderError::new(format!("unsupported hash algorithm: {hash_alg}")) + })?; + + let server_name = + server_name.ok_or_else(|| RequestBuilderError::new("server name is missing"))?; + + let server_cert_opening = ServerCertOpening::new( + server_cert_data + .ok_or_else(|| RequestBuilderError::new("server identity data is missing"))?, + ); + + let transcript = + transcript.ok_or_else(|| RequestBuilderError::new("transcript is missing"))?; + + let server_cert_commitment = server_cert_opening.commit(hasher); + + let encoding_commitment_root = encoding_tree.as_ref().map(|tree| tree.root()); + + let (pt_hashes, pt_secrets) = match plaintext_hashes { + Some(plaintext_hashes) => { + + if plaintext_hashes.is_empty() { + return Err(RequestBuilderError::new("empty plaintext hash info was set")); + } + + let mut field_id = FieldId::new(PLAINTEXT_HASH_INITIAL_FIELD_ID); + + let (pt_hashes, pt_secrets): (Vec>, Vec) = plaintext_hashes.into_iter().map(|info|{ + let alg = if let TranscriptCommitmentKind::Hash{alg} = info.kind() { + alg + } else { + return Err(RequestBuilderError::new("only plaintext commitments are allowed")); + }; + + let (dir, idx) = info.idx().clone(); + + let data = transcript.get(dir, &idx).ok_or_else(|| { + RequestBuilderError::new(format!( + "direction {} and index {:?} were not found in the transcript", + dir, &idx + )) + })?; + + let blinder = match info.blinder() { + Some(blinder) => { + // The hash was computed earlier. + blinder.clone() + }, + None => { + let blinder: Blinder = Standard.sample(&mut thread_rng()); + blinder + } + }; + + let hasher = provider + .hash + .get(alg) + .map_err(|_| RequestBuilderError::new("hash provider is missing"))?; + + #[cfg(feature = "poseidon")] + if alg == &HashAlgId::POSEIDON_BN256_434 { + if idx.count() != 1 { + return Err(RequestBuilderError::new("committing to more than one range with POSEIDON_HALO2 is not supported")); + } else if idx.len() > crate::hash::POSEIDON_MAX_INPUT_SIZE { + return Err(RequestBuilderError::new(format!("committing to more than {} bytes with POSEIDON_HALO2 is not supported", crate::hash::POSEIDON_MAX_INPUT_SIZE))); + } + } + + let data = Blinded::new_with_blinder(data.data().to_vec(), blinder.clone()); + let hash = hasher.hash_blinded(&data); + + let field = field_id.next(PlaintextHash { + direction: dir, + idx: idx.clone(), + hash: TypedHash { + alg: *alg, + value: hash + }, + }); + + let id = field.id; + Ok((field, PlaintextHashSecret { + blinder, + idx, + direction: dir, + commitment: id, + })) + + }).collect::, RequestBuilderError>>()?.into_iter().unzip(); + + (Some(pt_hashes), Some(pt_secrets.into())) + } + None => (None, None), + }; + + let request = Request { + signature_alg, + hash_alg, + server_cert_commitment, + encoding_commitment_root, + plaintext_hashes: pt_hashes, + }; + + let secrets = Secrets { + server_name, + server_cert_opening, + encoding_tree, + plaintext_hash_secrets: pt_secrets, + transcript, + }; + + Ok((request, secrets)) + } +} + +/// Error for [`RequestBuilder`]. +#[derive(Debug, thiserror::Error)] +#[error("request builder error: {message}")] +pub struct RequestBuilderError { + message: String, +} + +impl RequestBuilderError { + fn new(message: impl Into) -> Self { + Self { + message: message.into(), + } + } +} diff --git a/crates/core/src/request/config.rs b/crates/core/src/request/config.rs new file mode 100644 index 0000000000..d526e04bc5 --- /dev/null +++ b/crates/core/src/request/config.rs @@ -0,0 +1,76 @@ +use crate::{hash::HashAlgId, signing::SignatureAlgId}; + +/// Request configuration. +#[derive(Debug, Clone)] +pub struct RequestConfig { + signature_alg: SignatureAlgId, + hash_alg: HashAlgId, +} + +impl Default for RequestConfig { + fn default() -> Self { + Self::builder().build().unwrap() + } +} + +impl RequestConfig { + /// Creates a new builder. + pub fn builder() -> RequestConfigBuilder { + RequestConfigBuilder::default() + } + + /// Returns the signature algorithm. + pub fn signature_alg(&self) -> &SignatureAlgId { + &self.signature_alg + } + + /// Returns the hash algorithm. + pub fn hash_alg(&self) -> &HashAlgId { + &self.hash_alg + } +} + +/// Builder for [`RequestConfig`]. +#[derive(Debug)] +pub struct RequestConfigBuilder { + signature_alg: SignatureAlgId, + hash_alg: HashAlgId, +} + +impl Default for RequestConfigBuilder { + fn default() -> Self { + Self { + signature_alg: SignatureAlgId::SECP256K1, + hash_alg: HashAlgId::BLAKE3, + } + } +} + +impl RequestConfigBuilder { + /// Sets the signature algorithm. + pub fn signature_alg(&mut self, signature_alg: SignatureAlgId) -> &mut Self { + self.signature_alg = signature_alg; + self + } + + /// Sets the hash algorithm. + pub fn hash_alg(&mut self, hash_alg: HashAlgId) -> &mut Self { + self.hash_alg = hash_alg; + self + } + + /// Builds the config. + pub fn build(self) -> Result { + Ok(RequestConfig { + signature_alg: self.signature_alg, + hash_alg: self.hash_alg, + }) + } +} + +/// Error for [`RequestConfigBuilder`]. +#[derive(Debug, thiserror::Error)] +#[error("request configuration builder error: {message}")] +pub struct RequestConfigBuilderError { + message: String, +} diff --git a/crates/core/src/secrets.rs b/crates/core/src/secrets.rs new file mode 100644 index 0000000000..f7bee95001 --- /dev/null +++ b/crates/core/src/secrets.rs @@ -0,0 +1,47 @@ +//! Secrets types. + +use serde::{Deserialize, Serialize}; + +use crate::{ + connection::{ServerCertOpening, ServerIdentityProof, ServerName}, + index::Index, + transcript::{encoding::EncodingTree, PlaintextHashSecret, Transcript, TranscriptProofBuilder}, +}; + +/// Secret data of an [`Attestation`](crate::attestation::Attestation). +#[derive(Clone, Serialize, Deserialize)] +pub struct Secrets { + pub(crate) server_name: ServerName, + pub(crate) server_cert_opening: ServerCertOpening, + pub(crate) encoding_tree: Option, + pub(crate) plaintext_hash_secrets: Option>, + pub(crate) transcript: Transcript, +} + +opaque_debug::implement!(Secrets); + +impl Secrets { + /// Returns the server name. + pub fn server_name(&self) -> &ServerName { + &self.server_name + } + + /// Returns the transcript. + pub fn transcript(&self) -> &Transcript { + &self.transcript + } + + /// Returns a server identity proof. + pub fn identity_proof(&self) -> ServerIdentityProof { + ServerIdentityProof::new(self.server_name.clone(), self.server_cert_opening.clone()) + } + + /// Returns a transcript proof builder. + pub fn transcript_proof_builder(&self) -> TranscriptProofBuilder<'_> { + TranscriptProofBuilder::new( + &self.transcript, + self.encoding_tree.as_ref(), + &self.plaintext_hash_secrets, + ) + } +} diff --git a/crates/core/src/serialize.rs b/crates/core/src/serialize.rs new file mode 100644 index 0000000000..777149fbbf --- /dev/null +++ b/crates/core/src/serialize.rs @@ -0,0 +1,18 @@ +/// Canonical serialization of TLSNotary types. +/// +/// This trait is used to serialize types into a canonical byte representation. +pub(crate) trait CanonicalSerialize { + /// Serializes the type. + fn serialize(&self) -> Vec; +} + +impl CanonicalSerialize for T +where + T: serde::Serialize, +{ + fn serialize(&self) -> Vec { + // For now we use BCS for serialization. In future releases we will want to + // consider this further, particularly with respect to EVM compatibility. + bcs::to_bytes(self).unwrap() + } +} diff --git a/crates/core/src/signing.rs b/crates/core/src/signing.rs new file mode 100644 index 0000000000..eb2c902231 --- /dev/null +++ b/crates/core/src/signing.rs @@ -0,0 +1,449 @@ +//! Cryptographic signatures. + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::hash::impl_domain_separator; + +/// Key algorithm identifier. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct KeyAlgId(u8); + +impl KeyAlgId { + /// secp256k1 elliptic curve key algorithm. + pub const K256: Self = Self(1); + /// NIST P-256 elliptic curve key algorithm. + pub const P256: Self = Self(2); + + /// Creates a new key algorithm identifier. + /// + /// # Panics + /// + /// Panics if the identifier is in the reserved range 0-127. + /// + /// # Arguments + /// + /// * id - Unique identifier for the key algorithm. + pub const fn new(id: u8) -> Self { + assert!(id >= 128, "key algorithm id range 0-127 is reserved"); + + Self(id) + } + + /// Returns the id as a `u8`. + pub const fn as_u8(&self) -> u8 { + self.0 + } +} + +impl std::fmt::Display for KeyAlgId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match *self { + KeyAlgId::K256 => write!(f, "k256"), + KeyAlgId::P256 => write!(f, "p256"), + _ => write!(f, "custom({:02x})", self.0), + } + } +} + +/// Signature algorithm identifier. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct SignatureAlgId(u8); + +impl SignatureAlgId { + /// secp256k1 signature algorithm. + pub const SECP256K1: Self = Self(1); + /// secp256r1 signature algorithm. + pub const SECP256R1: Self = Self(2); + + /// Creates a new signature algorithm identifier. + /// + /// # Panics + /// + /// Panics if the identifier is in the reserved range 0-127. + /// + /// # Arguments + /// + /// * id - Unique identifier for the signature algorithm. + pub const fn new(id: u8) -> Self { + assert!(id >= 128, "signature algorithm id range 0-127 is reserved"); + + Self(id) + } + + /// Returns the id as a `u8`. + pub const fn as_u8(&self) -> u8 { + self.0 + } +} + +impl std::fmt::Display for SignatureAlgId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match *self { + SignatureAlgId::SECP256K1 => write!(f, "secp256k1"), + SignatureAlgId::SECP256R1 => write!(f, "secp256r1"), + _ => write!(f, "custom({:02x})", self.0), + } + } +} + +/// Unknown signature algorithm error. +#[derive(Debug, thiserror::Error)] +#[error("unknown signature algorithm id: {0:?}")] +pub struct UnknownSignatureAlgId(SignatureAlgId); + +/// Provider of signers. +#[derive(Default)] +pub struct SignerProvider { + signers: HashMap>, +} + +impl SignerProvider { + /// Returns the supported signature algorithms. + pub fn supported_algs(&self) -> impl Iterator + '_ { + self.signers.keys().copied() + } + + /// Configures a signer. + pub fn set_signer(&mut self, signer: Box) { + self.signers.insert(signer.alg_id(), signer); + } + + /// Configures a secp256k1 signer with the provided signing key. + pub fn set_secp256k1(&mut self, key: &[u8]) -> Result<&mut Self, SignerError> { + self.set_signer(Box::new(Secp256k1Signer::new(key)?)); + + Ok(self) + } + + /// Configures a secp256r1 signer with the provided signing key. + pub fn set_secp256r1(&mut self, key: &[u8]) -> Result<&mut Self, SignerError> { + self.set_signer(Box::new(Secp256r1Signer::new(key)?)); + + Ok(self) + } + + /// Returns a signer for the given algorithm. + pub(crate) fn get( + &self, + alg: &SignatureAlgId, + ) -> Result<&(dyn Signer + Send + Sync), UnknownSignatureAlgId> { + self.signers + .get(alg) + .map(|s| &**s) + .ok_or(UnknownSignatureAlgId(*alg)) + } +} + +/// Error for [`Signer`]. +#[derive(Debug, thiserror::Error)] +#[error("signer error: {0}")] +pub struct SignerError(String); + +/// Cryptographic signer. +pub trait Signer { + /// Returns the algorithm used by this signer. + fn alg_id(&self) -> SignatureAlgId; + + /// Signs the message. + fn sign(&self, msg: &[u8]) -> Result; + + /// Returns the verifying key for this signer. + fn verifying_key(&self) -> VerifyingKey; +} + +/// Provider of signature verifiers. +pub struct SignatureVerifierProvider { + verifiers: HashMap>, +} + +impl Default for SignatureVerifierProvider { + fn default() -> Self { + let mut verifiers = HashMap::new(); + + verifiers.insert(SignatureAlgId::SECP256K1, Box::new(Secp256k1Verifier) as _); + verifiers.insert(SignatureAlgId::SECP256R1, Box::new(Secp256r1Verifier) as _); + + Self { verifiers } + } +} + +impl SignatureVerifierProvider { + /// Configures a signature verifier. + pub fn set_verifier(&mut self, verifier: Box) { + self.verifiers.insert(verifier.alg_id(), verifier); + } + + /// Returns the verifier for the given algorithm. + pub(crate) fn get( + &self, + alg: &SignatureAlgId, + ) -> Result<&(dyn SignatureVerifier + Send + Sync), UnknownSignatureAlgId> { + self.verifiers + .get(alg) + .map(|s| &**s) + .ok_or(UnknownSignatureAlgId(*alg)) + } +} + +/// Signature verifier. +pub trait SignatureVerifier { + /// Returns the algorithm used by this verifier. + fn alg_id(&self) -> SignatureAlgId; + + /// Verifies the signature. + fn verify(&self, key: &VerifyingKey, msg: &[u8], sig: &[u8]) -> Result<(), SignatureError>; +} + +/// Verifying key. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct VerifyingKey { + /// The key algorithm. + pub alg: KeyAlgId, + /// The key data. + pub data: Vec, +} + +impl_domain_separator!(VerifyingKey); + +/// Error occurred while verifying a signature. +#[derive(Debug, thiserror::Error)] +#[error("signature verification failed: {0}")] +pub struct SignatureError(String); + +/// A signature. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Signature { + /// The algorithm used to sign the data. + pub alg: SignatureAlgId, + /// The signature data. + pub data: Vec, +} + +mod secp256k1 { + use std::sync::{Arc, Mutex}; + + use k256::ecdsa::{ + signature::{SignerMut, Verifier}, + Signature as Secp256K1Signature, SigningKey, + }; + + use super::*; + + /// secp256k1 signer. + pub struct Secp256k1Signer(Arc>); + + impl Secp256k1Signer { + /// Creates a new secp256k1 signer with the provided signing key. + pub fn new(key: &[u8]) -> Result { + SigningKey::from_slice(key) + .map(|key| Self(Arc::new(Mutex::new(key)))) + .map_err(|_| SignerError("invalid key".to_string())) + } + } + + impl Signer for Secp256k1Signer { + fn alg_id(&self) -> SignatureAlgId { + SignatureAlgId::SECP256K1 + } + + fn sign(&self, msg: &[u8]) -> Result { + let sig: Secp256K1Signature = self.0.lock().unwrap().sign(msg); + + Ok(Signature { + alg: SignatureAlgId::SECP256K1, + data: sig.to_vec(), + }) + } + + fn verifying_key(&self) -> VerifyingKey { + let key = self.0.lock().unwrap().verifying_key().to_sec1_bytes(); + + VerifyingKey { + alg: KeyAlgId::K256, + data: key.to_vec(), + } + } + } + + /// secp256k1 verifier. + pub struct Secp256k1Verifier; + + impl SignatureVerifier for Secp256k1Verifier { + fn alg_id(&self) -> SignatureAlgId { + SignatureAlgId::SECP256K1 + } + + fn verify(&self, key: &VerifyingKey, msg: &[u8], sig: &[u8]) -> Result<(), SignatureError> { + if key.alg != KeyAlgId::K256 { + return Err(SignatureError("key algorithm is not k256".to_string())); + } + + let key = k256::ecdsa::VerifyingKey::from_sec1_bytes(&key.data) + .map_err(|_| SignatureError("invalid k256 key".to_string()))?; + + let sig = Secp256K1Signature::from_slice(sig) + .map_err(|_| SignatureError("invalid secp256k1 signature".to_string()))?; + + key.verify(msg, &sig).map_err(|_| { + SignatureError("secp256k1 signature verification failed".to_string()) + })?; + + Ok(()) + } + } +} + +pub use secp256k1::{Secp256k1Signer, Secp256k1Verifier}; + +mod secp256r1 { + use std::sync::{Arc, Mutex}; + + use p256::ecdsa::{ + signature::{SignerMut, Verifier}, + Signature as Secp256R1Signature, SigningKey, + }; + + use super::*; + + /// secp256r1 signer. + pub struct Secp256r1Signer(Arc>); + + impl Secp256r1Signer { + /// Creates a new secp256r1 signer with the provided signing key. + pub fn new(key: &[u8]) -> Result { + SigningKey::from_slice(key) + .map(|key| Self(Arc::new(Mutex::new(key)))) + .map_err(|_| SignerError("invalid key".to_string())) + } + } + + impl Signer for Secp256r1Signer { + fn alg_id(&self) -> SignatureAlgId { + SignatureAlgId::SECP256R1 + } + + fn sign(&self, msg: &[u8]) -> Result { + let sig: Secp256R1Signature = self.0.lock().unwrap().sign(msg); + + Ok(Signature { + alg: SignatureAlgId::SECP256R1, + data: sig.to_vec(), + }) + } + + fn verifying_key(&self) -> VerifyingKey { + let key = self.0.lock().unwrap().verifying_key().to_sec1_bytes(); + + VerifyingKey { + alg: KeyAlgId::P256, + data: key.to_vec(), + } + } + } + + /// secp256r1 verifier. + pub struct Secp256r1Verifier; + + impl SignatureVerifier for Secp256r1Verifier { + fn alg_id(&self) -> SignatureAlgId { + SignatureAlgId::SECP256R1 + } + + fn verify(&self, key: &VerifyingKey, msg: &[u8], sig: &[u8]) -> Result<(), SignatureError> { + if key.alg != KeyAlgId::P256 { + return Err(SignatureError("key algorithm is not p256".to_string())); + } + + let key = p256::ecdsa::VerifyingKey::from_sec1_bytes(&key.data) + .map_err(|_| SignatureError("invalid p256 key".to_string()))?; + + let sig = Secp256R1Signature::from_slice(sig) + .map_err(|_| SignatureError("invalid secp256r1 signature".to_string()))?; + + key.verify(msg, &sig).map_err(|_| { + SignatureError("secp256r1 signature verification failed".to_string()) + })?; + + Ok(()) + } + } +} + +pub use secp256r1::{Secp256r1Signer, Secp256r1Verifier}; + +#[cfg(test)] +mod test { + use super::*; + use rand_core::OsRng; + use rstest::{fixture, rstest}; + + #[fixture] + #[once] + fn secp256k1_signer() -> Secp256k1Signer { + let signing_key = k256::ecdsa::SigningKey::random(&mut OsRng); + Secp256k1Signer::new(&signing_key.to_bytes()).unwrap() + } + + #[fixture] + #[once] + fn secp256r1_signer() -> Secp256r1Signer { + let signing_key = p256::ecdsa::SigningKey::random(&mut OsRng); + Secp256r1Signer::new(&signing_key.to_bytes()).unwrap() + } + + #[rstest] + fn test_secp256k1_success(secp256k1_signer: &Secp256k1Signer) { + assert_eq!(secp256k1_signer.alg_id(), SignatureAlgId::SECP256K1); + + let msg = "test payload"; + let signature = secp256k1_signer.sign(msg.as_bytes()).unwrap(); + let verifying_key = secp256k1_signer.verifying_key(); + + let verifier = Secp256k1Verifier {}; + assert_eq!(verifier.alg_id(), SignatureAlgId::SECP256K1); + let result = verifier.verify(&verifying_key, msg.as_bytes(), &signature.data); + assert!(result.is_ok()); + } + + #[rstest] + fn test_secp256r1_success(secp256r1_signer: &Secp256r1Signer) { + assert_eq!(secp256r1_signer.alg_id(), SignatureAlgId::SECP256R1); + + let msg = "test payload"; + let signature = secp256r1_signer.sign(msg.as_bytes()).unwrap(); + let verifying_key = secp256r1_signer.verifying_key(); + + let verifier = Secp256r1Verifier {}; + assert_eq!(verifier.alg_id(), SignatureAlgId::SECP256R1); + let result = verifier.verify(&verifying_key, msg.as_bytes(), &signature.data); + assert!(result.is_ok()); + } + + #[rstest] + #[case::wrong_signer(&secp256r1_signer(), false, false)] + #[case::corrupted_signature(&secp256k1_signer(), true, false)] + #[case::wrong_signature(&secp256k1_signer(), false, true)] + fn test_failure( + #[case] signer: &dyn Signer, + #[case] corrupted_signature: bool, + #[case] wrong_signature: bool, + ) { + let msg = "test payload"; + let mut signature = signer.sign(msg.as_bytes()).unwrap(); + let verifying_key = signer.verifying_key(); + + if corrupted_signature { + signature.data.push(0); + } + + if wrong_signature { + signature = signer.sign("different payload".as_bytes()).unwrap(); + } + + let verifier = Secp256k1Verifier {}; + let result = verifier.verify(&verifying_key, msg.as_bytes(), &signature.data); + assert!(result.is_err()); + } +} diff --git a/crates/core/src/transcript.rs b/crates/core/src/transcript.rs new file mode 100644 index 0000000000..cb64755b01 --- /dev/null +++ b/crates/core/src/transcript.rs @@ -0,0 +1,634 @@ +//! Transcript types. +//! +//! All application data communicated over a TLS connection is referred to as a +//! [`Transcript`]. A transcript is essentially just two vectors of bytes, each +//! corresponding to a [`Direction`]. +//! +//! TLS operates over a bidirectional byte stream, and thus there are no +//! application layer semantics present in the transcript. For example, HTTPS is +//! an application layer protocol that runs *over TLS* so there is no concept of +//! "requests" or "responses" in the transcript itself. These semantics must be +//! recovered by parsing the application data and relating it to the bytes +//! in the transcript. +//! +//! ## Commitments +//! +//! During the attestation process a Prover can generate multiple commitments to +//! various parts of the transcript. These commitments are inserted into the +//! attestation body and can be used by the Verifier to verify transcript proofs +//! later. +//! +//! To configure the transcript commitments, use the +//! [`TranscriptCommitConfigBuilder`]. +//! +//! ## Selective Disclosure +//! +//! Using a [`TranscriptProof`] a Prover can selectively disclose parts of a +//! transcript to a Verifier in the form of a [`PartialTranscript`]. A Verifier +//! always learns the length of the transcript, but sensitive data can be +//! withheld. +//! +//! To create a proof, use the [`TranscriptProofBuilder`] which is returned by +//! [`Secrets::transcript_proof_builder`](crate::Secrets::transcript_proof_builder). + +pub mod commit; +#[doc(hidden)] +pub mod encoding; +mod hash; +mod proof; + +use std::{fmt, ops::Range}; + +use serde::{Deserialize, Serialize}; +use utils::range::{Difference, IndexRanges, RangeSet, ToRangeSet, Union}; + +use crate::connection::TranscriptLength; + +pub use commit::{ + TranscriptCommitConfig, TranscriptCommitConfigBuilder, TranscriptCommitConfigBuilderError, + TranscriptCommitmentKind, +}; +pub use hash::{PlaintextHash, PlaintextHashSecret}; +pub use proof::{ + TranscriptProof, TranscriptProofBuilder, TranscriptProofBuilderError, TranscriptProofError, +}; + +/// Sent data transcript ID. +pub static TX_TRANSCRIPT_ID: &str = "tx"; +/// Received data transcript ID. +pub static RX_TRANSCRIPT_ID: &str = "rx"; + +/// A transcript contains all the data communicated over a TLS connection. +#[derive(Clone, Serialize, Deserialize)] +pub struct Transcript { + /// Data sent from the Prover to the Server. + sent: Vec, + /// Data received by the Prover from the Server. + received: Vec, +} + +opaque_debug::implement!(Transcript); + +impl Transcript { + /// Creates a new transcript. + pub fn new(sent: impl Into>, received: impl Into>) -> Self { + Self { + sent: sent.into(), + received: received.into(), + } + } + + /// Returns a reference to the sent data. + pub fn sent(&self) -> &[u8] { + &self.sent + } + + /// Returns a reference to the received data. + pub fn received(&self) -> &[u8] { + &self.received + } + + /// Returns the length of the sent and received data, respectively. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> (usize, usize) { + (self.sent.len(), self.received.len()) + } + + /// Returns the length of the transcript in the given direction. + pub(crate) fn len_of_direction(&self, direction: Direction) -> usize { + match direction { + Direction::Sent => self.sent.len(), + Direction::Received => self.received.len(), + } + } + + /// Returns the transcript length. + pub fn length(&self) -> TranscriptLength { + TranscriptLength { + sent: self.sent.len() as u32, + received: self.received.len() as u32, + } + } + + /// Returns the subsequence of the transcript with the provided index, + /// returning `None` if the index is out of bounds. + pub fn get(&self, direction: Direction, idx: &Idx) -> Option { + let data = match direction { + Direction::Sent => &self.sent, + Direction::Received => &self.received, + }; + + if idx.end() > data.len() { + return None; + } + + Some( + Subsequence::new(idx.clone(), data.index_ranges(&idx.0)) + .expect("data is same length as index"), + ) + } + + /// Returns a partial transcript containing the provided indices. + /// + /// # Panics + /// + /// Panics if the indices are out of bounds. + /// + /// # Arguments + /// + /// * `sent_idx` - The indices of the sent data to include. + /// * `recv_idx` - The indices of the received data to include. + pub fn to_partial(&self, sent_idx: Idx, recv_idx: Idx) -> PartialTranscript { + let mut sent = vec![0; self.sent.len()]; + let mut received = vec![0; self.received.len()]; + + for range in sent_idx.iter_ranges() { + sent[range.clone()].copy_from_slice(&self.sent[range]); + } + + for range in recv_idx.iter_ranges() { + received[range.clone()].copy_from_slice(&self.received[range]); + } + + PartialTranscript { + sent, + received, + sent_authed: sent_idx, + received_authed: recv_idx, + } + } +} + +/// A partial transcript. +/// +/// A partial transcript is a transcript which may not have all the data authenticated. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(try_from = "validation::PartialTranscriptUnchecked")] +pub struct PartialTranscript { + /// Data sent from the Prover to the Server. + sent: Vec, + /// Data received by the Prover from the Server. + received: Vec, + /// Index of `sent` which have been authenticated. + sent_authed: Idx, + /// Index of `received` which have been authenticated. + received_authed: Idx, +} + +impl PartialTranscript { + /// Creates a new partial transcript initalized to all 0s. + /// + /// # Arguments + /// + /// * `sent_len` - The length of the sent data. + /// * `received_len` - The length of the received data. + pub fn new(sent_len: usize, received_len: usize) -> Self { + Self { + sent: vec![0; sent_len], + received: vec![0; received_len], + sent_authed: Idx::default(), + received_authed: Idx::default(), + } + } + + /// Returns the length of the sent transcript. + pub fn len_sent(&self) -> usize { + self.sent.len() + } + + /// Returns the length of the received transcript. + pub fn len_received(&self) -> usize { + self.received.len() + } + + /// Returns whether the transcript is complete. + pub fn is_complete(&self) -> bool { + self.sent_authed.len() == self.sent.len() + && self.received_authed.len() == self.received.len() + } + + /// Returns whether the index is in bounds of the transcript. + pub fn contains(&self, direction: Direction, idx: &Idx) -> bool { + match direction { + Direction::Sent => idx.end() <= self.sent.len(), + Direction::Received => idx.end() <= self.received.len(), + } + } + + /// Returns a reference to the sent data. + /// + /// # Warning + /// + /// Not all of the data in the transcript may have been authenticated. See + /// [sent_authed](PartialTranscript::sent_authed) for a set of ranges which + /// have been. + pub fn sent_unsafe(&self) -> &[u8] { + &self.sent + } + + /// Returns a reference to the received data. + /// + /// # Warning + /// + /// Not all of the data in the transcript may have been authenticated. See + /// [received_authed](PartialTranscript::received_authed) for a set of + /// ranges which have been. + pub fn received_unsafe(&self) -> &[u8] { + &self.received + } + + /// Returns the index of sent data which have been authenticated. + pub fn sent_authed(&self) -> &Idx { + &self.sent_authed + } + + /// Returns the index of received data which have been authenticated. + pub fn received_authed(&self) -> &Idx { + &self.received_authed + } + + /// Returns the index of sent data which haven't been authenticated. + pub fn sent_unauthed(&self) -> Idx { + Idx(RangeSet::from(0..self.sent.len()).difference(&self.sent_authed.0)) + } + + /// Returns the index of received data which haven't been authenticated. + pub fn received_unauthed(&self) -> Idx { + Idx(RangeSet::from(0..self.received.len()).difference(&self.received_authed.0)) + } + + /// Returns an iterator over the authenticated data in the transcript. + pub fn iter(&self, direction: Direction) -> impl Iterator + '_ { + let (data, authed) = match direction { + Direction::Sent => (&self.sent, &self.sent_authed), + Direction::Received => (&self.received, &self.received_authed), + }; + + authed.0.iter().map(|i| data[i]) + } + + /// Unions the authenticated data of this transcript with another. + /// + /// # Panics + /// + /// Panics if the other transcript is not the same length. + pub fn union_transcript(&mut self, other: &PartialTranscript) { + assert_eq!( + self.sent.len(), + other.sent.len(), + "sent data are not the same length" + ); + assert_eq!( + self.received.len(), + other.received.len(), + "received data are not the same length" + ); + + for range in other + .sent_authed + .0 + .difference(&self.sent_authed.0) + .iter_ranges() + { + self.sent[range.clone()].copy_from_slice(&other.sent[range]); + } + + for range in other + .received_authed + .0 + .difference(&self.received_authed.0) + .iter_ranges() + { + self.received[range.clone()].copy_from_slice(&other.received[range]); + } + + self.sent_authed = self.sent_authed.union(&other.sent_authed); + self.received_authed = self.received_authed.union(&other.received_authed); + } + + /// Unions an authenticated subsequence into this transcript. + /// + /// # Panics + /// + /// Panics if the subsequence is outside the bounds of the transcript. + pub fn union_subsequence(&mut self, direction: Direction, seq: &Subsequence) { + match direction { + Direction::Sent => { + seq.copy_to(&mut self.sent); + self.sent_authed = self.sent_authed.union(&seq.idx); + } + Direction::Received => { + seq.copy_to(&mut self.received); + self.received_authed = self.received_authed.union(&seq.idx); + } + } + } + + /// Sets all bytes in the transcript which haven't been authenticated. + /// + /// # Arguments + /// + /// * `value` - The value to set the unauthenticated bytes to + pub fn set_unauthed(&mut self, value: u8) { + for range in self.sent_unauthed().iter_ranges() { + self.sent[range].fill(value); + } + for range in self.received_unauthed().iter_ranges() { + self.received[range].fill(value); + } + } + + /// Sets all bytes in the transcript which haven't been authenticated within + /// the given range. + /// + /// # Arguments + /// + /// * `value` - The value to set the unauthenticated bytes to + /// * `range` - The range of bytes to set + pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range) { + match direction { + Direction::Sent => { + for range in range.difference(&self.sent_authed.0).iter_ranges() { + self.sent[range].fill(value); + } + } + Direction::Received => { + for range in range.difference(&self.received_authed.0).iter_ranges() { + self.received[range].fill(value); + } + } + } + } +} + +/// The direction of data communicated over a TLS connection. +/// +/// This is used to differentiate between data sent from the Prover to the TLS peer, +/// and data received by the Prover from the TLS peer (client or server). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] +pub enum Direction { + /// Sent from the Prover to the TLS peer. + Sent = 0x00, + /// Received by the prover from the TLS peer. + Received = 0x01, +} + +impl fmt::Display for Direction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Direction::Sent => write!(f, "sent"), + Direction::Received => write!(f, "received"), + } + } +} + +/// Transcript index. +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct Idx(RangeSet); + +impl Idx { + /// Creates a new index builder. + pub fn builder() -> IdxBuilder { + IdxBuilder::default() + } + + /// Creates an empty index. + pub fn empty() -> Self { + Self(RangeSet::default()) + } + + /// Creates a new transcript index. + pub fn new(ranges: impl Into>) -> Self { + Self(ranges.into()) + } + + /// Returns the start of the index. + pub fn start(&self) -> usize { + self.0.min().unwrap_or_default() + } + + /// Returns the end of the index, non-inclusive. + pub fn end(&self) -> usize { + self.0.end().unwrap_or_default() + } + + /// Returns an iterator over the values in the index. + pub fn iter(&self) -> impl Iterator + '_ { + self.0.iter() + } + + /// Returns an iterator over the ranges of the index. + pub fn iter_ranges(&self) -> impl Iterator> + '_ { + self.0.iter_ranges() + } + + /// Returns the number of values in the index. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns whether the index is empty. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Returns the number of disjoint ranges in the index. + pub fn count(&self) -> usize { + self.0.len_ranges() + } + + /// Returns the union of this index with another. + pub fn union(&self, other: &Idx) -> Idx { + Idx(self.0.union(&other.0)) + } +} + +/// Builder for [`Idx`]. +#[derive(Debug, Default)] +pub struct IdxBuilder(RangeSet); + +impl IdxBuilder { + /// Unions ranges. + pub fn union(self, ranges: &dyn ToRangeSet) -> Self { + IdxBuilder(self.0.union(&ranges.to_range_set())) + } + + /// Builds the index. + pub fn build(self) -> Idx { + Idx(self.0) + } +} + +/// Transcript subsequence. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "validation::SubsequenceUnchecked")] +pub struct Subsequence { + /// Index of the subsequence. + idx: Idx, + /// Data of the subsequence. + data: Vec, +} + +impl Subsequence { + /// Creates a new subsequence. + pub fn new(idx: Idx, data: Vec) -> Result { + if idx.len() != data.len() { + return Err(InvalidSubsequence( + "index length does not match data length", + )); + } + + Ok(Self { idx, data }) + } + + /// Returns the index of the subsequence. + pub fn index(&self) -> &Idx { + &self.idx + } + + /// Returns the data of the subsequence. + pub fn data(&self) -> &[u8] { + &self.data + } + + /// Returns the length of the subsequence. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns the inner parts of the subsequence. + pub fn into_parts(self) -> (Idx, Vec) { + (self.idx, self.data) + } + + /// Copies the subsequence data into the given destination. + /// + /// # Panics + /// + /// Panics if the subsequence ranges are out of bounds. + pub(crate) fn copy_to(&self, dest: &mut [u8]) { + let mut offset = 0; + for range in self.idx.iter_ranges() { + dest[range.clone()].copy_from_slice(&self.data[offset..offset + range.len()]); + offset += range.len(); + } + } +} + +/// Invalid subsequence error. +#[derive(Debug, thiserror::Error)] +#[error("invalid subsequence: {0}")] +pub struct InvalidSubsequence(&'static str); + +/// Returns the value ID for each byte in the provided range set. +#[doc(hidden)] +pub fn get_value_ids(direction: Direction, idx: &Idx) -> impl Iterator + '_ { + let id = match direction { + Direction::Sent => TX_TRANSCRIPT_ID, + Direction::Received => RX_TRANSCRIPT_ID, + }; + + idx.iter().map(move |idx| format!("{}/{}", id, idx)) +} + +mod validation { + use super::*; + + #[derive(Debug, Deserialize)] + pub(super) struct SubsequenceUnchecked { + idx: Idx, + data: Vec, + } + + impl TryFrom for Subsequence { + type Error = InvalidSubsequence; + + fn try_from(unchecked: SubsequenceUnchecked) -> Result { + Self::new(unchecked.idx, unchecked.data) + } + } + + /// Invalid partial transcript error. + #[derive(Debug, thiserror::Error)] + #[error("invalid partial transcript: {0}")] + pub struct InvalidPartialTranscript(&'static str); + + #[derive(Debug, Deserialize)] + pub(super) struct PartialTranscriptUnchecked { + sent: Vec, + received: Vec, + sent_authed: Idx, + received_authed: Idx, + } + + impl TryFrom for PartialTranscript { + type Error = InvalidPartialTranscript; + + fn try_from(unchecked: PartialTranscriptUnchecked) -> Result { + if unchecked.sent_authed.end() > unchecked.sent.len() + || unchecked.received_authed.end() > unchecked.received.len() + { + return Err(InvalidPartialTranscript( + "authenticated ranges are not in bounds of the data", + )); + } + + // Rewrite the data to ensure that unauthenticated data is zeroed out. + let mut sent = vec![0; unchecked.sent.len()]; + let mut received = vec![0; unchecked.received.len()]; + + for range in unchecked.sent_authed.iter_ranges() { + sent[range.clone()].copy_from_slice(&unchecked.sent[range]); + } + + for range in unchecked.received_authed.iter_ranges() { + received[range.clone()].copy_from_slice(&unchecked.received[range]); + } + + Ok(Self { + sent, + received, + sent_authed: unchecked.sent_authed, + received_authed: unchecked.received_authed, + }) + } + } +} + +#[cfg(test)] +mod tests { + use rstest::{fixture, rstest}; + + use super::*; + + #[fixture] + fn transcript() -> Transcript { + Transcript::new( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ) + } + + #[rstest] + fn test_get_subsequence(transcript: Transcript) { + let subseq = transcript + .get(Direction::Received, &Idx(RangeSet::from([0..4, 7..10]))) + .unwrap(); + assert_eq!(subseq.data, vec![0, 1, 2, 3, 7, 8, 9]); + + let subseq = transcript + .get(Direction::Sent, &Idx(RangeSet::from([0..4, 9..12]))) + .unwrap(); + assert_eq!(subseq.data, vec![0, 1, 2, 3, 9, 10, 11]); + + let subseq = transcript.get( + Direction::Received, + &Idx(RangeSet::from([0..4, 7..10, 11..13])), + ); + assert_eq!(subseq, None); + + let subseq = transcript.get(Direction::Sent, &Idx(RangeSet::from([0..4, 7..10, 11..13]))); + assert_eq!(subseq, None); + } +} diff --git a/crates/core/src/transcript/commit.rs b/crates/core/src/transcript/commit.rs new file mode 100644 index 0000000000..a8431d87f3 --- /dev/null +++ b/crates/core/src/transcript/commit.rs @@ -0,0 +1,103 @@ +//! Transcript commitments. + +mod builder; + +use std::collections::HashSet; + +use getset::Getters; +use serde::{Deserialize, Serialize}; + +use crate::{ + hash::{Blinder, HashAlgId}, + transcript::{Direction, Idx, Transcript}, +}; + +pub use builder::{TranscriptCommitConfigBuilder, TranscriptCommitConfigBuilderError}; + +#[cfg(feature = "poseidon")] +pub(crate) const SUPPORTED_PLAINTEXT_HASH_ALGS: &[HashAlgId] = &[HashAlgId::POSEIDON_BN256_434]; + +#[cfg(not(feature = "poseidon"))] +pub(crate) const SUPPORTED_PLAINTEXT_HASH_ALGS: &[HashAlgId] = &[]; + +/// Kind of transcript commitment. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum TranscriptCommitmentKind { + /// A commitment to encodings of the transcript. + Encoding, + /// A hash commitment to plaintext in the transcript. + Hash { + /// The hash algorithm used. + alg: HashAlgId, + }, +} + +/// Configuration for transcript commitments. +#[derive(Debug, Clone)] +pub struct TranscriptCommitConfig { + encoding_hash_alg: HashAlgId, + /// Commitment information. + commits: HashSet, +} + +impl TranscriptCommitConfig { + /// Creates a new commit config builder. + pub fn builder(transcript: &Transcript) -> TranscriptCommitConfigBuilder { + TranscriptCommitConfigBuilder::new(transcript) + } + + /// Returns the hash algorithm to use for encoding commitments. + pub fn encoding_hash_alg(&self) -> &HashAlgId { + &self.encoding_hash_alg + } + + /// Returns whether the configuration has any encoding commitments. + pub fn has_encoding(&self) -> bool { + self.commits + .iter() + .any(|commit| matches!(commit.kind, TranscriptCommitmentKind::Encoding)) + } + + /// Returns an iterator over the encoding commitment indices. + pub fn iter_encoding(&self) -> impl Iterator { + self.commits.iter().filter_map(|commit| match commit.kind { + TranscriptCommitmentKind::Encoding => Some(&commit.idx), + _ => None, + }) + } + + /// Returns whether the configuration has any plaintext hash commitments. + pub fn has_plaintext_hashes(&self) -> bool { + self.commits + .iter() + .any(|commit| matches!(commit.kind, TranscriptCommitmentKind::Hash { .. })) + } + + /// Returns an iterator over the plaintext hash commitment info. + pub fn plaintext_hashes(&self) -> impl Iterator + '_ { + self.commits.iter().filter_map(|commit| { + if matches!(commit.kind, TranscriptCommitmentKind::Hash { .. }) { + Some(commit.clone()) + } else { + None + } + }) + } +} + +/// The information required to create a commitment to a subset of transcript data. +#[derive(Debug, Clone, Eq, PartialEq, Getters, std::hash::Hash)] +pub struct CommitInfo { + /// The index of data in a transcript. + #[getset(get = "pub")] + idx: (Direction, Idx), + /// The commitment kind. + #[getset(get = "pub")] + kind: TranscriptCommitmentKind, + /// The blinder to use for the commitment. + /// + /// None value means that the blinder will be generated later at the time of creating the + /// commitment. + #[getset(get = "pub")] + blinder: Option, +} diff --git a/crates/core/src/transcript/commit/builder.rs b/crates/core/src/transcript/commit/builder.rs new file mode 100644 index 0000000000..109575c5f6 --- /dev/null +++ b/crates/core/src/transcript/commit/builder.rs @@ -0,0 +1,245 @@ +use std::{collections::HashSet, fmt}; + +use rand::{distributions::Standard, prelude::Distribution, thread_rng}; +use utils::range::ToRangeSet; + +use crate::{ + hash::{Blinder, HashAlgId}, + transcript::{ + commit::{CommitInfo, SUPPORTED_PLAINTEXT_HASH_ALGS}, + Direction, Idx, Transcript, TranscriptCommitConfig, TranscriptCommitmentKind, + }, +}; + +/// A builder for [`TranscriptCommitConfig`]. +/// +/// The default hash algorithm is [`HashAlgId::BLAKE3`] and the default kind +/// is [`TranscriptCommitmentKind::Encoding`]. +#[derive(Debug)] +pub struct TranscriptCommitConfigBuilder<'a> { + transcript: &'a Transcript, + encoding_hash_alg: HashAlgId, + default_kind: TranscriptCommitmentKind, + /// Commitment information. + commits: HashSet, +} + +impl<'a> TranscriptCommitConfigBuilder<'a> { + /// Creates a new commit config builder. + pub fn new(transcript: &'a Transcript) -> Self { + Self { + transcript, + encoding_hash_alg: HashAlgId::BLAKE3, + default_kind: TranscriptCommitmentKind::Encoding, + commits: HashSet::default(), + } + } + + /// Sets the hash algorithm to use for encoding commitments. + pub fn encoding_hash_alg(&mut self, alg: HashAlgId) -> &mut Self { + self.encoding_hash_alg = alg; + self + } + + /// Sets the default kind of commitment to use. + pub fn default_kind(&mut self, default_kind: TranscriptCommitmentKind) -> &mut Self { + self.default_kind = default_kind; + self + } + + /// Adds a commitment. + /// + /// # Arguments + /// + /// * `ranges` - The ranges of the commitment. + /// * `direction` - The direction of the transcript. + /// * `kind` - The kind of commitment. + pub fn commit_with_kind( + &mut self, + ranges: &dyn ToRangeSet, + direction: Direction, + kind: TranscriptCommitmentKind, + ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> { + self.commit_with_kind_inner(ranges, direction, kind, None) + } + + /// Adds a commitment with the default kind with a random blinder and returns the blinder. + /// + /// # Arguments + /// + /// * `ranges` - The ranges of the commitment. + /// * `direction` - The direction of the transcript. + pub fn commit_with_blinder( + &mut self, + ranges: &dyn ToRangeSet, + direction: Direction, + ) -> Result { + let kind = self.default_kind; + + let TranscriptCommitmentKind::Hash { .. } = kind else { + return Err(TranscriptCommitConfigBuilderError::new( + ErrorKind::Algorithm, + "commit_with_blinder is only supported for plaintext commitments", + )); + }; + + let blinder: Blinder = Standard.sample(&mut thread_rng()); + + self.commit_with_kind_inner(ranges, direction, kind, Some(blinder.clone()))?; + + Ok(blinder) + } + + /// Adds a commitment with the default kind. + /// + /// # Arguments + /// + /// * `ranges` - The ranges of the commitment. + /// * `direction` - The direction of the transcript. + pub fn commit( + &mut self, + ranges: &dyn ToRangeSet, + direction: Direction, + ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> { + self.commit_with_kind(ranges, direction, self.default_kind) + } + + /// Adds a commitment with the default kind to the sent data transcript. + /// + /// # Arguments + /// + /// * `ranges` - The ranges of the commitment. + pub fn commit_sent( + &mut self, + ranges: &dyn ToRangeSet, + ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> { + self.commit(ranges, Direction::Sent) + } + + /// Adds a commitment with the default kind to the received data transcript. + /// + /// # Arguments + /// + /// * `ranges` - The ranges of the commitment. + pub fn commit_recv( + &mut self, + ranges: &dyn ToRangeSet, + ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> { + self.commit(ranges, Direction::Received) + } + + /// Builds the configuration. + pub fn build(self) -> Result { + Ok(TranscriptCommitConfig { + encoding_hash_alg: self.encoding_hash_alg, + commits: self.commits, + }) + } + + /// Returns plaintext hash commitments. + pub fn plaintext_hashes(&self) -> Vec { + self.commits + .iter() + .filter_map(|commit| match commit.kind { + TranscriptCommitmentKind::Hash { .. } => Some(commit.clone()), + _ => None, + }) + .collect::>() + } + + fn commit_with_kind_inner( + &mut self, + ranges: &dyn ToRangeSet, + direction: Direction, + kind: TranscriptCommitmentKind, + blinder: Option, + ) -> Result<&mut Self, TranscriptCommitConfigBuilderError> { + let idx = Idx::new(ranges.to_range_set()); + + if idx.end() > self.transcript.len_of_direction(direction) { + return Err(TranscriptCommitConfigBuilderError::new( + ErrorKind::Index, + format!( + "range is out of bounds of the transcript ({}): {} > {}", + direction, + idx.end(), + self.transcript.len_of_direction(direction) + ), + )); + } + + if let TranscriptCommitmentKind::Hash { alg } = kind { + if !SUPPORTED_PLAINTEXT_HASH_ALGS.contains(&alg) { + return Err(TranscriptCommitConfigBuilderError::new( + ErrorKind::Algorithm, + format!("unsupported plaintext commitment algorithm {}", alg,), + )); + } + } + + self.commits.insert(CommitInfo { + idx: (direction, idx), + kind, + blinder, + }); + + Ok(self) + } +} + +/// Error for [`TranscriptCommitConfigBuilder`]. +#[derive(Debug, thiserror::Error)] +pub struct TranscriptCommitConfigBuilderError { + kind: ErrorKind, + source: Option>, +} + +impl TranscriptCommitConfigBuilderError { + fn new(kind: ErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } +} + +#[derive(Debug)] +enum ErrorKind { + Index, + Algorithm, +} + +impl fmt::Display for TranscriptCommitConfigBuilderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.kind { + ErrorKind::Index => f.write_str("index error")?, + ErrorKind::Algorithm => f.write_str("algorithm error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_range_out_of_bounds() { + let transcript = Transcript::new( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ); + let mut builder = TranscriptCommitConfigBuilder::new(&transcript); + + assert!(builder.commit_sent(&(10..15)).is_err()); + assert!(builder.commit_recv(&(10..15)).is_err()); + } +} diff --git a/crates/core/src/transcript/encoding.rs b/crates/core/src/transcript/encoding.rs new file mode 100644 index 0000000000..940c821d63 --- /dev/null +++ b/crates/core/src/transcript/encoding.rs @@ -0,0 +1,38 @@ +//! Transcript encoding commitments and proofs. +//! +//! This is an internal module that is not intended to be used directly by +//! users. + +mod encoder; +mod proof; +mod provider; +mod tree; + +pub(crate) use encoder::{new_encoder, Encoder}; +pub use proof::{EncodingProof, EncodingProofError}; +pub use provider::EncodingProvider; +pub use tree::EncodingTree; + +use serde::{Deserialize, Serialize}; + +use crate::hash::{impl_domain_separator, TypedHash}; + +/// The maximum allowed total bytelength of all committed data. Used to prevent +/// DoS during verification. (this will cause the verifier to hash up to a max +/// of 1GB * 128 = 128GB of plaintext encodings if the commitment type is +/// [crate::commitment::Blake3]). +/// +/// This value must not exceed bcs's MAX_SEQUENCE_LENGTH limit (which is (1 << +/// 31) - 1 by default) +const MAX_TOTAL_COMMITTED_DATA: usize = 1_000_000_000; + +/// Transcript encoding commitment. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct EncodingCommitment { + /// Merkle root of the encoding commitments. + pub root: TypedHash, + /// Seed used to generate the encodings. + pub seed: Vec, +} + +impl_domain_separator!(EncodingCommitment); diff --git a/crates/core/src/transcript/encoding/encoder.rs b/crates/core/src/transcript/encoding/encoder.rs new file mode 100644 index 0000000000..5e66327929 --- /dev/null +++ b/crates/core/src/transcript/encoding/encoder.rs @@ -0,0 +1,49 @@ +use mpz_circuits::types::ValueType; +use mpz_core::serialize::CanonicalSerialize; +use mpz_garble_core::ChaChaEncoder; + +use crate::transcript::{Direction, Subsequence, RX_TRANSCRIPT_ID, TX_TRANSCRIPT_ID}; + +pub(crate) fn new_encoder(seed: [u8; 32]) -> impl Encoder { + ChaChaEncoder::new(seed) +} + +/// A transcript encoder. +/// +/// This is an internal implementation detail that should not be exposed to the +/// public API. +pub(crate) trait Encoder { + /// Returns the encoding for the given subsequence of the transcript. + /// + /// # Arguments + /// + /// * `seq` - The subsequence to encode. + fn encode_subsequence(&self, direction: Direction, seq: &Subsequence) -> Vec; +} + +impl Encoder for ChaChaEncoder { + fn encode_subsequence(&self, direction: Direction, seq: &Subsequence) -> Vec { + let id = match direction { + Direction::Sent => TX_TRANSCRIPT_ID, + Direction::Received => RX_TRANSCRIPT_ID, + }; + + let mut encoding = Vec::with_capacity(seq.len() * 16); + for (byte_id, &byte) in seq.index().iter().zip(seq.data()) { + let id_hash = mpz_core::utils::blake3(format!("{}/{}", id, byte_id).as_bytes()); + let id = u64::from_be_bytes(id_hash[..8].try_into().unwrap()); + + encoding.extend( + ::encode_by_type( + self, + id, + &ValueType::U8, + ) + .select(byte) + .expect("encoding is a byte encoding") + .to_bytes(), + ) + } + encoding + } +} diff --git a/crates/core/src/transcript/encoding/proof.rs b/crates/core/src/transcript/encoding/proof.rs new file mode 100644 index 0000000000..ead2a25b5b --- /dev/null +++ b/crates/core/src/transcript/encoding/proof.rs @@ -0,0 +1,182 @@ +use std::{collections::HashMap, fmt}; + +use serde::{Deserialize, Serialize}; + +use crate::{ + connection::TranscriptLength, + hash::{Blinded, Blinder, HashAlgorithmExt, HashProviderError}, + merkle::{MerkleError, MerkleProof}, + transcript::{ + encoding::{ + new_encoder, tree::EncodingLeaf, Encoder, EncodingCommitment, MAX_TOTAL_COMMITTED_DATA, + }, + Direction, PartialTranscript, Subsequence, + }, + CryptoProvider, +}; + +/// An opening of a leaf in the encoding tree. +#[derive(Clone, Serialize, Deserialize)] +pub(super) struct Opening { + pub(super) direction: Direction, + pub(super) seq: Subsequence, + pub(super) blinder: Blinder, +} + +opaque_debug::implement!(Opening); + +/// An encoding proof. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EncodingProof { + pub(super) inclusion_proof: MerkleProof, + pub(super) openings: HashMap, +} + +impl EncodingProof { + /// Verifies the proof against the commitment. + /// + /// Returns the partial sent and received transcripts, respectively. + /// + /// # Arguments + /// + /// * `transcript_length` - The length of the transcript. + /// * `commitment` - The encoding commitment to verify against. + pub fn verify_with_provider( + self, + provider: &CryptoProvider, + transcript_length: &TranscriptLength, + commitment: &EncodingCommitment, + ) -> Result { + let hasher = provider.hash.get(&commitment.root.alg)?; + + let seed: [u8; 32] = commitment.seed.clone().try_into().map_err(|_| { + EncodingProofError::new(ErrorKind::Commitment, "encoding seed not 32 bytes") + })?; + + let encoder = new_encoder(seed); + let Self { + inclusion_proof, + openings, + } = self; + let (sent_len, recv_len) = ( + transcript_length.sent as usize, + transcript_length.received as usize, + ); + + let mut leaves = Vec::with_capacity(openings.len()); + let mut transcript = PartialTranscript::new(sent_len, recv_len); + let mut total_opened = 0u128; + for ( + id, + Opening { + direction, + seq, + blinder, + }, + ) in openings + { + // Make sure the amount of data being proved is bounded. + total_opened += seq.len() as u128; + if total_opened > MAX_TOTAL_COMMITTED_DATA as u128 { + return Err(EncodingProofError::new( + ErrorKind::Proof, + "exceeded maximum allowed data", + ))?; + } + + // Make sure the ranges are within the bounds of the transcript + let transcript_len = match direction { + Direction::Sent => sent_len, + Direction::Received => recv_len, + }; + + if seq.index().end() > transcript_len { + return Err(EncodingProofError::new( + ErrorKind::Proof, + format!( + "index out of bounds of the transcript ({}): {} > {}", + direction, + seq.index().end(), + transcript_len + ), + )); + } + + let expected_encoding = encoder.encode_subsequence(direction, &seq); + let expected_leaf = + Blinded::new_with_blinder(EncodingLeaf::new(expected_encoding), blinder); + + // Compute the expected hash of the commitment to make sure it is + // present in the merkle tree. + leaves.push((id, hasher.hash_canonical(&expected_leaf))); + + // Union the authenticated subsequence into the transcript. + transcript.union_subsequence(direction, &seq); + } + + // Verify that the expected hashes are present in the merkle tree. + // + // This proves the Prover committed to the purported data prior to the encoder + // seed being revealed. Ergo, if the encodings are authentic then the purported + // data is authentic. + inclusion_proof.verify(hasher, &commitment.root, leaves)?; + + Ok(transcript) + } +} + +/// Error for [`EncodingProof`]. +#[derive(Debug, thiserror::Error)] +pub struct EncodingProofError { + kind: ErrorKind, + source: Option>, +} + +impl EncodingProofError { + fn new(kind: ErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } +} + +#[derive(Debug)] +enum ErrorKind { + Provider, + Commitment, + Proof, +} + +impl fmt::Display for EncodingProofError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("encoding proof error: ")?; + + match self.kind { + ErrorKind::Provider => f.write_str("provider error")?, + ErrorKind::Commitment => f.write_str("commitment error")?, + ErrorKind::Proof => f.write_str("proof error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +impl From for EncodingProofError { + fn from(error: HashProviderError) -> Self { + Self::new(ErrorKind::Provider, error) + } +} + +impl From for EncodingProofError { + fn from(error: MerkleError) -> Self { + Self::new(ErrorKind::Proof, error) + } +} diff --git a/crates/core/src/transcript/encoding/provider.rs b/crates/core/src/transcript/encoding/provider.rs new file mode 100644 index 0000000000..d30cd0286e --- /dev/null +++ b/crates/core/src/transcript/encoding/provider.rs @@ -0,0 +1,10 @@ +use crate::transcript::{Direction, Idx}; + +/// A provider of plaintext encodings. +pub trait EncodingProvider { + /// Provides the encoding of a subsequence of plaintext. + fn provide_encoding(&self, direction: Direction, idx: &Idx) -> Option>; + + /// Provides the encoding of each individual bit of a subsequence of plaintext in LSB0 bit order. + fn provide_bit_encodings(&self, direction: Direction, idx: &Idx) -> Option>>; +} diff --git a/crates/core/src/transcript/encoding/tree.rs b/crates/core/src/transcript/encoding/tree.rs new file mode 100644 index 0000000000..2a67900318 --- /dev/null +++ b/crates/core/src/transcript/encoding/tree.rs @@ -0,0 +1,331 @@ +use std::collections::HashMap; + +use bimap::BiMap; +use serde::{Deserialize, Serialize}; + +use crate::{ + connection::TranscriptLength, + hash::{Blinded, Blinder, HashAlgId, HashAlgorithm, TypedHash}, + merkle::MerkleTree, + serialize::CanonicalSerialize, + transcript::{ + encoding::{ + proof::{EncodingProof, Opening}, + EncodingProvider, + }, + Direction, Idx, Transcript, + }, +}; + +/// Encoding tree builder error. +#[derive(Debug, thiserror::Error)] +pub enum EncodingTreeError { + /// Index is out of bounds of the transcript. + #[error("index is out of bounds of the transcript")] + OutOfBounds { + /// The index. + index: Idx, + /// The transcript length. + transcript_length: usize, + }, + /// Encoding provider is missing an encoding for an index. + #[error("encoding provider is missing an encoding for an index")] + MissingEncoding { + /// The index which is missing. + index: Idx, + }, + /// Index is missing from the tree. + #[error("index is missing from the tree")] + MissingLeaf { + /// The index which is missing. + index: Idx, + }, +} + +#[derive(Serialize)] +pub(crate) struct EncodingLeaf(Vec); + +impl EncodingLeaf { + pub(super) fn new(encoding: Vec) -> Self { + Self(encoding) + } +} + +/// A merkle tree of transcript encodings. +#[derive(Clone, Serialize, Deserialize)] +pub struct EncodingTree { + /// Merkle tree of the commitments. + tree: MerkleTree, + /// Nonces used to blind the hashes. + nonces: Vec, + /// Mapping between the index of a leaf and the transcript index it + /// corresponds to. + idxs: BiMap, +} + +opaque_debug::implement!(EncodingTree); + +impl EncodingTree { + /// Creates a new encoding tree. + /// + /// # Arguments + /// + /// * `alg` - The hash algorithm to use. + /// * `idxs` - The subsequence indices to commit to. + /// * `provider` - The encoding provider. + /// * `transcript_length` - The length of the transcript. + pub fn new<'idx>( + hasher: &dyn HashAlgorithm, + idxs: impl IntoIterator, + provider: &dyn EncodingProvider, + transcript_length: &TranscriptLength, + ) -> Result { + let mut this = Self { + tree: MerkleTree::new(hasher.id()), + nonces: Vec::new(), + idxs: BiMap::new(), + }; + + let mut leaves = Vec::new(); + for dir_idx in idxs { + let direction = dir_idx.0; + let idx = &dir_idx.1; + + // Ignore empty indices. + if idx.is_empty() { + continue; + } + + let len = match direction { + Direction::Sent => transcript_length.sent as usize, + Direction::Received => transcript_length.received as usize, + }; + + if idx.end() > len { + return Err(EncodingTreeError::OutOfBounds { + index: idx.clone().clone(), + transcript_length: len, + }); + } + + if this.idxs.contains_right(&(direction, idx.clone().clone())) { + // The subsequence is already in the tree. + continue; + } + + let encoding = provider.provide_encoding(direction, idx).ok_or_else(|| { + EncodingTreeError::MissingEncoding { + index: idx.clone().clone(), + } + })?; + + let leaf = Blinded::new(EncodingLeaf::new(encoding)); + + leaves.push(hasher.hash(&CanonicalSerialize::serialize(&leaf))); + this.nonces.push(leaf.into_parts().1); + this.idxs + .insert(this.idxs.len(), (direction, idx.clone().clone()).clone()); + } + + this.tree.insert(hasher, leaves); + + Ok(this) + } + + /// Returns the root of the tree. + pub fn root(&self) -> TypedHash { + self.tree.root() + } + + /// Returns the hash algorithm of the tree. + pub fn algorithm(&self) -> HashAlgId { + self.tree.algorithm() + } + + /// Generates a proof for the given indices. + /// + /// # Arguments + /// + /// * `transcript` - The transcript to prove against. + /// * `idxs` - The transcript indices to prove. + pub fn proof<'idx>( + &self, + transcript: &Transcript, + idxs: impl Iterator, + ) -> Result { + let mut openings = HashMap::new(); + for dir_idx in idxs { + let direction = dir_idx.0; + let idx = &dir_idx.1; + + let leaf_idx = *self + .idxs + .get_by_right(dir_idx) + .ok_or_else(|| EncodingTreeError::MissingLeaf { index: idx.clone() })?; + + let seq = + transcript + .get(direction, idx) + .ok_or_else(|| EncodingTreeError::OutOfBounds { + index: idx.clone(), + transcript_length: transcript.len_of_direction(direction), + })?; + let nonce = self.nonces[leaf_idx].clone(); + + openings.insert( + leaf_idx, + Opening { + direction, + seq, + blinder: nonce, + }, + ); + } + + let mut indices = openings.keys().copied().collect::>(); + indices.sort(); + + Ok(EncodingProof { + inclusion_proof: self.tree.proof(&indices), + openings, + }) + } + + /// Returns whether the tree contains the given transcript index. + pub fn contains(&self, idx: &(Direction, Idx)) -> bool { + self.idxs.contains_right(idx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + fixtures::{encoder_seed, encoding_provider}, + hash::Blake3, + transcript::encoding::EncodingCommitment, + CryptoProvider, + }; + use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON}; + + fn new_tree<'seq>( + transcript: &Transcript, + idxs: impl Iterator, + ) -> Result { + let provider = encoding_provider(transcript.sent(), transcript.received()); + let transcript_length = TranscriptLength { + sent: transcript.sent().len() as u32, + received: transcript.received().len() as u32, + }; + EncodingTree::new(&Blake3::default(), idxs, &provider, &transcript_length) + } + + #[test] + fn test_encoding_tree() { + let transcript = Transcript::new(POST_JSON, OK_JSON); + + let idx_0 = &(Direction::Sent, Idx::new(0..POST_JSON.len())); + let idx_1 = &(Direction::Received, Idx::new(0..OK_JSON.len())); + + let tree = new_tree(&transcript, [idx_0, idx_1].into_iter()).unwrap(); + + assert!(tree.contains(idx_0)); + assert!(tree.contains(idx_1)); + + let proof = tree.proof(&transcript, [idx_0, idx_1].into_iter()).unwrap(); + + let commitment = EncodingCommitment { + root: tree.root(), + seed: encoder_seed().to_vec(), + }; + + let partial_transcript = proof + .verify_with_provider( + &CryptoProvider::default(), + &transcript.length(), + &commitment, + ) + .unwrap(); + + assert_eq!(partial_transcript.sent_unsafe(), transcript.sent()); + assert_eq!(partial_transcript.received_unsafe(), transcript.received()); + } + + #[test] + fn test_encoding_tree_multiple_ranges() { + let transcript = Transcript::new(POST_JSON, OK_JSON); + + let idx_0 = (Direction::Sent, Idx::new(0..1)); + let idx_1 = (Direction::Sent, Idx::new(1..POST_JSON.len())); + let idx_2 = (Direction::Received, Idx::new(0..1)); + let idx_3 = (Direction::Received, Idx::new(1..OK_JSON.len())); + + let tree = new_tree(&transcript, [&idx_0, &idx_1, &idx_2, &idx_3].into_iter()).unwrap(); + + assert!(tree.contains(&idx_0)); + assert!(tree.contains(&idx_1)); + assert!(tree.contains(&idx_2)); + assert!(tree.contains(&idx_3)); + + let proof = tree + .proof(&transcript, [&idx_0, &idx_1, &idx_2, &idx_3].into_iter()) + .unwrap(); + + let commitment = EncodingCommitment { + root: tree.root(), + seed: encoder_seed().to_vec(), + }; + + let partial_transcript = proof + .verify_with_provider( + &CryptoProvider::default(), + &transcript.length(), + &commitment, + ) + .unwrap(); + + assert_eq!(partial_transcript.sent_unsafe(), transcript.sent()); + assert_eq!(partial_transcript.received_unsafe(), transcript.received()); + } + + #[test] + fn test_encoding_tree_out_of_bounds() { + let transcript = Transcript::new(POST_JSON, OK_JSON); + + let idx_0 = (Direction::Sent, Idx::new(0..POST_JSON.len() + 1)); + let idx_1 = (Direction::Received, Idx::new(0..OK_JSON.len() + 1)); + + let result = new_tree(&transcript, [&idx_0].into_iter()).unwrap_err(); + assert!(matches!(result, EncodingTreeError::OutOfBounds { .. })); + + let result = new_tree(&transcript, [&idx_1].into_iter()).unwrap_err(); + assert!(matches!(result, EncodingTreeError::OutOfBounds { .. })); + } + + #[test] + fn test_encoding_tree_missing_encoding() { + let provider = encoding_provider(&[], &[]); + let transcript_length = TranscriptLength { + sent: 8, + received: 8, + }; + + let result = EncodingTree::new( + &Blake3::default(), + [(Direction::Sent, Idx::new(0..8))].iter(), + &provider, + &transcript_length, + ) + .unwrap_err(); + assert!(matches!(result, EncodingTreeError::MissingEncoding { .. })); + + let result = EncodingTree::new( + &Blake3::default(), + [(Direction::Sent, Idx::new(0..8))].iter(), + &provider, + &transcript_length, + ) + .unwrap_err(); + assert!(matches!(result, EncodingTreeError::MissingEncoding { .. })); + } +} diff --git a/crates/core/src/transcript/hash.rs b/crates/core/src/transcript/hash.rs new file mode 100644 index 0000000000..b32261d04f --- /dev/null +++ b/crates/core/src/transcript/hash.rs @@ -0,0 +1,96 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + attestation::FieldId, + hash::{impl_domain_separator, Blinded, Blinder, HashProvider, HashProviderError, TypedHash}, + transcript::{Direction, Idx, InvalidSubsequence, Subsequence}, +}; + +/// Hash of plaintext in the transcript. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct PlaintextHash { + /// Direction of the plaintext. + pub direction: Direction, + /// Index of plaintext. + pub idx: Idx, + /// The hash of the data. + pub hash: TypedHash, +} + +impl_domain_separator!(PlaintextHash); + +/// Secret data for a plaintext hash commitment. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlaintextHashSecret { + pub(crate) direction: Direction, + pub(crate) idx: Idx, + pub(crate) commitment: FieldId, + pub(crate) blinder: Blinder, +} + +/// Proof of the plaintext of a hash. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct PlaintextHashProof { + data: Blinded>, + commitment: FieldId, +} + +impl PlaintextHashProof { + pub(crate) fn new(data: Blinded>, commitment: FieldId) -> Self { + Self { data, commitment } + } +} + +impl PlaintextHashProof { + /// Returns the field id of the commitment this opening corresponds to. + pub(crate) fn commitment_id(&self) -> &FieldId { + &self.commitment + } + + /// Verifies the proof, returning the subsequence of plaintext. + /// + /// # Arguments + /// + /// * `commitment` - The commitment attested to by a Notary. + pub(crate) fn verify( + self, + provider: &HashProvider, + commitment: &PlaintextHash, + ) -> Result<(Direction, Subsequence), PlaintextHashProofError> { + let alg = provider.get(&commitment.hash.alg)?; + + if commitment.hash.value != alg.hash_blinded(&self.data) { + return Err(PlaintextHashProofError::new( + "hash does not match commitment", + )); + } + + Ok(( + commitment.direction, + Subsequence::new(commitment.idx.clone(), self.data.into_parts().0)?, + )) + } +} + +/// Error for [`PlaintextHashProof`]. +#[derive(Debug, thiserror::Error)] +#[error("invalid plaintext hash proof: {0}")] +pub(crate) struct PlaintextHashProofError(String); + +impl PlaintextHashProofError { + fn new>(msg: T) -> Self { + Self(msg.into()) + } +} + +impl From for PlaintextHashProofError { + fn from(err: HashProviderError) -> Self { + Self(err.to_string()) + } +} + +impl From for PlaintextHashProofError { + fn from(err: InvalidSubsequence) -> Self { + Self(err.to_string()) + } +} diff --git a/crates/core/src/transcript/proof.rs b/crates/core/src/transcript/proof.rs new file mode 100644 index 0000000000..b5378fe1b3 --- /dev/null +++ b/crates/core/src/transcript/proof.rs @@ -0,0 +1,407 @@ +//! Transcript proofs. + +use std::{collections::HashSet, fmt}; + +use serde::{Deserialize, Serialize}; +use utils::range::ToRangeSet; + +use crate::{ + attestation::{Body, Field}, + hash::Blinded, + index::Index, + transcript::{ + commit::TranscriptCommitmentKind, + encoding::{EncodingProof, EncodingProofError, EncodingTree}, + hash::{PlaintextHashProof, PlaintextHashProofError, PlaintextHashSecret}, + Direction, Idx, PartialTranscript, Transcript, + }, + CryptoProvider, +}; + +use super::hash::PlaintextHash; + +/// Proof of the contents of a transcript. +#[derive(Clone, Serialize, Deserialize)] +pub struct TranscriptProof { + encoding_proof: Option, + hash_proofs: Option>, +} + +opaque_debug::implement!(TranscriptProof); + +impl TranscriptProof { + /// Verifies the proof. + /// + /// Returns a partial transcript of authenticated data. + /// + /// # Arguments + /// + /// * `provider` - The crypto provider to use for verification. + /// * `attestation_body` - The attestation body to verify against. + pub fn verify_with_provider( + self, + provider: &CryptoProvider, + attestation_body: &Body, + ) -> Result { + let info = attestation_body.connection_info(); + + let mut transcript = PartialTranscript::new( + info.transcript_length.sent as usize, + info.transcript_length.received as usize, + ); + + // Verify encoding proof. + if let Some(proof) = self.encoding_proof { + let commitment = attestation_body.encoding_commitment().ok_or_else(|| { + TranscriptProofError::new( + ErrorKind::Encoding, + "contains an encoding proof but attestation is missing encoding commitment", + ) + })?; + let seq = proof.verify_with_provider(provider, &info.transcript_length, commitment)?; + transcript.union_transcript(&seq); + } + + match ( + attestation_body.plaintext_hashes().clone(), + self.hash_proofs, + ) { + (Some(attested_hashes), Some(hash_proofs)) => { + let index: Index> = attested_hashes.into(); + + // Verify hash openings. + for proof in hash_proofs { + let commitment = index + .get_by_field_id(proof.commitment_id()) + .map(|field| &field.data) + .ok_or_else(|| { + TranscriptProofError::new( + ErrorKind::Hash, + format!("contains a hash opening but attestation is missing corresponding commitment (id: {})", proof.commitment_id()), + ) + })?; + + let (direction, seq) = proof.verify(&provider.hash, commitment)?; + transcript.union_subsequence(direction, &seq); + } + } + // If there are no hash proofs, do nothing. + (None, None) => {} + (Some(_attested_hashes), None) => {} + (None, Some(_hash_proofs)) => { + return Err(TranscriptProofError::new( + ErrorKind::Hash, + "contains a hash opening but attestation contains no hash commitments", + )); + } + } + + Ok(transcript) + } +} + +/// Error for [`TranscriptProof`]. +#[derive(Debug, thiserror::Error)] +pub struct TranscriptProofError { + kind: ErrorKind, + source: Option>, +} + +impl TranscriptProofError { + fn new(kind: ErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } +} + +#[derive(Debug)] +enum ErrorKind { + Encoding, + Hash, +} + +impl fmt::Display for TranscriptProofError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("transcript proof error: ")?; + + match self.kind { + ErrorKind::Encoding => f.write_str("encoding error")?, + ErrorKind::Hash => f.write_str("hash error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +impl From for TranscriptProofError { + fn from(e: EncodingProofError) -> Self { + TranscriptProofError::new(ErrorKind::Encoding, e) + } +} + +impl From for TranscriptProofError { + fn from(e: PlaintextHashProofError) -> Self { + TranscriptProofError::new(ErrorKind::Hash, e) + } +} + +/// Builder for [`TranscriptProof`]. +#[derive(Debug)] +pub struct TranscriptProofBuilder<'a> { + default_kind: TranscriptCommitmentKind, + transcript: &'a Transcript, + encoding_tree: Option<&'a EncodingTree>, + plaintext_hashes: &'a Option>, + encoding_proof_idxs: HashSet<(Direction, Idx)>, + hash_proofs: Vec, +} + +impl<'a> TranscriptProofBuilder<'a> { + /// Creates a new proof config builder. + pub(crate) fn new( + transcript: &'a Transcript, + encoding_tree: Option<&'a EncodingTree>, + plaintext_hashes: &'a Option>, + ) -> Self { + Self { + default_kind: TranscriptCommitmentKind::Encoding, + transcript, + encoding_tree, + plaintext_hashes, + encoding_proof_idxs: HashSet::default(), + hash_proofs: Vec::new(), + } + } + + /// Sets the default kind of commitment to open when revealing ranges. + pub fn default_kind(&mut self, kind: TranscriptCommitmentKind) -> &mut Self { + self.default_kind = kind; + self + } + + /// Reveals the given ranges in the transcript using the provided kind of + /// commitment. + /// + /// # Arguments + /// + /// * `ranges` - The ranges to reveal. + /// * `direction` - The direction of the transcript. + /// * `kind` - The kind of commitment to open. + pub fn reveal_with_kind( + &mut self, + ranges: &dyn ToRangeSet, + direction: Direction, + kind: TranscriptCommitmentKind, + ) -> Result<&mut Self, TranscriptProofBuilderError> { + let idx = Idx::new(ranges.to_range_set()); + + if idx.end() > self.transcript.len_of_direction(direction) { + return Err(TranscriptProofBuilderError::new( + BuilderErrorKind::Index, + format!( + "range is out of bounds of the transcript ({}): {} > {}", + direction, + idx.end(), + self.transcript.len_of_direction(direction) + ), + )); + } + + match kind { + TranscriptCommitmentKind::Encoding => { + let Some(encoding_tree) = self.encoding_tree else { + return Err(TranscriptProofBuilderError::new( + BuilderErrorKind::MissingCommitment, + "encoding tree is missing", + )); + }; + + if !encoding_tree.contains(&(direction, idx.clone())) { + return Err(TranscriptProofBuilderError::new( + BuilderErrorKind::MissingCommitment, + format!( + "encoding commitment is missing for ranges in {} transcript", + direction + ), + )); + } + + self.encoding_proof_idxs.insert((direction, idx)); + } + TranscriptCommitmentKind::Hash { .. } => { + let Some(hashes) = self.plaintext_hashes else { + return Err(TranscriptProofBuilderError::new( + BuilderErrorKind::MissingCommitment, + format!( + "hash commitment is missing for ranges in {} transcript", + direction + ), + )); + }; + + let Some(PlaintextHashSecret { + direction, + commitment, + blinder, + .. + }) = hashes.get_by_transcript_idx(&direction, &idx) + else { + return Err(TranscriptProofBuilderError::new( + BuilderErrorKind::MissingCommitment, + format!( + "hash commitment is missing for ranges in {} transcript", + direction + ), + )); + }; + + let (_, data) = self + .transcript + .get(*direction, &idx) + .expect("subsequence was checked to be in transcript") + .into_parts(); + + self.hash_proofs.push(PlaintextHashProof::new( + Blinded::new_with_blinder(data, blinder.clone()), + *commitment, + )); + } + } + + Ok(self) + } + + /// Reveals the given ranges in the transcript using the default kind of + /// commitment. + /// + /// # Arguments + /// + /// * `ranges` - The ranges to reveal. + /// * `direction` - The direction of the transcript. + pub fn reveal( + &mut self, + ranges: &dyn ToRangeSet, + direction: Direction, + ) -> Result<&mut Self, TranscriptProofBuilderError> { + self.reveal_with_kind(ranges, direction, self.default_kind) + } + + /// Reveals the given ranges in the sent transcript using the default kind + /// of commitment. + /// + /// # Arguments + /// + /// * `ranges` - The ranges to reveal. + pub fn reveal_sent( + &mut self, + ranges: &dyn ToRangeSet, + ) -> Result<&mut Self, TranscriptProofBuilderError> { + self.reveal(ranges, Direction::Sent) + } + + /// Reveals the given ranges in the received transcript using the default + /// kind of commitment. + /// + /// # Arguments + /// + /// * `ranges` - The ranges to reveal. + pub fn reveal_recv( + &mut self, + ranges: &dyn ToRangeSet, + ) -> Result<&mut Self, TranscriptProofBuilderError> { + self.reveal(ranges, Direction::Received) + } + + /// Builds the transcript proof. + pub fn build(self) -> Result { + let encoding_proof = if !self.encoding_proof_idxs.is_empty() { + let encoding_tree = self.encoding_tree.expect("encoding tree is present"); + let proof = encoding_tree + .proof(self.transcript, self.encoding_proof_idxs.iter()) + .expect("subsequences were checked to be in tree"); + Some(proof) + } else { + None + }; + + let hash_proofs = if !self.hash_proofs.is_empty() { + Some(self.hash_proofs) + } else { + None + }; + + Ok(TranscriptProof { + encoding_proof, + hash_proofs, + }) + } +} + +/// Error for [`TranscriptProofBuilder`]. +#[derive(Debug, thiserror::Error)] +pub struct TranscriptProofBuilderError { + kind: BuilderErrorKind, + source: Option>, +} + +impl TranscriptProofBuilderError { + fn new(kind: BuilderErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } +} + +#[derive(Debug)] +enum BuilderErrorKind { + Index, + MissingCommitment, +} + +impl fmt::Display for TranscriptProofBuilderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("transcript proof builder error: ")?; + + match self.kind { + BuilderErrorKind::Index => f.write_str("index error")?, + BuilderErrorKind::MissingCommitment => f.write_str("commitment error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_range_out_of_bounds() { + let transcript = Transcript::new( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ); + let index = Some(Index::default()); + let mut builder = TranscriptProofBuilder::new(&transcript, None, &index); + + assert!(builder.reveal(&(10..15), Direction::Sent).is_err()); + assert!(builder.reveal(&(10..15), Direction::Received).is_err()); + } +} diff --git a/crates/core/tests/api.rs b/crates/core/tests/api.rs new file mode 100644 index 0000000000..2a184cea1e --- /dev/null +++ b/crates/core/tests/api.rs @@ -0,0 +1,184 @@ +use tlsn_core::{ + attestation::{Attestation, AttestationConfig}, + connection::{HandshakeData, HandshakeDataV1_2}, + fixtures::{self, encoder_seed, plaintext_hashes_from_request, ConnectionFixture}, + hash::{Blake3, HashAlgId}, + presentation::PresentationOutput, + request::{Request, RequestConfig}, + signing::SignatureAlgId, + transcript::{ + encoding::EncodingTree, Direction, Idx, Transcript, TranscriptCommitConfigBuilder, + TranscriptCommitmentKind, + }, + CryptoProvider, +}; +use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON}; +use utils::range::{RangeSet, Union}; + +/// Tests that the attestation protocol and verification work end-to-end. +#[test] +fn test_api() { + let mut provider = CryptoProvider::default(); + + // Configure signer for Notary. + provider.signer.set_secp256k1(&[42u8; 32]).unwrap(); + + let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON); + let (sent_len, recv_len) = transcript.len(); + // Plaintext encodings which the Prover obtained from GC evaluation. + let encodings_provider = fixtures::encoding_provider(GET_WITH_HEADER, OK_JSON); + + // At the end of the TLS connection the Prover holds the: + let ConnectionFixture { + server_name, + connection_info, + server_cert_data, + } = ConnectionFixture::tlsnotary(transcript.length()); + + let HandshakeData::V1_2(HandshakeDataV1_2 { + server_ephemeral_key, + .. + }) = server_cert_data.handshake.clone() + else { + unreachable!() + }; + + // Prover specifies the ranges it wants to commit to. + let mut transcript_commitment_builder = TranscriptCommitConfigBuilder::new(&transcript); + transcript_commitment_builder + .commit_sent(&(0..sent_len / 2)) + .unwrap() + .commit_recv(&(0..recv_len / 2)) + .unwrap(); + + #[cfg(feature = "use_poseidon_halo2")] + { + transcript_commitment_builder.default_kind(TranscriptCommitmentKind::Hash { + alg: HashAlgId::POSEIDON_HALO2, + }); + transcript_commitment_builder + .commit_sent(&(sent_len / 2..sent_len)) + .unwrap() + .commit_recv(&(recv_len / 2..recv_len)) + .unwrap(); + } + + let transcripts_commitment_config = transcript_commitment_builder.build().unwrap(); + + // Prover constructs encoding tree. + let encoding_tree = EncodingTree::new( + &Blake3::default(), + transcripts_commitment_config.iter_encoding(), + &encodings_provider, + &transcript.length(), + ) + .unwrap(); + + let request_config = RequestConfig::default(); + let mut request_builder = Request::builder(&request_config); + + request_builder + .server_name(server_name.clone()) + .server_cert_data(server_cert_data) + .transcript(transcript) + .encoding_tree(encoding_tree); + + if transcripts_commitment_config.has_plaintext_hashes() { + request_builder.plaintext_hashes(transcripts_commitment_config.plaintext_hashes()); + } + + let (request, secrets) = request_builder.build(&provider).unwrap(); + + // At this point the Authdecode protocol must be run if there was a commitment algorithm used + // which requires it. After that, the Notary can proceed to create an attestation. + + let attestation_config = AttestationConfig::builder() + .supported_signature_algs([SignatureAlgId::SECP256K1]) + .build() + .unwrap(); + + // Notary builds and signs an attestation according to their view of the connection. + let mut attestation_builder = Attestation::builder(&attestation_config); + + // Optionally, Notary obtains authenticated plaintext hashes from an external context and adds them + // to the attestation. + let authenticated_hashes = plaintext_hashes_from_request(&request); + if !authenticated_hashes.is_empty() { + attestation_builder.plaintext_hashes(authenticated_hashes); + } + + let mut attestation_builder = attestation_builder.accept_request(request.clone()).unwrap(); + + attestation_builder + // Notary's view of the connection + .connection_info(connection_info.clone()) + // Server key Notary received during handshake + .server_ephemeral_key(server_ephemeral_key) + .encoding_seed(encoder_seed().to_vec()); + + let attestation = attestation_builder.build(&provider).unwrap(); + + // Prover validates the attestation is consistent with its request. + request.validate(&attestation).unwrap(); + + let mut transcript_proof_builder = secrets.transcript_proof_builder(); + + // Stores the ranges which were revealed for the sent and the received data. + let mut revealed_sent = RangeSet::default(); + let mut revealed_recv = RangeSet::default(); + + transcript_proof_builder + .reveal(&(0..sent_len / 2), Direction::Sent) + .unwrap(); + revealed_sent = revealed_sent.union(&(0..sent_len / 2)); + + transcript_proof_builder + .reveal(&(0..recv_len / 2), Direction::Received) + .unwrap(); + revealed_recv = revealed_recv.union(&(0..recv_len / 2)); + + #[cfg(feature = "use_poseidon_halo2")] + { + transcript_proof_builder.default_kind(TranscriptCommitmentKind::Hash { + alg: HashAlgId::POSEIDON_HALO2, + }); + transcript_proof_builder + .reveal(&(sent_len / 2..sent_len), Direction::Sent) + .unwrap(); + revealed_sent = revealed_sent.union(&(sent_len / 2..sent_len)); + + transcript_proof_builder + .reveal(&(recv_len / 2..recv_len), Direction::Received) + .unwrap(); + revealed_recv = revealed_recv.union(&(recv_len / 2..recv_len)); + } + + let transcript_proof = transcript_proof_builder.build().unwrap(); + + let mut builder = attestation.presentation_builder(&provider); + + builder.identity_proof(secrets.identity_proof()); + builder.transcript_proof(transcript_proof); + + let presentation = builder.build().unwrap(); + + // Verifier verifies the presentation. + let PresentationOutput { + server_name: presented_server_name, + connection_info: presented_connection_info, + transcript: presented_transcript, + .. + } = presentation.verify(&provider).unwrap(); + + assert_eq!(presented_server_name.unwrap(), server_name); + assert_eq!(presented_connection_info, connection_info); + + let presented_transcript = presented_transcript.unwrap(); + + assert_eq!( + presented_transcript, + secrets + .transcript() + .to_partial(Idx::new(revealed_sent), Idx::new(revealed_recv)) + ); +} diff --git a/crates/data-fixtures/Cargo.toml b/crates/data-fixtures/Cargo.toml new file mode 100644 index 0000000000..5c12feebf8 --- /dev/null +++ b/crates/data-fixtures/Cargo.toml @@ -0,0 +1,5 @@ +[package] +name = "tlsn-data-fixtures" +version = "0.0.0" +edition = "2021" +publish = false diff --git a/crates/data-fixtures/data/http/request_get_empty b/crates/data-fixtures/data/http/request_get_empty new file mode 100644 index 0000000000..89358940f2 --- /dev/null +++ b/crates/data-fixtures/data/http/request_get_empty @@ -0,0 +1,2 @@ +GET / HTTP/1.1 + diff --git a/crates/data-fixtures/data/http/request_get_empty_header b/crates/data-fixtures/data/http/request_get_empty_header new file mode 100644 index 0000000000..c782810863 --- /dev/null +++ b/crates/data-fixtures/data/http/request_get_empty_header @@ -0,0 +1,4 @@ +GET / HTTP/1.1 +Host: localhost +Empty-Header: + diff --git a/crates/data-fixtures/data/http/request_get_with_header b/crates/data-fixtures/data/http/request_get_with_header new file mode 100644 index 0000000000..4f058ccaa0 --- /dev/null +++ b/crates/data-fixtures/data/http/request_get_with_header @@ -0,0 +1,3 @@ +GET / HTTP/1.1 +Host: localhost + diff --git a/crates/data-fixtures/data/http/request_post_json b/crates/data-fixtures/data/http/request_post_json new file mode 100644 index 0000000000..86f47b5ee0 --- /dev/null +++ b/crates/data-fixtures/data/http/request_post_json @@ -0,0 +1,6 @@ +POST /hello HTTP/1.1 +Host: localhost +Content-Length: 44 +Content-Type: application/json + +{"foo": "bar", "bazz": 123, "buzz": [1,"5"]} \ No newline at end of file diff --git a/crates/data-fixtures/data/http/response_empty b/crates/data-fixtures/data/http/response_empty new file mode 100644 index 0000000000..63eb7a5e1c --- /dev/null +++ b/crates/data-fixtures/data/http/response_empty @@ -0,0 +1,3 @@ +HTTP/1.1 200 OK +Content-Length: 0 + diff --git a/crates/data-fixtures/data/http/response_empty_header b/crates/data-fixtures/data/http/response_empty_header new file mode 100644 index 0000000000..598d763550 --- /dev/null +++ b/crates/data-fixtures/data/http/response_empty_header @@ -0,0 +1,5 @@ +HTTP/1.1 200 OK +Cookie: very-secret-cookie +Content-Length: 0 +Empty-Header: + diff --git a/crates/data-fixtures/data/http/response_json b/crates/data-fixtures/data/http/response_json new file mode 100644 index 0000000000..bd42e134f7 --- /dev/null +++ b/crates/data-fixtures/data/http/response_json @@ -0,0 +1,6 @@ +HTTP/1.1 200 OK +Cookie: very-secret-cookie +Content-Length: 44 +Content-Type: application/json + +{"foo": "bar", "bazz": 123, "buzz": [1,"5"]} diff --git a/crates/data-fixtures/data/http/response_text b/crates/data-fixtures/data/http/response_text new file mode 100644 index 0000000000..e07d0f01b3 --- /dev/null +++ b/crates/data-fixtures/data/http/response_text @@ -0,0 +1,6 @@ +HTTP/1.1 200 OK +Cookie: very-secret-cookie +Content-Length: 14 +Content-Type: text/plain + +Hello World!!! \ No newline at end of file diff --git a/crates/data-fixtures/src/http.rs b/crates/data-fixtures/src/http.rs new file mode 100644 index 0000000000..3ffef3dde2 --- /dev/null +++ b/crates/data-fixtures/src/http.rs @@ -0,0 +1,53 @@ +//! HTTP data fixtures + +/// HTTP requests +pub mod request { + use crate::define_fixture; + + define_fixture!( + GET_EMPTY, + "A GET request without a body or headers.", + "../data/http/request_get_empty" + ); + define_fixture!( + GET_EMPTY_HEADER, + "A GET request with an empty header.", + "../data/http/request_get_empty_header" + ); + define_fixture!( + GET_WITH_HEADER, + "A GET request with a header.", + "../data/http/request_get_with_header" + ); + define_fixture!( + POST_JSON, + "A POST request with a JSON body.", + "../data/http/request_post_json" + ); +} + +/// HTTP responses +pub mod response { + use crate::define_fixture; + + define_fixture!( + OK_EMPTY, + "An OK response without a body.", + "../data/http/response_empty" + ); + define_fixture!( + OK_EMPTY_HEADER, + "An OK response with an empty header.", + "../data/http/response_empty" + ); + define_fixture!( + OK_TEXT, + "An OK response with a text body.", + "../data/http/response_text" + ); + define_fixture!( + OK_JSON, + "An OK response with a JSON body.", + "../data/http/response_json" + ); +} diff --git a/crates/data-fixtures/src/lib.rs b/crates/data-fixtures/src/lib.rs new file mode 100644 index 0000000000..658c7f70d1 --- /dev/null +++ b/crates/data-fixtures/src/lib.rs @@ -0,0 +1,14 @@ +pub mod http; + +macro_rules! define_fixture { + ($name:ident, $doc:tt, $path:tt) => { + #[doc = $doc] + /// + /// ```text + #[doc = include_str!($path)] + /// ``` + pub const $name: &[u8] = include_bytes!($path); + }; +} + +pub(crate) use define_fixture; diff --git a/crates/examples/.gitignore b/crates/examples/.gitignore new file mode 100644 index 0000000000..824f69e762 --- /dev/null +++ b/crates/examples/.gitignore @@ -0,0 +1,2 @@ +// Ignore files from examples. +*.tlsn \ No newline at end of file diff --git a/crates/examples/Cargo.toml b/crates/examples/Cargo.toml new file mode 100644 index 0000000000..59761f1ddc --- /dev/null +++ b/crates/examples/Cargo.toml @@ -0,0 +1,59 @@ +[package] +edition = "2021" +name = "tlsn-examples" +publish = false +version = "0.0.0" + +[dependencies] +notary-client = { workspace = true } +tlsn-common = { workspace = true } +tlsn-core = { workspace = true } +tlsn-prover = { workspace = true } +tlsn-utils = { workspace = true } +tlsn-verifier = { workspace = true } +tlsn-formats = { workspace = true } + +bincode = { workspace = true } +chrono = { workspace = true } +dotenv = { version = "0.15.0" } +elliptic-curve = { workspace = true, features = ["pkcs8"] } +futures = { workspace = true } +http-body-util = { workspace = true } +hex = { workspace = true } +hyper = { workspace = true, features = ["client", "http1"] } +hyper-util = { workspace = true, features = ["full"] } +k256 = { workspace = true, features = ["ecdsa"] } +regex = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +tokio = { workspace = true, features = [ + "rt", + "rt-multi-thread", + "macros", + "net", + "io-std", + "fs", +] } +tokio-util = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } + +[[example]] +name = "attestation_prove" +path = "attestation/prove.rs" + +[[example]] +name = "attestation_present" +path = "attestation/present.rs" + +[[example]] +name = "attestation_verify" +path = "attestation/verify.rs" + +[[example]] +name = "interactive" +path = "interactive/interactive.rs" + +[[example]] +name = "discord_dm" +path = "discord/discord_dm.rs" diff --git a/tlsn/examples/README.md b/crates/examples/README.md similarity index 64% rename from tlsn/examples/README.md rename to crates/examples/README.md index 43fc7d59ac..bff6bcb9e9 100644 --- a/tlsn/examples/README.md +++ b/crates/examples/README.md @@ -2,7 +2,8 @@ This folder contains examples showing how to use the TLSNotary protocol. -* [simple](./simple/README.md) shows how to perform a simple notarization. +* [attestation](./attestation/README.md) shows how to perform a simple notarization. +* [interactive](./interactive/README.md) interactive Prover and Verifier, without a trusted notary. * [twitter](./twitter/README.md) shows how to notarize a Twitter DM. * [discord](./discord/README.md) shows how to notarize a Discord DM. diff --git a/crates/examples/attestation/README.md b/crates/examples/attestation/README.md new file mode 100644 index 0000000000..6cfb1db3de --- /dev/null +++ b/crates/examples/attestation/README.md @@ -0,0 +1,75 @@ +## Simple Attestation Example: Notarize Public Data from example.com (Rust) + +This example demonstrates the simplest possible use case for TLSNotary: +1. Fetch and acquire an attestation of its content. +2. Create a verifiable presentation using the attestation, while redacting the value of a header. +3. Verify the presentation. + +### 1. Notarize + +Run the `prove` binary. + +```shell +cargo run --release --example attestation_prove +``` + +If the notarization was successful, you should see this output in the console: + +```log +Starting an MPC TLS connection with the server +Got a response from the server +Notarization completed successfully! +The attestation has been written to `example.attestation.tlsn` and the corresponding secrets to `example.secrets.tlsn`. +``` + +⚠️ In this simple example the `Notary` server is automatically started in the background. Note that this is for demonstration purposes only. In a real world example, the notary should be run by a trusted party. Consult the [Notary Server Docs](https://docs.tlsnotary.org/developers/notary_server.html) for more details on how to run a notary server. + +### 2. Build a verifiable presentation + +This will build a verifiable presentation with the `User-Agent` header redacted from the request. This presentation can be shared with any verifier you wish to present the data to. + +Run the `present` binary. + +```shell +cargo run --release --example attestation_present +``` + +If successful, you should see this output in the console: + +```log +Presentation built successfully! +The presentation has been written to `example.presentation.tlsn`. +``` + +### 3. Verify the presentation + +This will read the presentation from the previous step, verify it, and print the disclosed data to console. + +Run the `verify` binary. + +```shell +cargo run --release --example attestation_verify +``` + +If successful, you should see this output in the console: + +```log +Verifying presentation with {key algorithm} key: { hex encoded key } + +**Ask yourself, do you trust this key?** + +------------------------------------------------------------------- +Successfully verified that the data below came from a session with example.com at 2024-10-03 03:01:40 UTC. +Note that the data which the Prover chose not to disclose are shown as X. + +Data sent: +... +``` + +⚠️ Notice that the presentation comes with a "verifying key". This is the key the Notary used when issuing the attestation that the presentation was built from. If you trust the Notary, or more specifically the verifying key, then you can trust that the presented data is authentic. + +### Next steps + +Try out the [Discord example](../discord/README.md) and notarize a Discord conversations. + + diff --git a/crates/examples/attestation/present.rs b/crates/examples/attestation/present.rs new file mode 100644 index 0000000000..802c2c0d26 --- /dev/null +++ b/crates/examples/attestation/present.rs @@ -0,0 +1,61 @@ +// This example demonstrates how to build a verifiable presentation from an +// attestation and the corresponding connection secrets. See the `prove.rs` +// example to learn how to acquire an attestation from a Notary. + +use tlsn_core::{attestation::Attestation, presentation::Presentation, CryptoProvider, Secrets}; +use tlsn_formats::http::HttpTranscript; + +fn main() -> Result<(), Box> { + // Read attestation from disk. + let attestation: Attestation = + bincode::deserialize(&std::fs::read("example.attestation.tlsn")?)?; + + // Read secrets from disk. + let secrets: Secrets = bincode::deserialize(&std::fs::read("example.secrets.tlsn")?)?; + + // Parse the HTTP transcript. + let transcript = HttpTranscript::parse(secrets.transcript())?; + + // Build a transcript proof. + let mut builder = secrets.transcript_proof_builder(); + + let request = &transcript.requests[0]; + // Reveal the structure of the request without the headers or body. + builder.reveal_sent(&request.without_data())?; + // Reveal the request target. + builder.reveal_sent(&request.request.target)?; + // Reveal all headers except the value of the User-Agent header. + for header in &request.headers { + if !header.name.as_str().eq_ignore_ascii_case("User-Agent") { + builder.reveal_sent(header)?; + } else { + builder.reveal_sent(&header.without_value())?; + } + } + // Reveal the entire response. + builder.reveal_recv(&transcript.responses[0])?; + + let transcript_proof = builder.build()?; + + // Use default crypto provider to build the presentation. + let provider = CryptoProvider::default(); + + let mut builder = attestation.presentation_builder(&provider); + + builder + .identity_proof(secrets.identity_proof()) + .transcript_proof(transcript_proof); + + let presentation: Presentation = builder.build()?; + + // Write the presentation to disk. + std::fs::write( + "example.presentation.tlsn", + bincode::serialize(&presentation)?, + )?; + + println!("Presentation built successfully!"); + println!("The presentation has been written to `example.presentation.tlsn`."); + + Ok(()) +} diff --git a/crates/examples/attestation/prove.rs b/crates/examples/attestation/prove.rs new file mode 100644 index 0000000000..23eb733750 --- /dev/null +++ b/crates/examples/attestation/prove.rs @@ -0,0 +1,125 @@ +// This example demonstrates how to use the Prover to acquire an attestation for +// an HTTP request sent to example.com. The attestation and secrets are saved to +// disk. + +use http_body_util::Empty; +use hyper::{body::Bytes, Request, StatusCode}; +use hyper_util::rt::TokioIo; +use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; + +use tlsn_common::config::ProtocolConfig; +use tlsn_core::{request::RequestConfig, transcript::TranscriptCommitConfig}; +use tlsn_examples::run_notary; +use tlsn_formats::http::{DefaultHttpCommitter, HttpCommit, HttpTranscript}; +use tlsn_prover::{Prover, ProverConfig}; + +// Setting of the application server +const SERVER_DOMAIN: &str = "example.com"; +const USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36"; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + + let (prover_socket, notary_socket) = tokio::io::duplex(1 << 16); + + // Start a local simple notary service + tokio::spawn(run_notary(notary_socket.compat())); + + // Prover configuration. + let config = ProverConfig::builder() + .server_name(SERVER_DOMAIN) + .protocol_config( + ProtocolConfig::builder() + // We must configure the amount of data we expect to exchange beforehand, which will + // be preprocessed prior to the connection. Reducing these limits will improve + // performance. + .max_sent_data(1024) + .max_recv_data(4096) + .build()?, + ) + .build()?; + + // Create a new prover and perform necessary setup. + let prover = Prover::new(config).setup(prover_socket.compat()).await?; + + // Open a TCP connection to the server. + let client_socket = tokio::net::TcpStream::connect((SERVER_DOMAIN, 443)).await?; + + // Bind the prover to the server connection. + // The returned `mpc_tls_connection` is an MPC TLS connection to the server: all + // data written to/read from it will be encrypted/decrypted using MPC with + // the notary. + let (mpc_tls_connection, prover_fut) = prover.connect(client_socket.compat()).await?; + let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat()); + + // Spawn the prover task to be run concurrently in the background. + let prover_task = tokio::spawn(prover_fut); + + // Attach the hyper HTTP client to the connection. + let (mut request_sender, connection) = + hyper::client::conn::http1::handshake(mpc_tls_connection).await?; + + // Spawn the HTTP task to be run concurrently in the background. + tokio::spawn(connection); + + // Build a simple HTTP request with common headers + let request = Request::builder() + .uri("/") + .header("Host", SERVER_DOMAIN) + .header("Accept", "*/*") + // Using "identity" instructs the Server not to use compression for its HTTP response. + // TLSNotary tooling does not support compression. + .header("Accept-Encoding", "identity") + .header("Connection", "close") + .header("User-Agent", USER_AGENT) + .body(Empty::::new())?; + + println!("Starting an MPC TLS connection with the server"); + + // Send the request to the server and wait for the response. + let response = request_sender.send_request(request).await?; + + println!("Got a response from the server"); + + assert!(response.status() == StatusCode::OK); + + // The prover task should be done now, so we can await it. + let prover = prover_task.await??; + + // Prepare for notarization. + let mut prover = prover.start_notarize(); + + // Parse the HTTP transcript. + let transcript = HttpTranscript::parse(prover.transcript())?; + + // Commit to the transcript. + let mut builder = TranscriptCommitConfig::builder(prover.transcript()); + + DefaultHttpCommitter::default().commit_transcript(&mut builder, &transcript)?; + + prover.transcript_commit(builder.build()?); + + // Request an attestation. + let config = RequestConfig::default(); + + let (attestation, secrets) = prover.finalize(&config).await?; + + // Write the attestation to disk. + tokio::fs::write( + "example.attestation.tlsn", + bincode::serialize(&attestation)?, + ) + .await?; + + // Write the secrets to disk. + tokio::fs::write("example.secrets.tlsn", bincode::serialize(&secrets)?).await?; + + println!("Notarization completed successfully!"); + println!( + "The attestation has been written to `example.attestation.tlsn` and the \ + corresponding secrets to `example.secrets.tlsn`." + ); + + Ok(()) +} diff --git a/crates/examples/attestation/verify.rs b/crates/examples/attestation/verify.rs new file mode 100644 index 0000000000..6b715c924f --- /dev/null +++ b/crates/examples/attestation/verify.rs @@ -0,0 +1,60 @@ +// This example demonstrates how to verify a presentation. See `present.rs` for +// an example of how to build a presentation from an attestation and connection +// secrets. + +use std::time::Duration; + +use tlsn_core::{ + presentation::{Presentation, PresentationOutput}, + signing::VerifyingKey, + CryptoProvider, +}; + +fn main() -> Result<(), Box> { + // Read the presentation from disk. + let presentation: Presentation = + bincode::deserialize(&std::fs::read("example.presentation.tlsn")?)?; + + let provider = CryptoProvider::default(); + + let VerifyingKey { + alg, + data: key_data, + } = presentation.verifying_key(); + + println!( + "Verifying presentation with {alg} key: {}\n\n**Ask yourself, do you trust this key?**\n", + hex::encode(key_data) + ); + + // Verify the presentation. + let PresentationOutput { + server_name, + connection_info, + transcript, + .. + } = presentation.verify(&provider).unwrap(); + + // The time at which the connection was started. + let time = chrono::DateTime::UNIX_EPOCH + Duration::from_secs(connection_info.time); + let server_name = server_name.unwrap(); + let mut partial_transcript = transcript.unwrap(); + // Set the unauthenticated bytes so they are distinguishable. + partial_transcript.set_unauthed(b'X'); + + let sent = String::from_utf8_lossy(partial_transcript.sent_unsafe()); + let recv = String::from_utf8_lossy(partial_transcript.received_unsafe()); + + println!("-------------------------------------------------------------------"); + println!( + "Successfully verified that the data below came from a session with {server_name} at {time}.", + ); + println!("Note that the data which the Prover chose not to disclose are shown as X.\n"); + println!("Data sent:\n"); + println!("{}\n", sent); + println!("Data received:\n"); + println!("{}\n", recv); + println!("-------------------------------------------------------------------"); + + Ok(()) +} diff --git a/tlsn/examples/discord/.env.example b/crates/examples/discord/.env.example similarity index 100% rename from tlsn/examples/discord/.env.example rename to crates/examples/discord/.env.example diff --git a/tlsn/examples/discord/README.md b/crates/examples/discord/README.md similarity index 56% rename from tlsn/examples/discord/README.md rename to crates/examples/discord/README.md index 34bce43f08..79ce461242 100644 --- a/tlsn/examples/discord/README.md +++ b/crates/examples/discord/README.md @@ -1,6 +1,6 @@ # Notarize Discord DMs -The `discord_dm.rs` example sets up a TLS connection with Discord and notarizes the requested DMs. The notarized session and the proof are written to local JSON files (`discord_dm_notarized_session.json` and `discord_dm_proof.json`) for easier inspection. +The `discord_dm.rs` example sets up a TLS connection with Discord and notarizes the requested DMs. The attestation and secrets are saved to disk. This involves 3 steps: 1. Configure the inputs @@ -27,62 +27,55 @@ You can find the `CHANNEL_ID` directly in the url: `https://discord.com/channels/@me/{CHANNEL_ID)` ## Start the notary server -At the root level of this repository, run -```sh -cd notary-server -cargo run --release -``` +1. Edit the notary server [config file](../../notary/server/config/config.yaml) to turn off TLS so that self-signed certificates can be avoided (⚠️ this is only for local development purposes — TLS must be used in production). + ```yaml + tls: + enabled: false + ... + ``` +2. Run the following at the root level of this repository to start the notary server: + ```shell + cd crates/notary/server + cargo run --release + ``` The notary server will now be running in the background waiting for connections. -For more information on how to configure the `Notary` server, please refer to [this](../../../notary-server/README.md#running-the-server). +For more information on how to configure the `Notary` server, please refer to [this](../../notary/server/README.md#running-the-server). ## Notarize In this tlsn/examples/discord folder, run the following command: ```sh -RUST_LOG=debug,yamux=info cargo run --release --example discord_dm +RUST_LOG=DEBUG,uid_mux=INFO,yamux=INFO cargo run --release --example discord_dm ``` If everything goes well, you should see output similar to the following: ```log -.. -2023-09-22T14:40:51.416047Z DEBUG discord_dm: [ +... +2024-06-26T08:49:47.017439Z DEBUG connect:tls_connection: tls_client_async: handshake complete +2024-06-26T08:49:48.676459Z DEBUG connect:tls_connection: tls_client_async: server closed connection +2024-06-26T08:49:48.676481Z DEBUG connect:commit: tls_mpc::leader: committing to transcript +2024-06-26T08:49:48.676503Z DEBUG connect:tls_connection: tls_client_async: client shutdown +2024-06-26T08:49:48.676466Z DEBUG discord_dm: Sent request +2024-06-26T08:49:48.676550Z DEBUG discord_dm: Request OK +2024-06-26T08:49:48.676598Z DEBUG connect:close_connection: tls_mpc::leader: closing connection +2024-06-26T08:49:48.676613Z DEBUG connect: tls_mpc::leader: leader actor stopped +2024-06-26T08:49:48.676618Z DEBUG discord_dm: [ { "attachments": [], - "author": { - "accent_color": null, - "avatar": "dd07631c9613240aa969d6e7916eb7ae", - "avatar_decoration_data": null, - "banner": null, - "banner_color": null, - "discriminator": "0", - "flags": 0, - "global_name": "sinu", - "id": "662709891017867273", - "public_flags": 0, - "username": "sinu_" - }, + ... "channel_id": "1154750485639745567", - "components": [], - "content": "Hello ETHGlobal NY!!", - "edited_timestamp": null, - "embeds": [], - "flags": 0, - "id": "1154750835784429678", - "mention_everyone": false, - "mention_roles": [], - "mentions": [], - "pinned": false, - "timestamp": "2023-09-22T12:07:33.484000+00:00", - "tts": false, - "type": 0 - }, - .. + ... + } ] -2023-09-22T14:40:51.847455Z DEBUG discord_dm: Notarization complete! +2024-06-26T08:49:48.678621Z DEBUG finalize: tlsn_prover::tls::notarize: starting finalization +2024-06-26T08:49:48.680839Z DEBUG finalize: tlsn_prover::tls::notarize: received OT secret +2024-06-26T08:49:50.004432Z INFO finalize:poll{role=Client}:handle_shutdown: uid_mux::yamux: mux connection closed +2024-06-26T08:49:50.004448Z INFO finalize:poll{role=Client}: uid_mux::yamux: connection complete +2024-06-26T08:49:50.004583Z DEBUG discord_dm: Notarization complete! ``` If the transcript was too long, you may encounter the following error. This occurs because there is a default limit of notarization size to 16kB: @@ -91,16 +84,6 @@ If the transcript was too long, you may encounter the following error. This occu thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: IOError(Custom { kind: InvalidData, error: BackendError(DecryptionError("Other: KOSReceiverActor is not setup")) })', /Users/heeckhau/tlsnotary/tlsn/tlsn/tlsn-prover/src/lib.rs:173:50 ``` -# Verifier - -The `discord_dm` example also generated a proof of the transcript with the `Authorization` header redacted from the request, saved in `discord_dm_proof.json`. - -We can verify this proof using the `discord_dm_verifier` by running: - -``` -cargo run --release --example discord_dm_verifier -``` - -This will verify the proof and print out the redacted transcript! +# Verify -> **_NOTE:_** ℹ️ hosts a generic proof visualizer. Drag and drop your proof into the drop zone to check and render your proof. \ No newline at end of file +See the [`present`](../attestation/present.rs) and [`verify`](../attestation/verify.rs) examples for a demonstration of how to construct a presentation and verify it. \ No newline at end of file diff --git a/crates/examples/discord/discord_dm.rs b/crates/examples/discord/discord_dm.rs new file mode 100644 index 0000000000..dd13198063 --- /dev/null +++ b/crates/examples/discord/discord_dm.rs @@ -0,0 +1,216 @@ +// This example shows how to notarize Discord DMs. +// +// The example uses the notary server implemented in ../../notary/server + +use http_body_util::{BodyExt, Empty}; +use hyper::{body::Bytes, Request, StatusCode}; +use hyper_util::rt::TokioIo; +use notary_client::{Accepted, NotarizationRequest, NotaryClient}; +use std::{env, str}; +use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; +use tracing::debug; +use utils::range::RangeSet; + +use tlsn_common::config::ProtocolConfig; +use tlsn_core::{request::RequestConfig, transcript::TranscriptCommitConfig}; +use tlsn_prover::{Prover, ProverConfig}; + +// Setting of the application server +const SERVER_DOMAIN: &str = "discord.com"; + +// Setting of the notary server — make sure these are the same with the config +// in ../../notary/server +const NOTARY_HOST: &str = "127.0.0.1"; +const NOTARY_PORT: u16 = 7047; + +// Maximum number of bytes that can be sent from prover to server +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server +const MAX_RECV_DATA: usize = 1 << 14; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + // Load secret variables frome environment for discord server connection + dotenv::dotenv().ok(); + let channel_id = env::var("CHANNEL_ID").unwrap(); + let auth_token = env::var("AUTHORIZATION").unwrap(); + let user_agent = env::var("USER_AGENT").unwrap(); + + // Build a client to connect to the notary server. + let notary_client = NotaryClient::builder() + .host(NOTARY_HOST) + .port(NOTARY_PORT) + // WARNING: Always use TLS to connect to notary server, except if notary is running locally + // e.g. this example, hence `enable_tls` is set to False (else it always defaults to True). + .enable_tls(false) + .build() + .unwrap(); + + // Send requests for configuration and notarization to the notary server. + let notarization_request = NotarizationRequest::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + let Accepted { + io: notary_connection, + id: _session_id, + .. + } = notary_client + .request_notarization(notarization_request) + .await + .expect("Could not connect to notary. Make sure it is running."); + + // Set up protocol configuration for prover. + let protocol_config = ProtocolConfig::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + // Create a new prover and set up the MPC backend. + let prover_config = ProverConfig::builder() + .server_name(SERVER_DOMAIN) + .protocol_config(protocol_config) + .build() + .unwrap(); + let prover = Prover::new(prover_config) + .setup(notary_connection.compat()) + .await + .unwrap(); + + // Open a new socket to the application server. + let client_socket = tokio::net::TcpStream::connect((SERVER_DOMAIN, 443)) + .await + .unwrap(); + + // Bind the Prover to server connection + let (tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); + + // Spawn the Prover to be run concurrently + let prover_task = tokio::spawn(prover_fut); + + // Attach the hyper HTTP client to the TLS connection + let (mut request_sender, connection) = + hyper::client::conn::http1::handshake(TokioIo::new(tls_connection.compat())) + .await + .unwrap(); + + // Spawn the HTTP task to be run concurrently + tokio::spawn(connection); + + // Build the HTTP request to fetch the DMs + let request = Request::builder() + .uri(format!( + "https://{SERVER_DOMAIN}/api/v9/channels/{channel_id}/messages?limit=2" + )) + .header("Host", SERVER_DOMAIN) + .header("Accept", "*/*") + .header("Accept-Language", "en-US,en;q=0.5") + .header("Accept-Encoding", "identity") + .header("User-Agent", user_agent) + .header("Authorization", &auth_token) + .header("Connection", "close") + .body(Empty::::new()) + .unwrap(); + + debug!("Sending request"); + + let response = request_sender.send_request(request).await.unwrap(); + + debug!("Sent request"); + + assert!(response.status() == StatusCode::OK, "{}", response.status()); + + debug!("Request OK"); + + // Pretty printing :) + let payload = response.into_body().collect().await.unwrap().to_bytes(); + let parsed = + serde_json::from_str::(&String::from_utf8_lossy(&payload)).unwrap(); + debug!("{}", serde_json::to_string_pretty(&parsed).unwrap()); + + // The Prover task should be done now, so we can grab it. + let prover = prover_task.await.unwrap().unwrap(); + + // Prepare for notarization + let mut prover = prover.start_notarize(); + + // Identify the ranges in the transcript that contain secrets + let sent_transcript = prover.transcript().sent(); + let recv_transcript = prover.transcript().received(); + + // Identify the ranges in the outbound data which contain data which we want to + // disclose + let (sent_public_ranges, _) = find_ranges(sent_transcript, &[auth_token.as_bytes()]); + #[allow(clippy::single_range_in_vec_init)] + let recv_public_ranges = RangeSet::from([0..recv_transcript.len()]); + + let mut builder = TranscriptCommitConfig::builder(prover.transcript()); + + // Commit to public ranges + builder.commit_sent(&sent_public_ranges).unwrap(); + builder.commit_recv(&recv_public_ranges).unwrap(); + + let config = builder.build().unwrap(); + + prover.transcript_commit(config); + + // Finalize, returning the notarized session + let request_config = RequestConfig::default(); + let (attestation, secrets) = prover.finalize(&request_config).await.unwrap(); + + debug!("Notarization complete!"); + + tokio::fs::write( + "discord.attestation.tlsn", + bincode::serialize(&attestation).unwrap(), + ) + .await + .unwrap(); + + tokio::fs::write( + "discord.secrets.tlsn", + bincode::serialize(&secrets).unwrap(), + ) + .await + .unwrap(); +} + +/// Find the ranges of the public and private parts of a sequence. +/// +/// Returns a tuple of `(public, private)` ranges. +fn find_ranges(seq: &[u8], sub_seq: &[&[u8]]) -> (RangeSet, RangeSet) { + let mut private_ranges = Vec::new(); + for s in sub_seq { + for (idx, w) in seq.windows(s.len()).enumerate() { + if w == *s { + private_ranges.push(idx..(idx + w.len())); + } + } + } + + let mut sorted_ranges = private_ranges.clone(); + sorted_ranges.sort_by_key(|r| r.start); + + let mut public_ranges = Vec::new(); + let mut last_end = 0; + for r in sorted_ranges { + if r.start > last_end { + public_ranges.push(last_end..r.start); + } + last_end = r.end; + } + + if last_end < seq.len() { + public_ranges.push(last_end..seq.len()); + } + + ( + RangeSet::from(public_ranges), + RangeSet::from(private_ranges), + ) +} diff --git a/crates/examples/interactive/README.md b/crates/examples/interactive/README.md new file mode 100644 index 0000000000..bed068da40 --- /dev/null +++ b/crates/examples/interactive/README.md @@ -0,0 +1,5 @@ +## Simple Interactive Verifier: Verifying Data from an API in Rust + +This example demonstrates how to use TLSNotary in a simple interactive session between a Prover and a Verifier. It involves the Verifier first verifying the MPC-TLS session and then confirming the correctness of the data. + +Note: In this example, the Prover and the Verifier run on the same machine. In real-world scenarios, the Prover and Verifier would be separate entities. diff --git a/crates/examples/interactive/interactive.rs b/crates/examples/interactive/interactive.rs new file mode 100644 index 0000000000..dd6fd86a35 --- /dev/null +++ b/crates/examples/interactive/interactive.rs @@ -0,0 +1,201 @@ +use http_body_util::Empty; +use hyper::{body::Bytes, Request, StatusCode, Uri}; +use hyper_util::rt::TokioIo; +use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; +use tlsn_core::transcript::Idx; +use tlsn_prover::{state::Prove, Prover, ProverConfig}; +use tlsn_verifier::{SessionInfo, Verifier, VerifierConfig}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; +use tracing::instrument; + +const SECRET: &str = "TLSNotary's private key 🤡"; +const SERVER_DOMAIN: &str = "example.com"; + +// Maximum number of bytes that can be sent from prover to server +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server +const MAX_RECV_DATA: usize = 1 << 14; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + let uri = "https://example.com"; + let id = "interactive verifier demo"; + + // Connect prover and verifier. + let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23); + let prover = prover(prover_socket, uri, id); + let verifier = verifier(verifier_socket, id); + let (_, (sent, received, _session_info)) = tokio::join!(prover, verifier); + + println!("Successfully verified {}", &uri); + println!("Verified sent data:\n{}", bytes_to_redacted_string(&sent)); + println!( + "Verified received data:\n{}", + bytes_to_redacted_string(&received) + ); +} + +#[instrument(skip(verifier_socket))] +async fn prover( + verifier_socket: T, + uri: &str, + id: &str, +) { + let uri = uri.parse::().unwrap(); + assert_eq!(uri.scheme().unwrap().as_str(), "https"); + let server_domain = uri.authority().unwrap().host(); + let server_port = uri.port_u16().unwrap_or(443); + + // Create prover and connect to verifier. + // + // Perform the setup phase with the verifier. + let prover = Prover::new( + ProverConfig::builder() + .server_name(server_domain) + .protocol_config( + ProtocolConfig::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(), + ) + .build() + .unwrap(), + ) + .setup(verifier_socket.compat()) + .await + .unwrap(); + + // Connect to TLS Server. + let tls_client_socket = tokio::net::TcpStream::connect((server_domain, server_port)) + .await + .unwrap(); + + // Pass server connection into the prover. + let (mpc_tls_connection, prover_fut) = + prover.connect(tls_client_socket.compat()).await.unwrap(); + + // Wrap the connection in a TokioIo compatibility layer to use it with hyper. + let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat()); + + // Spawn the Prover to run in the background. + let prover_task = tokio::spawn(prover_fut); + + // MPC-TLS Handshake. + let (mut request_sender, connection) = + hyper::client::conn::http1::handshake(mpc_tls_connection) + .await + .unwrap(); + + // Spawn the connection to run in the background. + tokio::spawn(connection); + + // MPC-TLS: Send Request and wait for Response. + let request = Request::builder() + .uri(uri.clone()) + .header("Host", server_domain) + .header("Connection", "close") + .header("Secret", SECRET) + .method("GET") + .body(Empty::::new()) + .unwrap(); + let response = request_sender.send_request(request).await.unwrap(); + + assert!(response.status() == StatusCode::OK); + + // Create proof for the Verifier. + let mut prover = prover_task.await.unwrap().unwrap().start_prove(); + + let idx_sent = redact_and_reveal_sent_data(&mut prover); + let idx_recv = redact_and_reveal_received_data(&mut prover); + + // Reveal parts of the transcript + prover.prove_transcript(idx_sent, idx_recv).await.unwrap(); + + // Finalize. + prover.finalize().await.unwrap() +} + +#[instrument(skip(socket))] +async fn verifier( + socket: T, + id: &str, +) -> (Vec, Vec, SessionInfo) { + // Setup Verifier. + let config_validator = ProtocolConfigValidator::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + let verifier_config = VerifierConfig::builder() + .protocol_config_validator(config_validator) + .build() + .unwrap(); + let verifier = Verifier::new(verifier_config); + + // Verify MPC-TLS and wait for (redacted) data. + let (mut partial_transcript, session_info) = verifier.verify(socket.compat()).await.unwrap(); + partial_transcript.set_unauthed(0); + + // Check sent data: check host. + let sent = partial_transcript.sent_unsafe().to_vec(); + let sent_data = String::from_utf8(sent.clone()).expect("Verifier expected sent data"); + sent_data + .find(SERVER_DOMAIN) + .unwrap_or_else(|| panic!("Verification failed: Expected host {}", SERVER_DOMAIN)); + + // Check received data: check json and version number. + let received = partial_transcript.received_unsafe().to_vec(); + let response = String::from_utf8(received.clone()).expect("Verifier expected received data"); + response + .find("Example Domain") + .expect("Expected valid data from example.com"); + + // Check Session info: server name. + assert_eq!(session_info.server_name.as_str(), SERVER_DOMAIN); + + (sent, received, session_info) +} + +/// Redacts and reveals received data to the verifier. +fn redact_and_reveal_received_data(prover: &mut Prover) -> Idx { + let recv_transcript = prover.transcript().received(); + let recv_transcript_len = recv_transcript.len(); + + // Get the received data as a string. + let received_string = String::from_utf8(recv_transcript.to_vec()).unwrap(); + // Find the substring "illustrative". + let start = received_string + .find("illustrative") + .expect("Error: The substring 'illustrative' was not found in the received data."); + let end = start + "illustrative".len(); + + Idx::new([0..start, end..recv_transcript_len]) +} + +/// Redacts and reveals sent data to the verifier. +fn redact_and_reveal_sent_data(prover: &mut Prover) -> Idx { + let sent_transcript = prover.transcript().sent(); + let sent_transcript_len = sent_transcript.len(); + + let sent_string = String::from_utf8(sent_transcript.to_vec()).unwrap(); + + let secret_start = sent_string.find(SECRET).unwrap(); + + // Reveal everything except for the SECRET. + Idx::new([ + 0..secret_start, + secret_start + SECRET.len()..sent_transcript_len, + ]) +} + +/// Render redacted bytes as `🙈`. +fn bytes_to_redacted_string(bytes: &[u8]) -> String { + String::from_utf8(bytes.to_vec()) + .unwrap() + .replace('\0', "🙈") +} diff --git a/crates/examples/src/lib.rs b/crates/examples/src/lib.rs new file mode 100644 index 0000000000..f74b9c10d0 --- /dev/null +++ b/crates/examples/src/lib.rs @@ -0,0 +1,46 @@ +use futures::{AsyncRead, AsyncWrite}; +use k256::{pkcs8::DecodePrivateKey, SecretKey}; +use tlsn_common::config::ProtocolConfigValidator; +use tlsn_core::{attestation::AttestationConfig, signing::SignatureAlgId, CryptoProvider}; +use tlsn_verifier::{Verifier, VerifierConfig}; + +/// The private key used by the Notary for signing attestations. +pub const NOTARY_PRIVATE_KEY: &[u8] = &[1u8; 32]; + +// Maximum number of bytes that can be sent from prover to server +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server +const MAX_RECV_DATA: usize = 1 << 14; + +/// Runs a simple Notary with the provided connection to the Prover. +pub async fn run_notary(conn: T) { + let pem_data = include_str!("../../notary/server/fixture/notary/notary.key"); + let secret_key = SecretKey::from_pkcs8_pem(pem_data).unwrap().to_bytes(); + + let mut provider = CryptoProvider::default(); + provider.signer.set_secp256k1(&secret_key).unwrap(); + + // Setup the config. Normally a different ID would be generated + // for each notarization. + let config_validator = ProtocolConfigValidator::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + let config = VerifierConfig::builder() + .protocol_config_validator(config_validator) + .crypto_provider(provider) + .build() + .unwrap(); + + let attestation_config = AttestationConfig::builder() + .supported_signature_algs(vec![SignatureAlgId::SECP256K1]) + .build() + .unwrap(); + + Verifier::new(config) + .notarize(conn, &attestation_config) + .await + .unwrap(); +} diff --git a/crates/formats/Cargo.toml b/crates/formats/Cargo.toml new file mode 100644 index 0000000000..ce0a40cf89 --- /dev/null +++ b/crates/formats/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "tlsn-formats" +version = "0.1.0-alpha.7" +edition = "2021" + +[dependencies] +tlsn-core = { workspace = true } +tlsn-utils = { workspace = true } + +bytes = { workspace = true } +serde = { workspace = true } +spansy = { workspace = true, features = ["serde"] } +thiserror = { workspace = true } + +[dev-dependencies] +tlsn-core = { workspace = true, features = ["fixtures"] } +tlsn-data-fixtures = { workspace = true } +rstest = { workspace = true } diff --git a/crates/formats/src/http/commit.rs b/crates/formats/src/http/commit.rs new file mode 100644 index 0000000000..525a126d7b --- /dev/null +++ b/crates/formats/src/http/commit.rs @@ -0,0 +1,460 @@ +use std::error::Error; + +use spansy::Spanned; +use tlsn_core::transcript::{Direction, TranscriptCommitConfigBuilder}; + +use crate::{ + http::{Body, BodyContent, Header, HttpTranscript, MessageKind, Request, Response, Target}, + json::{DefaultJsonCommitter, JsonCommit}, +}; + +/// HTTP commitment error. +#[derive(Debug, thiserror::Error)] +#[error("http commit error: {msg}")] +pub struct HttpCommitError { + idx: Option, + record_kind: MessageKind, + msg: String, + #[source] + source: Option>, +} + +impl HttpCommitError { + /// Creates a new HTTP commitment error. + /// + /// # Arguments + /// + /// * `record_kind` - The kind of the record (request or response). + /// * `msg` - The error message. + pub fn new(record_kind: MessageKind, msg: impl Into) -> Self { + Self { + idx: None, + record_kind, + msg: msg.into(), + source: None, + } + } + + /// Creates a new HTTP commitment error with a source. + /// + /// # Arguments + /// + /// * `record_kind` - The kind of the record (request or response). + /// * `msg` - The error message. + /// * `source` - The source error. + pub fn new_with_source(record_kind: MessageKind, msg: impl Into, source: E) -> Self + where + E: Into>, + { + Self { + idx: None, + record_kind, + msg: msg.into(), + source: Some(source.into()), + } + } + + /// Sets the index of the request or response in the transcript. + pub fn set_index(&mut self, idx: usize) { + self.idx = Some(idx); + } + + /// Returns the index of the request or response in the transcript, if set. + pub fn index(&self) -> Option { + self.idx + } + + /// Returns the error message. + pub fn msg(&self) -> &str { + &self.msg + } + + /// Returns the kind of record (request or response). + pub fn record_kind(&self) -> &MessageKind { + &self.record_kind + } +} + +/// An HTTP data committer. +#[allow(unused_variables)] +pub trait HttpCommit { + /// Commits to an HTTP transcript. + /// + /// The default implementation commits to each request and response in the + /// transcript separately. + /// + /// # Arguments + /// + /// * `builder` - The transcript commitment builder. + /// * `transcript` - The transcript to commit. + fn commit_transcript( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + transcript: &HttpTranscript, + ) -> Result<(), HttpCommitError> { + for request in &transcript.requests { + self.commit_request(builder, Direction::Sent, request)?; + } + + for response in &transcript.responses { + self.commit_response(builder, Direction::Received, response)?; + } + + Ok(()) + } + + /// Commits to a request. + /// + /// The default implementation commits to the request excluding the target, + /// headers and body. Additionally, it commits to the target, headers + /// and body separately. + /// + /// # Arguments + /// + /// * `builder` - The transcript commitment builder. + /// * `direction` - The direction of the request (sent or received). + /// * `request` - The request to commit to. + fn commit_request( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + direction: Direction, + request: &Request, + ) -> Result<(), HttpCommitError> { + builder.commit(request, direction).map_err(|e| { + HttpCommitError::new_with_source( + MessageKind::Request, + "failed to commit to entire request", + e, + ) + })?; + + if !request.headers.is_empty() || request.body.is_some() { + builder + .commit(&request.without_data(), direction) + .map_err(|e| { + HttpCommitError::new_with_source( + MessageKind::Request, + "failed to commit to request with excluded data", + e, + ) + })?; + } + + self.commit_target(builder, direction, request, &request.request.target)?; + + for header in &request.headers { + self.commit_request_header(builder, direction, request, header)?; + } + + if let Some(body) = &request.body { + self.commit_request_body(builder, direction, request, body)?; + } + + Ok(()) + } + + /// Commits to a request target. + /// + /// The default implementation commits to the target as a whole. + /// + /// # Arguments + /// + /// * `builder` - The transcript commitment builder. + /// * `direction` - The direction of the request (sent or received). + /// * `request` - The parent request. + /// * `target` - The target to commit to. + fn commit_target( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + direction: Direction, + request: &Request, + target: &Target, + ) -> Result<(), HttpCommitError> { + builder.commit(target, direction).map_err(|e| { + HttpCommitError::new_with_source( + MessageKind::Request, + "failed to commit to target in request", + e, + ) + })?; + + Ok(()) + } + + /// Commits to a request header. + /// + /// The default implementation commits to the entire header, and the header + /// excluding the value. + /// + /// # Arguments + /// + /// * `builder` - The transcript commitment builder. + /// * `direction` - The direction of the request (sent or received). + /// * `parent` - The parent request. + /// * `header` - The header to commit to. + fn commit_request_header( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + direction: Direction, + parent: &Request, + header: &Header, + ) -> Result<(), HttpCommitError> { + builder.commit(header, direction).map_err(|e| { + HttpCommitError::new_with_source( + MessageKind::Request, + format!("failed to commit to \"{}\" header", header.name.as_str()), + e, + ) + })?; + + if !header.value.span().is_empty() { + builder + .commit(&header.without_value(), direction) + .map_err(|e| { + HttpCommitError::new_with_source( + MessageKind::Request, + format!( + "failed to commit to \"{}\" header excluding value", + header.name.as_str() + ), + e, + ) + })?; + } + + Ok(()) + } + + /// Commits to a request body. + /// + /// The default implementation commits using the default implementation for + /// the format type of the body. If the format of the body is unknown, + /// it commits to the body as a whole. + /// + /// # Arguments + /// + /// * `builder` - The transcript commitment builder. + /// * `direction` - The direction of the request (sent or received). + /// * `parent` - The parent request. + /// * `body` - The body to commit to. + fn commit_request_body( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + direction: Direction, + parent: &Request, + body: &Body, + ) -> Result<(), HttpCommitError> { + match &body.content { + BodyContent::Json(body) => { + DefaultJsonCommitter::default() + .commit_value(builder, body, direction) + .map_err(|e| { + HttpCommitError::new_with_source( + MessageKind::Request, + "failed to commit to JSON body", + e, + ) + })?; + } + body => { + builder.commit(body, direction).map_err(|e| { + HttpCommitError::new_with_source( + MessageKind::Request, + "failed to commit to unknown content body", + e, + ) + })?; + } + } + + Ok(()) + } + + /// Commits to a response. + /// + /// The default implementation commits to the response excluding the headers + /// and body. Additionally, it commits to the headers and body + /// separately. + /// + /// # Arguments + /// + /// * `builder` - The transcript commitment builder. + /// * `direction` - The direction of the response (sent or received). + /// * `response` - The response to commit to. + fn commit_response( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + direction: Direction, + response: &Response, + ) -> Result<(), HttpCommitError> { + builder.commit(response, direction).map_err(|e| { + HttpCommitError::new_with_source( + MessageKind::Response, + "failed to commit to entire response", + e, + ) + })?; + + if !response.headers.is_empty() || response.body.is_some() { + builder + .commit(&response.without_data(), direction) + .map_err(|e| { + HttpCommitError::new_with_source( + MessageKind::Response, + "failed to commit to response excluding data", + e, + ) + })?; + } + + for header in &response.headers { + self.commit_response_header(builder, direction, response, header)?; + } + + if let Some(body) = &response.body { + self.commit_response_body(builder, direction, response, body)?; + } + + Ok(()) + } + + /// Commits to a response header. + /// + /// The default implementation commits to the entire header, and the header + /// excluding the value. + /// + /// # Arguments + /// + /// * `builder` - The transcript commitment builder. + /// * `direction` - The direction of the response (sent or received). + /// * `parent` - The parent response. + /// * `header` - The header to commit to. + fn commit_response_header( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + direction: Direction, + parent: &Response, + header: &Header, + ) -> Result<(), HttpCommitError> { + builder.commit(header, direction).map_err(|e| { + HttpCommitError::new_with_source( + MessageKind::Response, + format!("failed to commit to \"{}\" header", header.name.as_str()), + e, + ) + })?; + + if !header.value.span().is_empty() { + builder + .commit(&header.without_value(), direction) + .map_err(|e| { + HttpCommitError::new_with_source( + MessageKind::Response, + format!( + "failed to commit to \"{}\" header excluding value in response", + header.name.as_str() + ), + e, + ) + })?; + } + + Ok(()) + } + + /// Commits to a response body. + /// + /// The default implementation commits using the default implementation for + /// the format type of the body. If the format of the body is unknown, + /// it commits to the body as a whole. + /// + /// # Arguments + /// + /// * `builder` - The transcript commitment builder. + /// * `direction` - The direction of the response (sent or received). + /// * `parent` - The parent response. + /// * `body` - The body to commit to. + fn commit_response_body( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + direction: Direction, + parent: &Response, + body: &Body, + ) -> Result<(), HttpCommitError> { + match &body.content { + BodyContent::Json(body) => { + DefaultJsonCommitter::default() + .commit_value(builder, body, direction) + .map_err(|e| { + HttpCommitError::new_with_source( + MessageKind::Response, + "failed to commit to JSON body", + e, + ) + })?; + } + body => { + builder.commit(body, direction).map_err(|e| { + HttpCommitError::new_with_source( + MessageKind::Request, + "failed to commit to unknown content body", + e, + ) + })?; + } + } + + Ok(()) + } +} + +/// The default HTTP committer. +#[derive(Debug, Default, Clone)] +pub struct DefaultHttpCommitter {} + +impl HttpCommit for DefaultHttpCommitter {} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::*; + use spansy::http::{parse_request, parse_response}; + use tlsn_core::transcript::Transcript; + use tlsn_data_fixtures::http as fixtures; + + #[rstest] + #[case::get_empty(fixtures::request::GET_EMPTY)] + #[case::get_empty_header(fixtures::request::GET_EMPTY_HEADER)] + #[case::get_with_header(fixtures::request::GET_WITH_HEADER)] + #[case::post_json(fixtures::request::POST_JSON)] + fn test_http_default_commit_request(#[case] src: &'static [u8]) { + let transcript = Transcript::new(src, []); + let request = parse_request(src).unwrap(); + let mut committer = DefaultHttpCommitter::default(); + let mut builder = TranscriptCommitConfigBuilder::new(&transcript); + + committer + .commit_request(&mut builder, Direction::Sent, &request) + .unwrap(); + + builder.build().unwrap(); + } + + #[rstest] + #[case::empty(fixtures::response::OK_EMPTY)] + #[case::empty_header(fixtures::response::OK_EMPTY_HEADER)] + #[case::json(fixtures::response::OK_JSON)] + #[case::text(fixtures::response::OK_TEXT)] + fn test_http_default_commit_response(#[case] src: &'static [u8]) { + let transcript = Transcript::new([], src); + let response = parse_response(src).unwrap(); + let mut committer = DefaultHttpCommitter::default(); + let mut builder = TranscriptCommitConfigBuilder::new(&transcript); + + committer + .commit_response(&mut builder, Direction::Received, &response) + .unwrap(); + + builder.build().unwrap(); + } +} diff --git a/crates/formats/src/http/mod.rs b/crates/formats/src/http/mod.rs new file mode 100644 index 0000000000..7640b11abc --- /dev/null +++ b/crates/formats/src/http/mod.rs @@ -0,0 +1,48 @@ +//! Tooling for working with HTTP data. + +mod commit; + +use bytes::Bytes; +pub use commit::{DefaultHttpCommitter, HttpCommit, HttpCommitError}; + +#[doc(hidden)] +pub use spansy::http; + +pub use http::{ + parse_request, parse_response, Body, BodyContent, Header, HeaderName, HeaderValue, Method, + Reason, Request, RequestLine, Requests, Response, Responses, Status, Target, +}; +use tlsn_core::transcript::Transcript; + +/// The kind of HTTP message. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum MessageKind { + /// An HTTP request. + Request, + /// An HTTP response. + Response, +} + +/// An HTTP transcript. +#[derive(Debug)] +pub struct HttpTranscript { + /// The requests sent to the server. + pub requests: Vec, + /// The responses received from the server. + pub responses: Vec, +} + +impl HttpTranscript { + /// Parses the HTTP transcript from the provided transcripts. + pub fn parse(transcript: &Transcript) -> Result { + let requests = Requests::new(Bytes::copy_from_slice(transcript.sent())) + .collect::, _>>()?; + let responses = Responses::new(Bytes::copy_from_slice(transcript.received())) + .collect::, _>>()?; + + Ok(Self { + requests, + responses, + }) + } +} diff --git a/crates/formats/src/json/commit.rs b/crates/formats/src/json/commit.rs new file mode 100644 index 0000000000..725770f5ed --- /dev/null +++ b/crates/formats/src/json/commit.rs @@ -0,0 +1,252 @@ +use std::error::Error; + +use spansy::{json::KeyValue, Spanned}; +use tlsn_core::transcript::{Direction, TranscriptCommitConfigBuilder}; + +use crate::json::{Array, Bool, JsonValue, Null, Number, Object, String as JsonString}; + +/// JSON commitment error. +#[derive(Debug, thiserror::Error)] +#[error("json commitment error: {msg}")] +pub struct JsonCommitError { + msg: String, + #[source] + source: Option>, +} + +impl JsonCommitError { + /// Creates a new JSON commitment error. + /// + /// # Arguments + /// + /// * `msg` - The error message. + pub fn new(msg: impl Into) -> Self { + Self { + msg: msg.into(), + source: None, + } + } + + /// Creates a new JSON commitment error with a source. + /// + /// # Arguments + /// + /// * `msg` - The error message. + /// * `source` - The source error. + pub fn new_with_source(msg: impl Into, source: E) -> Self + where + E: Into>, + { + Self { + msg: msg.into(), + source: Some(source.into()), + } + } + + /// Returns the error message. + pub fn msg(&self) -> &str { + &self.msg + } +} + +/// A JSON committer. +pub trait JsonCommit { + /// Commits to a JSON value. + /// + /// # Arguments + /// + /// * `builder` - The commitment builder. + /// * `value` - The JSON value to commit. + /// * `direction` - The direction of the data (sent or received). + fn commit_value( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + value: &JsonValue, + direction: Direction, + ) -> Result<(), JsonCommitError> { + match value { + JsonValue::Object(obj) => self.commit_object(builder, obj, direction), + JsonValue::Array(arr) => self.commit_array(builder, arr, direction), + JsonValue::String(string) => self.commit_string(builder, string, direction), + JsonValue::Number(number) => self.commit_number(builder, number, direction), + JsonValue::Bool(boolean) => self.commit_bool(builder, boolean, direction), + JsonValue::Null(null) => self.commit_null(builder, null, direction), + } + } + + /// Commits to a JSON object. + /// + /// The default implementation commits the object without any of the + /// key-value pairs, then commits each key-value pair individually. + /// + /// # Arguments + /// + /// * `builder` - The commitment builder. + /// * `object` - The JSON object to commit. + /// * `direction` - The direction of the data (sent or received). + fn commit_object( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + object: &Object, + direction: Direction, + ) -> Result<(), JsonCommitError> { + builder + .commit(&object.without_pairs(), direction) + .map_err(|e| JsonCommitError::new_with_source("failed to commit object", e))?; + + for kv in &object.elems { + self.commit_key_value(builder, kv, direction)?; + } + + Ok(()) + } + + /// Commits to a JSON key-value pair. + /// + /// The default implementation commits the pair without the value, and then + /// commits the value separately. + /// + /// # Arguments + /// + /// * `builder` - The commitment builder. + /// * `kv` - The JSON key-value pair to commit. + /// * `direction` - The direction of the data (sent or received). + fn commit_key_value( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + kv: &KeyValue, + direction: Direction, + ) -> Result<(), JsonCommitError> { + builder + .commit(&kv.without_value(), direction) + .map_err(|e| { + JsonCommitError::new_with_source( + "failed to commit key-value pair excluding the value", + e, + ) + })?; + + self.commit_value(builder, &kv.value, direction) + } + + /// Commits to a JSON array. + /// + /// The default implementation commits to the entire array, then commits the + /// array excluding all values and separators. + /// + /// # Arguments + /// + /// * `builder` - The commitment builder. + /// * `array` - The JSON array to commit. + /// * `direction` - The direction of the data (sent or received). + fn commit_array( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + array: &Array, + direction: Direction, + ) -> Result<(), JsonCommitError> { + builder + .commit(array, direction) + .map_err(|e| JsonCommitError::new_with_source("failed to commit array", e))?; + + if !array.elems.is_empty() { + builder + .commit(&array.without_values(), direction) + .map_err(|e| { + JsonCommitError::new_with_source("failed to commit array excluding values", e) + })?; + } + + // TODO: Commit each value separately, but we need a strategy for handling + // separators. + + Ok(()) + } + + /// Commits to a JSON string. + /// + /// # Arguments + /// + /// * `builder` - The commitment builder. + /// * `string` - The JSON string to commit. + /// * `direction` - The direction of the data (sent or received). + fn commit_string( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + string: &JsonString, + direction: Direction, + ) -> Result<(), JsonCommitError> { + // Skip empty strings. + if string.span().is_empty() { + return Ok(()); + } + + builder + .commit(string, direction) + .map(|_| ()) + .map_err(|e| JsonCommitError::new_with_source("failed to commit string", e)) + } + + /// Commits to a JSON number. + /// + /// # Arguments + /// + /// * `builder` - The commitment builder. + /// * `number` - The JSON number to commit. + /// * `direction` - The direction of the data (sent or received). + fn commit_number( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + number: &Number, + direction: Direction, + ) -> Result<(), JsonCommitError> { + builder + .commit(number, direction) + .map(|_| ()) + .map_err(|e| JsonCommitError::new_with_source("failed to commit number", e)) + } + + /// Commits to a JSON boolean value. + /// + /// # Arguments + /// + /// * `builder` - The commitment builder. + /// * `boolean` - The JSON boolean to commit. + /// * `direction` - The direction of the data (sent or received). + fn commit_bool( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + boolean: &Bool, + direction: Direction, + ) -> Result<(), JsonCommitError> { + builder + .commit(boolean, direction) + .map(|_| ()) + .map_err(|e| JsonCommitError::new_with_source("failed to commit boolean", e)) + } + + /// Commits to a JSON null value. + /// + /// # Arguments + /// + /// * `builder` - The commitment builder. + /// * `null` - The JSON null to commit. + /// * `direction` - The direction of the data (sent or received). + fn commit_null( + &mut self, + builder: &mut TranscriptCommitConfigBuilder, + null: &Null, + direction: Direction, + ) -> Result<(), JsonCommitError> { + builder + .commit(null, direction) + .map(|_| ()) + .map_err(|e| JsonCommitError::new_with_source("failed to commit null", e)) + } +} + +/// Default committer for JSON values. +#[derive(Debug, Default, Clone)] +pub struct DefaultJsonCommitter {} + +impl JsonCommit for DefaultJsonCommitter {} diff --git a/crates/formats/src/json/mod.rs b/crates/formats/src/json/mod.rs new file mode 100644 index 0000000000..e635df0055 --- /dev/null +++ b/crates/formats/src/json/mod.rs @@ -0,0 +1,10 @@ +//! Tooling for working with JSON data. + +mod commit; + +use spansy::json; + +pub use commit::{DefaultJsonCommitter, JsonCommit, JsonCommitError}; +pub use json::{ + Array, Bool, JsonKey, JsonValue, JsonVisit, KeyValue, Null, Number, Object, String, +}; diff --git a/tlsn/tlsn-formats/src/lib.rs b/crates/formats/src/lib.rs similarity index 65% rename from tlsn/tlsn-formats/src/lib.rs rename to crates/formats/src/lib.rs index 54e0b0f9a9..ff0afc8944 100644 --- a/tlsn/tlsn-formats/src/lib.rs +++ b/crates/formats/src/lib.rs @@ -2,10 +2,11 @@ //! //! # Warning //! -//! This library is not yet ready for production use, and should *NOT* be considered secure. +//! This library is not yet ready for production use, and should *NOT* be +//! considered secure. //! -//! At present, this library does not verify that redacted data does not contain control characters which can -//! be used by a malicious prover to cheat. +//! At present, this library does not verify that redacted data does not contain +//! control characters which can be used by a malicious prover to cheat. #![deny(missing_docs, unreachable_pub, unused_must_use)] #![deny(clippy::all)] @@ -13,4 +14,7 @@ pub mod http; pub mod json; -mod unknown; + +#[doc(hidden)] +pub use spansy; +pub use spansy::ParseError; diff --git a/crates/notary/client/Cargo.toml b/crates/notary/client/Cargo.toml new file mode 100644 index 0000000000..cce1bace69 --- /dev/null +++ b/crates/notary/client/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "notary-client" +version = "0.1.0-alpha.7" +edition = "2021" + +[dependencies] +notary-server = { workspace = true } +tlsn-common = { workspace = true } + +derive_builder = { workspace = true } +futures = { workspace = true } +http-body-util = { workspace = true } +hyper = { workspace = true, features = ["client", "http1"] } +hyper-util = { workspace = true, features = ["full"] } +serde_json = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = [ + "rt", + "rt-multi-thread", + "macros", + "net", + "io-std", + "fs", +] } +tokio-rustls = { workspace = true } +tracing = { workspace = true } +webpki-roots = { workspace = true } diff --git a/crates/notary/client/src/client.rs b/crates/notary/client/src/client.rs new file mode 100644 index 0000000000..fdfc9e9e18 --- /dev/null +++ b/crates/notary/client/src/client.rs @@ -0,0 +1,405 @@ +//! Notary client. +//! +//! This module sets up connection to notary server via TCP or TLS for +//! subsequent requests for notarization. + +use http_body_util::{BodyExt as _, Either, Empty, Full}; +use hyper::{body::Bytes, client::conn::http1::Parts, Request, StatusCode}; +use hyper_util::rt::TokioIo; +use notary_server::{ClientType, NotarizationSessionRequest, NotarizationSessionResponse}; +use std::{ + io::Error as IoError, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::TcpStream, +}; +use tokio_rustls::{ + client::TlsStream, + rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore}, + TlsConnector, +}; +use tracing::{debug, error}; + +use crate::error::{ClientError, ErrorKind}; + +/// Parameters used to configure notarization. +#[derive(Debug, Clone, derive_builder::Builder)] +pub struct NotarizationRequest { + /// Maximum number of bytes that can be sent. + max_sent_data: usize, + /// Maximum number of bytes that can be received. + max_recv_data: usize, +} + +impl NotarizationRequest { + /// Creates a new builder for `NotarizationRequest`. + pub fn builder() -> NotarizationRequestBuilder { + NotarizationRequestBuilder::default() + } +} + +/// An accepted notarization request. +#[derive(Debug)] +#[non_exhaustive] +pub struct Accepted { + /// Session identifier. + pub id: String, + /// Connection to the notary server to be used by a prover. + pub io: NotaryConnection, +} + +/// A notary server connection. +#[derive(Debug)] +pub enum NotaryConnection { + /// Unencrypted TCP connection. + Tcp(TcpStream), + /// TLS connection. + Tls(TlsStream), +} + +impl AsyncRead for NotaryConnection { + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + NotaryConnection::Tcp(stream) => Pin::new(stream).poll_read(cx, buf), + NotaryConnection::Tls(stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for NotaryConnection { + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + NotaryConnection::Tcp(stream) => Pin::new(stream).poll_write(cx, buf), + NotaryConnection::Tls(stream) => Pin::new(stream).poll_write(cx, buf), + } + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + NotaryConnection::Tcp(stream) => Pin::new(stream).poll_flush(cx), + NotaryConnection::Tls(stream) => Pin::new(stream).poll_flush(cx), + } + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + NotaryConnection::Tcp(stream) => Pin::new(stream).poll_shutdown(cx), + NotaryConnection::Tls(stream) => Pin::new(stream).poll_shutdown(cx), + } + } +} + +/// Client that sets up connection to notary server. +#[derive(Debug, Clone, derive_builder::Builder)] +pub struct NotaryClient { + /// Host of the notary server endpoint, either a DNS name (if TLS is used) + /// or IP address. + #[builder(setter(into))] + host: String, + /// Port of the notary server endpoint. + #[builder(default = "self.default_port()")] + port: u16, + /// URL path prefix of the notary server endpoint, e.g. "https://://...". + #[builder(setter(into), default = "String::from(\"\")")] + path_prefix: String, + /// Flag to turn on/off using TLS with notary server. + #[builder(setter(name = "enable_tls"), default = "true")] + tls: bool, + /// Root certificate store used for establishing TLS connection with notary + /// server. + #[builder(default = "default_root_store()")] + root_cert_store: RootCertStore, + /// API key used to call notary server endpoints if whitelisting is enabled + /// in notary server. + #[builder(setter(into, strip_option), default)] + api_key: Option, +} + +impl NotaryClientBuilder { + // Default setter of port. + fn default_port(&self) -> u16 { + // If port is not specified, set it to 80 if TLS is off, else 443 since TLS is + // on (including when self.tls = None, which means it's set to default + // (true)). + if let Some(false) = self.tls { + 80 + } else { + 443 + } + } +} + +impl NotaryClient { + /// Creates a new builder for `NotaryClient`. + pub fn builder() -> NotaryClientBuilder { + NotaryClientBuilder::default() + } + + /// Configures and requests a notarization, returning a connection to the + /// notary server if successful. + pub async fn request_notarization( + &self, + notarization_request: NotarizationRequest, + ) -> Result { + if self.tls { + debug!("Setting up tls connection..."); + + let notary_client_config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(self.root_cert_store.clone()) + .with_no_client_auth(); + + let notary_socket = tokio::net::TcpStream::connect((self.host.as_str(), self.port)) + .await + .map_err(|err| ClientError::new(ErrorKind::Connection, Some(Box::new(err))))?; + + let notary_connector = TlsConnector::from(Arc::new(notary_client_config)); + let notary_tls_socket = notary_connector + .connect( + self.host.as_str().try_into().map_err(|err| { + error!("Failed to parse notary server DNS name: {:?}", self.host); + ClientError::new(ErrorKind::TlsSetup, Some(Box::new(err))) + })?, + notary_socket, + ) + .await + .map_err(|err| ClientError::new(ErrorKind::TlsSetup, Some(Box::new(err))))?; + + self.send_request(notary_tls_socket, notarization_request) + .await + .map(|(connection, session_id)| Accepted { + id: session_id, + io: NotaryConnection::Tls(connection), + }) + } else { + debug!("Setting up tcp connection..."); + + let notary_socket = tokio::net::TcpStream::connect((self.host.as_str(), self.port)) + .await + .map_err(|err| ClientError::new(ErrorKind::Connection, Some(Box::new(err))))?; + + self.send_request(notary_socket, notarization_request) + .await + .map(|(connection, session_id)| Accepted { + id: session_id, + io: NotaryConnection::Tcp(connection), + }) + } + } + + /// Sends notarization request to the notary server. + async fn send_request( + &self, + notary_socket: S, + notarization_request: NotarizationRequest, + ) -> Result<(S, String), ClientError> { + let http_scheme = if self.tls { "https" } else { "http" }; + let path_prefix = if self.path_prefix.is_empty() { + String::new() + } else { + format!("/{}", self.path_prefix) + }; + + // Attach the hyper HTTP client to the notary connection to send request to the + // /session endpoint to configure notarization and obtain session id. + let (mut notary_request_sender, notary_connection) = + hyper::client::conn::http1::handshake(TokioIo::new(notary_socket)) + .await + .map_err(|err| { + error!("Failed to attach http client to notary socket"); + ClientError::new(ErrorKind::Connection, Some(Box::new(err))) + })?; + + // Create a future to poll the notary connection to completion before extracting + // the socket. + let notary_connection_fut = async { + // Claim back notary socket after HTTP exchange is done. + let Parts { + io: notary_socket, .. + } = notary_connection.without_shutdown().await.map_err(|err| { + error!("Failed to claim back notary socket after HTTP exchange is done"); + ClientError::new(ErrorKind::Internal, Some(Box::new(err))) + })?; + + Ok(notary_socket) + }; + + // Create a future to send configuration and notarization requests to the notary + // server using the connection established above. + let client_requests_fut = async { + // Build the HTTP request to configure notarization. + let configuration_request_payload = + serde_json::to_string(&NotarizationSessionRequest { + client_type: ClientType::Tcp, + max_sent_data: Some(notarization_request.max_sent_data), + max_recv_data: Some(notarization_request.max_recv_data), + }) + .map_err(|err| { + error!("Failed to serialise http request for configuration"); + ClientError::new(ErrorKind::Internal, Some(Box::new(err))) + })?; + + let mut configuration_request_builder = Request::builder() + .uri(format!( + "{http_scheme}://{}:{}{}/session", + self.host, self.port, path_prefix + )) + .method("POST") + .header("Host", &self.host) + // Need to specify application/json for axum to parse it as json. + .header("Content-Type", "application/json"); + + if let Some(api_key) = &self.api_key { + configuration_request_builder = + configuration_request_builder.header("Authorization", api_key); + } + + let configuration_request = configuration_request_builder + .body(Either::Left(Full::new(Bytes::from( + configuration_request_payload, + )))) + .map_err(|err| { + error!("Failed to build http request for configuration"); + ClientError::new(ErrorKind::Internal, Some(Box::new(err))) + })?; + + debug!("Sending configuration request: {:?}", configuration_request); + + let configuration_response = notary_request_sender + .send_request(configuration_request) + .await + .map_err(|err| { + error!("Failed to send http request for configuration"); + ClientError::new(ErrorKind::Http, Some(Box::new(err))) + })?; + + debug!("Sent configuration request"); + + if configuration_response.status() != StatusCode::OK { + return Err(ClientError::new( + ErrorKind::Configuration, + Some( + format!( + "Configuration response status is not OK: {:?}", + configuration_response + ) + .into(), + ), + )); + } + + let configuration_response_payload = configuration_response + .into_body() + .collect() + .await + .map_err(|err| { + error!("Failed to parse configuration response"); + ClientError::new(ErrorKind::Http, Some(Box::new(err))) + })? + .to_bytes(); + + let configuration_response_payload_parsed = + serde_json::from_str::(&String::from_utf8_lossy( + &configuration_response_payload, + )) + .map_err(|err| { + error!("Failed to parse configuration response payload"); + ClientError::new(ErrorKind::Internal, Some(Box::new(err))) + })?; + + debug!( + "Configuration response: {:?}", + configuration_response_payload_parsed + ); + + // Send notarization request via HTTP, where the underlying TCP/TLS connection + // will be extracted later. + let notarization_request = Request::builder() + // Need to specify the session_id so that notary server knows the right + // configuration to use as the configuration is set in the previous + // HTTP call. + .uri(format!( + "{http_scheme}://{}:{}{}/notarize?sessionId={}", + self.host, + self.port, + path_prefix, + &configuration_response_payload_parsed.session_id + )) + .method("GET") + .header("Host", &self.host) + .header("Connection", "Upgrade") + // Need to specify this upgrade header for server to extract TCP/TLS connection + // later. + .header("Upgrade", "TCP") + .body(Either::Right(Empty::::new())) + .map_err(|err| { + error!("Failed to build http request for notarization"); + ClientError::new(ErrorKind::Internal, Some(Box::new(err))) + })?; + + debug!("Sending notarization request: {:?}", notarization_request); + + let notarization_response = notary_request_sender + .send_request(notarization_request) + .await + .map_err(|err| { + error!("Failed to send http request for notarization"); + ClientError::new(ErrorKind::Http, Some(Box::new(err))) + })?; + + debug!("Sent notarization request"); + + if notarization_response.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(ClientError::new( + ErrorKind::Internal, + Some( + format!( + "Notarization response status is not SWITCHING_PROTOCOL: {:?}", + notarization_response + ) + .into(), + ), + )); + } + + Ok(configuration_response_payload_parsed.session_id) + }; + + // Poll both futures simultaneously to obtain the resulting socket and + // session_id. + let (notary_socket, session_id) = + futures::try_join!(notary_connection_fut, client_requests_fut)?; + + Ok((notary_socket.into_inner(), session_id)) + } +} + +/// Default root store using mozilla certs. +fn default_root_store() -> RootCertStore { + let mut root_store = RootCertStore::empty(); + root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject.as_ref(), + ta.subject_public_key_info.as_ref(), + ta.name_constraints.as_ref().map(|nc| nc.as_ref()), + ) + })); + + root_store +} diff --git a/crates/notary/client/src/error.rs b/crates/notary/client/src/error.rs new file mode 100644 index 0000000000..360ed03d1b --- /dev/null +++ b/crates/notary/client/src/error.rs @@ -0,0 +1,48 @@ +//! Notary client errors. +//! +//! This module handles errors that might occur during connection setup and +//! notarization requests. + +use derive_builder::UninitializedFieldError; +use std::{error::Error, fmt}; + +#[derive(Debug)] +#[allow(missing_docs)] +pub(crate) enum ErrorKind { + Internal, + Builder, + Connection, + TlsSetup, + Http, + Configuration, +} + +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub struct ClientError { + kind: ErrorKind, + #[source] + source: Option>, +} + +impl ClientError { + pub(crate) fn new(kind: ErrorKind, source: Option>) -> Self { + Self { kind, source } + } +} + +impl fmt::Display for ClientError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "client error: {:?}, source: {:?}", + self.kind, self.source + ) + } +} + +impl From for ClientError { + fn from(ufe: UninitializedFieldError) -> Self { + ClientError::new(ErrorKind::Builder, Some(Box::new(ufe))) + } +} diff --git a/crates/notary/client/src/lib.rs b/crates/notary/client/src/lib.rs new file mode 100644 index 0000000000..445267a4aa --- /dev/null +++ b/crates/notary/client/src/lib.rs @@ -0,0 +1,15 @@ +//! Notary client library. +//! +//! A notary client's purpose is to establish a connection to the notary server +//! via TCP or TLS, and to configure and request notarization. +//! Note that the actual notarization is not performed by the notary client but +//! by the prover of the TLSNotary protocol. +#![deny(missing_docs, unreachable_pub, unused_must_use)] +#![deny(clippy::all)] +#![forbid(unsafe_code)] + +mod client; +mod error; + +pub use client::{Accepted, NotarizationRequest, NotaryClient, NotaryConnection}; +pub use error::ClientError; diff --git a/crates/notary/server/Cargo.toml b/crates/notary/server/Cargo.toml new file mode 100644 index 0000000000..e2e1087e0e --- /dev/null +++ b/crates/notary/server/Cargo.toml @@ -0,0 +1,51 @@ +[package] +name = "notary-server" +version = "0.1.0-alpha.7" +edition = "2021" + +[dependencies] +tlsn-core = { workspace = true } +tlsn-common = { workspace = true } +tlsn-verifier = { workspace = true } + +async-trait = { workspace = true } +async-tungstenite = { workspace = true, features = ["tokio-native-tls"] } +axum = { workspace = true, features = ["ws"] } +axum-core = { version = "0.4" } +axum-macros = { version = "0.4" } +base64 = { version = "0.21" } +chrono = { version = "0.4" } +csv = { version = "1.3" } +eyre = { version = "0.6" } +futures = { workspace = true } +futures-util = { workspace = true } +http = { workspace = true } +http-body-util = { workspace = true } +hyper = { workspace = true, features = ["client", "http1", "server"] } +hyper-util = { workspace = true, features = ["full"] } +k256 = { workspace = true } +notify = { version = "6.1.1", default-features = false, features = [ + "macos_kqueue", +] } +p256 = { workspace = true } +pkcs8 = { workspace = true, features = ["pem"] } +rustls = { workspace = true } +rustls-pemfile = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +serde_yaml = { version = "0.9" } +sha1 = { version = "0.10" } +structopt = { version = "0.3" } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["full"] } +tokio-rustls = { workspace = true } +tokio-util = { workspace = true, features = ["compat"] } +tower = { version = "0.4", features = ["make"] } +tower-http = { version = "0.5", features = ["cors"] } +tower-service = { version = "0.3" } +tower-util = { version = "0.3.1" } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } +uuid = { workspace = true, features = ["v4", "fast-rng"] } +ws_stream_tungstenite = { workspace = true, features = ["tokio_io"] } +zeroize = { workspace = true } diff --git a/crates/notary/server/README.md b/crates/notary/server/README.md new file mode 100644 index 0000000000..fa5b50bc3c --- /dev/null +++ b/crates/notary/server/README.md @@ -0,0 +1,136 @@ +# notary-server + +An implementation of the notary server in Rust. + +## ⚠️ Notice + +This crate is currently under active development and should not be used in production. Expect bugs and regular major breaking changes. + +--- +## Running the server +### ⚠️ Notice +- When running this server against a prover (e.g. [Rust](../../examples/) or [browser extension](https://github.com/tlsnotary/tlsn-extension)), please ensure that the prover's version is the same as the version of this server +- When running this server in a *production environment*, please first read this [page](https://docs.tlsnotary.org/developers/notary_server.html) +- When running this server in a *local environment* with a browser extension, please turn off this server's TLS in the config (refer [here](#optional-tls)) + +### Using Cargo +1. Configure the server setting in this config [file](./config/config.yaml) — refer [here](./src/config.rs) for more information on the definition of the setting parameters. +2. Start the server by running the following in a terminal at the root of this crate. +```bash +cargo run --release +``` +3. To use a config file from a different location, run the following command to override the default config file location. +```bash +cargo run --release -- --config-file +``` + +### Using Docker +There are two ways to obtain the notary server's Docker image: +- [GitHub](#obtaining-the-image-via-github) +- [Building from source](#building-from-source) + +#### GitHub +1. Obtain the latest image with: +```bash +docker pull ghcr.io/tlsnotary/tlsn/notary-server:latest +``` +2. Run the docker container with: +```bash +docker run --init -p 127.0.0.1:7047:7047 ghcr.io/tlsnotary/tlsn/notary-server:latest +``` +3. If you want to change the default configuration, create a `config` folder locally, that contains a `config.yaml`, whose content follows the format of the default config file [here](./config/config.yaml). +4. Instead of step 2, run the docker container with the following (remember to change the port mapping if you have changed that in the config): +```bash +docker run --init -p 127.0.0.1:7047:7047 -v :/root/.notary-server/config ghcr.io/tlsnotary/tlsn/notary-server:latest +``` + +#### Building from source +1. Configure the server setting in this config [file](./config/config.yaml). +2. Build the docker image by running the following in a terminal at the root of this *repository*. +```bash +docker build . -t notary-server:local -f crates/notary/server/notary-server.Dockerfile +``` +3. Run the docker container and specify the port specified in the config file, e.g. for the default port 7047 +```bash +docker run --init -p 127.0.0.1:7047:7047 notary-server:local +``` + +### Using different setting files with Docker +1. Instead of changing the key/cert/auth file path(s) in the config file, create a folder containing your key/cert/auth files by following the folder structure [here](./fixture/). +2. When launching the docker container, mount your folder onto the docker container at the relevant path prefixed by `/root/.notary-server`. +- Example 1: Using different key, cert, and auth files: +```bash +docker run --init -p 127.0.0.1:7047:7047 -v :/root/.notary-server/fixture notary-server:local +``` +- Example 2: Using a different key for notarizations: +```bash +docker run --init -p 127.0.0.1:7047:7047 -v :/root/.notary-server/fixture/notary notary-server:local +``` +--- +## API +All APIs are TLS-protected, hence please use `https://` or `wss://`. +### HTTP APIs +Defined in the [OpenAPI specification](./openapi.yaml). + +### WebSocket APIs +#### /notarize +##### Description +To perform a notarization using a session id (an unique id returned upon calling the `/session` endpoint successfully). + +##### Query Parameter +`sessionId` + +##### Query Parameter Type +String + +--- +## Logging +The default logging strategy of this server is set to `DEBUG` verbosity level for the crates that are useful for most debugging scenarios, i.e. using the following filtering logic: + +`notary_server=DEBUG,tlsn_verifier=DEBUG,tls_mpc=DEBUG,tls_client_async=DEBUG` + +In the config [file](./config/config.yaml), one can toggle the verbosity level for these crates using the `level` field under `logging`. + +One can also provide a custom filtering logic by adding a `filter` field under `logging` in the config file above, and use a value that follows the tracing crate's [filter directive syntax](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#example-syntax). + +--- +## Architecture +### Objective +The main objective of a notary server is to perform notarizations together with a prover. In this case, the prover can either be a +1. TCP client — which has access and control over the transport layer, i.e. TCP +2. WebSocket client — which has no access over TCP and instead uses WebSocket for notarizations + +### Features +#### Notarization Configuration +To perform a notarization, some parameters need to be configured by the prover and the notary server (more details in the [OpenAPI specification](./openapi.yaml)), i.e. +- maximum data that can be sent and received +- unique session id + +To streamline this process, a single HTTP endpoint (`/session`) is used by both TCP and WebSocket clients. + +#### Notarization +After calling the configuration endpoint above, the prover can proceed to start the notarization. For a TCP client, that means calling the `/notarize` endpoint using HTTP (`https`), while a WebSocket client should call the same endpoint but using WebSocket (`wss`). Example implementations of these clients can be found in the [integration test](../tests-integration/tests/notary.rs). + +#### Signatures +Currently, both the private key (and cert) used to establish a TLS connection with the prover, and the private key used by the notary server to sign the notarized transcript, are hardcoded PEM keys stored in this repository. Though the paths of these keys can be changed in the config (`notary-key` field) to use different keys instead. + +#### Authorization +An optional authorization module is available to only allow requests with a valid API key attached in the authorization header. The API key whitelist path (as well as the flag to enable/disable this module) can be changed in the config (`authorization` field). + +Hot reloading of the whitelist is supported, i.e. modification of the whitelist file will be automatically applied without needing to restart the server. Please take note of the following +- Avoid using auto save mode when editing the whitelist to prevent spamming hot reloads +- Once the edit is saved, ensure that it has been reloaded successfully by checking the server log + +#### Optional TLS +TLS between the prover and the notary is currently manually handled in this server, though it can be turned off if any of the following is true +- This server is run locally +- TLS is to be handled by an external environment, e.g. reverse proxy, cloud setup + +The toggle to turn on/off TLS is in the config (`tls` field). + +### Design Choices +#### Web Framework +Axum is chosen as the framework to serve HTTP and WebSocket requests from the prover clients due to its rich and well supported features, e.g. native integration with Tokio/Hyper/Tower, customizable middleware, the ability to support lower level integrations of TLS ([example](https://github.com/tokio-rs/axum/blob/main/examples/low-level-rustls/src/main.rs)). To simplify the notary server setup, a single Axum router is used to support both HTTP and WebSocket connections, i.e. all requests can be made to the same port of the notary server. + +#### WebSocket +Axum's internal implementation of WebSocket uses [tokio_tungstenite](https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/), which provides a WebSocket struct that doesn't implement [AsyncRead](https://docs.rs/futures/latest/futures/io/trait.AsyncRead.html) and [AsyncWrite](https://docs.rs/futures/latest/futures/io/trait.AsyncWrite.html). Both these traits are required by the TLSN core libraries for the prover and the notary. To overcome this, a [slight modification](./src/service/axum_websocket.rs) of Axum's implementation of WebSocket is used, where [async_tungstenite](https://docs.rs/async-tungstenite/latest/async_tungstenite/) is used instead so that [ws_stream_tungstenite](https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html) can be used to wrap on top of the WebSocket struct to get AsyncRead and AsyncWrite implemented. diff --git a/notary-server/build.rs b/crates/notary/server/build.rs similarity index 100% rename from notary-server/build.rs rename to crates/notary/server/build.rs diff --git a/crates/notary/server/config/config.yaml b/crates/notary/server/config/config.yaml new file mode 100644 index 0000000000..18336dc567 --- /dev/null +++ b/crates/notary/server/config/config.yaml @@ -0,0 +1,32 @@ +server: + name: "notary-server" + host: "0.0.0.0" + port: 7047 + html-info: | +

Notary Server {version}!

+
    +
  • git commit hash: {git_commit_hash}
  • +
  • git commit timestamp: {git_commit_timestamp}
  • +
  • public key:
    {public_key}
  • +
+ health check - info
+ +notarization: + max-sent-data: 4096 + max-recv-data: 16384 + +tls: + enabled: true + private-key-pem-path: "./fixture/tls/notary.key" + certificate-pem-path: "./fixture/tls/notary.crt" + +notary-key: + private-key-pem-path: "./fixture/notary/notary.key" + public-key-pem-path: "./fixture/notary/notary.pub" + +logging: + level: DEBUG + +authorization: + enabled: false + whitelist-csv-path: "./fixture/auth/whitelist.csv" diff --git a/crates/notary/server/fixture/.gitignore b/crates/notary/server/fixture/.gitignore new file mode 100644 index 0000000000..74e5b1d3a9 --- /dev/null +++ b/crates/notary/server/fixture/.gitignore @@ -0,0 +1 @@ +!* \ No newline at end of file diff --git a/notary-server/fixture/auth/whitelist.csv b/crates/notary/server/fixture/auth/whitelist.csv similarity index 100% rename from notary-server/fixture/auth/whitelist.csv rename to crates/notary/server/fixture/auth/whitelist.csv diff --git a/crates/notary/server/fixture/notary/notary.key b/crates/notary/server/fixture/notary/notary.key new file mode 100644 index 0000000000..716507fdc3 --- /dev/null +++ b/crates/notary/server/fixture/notary/notary.key @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQgbGCmm+WHxwlKKKRWddfO +02TmpM787BJQuoVrHeCI5v6hRANCAAR7SPGcE5toiPteODpNcsIzUYb9WFjnrnQ6 +tL+OBxsG5+j9AN8W8v+KvMi/UlKaEaJVywIcLCiWENdZyB7u/Yix +-----END PRIVATE KEY----- diff --git a/crates/notary/server/fixture/notary/notary.pub b/crates/notary/server/fixture/notary/notary.pub new file mode 100644 index 0000000000..2e44ba7010 --- /dev/null +++ b/crates/notary/server/fixture/notary/notary.pub @@ -0,0 +1,4 @@ +-----BEGIN PUBLIC KEY----- +MDYwEAYHKoZIzj0CAQYFK4EEAAoDIgADe0jxnBObaIj7Xjg6TXLCM1GG/VhY5650 +OrS/jgcbBuc= +-----END PUBLIC KEY----- diff --git a/crates/notary/server/fixture/tls/README.md b/crates/notary/server/fixture/tls/README.md new file mode 100644 index 0000000000..a3775890eb --- /dev/null +++ b/crates/notary/server/fixture/tls/README.md @@ -0,0 +1,14 @@ +# Create a private key for the root CA +openssl genpkey -algorithm RSA -out rootCA.key -pkeyopt rsa_keygen_bits:2048 + +# Create a self-signed root CA certificate (100 years validity) +openssl req -x509 -new -nodes -key rootCA.key -sha256 -days 36525 -out rootCA.crt -subj "/C=US/ST=State/L=City/O=tlsnotary/OU=IT/CN=tlsnotary.org" + +# Create a private key for the end entity certificate +openssl genpkey -algorithm RSA -out notary.key -pkeyopt rsa_keygen_bits:2048 + +# Create a certificate signing request (CSR) for the end entity certificate +openssl req -new -key notary.key -out notary.csr -subj "/C=US/ST=State/L=City/O=tlsnotary/OU=IT/CN=tlsnotaryserver.io" + +# Sign the CSR with the root CA to create the end entity certificate (100 years validity) +openssl x509 -req -in notary.csr -CA rootCA.crt -CAkey rootCA.key -CAcreateserial -out notary.crt -days 36525 -sha256 -extfile openssl.cnf -extensions v3_req diff --git a/crates/notary/server/fixture/tls/notary.crt b/crates/notary/server/fixture/tls/notary.crt new file mode 100644 index 0000000000..8df0d5b4ef --- /dev/null +++ b/crates/notary/server/fixture/tls/notary.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIIDzTCCArWgAwIBAgIJALo+PtyTmxENMA0GCSqGSIb3DQEBCwUAMGUxCzAJBgNV +BAYTAlVTMQ4wDAYDVQQIDAVTdGF0ZTENMAsGA1UEBwwEQ2l0eTESMBAGA1UECgwJ +dGxzbm90YXJ5MQswCQYDVQQLDAJJVDEWMBQGA1UEAwwNdGxzbm90YXJ5Lm9yZzAg +Fw0yNDA4MDIxMTE1MzZaGA8yMTI0MDgwMzExMTUzNlowajELMAkGA1UEBhMCVVMx +DjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5MRIwEAYDVQQKDAl0bHNub3Rh +cnkxCzAJBgNVBAsMAklUMRswGQYDVQQDDBJ0bHNub3RhcnlzZXJ2ZXIuaW8wggEi +MA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDEzkZE9X7Utn3by4sFG8KcDrdV +3szzPP9eA8U4cVmrWQAS0lsrEeHDv0KGKMFKOi3FDgyF1I8OWMIvnWj4LQ1zKYny +fufOkAv4UcYY0E9/VonqPKY0Xo9lbbl5Xu/E55gfJhAPZzoV73uXjvlhSVdhaypZ +ibSZm9t5izTiK1pcKDuvubB5zhmldt1+f0wbBxhLWVlf8T8GaPVZ37NCJGeeUf6Z +GL6Fq4jBYfvjzUQl6P72Zk0FCpIq2W/z2yBfWnNRRPjQuzIxk7cB6ssVpQF52cXZ +OF5YJhc7C/hr4rfWLshGQxkmwNktBSHrQUBm3LQHaT9ccPy0xgdIAD9Avf0BAgMB +AAGjeTB3MAkGA1UdEwQCMAAwCwYDVR0PBAQDAgXgMB0GA1UdEQQWMBSCEnRsc25v +dGFyeXNlcnZlci5pbzAdBgNVHQ4EFgQULo1DGRbjA/+zX9AvRk6YcO2AYHowHwYD +VR0jBBgwFoAUKmfDzMNGdJr5blSUarmhRIiI88IwDQYJKoZIhvcNAQELBQADggEB +AFgTVLHCfaswU8pTwgRK1xWTGlMDQmZU//Lbatel6HTH0zMF4wj/hVkGikHWpJLt +1UipGRPUgyjFtDoPag8jrSDeK1ahtjNzkGEuz5wXM0zYqIv1xkXPatEbCV4LLI3Q +Yxf2YI7Nh599+2I/oZ+8YKUMn6EI58PgiSjyG7vzRoQKGAoE82FpBFyEUpcUXQDa +MIr/D8Xcv+RPpdHxi4cyHJAy+irzs9ghF7WdmFEOATYNF8EsP/doiskXWl68t2Hn +sDflDIbOH1xId3zJIwE/5IG3NrNqhVm2va06TNWURo3v8h+7bnD8Rxq107ObflKj +i1MwBiwdf7/w5Dw9o3K21ic= +-----END CERTIFICATE----- diff --git a/crates/notary/server/fixture/tls/notary.csr b/crates/notary/server/fixture/tls/notary.csr new file mode 100644 index 0000000000..e8d88ef185 --- /dev/null +++ b/crates/notary/server/fixture/tls/notary.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICrzCCAZcCAQAwajELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYD +VQQHDARDaXR5MRIwEAYDVQQKDAl0bHNub3RhcnkxCzAJBgNVBAsMAklUMRswGQYD +VQQDDBJ0bHNub3RhcnlzZXJ2ZXIuaW8wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQDEzkZE9X7Utn3by4sFG8KcDrdV3szzPP9eA8U4cVmrWQAS0lsrEeHD +v0KGKMFKOi3FDgyF1I8OWMIvnWj4LQ1zKYnyfufOkAv4UcYY0E9/VonqPKY0Xo9l +bbl5Xu/E55gfJhAPZzoV73uXjvlhSVdhaypZibSZm9t5izTiK1pcKDuvubB5zhml +dt1+f0wbBxhLWVlf8T8GaPVZ37NCJGeeUf6ZGL6Fq4jBYfvjzUQl6P72Zk0FCpIq +2W/z2yBfWnNRRPjQuzIxk7cB6ssVpQF52cXZOF5YJhc7C/hr4rfWLshGQxkmwNkt +BSHrQUBm3LQHaT9ccPy0xgdIAD9Avf0BAgMBAAGgADANBgkqhkiG9w0BAQsFAAOC +AQEAups2oJRV5x/BZcZvRseWpGToqr5pO3ESXUEEbCpeHDKLIav4aWfYUkY4UGGN +2m1XYN7nEytwygJmMRWS8kjJzacII9j+dCysqCmm71T2L4BszCCVYGwTAigZuZ1R +WmULhso1tXXUF7ggEdTUpxMa5VijkbpZ5iQfBbslpSo0mjgM2bL4hO3Y8dl7a1Bn +0LNasWzWaizp6SkMU2BDNVF+i5blR4p8Bk0GQpPzGqwZf2tKcqmvutPqEm4rcOOC +U5j/U6uZpCYc8VQOklOUkDUSAZzCSJxeGHykddtMFte5+HkqBZoMCQwHeZl1g0qZ +/NLvHB8YO7U2XRJTfxloHhj3WQ== +-----END CERTIFICATE REQUEST----- diff --git a/crates/notary/server/fixture/tls/notary.key b/crates/notary/server/fixture/tls/notary.key new file mode 100644 index 0000000000..71f8b04401 --- /dev/null +++ b/crates/notary/server/fixture/tls/notary.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDEzkZE9X7Utn3b +y4sFG8KcDrdV3szzPP9eA8U4cVmrWQAS0lsrEeHDv0KGKMFKOi3FDgyF1I8OWMIv +nWj4LQ1zKYnyfufOkAv4UcYY0E9/VonqPKY0Xo9lbbl5Xu/E55gfJhAPZzoV73uX +jvlhSVdhaypZibSZm9t5izTiK1pcKDuvubB5zhmldt1+f0wbBxhLWVlf8T8GaPVZ +37NCJGeeUf6ZGL6Fq4jBYfvjzUQl6P72Zk0FCpIq2W/z2yBfWnNRRPjQuzIxk7cB +6ssVpQF52cXZOF5YJhc7C/hr4rfWLshGQxkmwNktBSHrQUBm3LQHaT9ccPy0xgdI +AD9Avf0BAgMBAAECggEACAuCldkPOTTIik6UvT24Q9baKbl02VCaA8bVrgv8JWP6 ++8n7jhQqDW1pE8Dgvd8I9fAwFNxuiKCaN4YQv2xgC2AcUnxbj3cV9i2pkmQZi9QG +yTt3c9aVuAi3Nz3pQTxSXJuatnZ6ymDCxZxDl3V/C+1sisJ1Tn4vh5VoMQKiq/eq +sgmNF73VvKiHUeJGpMixho9AeFfE/o+HTAXcycmHlXBvJzOMsgmgTxinBNt54ROc +WKtt4GNvkxN72e/qu2rNPJvi/Hdq9cG03LmuaOn9dSbHWOdeLZFR2OkO6jvrUKv0 +doDOYUsdCBIC/LaLvtOBzts3d3BZou9MOcz9cXUBDQKBgQDrFomWXDEnqH6o4ixg +1VpZfaK+fRHasni9txZC0zK44HOL4qnUjLLQ7GPyam44exq2/B2ouTCrV+TicoKy +GYAjCL+/rhZfiMoWrXO/SIU73TiXbYQiNEV2FYuAbvCsW9rLRH2/GzYu+fb/hUBP +vYLn9gf2u/nCiqGZ4dJnmrEOIwKBgQDWT/eE6+4HQB8Zhz6Dxu6/EOlwZtC9QsiT +CEEYja7yl0MDnQsQzv2nbCs5li6359IQJR1+L6G3kcd0z6EumWpB0JVQzMtKwlVe +WodfpPOWrectftfGUr+3xZIRbSV/hj89RhG2uhQ8xJYDLWHjP3LF0menvW/1JP0K +xhlGKqRwiwKBgQCtzz3uYz8ceSEcMAxrk5J3M8JNYB8BOI64hVL6GTgZJCmJtQ2n +TlcuzHeg1Tukmq/HtmMfSbxIEnXxToR+tQfd3ywVxdpYy8POPHOlazLGberXWmsk +9sycX5WCYYOji04alwr5bl8DIGCTzqsbyZutcGO28ofYY7LTGPj9DIv3TQKBgHiM +gq5CB6IMb3HsoT1+qMzQtn6DVuceqbQK8JLfH4lVjFx7+b16sTN7pNS/pYfM3lw2 +hGB2aoDXf1o1cHTF1v8uVM8eYzuqFFr+kSc7ockgCOmOb9EeurikaYVj37Pbz7an +s08VXEzSR4+B943cIrMjpyqzZEaAh9WHmK/fTKABAoGACsgGUB84KSV+LBz/LY9M +5xYJNuf11Jucx6bahX6wPZyssLnZ5o+x7QmFIVXyPnQ7wn69C7EfvKJIgdXmjEEk +P4oUh7Osc5UwTR5s7Kr9iCqcDIR5NW68AFHEddMEXpOyFn4QrjrdYSlO4CQAiQUU +Nudz2KSI148F/vzo0I78A7c= +-----END PRIVATE KEY----- diff --git a/crates/notary/server/fixture/tls/openssl.cnf b/crates/notary/server/fixture/tls/openssl.cnf new file mode 100644 index 0000000000..dffdf04f9d --- /dev/null +++ b/crates/notary/server/fixture/tls/openssl.cnf @@ -0,0 +1,7 @@ +[ v3_req ] +basicConstraints = CA:FALSE +keyUsage = nonRepudiation, digitalSignature, keyEncipherment +subjectAltName = @alt_names + +[ alt_names ] +DNS.1 = tlsnotaryserver.io \ No newline at end of file diff --git a/crates/notary/server/fixture/tls/rootCA.crt b/crates/notary/server/fixture/tls/rootCA.crt new file mode 100644 index 0000000000..090534e1a4 --- /dev/null +++ b/crates/notary/server/fixture/tls/rootCA.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDrTCCApWgAwIBAgIUUmpF/+i9EcDpciV0s1Okh4Wx/QswDQYJKoZIhvcNAQEL +BQAwZTELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MRIwEAYDVQQKDAl0bHNub3RhcnkxCzAJBgNVBAsMAklUMRYwFAYDVQQDDA10bHNu +b3Rhcnkub3JnMCAXDTI0MDgwMjExMDU1MloYDzIxMjQwODAzMTEwNTUyWjBlMQsw +CQYDVQQGEwJVUzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxEjAQBgNV +BAoMCXRsc25vdGFyeTELMAkGA1UECwwCSVQxFjAUBgNVBAMMDXRsc25vdGFyeS5v +cmcwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDDNGFXBMov4HBr4F/W ++9mzM4t+ww4jURyF/7O1puyhz0gueAu5/kzh6d5r+P2xwP0tpqtITvwfo2tHCNTg +dKBNPO7NnRnW8QtommHhafHUfj+4cR7G1xxSZD34mwuBnYW3cmxCbi0l5dClWfHA +G7GRHv5aPBBYbeF2ACYBesaCJLa5OMkab/N7DwPTWuSjoQqrMeodaQ1Q5Ro09cbt +WlL+ywRVq1gKZvgs3RogwDt6NUEZ8Hkz/BZzbo2HlX1+XUpMP7ucHGUQIt7F2Z+6 +iYkMJfP+BflBR+qOzoMbgHo1SD5uIv1/iXi3UoddpCnzsretkcNs2pnpiPWoEhdA +fNuxAgMBAAGjUzBRMB0GA1UdDgQWBBQqZ8PMw0Z0mvluVJRquaFEiIjzwjAfBgNV +HSMEGDAWgBQqZ8PMw0Z0mvluVJRquaFEiIjzwjAPBgNVHRMBAf8EBTADAQH/MA0G +CSqGSIb3DQEBCwUAA4IBAQA7HR1mmHe5jT52EhSjwePvzvW7Tx6VGSUrhzkhnRVv +IbYjX0jWPSSXvc2NG3LyxyDLLOTkM0xWQLGEQ9LYYuH9Sy1ZUK4Mv7qWO23LaM2s +dYjWDKM9N23XhtgkbzFX6+X1Q93wU5KIibVMkPSzmaxDbhoiKYozznmSjOBt2HR2 +UbpjNPjzN7BL+Gv+8hBhS0UeE2zgN0XcZmyiZQlfL7XTVoszjNd6HeKyCHX1Tk4a +/vYn3B1cFK8u4gRyjPKr8QH/uju4T+0gp8GtB1eQ9erdBkehPgb8x1QwdXWKPp4m +woJDTdgJhMu3w0InHtQztCtiTPphjrN/as2rw9hyYU4C +-----END CERTIFICATE----- diff --git a/crates/notary/server/fixture/tls/rootCA.key b/crates/notary/server/fixture/tls/rootCA.key new file mode 100644 index 0000000000..8e0fe87edc --- /dev/null +++ b/crates/notary/server/fixture/tls/rootCA.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDDNGFXBMov4HBr +4F/W+9mzM4t+ww4jURyF/7O1puyhz0gueAu5/kzh6d5r+P2xwP0tpqtITvwfo2tH +CNTgdKBNPO7NnRnW8QtommHhafHUfj+4cR7G1xxSZD34mwuBnYW3cmxCbi0l5dCl +WfHAG7GRHv5aPBBYbeF2ACYBesaCJLa5OMkab/N7DwPTWuSjoQqrMeodaQ1Q5Ro0 +9cbtWlL+ywRVq1gKZvgs3RogwDt6NUEZ8Hkz/BZzbo2HlX1+XUpMP7ucHGUQIt7F +2Z+6iYkMJfP+BflBR+qOzoMbgHo1SD5uIv1/iXi3UoddpCnzsretkcNs2pnpiPWo +EhdAfNuxAgMBAAECggEAGlol5z4e9XD9JvMMEn++wfHBcS7FPStOsyBJPcqibgMH +oY5UjEVc/QU6IPq6H5cIFsjwnTsHJQDwveQz/iErICzg/Xep7K8ZyoNHl3YFTu8Z +jGgTruWMo0AjxZNYwvoQT9WYm9c318KQn4yRlaJSHwnqGHsR/H4eTnRyrQcgE/gY +V7TNEqS7CMvuKqY+rnRhRjXlnKD0p6iT68QF5RVfWH4Qedk5t09JohfTjCK+5+Zo +TXFkpltNv6qHXpZoq5LTo4HZL/l9AnvUU5sjHbzfB6FJtZ0wYtI4q0EIchTusIw8 +cJttSsIHzDnWaw2HLRm7dIHyrk7WqbLUtJRn+Bu9SQKBgQDwKJM7EoH0ZzzG+D24 +lnSV+zjcBeMB4VLRAt5uabWZhmEa0lcb1nv73RU6vmU71UeuzlErSSRPqutr7Ajk +f37xQCeuQKFLrY1OGmiBp4CBOFLe/l2mwPjnDgccgaWPrkj7QoqMdDAK7sdWnUO1 +uo9mKhzX08DuLxlU5VxxarzFqwKBgQDQFLGH38rg2BcYV5TjE2SZHAWZSz5TQTNj +8PMqzZWqbY0WtmEnJEh6I99l0Y4MguuFVjO9WD0kssiQtL+kQvVJkR1WPaOAviFl +PppWyA4BKGcdXSGKsXY08I4KXJaWVzolYZLA/y1zT+7JSrBd2QILyrZvF4DiPv5Y +Jm1LMd6QEwKBgBOnDl1QJ2hLpnKVz98yGLpJQ57lsGzv9mn6NR+N8PluQLYELnKt +u5mhvuH+wKQD0QjiA0xqgNkwIHHFb/ja4hV17YlZ6pkZy61vhcvOXDq21DlBUYKa +2gN2Z2iSx2yZk4lUKahSvbe3UIKq/eZ6LM/sdE3JG0miew0yc70oQehfAoGAOvTy +DEabjDON76a5F9Hh2gP3jiSkpyA9OF8H9yPC+UQLCtloE5gTNRA+9vF2JxNdOi1f +gZGj2WcSrvWXqyoRp+OHBW13iz3T5oTjZB1Q4oEZHlfJ7is0C/HwvPzY6gYTAo5v +72Ed9qM6TCxuZljbXI32POnS6cfhdwaERx79KaMCgYAJzrJBGEd194gVewUsoeiL +fB8eERgvvPCZwfKMh4H0Q8i6RsECNrZJOnNq6xG/Pf1ubasxNYZwSP4yOB+syhA7 +NlvIP8Wps+c0M0oAAhF8q//eduUHyS1o/BbTL44ZkINVlmO5WuQ2pB1QdaBunrnF +GbrTaj5XbaeHwD4CKq5q0w== +-----END PRIVATE KEY----- diff --git a/crates/notary/server/fixture/tls/rootCA.srl b/crates/notary/server/fixture/tls/rootCA.srl new file mode 100644 index 0000000000..460f6f23fa --- /dev/null +++ b/crates/notary/server/fixture/tls/rootCA.srl @@ -0,0 +1 @@ +BA3E3EDC939B110D diff --git a/notary-server/notary-server.Dockerfile b/crates/notary/server/notary-server.Dockerfile similarity index 92% rename from notary-server/notary-server.Dockerfile rename to crates/notary/server/notary-server.Dockerfile index 077f3ed0a0..dfe6e716e9 100644 --- a/notary-server/notary-server.Dockerfile +++ b/crates/notary/server/notary-server.Dockerfile @@ -19,7 +19,7 @@ FROM rust:bookworm AS builder WORKDIR /usr/src/tlsn COPY . . -RUN cargo install --path notary-server +RUN cargo install --path crates/notary/server FROM ubuntu:latest WORKDIR /root/.notary-server @@ -30,9 +30,9 @@ RUN apt-get update && apt-get -y upgrade && apt-get install -y --no-install-reco && apt-get clean \ && rm -rf /var/lib/apt/lists/* # Copy default fixture folder for default usage -COPY --from=builder /usr/src/tlsn/notary-server/fixture ./fixture +COPY --from=builder /usr/src/tlsn/crates/notary/server/fixture ./fixture # Copy default config folder for default usage -COPY --from=builder /usr/src/tlsn/notary-server/config ./config +COPY --from=builder /usr/src/tlsn/crates/notary/server/config ./config COPY --from=builder /usr/local/cargo/bin/notary-server /usr/local/bin/notary-server # Label to link this image with the repository in Github Container Registry (https://docs.github.com/en/packages/learn-github-packages/connecting-a-repository-to-a-package#connecting-a-repository-to-a-container-image-using-the-command-line) LABEL org.opencontainers.image.source=https://github.com/tlsnotary/tlsn diff --git a/notary-server/notary-server.Dockerfile.dockerignore b/crates/notary/server/notary-server.Dockerfile.dockerignore similarity index 60% rename from notary-server/notary-server.Dockerfile.dockerignore rename to crates/notary/server/notary-server.Dockerfile.dockerignore index b326c7e88d..b9fe8dce6c 100644 --- a/notary-server/notary-server.Dockerfile.dockerignore +++ b/crates/notary/server/notary-server.Dockerfile.dockerignore @@ -1,12 +1,9 @@ # exclude everything * -# include notary-server -!/notary-server - -# include core library dependencies -!/tlsn -!/components +# include notary and core library dependencies +!/crates +!/Cargo.toml # include .git for program to get git info !/.git diff --git a/notary-server/openapi.yaml b/crates/notary/server/openapi.yaml similarity index 96% rename from notary-server/openapi.yaml rename to crates/notary/server/openapi.yaml index e13f74e4f0..7044fbc612 100644 --- a/notary-server/openapi.yaml +++ b/crates/notary/server/openapi.yaml @@ -3,7 +3,7 @@ openapi: 3.0.0 info: title: Notary Server description: Notary server written in Rust to provide notarization service. - version: 0.1.0-alpha.3 + version: 0.1.0-alpha.7 tags: - name: General @@ -175,12 +175,14 @@ components: enum: - "Tcp" - "Websocket" - maxTranscriptSize: - description: Maximum transcript size in bytes + maxSentData: + description: Maximum data that can be sent by the prover in bytes + type: integer + maxRecvData: + description: Maximum data that can be received by the prover in bytes type: integer required: - "clientType" - - "maxTranscriptSize" NotarizationSessionResponse: type: object properties: diff --git a/crates/notary/server/src/config.rs b/crates/notary/server/src/config.rs new file mode 100644 index 0000000000..54a8b6bac1 --- /dev/null +++ b/crates/notary/server/src/config.rs @@ -0,0 +1,78 @@ +use serde::Deserialize; + +#[derive(Clone, Debug, Deserialize, Default)] +#[serde(rename_all = "kebab-case")] +pub struct NotaryServerProperties { + /// Name and address of the notary server + pub server: ServerProperties, + /// Setting for notarization + pub notarization: NotarizationProperties, + /// Setting for TLS connection between prover and notary + pub tls: TLSProperties, + /// File path of private key (in PEM format) used to sign the notarization + pub notary_key: NotarySigningKeyProperties, + /// Setting for logging + pub logging: LoggingProperties, + /// Setting for authorization + pub authorization: AuthorizationProperties, +} + +#[derive(Clone, Debug, Deserialize, Default)] +#[serde(rename_all = "kebab-case")] +pub struct AuthorizationProperties { + /// Switch to turn on or off auth middleware + pub enabled: bool, + /// File path of the whitelist API key csv + pub whitelist_csv_path: String, +} + +#[derive(Clone, Debug, Deserialize, Default)] +#[serde(rename_all = "kebab-case")] +pub struct NotarizationProperties { + /// Global limit for maximum number of bytes that can be sent + pub max_sent_data: usize, + /// Global limit for maximum number of bytes that can be received + pub max_recv_data: usize, +} + +#[derive(Clone, Debug, Deserialize, Default)] +#[serde(rename_all = "kebab-case")] +pub struct ServerProperties { + /// Used for testing purpose + pub name: String, + pub host: String, + pub port: u16, + /// Static html response returned from API root endpoint "/". Default html + /// response contains placeholder strings that will be replaced with + /// actual values in server.rs, e.g. {version}, {public_key} + pub html_info: String, +} + +#[derive(Clone, Debug, Deserialize, Default)] +#[serde(rename_all = "kebab-case")] +pub struct TLSProperties { + /// Flag to turn on/off TLS between prover and notary (should always be + /// turned on unless TLS is handled by external setup e.g. reverse proxy, + /// cloud) + pub enabled: bool, + pub private_key_pem_path: String, + pub certificate_pem_path: String, +} + +#[derive(Clone, Debug, Deserialize, Default)] +#[serde(rename_all = "kebab-case")] +pub struct NotarySigningKeyProperties { + pub private_key_pem_path: String, + pub public_key_pem_path: String, +} + +#[derive(Clone, Debug, Deserialize, Default)] +#[serde(rename_all = "kebab-case")] +pub struct LoggingProperties { + /// Log verbosity level of the default filtering logic, which is + /// notary_server=,tlsn_verifier=,tls_mpc= Must be either of + pub level: String, + /// Custom filtering logic, refer to the syntax here https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#example-syntax + /// This will override the default filtering logic above + pub filter: Option, +} diff --git a/notary-server/src/domain.rs b/crates/notary/server/src/domain.rs similarity index 100% rename from notary-server/src/domain.rs rename to crates/notary/server/src/domain.rs diff --git a/notary-server/src/domain/auth.rs b/crates/notary/server/src/domain/auth.rs similarity index 80% rename from notary-server/src/domain/auth.rs rename to crates/notary/server/src/domain/auth.rs index d31793948a..2360b9bef2 100644 --- a/notary-server/src/domain/auth.rs +++ b/crates/notary/server/src/domain/auth.rs @@ -1,8 +1,9 @@ -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; -/// Structure of each whitelisted record of the API key whitelist for authorization purpose -#[derive(Clone, Debug, Deserialize)] +/// Structure of each whitelisted record of the API key whitelist for +/// authorization purpose +#[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "PascalCase")] pub struct AuthorizationWhitelistRecord { pub name: String, @@ -10,7 +11,8 @@ pub struct AuthorizationWhitelistRecord { pub created_at: String, } -/// Convert whitelist data structure from vector to hashmap using api_key as the key to speed up lookup +/// Convert whitelist data structure from vector to hashmap using api_key as the +/// key to speed up lookup pub fn authorization_whitelist_vec_into_hashmap( authorization_whitelist: Vec, ) -> HashMap { diff --git a/notary-server/src/domain/cli.rs b/crates/notary/server/src/domain/cli.rs similarity index 100% rename from notary-server/src/domain/cli.rs rename to crates/notary/server/src/domain/cli.rs diff --git a/notary-server/src/domain/notary.rs b/crates/notary/server/src/domain/notary.rs similarity index 64% rename from notary-server/src/domain/notary.rs rename to crates/notary/server/src/domain/notary.rs index 3c755b432d..39923c9148 100644 --- a/notary-server/src/domain/notary.rs +++ b/crates/notary/server/src/domain/notary.rs @@ -1,9 +1,9 @@ -use std::{collections::HashMap, sync::Arc}; - -use chrono::{DateTime, Utc}; -use p256::ecdsa::SigningKey; use serde::{Deserialize, Serialize}; -use tokio::sync::Mutex; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; +use tlsn_core::CryptoProvider; use crate::{config::NotarizationProperties, domain::auth::AuthorizationWhitelistRecord}; @@ -20,8 +20,10 @@ pub struct NotarizationSessionResponse { #[serde(rename_all = "camelCase")] pub struct NotarizationSessionRequest { pub client_type: ClientType, - /// Maximum transcript size in bytes - pub max_transcript_size: Option, + /// Maximum data that can be sent by the prover + pub max_sent_data: Option, + /// Maximum data that can be received by the prover + pub max_recv_data: Option, } /// Request query of the /notarize API @@ -37,36 +39,30 @@ pub struct NotarizationRequestQuery { pub enum ClientType { /// Client that has access to the transport layer Tcp, - /// Client that cannot directly access transport layer, e.g. browser extension + /// Client that cannot directly access transport layer, e.g. browser + /// extension Websocket, } -/// Session configuration data to be stored in temporary storage -#[derive(Clone, Debug)] -pub struct SessionData { - pub max_transcript_size: Option, - pub created_at: DateTime, -} - /// Global data that needs to be shared with the axum handlers #[derive(Clone, Debug)] pub struct NotaryGlobals { - pub notary_signing_key: SigningKey, + pub crypto_provider: Arc, pub notarization_config: NotarizationProperties, - /// A temporary storage to store configuration data, mainly used for WebSocket client - pub store: Arc>>, + /// A temporary storage to store session_id + pub store: Arc>>, /// Whitelist of API keys for authorization purpose - pub authorization_whitelist: Option>>, + pub authorization_whitelist: Option>>>, } impl NotaryGlobals { pub fn new( - notary_signing_key: SigningKey, + crypto_provider: Arc, notarization_config: NotarizationProperties, - authorization_whitelist: Option>>, + authorization_whitelist: Option>>>, ) -> Self { Self { - notary_signing_key, + crypto_provider, notarization_config, store: Default::default(), authorization_whitelist, diff --git a/notary-server/src/error.rs b/crates/notary/server/src/error.rs similarity index 77% rename from notary-server/src/error.rs rename to crates/notary/server/src/error.rs index c90a6c015f..4b5ec6691b 100644 --- a/notary-server/src/error.rs +++ b/crates/notary/server/src/error.rs @@ -1,11 +1,10 @@ -use axum::{ - http::StatusCode, - response::{IntoResponse, Response}, -}; +use axum::http::StatusCode; +use axum_core::response::{IntoResponse as AxumCoreIntoResponse, Response}; use eyre::Report; use std::error::Error; +use tlsn_common::config::ProtocolConfigValidatorBuilderError; -use tlsn_verifier::tls::{VerifierConfigBuilderError, VerifierError}; +use tlsn_verifier::{VerifierConfigBuilderError, VerifierError}; #[derive(Debug, thiserror::Error)] pub enum NotaryServerError { @@ -33,8 +32,14 @@ impl From for NotaryServerError { } } +impl From for NotaryServerError { + fn from(error: ProtocolConfigValidatorBuilderError) -> Self { + Self::Notarization(Box::new(error)) + } +} + /// Trait implementation to convert this error into an axum http response -impl IntoResponse for NotaryServerError { +impl AxumCoreIntoResponse for NotaryServerError { fn into_response(self) -> Response { match self { bad_request_error @ NotaryServerError::BadProverRequest(_) => { diff --git a/notary-server/src/lib.rs b/crates/notary/server/src/lib.rs similarity index 80% rename from notary-server/src/lib.rs rename to crates/notary/server/src/lib.rs index 8a46a574ae..9353150dbd 100644 --- a/notary-server/src/lib.rs +++ b/crates/notary/server/src/lib.rs @@ -5,11 +5,12 @@ mod middleware; mod server; mod server_tracing; mod service; +mod signing; mod util; pub use config::{ - AuthorizationProperties, NotarizationProperties, NotaryServerProperties, - NotarySigningKeyProperties, ServerProperties, TLSProperties, TracingProperties, + AuthorizationProperties, LoggingProperties, NotarizationProperties, NotaryServerProperties, + NotarySigningKeyProperties, ServerProperties, TLSProperties, }; pub use domain::{ cli::CliFields, diff --git a/notary-server/src/main.rs b/crates/notary/server/src/main.rs similarity index 100% rename from notary-server/src/main.rs rename to crates/notary/server/src/main.rs diff --git a/notary-server/src/middleware.rs b/crates/notary/server/src/middleware.rs similarity index 96% rename from notary-server/src/middleware.rs rename to crates/notary/server/src/middleware.rs index e0b694e396..84f0b40a41 100644 --- a/notary-server/src/middleware.rs +++ b/crates/notary/server/src/middleware.rs @@ -33,6 +33,7 @@ where match auth_header { Some(auth_header) => { + let whitelist = whitelist.lock().unwrap(); if api_key_is_valid(auth_header, &whitelist) { trace!("Request authorized."); Ok(Self) @@ -96,9 +97,6 @@ mod test { #[test] fn test_api_key_is_absent() { let whitelist = get_whitelist_fixture(); - assert_eq!( - api_key_is_valid("test-api-keY-0", &Arc::new(whitelist)), - false - ); + assert!(!api_key_is_valid("test-api-keY-0", &Arc::new(whitelist))); } } diff --git a/crates/notary/server/src/server.rs b/crates/notary/server/src/server.rs new file mode 100644 index 0000000000..7006073fc5 --- /dev/null +++ b/crates/notary/server/src/server.rs @@ -0,0 +1,449 @@ +use axum::{ + extract::Request, + http::StatusCode, + middleware::from_extractor_with_state, + response::{Html, IntoResponse}, + routing::{get, post}, + Json, Router, +}; +use eyre::{ensure, eyre, Result}; +use futures_util::future::poll_fn; +use hyper::{body::Incoming, server::conn::http1}; +use hyper_util::rt::TokioIo; +use notify::{ + event::ModifyKind, Error, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher, +}; +use pkcs8::DecodePrivateKey; +use rustls::{Certificate, PrivateKey, ServerConfig}; +use std::{ + collections::HashMap, + fs::File as StdFile, + io::BufReader, + net::{IpAddr, SocketAddr}, + path::Path, + pin::Pin, + sync::{Arc, Mutex}, +}; +use tlsn_core::CryptoProvider; +use tokio::{fs::File, io::AsyncReadExt, net::TcpListener}; +use tokio_rustls::TlsAcceptor; +use tower_http::cors::CorsLayer; +use tower_service::Service; +use tracing::{debug, error, info}; +use zeroize::Zeroize; + +use crate::{ + config::{NotaryServerProperties, NotarySigningKeyProperties}, + domain::{ + auth::{authorization_whitelist_vec_into_hashmap, AuthorizationWhitelistRecord}, + notary::NotaryGlobals, + InfoResponse, + }, + error::NotaryServerError, + middleware::AuthorizationMiddleware, + service::{initialize, upgrade_protocol}, + signing::AttestationKey, + util::parse_csv_file, +}; + +/// Start a TCP server (with or without TLS) to accept notarization request for +/// both TCP and WebSocket clients +#[tracing::instrument(skip(config))] +pub async fn run_server(config: &NotaryServerProperties) -> Result<(), NotaryServerError> { + // Load the private key for notarized transcript signing + let attestation_key = load_attestation_key(&config.notary_key).await?; + let crypto_provider = build_crypto_provider(attestation_key); + + // Build TLS acceptor if it is turned on + let tls_acceptor = if !config.tls.enabled { + debug!("Skipping TLS setup as it is turned off."); + None + } else { + let (tls_private_key, tls_certificates) = load_tls_key_and_cert( + &config.tls.private_key_pem_path, + &config.tls.certificate_pem_path, + ) + .await?; + + let mut server_config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(tls_certificates, tls_private_key) + .map_err(|err| eyre!("Failed to instantiate notary server tls config: {err}"))?; + + // Set the http protocols we support + server_config.alpn_protocols = vec![b"http/1.1".to_vec()]; + let tls_config = Arc::new(server_config); + Some(TlsAcceptor::from(tls_config)) + }; + + // Load the authorization whitelist csv if it is turned on + let authorization_whitelist = + load_authorization_whitelist(config)?.map(|whitelist| Arc::new(Mutex::new(whitelist))); + // Enable hot reload if authorization whitelist is available + let watcher = + watch_and_reload_authorization_whitelist(config.clone(), authorization_whitelist.clone())?; + if watcher.is_some() { + debug!("Successfully setup watcher for hot reload of authorization whitelist!"); + } + + let notary_address = SocketAddr::new( + IpAddr::V4(config.server.host.parse().map_err(|err| { + eyre!("Failed to parse notary host address from server config: {err}") + })?), + config.server.port, + ); + let mut listener = TcpListener::bind(notary_address) + .await + .map_err(|err| eyre!("Failed to bind server address to tcp listener: {err}"))?; + + info!("Listening for TCP traffic at {}", notary_address); + + let protocol = Arc::new(http1::Builder::new()); + let notary_globals = NotaryGlobals::new( + Arc::new(crypto_provider), + config.notarization.clone(), + authorization_whitelist, + ); + + // Parameters needed for the info endpoint + let public_key = std::fs::read_to_string(&config.notary_key.public_key_pem_path) + .map_err(|err| eyre!("Failed to load notary public signing key for notarization: {err}"))?; + let version = env!("CARGO_PKG_VERSION").to_string(); + let git_commit_hash = env!("GIT_COMMIT_HASH").to_string(); + let git_commit_timestamp = env!("GIT_COMMIT_TIMESTAMP").to_string(); + + // Parameters needed for the root / endpoint + let html_string = config.server.html_info.clone(); + let html_info = Html( + html_string + .replace("{version}", &version) + .replace("{git_commit_hash}", &git_commit_hash) + .replace("{git_commit_timestamp}", &git_commit_timestamp) + .replace("{public_key}", &public_key), + ); + + let router = Router::new() + .route( + "/", + get(|| async move { (StatusCode::OK, html_info).into_response() }), + ) + .route( + "/healthcheck", + get(|| async move { (StatusCode::OK, "Ok").into_response() }), + ) + .route( + "/info", + get(|| async move { + ( + StatusCode::OK, + Json(InfoResponse { + version, + public_key, + git_commit_hash, + git_commit_timestamp, + }), + ) + .into_response() + }), + ) + .route("/session", post(initialize)) + // Not applying auth middleware to /notarize endpoint for now as we can rely on our + // short-lived session id generated from /session endpoint, as it is not possible + // to use header for API key for websocket /notarize endpoint due to browser restriction + // ref: https://stackoverflow.com/a/4361358; And putting it in url query param + // seems to be more insecured: https://stackoverflow.com/questions/5517281/place-api-key-in-headers-or-url + .route_layer(from_extractor_with_state::< + AuthorizationMiddleware, + NotaryGlobals, + >(notary_globals.clone())) + .route("/notarize", get(upgrade_protocol)) + .layer(CorsLayer::permissive()) + .with_state(notary_globals); + + loop { + // Poll and await for any incoming connection, ensure that all operations inside + // are infallible to prevent bringing down the server + let stream = match poll_fn(|cx| Pin::new(&mut listener).poll_accept(cx)).await { + Ok((stream, _)) => stream, + Err(err) => { + error!("{}", NotaryServerError::Connection(err.to_string())); + continue; + } + }; + debug!("Received a prover's TCP connection"); + + let tower_service = router.clone(); + let tls_acceptor = tls_acceptor.clone(); + let protocol = protocol.clone(); + + // Spawn a new async task to handle the new connection + tokio::spawn(async move { + // When TLS is enabled + if let Some(acceptor) = tls_acceptor { + match acceptor.accept(stream).await { + Ok(stream) => { + info!("Accepted prover's TLS-secured TCP connection"); + // Reference: https://github.com/tokio-rs/axum/blob/5201798d4e4d4759c208ef83e30ce85820c07baa/examples/low-level-rustls/src/main.rs#L67-L80 + let io = TokioIo::new(stream); + let hyper_service = + hyper::service::service_fn(move |request: Request| { + tower_service.clone().call(request) + }); + // Serve different requests using the same hyper protocol and axum router + let _ = protocol + .serve_connection(io, hyper_service) + // use with_upgrades to upgrade connection to websocket for websocket + // clients and to extract tcp connection for + // tcp clients + .with_upgrades() + .await; + } + Err(err) => { + error!("{}", NotaryServerError::Connection(err.to_string())); + } + } + } else { + // When TLS is disabled + info!("Accepted prover's TCP connection",); + // Reference: https://github.com/tokio-rs/axum/blob/5201798d4e4d4759c208ef83e30ce85820c07baa/examples/low-level-rustls/src/main.rs#L67-L80 + let io = TokioIo::new(stream); + let hyper_service = + hyper::service::service_fn(move |request: Request| { + tower_service.clone().call(request) + }); + // Serve different requests using the same hyper protocol and axum router + let _ = protocol + .serve_connection(io, hyper_service) + // use with_upgrades to upgrade connection to websocket for websocket clients + // and to extract tcp connection for tcp clients + .with_upgrades() + .await; + } + }); + } +} + +fn build_crypto_provider(attestation_key: AttestationKey) -> CryptoProvider { + let mut provider = CryptoProvider::default(); + provider.signer.set_signer(attestation_key.into_signer()); + provider +} + +/// Load notary signing key for attestations from static file +async fn load_attestation_key(config: &NotarySigningKeyProperties) -> Result { + debug!("Loading notary server's signing key"); + + let mut file = File::open(&config.private_key_pem_path).await?; + let mut pem = String::new(); + file.read_to_string(&mut pem) + .await + .map_err(|_| eyre!("pem file does not contain valid UTF-8"))?; + + let key = AttestationKey::from_pkcs8_pem(&pem) + .map_err(|err| eyre!("Failed to load notary signing key for notarization: {err}"))?; + + pem.zeroize(); + + debug!("Successfully loaded notary server's signing key!"); + + Ok(key) +} + +/// Read a PEM-formatted file and return its buffer reader +pub async fn read_pem_file(file_path: &str) -> Result> { + let key_file = File::open(file_path).await?.into_std().await; + Ok(BufReader::new(key_file)) +} + +/// Load notary tls private key and cert from static files +async fn load_tls_key_and_cert( + private_key_pem_path: &str, + certificate_pem_path: &str, +) -> Result<(PrivateKey, Vec)> { + debug!("Loading notary server's tls private key and certificate"); + + let mut private_key_file_reader = read_pem_file(private_key_pem_path).await?; + let mut private_keys = rustls_pemfile::pkcs8_private_keys(&mut private_key_file_reader)?; + ensure!( + private_keys.len() == 1, + "More than 1 key found in the tls private key pem file" + ); + let private_key = PrivateKey(private_keys.remove(0)); + + let mut certificate_file_reader = read_pem_file(certificate_pem_path).await?; + let certificates = rustls_pemfile::certs(&mut certificate_file_reader)? + .into_iter() + .map(Certificate) + .collect(); + + debug!("Successfully loaded notary server's tls private key and certificate!"); + Ok((private_key, certificates)) +} + +/// Load authorization whitelist if it is enabled +fn load_authorization_whitelist( + config: &NotaryServerProperties, +) -> Result>> { + let authorization_whitelist = if !config.authorization.enabled { + debug!("Skipping authorization as it is turned off."); + None + } else { + // Load the csv + let whitelist_csv = parse_csv_file::( + &config.authorization.whitelist_csv_path, + ) + .map_err(|err| eyre!("Failed to parse authorization whitelist csv: {:?}", err))?; + // Convert the whitelist record into hashmap for faster lookup + let whitelist_hashmap = authorization_whitelist_vec_into_hashmap(whitelist_csv); + Some(whitelist_hashmap) + }; + Ok(authorization_whitelist) +} + +// Setup a watcher to detect any changes to authorization whitelist +// When the list file is modified, the watcher thread will reload the whitelist +// The watcher is setup in a separate thread by the notify library which is +// synchronous +fn watch_and_reload_authorization_whitelist( + config: NotaryServerProperties, + authorization_whitelist: Option>>>, +) -> Result> { + // Only setup the watcher if auth whitelist is loaded + let watcher = if let Some(authorization_whitelist) = authorization_whitelist { + let cloned_config = config.clone(); + // Setup watcher by giving it a function that will be triggered when an event is + // detected + let mut watcher = RecommendedWatcher::new( + move |event: Result| { + match event { + Ok(event) => { + // Only reload whitelist if it's an event that modified the file data + if let EventKind::Modify(ModifyKind::Data(_)) = event.kind { + debug!("Authorization whitelist is modified"); + match load_authorization_whitelist(&cloned_config) { + Ok(Some(new_authorization_whitelist)) => { + *authorization_whitelist.lock().unwrap() = new_authorization_whitelist; + info!("Successfully reloaded authorization whitelist!"); + } + Ok(None) => unreachable!( + "Authorization whitelist will never be None as the auth module is enabled" + ), + // Ensure that error from reloading doesn't bring the server down + Err(err) => error!("{err}"), + } + } + }, + Err(err) => { + error!("Error occured when watcher detected an event: {err}") + } + } + }, + notify::Config::default(), + ) + .map_err(|err| eyre!("Error occured when setting up watcher for hot reload: {err}"))?; + + // Start watcher to listen to any changes on the whitelist file + watcher + .watch( + Path::new(&config.authorization.whitelist_csv_path), + RecursiveMode::Recursive, + ) + .map_err(|err| eyre!("Error occured when starting up watcher for hot reload: {err}"))?; + + Some(watcher) + } else { + // Skip setup the watcher if auth whitelist is not loaded + None + }; + // Need to return the watcher to parent function, else it will be dropped and + // stop listening + Ok(watcher) +} + +#[cfg(test)] +mod test { + use std::{fs::OpenOptions, time::Duration}; + + use csv::WriterBuilder; + + use crate::AuthorizationProperties; + + use super::*; + + #[tokio::test] + async fn test_load_notary_key_and_cert() { + let private_key_pem_path = "./fixture/tls/notary.key"; + let certificate_pem_path = "./fixture/tls/notary.crt"; + let result: Result<(PrivateKey, Vec)> = + load_tls_key_and_cert(private_key_pem_path, certificate_pem_path).await; + assert!(result.is_ok(), "Could not load tls private key and cert"); + } + + #[tokio::test] + async fn test_load_attestation_key() { + let config = NotarySigningKeyProperties { + private_key_pem_path: "./fixture/notary/notary.key".to_string(), + public_key_pem_path: "./fixture/notary/notary.pub".to_string(), + }; + load_attestation_key(&config).await.unwrap(); + } + + #[tokio::test] + async fn test_watch_and_reload_authorization_whitelist() { + // Clone fixture auth whitelist for testing + let original_whitelist_csv_path = "./fixture/auth/whitelist.csv"; + let whitelist_csv_path = "./fixture/auth/whitelist_copied.csv".to_string(); + std::fs::copy(original_whitelist_csv_path, &whitelist_csv_path).unwrap(); + + // Setup watcher + let config = NotaryServerProperties { + authorization: AuthorizationProperties { + enabled: true, + whitelist_csv_path, + }, + ..Default::default() + }; + let authorization_whitelist = load_authorization_whitelist(&config) + .expect("Authorization whitelist csv from fixture should be able to be loaded") + .as_ref() + .map(|whitelist| Arc::new(Mutex::new(whitelist.clone()))); + let _watcher = watch_and_reload_authorization_whitelist( + config.clone(), + authorization_whitelist.as_ref().map(Arc::clone), + ) + .expect("Watcher should be able to be setup successfully") + .expect("Watcher should be set up and not None"); + + // Sleep to buy a bit of time for hot reload task and watcher thread to run + tokio::time::sleep(Duration::from_millis(50)).await; + + // Write a new record to the whitelist to trigger modify event + let new_record = AuthorizationWhitelistRecord { + name: "unit-test-name".to_string(), + api_key: "unit-test-api-key".to_string(), + created_at: "unit-test-created-at".to_string(), + }; + let file = OpenOptions::new() + .append(true) + .open(&config.authorization.whitelist_csv_path) + .unwrap(); + let mut wtr = WriterBuilder::new() + .has_headers(false) // Set to false to avoid writing header again + .from_writer(file); + wtr.serialize(new_record).unwrap(); + wtr.flush().unwrap(); + + // Sleep to buy a bit of time for updated whitelist to be hot reloaded + tokio::time::sleep(Duration::from_millis(50)).await; + + assert!(authorization_whitelist + .unwrap() + .lock() + .unwrap() + .contains_key("unit-test-api-key")); + + // Delete the cloned whitelist + std::fs::remove_file(&config.authorization.whitelist_csv_path).unwrap(); + } +} diff --git a/crates/notary/server/src/server_tracing.rs b/crates/notary/server/src/server_tracing.rs new file mode 100644 index 0000000000..2d7e7bdcba --- /dev/null +++ b/crates/notary/server/src/server_tracing.rs @@ -0,0 +1,34 @@ +use eyre::Result; +use std::str::FromStr; +use tracing::Level; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry}; + +use crate::config::NotaryServerProperties; + +pub fn init_tracing(config: &NotaryServerProperties) -> Result<()> { + // Retrieve log filtering logic from config + let directives = match &config.logging.filter { + // Use custom filter that is provided by user + Some(filter) => filter.clone(), + // Use the default filter when only verbosity level is provided + None => { + let level = Level::from_str(&config.logging.level)?; + format!("notary_server={level},tlsn_verifier={level},tls_mpc={level}") + } + }; + let filter_layer = EnvFilter::builder().parse(directives)?; + + // Format the log + let format_layer = tracing_subscriber::fmt::layer() + // Use a more compact, abbreviated log format + .compact() + .with_thread_ids(true) + .with_thread_names(true); + + Registry::default() + .with(filter_layer) + .with(format_layer) + .try_init()?; + + Ok(()) +} diff --git a/notary-server/src/service.rs b/crates/notary/server/src/service.rs similarity index 52% rename from notary-server/src/service.rs rename to crates/notary/server/src/service.rs index 9d3248b6e0..dd93d36157 100644 --- a/notary-server/src/service.rs +++ b/crates/notary/server/src/service.rs @@ -2,6 +2,8 @@ pub mod axum_websocket; pub mod tcp; pub mod websocket; +use std::sync::Arc; + use async_trait::async_trait; use axum::{ extract::{rejection::JsonRejection, FromRequestParts, Query, State}, @@ -9,9 +11,9 @@ use axum::{ response::{IntoResponse, Json, Response}, }; use axum_macros::debug_handler; -use chrono::Utc; -use p256::ecdsa::{Signature, SigningKey}; -use tlsn_verifier::tls::{Verifier, VerifierConfig}; +use tlsn_common::config::ProtocolConfigValidator; +use tlsn_core::{attestation::AttestationConfig, CryptoProvider}; +use tlsn_verifier::{Verifier, VerifierConfig}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::compat::TokioAsyncReadCompatExt; use tracing::{debug, error, info, trace}; @@ -20,7 +22,7 @@ use uuid::Uuid; use crate::{ domain::notary::{ NotarizationRequestQuery, NotarizationSessionRequest, NotarizationSessionResponse, - NotaryGlobals, SessionData, + NotaryGlobals, }, error::NotaryServerError, service::{ @@ -30,8 +32,9 @@ use crate::{ }, }; -/// A wrapper enum to facilitate extracting TCP connection for either WebSocket or TCP clients, -/// so that we can use a single endpoint and handler for notarization for both types of clients +/// A wrapper enum to facilitate extracting TCP connection for either WebSocket +/// or TCP clients, so that we can use a single endpoint and handler for +/// notarization for both types of clients pub enum ProtocolUpgrade { Tcp(TcpUpgrade), Ws(WebSocketUpgrade), @@ -65,9 +68,10 @@ where } } -/// Handler to upgrade protocol from http to either websocket or underlying tcp depending on the type of client -/// the session_id parameter is also extracted here to fetch the configuration parameters -/// that have been submitted in the previous request to /session made by the same client +/// Handler to upgrade protocol from http to either websocket or underlying tcp +/// depending on the type of client the session_id parameter is also extracted +/// here to fetch the configuration parameters that have been submitted in the +/// previous request to /session made by the same client pub async fn upgrade_protocol( protocol_upgrade: ProtocolUpgrade, State(notary_globals): State, @@ -75,28 +79,33 @@ pub async fn upgrade_protocol( ) -> Response { info!("Received upgrade protocol request"); let session_id = params.session_id; - // Fetch the configuration data from the store using the session_id - // This also removes the configuration data from the store as each session_id can only be used once - let max_transcript_size = match notary_globals.store.lock().await.remove(&session_id) { - Some(data) => data.max_transcript_size, - None => { - let err_msg = format!("Session id {} does not exist", session_id); - error!(err_msg); - return NotaryServerError::BadProverRequest(err_msg).into_response(); - } + // Check if session_id exists in the store, this also removes session_id from + // the store as each session_id can only be used once + if notary_globals + .store + .lock() + .unwrap() + .remove(&session_id) + .is_none() + { + let err_msg = format!("Session id {} does not exist", session_id); + error!(err_msg); + return NotaryServerError::BadProverRequest(err_msg).into_response(); }; - // This completes the HTTP Upgrade request and returns a successful response to the client, meanwhile initiating the websocket or tcp connection + // This completes the HTTP Upgrade request and returns a successful response to + // the client, meanwhile initiating the websocket or tcp connection match protocol_upgrade { - ProtocolUpgrade::Ws(ws) => ws.on_upgrade(move |socket| { - websocket_notarize(socket, notary_globals, session_id, max_transcript_size) - }), - ProtocolUpgrade::Tcp(tcp) => tcp.on_upgrade(move |stream| { - tcp_notarize(stream, notary_globals, session_id, max_transcript_size) - }), + ProtocolUpgrade::Ws(ws) => { + ws.on_upgrade(move |socket| websocket_notarize(socket, notary_globals, session_id)) + } + ProtocolUpgrade::Tcp(tcp) => { + tcp.on_upgrade(move |stream| tcp_notarize(stream, notary_globals, session_id)) + } } } -/// Handler to initialize and configure notarization for both TCP and WebSocket clients +/// Handler to initialize and configure notarization for both TCP and WebSocket +/// clients #[debug_handler(state = NotaryGlobals)] pub async fn initialize( State(notary_globals): State, @@ -116,28 +125,45 @@ pub async fn initialize( } }; - // Ensure that the max_transcript_size submitted is not larger than the global max limit configured in notary server - if payload.max_transcript_size > Some(notary_globals.notarization_config.max_transcript_size) { - error!( - "Max transcript size requested {:?} exceeds the maximum threshold {:?}", - payload.max_transcript_size, notary_globals.notarization_config.max_transcript_size - ); - return NotaryServerError::BadProverRequest( - "Max transcript size requested exceeds the maximum threshold".to_string(), - ) - .into_response(); + // Ensure that the max_sent_data, max_recv_data submitted is not larger than the + // global max limits configured in notary server + if payload.max_sent_data.is_some() || payload.max_recv_data.is_some() { + if payload.max_sent_data.unwrap_or_default() + > notary_globals.notarization_config.max_sent_data + { + error!( + "Max sent data requested {:?} exceeds the global maximum threshold {:?}", + payload.max_sent_data.unwrap_or_default(), + notary_globals.notarization_config.max_sent_data + ); + return NotaryServerError::BadProverRequest( + "Max sent data requested exceeds the global maximum threshold".to_string(), + ) + .into_response(); + } + if payload.max_recv_data.unwrap_or_default() + > notary_globals.notarization_config.max_recv_data + { + error!( + "Max recv data requested {:?} exceeds the global maximum threshold {:?}", + payload.max_recv_data.unwrap_or_default(), + notary_globals.notarization_config.max_recv_data + ); + return NotaryServerError::BadProverRequest( + "Max recv data requested exceeds the global maximum threshold".to_string(), + ) + .into_response(); + } } let prover_session_id = Uuid::new_v4().to_string(); // Store the configuration data in a temporary store - notary_globals.store.lock().await.insert( - prover_session_id.clone(), - SessionData { - max_transcript_size: payload.max_transcript_size, - created_at: Utc::now(), - }, - ); + notary_globals + .store + .lock() + .unwrap() + .insert(prover_session_id.clone(), ()); trace!("Latest store state: {:?}", notary_globals.store); @@ -154,24 +180,30 @@ pub async fn initialize( /// Run the notarization pub async fn notary_service( socket: T, - signing_key: &SigningKey, + crypto_provider: Arc, session_id: &str, - max_transcript_size: Option, + max_sent_data: usize, + max_recv_data: usize, ) -> Result<(), NotaryServerError> { debug!(?session_id, "Starting notarization..."); - let mut config_builder = VerifierConfig::builder(); - - config_builder = config_builder.id(session_id); - - if let Some(max_transcript_size) = max_transcript_size { - config_builder = config_builder.max_transcript_size(max_transcript_size); - } - - let config = config_builder.build()?; + let att_config = AttestationConfig::builder() + .supported_signature_algs(Vec::from_iter(crypto_provider.signer.supported_algs())) + .build() + .map_err(|err| NotaryServerError::Notarization(Box::new(err)))?; + + let config = VerifierConfig::builder() + .protocol_config_validator( + ProtocolConfigValidator::builder() + .max_sent_data(max_sent_data) + .max_recv_data(max_recv_data) + .build()?, + ) + .crypto_provider(crypto_provider) + .build()?; Verifier::new(config) - .notarize::<_, Signature>(socket.compat(), signing_key) + .notarize(socket.compat(), &att_config) .await?; Ok(()) diff --git a/notary-server/src/service/axum_websocket.rs b/crates/notary/server/src/service/axum_websocket.rs similarity index 83% rename from notary-server/src/service/axum_websocket.rs rename to crates/notary/server/src/service/axum_websocket.rs index eea235871e..991391b6f9 100644 --- a/notary-server/src/service/axum_websocket.rs +++ b/crates/notary/server/src/service/axum_websocket.rs @@ -1,8 +1,9 @@ -//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.6.19/axum/src/extract/ws.rs +//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.7.3/axum/src/extract/ws.rs //! where we swapped out tokio_tungstenite (https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/) //! with async_tungstenite (https://docs.rs/async-tungstenite/latest/async_tungstenite/) so that we can use //! ws_stream_tungstenite (https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html) -//! to get AsyncRead and AsyncWrite implemented for the WebSocket. Any other modification is commented with the prefix "NOTARY_MODIFICATION:" +//! to get AsyncRead and AsyncWrite implemented for the WebSocket. Any other +//! modification is commented with the prefix "NOTARY_MODIFICATION:" //! //! The code is under the following license: //! @@ -66,9 +67,7 @@ //! } //! } //! } -//! # async { -//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -//! # }; +//! # let _: Router = app; //! ``` //! //! # Passing data and/or state to an `on_upgrade` callback @@ -97,9 +96,7 @@ //! let app = Router::new() //! .route("/ws", get(handler)) //! .with_state(AppState { /* ... */ }); -//! # async { -//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -//! # }; +//! # let _: Router = app; //! ``` //! //! # Read and write concurrently @@ -128,7 +125,6 @@ //! ``` //! //! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split - #![allow(unused)] use self::rejection::*; @@ -141,13 +137,8 @@ use async_tungstenite::{ }, WebSocketStream, }; -use axum::{ - body::{self, Bytes}, - extract::FromRequestParts, - response::Response, - Error, -}; - +use axum::{body::Bytes, extract::FromRequestParts, response::Response, Error}; +use axum_core::body::Body; use futures_util::{ sink::{Sink, SinkExt}, stream::{Stream, StreamExt}, @@ -157,7 +148,7 @@ use http::{ request::Parts, Method, StatusCode, }; -use hyper::upgrade::{OnUpgrade, Upgraded}; +use hyper_util::rt::TokioIo; use sha1::{Digest, Sha1}; use std::{ borrow::Cow, @@ -170,17 +161,18 @@ use tracing::error; /// Extractor for establishing WebSocket connections. /// /// Note: This extractor requires the request method to be `GET` so it should -/// always be used with [`get`](crate::routing::get). Requests with other methods will be -/// rejected. +/// always be used with [`get`](crate::routing::get). Requests with other +/// methods will be rejected. /// /// See the [module docs](self) for an example. #[cfg_attr(docsrs, doc(cfg(feature = "ws")))] -pub struct WebSocketUpgrade { +pub struct WebSocketUpgrade { config: WebSocketConfig, - /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. + /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the + /// response. protocol: Option, sec_websocket_key: HeaderValue, - on_upgrade: OnUpgrade, + on_upgrade: hyper::upgrade::OnUpgrade, on_failed_upgrade: F, sec_websocket_protocol: Option, } @@ -197,9 +189,38 @@ impl std::fmt::Debug for WebSocketUpgrade { } impl WebSocketUpgrade { - /// Set the size of the internal message send queue. - pub fn max_send_queue(mut self, max: usize) -> Self { - self.config.max_send_queue = Some(max); + /// The target minimum size of the write buffer to reach before writing the + /// data to the underlying stream. + /// + /// The default value is 128 KiB. + /// + /// If set to `0` each message will be eagerly written to the underlying + /// stream. It is often more optimal to allow them to buffer a little, + /// hence the default value. + /// + /// Note: [`flush`](SinkExt::flush) will always fully write the buffer + /// regardless. + pub fn write_buffer_size(mut self, size: usize) -> Self { + self.config.write_buffer_size = size; + self + } + + /// The max size of the write buffer in bytes. Setting this can provide + /// backpressure in the case the write buffer is filling up due to write + /// errors. + /// + /// The default value is unlimited. + /// + /// Note: The write buffer only builds up past + /// [`write_buffer_size`](Self::write_buffer_size) when writes to the + /// underlying stream are failing. So the **write buffer can not fill up + /// if you are not observing write errors even if not flushing**. + /// + /// Note: Should always be at least [`write_buffer_size + 1 + /// message`](Self::write_buffer_size) and probably a little more + /// depending on error handling strategy. + pub fn max_write_buffer_size(mut self, max: usize) -> Self { + self.config.max_write_buffer_size = max; self } @@ -224,12 +245,12 @@ impl WebSocketUpgrade { /// Set the known protocols. /// /// If the protocol name specified by `Sec-WebSocket-Protocol` header - /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and - /// return the protocol name. + /// to match any of them, the upgrade response will include + /// `Sec-WebSocket-Protocol` header and return the protocol name. /// - /// The protocols should be listed in decreasing order of preference: if the client offers - /// multiple protocols that the server could support, the server will pick the first one in - /// this list. + /// The protocols should be listed in decreasing order of preference: if the + /// client offers multiple protocols that the server could support, the + /// server will pick the first one in this list. /// /// # Examples /// @@ -249,9 +270,7 @@ impl WebSocketUpgrade { /// // ... /// }) /// } - /// # async { - /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); - /// # }; + /// # let _: Router = app; /// ``` pub fn protocols(mut self, protocols: I) -> Self where @@ -284,8 +303,8 @@ impl WebSocketUpgrade { /// Provide a callback to call if upgrading the connection fails. /// - /// The connection upgrade is performed in a background task. If that fails this callback - /// will be called. + /// The connection upgrade is performed in a background task. If that fails + /// this callback will be called. /// /// By default any errors will be silently ignored. /// @@ -308,7 +327,7 @@ impl WebSocketUpgrade { /// ``` pub fn on_failed_upgrade(self, callback: C) -> WebSocketUpgrade where - C: OnFailedUpdgrade, + C: OnFailedUpgrade, { WebSocketUpgrade { config: self.config, @@ -322,12 +341,12 @@ impl WebSocketUpgrade { /// Finalize upgrading the connection and call the provided callback with /// the stream. - #[must_use = "to setup the WebSocket connection, this response must be returned"] + #[must_use = "to set up the WebSocket connection, this response must be returned"] pub fn on_upgrade(self, callback: C) -> Response where C: FnOnce(WebSocket) -> Fut + Send + 'static, Fut: Future + Send + 'static, - F: OnFailedUpdgrade, + F: OnFailedUpgrade, { let on_upgrade = self.on_upgrade; let config = self.config; @@ -344,8 +363,11 @@ impl WebSocketUpgrade { return; } }; + let upgraded = TokioIo::new(upgraded); + let socket = WebSocketStream::from_raw_socket( - // NOTARY_MODIFICATION: Need to use TokioAdapter to wrap Upgraded which doesn't implement futures crate's AsyncRead and AsyncWrite + // NOTARY_MODIFICATION: Need to use TokioAdapter to wrap Upgraded which doesn't + // implement futures crate's AsyncRead and AsyncWrite TokioAdapter::new(upgraded), protocol::Role::Server, Some(config), @@ -376,19 +398,19 @@ impl WebSocketUpgrade { builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); } - builder.body(body::boxed(body::Empty::new())).unwrap() + builder.body(Body::empty()).unwrap() } } /// What to do when a connection upgrade fails. /// /// See [`WebSocketUpgrade::on_failed_upgrade`] for more details. -pub trait OnFailedUpdgrade: Send + 'static { +pub trait OnFailedUpgrade: Send + 'static { /// Call the callback. fn call(self, error: Error); } -impl OnFailedUpdgrade for F +impl OnFailedUpgrade for F where F: FnOnce(Error) + Send + 'static, { @@ -397,20 +419,20 @@ where } } -/// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`. +/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`. /// /// It simply ignores the error. #[non_exhaustive] #[derive(Debug)] -pub struct DefaultOnFailedUpdgrade; +pub struct DefaultOnFailedUpgrade; -impl OnFailedUpdgrade for DefaultOnFailedUpdgrade { +impl OnFailedUpgrade for DefaultOnFailedUpgrade { #[inline] fn call(self, _error: Error) {} } #[async_trait] -impl FromRequestParts for WebSocketUpgrade +impl FromRequestParts for WebSocketUpgrade where S: Send + Sync, { @@ -441,7 +463,7 @@ where let on_upgrade = parts .extensions - .remove::() + .remove::() .ok_or(ConnectionNotUpgradable)?; let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned(); @@ -452,11 +474,12 @@ where sec_websocket_key, on_upgrade, sec_websocket_protocol, - on_failed_upgrade: DefaultOnFailedUpdgrade, + on_failed_upgrade: DefaultOnFailedUpgrade, }) } } +/// NOTARY_MODIFICATION: Made this function public to be used in service.rs pub fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { if let Some(header) = headers.get(&key) { header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) @@ -484,13 +507,14 @@ fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> /// See [the module level documentation](self) for more details. #[derive(Debug)] pub struct WebSocket { - inner: WebSocketStream>, + inner: WebSocketStream>>, protocol: Option, } impl WebSocket { - /// Consume `self` and get the inner [`async_tungstenite::WebSocketStream`]. - pub fn into_inner(self) -> WebSocketStream> { + /// NOTARY_MODIFICATION: Consume `self` and get the inner + /// [`async_tungstenite::WebSocketStream`]. + pub fn into_inner(self) -> WebSocketStream>> { self.inner } @@ -560,7 +584,8 @@ impl Sink for WebSocket { } } -/// Status code used to indicate why an endpoint is closing the WebSocket connection. +/// Status code used to indicate why an endpoint is closing the WebSocket +/// connection. pub type CloseCode = u16; /// A struct representing the close command. @@ -605,16 +630,16 @@ pub enum Message { /// /// The payload here must have a length less than 125 bytes. /// - /// Ping messages will be automatically responded to by the server, so you do not have to worry - /// about dealing with them yourself. + /// Ping messages will be automatically responded to by the server, so you + /// do not have to worry about dealing with them yourself. Ping(Vec), /// A pong message with the specified payload /// /// The payload here must have a length less than 125 bytes. /// - /// Pong messages will be automatically sent to the client if a ping message is received, so - /// you do not have to worry about constructing them yourself unless you want to implement a - /// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3). + /// Pong messages will be automatically sent to the client if a ping message + /// is received, so you do not have to worry about constructing them + /// yourself unless you want to implement a [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3). Pong(Vec), /// A close message with the optional close frame. Close(Option>), @@ -805,20 +830,22 @@ pub mod close_code { //! //! [`CloseCode`]: super::CloseCode - /// Indicates a normal closure, meaning that the purpose for which the connection was - /// established has been fulfilled. + /// Indicates a normal closure, meaning that the purpose for which the + /// connection was established has been fulfilled. pub const NORMAL: u16 = 1000; - /// Indicates that an endpoint is "going away", such as a server going down or a browser having - /// navigated away from a page. + /// Indicates that an endpoint is "going away", such as a server going down + /// or a browser having navigated away from a page. pub const AWAY: u16 = 1001; - /// Indicates that an endpoint is terminating the connection due to a protocol error. + /// Indicates that an endpoint is terminating the connection due to a + /// protocol error. pub const PROTOCOL: u16 = 1002; - /// Indicates that an endpoint is terminating the connection because it has received a type of - /// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if - /// it receives a binary message). + /// Indicates that an endpoint is terminating the connection because it has + /// received a type of data it cannot accept (e.g., an endpoint that + /// understands only text data MAY send this if it receives a binary + /// message). pub const UNSUPPORTED: u16 = 1003; /// Indicates that no status code was included in a closing frame. @@ -827,38 +854,42 @@ pub mod close_code { /// Indicates an abnormal closure. pub const ABNORMAL: u16 = 1006; - /// Indicates that an endpoint is terminating the connection because it has received data - /// within a message that was not consistent with the type of the message (e.g., non-UTF-8 - /// RFC3629 data within a text message). + /// Indicates that an endpoint is terminating the connection because it has + /// received data within a message that was not consistent with the type + /// of the message (e.g., non-UTF-8 RFC3629 data within a text message). pub const INVALID: u16 = 1007; - /// Indicates that an endpoint is terminating the connection because it has received a message - /// that violates its policy. This is a generic status code that can be returned when there is - /// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to - /// hide specific details about the policy. + /// Indicates that an endpoint is terminating the connection because it has + /// received a message that violates its policy. This is a generic + /// status code that can be returned when there is no other more + /// suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a + /// need to hide specific details about the policy. pub const POLICY: u16 = 1008; - /// Indicates that an endpoint is terminating the connection because it has received a message - /// that is too big for it to process. + /// Indicates that an endpoint is terminating the connection because it has + /// received a message that is too big for it to process. pub const SIZE: u16 = 1009; - /// Indicates that an endpoint (client) is terminating the connection because it has expected - /// the server to negotiate one or more extension, but the server didn't return them in the - /// response message of the WebSocket handshake. The list of extensions that are needed should - /// be given as the reason for closing. Note that this status code is not used by the server, - /// because it can fail the WebSocket handshake instead. + /// Indicates that an endpoint (client) is terminating the connection + /// because it has expected the server to negotiate one or more + /// extension, but the server didn't return them in the response message + /// of the WebSocket handshake. The list of extensions that are needed + /// should be given as the reason for closing. Note that this status + /// code is not used by the server, because it can fail the WebSocket + /// handshake instead. pub const EXTENSION: u16 = 1010; - /// Indicates that a server is terminating the connection because it encountered an unexpected - /// condition that prevented it from fulfilling the request. + /// Indicates that a server is terminating the connection because it + /// encountered an unexpected condition that prevented it from + /// fulfilling the request. pub const ERROR: u16 = 1011; /// Indicates that the server is restarting. pub const RESTART: u16 = 1012; - /// Indicates that the server is overloaded and the client should either connect to a different - /// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an - /// action. + /// Indicates that the server is overloaded and the client should either + /// connect to a different IP (when multiple targets exist), or + /// reconnect to the same IP when a user has performed an action. pub const AGAIN: u16 = 1013; } @@ -867,7 +898,8 @@ mod tests { use super::*; use axum::{body::Body, routing::get, Router}; use http::{Request, Version}; - use tower::ServiceExt; + // NOTARY_MODIFICATION: use tower_util instead of tower to make clippy happy + use tower_util::ServiceExt; #[tokio::test] async fn rejects_http_1_0_requests() { diff --git a/notary-server/src/service/tcp.rs b/crates/notary/server/src/service/tcp.rs similarity index 74% rename from notary-server/src/service/tcp.rs rename to crates/notary/server/src/service/tcp.rs index c401d9beff..e68f79370d 100644 --- a/notary-server/src/service/tcp.rs +++ b/crates/notary/server/src/service/tcp.rs @@ -1,19 +1,21 @@ use async_trait::async_trait; use axum::{ - body, extract::FromRequestParts, http::{header, request::Parts, HeaderValue, StatusCode}, response::Response, }; +use axum_core::body::Body; use hyper::upgrade::{OnUpgrade, Upgraded}; +use hyper_util::rt::TokioIo; use std::future::Future; use tracing::{debug, error, info}; use crate::{domain::notary::NotaryGlobals, service::notary_service, NotaryServerError}; -/// Custom extractor used to extract underlying TCP connection for TCP client — using the same upgrade primitives used by -/// the WebSocket implementation where the underlying TCP connection (wrapped in an Upgraded object) only gets polled as an OnUpgrade future -/// after the ongoing HTTP request is finished (ref: https://github.com/tokio-rs/axum/blob/a6a849bb5b96a2f641fa077fe76f70ad4d20341c/axum/src/extract/ws.rs#L122) +/// Custom extractor used to extract underlying TCP connection for TCP client — +/// using the same upgrade primitives used by the WebSocket implementation where +/// the underlying TCP connection (wrapped in an Upgraded object) only gets +/// polled as an OnUpgrade future after the ongoing HTTP request is finished (ref: https://github.com/tokio-rs/axum/blob/a6a849bb5b96a2f641fa077fe76f70ad4d20341c/axum/src/extract/ws.rs#L122) /// /// More info on the upgrade primitives: https://docs.rs/hyper/latest/hyper/upgrade/index.html pub struct TcpUpgrade { @@ -42,11 +44,12 @@ where impl TcpUpgrade { /// Utility function to complete the http upgrade protocol by - /// (1) Return 101 switching protocol response to client to indicate the switching to TCP - /// (2) Spawn a new thread to await on the OnUpgrade object to claim the underlying TCP connection + /// (1) Return 101 switching protocol response to client to indicate the + /// switching to TCP (2) Spawn a new thread to await on the OnUpgrade + /// object to claim the underlying TCP connection pub fn on_upgrade(self, callback: C) -> Response where - C: FnOnce(Upgraded) -> Fut + Send + 'static, + C: FnOnce(TokioIo) -> Fut + Send + 'static, Fut: Future + Send + 'static, { let on_upgrade = self.on_upgrade; @@ -58,6 +61,8 @@ impl TcpUpgrade { return; } }; + let upgraded = TokioIo::new(upgraded); + callback(upgraded).await; }); @@ -71,23 +76,23 @@ impl TcpUpgrade { .header(header::CONNECTION, UPGRADE) .header(header::UPGRADE, TCP); - builder.body(body::boxed(body::Empty::new())).unwrap() + builder.body(Body::empty()).unwrap() } } /// Perform notarization using the extracted tcp connection pub async fn tcp_notarize( - stream: Upgraded, + stream: TokioIo, notary_globals: NotaryGlobals, session_id: String, - max_transcript_size: Option, ) { debug!(?session_id, "Upgraded to tcp connection"); match notary_service( stream, - ¬ary_globals.notary_signing_key, + notary_globals.crypto_provider.clone(), &session_id, - max_transcript_size, + notary_globals.notarization_config.max_sent_data, + notary_globals.notarization_config.max_recv_data, ) .await { diff --git a/notary-server/src/service/websocket.rs b/crates/notary/server/src/service/websocket.rs similarity index 81% rename from notary-server/src/service/websocket.rs rename to crates/notary/server/src/service/websocket.rs index 1a06c0d5ca..f0f1ea5163 100644 --- a/notary-server/src/service/websocket.rs +++ b/crates/notary/server/src/service/websocket.rs @@ -11,16 +11,17 @@ pub async fn websocket_notarize( socket: WebSocket, notary_globals: NotaryGlobals, session_id: String, - max_transcript_size: Option, ) { debug!(?session_id, "Upgraded to websocket connection"); - // Wrap the websocket in WsStream so that we have AsyncRead and AsyncWrite implemented + // Wrap the websocket in WsStream so that we have AsyncRead and AsyncWrite + // implemented let stream = WsStream::new(socket.into_inner()); match notary_service( stream, - ¬ary_globals.notary_signing_key, + notary_globals.crypto_provider.clone(), &session_id, - max_transcript_size, + notary_globals.notarization_config.max_sent_data, + notary_globals.notarization_config.max_recv_data, ) .await { diff --git a/crates/notary/server/src/signing.rs b/crates/notary/server/src/signing.rs new file mode 100644 index 0000000000..6d990485a2 --- /dev/null +++ b/crates/notary/server/src/signing.rs @@ -0,0 +1,74 @@ +use core::fmt; + +use pkcs8::{der::Encode, AssociatedOid, DecodePrivateKey, ObjectIdentifier, PrivateKeyInfo}; +use tlsn_core::signing::{Secp256k1Signer, Secp256r1Signer, SignatureAlgId, Signer}; +use tracing::error; + +/// A cryptographic key used for signing attestations. +pub struct AttestationKey { + alg_id: SignatureAlgId, + key: SigningKey, +} + +impl TryFrom> for AttestationKey { + type Error = pkcs8::Error; + + fn try_from(pkcs8: PrivateKeyInfo<'_>) -> Result { + const OID_EC_PUBLIC_KEY: ObjectIdentifier = + ObjectIdentifier::new_unwrap("1.2.840.10045.2.1"); + + // For now we only support elliptic curve keys + if pkcs8.algorithm.oid != OID_EC_PUBLIC_KEY { + error!("unsupported key algorithm OID: {:?}", pkcs8.algorithm.oid); + + return Err(pkcs8::Error::KeyMalformed); + } + + let (alg_id, key) = match pkcs8.algorithm.parameters_oid()? { + k256::Secp256k1::OID => { + let key = k256::ecdsa::SigningKey::from_pkcs8_der(&pkcs8.to_der()?) + .map_err(|_| pkcs8::Error::KeyMalformed)?; + (SignatureAlgId::SECP256K1, SigningKey::Secp256k1(key)) + } + p256::NistP256::OID => { + let key = p256::ecdsa::SigningKey::from_pkcs8_der(&pkcs8.to_der()?) + .map_err(|_| pkcs8::Error::KeyMalformed)?; + (SignatureAlgId::SECP256R1, SigningKey::Secp256r1(key)) + } + oid => { + error!("unsupported curve OID: {:?}", oid); + + return Err(pkcs8::Error::KeyMalformed); + } + }; + + Ok(Self { alg_id, key }) + } +} + +impl AttestationKey { + /// Creates a new signer using this key. + pub fn into_signer(self) -> Box { + match self.key { + SigningKey::Secp256k1(key) => { + Box::new(Secp256k1Signer::new(&key.to_bytes()).expect("key should be valid")) + } + SigningKey::Secp256r1(key) => { + Box::new(Secp256r1Signer::new(&key.to_bytes()).expect("key should be valid")) + } + } + } +} + +impl fmt::Debug for AttestationKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AttestationKey") + .field("alg_id", &self.alg_id) + .finish_non_exhaustive() + } +} + +enum SigningKey { + Secp256k1(k256::ecdsa::SigningKey), + Secp256r1(p256::ecdsa::SigningKey), +} diff --git a/notary-server/src/util.rs b/crates/notary/server/src/util.rs similarity index 100% rename from notary-server/src/util.rs rename to crates/notary/server/src/util.rs diff --git a/crates/notary/tests-integration/Cargo.toml b/crates/notary/tests-integration/Cargo.toml new file mode 100644 index 0000000000..19e16de248 --- /dev/null +++ b/crates/notary/tests-integration/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "notary-tests-integration" +version = "0.0.0" +edition = "2021" +publish = false + +[dev-dependencies] +notary-client = { workspace = true } +notary-server = { workspace = true } +tls-server-fixture = { workspace = true } +tlsn-common = { workspace = true } +tlsn-prover = { workspace = true } +tlsn-tls-core = { workspace = true } +tlsn-core = { workspace = true } + +async-tungstenite = { workspace = true, features = ["tokio-native-tls"] } +http = { workspace = true } +http-body-util = { workspace = true } +hyper = { workspace = true, features = ["client", "http1", "server"] } +hyper-tls = { version = "0.6", features = [ + "vendored", +] } # specify vendored feature to use statically linked copy of OpenSSL +hyper-util = { workspace = true, features = ["full"] } +rstest = { workspace = true } +rustls = { workspace = true } +rustls-pemfile = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true, features = ["full"] } +tokio-native-tls = { version = "0.3.1", features = ["vendored"] } +tokio-util = { workspace = true, features = ["compat"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } +uuid = { workspace = true, features = ["v4", "fast-rng"] } +ws_stream_tungstenite = { workspace = true, features = ["tokio_io"] } diff --git a/crates/notary/tests-integration/tests/notary.rs b/crates/notary/tests-integration/tests/notary.rs new file mode 100644 index 0000000000..da067cd58b --- /dev/null +++ b/crates/notary/tests-integration/tests/notary.rs @@ -0,0 +1,454 @@ +use async_tungstenite::{ + tokio::connect_async_with_tls_connector_and_config, tungstenite::protocol::WebSocketConfig, +}; +use http_body_util::{BodyExt as _, Full}; +use hyper::{body::Bytes, Request, StatusCode}; +use hyper_tls::HttpsConnector; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Builder}, + rt::{TokioExecutor, TokioIo}, +}; +use notary_client::{NotarizationRequest, NotaryClient, NotaryConnection}; +use rstest::rstest; +use rustls::{Certificate, RootCertStore}; +use std::{string::String, time::Duration}; +use tls_core::verify::WebPkiVerifier; +use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN}; +use tlsn_common::config::ProtocolConfig; +use tlsn_core::{request::RequestConfig, transcript::TranscriptCommitConfig, CryptoProvider}; +use tlsn_prover::{Prover, ProverConfig}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; +use tracing::debug; +use ws_stream_tungstenite::WsStream; + +use notary_server::{ + read_pem_file, run_server, AuthorizationProperties, LoggingProperties, NotarizationProperties, + NotarizationSessionRequest, NotarizationSessionResponse, NotaryServerProperties, + NotarySigningKeyProperties, ServerProperties, TLSProperties, +}; + +const MAX_SENT_DATA: usize = 1 << 13; +const MAX_RECV_DATA: usize = 1 << 13; + +const NOTARY_HOST: &str = "127.0.0.1"; +const NOTARY_DNS: &str = "tlsnotaryserver.io"; +const NOTARY_CA_CERT_PATH: &str = "../server/fixture/tls/rootCA.crt"; +const NOTARY_CA_CERT_BYTES: &[u8] = include_bytes!("../../server/fixture/tls/rootCA.crt"); +const API_KEY: &str = "test_api_key_0"; + +fn get_server_config(port: u16, tls_enabled: bool, auth_enabled: bool) -> NotaryServerProperties { + NotaryServerProperties { + server: ServerProperties { + name: NOTARY_DNS.to_string(), + host: NOTARY_HOST.to_string(), + port, + html_info: "example html response".to_string(), + }, + notarization: NotarizationProperties { + max_sent_data: 1 << 13, + max_recv_data: 1 << 14, + }, + tls: TLSProperties { + enabled: tls_enabled, + private_key_pem_path: "../server/fixture/tls/notary.key".to_string(), + certificate_pem_path: "../server/fixture/tls/notary.crt".to_string(), + }, + notary_key: NotarySigningKeyProperties { + private_key_pem_path: "../server/fixture/notary/notary.key".to_string(), + public_key_pem_path: "../server/fixture/notary/notary.pub".to_string(), + }, + logging: LoggingProperties { + level: "DEBUG".to_string(), + filter: None, + }, + authorization: AuthorizationProperties { + enabled: auth_enabled, + whitelist_csv_path: "../server/fixture/auth/whitelist.csv".to_string(), + }, + } +} + +async fn setup_config_and_server( + sleep_ms: u64, + port: u16, + tls_enabled: bool, + auth_enabled: bool, +) -> NotaryServerProperties { + let notary_config = get_server_config(port, tls_enabled, auth_enabled); + + let _ = tracing_subscriber::fmt::try_init(); + + let config = notary_config.clone(); + + // Run the notary server + tokio::spawn(async move { + run_server(&config).await.unwrap(); + }); + + // Sleep for a while to allow notary server to finish set up and start listening + tokio::time::sleep(Duration::from_millis(sleep_ms)).await; + + notary_config +} + +async fn tcp_prover(notary_config: NotaryServerProperties) -> (NotaryConnection, String) { + let mut notary_client_builder = NotaryClient::builder(); + + notary_client_builder + .host(¬ary_config.server.host) + .port(notary_config.server.port) + .enable_tls(false); + + if notary_config.authorization.enabled { + notary_client_builder.api_key(API_KEY); + } + + let notary_client = notary_client_builder.build().unwrap(); + + let notarization_request = NotarizationRequest::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + let accepted_request = notary_client + .request_notarization(notarization_request) + .await + .unwrap(); + + (accepted_request.io, accepted_request.id) +} + +async fn tls_prover(notary_config: NotaryServerProperties) -> (NotaryConnection, String) { + let mut certificate_file_reader = read_pem_file(NOTARY_CA_CERT_PATH).await.unwrap(); + let mut certificates: Vec = rustls_pemfile::certs(&mut certificate_file_reader) + .unwrap() + .into_iter() + .map(Certificate) + .collect(); + let certificate = certificates.remove(0); + + let mut root_cert_store = RootCertStore::empty(); + root_cert_store.add(&certificate).unwrap(); + + let notary_client = NotaryClient::builder() + .host(¬ary_config.server.name) + .port(notary_config.server.port) + .root_cert_store(root_cert_store) + .build() + .unwrap(); + + let notarization_request = NotarizationRequest::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + let accepted_request = notary_client + .request_notarization(notarization_request) + .await + .unwrap(); + + (accepted_request.io, accepted_request.id) +} + +#[rstest] +// For `tls_without_auth` test to pass, one needs to add " " in /etc/hosts +// so that this test programme can resolve the self-named NOTARY_DNS to NOTARY_HOST IP successfully. +#[case::tls_without_auth( + tls_prover(setup_config_and_server(100, 7047, true, false).await) +)] +#[case::tcp_with_auth( + tcp_prover(setup_config_and_server(100, 7048, false, true).await) +)] +#[case::tcp_without_auth( + tcp_prover(setup_config_and_server(100, 7049, false, false).await) +)] +#[awt] +#[tokio::test] +#[ignore = "expensive"] +async fn test_tcp_prover( + #[future] + #[case] + requested_notarization: (S, String), +) { + let (notary_socket, _) = requested_notarization; + + let mut root_store = tls_core::anchors::RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + let provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store, None), + ..Default::default() + }; + + let protocol_config = ProtocolConfig::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + // Prover config using the session_id returned from calling /session endpoint in + // notary client. + let prover_config = ProverConfig::builder() + .server_name(SERVER_DOMAIN) + .protocol_config(protocol_config) + .crypto_provider(provider) + .build() + .unwrap(); + + // Create a new Prover. + let prover = Prover::new(prover_config) + .setup(notary_socket.compat()) + .await + .unwrap(); + + // Connect to the Server. + let (client_socket, server_socket) = tokio::io::duplex(1 << 16); + let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat())); + + let (tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); + + // Spawn the Prover task to be run concurrently. + let prover_task = tokio::spawn(prover_fut); + + let (mut request_sender, connection) = + hyper::client::conn::http1::handshake(TokioIo::new(tls_connection.compat())) + .await + .unwrap(); + + tokio::spawn(connection); + + let request = Request::builder() + .uri(format!("https://{}/echo", SERVER_DOMAIN)) + .method("POST") + .header("Host", SERVER_DOMAIN) + .header("Connection", "close") + .body(Full::::new("echo".into())) + .unwrap(); + + debug!("Sending request to server: {:?}", request); + + let response = request_sender.send_request(request).await.unwrap(); + + assert!(response.status() == StatusCode::OK); + + let payload = response.into_body().collect().await.unwrap().to_bytes(); + debug!( + "Received response from server: {:?}", + &String::from_utf8_lossy(&payload) + ); + + server_task.await.unwrap().unwrap(); + + let mut prover = prover_task.await.unwrap().unwrap().start_notarize(); + + let (sent_len, recv_len) = prover.transcript().len(); + + let mut builder = TranscriptCommitConfig::builder(prover.transcript()); + + builder.commit_sent(&(0..sent_len)).unwrap(); + builder.commit_recv(&(0..recv_len)).unwrap(); + + let commit_config = builder.build().unwrap(); + + prover.transcript_commit(commit_config); + + let request = RequestConfig::builder().build().unwrap(); + + _ = prover.finalize(&request).await.unwrap(); + + debug!("Done notarization!"); +} + +#[tokio::test] +#[ignore = "expensive"] +async fn test_websocket_prover() { + // Notary server configuration setup + let notary_config = setup_config_and_server(100, 7050, true, false).await; + let notary_host = notary_config.server.host.clone(); + let notary_port = notary_config.server.port; + + // Connect to the notary server via TLS-WebSocket + // Try to avoid dealing with transport layer directly to mimic the limitation of + // a browser extension that uses websocket + // + // Establish TLS setup for connections later + let certificate = + tokio_native_tls::native_tls::Certificate::from_pem(NOTARY_CA_CERT_BYTES).unwrap(); + let notary_tls_connector = tokio_native_tls::native_tls::TlsConnector::builder() + .add_root_certificate(certificate) + .use_sni(false) + .danger_accept_invalid_certs(true) + .build() + .unwrap(); + + // Call the /session HTTP API to configure notarization and obtain session id + let mut hyper_http_connector = HttpConnector::new(); + hyper_http_connector.enforce_http(false); + let mut hyper_tls_connector = + HttpsConnector::from((hyper_http_connector, notary_tls_connector.clone().into())); + hyper_tls_connector.https_only(true); + let https_client = Builder::new(TokioExecutor::new()).build(hyper_tls_connector); + + // Build the HTTP request to configure notarization + let payload = serde_json::to_string(&NotarizationSessionRequest { + client_type: notary_server::ClientType::Websocket, + max_sent_data: Some(MAX_SENT_DATA), + max_recv_data: Some(MAX_RECV_DATA), + }) + .unwrap(); + + let request = Request::builder() + .uri(format!("https://{notary_host}:{notary_port}/session")) + .method("POST") + .header("Host", notary_host.clone()) + // Need to specify application/json for axum to parse it as json + .header("Content-Type", "application/json") + .body(Full::new(Bytes::from(payload))) + .unwrap(); + + debug!("Sending request"); + + let response = https_client.request(request).await.unwrap(); + + debug!("Sent request"); + + assert!(response.status() == StatusCode::OK); + + debug!("Response OK"); + + // Pretty printing :) + let payload = response.into_body().collect().await.unwrap().to_bytes(); + let notarization_response = + serde_json::from_str::(&String::from_utf8_lossy(&payload)) + .unwrap(); + + debug!("Notarization response: {:?}", notarization_response,); + + // Connect to the Notary via TLS-Websocket + // + // Note: This will establish a new TLS-TCP connection instead of reusing the + // previous TCP connection used in the previous HTTP POST request because we + // cannot claim back the tcp connection used in hyper client while using its + // high level request function — there does not seem to have a crate that can + // let you make a request without establishing TCP connection where you can + // claim the TCP connection later after making the request + let request = http::Request::builder() + // Need to specify the session_id so that notary server knows the right configuration to use + // as the configuration is set in the previous HTTP call + .uri(format!( + "wss://{}:{}/notarize?sessionId={}", + notary_host, + notary_port, + notarization_response.session_id.clone() + )) + .header("Host", notary_host.clone()) + .header("Sec-WebSocket-Key", uuid::Uuid::new_v4().to_string()) + .header("Sec-WebSocket-Version", "13") + .header("Connection", "Upgrade") + .header("Upgrade", "Websocket") + .body(()) + .unwrap(); + + let (notary_ws_stream, _) = connect_async_with_tls_connector_and_config( + request, + Some(notary_tls_connector.into()), + Some(WebSocketConfig::default()), + ) + .await + .unwrap(); + + // Wrap the socket with the adapter so that we get AsyncRead and AsyncWrite + // implemented + let notary_ws_socket = WsStream::new(notary_ws_stream); + + // Connect to the Server + let (client_socket, server_socket) = tokio::io::duplex(1 << 16); + let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat())); + + let mut root_store = tls_core::anchors::RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + let provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store, None), + ..Default::default() + }; + + let protocol_config = ProtocolConfig::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + // Basic default prover config — use the responded session id from notary server + let prover_config = ProverConfig::builder() + .server_name(SERVER_DOMAIN) + .protocol_config(protocol_config) + .crypto_provider(provider) + .build() + .unwrap(); + + // Bind the Prover to the sockets + let prover = Prover::new(prover_config) + .setup(notary_ws_socket) + .await + .unwrap(); + let (tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); + + // Spawn the Prover and Mux tasks to be run concurrently + let prover_task = tokio::spawn(prover_fut); + + let (mut request_sender, connection) = + hyper::client::conn::http1::handshake(TokioIo::new(tls_connection.compat())) + .await + .unwrap(); + + tokio::spawn(connection); + + let request = Request::builder() + .uri(format!("https://{}/echo", SERVER_DOMAIN)) + .header("Host", SERVER_DOMAIN) + .header("Connection", "close") + .method("POST") + .body(Full::::new("echo".into())) + .unwrap(); + + debug!("Sending request to server: {:?}", request); + + let response = request_sender.send_request(request).await.unwrap(); + + assert!(response.status() == StatusCode::OK); + + let payload = response.into_body().collect().await.unwrap().to_bytes(); + debug!( + "Received response from server: {:?}", + &String::from_utf8_lossy(&payload) + ); + + server_task.await.unwrap().unwrap(); + + let mut prover = prover_task.await.unwrap().unwrap().start_notarize(); + + let (sent_len, recv_len) = prover.transcript().len(); + + let mut builder = TranscriptCommitConfig::builder(prover.transcript()); + + builder.commit_sent(&(0..sent_len)).unwrap(); + builder.commit_recv(&(0..recv_len)).unwrap(); + + let commit_config = builder.build().unwrap(); + + prover.transcript_commit(commit_config); + + let request = RequestConfig::builder().build().unwrap(); + + _ = prover.finalize(&request).await.unwrap(); + + debug!("Done notarization!"); +} diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml new file mode 100644 index 0000000000..f911090c4b --- /dev/null +++ b/crates/prover/Cargo.toml @@ -0,0 +1,64 @@ +[package] +name = "tlsn-prover" +authors = ["TLSNotary Team"] +description = "Contains the prover library" +keywords = ["tls", "mpc", "2pc", "prover"] +categories = ["cryptography"] +license = "MIT OR Apache-2.0" +version = "0.1.0-alpha.7" +edition = "2021" + +[features] +default = ["rayon"] +rayon = ["mpz-common/rayon"] +force-st = ["mpz-common/force-st"] +# Enables the AuthDecode protocol which allows to prove zk-friendly hashes over the transcript data. +# This is an early iteration meant for gathering feedback and assessing performance. As such, this +# feature has an "_unsafe" suffix since it will leak the ranges of data committed to. +# Future iterations will get rid of the leakage at the cost of worse performance. +# This feature is EXPERIMENTAL and will be removed in future releases without prior notice. +authdecode_unsafe = [ + "tlsn-authdecode-core", + "tlsn-authdecode-transcript", + "tlsn-core/poseidon", + "tlsn-common/authdecode_unsafe_common", +] + +[dependencies] +tlsn-authdecode-core = { workspace = true, optional = true } +tlsn-authdecode-transcript = { workspace = true, optional = true } +tlsn-common = { workspace = true } +tlsn-core = { workspace = true } +tlsn-formats = { workspace = true, optional = true } +tlsn-tls-client = { workspace = true } +tlsn-tls-client-async = { workspace = true } +tlsn-tls-core = { workspace = true } +tlsn-tls-mpc = { workspace = true } +tlsn-utils = { workspace = true } +tlsn-utils-aio = { workspace = true } + +serio = { workspace = true, features = ["compat"] } +uid-mux = { workspace = true, features = ["serio"] } + +mpz-common = { workspace = true } +mpz-core = { workspace = true } +mpz-garble = { workspace = true } +mpz-garble-core = { workspace = true } +mpz-ole = { workspace = true } +mpz-ot = { workspace = true } +mpz-share-conversion = { workspace = true } + +bytes = { workspace = true } +derive_builder = { workspace = true } +futures = { workspace = true } +opaque-debug = { workspace = true } +rand = { workspace = true } +serde = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } +web-time = { workspace = true } +webpki-roots = { workspace = true } + +[target.'cfg(target_arch = "wasm32")'.dependencies] +getrandom = { version = "0.2", features = ["js"] } +ring = { version = "0.17", features = ["wasm32_unknown_unknown_js"] } diff --git a/crates/prover/src/authdecode.rs b/crates/prover/src/authdecode.rs new file mode 100644 index 0000000000..6764f049da --- /dev/null +++ b/crates/prover/src/authdecode.rs @@ -0,0 +1,274 @@ +use std::{mem, ops::Range}; + +use authdecode_core::{ + backend::{ + halo2::{Bn256F, CHUNK_SIZE}, + traits::Field, + }, + prover::{ + CommitmentData, Committed, Initialized, ProofGenerated, ProverError as CoreProverError, + }, + Prover as AuthDecodeProver, SSP, +}; +use authdecode_transcript::{TranscriptData, TranscriptEncoder}; +use mpz_core::utils::blake3; +use tlsn_core::{ + hash::{Blinder, HashAlgId}, + transcript::{encoding::EncodingProvider, Direction, Idx, Transcript}, +}; +use utils::range::RangeSet; + +use crate::{error, ProverError}; + +/// Returns an AuthDecode prover for a TLS transcript based on the hashing algorithm used. +pub(crate) fn authdecode_prover( + inputs: Vec<(Direction, Range, HashAlgId, Blinder)>, + encoding_provider: &(dyn EncodingProvider + Send + Sync), + transcript: &Transcript, + max_plaintext: usize, +) -> Result { + if inputs.is_empty() { + return Err(ProverError::authdecode("inputs vector is empty")); + } + + let alg = inputs.first().expect("At least one input is expected").2; + + let mut total_plaitext = 0; + + let adinputs = inputs + .iter() + .map(|(dir, range, this_alg, blinder)| { + if &alg != this_alg { + return Err(ProverError::authdecode( + "more than one hash algorithms are present", + )); + } + + total_plaitext += range.len(); + if total_plaitext > max_plaintext { + return Err(ProverError::authdecode("max_plaintext exceeded")); + } + + let idx = Idx::new(RangeSet::new(&[range.clone()])); + + let mut encodings = encoding_provider.provide_bit_encodings(*dir, &idx).ok_or( + ProverError::authdecode(format!( + "direction {} and index {:?} were not found by the encoding provider", + &dir, &idx + )), + )?; + // Reverse byte encodings to MSB0. + for chunk in encodings.chunks_mut(8) { + chunk.reverse(); + } + + let plaintext = transcript + .get(*dir, &idx) + .ok_or(ProverError::authdecode(format!( + "direction {} and index {:?} were not found in the transcript", + &dir, &idx + )))? + .data() + .to_vec(); + + Ok(AuthDecodeInput::new( + blinder.clone(), + plaintext, + encodings, + range.clone(), + *dir, + )) + }) + .collect::, ProverError>>()?; + + match alg { + HashAlgId::POSEIDON_BN256_434 => Ok(PoseidonHalo2Prover::new(AuthdecodeInputs(adinputs))), + alg => Err(error::ProverError::authdecode(format!( + "unsupported hash algorithm {:?} for AuthDecode", + alg + ))), + } +} + +/// An AuthDecode prover for a TLS transcript. +pub(crate) trait TranscriptProver { + /// Creates a new prover instantiated with the given `inputs`. + /// + /// # Panics + /// + /// Panics if the `inputs` are malformed. + fn new(inputs: AuthdecodeInputs) -> Self; + + /// Commits to the commitment data which the prover was instantiated with. + /// + /// Returns a message to be passed to the verifier. + fn commit(&mut self) -> Result; + + /// Creates proofs using the `seed` to generate encodings. + /// + /// Returns a message to be passed to the verifier. + fn prove(&mut self, seed: [u8; 32]) -> Result; + + /// Returns the hash algorithm used to create commitments. + fn alg(&self) -> HashAlgId; +} + +/// An AuthDecode prover for a batch of data from a TLS transcript using the +/// POSEIDON_HALO2 hash algorithm. +pub(crate) struct PoseidonHalo2Prover { + /// A batch of AuthDecode commitment data with the plaintext salt. + commitment_data: Option, Bn256F)>>, + /// The prover in the [Initialized] state. + initialized: Option>, + /// The prover in the [Committed] state. + committed: Option, Bn256F>>, + /// The prover in the [ProofGenerated] state. + proof_generated: + Option, Bn256F>>, +} + +impl TranscriptProver for PoseidonHalo2Prover { + fn new(inputs: AuthdecodeInputs) -> Self { + let inputs = inputs.into_inner(); + + for input in &inputs { + assert!(input.range.len() <= CHUNK_SIZE); + } + assert!(!inputs.is_empty()); + // All encodings must have at least SSP bitlength. + assert!(inputs[0].encodings[0].len() * 8 >= SSP); + + let commitment_data = inputs + .into_iter() + .map(|input| { + // Hash the encodings to break the correlation and truncate them. + let hashed_encodings = input + .encodings + .into_iter() + .map(|enc| { + let mut enc_new = [0u8; SSP / 8]; + enc_new.copy_from_slice(&blake3(&enc)[0..SSP / 8]); + enc_new + }) + .collect::>(); + + ( + CommitmentData::new( + &input.plaintext, + &hashed_encodings, + TranscriptData::new(input.direction, &input.range), + ), + Bn256F::from_bytes_be(input.salt.as_inner().to_vec()), + ) + }) + .collect::>(); + + Self { + initialized: Some(AuthDecodeProver::new(Box::new( + authdecode_core::backend::halo2::prover::Prover::new(), + ))), + committed: None, + proof_generated: None, + commitment_data: Some(commitment_data), + } + } + + fn commit(&mut self) -> Result { + let prover = mem::take(&mut self.initialized).ok_or(TranscriptProverError::Other( + "The prover was called in the wrong state".to_string(), + ))?; + let commitment = mem::take(&mut self.commitment_data) + .ok_or(TranscriptProverError::Other( + "The commitment data was not set".to_string(), + ))? + .into_iter() + .map(|(comm_data, salt)| (comm_data, vec![salt])) + .collect::>(); + + let (prover, msg) = prover.commit_with_salt(commitment)?; + + self.committed = Some(prover); + Ok(msg) + } + + fn prove(&mut self, seed: [u8; 32]) -> Result { + let encoding_provider = TranscriptEncoder::new(seed); + + let prover = mem::take(&mut self.committed).ok_or(TranscriptProverError::Other( + "The prover was called in the wrong state".to_string(), + ))?; + + let (prover, msg) = prover.prove(&encoding_provider)?; + self.proof_generated = Some(prover); + + Ok(msg) + } + + fn alg(&self) -> HashAlgId { + HashAlgId::POSEIDON_BN256_434 + } +} + +#[derive(Debug, thiserror::Error)] +/// Error for [TranscriptProver]. +pub(crate) enum TranscriptProverError { + #[error(transparent)] + CoreProtocolError(#[from] CoreProverError), + #[error("AuthDecode prover failed with an error: {0}")] + Other(String), +} + +/// An AuthDecode input to prove a single range of a TLS transcript. Also contains the `salt` to be +/// used for the plaintext commitment. +struct AuthDecodeInput { + /// The salt of the plaintext commitment. + pub salt: Blinder, + /// The plaintext to commit to. + pub plaintext: Vec, + /// The encodings to commit to in MSB0 bit order. + pub encodings: Vec>, + /// The byterange of the plaintext. + pub range: Range, + /// The direction of the range in the transcript. + pub direction: Direction, +} + +impl AuthDecodeInput { + /// Creates a new `AuthDecodeInput`. + /// + /// # Panics + /// + /// Panics if some of the arguments are not correct. + fn new( + salt: Blinder, + plaintext: Vec, + encodings: Vec>, + range: Range, + direction: Direction, + ) -> Self { + assert!(!range.is_empty()); + assert!(plaintext.len() * 8 == encodings.len()); + assert!(plaintext.len() == range.len()); + // All encodings should have the same length. + for pair in encodings.windows(2) { + assert!(pair[0].len() == pair[1].len()); + } + Self { + salt, + plaintext, + encodings, + range, + direction, + } + } +} + +/// A batch of AuthDecode inputs. +pub(crate) struct AuthdecodeInputs(Vec); + +impl AuthdecodeInputs { + /// Consumes self, returning the inner vector. + fn into_inner(self) -> Vec { + self.0 + } +} diff --git a/crates/prover/src/config.rs b/crates/prover/src/config.rs new file mode 100644 index 0000000000..5c003702f3 --- /dev/null +++ b/crates/prover/src/config.rs @@ -0,0 +1,103 @@ +use std::sync::Arc; + +use mpz_ot::{chou_orlandi, kos}; +use tls_mpc::{MpcTlsCommonConfig, MpcTlsLeaderConfig, TranscriptConfig}; +use tlsn_common::config::ProtocolConfig; +use tlsn_core::{connection::ServerName, CryptoProvider}; + +/// Configuration for the prover +#[derive(Debug, Clone, derive_builder::Builder)] +pub struct ProverConfig { + /// The server DNS name. + #[builder(setter(into))] + server_name: ServerName, + /// Protocol configuration to be checked with the verifier. + protocol_config: ProtocolConfig, + /// Whether the `deferred decryption` feature is toggled on from the start + /// of the MPC-TLS connection. + /// + /// See `defer_decryption_from_start` in [tls_mpc::MpcTlsLeaderConfig]. + #[builder(default = "true")] + defer_decryption_from_start: bool, + /// Cryptography provider. + #[builder(default, setter(into))] + crypto_provider: Arc, +} + +impl ProverConfig { + /// Create a new builder for `ProverConfig`. + pub fn builder() -> ProverConfigBuilder { + ProverConfigBuilder::default() + } + + /// Returns the server DNS name. + pub fn server_name(&self) -> &ServerName { + &self.server_name + } + + /// Returns the crypto provider. + pub fn crypto_provider(&self) -> &CryptoProvider { + &self.crypto_provider + } + + /// Returns the protocol configuration. + pub fn protocol_config(&self) -> &ProtocolConfig { + &self.protocol_config + } + + /// Returns whether the `deferred decryption` feature is toggled on from the + /// start of the MPC-TLS connection. + pub fn defer_decryption_from_start(&self) -> bool { + self.defer_decryption_from_start + } + + pub(crate) fn build_mpc_tls_config(&self) -> MpcTlsLeaderConfig { + MpcTlsLeaderConfig::builder() + .common( + MpcTlsCommonConfig::builder() + .tx_config( + TranscriptConfig::default_tx() + .max_online_size(self.protocol_config.max_sent_data()) + .build() + .unwrap(), + ) + .rx_config( + TranscriptConfig::default_rx() + .max_online_size(self.protocol_config.max_recv_data_online()) + .max_offline_size( + self.protocol_config.max_recv_data() + - self.protocol_config.max_recv_data_online(), + ) + .build() + .unwrap(), + ) + .handshake_commit(true) + .build() + .unwrap(), + ) + .build() + .unwrap() + } + + pub(crate) fn build_base_ot_sender_config(&self) -> chou_orlandi::SenderConfig { + chou_orlandi::SenderConfig::builder() + .receiver_commit() + .build() + .unwrap() + } + + pub(crate) fn build_base_ot_receiver_config(&self) -> chou_orlandi::ReceiverConfig { + chou_orlandi::ReceiverConfig::default() + } + + pub(crate) fn build_ot_sender_config(&self) -> kos::SenderConfig { + kos::SenderConfig::default() + } + + pub(crate) fn build_ot_receiver_config(&self) -> kos::ReceiverConfig { + kos::ReceiverConfig::builder() + .sender_commit() + .build() + .unwrap() + } +} diff --git a/crates/prover/src/error.rs b/crates/prover/src/error.rs new file mode 100644 index 0000000000..7dfdb3bfee --- /dev/null +++ b/crates/prover/src/error.rs @@ -0,0 +1,159 @@ +use std::{error::Error, fmt}; +use tls_mpc::MpcTlsError; + +/// Error for [`Prover`](crate::Prover). +#[derive(Debug, thiserror::Error)] +pub struct ProverError { + kind: ErrorKind, + source: Option>, +} + +impl ProverError { + fn new(kind: ErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } + + pub(crate) fn config(source: E) -> Self + where + E: Into>, + { + Self::new(ErrorKind::Config, source) + } + + pub(crate) fn attestation(source: E) -> Self + where + E: Into>, + { + Self::new(ErrorKind::Attestation, source) + } + + #[cfg(feature = "authdecode_unsafe")] + pub(crate) fn authdecode(source: E) -> Self + where + E: Into>, + { + Self::new(ErrorKind::AuthDecode, source) + } +} + +#[derive(Debug)] +enum ErrorKind { + Io, + Mpc, + Config, + Attestation, + #[cfg(feature = "authdecode_unsafe")] + AuthDecode, +} + +impl fmt::Display for ProverError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("prover error: ")?; + + match self.kind { + ErrorKind::Io => f.write_str("io error")?, + ErrorKind::Mpc => f.write_str("mpc error")?, + ErrorKind::Config => f.write_str("config error")?, + ErrorKind::Attestation => f.write_str("attestation error")?, + #[cfg(feature = "authdecode_unsafe")] + ErrorKind::AuthDecode => f.write_str("authdecode error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +impl From for ProverError { + fn from(e: std::io::Error) -> Self { + Self::new(ErrorKind::Io, e) + } +} + +impl From for ProverError { + fn from(e: tls_client_async::ConnectionError) -> Self { + Self::new(ErrorKind::Io, e) + } +} + +impl From for ProverError { + fn from(e: uid_mux::yamux::ConnectionError) -> Self { + Self::new(ErrorKind::Io, e) + } +} + +impl From for ProverError { + fn from(e: mpz_common::ContextError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for ProverError { + fn from(e: MpcTlsError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for ProverError { + fn from(e: mpz_ot::OTError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for ProverError { + fn from(e: mpz_ot::kos::SenderError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for ProverError { + fn from(e: mpz_ole::OLEError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for ProverError { + fn from(e: mpz_ot::kos::ReceiverError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for ProverError { + fn from(e: mpz_garble::VmError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for ProverError { + fn from(e: mpz_garble::protocol::deap::DEAPError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for ProverError { + fn from(e: mpz_garble::MemoryError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for ProverError { + fn from(e: mpz_garble::ProveError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +#[cfg(feature = "authdecode_unsafe")] +impl From for ProverError { + fn from(e: crate::authdecode::TranscriptProverError) -> Self { + Self::new(ErrorKind::AuthDecode, e) + } +} diff --git a/crates/prover/src/future.rs b/crates/prover/src/future.rs new file mode 100644 index 0000000000..9ea8bbfffa --- /dev/null +++ b/crates/prover/src/future.rs @@ -0,0 +1,31 @@ +//! This module collects futures which are used by the [Prover]. + +use super::{state, Prover, ProverControl, ProverError}; +use futures::Future; +use std::pin::Pin; + +/// Prover future which must be polled for the TLS connection to make progress. +pub struct ProverFuture { + #[allow(clippy::type_complexity)] + pub(crate) fut: + Pin, ProverError>> + Send + 'static>>, + pub(crate) ctrl: ProverControl, +} + +impl ProverFuture { + /// Returns a controller for the prover for advanced functionality. + pub fn control(&self) -> ProverControl { + self.ctrl.clone() + } +} + +impl Future for ProverFuture { + type Output = Result, ProverError>; + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + self.fut.as_mut().poll(cx) + } +} diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs new file mode 100644 index 0000000000..de630b83e6 --- /dev/null +++ b/crates/prover/src/lib.rs @@ -0,0 +1,427 @@ +//! TLSNotary prover library. + +#![deny(missing_docs, unreachable_pub, unused_must_use)] +#![deny(clippy::all)] +#![forbid(unsafe_code)] + +#[cfg(feature = "authdecode_unsafe")] +mod authdecode; +mod config; +mod error; +mod future; +mod notarize; +mod prove; +pub mod state; + +pub use config::{ProverConfig, ProverConfigBuilder, ProverConfigBuilderError}; +pub use error::ProverError; +pub use future::ProverFuture; +use state::{Notarize, Prove}; + +use futures::{AsyncRead, AsyncWrite, TryFutureExt}; +use mpz_common::Allocate; +use mpz_garble::config::Role as DEAPRole; +use mpz_ot::{chou_orlandi, kos}; +use rand::Rng; +use serio::{SinkExt, StreamExt}; +use std::sync::Arc; +use tls_client::{ClientConnection, ServerName as TlsServerName}; +use tls_client_async::{bind_client, ClosedConnection, TlsConnection}; +use tls_mpc::{build_components, LeaderCtrl, MpcTlsLeader, TlsRole}; +use tlsn_common::{ + mux::{attach_mux, MuxControl}, + DEAPThread, Executor, OTReceiver, OTSender, Role, +}; +use tlsn_core::{ + connection::{ + ConnectionInfo, HandshakeData, HandshakeDataV1_2, ServerCertData, ServerSignature, + TranscriptLength, + }, + transcript::Transcript, +}; +use uid_mux::FramedUidMux as _; + +use tracing::{debug, info_span, instrument, Instrument, Span}; + +/// A prover instance. +#[derive(Debug)] +pub struct Prover { + config: ProverConfig, + span: Span, + state: T, +} + +impl Prover { + /// Creates a new prover. + /// + /// # Arguments + /// + /// * `config` - The configuration for the prover. + pub fn new(config: ProverConfig) -> Self { + let span = info_span!("prover"); + Self { + config, + span, + state: state::Initialized, + } + } + + /// Sets up the prover. + /// + /// This performs all MPC setup prior to establishing the connection to the + /// application server. + /// + /// # Arguments + /// + /// * `socket` - The socket to the TLS verifier. + #[instrument(parent = &self.span, level = "debug", skip_all, err)] + pub async fn setup( + self, + socket: S, + ) -> Result, ProverError> { + let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover); + + let mut io = mux_fut + .poll_with(mux_ctrl.open_framed(b"tlsnotary")) + .await?; + + // Sends protocol configuration to verifier for compatibility check. + mux_fut + .poll_with(io.send(self.config.protocol_config().clone())) + .await?; + + // Maximum thread forking concurrency of 8. + // TODO: Determine the optimal number of threads. + let mut exec = Executor::new(mux_ctrl.clone(), 8); + + let (mpc_tls, vm, ot_recv) = mux_fut + .poll_with(setup_mpc_backend(&self.config, &mux_ctrl, &mut exec)) + .await?; + + let ctx = mux_fut.poll_with(exec.new_thread()).await?; + + Ok(Prover { + config: self.config, + span: self.span, + state: state::Setup { + io, + mux_ctrl, + mux_fut, + mpc_tls, + vm, + ot_recv, + ctx, + }, + }) + } +} + +impl Prover { + /// Connects to the server using the provided socket. + /// + /// Returns a handle to the TLS connection, a future which returns the + /// prover once the connection is closed. + /// + /// # Arguments + /// + /// * `socket` - The socket to the server. + #[instrument(parent = &self.span, level = "debug", skip_all, err)] + pub async fn connect( + self, + socket: S, + ) -> Result<(TlsConnection, ProverFuture), ProverError> { + let state::Setup { + io, + mux_ctrl, + mut mux_fut, + mpc_tls, + vm, + ot_recv, + ctx, + } = self.state; + + let (mpc_ctrl, mpc_fut) = mpc_tls.run(); + + let server_name = + TlsServerName::try_from(self.config.server_name().as_str()).map_err(|_| { + ProverError::config(format!( + "invalid server name: {}", + self.config.server_name() + )) + })?; + + let config = tls_client::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(self.config.crypto_provider().cert.root_store().clone()) + .with_no_client_auth(); + let client = + ClientConnection::new(Arc::new(config), Box::new(mpc_ctrl.clone()), server_name) + .map_err(ProverError::config)?; + + let (conn, conn_fut) = bind_client(socket, client); + + let start_time = web_time::UNIX_EPOCH.elapsed().unwrap().as_secs(); + + let fut = Box::pin({ + let span = self.span.clone(); + let mpc_ctrl = mpc_ctrl.clone(); + async move { + let conn_fut = async { + let ClosedConnection { sent, recv, .. } = mux_fut + .poll_with(conn_fut.map_err(ProverError::from)) + .await?; + + mpc_ctrl.close_connection().await?; + + Ok::<_, ProverError>((sent, recv)) + }; + + let ((sent, recv), mpc_tls_data) = futures::try_join!( + conn_fut, + mpc_fut.in_current_span().map_err(ProverError::from) + )?; + + let connection_info = ConnectionInfo { + time: start_time, + version: mpc_tls_data + .protocol_version + .try_into() + .expect("only supported version should have been accepted"), + transcript_length: TranscriptLength { + sent: sent.len() as u32, + received: recv.len() as u32, + }, + }; + + let server_cert_data = ServerCertData { + certs: mpc_tls_data + .server_cert_details + .cert_chain() + .iter() + .cloned() + .map(|c| c.into()) + .collect(), + sig: ServerSignature { + scheme: mpc_tls_data + .server_kx_details + .kx_sig() + .scheme + .try_into() + .expect("only supported signature scheme should have been accepted"), + sig: mpc_tls_data.server_kx_details.kx_sig().sig.0.clone(), + }, + handshake: HandshakeData::V1_2(HandshakeDataV1_2 { + client_random: mpc_tls_data.client_random.0, + server_random: mpc_tls_data.server_random.0, + server_ephemeral_key: mpc_tls_data + .server_public_key + .try_into() + .expect("only supported key scheme should have been accepted"), + }), + }; + + Ok(Prover { + config: self.config, + span: self.span, + state: state::Closed { + io, + mux_ctrl, + mux_fut, + vm, + ot_recv, + ctx, + connection_info, + server_cert_data, + transcript: Transcript::new(sent, recv), + }, + }) + } + .instrument(span) + }); + + Ok(( + conn, + ProverFuture { + fut, + ctrl: ProverControl { mpc_ctrl }, + }, + )) + } +} + +impl Prover { + /// Returns the transcript. + pub fn transcript(&self) -> &Transcript { + &self.state.transcript + } + + /// Starts notarization of the TLS session. + /// + /// Used when the TLS verifier is a Notary to transition the prover to the + /// next state where it can generate commitments to the transcript prior + /// to finalization. + pub fn start_notarize(self) -> Prover { + Prover { + config: self.config, + span: self.span, + state: self.state.into(), + } + } + + /// Starts proving the TLS session. + /// + /// This function transitions the prover into a state where it can prove + /// content of the transcript. + pub fn start_prove(self) -> Prover { + Prover { + config: self.config, + span: self.span, + state: self.state.into(), + } + } +} + +/// Performs a setup of the various MPC subprotocols. +#[instrument(level = "debug", skip_all, err)] +async fn setup_mpc_backend( + config: &ProverConfig, + mux: &MuxControl, + exec: &mut Executor, +) -> Result<(MpcTlsLeader, DEAPThread, OTReceiver), ProverError> { + debug!("starting MPC backend setup"); + + let mut ot_sender = kos::Sender::new( + config.build_ot_sender_config(), + chou_orlandi::Receiver::new(config.build_base_ot_receiver_config()), + ); + ot_sender.alloc(config.protocol_config().ot_sender_setup_count(Role::Prover)); + + let mut ot_receiver = kos::Receiver::new( + config.build_ot_receiver_config(), + chou_orlandi::Sender::new(config.build_base_ot_sender_config()), + ); + ot_receiver.alloc( + config + .protocol_config() + .ot_receiver_setup_count(Role::Prover), + ); + + let ot_sender = OTSender::new(ot_sender); + let ot_receiver = OTReceiver::new(ot_receiver); + + let ( + ctx_vm, + ctx_ke_0, + ctx_ke_1, + ctx_prf_0, + ctx_prf_1, + ctx_encrypter_block_cipher, + ctx_encrypter_stream_cipher, + ctx_encrypter_ghash, + ctx_encrypter, + ctx_decrypter_block_cipher, + ctx_decrypter_stream_cipher, + ctx_decrypter_ghash, + ctx_decrypter, + ) = futures::try_join!( + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + )?; + + let vm = DEAPThread::new( + DEAPRole::Leader, + rand::rngs::OsRng.gen(), + ctx_vm, + ot_sender.clone(), + ot_receiver.clone(), + ); + + let mpc_tls_config = config.build_mpc_tls_config(); + let (ke, prf, encrypter, decrypter) = build_components( + TlsRole::Leader, + mpc_tls_config.common(), + ctx_ke_0, + ctx_encrypter, + ctx_decrypter, + ctx_encrypter_ghash, + ctx_decrypter_ghash, + vm.new_thread(ctx_ke_1, ot_sender.clone(), ot_receiver.clone())?, + vm.new_thread(ctx_prf_0, ot_sender.clone(), ot_receiver.clone())?, + vm.new_thread(ctx_prf_1, ot_sender.clone(), ot_receiver.clone())?, + vm.new_thread( + ctx_encrypter_block_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + vm.new_thread( + ctx_decrypter_block_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + vm.new_thread( + ctx_encrypter_stream_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + vm.new_thread( + ctx_decrypter_stream_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + ot_sender.clone(), + ot_receiver.clone(), + ); + + let channel = mux.open_framed(b"mpc_tls").await?; + let mut mpc_tls = MpcTlsLeader::new( + mpc_tls_config, + Box::new(StreamExt::compat_stream(channel)), + ke, + prf, + encrypter, + decrypter, + ); + + mpc_tls.setup().await?; + + debug!("MPC backend setup complete"); + + Ok((mpc_tls, vm, ot_receiver)) +} + +/// A controller for the prover. +#[derive(Clone)] +pub struct ProverControl { + mpc_ctrl: LeaderCtrl, +} + +impl ProverControl { + /// Defers decryption of data from the server until the server has closed + /// the connection. + /// + /// This is a performance optimization which will significantly reduce the + /// amount of upload bandwidth used by the prover. + /// + /// # Notes + /// + /// * The prover may need to close the connection to the server in order for + /// it to close the connection on its end. If neither the prover or server + /// close the connection this will cause a deadlock. + pub async fn defer_decryption(&self) -> Result<(), ProverError> { + self.mpc_ctrl + .defer_decryption() + .await + .map_err(ProverError::from) + } +} diff --git a/crates/prover/src/notarize.rs b/crates/prover/src/notarize.rs new file mode 100644 index 0000000000..094d40e4a0 --- /dev/null +++ b/crates/prover/src/notarize.rs @@ -0,0 +1,232 @@ +//! This module handles the notarization phase of the prover. +//! +//! The prover deals with a TLS verifier that is only a notary. + +use crate::{state::Notarize, Prover, ProverError}; + +use mpz_ot::VerifiableOTReceiver; +use serio::{stream::IoStreamExt as _, SinkExt as _}; +use tlsn_core::{ + attestation::Attestation, + request::{Request, RequestConfig}, + transcript::{encoding::EncodingTree, Transcript, TranscriptCommitConfig}, + Secrets, +}; +use tracing::{debug, instrument}; + +#[cfg(feature = "authdecode_unsafe")] +use std::ops::Range; + +#[cfg(feature = "authdecode_unsafe")] +use crate::authdecode::{authdecode_prover, TranscriptProver}; +#[cfg(feature = "authdecode_unsafe")] +use tlsn_core::{ + hash::{Blinder, HashAlgId}, + transcript::Direction, +}; + +impl Prover { + /// Returns the transcript. + pub fn transcript(&self) -> &Transcript { + &self.state.transcript + } + + /// Configures transcript commitments. + pub fn transcript_commit(&mut self, config: TranscriptCommitConfig) { + self.state.transcript_commit_config = Some(config); + } + + /// Finalizes the notarization. + #[instrument(parent = &self.span, level = "debug", skip_all, err)] + pub async fn finalize( + self, + config: &RequestConfig, + ) -> Result<(Attestation, Secrets), ProverError> { + let Notarize { + mut io, + mux_ctrl, + mut mux_fut, + mut vm, + mut ot_recv, + mut ctx, + connection_info, + server_cert_data, + transcript, + encoding_provider, + transcript_commit_config, + encoding_commitments, + } = self.state; + + let provider = self.config.crypto_provider(); + + let hasher = provider.hash.get(config.hash_alg()).unwrap(); + + let mut builder = Request::builder(config); + + builder + .server_name(self.config.server_name().clone()) + .server_cert_data(server_cert_data) + .transcript(transcript.clone()); + + if let Some(config) = transcript_commit_config { + if config.has_encoding() { + let tree = match encoding_commitments { + Some(tree) => tree, + None => EncodingTree::new( + hasher, + config.iter_encoding(), + &*encoding_provider, + &connection_info.transcript_length, + ) + .map_err(ProverError::attestation)?, + }; + + builder.encoding_tree(tree); + } + + if config.has_plaintext_hashes() { + builder.plaintext_hashes(config.plaintext_hashes()); + } + } + + let (request, secrets) = builder.build(provider).map_err(ProverError::attestation)?; + + let attestation = mux_fut + .poll_with(async { + debug!("starting finalization"); + + io.send(request.clone()).await?; + + ot_recv.accept_reveal(&mut ctx).await?; + + debug!("received OT secret"); + + vm.finalize().await?; + + let attestation: Attestation = io.expect_next().await?; + + Ok::<_, ProverError>(attestation) + }) + .await?; + + // Wait for the notary to correctly close the connection. + if !mux_fut.is_complete() { + mux_ctrl.mux().close(); + mux_fut.await?; + } + + // Check the attestation is consistent with the Prover's view. + request + .validate(&attestation) + .map_err(ProverError::attestation)?; + + Ok((attestation, secrets)) + } + + /// Finalizes the notarization and runs the AuthDecode protocol. + #[instrument(parent = &self.span, level = "debug", skip_all, err)] + #[cfg(feature = "authdecode_unsafe")] + pub async fn finalize_with_authdecode( + self, + config: &RequestConfig, + authdecode_inputs: Vec<(Direction, Range, HashAlgId, Blinder)>, + ) -> Result<(Attestation, Secrets), ProverError> { + let Notarize { + mut io, + mux_ctrl, + mut mux_fut, + mut vm, + mut ot_recv, + mut ctx, + connection_info, + server_cert_data, + transcript, + encoding_provider, + transcript_commit_config, + encoding_commitments, + } = self.state; + + let provider = self.config.crypto_provider(); + + let hasher = provider.hash.get(config.hash_alg()).unwrap(); + + let mut builder = Request::builder(config); + + builder + .server_name(self.config.server_name().clone()) + .server_cert_data(server_cert_data) + .transcript(transcript.clone()); + + if let Some(config) = transcript_commit_config { + if config.has_encoding() { + let tree = match encoding_commitments { + Some(tree) => tree, + None => EncodingTree::new( + hasher, + config.iter_encoding(), + &*encoding_provider, + &connection_info.transcript_length, + ) + .map_err(ProverError::attestation)?, + }; + + builder.encoding_tree(tree); + } + + if config.has_plaintext_hashes() { + builder.plaintext_hashes(config.plaintext_hashes()); + } + } + + let (request, secrets) = builder.build(provider).map_err(ProverError::attestation)?; + + let attestation = mux_fut + .poll_with(async { + debug!("starting finalization"); + + io.send(request.clone()).await?; + + let max = self.config.protocol_config().max_authdecode_data(); + let mut authdecode_prover = + authdecode_prover(authdecode_inputs, &*encoding_provider, &transcript, max)?; + + io.send(authdecode_prover.alg()).await?; + io.send(authdecode_prover.commit()?).await?; + + debug!("sent AuthDecode commitment"); + + ot_recv.accept_reveal(&mut ctx).await?; + + debug!("received OT secret"); + + let seed = vm + .finalize() + .await? + .expect("The seed should be returned to the leader"); + + // Now that the full encodings were authenticated, it is safe to proceed to the + // proof generation phase of the AuthDecode protocol. + io.send(authdecode_prover.prove(seed)?).await?; + + debug!("sent AuthDecode proof"); + + let attestation: Attestation = io.expect_next().await?; + + Ok::<_, ProverError>(attestation) + }) + .await?; + + // Wait for the notary to correctly close the connection. + if !mux_fut.is_complete() { + mux_ctrl.mux().close(); + mux_fut.await?; + } + + // Check the attestation is consistent with the Prover's view. + request + .validate(&attestation) + .map_err(ProverError::attestation)?; + + Ok((attestation, secrets)) + } +} diff --git a/crates/prover/src/prove.rs b/crates/prover/src/prove.rs new file mode 100644 index 0000000000..8163b18d8b --- /dev/null +++ b/crates/prover/src/prove.rs @@ -0,0 +1,103 @@ +//! This module handles the proving phase of the prover. +//! +//! Here the prover deals with a verifier directly, so there is no notary +//! involved. Instead the verifier directly verifies parts of the transcript. + +use super::{state::Prove as ProveState, Prover, ProverError}; +use mpz_garble::{Memory, Prove}; +use mpz_ot::VerifiableOTReceiver; +use serio::SinkExt as _; +use tlsn_common::msg::ServerIdentityProof; +use tlsn_core::transcript::{get_value_ids, Direction, Idx, Transcript}; + +use tracing::{info, instrument}; + +impl Prover { + /// Returns the transcript. + pub fn transcript(&self) -> &Transcript { + &self.state.transcript + } + + /// Prove subsequences in the transcript to the verifier. + /// + /// # Arguments + /// + /// * `sent` - Indices of the sent data. + /// * `recv` - Indices of the received data. + #[instrument(parent = &self.span, level = "debug", skip_all, err)] + pub async fn prove_transcript(&mut self, sent: Idx, recv: Idx) -> Result<(), ProverError> { + let partial_transcript = self.transcript().to_partial(sent.clone(), recv.clone()); + + let sent_value_ids = get_value_ids(Direction::Sent, &sent); + let recv_value_ids = get_value_ids(Direction::Received, &recv); + + let value_refs = sent_value_ids + .chain(recv_value_ids) + .map(|id| { + self.state + .vm + .get_value(id.as_str()) + .expect("Byte should be in VM memory") + }) + .collect::>(); + + self.state + .mux_fut + .poll_with(async { + // Send the partial transcript to the verifier. + self.state.io.send(partial_transcript).await?; + + info!("Sent partial transcript"); + + // Prove the partial transcript to the verifier. + self.state.vm.prove(value_refs.as_slice()).await?; + + info!("Proved partial transcript"); + + Ok::<_, ProverError>(()) + }) + .await?; + + Ok(()) + } + + /// Finalize the proving + #[instrument(parent = &self.span, level = "debug", skip_all, err)] + pub async fn finalize(self) -> Result<(), ProverError> { + let ProveState { + mut io, + mux_ctrl, + mut mux_fut, + mut vm, + mut ot_recv, + mut ctx, + server_cert_data, + .. + } = self.state; + + mux_fut + .poll_with(async move { + ot_recv.accept_reveal(&mut ctx).await?; + + vm.finalize().await?; + + // Send identity proof to the verifier + io.send(ServerIdentityProof { + name: self.config.server_name().clone(), + data: server_cert_data, + }) + .await?; + + Ok::<_, ProverError>(()) + }) + .await?; + + // Wait for the verifier to correctly close the connection. + if !mux_fut.is_complete() { + mux_ctrl.mux().close(); + mux_fut.await?; + } + + Ok(()) + } +} diff --git a/crates/prover/src/state.rs b/crates/prover/src/state.rs new file mode 100644 index 0000000000..235a7a3041 --- /dev/null +++ b/crates/prover/src/state.rs @@ -0,0 +1,201 @@ +//! TLS prover states. + +use mpz_core::{serialize::CanonicalSerialize, Block}; +use mpz_garble::protocol::deap::PeerEncodings; +use mpz_garble_core::{encoding_state, EncodedValue}; +use std::collections::HashMap; +use tls_mpc::MpcTlsLeader; +use tlsn_common::{ + mux::{MuxControl, MuxFuture}, + Context, DEAPThread, Io, OTReceiver, +}; +use tlsn_core::{ + connection::{ConnectionInfo, ServerCertData}, + transcript::{ + encoding::{EncodingProvider, EncodingTree}, + Direction, Idx, Transcript, TranscriptCommitConfig, + }, +}; + +/// Entry state +pub struct Initialized; + +opaque_debug::implement!(Initialized); + +/// State after MPC setup has completed. +pub struct Setup { + pub(crate) io: Io, + pub(crate) mux_ctrl: MuxControl, + pub(crate) mux_fut: MuxFuture, + + pub(crate) mpc_tls: MpcTlsLeader, + pub(crate) vm: DEAPThread, + pub(crate) ot_recv: OTReceiver, + pub(crate) ctx: Context, +} + +opaque_debug::implement!(Setup); + +/// State after the TLS connection has been closed. +pub struct Closed { + pub(crate) io: Io, + pub(crate) mux_ctrl: MuxControl, + pub(crate) mux_fut: MuxFuture, + + pub(crate) vm: DEAPThread, + pub(crate) ot_recv: OTReceiver, + pub(crate) ctx: Context, + + pub(crate) connection_info: ConnectionInfo, + pub(crate) server_cert_data: ServerCertData, + + pub(crate) transcript: Transcript, +} + +opaque_debug::implement!(Closed); + +/// Notarizing state. +pub struct Notarize { + pub(crate) io: Io, + pub(crate) mux_ctrl: MuxControl, + pub(crate) mux_fut: MuxFuture, + + pub(crate) vm: DEAPThread, + pub(crate) ot_recv: OTReceiver, + pub(crate) ctx: Context, + + pub(crate) connection_info: ConnectionInfo, + pub(crate) server_cert_data: ServerCertData, + + pub(crate) transcript: Transcript, + pub(crate) encoding_provider: Box, + pub(crate) encoding_commitments: Option, + + pub(crate) transcript_commit_config: Option, +} + +opaque_debug::implement!(Notarize); + +impl From for Notarize { + fn from(state: Closed) -> Self { + struct HashMapProvider(HashMap>); + + impl EncodingProvider for HashMapProvider { + fn provide_encoding(&self, direction: Direction, idx: &Idx) -> Option> { + let mut encoding = Vec::new(); + let prefix = match direction { + Direction::Sent => "tx/", + Direction::Received => "rx/", + }; + for i in idx.iter() { + encoding + .extend_from_slice(&self.0.get(&format!("{}{}", prefix, i))?.to_bytes()); + } + + Some(encoding) + } + + fn provide_bit_encodings( + &self, + direction: Direction, + idx: &Idx, + ) -> Option>> { + let mut encodings = Vec::with_capacity(idx.len() * 8); + let prefix = match direction { + Direction::Sent => "tx/", + Direction::Received => "rx/", + }; + for i in idx.iter() { + for label in self.0.get(&format!("{}{}", prefix, i))?.iter() { + encodings.push(Block::to_bytes(*label.as_ref()).to_vec()) + } + } + + Some(encodings) + } + } + + let encoding_provider = HashMapProvider(collect_encodings(&state.vm, &state.transcript)); + + Self { + io: state.io, + mux_ctrl: state.mux_ctrl, + mux_fut: state.mux_fut, + vm: state.vm, + ot_recv: state.ot_recv, + ctx: state.ctx, + connection_info: state.connection_info, + server_cert_data: state.server_cert_data, + transcript: state.transcript, + encoding_provider: Box::new(encoding_provider), + encoding_commitments: None, + transcript_commit_config: None, + } + } +} + +/// Proving state. +pub struct Prove { + pub(crate) io: Io, + pub(crate) mux_ctrl: MuxControl, + pub(crate) mux_fut: MuxFuture, + + pub(crate) vm: DEAPThread, + pub(crate) ot_recv: OTReceiver, + pub(crate) ctx: Context, + + pub(crate) server_cert_data: ServerCertData, + + pub(crate) transcript: Transcript, +} + +impl From for Prove { + fn from(state: Closed) -> Self { + Self { + io: state.io, + mux_ctrl: state.mux_ctrl, + mux_fut: state.mux_fut, + vm: state.vm, + ot_recv: state.ot_recv, + ctx: state.ctx, + server_cert_data: state.server_cert_data, + transcript: state.transcript, + } + } +} + +#[allow(missing_docs)] +pub trait ProverState: sealed::Sealed {} + +impl ProverState for Initialized {} +impl ProverState for Setup {} +impl ProverState for Closed {} +impl ProverState for Notarize {} +impl ProverState for Prove {} + +mod sealed { + pub trait Sealed {} + impl Sealed for super::Initialized {} + impl Sealed for super::Setup {} + impl Sealed for super::Closed {} + impl Sealed for super::Notarize {} + impl Sealed for super::Prove {} +} + +fn collect_encodings( + vm: &impl PeerEncodings, + transcript: &Transcript, +) -> HashMap> { + let tx_ids = (0..transcript.sent().len()).map(|id| format!("tx/{id}")); + let rx_ids = (0..transcript.received().len()).map(|id| format!("rx/{id}")); + + let ids = tx_ids.chain(rx_ids).collect::>(); + let id_refs = ids.iter().map(|id| id.as_ref()).collect::>(); + + vm.get_peer_encodings(&id_refs) + .expect("encodings for all transcript values should be present") + .into_iter() + .zip(ids) + .map(|(encoding, id)| (id, encoding)) + .collect() +} diff --git a/crates/server-fixture/certs/Cargo.toml b/crates/server-fixture/certs/Cargo.toml new file mode 100644 index 0000000000..25ef66064f --- /dev/null +++ b/crates/server-fixture/certs/Cargo.toml @@ -0,0 +1,4 @@ +[package] +name = "tlsn-server-fixture-certs" +version = "0.1.0" +edition = "2021" diff --git a/crates/server-fixture/certs/src/lib.rs b/crates/server-fixture/certs/src/lib.rs new file mode 100644 index 0000000000..b090b68772 --- /dev/null +++ b/crates/server-fixture/certs/src/lib.rs @@ -0,0 +1,8 @@ +/// A certificate authority certificate fixture. +pub static CA_CERT_DER: &[u8] = include_bytes!("tls/root_ca_cert.der"); +/// A server certificate (domain=test-server.io) fixture. +pub static SERVER_CERT_DER: &[u8] = include_bytes!("tls/test_server_cert.der"); +/// A server private key fixture. +pub static SERVER_KEY_DER: &[u8] = include_bytes!("tls/test_server_private_key.der"); +/// The domain name bound to the server certificate. +pub static SERVER_DOMAIN: &str = "test-server.io"; diff --git a/crates/server-fixture/certs/src/tls/README.md b/crates/server-fixture/certs/src/tls/README.md new file mode 100644 index 0000000000..01d243ef86 --- /dev/null +++ b/crates/server-fixture/certs/src/tls/README.md @@ -0,0 +1,23 @@ +# Create a private key for the root CA +openssl genpkey -algorithm RSA -out root_ca.key -pkeyopt rsa_keygen_bits:2048 + +# Create a self-signed root CA certificate (100 years validity) +openssl req -x509 -new -nodes -key root_ca.key -sha256 -days 36525 -out root_ca.crt -subj "/C=US/ST=State/L=City/O=tlsnotary/OU=IT/CN=tlsnotary.org" + +# Create a private key for the end entity certificate +openssl genpkey -algorithm RSA -out test_server.key -pkeyopt rsa_keygen_bits:2048 + +# Create a certificate signing request (CSR) for the end entity certificate +openssl req -new -key test_server.key -out test_server.csr -subj "/C=US/ST=State/L=City/O=tlsnotary/OU=IT/CN=test-server.io" + +# Sign the CSR with the root CA to create the end entity certificate (100 years validity) +openssl x509 -req -in test_server.csr -CA root_ca.crt -CAkey root_ca.key -CAcreateserial -out test_server.crt -days 36525 -sha256 -extfile openssl.cnf -extensions v3_req + +# Convert the root CA certificate to DER format +openssl x509 -in root_ca.crt -outform der -out root_ca_cert.der + +# Convert the end entity certificate to DER format +openssl x509 -in test_server.crt -outform der -out test_server_cert.der + +# Convert the end entity certificate private key to DER format +openssl pkcs8 -topk8 -inform PEM -outform DER -in test_server.key -out test_server_private_key.der -nocrypt \ No newline at end of file diff --git a/crates/server-fixture/certs/src/tls/openssl.cnf b/crates/server-fixture/certs/src/tls/openssl.cnf new file mode 100644 index 0000000000..6a52a8855b --- /dev/null +++ b/crates/server-fixture/certs/src/tls/openssl.cnf @@ -0,0 +1,7 @@ +[ v3_req ] +basicConstraints = CA:FALSE +keyUsage = nonRepudiation, digitalSignature, keyEncipherment +subjectAltName = @alt_names + +[ alt_names ] +DNS.1 = test-server.io \ No newline at end of file diff --git a/crates/server-fixture/certs/src/tls/root_ca.crt b/crates/server-fixture/certs/src/tls/root_ca.crt new file mode 100644 index 0000000000..3707fc7f0b --- /dev/null +++ b/crates/server-fixture/certs/src/tls/root_ca.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDrTCCApWgAwIBAgIUNBABQqSkJXdF8qtOLwP4EylSOPcwDQYJKoZIhvcNAQEL +BQAwZTELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MRIwEAYDVQQKDAl0bHNub3RhcnkxCzAJBgNVBAsMAklUMRYwFAYDVQQDDA10bHNu +b3Rhcnkub3JnMCAXDTI0MDgwMjEwMTQ1M1oYDzIxMjQwODAzMTAxNDUzWjBlMQsw +CQYDVQQGEwJVUzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxEjAQBgNV +BAoMCXRsc25vdGFyeTELMAkGA1UECwwCSVQxFjAUBgNVBAMMDXRsc25vdGFyeS5v +cmcwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCVvgedJ3zVE7ICYoaD +CwybhEN/6g1baoyDRVD8fpZfhdkh0uMMKBFqRa1qO9wF3Fthq6DJRaHsmZeE42Jm +aDvlRtaKDfB0MMcSeNqmP8ia7+8TFgMBY/YP7dW3d9QADFHLqyMcS6O2iaSMjBzg +4nx33TdAhQOIPHOSZbMZJGO18jn55GEeogIz6UiV8gqjQtbel/cn8jXi2rOgub+p +CZziixQ6ikppdW6a8p37B5W4/WNHDIRgRP890q0GyrEJWtj9TwyMmeC6/0mxXjZC +caLWV0072j3Dd+66XvkeL04mSe4Bp0YUs8jcTPsfOAo3FAvPgyQ6UqQfZBqOnU93 +xmYzAgMBAAGjUzBRMB0GA1UdDgQWBBTJgXIkPw2ZVkTscFx/CKZZrhymzTAfBgNV +HSMEGDAWgBTJgXIkPw2ZVkTscFx/CKZZrhymzTAPBgNVHRMBAf8EBTADAQH/MA0G +CSqGSIb3DQEBCwUAA4IBAQBP/IbIt7TheTcdhCtT5uAo4bp9Hjo5loaj0jtUkmwP +0RM2uA1IPu+stA+Zb8zfYZ9cIeTlYpFKZpVGmZfQYb26vYsPb40fUWAO+pYt5CGE +Kf+nwDokwT4sZUocm8sOhiLb4LWbE+e5ZmfthwUfloR2qJD9GEi4XmNt/QEbUDCK +AFch+dCRNzf0W7XJKkUsB2UQqBjHfcVbor1KrhEPgWfzfHBac6hipyekr82FVTY0 +tUIKJvvsDuCm9vUE8EpXHpT+krw9H6jSxhcovw9K4Ix4VSbY/c3g9jO4eneMama+ +IrjLekT8wUU4YHSaPcUm3VTBGrVDe6ZCGGaM3iDUJvr8 +-----END CERTIFICATE----- diff --git a/crates/server-fixture/certs/src/tls/root_ca.key b/crates/server-fixture/certs/src/tls/root_ca.key new file mode 100644 index 0000000000..712469bf3e --- /dev/null +++ b/crates/server-fixture/certs/src/tls/root_ca.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCVvgedJ3zVE7IC +YoaDCwybhEN/6g1baoyDRVD8fpZfhdkh0uMMKBFqRa1qO9wF3Fthq6DJRaHsmZeE +42JmaDvlRtaKDfB0MMcSeNqmP8ia7+8TFgMBY/YP7dW3d9QADFHLqyMcS6O2iaSM +jBzg4nx33TdAhQOIPHOSZbMZJGO18jn55GEeogIz6UiV8gqjQtbel/cn8jXi2rOg +ub+pCZziixQ6ikppdW6a8p37B5W4/WNHDIRgRP890q0GyrEJWtj9TwyMmeC6/0mx +XjZCcaLWV0072j3Dd+66XvkeL04mSe4Bp0YUs8jcTPsfOAo3FAvPgyQ6UqQfZBqO +nU93xmYzAgMBAAECggEAB+ybV4rgQCBqMlZyGtuJ/8Ld6uuBEx442wuJ2nV9J1yc +cyicq6cv1hQONh8pKMWSr8EBjGqFw/u+znaqsuj/iRsYvbaOISqhpk3Eow6guD5L +7xJ3oepfJP786S12B8ifHYGWz+ewKA1HAB8RZNSSKf+ywv8nAt3Rbzpi4h47CUT4 +Z06gLJYZNimLVPIWLzrHa+/ZOyHq/XRWsr6GTFgXfT6nudfCxzdlIdajrBvaSLBG +KbOs52tffEUHn+V1AoH6kmNp0EPSCbnR2b1KIv7loj6vi52UBpipjNFwa8PNzWfL +Cuu9N6fl7qRv9VYCnC2gJz6rTARaNJWf57UP2avygQKBgQDJC89y4lgai8qJLv3w +go+kFiWnZE0C8U69sOmNeACYhunQFKX2cG7EkTuPOnZj8XJcLYVHMSJLrEJcqyX/ +wDv1at+KqDMQsf0j7NHCSpkoG93wlffCB87VPndy7ajRN4d17tbQOJP6zmOQo1YP +7MTeVtDF3JF9IxfTb+Pxmp5nswKBgQC+rDzBN8Drr1jp6FfzZrDcr/gvlSftXupF +jTSkSxywQjophp02Hdi2t32Xq+wEuaMaJUOtywK/NVs5hJeGC584rWQjLObh7oUD +td+2V802kzsERSeDiDwtBYgjePtkeO7MXadGLwJSaZxocjcjgGj2qWPs9ihUASuB +TEtkO0jHgQKBgQCUFGXc2YhJLTOlrX4O+ytvkXx0ebUbeL8lirvLnlrZ/W0T/VFs +Xc3IbKxwx3/SB1HTQRgMosz+7ccHWGwpnt7K2cgC6faK0n6ASnsJX0bFuxjSjrMp +L/URLexvM0uHph3ZKG0CetnL/t5o91V5b0xl843cXqSuhf2Tl7NODjOkbwKBgAIn +5mP04myHxgSXCO+KmLNWFgNLt3DaouF4cEDvTHq9tPSlPf/PpJSkTHo7imafRrXT ++AjuA7DvxIFI+4GbfghhBYHUTyP802owU0A3i+1zCrbIpWK6VpvXtStZgdYn++M5 +p9uGSotuAEO6Dt+K4yTu019phRk2DizfFPckKHWBAoGAehmqjR+5T7SpDiZwXFyN +CA4qKVoYPexmNjbECYkpLbEkPxOc145H0Y4oHOBH46jIiHumSV3N2bvywYQ2IlyV +BSGqGFAeFhpRAtMKCFMG7bNPTbskKcpUyGD2csoiYxXsFuFZX4Db9i0tpjt57C/a +9ij7zNzrAj5Iby8EMykK+aM= +-----END PRIVATE KEY----- diff --git a/crates/server-fixture/certs/src/tls/root_ca.srl b/crates/server-fixture/certs/src/tls/root_ca.srl new file mode 100644 index 0000000000..f4add3930b --- /dev/null +++ b/crates/server-fixture/certs/src/tls/root_ca.srl @@ -0,0 +1 @@ +1B924A233FDF6D40DDA57D7E4C0C37DE64BE996A diff --git a/crates/server-fixture/certs/src/tls/root_ca_cert.der b/crates/server-fixture/certs/src/tls/root_ca_cert.der new file mode 100644 index 0000000000..3aceaa592a Binary files /dev/null and b/crates/server-fixture/certs/src/tls/root_ca_cert.der differ diff --git a/crates/server-fixture/certs/src/tls/test_server.crt b/crates/server-fixture/certs/src/tls/test_server.crt new file mode 100644 index 0000000000..15e3d3e8f3 --- /dev/null +++ b/crates/server-fixture/certs/src/tls/test_server.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID0DCCArigAwIBAgIUG5JKIz/fbUDdpX1+TAw33mS+mWowDQYJKoZIhvcNAQEL +BQAwZTELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MRIwEAYDVQQKDAl0bHNub3RhcnkxCzAJBgNVBAsMAklUMRYwFAYDVQQDDA10bHNu +b3Rhcnkub3JnMCAXDTI0MDgwMjEwMTU1N1oYDzIxMjQwODAzMTAxNTU3WjBmMQsw +CQYDVQQGEwJVUzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxEjAQBgNV +BAoMCXRsc25vdGFyeTELMAkGA1UECwwCSVQxFzAVBgNVBAMMDnRlc3Qtc2VydmVy +LmlvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAoeBDxxxAASDtcXx4 +07dK7YfLw2+cRz5rDdv/HHPHJLGJTvCXfZCTfV3y3KzTuLeOWHhGyG1bH075Jg/1 +TZ+nTdr/T/78mV4GXilf6hvmnwX3Pr7KLfXDEizRDKbnQqTgThs9hgHJ5pm8Jkid +5dWJEnvT5ChaBzwITpAe7qD05dVln7wkayKkT28IuV1iOglXjoBsozsL2qvj2wmL +pYQqn17Ir98CY9AUjJ/D4tAGRbxGmhQ3+kLakO2wR+TA0E51opjlWeP4qc8i1OWp +MH3fz5GddrC0BYVF0yute2VgjOXlM0PB2V4aMrqeB52hppix9XZOXymLeVQHddXQ +YbtPPQIDAQABo3UwczAJBgNVHRMEAjAAMAsGA1UdDwQEAwIF4DAZBgNVHREEEjAQ +gg50ZXN0LXNlcnZlci5pbzAdBgNVHQ4EFgQUXLxbOoGpjtxTs0zuIRtl74jPNokw +HwYDVR0jBBgwFoAUyYFyJD8NmVZE7HBcfwimWa4cps0wDQYJKoZIhvcNAQELBQAD +ggEBADnqTNfq6bhNYeSrT0KjpkJL+lI8g1gIUHUP9/wIMOJAaIkKj/hjoqSkrUfm +DoE9zK+7Wy27nyS6q8YEpbjWqRBUmdL6iQ0/fzgl/jaPyqfwrj0S2Xjj8mWBC2jj +//omjVrdq2RtYoL175sHEq3df9bprOPzb7mVPDP1kb7akkiqHRvvID+A1rd+hbHo +H1cjc8RvtTxfnwAwXui8GFgdGjlt59qi+RxxGFCzdQNUhVTxQUeX4Xp6I99EMWWK +1UcKR7wmU+LuW4NNZXCWHUdsDlJaGgYuNZM90IpV4XeEr3MKFom7+G2M1cpuRi86 +/c1GLXO07No9K57VJ+W3qL/uw8c= +-----END CERTIFICATE----- diff --git a/crates/server-fixture/certs/src/tls/test_server.csr b/crates/server-fixture/certs/src/tls/test_server.csr new file mode 100644 index 0000000000..6be1fbcf4c --- /dev/null +++ b/crates/server-fixture/certs/src/tls/test_server.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICqzCCAZMCAQAwZjELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYD +VQQHDARDaXR5MRIwEAYDVQQKDAl0bHNub3RhcnkxCzAJBgNVBAsMAklUMRcwFQYD +VQQDDA50ZXN0LXNlcnZlci5pbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC +ggEBAKHgQ8ccQAEg7XF8eNO3Su2Hy8NvnEc+aw3b/xxzxySxiU7wl32Qk31d8tys +07i3jlh4RshtWx9O+SYP9U2fp03a/0/+/JleBl4pX+ob5p8F9z6+yi31wxIs0Qym +50Kk4E4bPYYByeaZvCZIneXViRJ70+QoWgc8CE6QHu6g9OXVZZ+8JGsipE9vCLld +YjoJV46AbKM7C9qr49sJi6WEKp9eyK/fAmPQFIyfw+LQBkW8RpoUN/pC2pDtsEfk +wNBOdaKY5Vnj+KnPItTlqTB938+RnXawtAWFRdMrrXtlYIzl5TNDwdleGjK6nged +oaaYsfV2Tl8pi3lUB3XV0GG7Tz0CAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4IBAQCA +aNz5mVndHInJJloJIuFvHbQLeuglEfn1Iyjjk3ILLm29RqcVlJ1LsnZZXG4rv8JH +YWHpvsLLrR/nIkT+wxFCfYVHp8szpyLVW/mTLWb6xAB/d6i1SEmYSN0LNkmNvWFS +kDq9A3v5sa9SZ1/btgfIVa6QzZWHuqYqad3KWJcpn+PckqiG+Bihx69TGsIMJHgN +9P//ra2lWyL391KGycNrKTbydpFjRT6vwC2QZJWG47liRS/PYfm6wtdoJa7Mw9vl +ciBvDhTFF7FYl0uV1NlzIoVyChMmRv2JR66efcTfWqfP44E4dhBKHIpBxc8+4GtI +ol18bSfvVKBlIyoZPdRP +-----END CERTIFICATE REQUEST----- diff --git a/crates/server-fixture/certs/src/tls/test_server.key b/crates/server-fixture/certs/src/tls/test_server.key new file mode 100644 index 0000000000..a3f4a433b8 --- /dev/null +++ b/crates/server-fixture/certs/src/tls/test_server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCh4EPHHEABIO1x +fHjTt0rth8vDb5xHPmsN2/8cc8cksYlO8Jd9kJN9XfLcrNO4t45YeEbIbVsfTvkm +D/VNn6dN2v9P/vyZXgZeKV/qG+afBfc+vsot9cMSLNEMpudCpOBOGz2GAcnmmbwm +SJ3l1YkSe9PkKFoHPAhOkB7uoPTl1WWfvCRrIqRPbwi5XWI6CVeOgGyjOwvaq+Pb +CYulhCqfXsiv3wJj0BSMn8Pi0AZFvEaaFDf6QtqQ7bBH5MDQTnWimOVZ4/ipzyLU +5akwfd/PkZ12sLQFhUXTK617ZWCM5eUzQ8HZXhoyup4HnaGmmLH1dk5fKYt5VAd1 +1dBhu089AgMBAAECggEAARLZCuZdEPou7k8Xs2Ub0hzRyny1r03yrSeFtvftnN4F +6HLKuRfPDUh6O+HJkF1nTEmVQ+8LE6y/v542JPWnc79oF20RhSg30pgOUsyB6GbE +ZQjO6SR1eWwNVuV50y9UwtqGEJrNGVfGWlqmRsf2c3DuztdAVvFG/NO9Nh1LLTBi +bxkyhVma4NRs3Yt9tzL/wCeIAWFXjHIctYxUw7TUd21s7/i/+dN96wEGywpuN0zs +HeiV8OOV0wEgaIldShJnDW7EF5AVl9eGZznm5gJnCqE4/Sq+0CuRLW95jFwZu0GR +tUNjSi+ypPLt0vf1LIM0hrYoFNB0xxLdnRqpEB96AQKBgQDatlglMc0KvJ2ghNGJ +TbK/jtkb7uEkdFxM+l+4EYOdpVrh7gQMekSDv6nNyze9eCBWouB8m5ty1+HKTPk8 +X+4shd7VE8joJus/R/0vqVrcK3eh/fuL8E2TgZbKmsBh168NQY5dxMgs69do3bPg +k46EUxLd3MJ9fdv2xoila7c2DQKBgQC9eVAyVlDA+VBonZA1kmVY4858JKMsPH+Y +pfC1ty+QwFxsN3AxpYolUYJAp11vDCdGP9GpOaHhfGmxLCIyUSWDUdcBjXZCz/lW +76xT4wh7LieOezhPMzoLP+vdel9xvLGQYu2a6GvRaUxfPH7fbFaGBuTvi4JjEaq0 +CjtmksDh8QKBgDXKHseXBeycEtBFmhsApvOBuFesWmbSz1iHQz9L32jIIB/sn8ZJ +08vrOWHJlv3cK2fjSv6abpLCEV/lqm500WjVy8XvxbuCxtybYeN07Um0zwliI5l5 +Ejsy5dkSUjo+B2llNBRPr0ONBT9fNzwGTkiw/bTe9F5Us+JvVXAJm9eJAoGAX04X +Jcq/AeImLQkcUaYarlSgN1eicAzaTakiY/UJyvDHTHOyTnaq/0x5jRXibIobcz2E +s29W2vnenAzMAq1Ihj5zPMewNbkw/Sa/cs6fJH65zPR0BXqJ9sCnXpdATRCR7EOm +qqXAHeyuSrU+SBnRh8cN/uQYqMZpK/h9moG03bECgYB8MtBxLxLctVYofIwyzRF/ +gu4Tm84bVBeaRlUapuXB0ZsC6JBnCM9cWDR2SxLHEbjGQ07Oaae+OQb4C5Uc0Ex0 +fmu0ATFb52BScnNFw+elQHk8ynv0sfV6LbPe2C1uNBVFtatRMgNJGkg7UauM68TJ +VOZB9TwtVa9Sr0QpQNiyvg== +-----END PRIVATE KEY----- diff --git a/crates/server-fixture/certs/src/tls/test_server_cert.der b/crates/server-fixture/certs/src/tls/test_server_cert.der new file mode 100644 index 0000000000..765cc31b5c Binary files /dev/null and b/crates/server-fixture/certs/src/tls/test_server_cert.der differ diff --git a/crates/server-fixture/certs/src/tls/test_server_private_key.der b/crates/server-fixture/certs/src/tls/test_server_private_key.der new file mode 100644 index 0000000000..d2591402c8 Binary files /dev/null and b/crates/server-fixture/certs/src/tls/test_server_private_key.der differ diff --git a/crates/server-fixture/server/Cargo.toml b/crates/server-fixture/server/Cargo.toml new file mode 100644 index 0000000000..dd617b8815 --- /dev/null +++ b/crates/server-fixture/server/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "tlsn-server-fixture" +version = "0.1.0" +edition = "2021" + +[dependencies] +axum = { workspace = true } +anyhow = { workspace = true } +futures = { workspace = true } +futures-rustls = { workspace = true } +hyper = { workspace = true } +hyper-util = { workspace = true, features = ["full"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } +tokio-util = { workspace = true, features = ["compat", "io"] } +tower-service = { version = "0.3" } + +tlsn-server-fixture-certs = { workspace = true } + +[[bin]] +name = "main" +path = "src/main.rs" diff --git a/tlsn/tlsn-server-fixture/README.md b/crates/server-fixture/server/README.md similarity index 100% rename from tlsn/tlsn-server-fixture/README.md rename to crates/server-fixture/server/README.md diff --git a/tlsn/tlsn-server-fixture/src/data/.gitignore b/crates/server-fixture/server/src/data/.gitignore similarity index 100% rename from tlsn/tlsn-server-fixture/src/data/.gitignore rename to crates/server-fixture/server/src/data/.gitignore diff --git a/tlsn/tlsn-server-fixture/src/data/1kb.json b/crates/server-fixture/server/src/data/1kb.json similarity index 100% rename from tlsn/tlsn-server-fixture/src/data/1kb.json rename to crates/server-fixture/server/src/data/1kb.json diff --git a/tlsn/tlsn-server-fixture/src/data/4kb.html b/crates/server-fixture/server/src/data/4kb.html similarity index 100% rename from tlsn/tlsn-server-fixture/src/data/4kb.html rename to crates/server-fixture/server/src/data/4kb.html diff --git a/tlsn/tlsn-server-fixture/src/data/4kb.json b/crates/server-fixture/server/src/data/4kb.json similarity index 100% rename from tlsn/tlsn-server-fixture/src/data/4kb.json rename to crates/server-fixture/server/src/data/4kb.json diff --git a/tlsn/tlsn-server-fixture/src/data/8kb.json b/crates/server-fixture/server/src/data/8kb.json similarity index 100% rename from tlsn/tlsn-server-fixture/src/data/8kb.json rename to crates/server-fixture/server/src/data/8kb.json diff --git a/crates/server-fixture/server/src/lib.rs b/crates/server-fixture/server/src/lib.rs new file mode 100644 index 0000000000..1ac44f3975 --- /dev/null +++ b/crates/server-fixture/server/src/lib.rs @@ -0,0 +1,127 @@ +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use axum::{ + extract::{Query, State}, + response::{Html, Json}, + routing::get, + Router, +}; +use futures::{channel::oneshot, AsyncRead, AsyncWrite}; +use futures_rustls::{ + pki_types::{CertificateDer, PrivateKeyDer}, + rustls::ServerConfig, + TlsAcceptor, +}; +use hyper::{ + body::{Bytes, Incoming}, + server::conn::http1, + Request, StatusCode, +}; +use hyper_util::rt::TokioIo; + +use tokio_util::compat::FuturesAsyncReadCompatExt; +use tower_service::Service; + +use tlsn_server_fixture_certs::*; + +struct AppState { + shutdown: Option>, +} + +fn app(state: AppState) -> Router { + Router::new() + .route("/", get(|| async { "Hello, World!" })) + .route("/bytes", get(bytes)) + .route("/formats/json", get(json)) + .route("/formats/html", get(html)) + .with_state(Arc::new(Mutex::new(state))) +} + +/// Bind the server to the given socket. +pub async fn bind( + socket: T, +) -> anyhow::Result<()> { + let key = PrivateKeyDer::Pkcs8(SERVER_KEY_DER.into()); + let cert = CertificateDer::from(SERVER_CERT_DER); + + let config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert], key) + .unwrap(); + + let acceptor = TlsAcceptor::from(Arc::new(config)); + + let conn = acceptor.accept(socket).await?; + + let io = TokioIo::new(conn.compat()); + + let (sender, receiver) = oneshot::channel(); + let state = AppState { + shutdown: Some(sender), + }; + let tower_service = app(state); + + let hyper_service = hyper::service::service_fn(move |request: Request| { + tower_service.clone().call(request) + }); + + tokio::select! { + _ = http1::Builder::new() + .keep_alive(false) + .serve_connection(io, hyper_service) => {}, + _ = receiver => {}, + } + + Ok(()) +} + +async fn bytes( + State(state): State>>, + Query(params): Query>, +) -> Result { + let size = params + .get("size") + .and_then(|size| size.parse::().ok()) + .unwrap_or(1); + + if params.contains_key("shutdown") { + _ = state.lock().unwrap().shutdown.take().unwrap().send(()); + } + + Ok(Bytes::from(vec![0x42u8; size])) +} + +async fn json( + State(state): State>>, + Query(params): Query>, +) -> Result, StatusCode> { + let size = params + .get("size") + .and_then(|size| size.parse::().ok()) + .unwrap_or(1); + + if params.contains_key("shutdown") { + _ = state.lock().unwrap().shutdown.take().unwrap().send(()); + } + + match size { + 1 => Ok(Json(include_str!("data/1kb.json"))), + 4 => Ok(Json(include_str!("data/4kb.json"))), + 8 => Ok(Json(include_str!("data/8kb.json"))), + _ => Err(StatusCode::NOT_FOUND), + } +} + +async fn html( + State(state): State>>, + Query(params): Query>, +) -> Html<&'static str> { + if params.contains_key("shutdown") { + _ = state.lock().unwrap().shutdown.take().unwrap().send(()); + } + + Html(include_str!("data/4kb.html")) +} diff --git a/tlsn/tlsn-server-fixture/src/main.rs b/crates/server-fixture/server/src/main.rs similarity index 74% rename from tlsn/tlsn-server-fixture/src/main.rs rename to crates/server-fixture/server/src/main.rs index a1d67004d5..d93e2de9f1 100644 --- a/tlsn/tlsn-server-fixture/src/main.rs +++ b/crates/server-fixture/server/src/main.rs @@ -1,16 +1,16 @@ -use std::env; +use std::{env, io}; use tlsn_server_fixture::bind; use tokio::net::TcpListener; use tokio_util::compat::TokioAsyncWriteCompatExt; #[tokio::main] -async fn main() { +async fn main() -> io::Result<()> { let port = env::var("PORT").unwrap_or_else(|_| "3000".to_string()); - let listener = TcpListener::bind(&format!("0.0.0.0:{port}")).await.unwrap(); + let listener = TcpListener::bind(&format!("0.0.0.0:{port}")).await?; loop { - let (socket, _) = listener.accept().await.unwrap(); + let (socket, _) = listener.accept().await?; tokio::spawn(bind(socket.compat_write())); } } diff --git a/crates/tests-integration/Cargo.toml b/crates/tests-integration/Cargo.toml new file mode 100644 index 0000000000..8513c9d28a --- /dev/null +++ b/crates/tests-integration/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "tests-integration" +version = "0.0.0" +edition = "2021" +publish = false + +[features] +# Enables the AuthDecode protocol which allows to prove zk-friendly hashes over the transcript data. +authdecode_unsafe = [ + "tlsn-prover/authdecode_unsafe", + "tlsn-verifier/authdecode_unsafe" + ] + +[dev-dependencies] +tlsn-core = { workspace = true } +tlsn-common = { workspace = true } +tlsn-prover = { workspace = true } +tlsn-server-fixture = { workspace = true } +tlsn-server-fixture-certs = { workspace = true } +tlsn-tls-core = { workspace = true } +tlsn-utils = { workspace = true } +tlsn-verifier = { workspace = true } + +bincode = { workspace = true } +futures = { workspace = true } +http-body-util = { workspace = true } +hyper = { workspace = true, features = ["client", "http1"] } +hyper-util = { workspace = true, features = ["full"] } +p256 = { workspace = true, features = ["ecdsa"] } +serde_json = { workspace = true } +tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] } +tokio-util = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } diff --git a/crates/tests-integration/tests/defer_decryption.rs b/crates/tests-integration/tests/defer_decryption.rs new file mode 100644 index 0000000000..a9a8e6bee5 --- /dev/null +++ b/crates/tests-integration/tests/defer_decryption.rs @@ -0,0 +1,133 @@ +use futures::{AsyncReadExt, AsyncWriteExt}; +use tls_core::verify::WebPkiVerifier; +use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; +use tlsn_core::{ + attestation::AttestationConfig, request::RequestConfig, signing::SignatureAlgId, + transcript::TranscriptCommitConfig, CryptoProvider, +}; +use tlsn_prover::{Prover, ProverConfig}; +use tlsn_server_fixture::bind; +use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN}; +use tlsn_verifier::{Verifier, VerifierConfig}; + +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::compat::TokioAsyncReadCompatExt; +use tracing::instrument; + +// Maximum number of bytes that can be sent from prover to server +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server +const MAX_RECV_DATA: usize = 1 << 14; + +#[tokio::test] +#[ignore] +async fn test_defer_decryption() { + tracing_subscriber::fmt::init(); + + let (socket_0, socket_1) = tokio::io::duplex(2 << 23); + + tokio::join!(prover(socket_0), notary(socket_1)); +} + +#[instrument(skip(notary_socket))] +async fn prover(notary_socket: T) { + let (client_socket, server_socket) = tokio::io::duplex(2 << 16); + + let server_task = tokio::spawn(bind(server_socket.compat())); + + let mut root_store = tls_core::anchors::RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + let provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store, None), + ..Default::default() + }; + + let prover = Prover::new( + ProverConfig::builder() + .server_name(SERVER_DOMAIN) + .protocol_config( + ProtocolConfig::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(), + ) + .crypto_provider(provider) + .build() + .unwrap(), + ) + .setup(notary_socket.compat()) + .await + .unwrap(); + + let (mut tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); + let prover_task = tokio::spawn(prover_fut); + + tls_connection + .write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + tls_connection.close().await.unwrap(); + + let mut response = vec![0u8; 1024]; + tls_connection.read_to_end(&mut response).await.unwrap(); + + let _ = server_task.await.unwrap(); + + let mut prover = prover_task.await.unwrap().unwrap().start_notarize(); + let sent_tx_len = prover.transcript().sent().len(); + let recv_tx_len = prover.transcript().received().len(); + + let mut builder = TranscriptCommitConfig::builder(prover.transcript()); + + // Commit to everything + builder.commit_sent(&(0..sent_tx_len)).unwrap(); + builder.commit_recv(&(0..recv_tx_len)).unwrap(); + + let config = builder.build().unwrap(); + + prover.transcript_commit(config); + + let config = RequestConfig::default(); + + prover.finalize(&config).await.unwrap(); +} + +#[instrument(skip(socket))] +async fn notary(socket: T) { + let mut root_store = tls_core::anchors::RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + let mut provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store, None), + ..Default::default() + }; + + provider.signer.set_secp256k1(&[1u8; 32]).unwrap(); + + let config_validator = ProtocolConfigValidator::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + let verifier = Verifier::new( + VerifierConfig::builder() + .protocol_config_validator(config_validator) + .crypto_provider(provider) + .build() + .unwrap(), + ); + + let config = AttestationConfig::builder() + .supported_signature_algs(vec![SignatureAlgId::SECP256K1]) + .build() + .unwrap(); + + _ = verifier.notarize(socket.compat(), &config).await.unwrap(); +} diff --git a/crates/tests-integration/tests/notarize.rs b/crates/tests-integration/tests/notarize.rs new file mode 100644 index 0000000000..b7a34bd79a --- /dev/null +++ b/crates/tests-integration/tests/notarize.rs @@ -0,0 +1,204 @@ +use tls_core::verify::WebPkiVerifier; +use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; +use tlsn_core::{ + attestation::AttestationConfig, request::RequestConfig, signing::SignatureAlgId, + transcript::TranscriptCommitConfig, CryptoProvider, +}; +use tlsn_prover::{Prover, ProverConfig}; +use tlsn_server_fixture::bind; +use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN}; +use tlsn_verifier::{Verifier, VerifierConfig}; + +use http_body_util::{BodyExt as _, Empty}; +use hyper::{body::Bytes, Request, StatusCode}; +use hyper_util::rt::TokioIo; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; +use tracing::instrument; + +#[cfg(feature = "authdecode_unsafe")] +use tlsn_core::{ + hash::{HashAlgId, POSEIDON_MAX_INPUT_SIZE}, + transcript::{Direction, TranscriptCommitmentKind}, +}; + +// Maximum number of bytes that can be sent from prover to server. +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server. +const MAX_RECV_DATA: usize = 1 << 14; +#[cfg(feature = "authdecode_unsafe")] +// Maximum number of plaintext bytes which can be authenticated using the AuthDecode protocol. +const MAX_AUTHDECODE_DATA: usize = 1 << 10; + +#[tokio::test] +#[ignore] +async fn notarize() { + tracing_subscriber::fmt::init(); + + let (socket_0, socket_1) = tokio::io::duplex(2 << 23); + + tokio::join!(prover(socket_0), notary(socket_1)); +} + +#[instrument(skip(notary_socket))] +async fn prover(notary_socket: T) { + let (client_socket, server_socket) = tokio::io::duplex(2 << 16); + + let server_task = tokio::spawn(bind(server_socket.compat())); + + let mut root_store = tls_core::anchors::RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + let provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store, None), + ..Default::default() + }; + + let mut builder = ProtocolConfig::builder(); + builder + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .max_recv_data_online(MAX_RECV_DATA); + + #[cfg(feature = "authdecode_unsafe")] + builder.max_authdecode_data(MAX_AUTHDECODE_DATA); + + let protocol_config = builder.build().unwrap(); + + let prover = Prover::new( + ProverConfig::builder() + .server_name(SERVER_DOMAIN) + .defer_decryption_from_start(false) + .protocol_config(protocol_config) + .crypto_provider(provider) + .build() + .unwrap(), + ) + .setup(notary_socket.compat()) + .await + .unwrap(); + + let (tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); + + let prover_task = tokio::spawn(prover_fut); + + let (mut request_sender, connection) = + hyper::client::conn::http1::handshake(TokioIo::new(tls_connection.compat())) + .await + .unwrap(); + + tokio::spawn(connection); + + let request = Request::builder() + .uri(format!("https://{}/bytes?size=16000", SERVER_DOMAIN)) + .header("Host", SERVER_DOMAIN) + .header("Connection", "close") + .method("GET") + .body(Empty::::new()) + .unwrap(); + + let response = request_sender.send_request(request).await.unwrap(); + + assert!(response.status() == StatusCode::OK); + + let payload = response.into_body().collect().await.unwrap().to_bytes(); + println!("{:?}", &String::from_utf8_lossy(&payload)); + + let _ = server_task.await.unwrap(); + + let mut prover = prover_task.await.unwrap().unwrap().start_notarize(); + let sent_tx_len = prover.transcript().sent().len(); + let recv_tx_len = prover.transcript().received().len(); + + let mut builder = TranscriptCommitConfig::builder(prover.transcript()); + + // Commit to a portion of the data. + builder + .commit_sent(&(sent_tx_len / 2..sent_tx_len)) + .unwrap(); + builder + .commit_recv(&(recv_tx_len / 2..recv_tx_len)) + .unwrap(); + + #[cfg(feature = "authdecode_unsafe")] + let authdecode_inputs = { + let alg = HashAlgId::POSEIDON_BN256_434; + + builder.default_kind(TranscriptCommitmentKind::Hash { alg }); + + // Currently there is a limit on commitment data length for POSEIDON_HALO2. + let sent_range = 0..sent_tx_len / 2; + let recv_range = 0..POSEIDON_MAX_INPUT_SIZE; + assert!(sent_range.len() <= POSEIDON_MAX_INPUT_SIZE); + assert!(recv_range.len() <= POSEIDON_MAX_INPUT_SIZE); + + let blinder_sent = builder + .commit_with_blinder(&sent_range, Direction::Sent) + .unwrap(); + + let blinder_recv = builder + .commit_with_blinder(&recv_range, Direction::Received) + .unwrap(); + vec![ + (Direction::Sent, sent_range, alg, blinder_sent), + (Direction::Received, recv_range, alg, blinder_recv), + ] + }; + + let config = builder.build().unwrap(); + + prover.transcript_commit(config.clone()); + + let config = RequestConfig::default(); + + #[cfg(feature = "authdecode_unsafe")] + prover + .finalize_with_authdecode(&config, authdecode_inputs) + .await + .unwrap(); + + #[cfg(not(feature = "authdecode_unsafe"))] + prover.finalize(&config).await.unwrap(); +} + +#[instrument(skip(socket))] +async fn notary(socket: T) { + let mut root_store = tls_core::anchors::RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + let mut provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store, None), + ..Default::default() + }; + + provider.signer.set_secp256k1(&[1u8; 32]).unwrap(); + + let mut builder = ProtocolConfigValidator::builder(); + builder + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA); + + #[cfg(feature = "authdecode_unsafe")] + builder.max_authdecode_data(MAX_AUTHDECODE_DATA); + + let config_validator = builder.build().unwrap(); + + let verifier = Verifier::new( + VerifierConfig::builder() + .protocol_config_validator(config_validator) + .crypto_provider(provider) + .build() + .unwrap(), + ); + + let config = AttestationConfig::builder() + .supported_signature_algs(vec![SignatureAlgId::SECP256K1]) + .build() + .unwrap(); + + _ = verifier.notarize(socket.compat(), &config).await.unwrap(); +} diff --git a/crates/tests-integration/tests/verify.rs b/crates/tests-integration/tests/verify.rs new file mode 100644 index 0000000000..05628011bb --- /dev/null +++ b/crates/tests-integration/tests/verify.rs @@ -0,0 +1,150 @@ +use tls_core::{anchors::RootCertStore, verify::WebPkiVerifier}; +use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; +use tlsn_core::{ + transcript::{Idx, PartialTranscript}, + CryptoProvider, +}; +use tlsn_prover::{Prover, ProverConfig}; +use tlsn_server_fixture::bind; +use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN}; +use tlsn_verifier::{SessionInfo, Verifier, VerifierConfig}; + +use http_body_util::{BodyExt as _, Empty}; +use hyper::{body::Bytes, Request, StatusCode}; +use hyper_util::rt::TokioIo; + +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; +use tracing::instrument; + +// Maximum number of bytes that can be sent from prover to server +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server +const MAX_RECV_DATA: usize = 1 << 14; + +#[tokio::test] +#[ignore] +async fn verify() { + tracing_subscriber::fmt::init(); + + let (socket_0, socket_1) = tokio::io::duplex(1 << 23); + + let (_, (partial_transcript, info)) = tokio::join!(prover(socket_0), verifier(socket_1)); + + assert_eq!( + partial_transcript.sent_authed(), + &Idx::new(0..partial_transcript.len_sent() - 1) + ); + assert_eq!( + partial_transcript.received_authed(), + &Idx::new(2..partial_transcript.len_received()) + ); + assert_eq!(info.server_name.as_str(), SERVER_DOMAIN); +} + +#[instrument(skip(notary_socket))] +async fn prover(notary_socket: T) { + let (client_socket, server_socket) = tokio::io::duplex(1 << 16); + + let server_task = tokio::spawn(bind(server_socket.compat())); + + let mut root_store = RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + let provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store, None), + ..Default::default() + }; + + let prover = Prover::new( + ProverConfig::builder() + .server_name(SERVER_DOMAIN) + .defer_decryption_from_start(false) + .protocol_config( + ProtocolConfig::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .max_recv_data_online(MAX_RECV_DATA) + .build() + .unwrap(), + ) + .crypto_provider(provider) + .build() + .unwrap(), + ) + .setup(notary_socket.compat()) + .await + .unwrap(); + + let (tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); + + let prover_task = tokio::spawn(prover_fut); + + let (mut request_sender, connection) = + hyper::client::conn::http1::handshake(TokioIo::new(tls_connection.compat())) + .await + .unwrap(); + + tokio::spawn(connection); + + let request = Request::builder() + .uri(format!("https://{}", SERVER_DOMAIN)) + .header("Host", SERVER_DOMAIN) + .header("Connection", "close") + .method("GET") + .body(Empty::::new()) + .unwrap(); + + let response = request_sender.send_request(request).await.unwrap(); + + assert!(response.status() == StatusCode::OK); + + let payload = response.into_body().collect().await.unwrap().to_bytes(); + println!("{:?}", &String::from_utf8_lossy(&payload)); + + let _ = server_task.await.unwrap(); + + let mut prover = prover_task.await.unwrap().unwrap().start_prove(); + + let (sent_len, recv_len) = prover.transcript().len(); + + let idx_sent = Idx::new(0..sent_len - 1); + let idx_recv = Idx::new(2..recv_len); + + // Reveal parts of the transcript + prover.prove_transcript(idx_sent, idx_recv).await.unwrap(); + prover.finalize().await.unwrap(); +} + +#[instrument(skip(socket))] +async fn verifier( + socket: T, +) -> (PartialTranscript, SessionInfo) { + let mut root_store = RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + let provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store, None), + ..Default::default() + }; + + let config = VerifierConfig::builder() + .protocol_config_validator( + ProtocolConfigValidator::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(), + ) + .crypto_provider(provider) + .build() + .unwrap(); + + let verifier = Verifier::new(config); + + verifier.verify(socket.compat()).await.unwrap() +} diff --git a/components/tls/tls-backend/Cargo.toml b/crates/tls/backend/Cargo.toml similarity index 52% rename from components/tls/tls-backend/Cargo.toml rename to crates/tls/backend/Cargo.toml index 8df8ba4f8a..5e3cf49f46 100644 --- a/components/tls/tls-backend/Cargo.toml +++ b/crates/tls/backend/Cargo.toml @@ -5,16 +5,15 @@ description = "A TLS backend trait for TLSNotary" keywords = ["tls", "mpc", "2pc"] categories = ["cryptography"] license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" +version = "0.1.0-alpha.7" edition = "2021" [lib] name = "tls_backend" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] -tlsn-tls-core = { path = "../tls-core" } -async-trait.workspace = true -thiserror.workspace = true -futures.workspace = true +tlsn-tls-core = { workspace = true } + +async-trait = { workspace = true } +thiserror = { workspace = true } +futures = { workspace = true } diff --git a/components/tls/tls-backend/src/lib.rs b/crates/tls/backend/src/lib.rs similarity index 97% rename from components/tls/tls-backend/src/lib.rs rename to crates/tls/backend/src/lib.rs index 1c6336f5f9..9db7063155 100644 --- a/components/tls/tls-backend/src/lib.rs +++ b/crates/tls/backend/src/lib.rs @@ -1,5 +1,5 @@ -//! This library provides the [Backend] trait to encapsulate the cryptography backend of the TLS -//! client. +//! This library provides the [Backend] trait to encapsulate the cryptography +//! backend of the TLS client. #![deny(missing_docs, unreachable_pub, unused_must_use)] #![deny(clippy::all)] @@ -66,8 +66,8 @@ pub enum DecryptMode { Application, } -/// Core trait which manages crypto operations for the TLS connection such as key exchange, encryption -/// and decryption. +/// Core trait which manages crypto operations for the TLS connection such as +/// key exchange, encryption and decryption. #[async_trait] pub trait Backend: Send { /// Signals selected protocol version to implementor. @@ -120,8 +120,8 @@ pub trait Backend: Send { async fn buffer_incoming(&mut self, msg: OpaqueMessage) -> Result<(), BackendError>; /// Returns next incoming message ready for decryption. async fn next_incoming(&mut self) -> Result, BackendError>; - /// Returns a notification future which resolves when the backend is ready to process - /// the next message. + /// Returns a notification future which resolves when the backend is ready + /// to process the next message. async fn get_notify(&mut self) -> Result { Ok(BackendNotify::dummy()) } diff --git a/components/tls/tls-backend/src/notify.rs b/crates/tls/backend/src/notify.rs similarity index 100% rename from components/tls/tls-backend/src/notify.rs rename to crates/tls/backend/src/notify.rs diff --git a/components/tls/tls-client-async/Cargo.toml b/crates/tls/client-async/Cargo.toml similarity index 63% rename from components/tls/tls-client-async/Cargo.toml rename to crates/tls/client-async/Cargo.toml index 9247f18e0c..dea545031a 100644 --- a/components/tls/tls-client-async/Cargo.toml +++ b/crates/tls/client-async/Cargo.toml @@ -5,7 +5,7 @@ description = "An async TLS client for TLSNotary" keywords = ["tls", "mpc", "2pc", "client", "async"] categories = ["cryptography"] license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" +version = "0.1.0-alpha.7" edition = "2021" [lib] @@ -16,17 +16,21 @@ default = ["tracing"] tracing = ["dep:tracing"] [dependencies] -tlsn-tls-client = { path = "../tls-client" } -bytes.workspace = true -futures.workspace = true +tlsn-tls-client = { workspace = true } + +bytes = { workspace = true } +futures = { workspace = true } +thiserror = { workspace = true } tokio-util = { workspace = true, features = ["io", "compat"] } tracing = { workspace = true, optional = true } -thiserror.workspace = true [dev-dependencies] -tracing-subscriber.workspace = true -tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] } -webpki-roots.workspace = true +tls-server-fixture = { workspace = true } + +http-body-util = { workspace = true } hyper = { workspace = true, features = ["client", "http1"] } -tls-server-fixture = { path = "../tls-server-fixture" } +hyper-util = { workspace = true, features = ["full"] } rstest = { workspace = true } +tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] } +tracing-subscriber = { workspace = true } +webpki-roots = { workspace = true } diff --git a/components/tls/tls-client-async/src/conn.rs b/crates/tls/client-async/src/conn.rs similarity index 95% rename from components/tls/tls-client-async/src/conn.rs rename to crates/tls/client-async/src/conn.rs index 186a067293..aec79c92b6 100644 --- a/components/tls/tls-client-async/src/conn.rs +++ b/crates/tls/client-async/src/conn.rs @@ -19,13 +19,15 @@ type CompatSinkWriter = /// A TLS connection to a server. /// -/// This type implements `AsyncRead` and `AsyncWrite` and can be used to communicate -/// with a server using TLS. +/// This type implements `AsyncRead` and `AsyncWrite` and can be used to +/// communicate with a server using TLS. /// /// # Note /// -/// This connection is closed on a best-effort basis if this is dropped. To ensure a clean close, you should call -/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the connection. +/// This connection is closed on a best-effort basis if this is dropped. To +/// ensure a clean close, you should call +/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the +/// connection. #[derive(Debug)] pub struct TlsConnection { /// The data to be transmitted to the server is sent to this sink. diff --git a/components/tls/tls-client-async/src/lib.rs b/crates/tls/client-async/src/lib.rs similarity index 94% rename from components/tls/tls-client-async/src/lib.rs rename to crates/tls/client-async/src/lib.rs index ca3c6eaf3b..a7e66482d5 100644 --- a/components/tls/tls-client-async/src/lib.rs +++ b/crates/tls/client-async/src/lib.rs @@ -1,9 +1,10 @@ //! Provides a TLS client which exposes an async socket. //! -//! This library provides the [bind_client] function which attaches a TLS client to a socket -//! connection and then exposes a [TlsConnection] object, which provides an async socket API for -//! reading and writing cleartext. The TLS client will then automatically encrypt and decrypt -//! traffic and forward that to the provided socket. +//! This library provides the [bind_client] function which attaches a TLS client +//! to a socket connection and then exposes a [TlsConnection] object, which +//! provides an async socket API for reading and writing cleartext. The TLS +//! client will then automatically encrypt and decrypt traffic and forward that +//! to the provided socket. #![deny(missing_docs, unreachable_pub, unused_must_use)] #![deny(clippy::all)] @@ -71,11 +72,13 @@ impl Future for ConnectionFuture { /// Binds a client connection to the provided socket. /// -/// Returns a connection handle and a future which runs the connection to completion. +/// Returns a connection handle and a future which runs the connection to +/// completion. /// /// # Errors /// -/// Any connection errors that occur will be returned from the future, not [`TlsConnection`]. +/// Any connection errors that occur will be returned from the future, not +/// [`TlsConnection`]. pub fn bind_client( socket: T, mut client: ClientConnection, @@ -138,14 +141,13 @@ pub fn bind_client( #[cfg(feature = "tracing")] debug!("handshake complete"); handshake_done = true; - // Start reading application data that needs to be transmitted from the `TlsConnection`. + // Start reading application data that needs to be transmitted from the + // `TlsConnection`. tx_recv_fut = tx_receiver.next().fuse(); } - if server_closed && client.plaintext_is_empty() { - if client.buffer_len().await? == 0 { - break 'conn; - } + if server_closed && client.plaintext_is_empty() && client.buffer_len().await? == 0 { + break 'conn; } select_biased! { diff --git a/components/tls/tls-client-async/tests/test.rs b/crates/tls/client-async/tests/test.rs similarity index 92% rename from components/tls/tls-client-async/tests/test.rs rename to crates/tls/client-async/tests/test.rs index 71d50a7bfa..db8bbbf214 100644 --- a/components/tls/tls-client-async/tests/test.rs +++ b/crates/tls/client-async/tests/test.rs @@ -2,7 +2,9 @@ use std::{str, sync::Arc}; use core::future::Future; use futures::{AsyncReadExt, AsyncWriteExt}; -use hyper::{body::to_bytes, Body, Request, StatusCode}; +use http_body_util::{BodyExt as _, Full}; +use hyper::{body::Bytes, Request, StatusCode}; +use hyper_util::rt::TokioIo; use rstest::{fixture, rstest}; use tls_client::{Certificate, ClientConfig, ClientConnection, RustCryptoBackend, ServerName}; use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection}; @@ -64,8 +66,8 @@ async fn set_up_tls() -> TlsFixture { } } -// Expect the async tls client wrapped in `hyper::client` to make a successful request and receive -// the expected response and cleanly close the TLS connection +// Expect the async tls client wrapped in `hyper::client` to make a successful +// request and receive the expected response #[tokio::test] async fn test_hyper_ok() { let (client_socket, server_socket) = tokio::io::duplex(1 << 16); @@ -90,16 +92,18 @@ async fn test_hyper_ok() { let closed_tls_task = tokio::spawn(tls_fut); let (mut request_sender, connection) = - hyper::client::conn::handshake(conn.compat()).await.unwrap(); + hyper::client::conn::http1::handshake(TokioIo::new(conn.compat())) + .await + .unwrap(); - let http_task = tokio::spawn(connection.without_shutdown()); + tokio::spawn(connection); let request = Request::builder() .uri(format!("https://{}/echo", SERVER_DOMAIN)) .header("Host", SERVER_DOMAIN) .header("Connection", "close") .method("POST") - .body(Body::from("hello")) + .body(Full::::new("hello".into())) .unwrap(); let response = request_sender.send_request(request).await.unwrap(); @@ -107,25 +111,17 @@ async fn test_hyper_ok() { assert!(response.status() == StatusCode::OK); // Process the response body - to_bytes(response.into_body()).await.unwrap(); + response.into_body().collect().await.unwrap().to_bytes(); - let mut server_tls_conn = server_task.await.unwrap().unwrap(); - - // Make sure the server closes cleanly (sends close notify) - server_tls_conn.close().await.unwrap(); - - let http_parts = http_task.await.unwrap().unwrap(); - let mut tls_conn = http_parts.io.into_inner(); - - tls_conn.close().await.unwrap(); + let _ = server_task.await.unwrap(); let closed_conn = closed_tls_task.await.unwrap().unwrap(); assert!(closed_conn.client.received_close_notify()); } -// Expect a clean TLS connection closure when server responds to the client's close_notify but -// doesn't close the socket +// Expect a clean TLS connection closure when server responds to the client's +// close_notify but doesn't close the socket #[rstest] #[tokio::test] async fn test_ok_server_no_socket_close(set_up_tls: impl Future) { @@ -149,8 +145,8 @@ async fn test_ok_server_no_socket_close(set_up_tls: impl Future) { @@ -159,7 +155,8 @@ async fn test_ok_server_socket_close(set_up_tls: impl Future) { @@ -201,8 +198,8 @@ async fn test_ok_server_close_notify(set_up_tls: impl Future) assert_eq!(std::str::from_utf8(&buf[0..n]).unwrap(), "hello"); } -// Expect there to be no error when server DOES NOT send close_notify but just closes the socket +// Expect there to be no error when server DOES NOT send close_notify but just +// closes the socket #[rstest] #[tokio::test] async fn test_ok_server_no_close_notify(set_up_tls: impl Future) { @@ -395,7 +393,8 @@ async fn test_err_alert(set_up_tls: impl Future) { ); } -// Expect an error when trying to write data to a connection which server closed abruptly +// Expect an error when trying to write data to a connection which server closed +// abruptly #[rstest] #[tokio::test] async fn test_err_write_after_close(set_up_tls: impl Future) { diff --git a/components/tls/tls-client/Cargo.toml b/crates/tls/client/Cargo.toml similarity index 61% rename from components/tls/tls-client/Cargo.toml rename to crates/tls/client/Cargo.toml index 6d1dbfac87..cd9e4fa383 100644 --- a/components/tls/tls-client/Cargo.toml +++ b/crates/tls/client/Cargo.toml @@ -5,47 +5,41 @@ description = "A TLS client for TLSNotary" keywords = ["tls", "mpc", "2pc", "client", "sync"] categories = ["cryptography"] license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" +version = "0.1.0-alpha.7" edition = "2021" autobenches = false -build = "build.rs" [lib] name = "tls_client" -[build-dependencies] -rustversion = { version = "1", optional = true } - [dependencies] -tlsn-tls-backend = { path = "../tls-backend" } -tlsn-tls-core = { path = "../tls-core" } +tlsn-tls-backend = { workspace = true } +tlsn-tls-core = { workspace = true } -async-trait.workspace = true +async-trait = { workspace = true } log = { workspace = true, optional = true } -ring.workspace = true -sct.workspace = true +ring = { workspace = true } +sct = { workspace = true } webpki = { workspace = true, features = ["alloc", "std"] } -aes-gcm.workspace = true +aes-gcm = { workspace = true } p256 = { workspace = true, features = ["ecdh"] } -rand.workspace = true -hmac.workspace = true +rand = { workspace = true } +hmac = { workspace = true } sha2 = { workspace = true, features = ["compress"] } -digest.workspace = true -futures.workspace = true -web-time.workspace = true +digest = { workspace = true } +futures = { workspace = true } +web-time = { workspace = true } [features] default = ["logging", "tls12"] logging = ["log"] -dangerous_configuration = [] tls12 = [] -read_buf = ["rustversion"] [dev-dependencies] -env_logger.workspace = true -webpki-roots.workspace = true -rustls-pemfile.workspace = true -rustls = { workspace = true, features = ["tls12"] } +env_logger = { workspace = true } +webpki-roots = { workspace = true } +rustls-pemfile = { workspace = true } +rustls = { version = "0.20", features = ["tls12"] } tokio = { workspace = true, features = ["rt", "macros"] } [[example]] diff --git a/components/tls/tls-client/README.md b/crates/tls/client/README.md similarity index 100% rename from components/tls/tls-client/README.md rename to crates/tls/client/README.md diff --git a/components/tls/tls-client/examples/internal/bench.rs b/crates/tls/client/examples/internal/bench.rs similarity index 100% rename from components/tls/tls-client/examples/internal/bench.rs rename to crates/tls/client/examples/internal/bench.rs diff --git a/components/tls/tls-client/examples/internal/bogo_shim.rs b/crates/tls/client/examples/internal/bogo_shim.rs similarity index 100% rename from components/tls/tls-client/examples/internal/bogo_shim.rs rename to crates/tls/client/examples/internal/bogo_shim.rs diff --git a/components/tls/tls-client/examples/internal/trytls_shim.rs b/crates/tls/client/examples/internal/trytls_shim.rs similarity index 100% rename from components/tls/tls-client/examples/internal/trytls_shim.rs rename to crates/tls/client/examples/internal/trytls_shim.rs diff --git a/components/tls/tls-client/src/backend/mod.rs b/crates/tls/client/src/backend/mod.rs similarity index 100% rename from components/tls/tls-client/src/backend/mod.rs rename to crates/tls/client/src/backend/mod.rs diff --git a/components/tls/tls-client/src/backend/standard.rs b/crates/tls/client/src/backend/standard.rs similarity index 99% rename from components/tls/tls-client/src/backend/standard.rs rename to crates/tls/client/src/backend/standard.rs index afd6b7adb0..bfa843e1dc 100644 --- a/components/tls/tls-client/src/backend/standard.rs +++ b/crates/tls/client/src/backend/standard.rs @@ -9,7 +9,7 @@ use p256::{ecdh::EphemeralSecret, EncodedPoint, PublicKey as ECDHPublicKey}; use rand::{rngs::OsRng, thread_rng, Rng}; use digest::Digest; -use std::{any::Any, convert::TryInto, collections::VecDeque}; +use std::{any::Any, collections::VecDeque, convert::TryInto}; use tls_core::{ cert::ServerCertDetails, ke::ServerKxDetails, diff --git a/components/tls/tls-client/src/bs_debug.rs b/crates/tls/client/src/bs_debug.rs similarity index 100% rename from components/tls/tls-client/src/bs_debug.rs rename to crates/tls/client/src/bs_debug.rs diff --git a/components/tls/tls-client/src/builder.rs b/crates/tls/client/src/builder.rs similarity index 100% rename from components/tls/tls-client/src/builder.rs rename to crates/tls/client/src/builder.rs diff --git a/components/tls/tls-client/src/check.rs b/crates/tls/client/src/check.rs similarity index 100% rename from components/tls/tls-client/src/check.rs rename to crates/tls/client/src/check.rs diff --git a/components/tls/tls-client/src/cipher.rs b/crates/tls/client/src/cipher.rs similarity index 87% rename from components/tls/tls-client/src/cipher.rs rename to crates/tls/client/src/cipher.rs index 1179ddc851..9dae1a2f0a 100644 --- a/components/tls/tls-client/src/cipher.rs +++ b/crates/tls/client/src/cipher.rs @@ -42,17 +42,17 @@ impl MessageDecrypter for InvalidMessageDecrypter { /// A write or read IV. #[derive(Default)] -pub(crate) struct Iv(pub(crate) [u8; ring::aead::NONCE_LEN]); +pub(crate) struct Iv(pub(crate) [u8; aead::NONCE_LEN]); impl Iv { #[cfg(feature = "tls12")] - fn new(value: [u8; ring::aead::NONCE_LEN]) -> Self { + fn new(value: [u8; aead::NONCE_LEN]) -> Self { Self(value) } #[cfg(feature = "tls12")] pub(crate) fn copy(value: &[u8]) -> Self { - debug_assert_eq!(value.len(), ring::aead::NONCE_LEN); + debug_assert_eq!(value.len(), aead::NONCE_LEN); let mut iv = Self::new(Default::default()); iv.0.copy_from_slice(value); iv @@ -80,8 +80,8 @@ impl From> for Iv { } } -pub(crate) fn make_nonce(iv: &Iv, seq: u64) -> ring::aead::Nonce { - let mut nonce = [0u8; ring::aead::NONCE_LEN]; +pub(crate) fn make_nonce(iv: &Iv, seq: u64) -> aead::Nonce { + let mut nonce = [0u8; aead::NONCE_LEN]; codec::put_u64(seq, &mut nonce[4..]); nonce.iter_mut().zip(iv.0.iter()).for_each(|(nonce, iv)| { diff --git a/components/tls/tls-client/src/client/builder.rs b/crates/tls/client/src/client/builder.rs similarity index 90% rename from components/tls/tls-client/src/client/builder.rs rename to crates/tls/client/src/client/builder.rs index 5b849dea8f..8ae975ff18 100644 --- a/components/tls/tls-client/src/client/builder.rs +++ b/crates/tls/client/src/client/builder.rs @@ -26,22 +26,6 @@ impl ConfigBuilder { }, } } - - #[cfg(feature = "dangerous_configuration")] - /// Set a custom certificate verifier. - pub fn with_custom_certificate_verifier( - self, - verifier: Arc, - ) -> ConfigBuilder { - ConfigBuilder { - state: WantsClientCert { - cipher_suites: self.state.cipher_suites, - kx_groups: self.state.kx_groups, - versions: self.state.versions, - verifier, - }, - } - } } /// A config builder state where the caller needs to supply a certificate transparency policy or @@ -86,7 +70,7 @@ impl ConfigBuilder { /// This function fails if `key_der` is invalid. pub fn with_single_cert( self, - cert_chain: Vec, + cert_chain: Vec, key_der: key::PrivateKey, ) -> Result { self.with_logs(None).with_single_cert(cert_chain, key_der) @@ -146,7 +130,7 @@ impl ConfigBuilder { /// This function fails if `key_der` is invalid. pub fn with_single_cert( self, - cert_chain: Vec, + cert_chain: Vec, key_der: key::PrivateKey, ) -> Result { let resolver = handy::AlwaysResolvesClientCert::new(cert_chain, &key_der)?; diff --git a/components/tls/tls-client/src/client/client_conn.rs b/crates/tls/client/src/client/client_conn.rs similarity index 92% rename from components/tls/tls-client/src/client/client_conn.rs rename to crates/tls/client/src/client/client_conn.rs index 1526203b85..227e9b51e7 100644 --- a/components/tls/tls-client/src/client/client_conn.rs +++ b/crates/tls/client/src/client/client_conn.rs @@ -171,13 +171,6 @@ impl ClientConfig { .any(|cs| cs.version().version == v) } - /// Access configuration options whose use is dangerous and requires - /// extra care. - #[cfg(feature = "dangerous_configuration")] - pub fn dangerous(&mut self) -> danger::DangerousClientConfig { - danger::DangerousClientConfig { cfg: self } - } - pub(super) fn find_cipher_suite(&self, suite: CipherSuite) -> Option { self.cipher_suites .iter() @@ -186,27 +179,6 @@ impl ClientConfig { } } -/// Container for unsafe APIs -#[cfg(feature = "dangerous_configuration")] -pub(super) mod danger { - use std::sync::Arc; - - use super::{verify::ServerCertVerifier, ClientConfig}; - - /// Accessor for dangerous configuration options. - pub struct DangerousClientConfig<'a> { - /// The underlying ClientConfig - pub cfg: &'a mut ClientConfig, - } - - impl<'a> DangerousClientConfig<'a> { - /// Overrides the default `ServerCertVerifier` with something else. - pub fn set_certificate_verifier(&mut self, verifier: Arc) { - self.cfg.verifier = verifier; - } - } -} - #[derive(Debug, PartialEq)] enum EarlyDataState { Disabled, diff --git a/components/tls/tls-client/src/client/common.rs b/crates/tls/client/src/client/common.rs similarity index 100% rename from components/tls/tls-client/src/client/common.rs rename to crates/tls/client/src/client/common.rs diff --git a/components/tls/tls-client/src/client/handy.rs b/crates/tls/client/src/client/handy.rs similarity index 98% rename from components/tls/tls-client/src/client/handy.rs rename to crates/tls/client/src/client/handy.rs index 772519f400..896c768a23 100644 --- a/components/tls/tls-client/src/client/handy.rs +++ b/crates/tls/client/src/client/handy.rs @@ -64,7 +64,7 @@ pub(super) struct AlwaysResolvesClientCert(Arc); impl AlwaysResolvesClientCert { pub(super) fn new( - chain: Vec, + chain: Vec, priv_key: &key::PrivateKey, ) -> Result { let key = sign::any_supported_type(priv_key) diff --git a/components/tls/tls-client/src/client/hs.rs b/crates/tls/client/src/client/hs.rs similarity index 100% rename from components/tls/tls-client/src/client/hs.rs rename to crates/tls/client/src/client/hs.rs diff --git a/components/tls/tls-client/src/client/tls12.rs b/crates/tls/client/src/client/tls12.rs similarity index 100% rename from components/tls/tls-client/src/client/tls12.rs rename to crates/tls/client/src/client/tls12.rs diff --git a/components/tls/tls-client/src/client/tls13.rs b/crates/tls/client/src/client/tls13.rs similarity index 100% rename from components/tls/tls-client/src/client/tls13.rs rename to crates/tls/client/src/client/tls13.rs diff --git a/components/tls/tls-client/src/conn.rs b/crates/tls/client/src/conn.rs similarity index 95% rename from components/tls/tls-client/src/conn.rs rename to crates/tls/client/src/conn.rs index 56d64da6e5..8a397a963a 100644 --- a/components/tls/tls-client/src/conn.rs +++ b/crates/tls/client/src/conn.rs @@ -111,49 +111,6 @@ impl<'a> io::Read for Reader<'a> { Ok(len) } - - /// Obtain plaintext data received from the peer over this TLS connection. - /// - /// If the peer closes the TLS session, this returns `Ok(())` without filling - /// any more of the buffer once all the pending data has been read. No further - /// data can be received on that connection, so the underlying TCP connection - /// should be half-closed too. - /// - /// If the peer closes the TLS session uncleanly (a TCP EOF without sending a - /// `close_notify` alert) this function returns `Err(ErrorKind::UnexpectedEof.into())` - /// once any pending data has been read. - /// - /// Note that support for `close_notify` varies in peer TLS libraries: many do not - /// support it and uncleanly close the TCP connection (this might be - /// vulnerable to truncation attacks depending on the application protocol). - /// This means applications using rustls must both handle EOF - /// from this function, *and* unexpected EOF of the underlying TCP connection. - /// - /// If there are no bytes to read, this returns `Err(ErrorKind::WouldBlock.into())`. - /// - /// You may learn the number of bytes available at any time by inspecting - /// the return of [`Connection::process_new_packets`]. - #[cfg(read_buf)] - fn read_buf(&mut self, buf: &mut io::ReadBuf<'_>) -> io::Result<()> { - let before = buf.filled_len(); - self.received_plaintext.read_buf(buf)?; - let len = buf.filled_len() - before; - - if len == 0 && buf.capacity() > 0 { - // No bytes available: - match (self.peer_cleanly_closed, self.has_seen_eof) { - // cleanly closed; don't care about TCP EOF: express this as Ok(0) - (true, _) => {} - // unclean closure - (false, true) => return Err(io::ErrorKind::UnexpectedEof.into()), - // connection still going, but need more data: signal `WouldBlock` so that - // the caller knows this - (false, false) => return Err(io::ErrorKind::WouldBlock.into()), - } - } - - Ok(()) - } } #[derive(Copy, Clone, Eq, PartialEq)] @@ -232,7 +189,7 @@ impl ConnectionCommon { /// Reads out any buffered plaintext received from the peer. Returns the /// number of bytes read. - pub fn read_plaintext(&mut self, buf: &mut [u8]) -> std::io::Result { + pub fn read_plaintext(&mut self, buf: &mut [u8]) -> io::Result { self.common_state.received_plaintext.read(buf) } diff --git a/components/tls/tls-client/src/crypto/mod.rs b/crates/tls/client/src/crypto/mod.rs similarity index 100% rename from components/tls/tls-client/src/crypto/mod.rs rename to crates/tls/client/src/crypto/mod.rs diff --git a/components/tls/tls-client/src/crypto/standard.rs b/crates/tls/client/src/crypto/standard.rs similarity index 100% rename from components/tls/tls-client/src/crypto/standard.rs rename to crates/tls/client/src/crypto/standard.rs diff --git a/components/tls/tls-client/src/error.rs b/crates/tls/client/src/error.rs similarity index 100% rename from components/tls/tls-client/src/error.rs rename to crates/tls/client/src/error.rs diff --git a/components/tls/tls-client/src/hash_hs.rs b/crates/tls/client/src/hash_hs.rs similarity index 100% rename from components/tls/tls-client/src/hash_hs.rs rename to crates/tls/client/src/hash_hs.rs diff --git a/components/tls/tls-client/src/key_log.rs b/crates/tls/client/src/key_log.rs similarity index 100% rename from components/tls/tls-client/src/key_log.rs rename to crates/tls/client/src/key_log.rs diff --git a/components/tls/tls-client/src/key_log_file.rs b/crates/tls/client/src/key_log_file.rs similarity index 100% rename from components/tls/tls-client/src/key_log_file.rs rename to crates/tls/client/src/key_log_file.rs diff --git a/components/tls/tls-client/src/kx.rs b/crates/tls/client/src/kx.rs similarity index 100% rename from components/tls/tls-client/src/kx.rs rename to crates/tls/client/src/kx.rs diff --git a/components/tls/tls-client/src/lib.rs b/crates/tls/client/src/lib.rs similarity index 90% rename from components/tls/tls-client/src/lib.rs rename to crates/tls/client/src/lib.rs index 76f31828ac..9555770dd2 100644 --- a/components/tls/tls-client/src/lib.rs +++ b/crates/tls/client/src/lib.rs @@ -252,14 +252,12 @@ // Require docs for public APIs, deny unsafe code, etc. #![forbid(unsafe_code)] #![allow(dead_code, unused_imports)] -#![cfg_attr(not(read_buf), forbid(unstable_features))] #![deny( clippy::clone_on_ref_ptr, clippy::use_self, trivial_casts, trivial_numeric_casts, missing_docs, - //unreachable_pub, unused_import_braces, unused_extern_crates, unused_qualifications @@ -274,24 +272,9 @@ // a false positive, https://github.com/rust-lang/rust-clippy/issues/5210 // - new_without_default: for internal constructors, the indirection is not // helpful -#![allow( - clippy::too_many_arguments, - clippy::new_ret_no_self, - clippy::ptr_arg, - clippy::single_component_path_imports, - clippy::new_without_default -)] +#![allow(clippy::all)] // Enable documentation for all features on docs.rs #![cfg_attr(docsrs, feature(doc_cfg))] -// XXX: Because of https://github.com/rust-lang/rust/issues/54726, we cannot -// write `#![rustversion::attr(nightly, feature(read_buf))]` here. Instead, -// build.rs set `read_buf` for (only) Rust Nightly to get the same effect. -// -// All the other conditional logic in the crate could use -// `#[rustversion::nightly]` instead of `#[cfg(read_buf)]`; `#[cfg(read_buf)]` -// is used to avoid needing `rustversion` to be compiled twice during -// cross-compiling. -#![cfg_attr(read_buf, feature(read_buf))] // log for logging (optional). #[cfg(feature = "logging")] @@ -366,8 +349,6 @@ pub use tls_core::{ suites::{SupportedCipherSuite, ALL_CIPHER_SUITES}, versions::{SupportedProtocolVersion, ALL_VERSIONS}, }; -//pub use crate::stream::{Stream, StreamOwned}; -//pub use crate::ticketer::Ticketer; /// Items for use in a client. pub mod client { @@ -386,16 +367,6 @@ pub mod client { ResolvesClientCert, ServerName, StoresClientSessions, }; pub use handy::{ClientSessionMemoryCache, NoClientSessionStorage}; - - #[cfg(feature = "dangerous_configuration")] - #[cfg_attr(docsrs, doc(cfg(feature = "dangerous_configuration")))] - pub use crate::verify::{ - CertificateTransparencyPolicy, HandshakeSignatureValid, ServerCertVerified, - ServerCertVerifier, WebPkiVerifier, - }; - #[cfg(feature = "dangerous_configuration")] - #[cfg_attr(docsrs, doc(cfg(feature = "dangerous_configuration")))] - pub use client_conn::danger::DangerousClientConfig; } pub use client::{ClientConfig, ClientConnection, ServerName}; @@ -422,13 +393,6 @@ pub mod sign; /// This is the rustls manual. pub mod manual; -/** Type renames. */ -#[allow(clippy::upper_case_acronyms)] -#[cfg(feature = "dangerous_configuration")] -#[cfg_attr(docsrs, doc(cfg(feature = "dangerous_configuration")))] -#[doc(hidden)] -#[deprecated(since = "0.20.0", note = "Use client::WebPkiVerifier")] -pub type WebPKIVerifier = client::WebPkiVerifier; #[allow(clippy::upper_case_acronyms)] #[doc(hidden)] #[deprecated(since = "0.20.0", note = "Use Error")] diff --git a/components/tls/tls-client/src/limited_cache.rs b/crates/tls/client/src/limited_cache.rs similarity index 96% rename from components/tls/tls-client/src/limited_cache.rs rename to crates/tls/client/src/limited_cache.rs index 76c62904c3..57e5448cc9 100644 --- a/components/tls/tls-client/src/limited_cache.rs +++ b/crates/tls/client/src/limited_cache.rs @@ -54,18 +54,18 @@ where } } - pub(crate) fn get(&self, k: &Q) -> Option<&V> + pub(crate) fn get(&self, k: &Q) -> Option<&V> where K: Borrow, - Q: Hash + Eq, + Q: Hash + Eq + ?Sized, { self.map.get(k) } - pub(crate) fn remove(&mut self, k: &Q) -> Option + pub(crate) fn remove(&mut self, k: &Q) -> Option where K: Borrow, - Q: Hash + Eq, + Q: Hash + Eq + ?Sized, { if let Some(value) = self.map.remove(k) { // O(N) search, followed by O(N) removal diff --git a/components/tls/tls-client/src/manual/defaults.rs b/crates/tls/client/src/manual/defaults.rs similarity index 100% rename from components/tls/tls-client/src/manual/defaults.rs rename to crates/tls/client/src/manual/defaults.rs diff --git a/components/tls/tls-client/src/manual/features.rs b/crates/tls/client/src/manual/features.rs similarity index 100% rename from components/tls/tls-client/src/manual/features.rs rename to crates/tls/client/src/manual/features.rs diff --git a/components/tls/tls-client/src/manual/howto.rs b/crates/tls/client/src/manual/howto.rs similarity index 100% rename from components/tls/tls-client/src/manual/howto.rs rename to crates/tls/client/src/manual/howto.rs diff --git a/components/tls/tls-client/src/manual/implvulns.rs b/crates/tls/client/src/manual/implvulns.rs similarity index 100% rename from components/tls/tls-client/src/manual/implvulns.rs rename to crates/tls/client/src/manual/implvulns.rs diff --git a/components/tls/tls-client/src/manual/mod.rs b/crates/tls/client/src/manual/mod.rs similarity index 100% rename from components/tls/tls-client/src/manual/mod.rs rename to crates/tls/client/src/manual/mod.rs diff --git a/components/tls/tls-client/src/manual/tlsvulns.rs b/crates/tls/client/src/manual/tlsvulns.rs similarity index 100% rename from components/tls/tls-client/src/manual/tlsvulns.rs rename to crates/tls/client/src/manual/tlsvulns.rs diff --git a/components/tls/tls-client/src/msgs/mod.rs b/crates/tls/client/src/msgs/mod.rs similarity index 100% rename from components/tls/tls-client/src/msgs/mod.rs rename to crates/tls/client/src/msgs/mod.rs diff --git a/components/tls/tls-client/src/msgs/persist.rs b/crates/tls/client/src/msgs/persist.rs similarity index 100% rename from components/tls/tls-client/src/msgs/persist.rs rename to crates/tls/client/src/msgs/persist.rs diff --git a/components/tls/tls-client/src/msgs/persist_test.rs b/crates/tls/client/src/msgs/persist_test.rs similarity index 100% rename from components/tls/tls-client/src/msgs/persist_test.rs rename to crates/tls/client/src/msgs/persist_test.rs diff --git a/components/tls/tls-client/src/rand.rs b/crates/tls/client/src/rand.rs similarity index 100% rename from components/tls/tls-client/src/rand.rs rename to crates/tls/client/src/rand.rs diff --git a/components/tls/tls-client/src/record_layer.rs b/crates/tls/client/src/record_layer.rs similarity index 100% rename from components/tls/tls-client/src/record_layer.rs rename to crates/tls/client/src/record_layer.rs diff --git a/components/tls/tls-client/src/sign.rs b/crates/tls/client/src/sign.rs similarity index 98% rename from components/tls/tls-client/src/sign.rs rename to crates/tls/client/src/sign.rs index ed52e776b8..b9ac2f2853 100644 --- a/components/tls/tls-client/src/sign.rs +++ b/crates/tls/client/src/sign.rs @@ -38,7 +38,7 @@ pub trait Signer: Send + Sync { #[derive(Clone)] pub struct CertifiedKey { /// The certificate chain. - pub cert: Vec, + pub cert: Vec, /// The certified key. pub key: Arc, @@ -58,7 +58,7 @@ impl CertifiedKey { /// /// The cert chain must not be empty. The first certificate in the chain /// must be the end-entity certificate. - pub fn new(cert: Vec, key: Arc) -> Self { + pub fn new(cert: Vec, key: Arc) -> Self { Self { cert, key, @@ -68,7 +68,7 @@ impl CertifiedKey { } /// The end-entity certificate. - pub fn end_entity_cert(&self) -> Result<&tls_core::key::Certificate, SignError> { + pub fn end_entity_cert(&self) -> Result<&key::Certificate, SignError> { self.cert.first().ok_or(SignError(())) } diff --git a/components/tls/tls-client/src/stream.rs b/crates/tls/client/src/stream.rs similarity index 100% rename from components/tls/tls-client/src/stream.rs rename to crates/tls/client/src/stream.rs diff --git a/components/tls/tls-client/src/testdata/cert-arstechnica.0.der b/crates/tls/client/src/testdata/cert-arstechnica.0.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-arstechnica.0.der rename to crates/tls/client/src/testdata/cert-arstechnica.0.der diff --git a/components/tls/tls-client/src/testdata/cert-arstechnica.1.der b/crates/tls/client/src/testdata/cert-arstechnica.1.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-arstechnica.1.der rename to crates/tls/client/src/testdata/cert-arstechnica.1.der diff --git a/components/tls/tls-client/src/testdata/cert-arstechnica.2.der b/crates/tls/client/src/testdata/cert-arstechnica.2.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-arstechnica.2.der rename to crates/tls/client/src/testdata/cert-arstechnica.2.der diff --git a/components/tls/tls-client/src/testdata/cert-arstechnica.3.der b/crates/tls/client/src/testdata/cert-arstechnica.3.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-arstechnica.3.der rename to crates/tls/client/src/testdata/cert-arstechnica.3.der diff --git a/components/tls/tls-client/src/testdata/cert-duckduckgo.0.der b/crates/tls/client/src/testdata/cert-duckduckgo.0.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-duckduckgo.0.der rename to crates/tls/client/src/testdata/cert-duckduckgo.0.der diff --git a/components/tls/tls-client/src/testdata/cert-duckduckgo.1.der b/crates/tls/client/src/testdata/cert-duckduckgo.1.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-duckduckgo.1.der rename to crates/tls/client/src/testdata/cert-duckduckgo.1.der diff --git a/components/tls/tls-client/src/testdata/cert-github.0.der b/crates/tls/client/src/testdata/cert-github.0.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-github.0.der rename to crates/tls/client/src/testdata/cert-github.0.der diff --git a/components/tls/tls-client/src/testdata/cert-github.1.der b/crates/tls/client/src/testdata/cert-github.1.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-github.1.der rename to crates/tls/client/src/testdata/cert-github.1.der diff --git a/components/tls/tls-client/src/testdata/cert-google.0.der b/crates/tls/client/src/testdata/cert-google.0.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-google.0.der rename to crates/tls/client/src/testdata/cert-google.0.der diff --git a/components/tls/tls-client/src/testdata/cert-google.1.der b/crates/tls/client/src/testdata/cert-google.1.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-google.1.der rename to crates/tls/client/src/testdata/cert-google.1.der diff --git a/components/tls/tls-client/src/testdata/cert-google.2.der b/crates/tls/client/src/testdata/cert-google.2.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-google.2.der rename to crates/tls/client/src/testdata/cert-google.2.der diff --git a/components/tls/tls-client/src/testdata/cert-hn.0.der b/crates/tls/client/src/testdata/cert-hn.0.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-hn.0.der rename to crates/tls/client/src/testdata/cert-hn.0.der diff --git a/components/tls/tls-client/src/testdata/cert-hn.1.der b/crates/tls/client/src/testdata/cert-hn.1.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-hn.1.der rename to crates/tls/client/src/testdata/cert-hn.1.der diff --git a/components/tls/tls-client/src/testdata/cert-reddit.0.der b/crates/tls/client/src/testdata/cert-reddit.0.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-reddit.0.der rename to crates/tls/client/src/testdata/cert-reddit.0.der diff --git a/components/tls/tls-client/src/testdata/cert-reddit.1.der b/crates/tls/client/src/testdata/cert-reddit.1.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-reddit.1.der rename to crates/tls/client/src/testdata/cert-reddit.1.der diff --git a/components/tls/tls-client/src/testdata/cert-rustlang.0.der b/crates/tls/client/src/testdata/cert-rustlang.0.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-rustlang.0.der rename to crates/tls/client/src/testdata/cert-rustlang.0.der diff --git a/components/tls/tls-client/src/testdata/cert-rustlang.1.der b/crates/tls/client/src/testdata/cert-rustlang.1.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-rustlang.1.der rename to crates/tls/client/src/testdata/cert-rustlang.1.der diff --git a/components/tls/tls-client/src/testdata/cert-rustlang.2.der b/crates/tls/client/src/testdata/cert-rustlang.2.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-rustlang.2.der rename to crates/tls/client/src/testdata/cert-rustlang.2.der diff --git a/components/tls/tls-client/src/testdata/cert-rustlang.3.der b/crates/tls/client/src/testdata/cert-rustlang.3.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-rustlang.3.der rename to crates/tls/client/src/testdata/cert-rustlang.3.der diff --git a/components/tls/tls-client/src/testdata/cert-servo.0.der b/crates/tls/client/src/testdata/cert-servo.0.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-servo.0.der rename to crates/tls/client/src/testdata/cert-servo.0.der diff --git a/components/tls/tls-client/src/testdata/cert-servo.1.der b/crates/tls/client/src/testdata/cert-servo.1.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-servo.1.der rename to crates/tls/client/src/testdata/cert-servo.1.der diff --git a/components/tls/tls-client/src/testdata/cert-stackoverflow.0.der b/crates/tls/client/src/testdata/cert-stackoverflow.0.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-stackoverflow.0.der rename to crates/tls/client/src/testdata/cert-stackoverflow.0.der diff --git a/components/tls/tls-client/src/testdata/cert-stackoverflow.1.der b/crates/tls/client/src/testdata/cert-stackoverflow.1.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-stackoverflow.1.der rename to crates/tls/client/src/testdata/cert-stackoverflow.1.der diff --git a/components/tls/tls-client/src/testdata/cert-stackoverflow.2.der b/crates/tls/client/src/testdata/cert-stackoverflow.2.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-stackoverflow.2.der rename to crates/tls/client/src/testdata/cert-stackoverflow.2.der diff --git a/components/tls/tls-client/src/testdata/cert-twitter.0.der b/crates/tls/client/src/testdata/cert-twitter.0.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-twitter.0.der rename to crates/tls/client/src/testdata/cert-twitter.0.der diff --git a/components/tls/tls-client/src/testdata/cert-twitter.1.der b/crates/tls/client/src/testdata/cert-twitter.1.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-twitter.1.der rename to crates/tls/client/src/testdata/cert-twitter.1.der diff --git a/components/tls/tls-client/src/testdata/cert-wapo.0.der b/crates/tls/client/src/testdata/cert-wapo.0.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-wapo.0.der rename to crates/tls/client/src/testdata/cert-wapo.0.der diff --git a/components/tls/tls-client/src/testdata/cert-wapo.1.der b/crates/tls/client/src/testdata/cert-wapo.1.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-wapo.1.der rename to crates/tls/client/src/testdata/cert-wapo.1.der diff --git a/components/tls/tls-client/src/testdata/cert-wikipedia.0.der b/crates/tls/client/src/testdata/cert-wikipedia.0.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-wikipedia.0.der rename to crates/tls/client/src/testdata/cert-wikipedia.0.der diff --git a/components/tls/tls-client/src/testdata/cert-wikipedia.1.der b/crates/tls/client/src/testdata/cert-wikipedia.1.der similarity index 100% rename from components/tls/tls-client/src/testdata/cert-wikipedia.1.der rename to crates/tls/client/src/testdata/cert-wikipedia.1.der diff --git a/components/tls/tls-client/src/testdata/deframer-empty-applicationdata.bin b/crates/tls/client/src/testdata/deframer-empty-applicationdata.bin similarity index 100% rename from components/tls/tls-client/src/testdata/deframer-empty-applicationdata.bin rename to crates/tls/client/src/testdata/deframer-empty-applicationdata.bin diff --git a/components/tls/tls-client/src/testdata/deframer-invalid-contenttype.bin b/crates/tls/client/src/testdata/deframer-invalid-contenttype.bin similarity index 100% rename from components/tls/tls-client/src/testdata/deframer-invalid-contenttype.bin rename to crates/tls/client/src/testdata/deframer-invalid-contenttype.bin diff --git a/components/tls/tls-client/src/testdata/deframer-invalid-empty.bin b/crates/tls/client/src/testdata/deframer-invalid-empty.bin similarity index 100% rename from components/tls/tls-client/src/testdata/deframer-invalid-empty.bin rename to crates/tls/client/src/testdata/deframer-invalid-empty.bin diff --git a/components/tls/tls-client/src/testdata/deframer-invalid-length.bin b/crates/tls/client/src/testdata/deframer-invalid-length.bin similarity index 100% rename from components/tls/tls-client/src/testdata/deframer-invalid-length.bin rename to crates/tls/client/src/testdata/deframer-invalid-length.bin diff --git a/components/tls/tls-client/src/testdata/deframer-invalid-version.bin b/crates/tls/client/src/testdata/deframer-invalid-version.bin similarity index 100% rename from components/tls/tls-client/src/testdata/deframer-invalid-version.bin rename to crates/tls/client/src/testdata/deframer-invalid-version.bin diff --git a/components/tls/tls-client/src/testdata/deframer-test.1.bin b/crates/tls/client/src/testdata/deframer-test.1.bin similarity index 100% rename from components/tls/tls-client/src/testdata/deframer-test.1.bin rename to crates/tls/client/src/testdata/deframer-test.1.bin diff --git a/components/tls/tls-client/src/testdata/deframer-test.2.bin b/crates/tls/client/src/testdata/deframer-test.2.bin similarity index 100% rename from components/tls/tls-client/src/testdata/deframer-test.2.bin rename to crates/tls/client/src/testdata/deframer-test.2.bin diff --git a/components/tls/tls-client/src/testdata/eddsakey.der b/crates/tls/client/src/testdata/eddsakey.der similarity index 100% rename from components/tls/tls-client/src/testdata/eddsakey.der rename to crates/tls/client/src/testdata/eddsakey.der diff --git a/components/tls/tls-client/src/testdata/nistp256key.der b/crates/tls/client/src/testdata/nistp256key.der similarity index 100% rename from components/tls/tls-client/src/testdata/nistp256key.der rename to crates/tls/client/src/testdata/nistp256key.der diff --git a/components/tls/tls-client/src/testdata/nistp256key.pkcs8.der b/crates/tls/client/src/testdata/nistp256key.pkcs8.der similarity index 100% rename from components/tls/tls-client/src/testdata/nistp256key.pkcs8.der rename to crates/tls/client/src/testdata/nistp256key.pkcs8.der diff --git a/components/tls/tls-client/src/testdata/nistp384key.der b/crates/tls/client/src/testdata/nistp384key.der similarity index 100% rename from components/tls/tls-client/src/testdata/nistp384key.der rename to crates/tls/client/src/testdata/nistp384key.der diff --git a/components/tls/tls-client/src/testdata/nistp384key.pkcs8.der b/crates/tls/client/src/testdata/nistp384key.pkcs8.der similarity index 100% rename from components/tls/tls-client/src/testdata/nistp384key.pkcs8.der rename to crates/tls/client/src/testdata/nistp384key.pkcs8.der diff --git a/components/tls/tls-client/src/testdata/prf-result.1.bin b/crates/tls/client/src/testdata/prf-result.1.bin similarity index 100% rename from components/tls/tls-client/src/testdata/prf-result.1.bin rename to crates/tls/client/src/testdata/prf-result.1.bin diff --git a/components/tls/tls-client/src/testdata/prf-result.2.bin b/crates/tls/client/src/testdata/prf-result.2.bin similarity index 100% rename from components/tls/tls-client/src/testdata/prf-result.2.bin rename to crates/tls/client/src/testdata/prf-result.2.bin diff --git a/components/tls/tls-client/src/testdata/rsa2048key.pkcs1.der b/crates/tls/client/src/testdata/rsa2048key.pkcs1.der similarity index 100% rename from components/tls/tls-client/src/testdata/rsa2048key.pkcs1.der rename to crates/tls/client/src/testdata/rsa2048key.pkcs1.der diff --git a/components/tls/tls-client/src/testdata/rsa2048key.pkcs8.der b/crates/tls/client/src/testdata/rsa2048key.pkcs8.der similarity index 100% rename from components/tls/tls-client/src/testdata/rsa2048key.pkcs8.der rename to crates/tls/client/src/testdata/rsa2048key.pkcs8.der diff --git a/components/tls/tls-client/src/ticketer.rs b/crates/tls/client/src/ticketer.rs similarity index 100% rename from components/tls/tls-client/src/ticketer.rs rename to crates/tls/client/src/ticketer.rs diff --git a/components/tls/tls-client/src/vecbuf.rs b/crates/tls/client/src/vecbuf.rs similarity index 75% rename from components/tls/tls-client/src/vecbuf.rs rename to crates/tls/client/src/vecbuf.rs index c622708e96..5d73b8c775 100644 --- a/components/tls/tls-client/src/vecbuf.rs +++ b/crates/tls/client/src/vecbuf.rs @@ -96,19 +96,6 @@ impl ChunkVecBuffer { Ok(offs) } - #[cfg(read_buf)] - /// Read data out of this object, writing it into `buf`. - pub(crate) fn read_buf(&mut self, buf: &mut io::ReadBuf<'_>) -> io::Result<()> { - while !self.is_empty() && buf.remaining() > 0 { - let chunk = self.chunks[0].as_slice(); - let used = std::cmp::min(chunk.len(), buf.remaining()); - buf.append(&chunk[..used]); - self.consume(used); - } - - Ok(()) - } - fn consume(&mut self, mut used: usize) { while let Some(mut buf) = self.chunks.pop_front() { if used < buf.len() { @@ -172,38 +159,4 @@ mod test { assert_eq!(cvb.read(&mut buf).unwrap(), 12); assert_eq!(buf.to_vec(), b"helloworldhe".to_vec()); } - - #[cfg(read_buf)] - #[test] - fn read_buf() { - use std::{io::ReadBuf, mem::MaybeUninit}; - - { - let mut cvb = ChunkVecBuffer::new(None); - cvb.append(b"test ".to_vec()); - cvb.append(b"fixture ".to_vec()); - cvb.append(b"data".to_vec()); - - let mut buf = [MaybeUninit::::uninit(); 8]; - let mut buf = ReadBuf::uninit(&mut buf); - cvb.read_buf(&mut buf).unwrap(); - assert_eq!(buf.filled(), b"test fix"); - buf.clear(); - cvb.read_buf(&mut buf).unwrap(); - assert_eq!(buf.filled(), b"ture dat"); - buf.clear(); - cvb.read_buf(&mut buf).unwrap(); - assert_eq!(buf.filled(), b"a"); - } - - { - let mut cvb = ChunkVecBuffer::new(None); - cvb.append(b"short message".to_vec()); - - let mut buf = [MaybeUninit::::uninit(); 1024]; - let mut buf = ReadBuf::uninit(&mut buf); - cvb.read_buf(&mut buf).unwrap(); - assert_eq!(buf.filled(), b"short message"); - } - } } diff --git a/components/tls/tls-client/src/verifybench.rs b/crates/tls/client/src/verifybench.rs similarity index 100% rename from components/tls/tls-client/src/verifybench.rs rename to crates/tls/client/src/verifybench.rs diff --git a/components/tls/tls-client/test-ca/build-a-pki.sh b/crates/tls/client/test-ca/build-a-pki.sh similarity index 100% rename from components/tls/tls-client/test-ca/build-a-pki.sh rename to crates/tls/client/test-ca/build-a-pki.sh diff --git a/components/tls/tls-client/test-ca/ecdsa/ca.cert b/crates/tls/client/test-ca/ecdsa/ca.cert similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/ca.cert rename to crates/tls/client/test-ca/ecdsa/ca.cert diff --git a/components/tls/tls-client/test-ca/ecdsa/ca.der b/crates/tls/client/test-ca/ecdsa/ca.der similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/ca.der rename to crates/tls/client/test-ca/ecdsa/ca.der diff --git a/components/tls/tls-client/test-ca/ecdsa/ca.key b/crates/tls/client/test-ca/ecdsa/ca.key similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/ca.key rename to crates/tls/client/test-ca/ecdsa/ca.key diff --git a/components/tls/tls-client/test-ca/ecdsa/client.cert b/crates/tls/client/test-ca/ecdsa/client.cert similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/client.cert rename to crates/tls/client/test-ca/ecdsa/client.cert diff --git a/components/tls/tls-client/test-ca/ecdsa/client.chain b/crates/tls/client/test-ca/ecdsa/client.chain similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/client.chain rename to crates/tls/client/test-ca/ecdsa/client.chain diff --git a/components/tls/tls-client/test-ca/ecdsa/client.fullchain b/crates/tls/client/test-ca/ecdsa/client.fullchain similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/client.fullchain rename to crates/tls/client/test-ca/ecdsa/client.fullchain diff --git a/components/tls/tls-client/test-ca/ecdsa/client.key b/crates/tls/client/test-ca/ecdsa/client.key similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/client.key rename to crates/tls/client/test-ca/ecdsa/client.key diff --git a/components/tls/tls-client/test-ca/ecdsa/client.req b/crates/tls/client/test-ca/ecdsa/client.req similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/client.req rename to crates/tls/client/test-ca/ecdsa/client.req diff --git a/components/tls/tls-client/test-ca/ecdsa/end.cert b/crates/tls/client/test-ca/ecdsa/end.cert similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/end.cert rename to crates/tls/client/test-ca/ecdsa/end.cert diff --git a/components/tls/tls-client/test-ca/ecdsa/end.chain b/crates/tls/client/test-ca/ecdsa/end.chain similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/end.chain rename to crates/tls/client/test-ca/ecdsa/end.chain diff --git a/components/tls/tls-client/test-ca/ecdsa/end.fullchain b/crates/tls/client/test-ca/ecdsa/end.fullchain similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/end.fullchain rename to crates/tls/client/test-ca/ecdsa/end.fullchain diff --git a/components/tls/tls-client/test-ca/ecdsa/end.key b/crates/tls/client/test-ca/ecdsa/end.key similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/end.key rename to crates/tls/client/test-ca/ecdsa/end.key diff --git a/components/tls/tls-client/test-ca/ecdsa/end.req b/crates/tls/client/test-ca/ecdsa/end.req similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/end.req rename to crates/tls/client/test-ca/ecdsa/end.req diff --git a/components/tls/tls-client/test-ca/ecdsa/inter.cert b/crates/tls/client/test-ca/ecdsa/inter.cert similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/inter.cert rename to crates/tls/client/test-ca/ecdsa/inter.cert diff --git a/components/tls/tls-client/test-ca/ecdsa/inter.key b/crates/tls/client/test-ca/ecdsa/inter.key similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/inter.key rename to crates/tls/client/test-ca/ecdsa/inter.key diff --git a/components/tls/tls-client/test-ca/ecdsa/inter.req b/crates/tls/client/test-ca/ecdsa/inter.req similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/inter.req rename to crates/tls/client/test-ca/ecdsa/inter.req diff --git a/components/tls/tls-client/test-ca/ecdsa/nistp256.pem b/crates/tls/client/test-ca/ecdsa/nistp256.pem similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/nistp256.pem rename to crates/tls/client/test-ca/ecdsa/nistp256.pem diff --git a/components/tls/tls-client/test-ca/ecdsa/nistp384.pem b/crates/tls/client/test-ca/ecdsa/nistp384.pem similarity index 100% rename from components/tls/tls-client/test-ca/ecdsa/nistp384.pem rename to crates/tls/client/test-ca/ecdsa/nistp384.pem diff --git a/components/tls/tls-client/test-ca/eddsa/ca.cert b/crates/tls/client/test-ca/eddsa/ca.cert similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/ca.cert rename to crates/tls/client/test-ca/eddsa/ca.cert diff --git a/components/tls/tls-client/test-ca/eddsa/ca.der b/crates/tls/client/test-ca/eddsa/ca.der similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/ca.der rename to crates/tls/client/test-ca/eddsa/ca.der diff --git a/components/tls/tls-client/test-ca/eddsa/ca.key b/crates/tls/client/test-ca/eddsa/ca.key similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/ca.key rename to crates/tls/client/test-ca/eddsa/ca.key diff --git a/components/tls/tls-client/test-ca/eddsa/client.cert b/crates/tls/client/test-ca/eddsa/client.cert similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/client.cert rename to crates/tls/client/test-ca/eddsa/client.cert diff --git a/components/tls/tls-client/test-ca/eddsa/client.chain b/crates/tls/client/test-ca/eddsa/client.chain similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/client.chain rename to crates/tls/client/test-ca/eddsa/client.chain diff --git a/components/tls/tls-client/test-ca/eddsa/client.fullchain b/crates/tls/client/test-ca/eddsa/client.fullchain similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/client.fullchain rename to crates/tls/client/test-ca/eddsa/client.fullchain diff --git a/components/tls/tls-client/test-ca/eddsa/client.key b/crates/tls/client/test-ca/eddsa/client.key similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/client.key rename to crates/tls/client/test-ca/eddsa/client.key diff --git a/components/tls/tls-client/test-ca/eddsa/client.req b/crates/tls/client/test-ca/eddsa/client.req similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/client.req rename to crates/tls/client/test-ca/eddsa/client.req diff --git a/components/tls/tls-client/test-ca/eddsa/end.cert b/crates/tls/client/test-ca/eddsa/end.cert similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/end.cert rename to crates/tls/client/test-ca/eddsa/end.cert diff --git a/components/tls/tls-client/test-ca/eddsa/end.chain b/crates/tls/client/test-ca/eddsa/end.chain similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/end.chain rename to crates/tls/client/test-ca/eddsa/end.chain diff --git a/components/tls/tls-client/test-ca/eddsa/end.fullchain b/crates/tls/client/test-ca/eddsa/end.fullchain similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/end.fullchain rename to crates/tls/client/test-ca/eddsa/end.fullchain diff --git a/components/tls/tls-client/test-ca/eddsa/end.key b/crates/tls/client/test-ca/eddsa/end.key similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/end.key rename to crates/tls/client/test-ca/eddsa/end.key diff --git a/components/tls/tls-client/test-ca/eddsa/end.req b/crates/tls/client/test-ca/eddsa/end.req similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/end.req rename to crates/tls/client/test-ca/eddsa/end.req diff --git a/components/tls/tls-client/test-ca/eddsa/inter.cert b/crates/tls/client/test-ca/eddsa/inter.cert similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/inter.cert rename to crates/tls/client/test-ca/eddsa/inter.cert diff --git a/components/tls/tls-client/test-ca/eddsa/inter.key b/crates/tls/client/test-ca/eddsa/inter.key similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/inter.key rename to crates/tls/client/test-ca/eddsa/inter.key diff --git a/components/tls/tls-client/test-ca/eddsa/inter.req b/crates/tls/client/test-ca/eddsa/inter.req similarity index 100% rename from components/tls/tls-client/test-ca/eddsa/inter.req rename to crates/tls/client/test-ca/eddsa/inter.req diff --git a/components/tls/tls-client/test-ca/openssl.cnf b/crates/tls/client/test-ca/openssl.cnf similarity index 100% rename from components/tls/tls-client/test-ca/openssl.cnf rename to crates/tls/client/test-ca/openssl.cnf diff --git a/components/tls/tls-client/test-ca/rsa/ca.cert b/crates/tls/client/test-ca/rsa/ca.cert similarity index 100% rename from components/tls/tls-client/test-ca/rsa/ca.cert rename to crates/tls/client/test-ca/rsa/ca.cert diff --git a/components/tls/tls-client/test-ca/rsa/ca.der b/crates/tls/client/test-ca/rsa/ca.der similarity index 100% rename from components/tls/tls-client/test-ca/rsa/ca.der rename to crates/tls/client/test-ca/rsa/ca.der diff --git a/components/tls/tls-client/test-ca/rsa/ca.key b/crates/tls/client/test-ca/rsa/ca.key similarity index 100% rename from components/tls/tls-client/test-ca/rsa/ca.key rename to crates/tls/client/test-ca/rsa/ca.key diff --git a/components/tls/tls-client/test-ca/rsa/client.cert b/crates/tls/client/test-ca/rsa/client.cert similarity index 100% rename from components/tls/tls-client/test-ca/rsa/client.cert rename to crates/tls/client/test-ca/rsa/client.cert diff --git a/components/tls/tls-client/test-ca/rsa/client.chain b/crates/tls/client/test-ca/rsa/client.chain similarity index 100% rename from components/tls/tls-client/test-ca/rsa/client.chain rename to crates/tls/client/test-ca/rsa/client.chain diff --git a/components/tls/tls-client/test-ca/rsa/client.fullchain b/crates/tls/client/test-ca/rsa/client.fullchain similarity index 100% rename from components/tls/tls-client/test-ca/rsa/client.fullchain rename to crates/tls/client/test-ca/rsa/client.fullchain diff --git a/components/tls/tls-client/test-ca/rsa/client.key b/crates/tls/client/test-ca/rsa/client.key similarity index 100% rename from components/tls/tls-client/test-ca/rsa/client.key rename to crates/tls/client/test-ca/rsa/client.key diff --git a/components/tls/tls-client/test-ca/rsa/client.req b/crates/tls/client/test-ca/rsa/client.req similarity index 100% rename from components/tls/tls-client/test-ca/rsa/client.req rename to crates/tls/client/test-ca/rsa/client.req diff --git a/components/tls/tls-client/test-ca/rsa/client.rsa b/crates/tls/client/test-ca/rsa/client.rsa similarity index 100% rename from components/tls/tls-client/test-ca/rsa/client.rsa rename to crates/tls/client/test-ca/rsa/client.rsa diff --git a/components/tls/tls-client/test-ca/rsa/end.cert b/crates/tls/client/test-ca/rsa/end.cert similarity index 100% rename from components/tls/tls-client/test-ca/rsa/end.cert rename to crates/tls/client/test-ca/rsa/end.cert diff --git a/components/tls/tls-client/test-ca/rsa/end.chain b/crates/tls/client/test-ca/rsa/end.chain similarity index 100% rename from components/tls/tls-client/test-ca/rsa/end.chain rename to crates/tls/client/test-ca/rsa/end.chain diff --git a/components/tls/tls-client/test-ca/rsa/end.fullchain b/crates/tls/client/test-ca/rsa/end.fullchain similarity index 100% rename from components/tls/tls-client/test-ca/rsa/end.fullchain rename to crates/tls/client/test-ca/rsa/end.fullchain diff --git a/components/tls/tls-client/test-ca/rsa/end.key b/crates/tls/client/test-ca/rsa/end.key similarity index 100% rename from components/tls/tls-client/test-ca/rsa/end.key rename to crates/tls/client/test-ca/rsa/end.key diff --git a/components/tls/tls-client/test-ca/rsa/end.req b/crates/tls/client/test-ca/rsa/end.req similarity index 100% rename from components/tls/tls-client/test-ca/rsa/end.req rename to crates/tls/client/test-ca/rsa/end.req diff --git a/components/tls/tls-client/test-ca/rsa/end.rsa b/crates/tls/client/test-ca/rsa/end.rsa similarity index 100% rename from components/tls/tls-client/test-ca/rsa/end.rsa rename to crates/tls/client/test-ca/rsa/end.rsa diff --git a/components/tls/tls-client/test-ca/rsa/inter.cert b/crates/tls/client/test-ca/rsa/inter.cert similarity index 100% rename from components/tls/tls-client/test-ca/rsa/inter.cert rename to crates/tls/client/test-ca/rsa/inter.cert diff --git a/components/tls/tls-client/test-ca/rsa/inter.key b/crates/tls/client/test-ca/rsa/inter.key similarity index 100% rename from components/tls/tls-client/test-ca/rsa/inter.key rename to crates/tls/client/test-ca/rsa/inter.key diff --git a/components/tls/tls-client/test-ca/rsa/inter.req b/crates/tls/client/test-ca/rsa/inter.req similarity index 100% rename from components/tls/tls-client/test-ca/rsa/inter.req rename to crates/tls/client/test-ca/rsa/inter.req diff --git a/components/tls/tls-client/tests/api.rs b/crates/tls/client/tests/api.rs similarity index 75% rename from components/tls/tls-client/tests/api.rs rename to crates/tls/client/tests/api.rs index 43dd64ae44..7f0e25d19c 100644 --- a/components/tls/tls-client/tests/api.rs +++ b/crates/tls/client/tests/api.rs @@ -14,10 +14,6 @@ use std::{ }, }; - - -#[cfg(feature = "quic")] -use tls_client::quic::{self, ClientQuicExt, QuicExt, ServerQuicExt}; use tls_client::{ client::ResolvesClientCert, sign, CipherSuite, ClientConfig, ClientConnection, Error, KeyLog, ProtocolVersion, RustCryptoBackend, SignatureScheme, SupportedCipherSuite, ALL_CIPHER_SUITES, @@ -44,7 +40,7 @@ async fn alpn_test_error( for version in tls_client::ALL_VERSIONS { let mut client_config = make_client_config_with_versions(KeyType::Rsa, &[version]); - client_config.alpn_protocols = client_protos.clone(); + client_config.alpn_protocols.clone_from(&client_protos); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await; @@ -891,6 +887,7 @@ async fn client_error_is_sticky() { } #[tokio::test] +#[allow(clippy::no_effect)] async fn client_is_send() { let (client, _) = make_pair(KeyType::Rsa).await; &client as &dyn Send; @@ -1376,32 +1373,6 @@ async fn server_streamowned_read() { } } -struct FailsWrites { - errkind: io::ErrorKind, - after: usize, -} - -impl io::Read for FailsWrites { - fn read(&mut self, _b: &mut [u8]) -> io::Result { - Ok(0) - } -} - -impl io::Write for FailsWrites { - fn write(&mut self, b: &[u8]) -> io::Result { - if self.after > 0 { - self.after -= 1; - Ok(b.len()) - } else { - Err(io::Error::new(self.errkind, "oops")) - } - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - // #[tokio::test] // async fn stream_write_reports_underlying_io_error_before_plaintext_processed() { // let (mut client, mut server) = make_pair(KeyType::Rsa).await; @@ -2347,626 +2318,6 @@ async fn tls13_stateless_resumption() { // assert_eq!(client.is_early_data_accepted(), false); // } -#[cfg(feature = "quic")] -mod test_quic { - use super::*; - use rustls::Connection; - - // Returns the sender's next secrets to use, or the receiver's error. - fn step( - send: &mut dyn QuicExt, - recv: &mut dyn QuicExt, - ) -> Result, Error> { - let mut buf = Vec::new(); - let change = loop { - let prev = buf.len(); - if let Some(x) = send.write_hs(&mut buf) { - break Some(x); - } - if prev == buf.len() { - break None; - } - }; - if let Err(e) = recv.read_hs(&buf) { - return Err(e); - } else { - assert_eq!(recv.alert(), None); - } - - Ok(change) - } - - #[tokio::test] - fn test_quic_handshake() { - fn equal_packet_keys(x: &quic::PacketKey, y: &quic::PacketKey) -> bool { - // Check that these two sets of keys are equal. - let mut buf = vec![0; 32]; - let (header, payload_tag) = buf.split_at_mut(8); - let (payload, tag_buf) = payload_tag.split_at_mut(8); - let tag = x.encrypt_in_place(42, &*header, payload).unwrap(); - tag_buf.copy_from_slice(tag.as_ref()); - - let result = y.decrypt_in_place(42, &*header, payload_tag); - match result { - Ok(payload) => payload == &[0; 8], - Err(_) => false, - } - } - - fn compatible_keys(x: &quic::KeyChange, y: &quic::KeyChange) -> bool { - fn keys(kc: &quic::KeyChange) -> &quic::Keys { - match kc { - quic::KeyChange::Handshake { keys } => keys, - quic::KeyChange::OneRtt { keys, .. } => keys, - } - } - - let (x, y) = (keys(x), keys(y)); - equal_packet_keys(&x.local.packet, &y.remote.packet) - && equal_packet_keys(&x.remote.packet, &y.local.packet) - } - - let kt = KeyType::Rsa; - let mut client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13]); - client_config.enable_early_data = true; - let client_config = Arc::new(client_config); - let mut server_config = make_server_config_with_versions(kt, &[&rustls::version::TLS13]); - server_config.max_early_data_size = 0xffffffff; - let server_config = Arc::new(server_config); - let client_params = &b"client params"[..]; - let server_params = &b"server params"[..]; - - // full handshake - let mut client = Connection::from( - ClientConnection::new_quic( - Arc::clone(&client_config), - quic::Version::V1, - dns_name("localhost"), - client_params.into(), - ) - .unwrap(), - ); - - let mut server = Connection::from( - ServerConnection::new_quic( - Arc::clone(&server_config), - quic::Version::V1, - server_params.into(), - ) - .unwrap(), - ); - - let client_initial = step(&mut client, &mut server).unwrap(); - assert!(client_initial.is_none()); - assert!(client.zero_rtt_keys().is_none()); - assert_eq!(server.quic_transport_parameters(), Some(client_params)); - let server_hs = step(&mut server, &mut client).unwrap().unwrap(); - assert!(server.zero_rtt_keys().is_none()); - let client_hs = step(&mut client, &mut server).unwrap().unwrap(); - assert!(compatible_keys(&server_hs, &client_hs)); - assert!(client.is_handshaking()); - let server_1rtt = step(&mut server, &mut client).unwrap().unwrap(); - assert!(!client.is_handshaking()); - assert_eq!(client.quic_transport_parameters(), Some(server_params)); - assert!(server.is_handshaking()); - let client_1rtt = step(&mut client, &mut server).unwrap().unwrap(); - assert!(!server.is_handshaking()); - assert!(compatible_keys(&server_1rtt, &client_1rtt)); - assert!(!compatible_keys(&server_hs, &server_1rtt)); - assert!(step(&mut client, &mut server).unwrap().is_none()); - assert!(step(&mut server, &mut client).unwrap().is_none()); - - // 0-RTT handshake - let mut client = ClientConnection::new_quic( - Arc::clone(&client_config), - quic::Version::V1, - dns_name("localhost"), - client_params.into(), - ) - .unwrap(); - assert!(client.negotiated_cipher_suite().is_some()); - - let mut server = ServerConnection::new_quic( - Arc::clone(&server_config), - quic::Version::V1, - server_params.into(), - ) - .unwrap(); - - step(&mut client, &mut server).unwrap(); - assert_eq!(client.quic_transport_parameters(), Some(server_params)); - { - let client_early = client.zero_rtt_keys().unwrap(); - let server_early = server.zero_rtt_keys().unwrap(); - assert!(equal_packet_keys( - &client_early.packet, - &server_early.packet - )); - } - step(&mut server, &mut client).unwrap().unwrap(); - step(&mut client, &mut server).unwrap().unwrap(); - step(&mut server, &mut client).unwrap().unwrap(); - assert!(client.is_early_data_accepted()); - - // 0-RTT rejection - { - let client_config = (*client_config).clone(); - let mut client = ClientConnection::new_quic( - Arc::new(client_config), - quic::Version::V1, - dns_name("localhost"), - client_params.into(), - ) - .unwrap(); - - let mut server = ServerConnection::new_quic( - Arc::clone(&server_config), - quic::Version::V1, - server_params.into(), - ) - .unwrap(); - - step(&mut client, &mut server).unwrap(); - assert_eq!(client.quic_transport_parameters(), Some(server_params)); - assert!(client.zero_rtt_keys().is_some()); - assert!(server.zero_rtt_keys().is_none()); - step(&mut server, &mut client).unwrap().unwrap(); - step(&mut client, &mut server).unwrap().unwrap(); - step(&mut server, &mut client).unwrap().unwrap(); - assert!(!client.is_early_data_accepted()); - } - - // failed handshake - let mut client = ClientConnection::new_quic( - client_config, - quic::Version::V1, - dns_name("example.com"), - client_params.into(), - ) - .unwrap(); - - let mut server = - ServerConnection::new_quic(server_config, quic::Version::V1, server_params.into()) - .unwrap(); - - step(&mut client, &mut server).unwrap(); - step(&mut server, &mut client).unwrap().unwrap(); - assert!(step(&mut server, &mut client).is_err()); - assert_eq!( - client.alert(), - Some(rustls::internal::msgs::enums::AlertDescription::BadCertificate) - ); - - // Key updates - - let (mut client_secrets, mut server_secrets) = match (client_1rtt, server_1rtt) { - (quic::KeyChange::OneRtt { next: c, .. }, quic::KeyChange::OneRtt { next: s, .. }) => { - (c, s) - } - _ => unreachable!(), - }; - - let mut client_next = client_secrets.next_packet_keys(); - let mut server_next = server_secrets.next_packet_keys(); - assert!(equal_packet_keys(&client_next.local, &server_next.remote)); - assert!(equal_packet_keys(&server_next.local, &client_next.remote)); - - client_next = client_secrets.next_packet_keys(); - server_next = server_secrets.next_packet_keys(); - assert!(equal_packet_keys(&client_next.local, &server_next.remote)); - assert!(equal_packet_keys(&server_next.local, &client_next.remote)); - } - - #[tokio::test] - fn test_quic_rejects_missing_alpn() { - let client_params = &b"client params"[..]; - let server_params = &b"server params"[..]; - - for &kt in ALL_KEY_TYPES.iter() { - let client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13]); - let client_config = Arc::new(client_config); - - let mut server_config = - make_server_config_with_versions(kt, &[&rustls::version::TLS13]); - server_config.alpn_protocols = vec!["foo".into()]; - let server_config = Arc::new(server_config); - - let mut client = ClientConnection::new_quic( - client_config, - quic::Version::V1, - dns_name("localhost"), - client_params.into(), - ) - .unwrap(); - let mut server = - ServerConnection::new_quic(server_config, quic::Version::V1, server_params.into()) - .unwrap(); - - assert_eq!( - step(&mut client, &mut server).err().unwrap(), - Error::NoApplicationProtocol - ); - - assert_eq!( - server.alert(), - Some(rustls::internal::msgs::enums::AlertDescription::NoApplicationProtocol) - ); - } - } - - #[tokio::test] - fn test_quic_no_tls13_error() { - let mut client_config = - make_client_config_with_versions(KeyType::Ed25519, &[&rustls::version::TLS12]); - client_config.alpn_protocols = vec!["foo".into()]; - let client_config = Arc::new(client_config); - - assert!(ClientConnection::new_quic( - client_config, - quic::Version::V1, - dns_name("localhost"), - b"client params".to_vec(), - ) - .is_err()); - - let mut server_config = - make_server_config_with_versions(KeyType::Ed25519, &[&rustls::version::TLS12]); - server_config.alpn_protocols = vec!["foo".into()]; - let server_config = Arc::new(server_config); - - assert!(ServerConnection::new_quic( - server_config, - quic::Version::V1, - b"server params".to_vec(), - ) - .is_err()); - } - - #[tokio::test] - fn test_quic_invalid_early_data_size() { - let mut server_config = - make_server_config_with_versions(KeyType::Ed25519, &[&rustls::version::TLS13]); - server_config.alpn_protocols = vec!["foo".into()]; - - let cases = [ - (None, true), - (Some(0u32), true), - (Some(5), false), - (Some(0xffff_ffff), true), - ]; - - for &(size, ok) in cases.iter() { - println!("early data size case: {:?}", size); - if let Some(new) = size { - server_config.max_early_data_size = new; - } - - let wrapped = Arc::new(server_config.clone()); - assert_eq!( - ServerConnection::new_quic(wrapped, quic::Version::V1, b"server params".to_vec(),) - .is_ok(), - ok - ); - } - } - - #[tokio::test] - fn test_quic_server_no_params_received() { - let server_config = - make_server_config_with_versions(KeyType::Ed25519, &[&rustls::version::TLS13]); - let server_config = Arc::new(server_config); - - let mut server = - ServerConnection::new_quic(server_config, quic::Version::V1, b"server params".to_vec()) - .unwrap(); - - use ring::rand::SecureRandom; - use rustls::internal::msgs::{ - base::PayloadU16, - enums::{CipherSuite, Compression, HandshakeType, NamedGroup, SignatureScheme}, - handshake::{ - ClientHelloPayload, HandshakeMessagePayload, KeyShareEntry, Random, SessionID, - }, - message::PlainMessage, - }; - - let rng = ring::rand::SystemRandom::new(); - let mut random = [0; 32]; - rng.fill(&mut random).unwrap(); - let random = Random::from(random); - - let kx = ring::agreement::EphemeralPrivateKey::generate(&ring::agreement::X25519, &rng) - .unwrap() - .compute_public_key() - .unwrap(); - - let client_hello = Message { - version: ProtocolVersion::TLSv1_3, - payload: MessagePayload::Handshake(HandshakeMessagePayload { - typ: HandshakeType::ClientHello, - payload: HandshakePayload::ClientHello(ClientHelloPayload { - client_version: ProtocolVersion::TLSv1_3, - random, - session_id: SessionID::random().unwrap(), - cipher_suites: vec![CipherSuite::TLS13_AES_128_GCM_SHA256], - compression_methods: vec![Compression::Null], - extensions: vec![ - ClientExtension::SupportedVersions(vec![ProtocolVersion::TLSv1_3]), - ClientExtension::NamedGroups(vec![NamedGroup::X25519]), - ClientExtension::SignatureAlgorithms(vec![SignatureScheme::ED25519]), - ClientExtension::KeyShare(vec![KeyShareEntry { - group: NamedGroup::X25519, - payload: PayloadU16::new(kx.as_ref().to_vec()), - }]), - ], - }), - }), - }; - - let buf = PlainMessage::from(client_hello) - .into_unencrypted_opaque() - .encode(); - server.read_tls(&mut buf.as_slice()).unwrap(); - assert_eq!( - server.process_new_packets().err(), - Some(Error::PeerMisbehavedError( - "QUIC transport parameters not found".into(), - )), - ); - } - - #[tokio::test] - fn test_quic_server_no_tls12() { - let mut server_config = - make_server_config_with_versions(KeyType::Ed25519, &[&rustls::version::TLS13]); - server_config.alpn_protocols = vec!["foo".into()]; - let server_config = Arc::new(server_config); - - use ring::rand::SecureRandom; - use rustls::internal::msgs::{ - base::PayloadU16, - enums::{CipherSuite, Compression, HandshakeType, NamedGroup, SignatureScheme}, - handshake::{ - ClientHelloPayload, HandshakeMessagePayload, KeyShareEntry, Random, SessionID, - }, - message::PlainMessage, - }; - - let rng = ring::rand::SystemRandom::new(); - let mut random = [0; 32]; - rng.fill(&mut random).unwrap(); - let random = Random::from(random); - - let kx = ring::agreement::EphemeralPrivateKey::generate(&ring::agreement::X25519, &rng) - .unwrap() - .compute_public_key() - .unwrap(); - - let mut server = - ServerConnection::new_quic(server_config, quic::Version::V1, b"server params".to_vec()) - .unwrap(); - - let client_hello = Message { - version: ProtocolVersion::TLSv1_2, - payload: MessagePayload::Handshake(HandshakeMessagePayload { - typ: HandshakeType::ClientHello, - payload: HandshakePayload::ClientHello(ClientHelloPayload { - client_version: ProtocolVersion::TLSv1_2, - random: random.clone(), - session_id: SessionID::random().unwrap(), - cipher_suites: vec![CipherSuite::TLS13_AES_128_GCM_SHA256], - compression_methods: vec![Compression::Null], - extensions: vec![ - ClientExtension::NamedGroups(vec![NamedGroup::X25519]), - ClientExtension::SignatureAlgorithms(vec![SignatureScheme::ED25519]), - ClientExtension::KeyShare(vec![KeyShareEntry { - group: NamedGroup::X25519, - payload: PayloadU16::new(kx.as_ref().to_vec()), - }]), - ], - }), - }), - }; - - let buf = PlainMessage::from(client_hello) - .into_unencrypted_opaque() - .encode(); - server.read_tls(&mut buf.as_slice()).unwrap(); - assert_eq!( - server.process_new_packets().err(), - Some(Error::PeerIncompatibleError( - "Server requires TLS1.3, but client omitted versions ext".into(), - )), - ); - } - - #[tokio::test] - fn packet_key_api() { - use rustls::quic::{Keys, Version}; - - // Test vectors: https://www.rfc-editor.org/rfc/rfc9001.html#name-client-initial - const CONNECTION_ID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08]; - const PACKET_NUMBER: u64 = 2; - const PLAIN_HEADER: &[u8] = &[ - 0xc3, 0x00, 0x00, 0x00, 0x01, 0x08, 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08, - 0x00, 0x00, 0x44, 0x9e, 0x00, 0x00, 0x00, 0x02, - ]; - - const PAYLOAD: &[u8] = &[ - 0x06, 0x00, 0x40, 0xf1, 0x01, 0x00, 0x00, 0xed, 0x03, 0x03, 0xeb, 0xf8, 0xfa, 0x56, - 0xf1, 0x29, 0x39, 0xb9, 0x58, 0x4a, 0x38, 0x96, 0x47, 0x2e, 0xc4, 0x0b, 0xb8, 0x63, - 0xcf, 0xd3, 0xe8, 0x68, 0x04, 0xfe, 0x3a, 0x47, 0xf0, 0x6a, 0x2b, 0x69, 0x48, 0x4c, - 0x00, 0x00, 0x04, 0x13, 0x01, 0x13, 0x02, 0x01, 0x00, 0x00, 0xc0, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x0e, 0x00, 0x00, 0x0b, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, - 0x63, 0x6f, 0x6d, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x08, 0x00, 0x06, - 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x10, 0x00, 0x07, 0x00, 0x05, 0x04, 0x61, - 0x6c, 0x70, 0x6e, 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x33, - 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x93, 0x70, 0xb2, 0xc9, 0xca, 0xa4, - 0x7f, 0xba, 0xba, 0xf4, 0x55, 0x9f, 0xed, 0xba, 0x75, 0x3d, 0xe1, 0x71, 0xfa, 0x71, - 0xf5, 0x0f, 0x1c, 0xe1, 0x5d, 0x43, 0xe9, 0x94, 0xec, 0x74, 0xd7, 0x48, 0x00, 0x2b, - 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x0d, 0x00, 0x10, 0x00, 0x0e, 0x04, 0x03, 0x05, - 0x03, 0x06, 0x03, 0x02, 0x03, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x00, 0x2d, 0x00, - 0x02, 0x01, 0x01, 0x00, 0x1c, 0x00, 0x02, 0x40, 0x01, 0x00, 0x39, 0x00, 0x32, 0x04, - 0x08, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x05, 0x04, 0x80, 0x00, 0xff, - 0xff, 0x07, 0x04, 0x80, 0x00, 0xff, 0xff, 0x08, 0x01, 0x10, 0x01, 0x04, 0x80, 0x00, - 0x75, 0x30, 0x09, 0x01, 0x10, 0x0f, 0x08, 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, - 0x08, 0x06, 0x04, 0x80, 0x00, 0xff, 0xff, - ]; - - let client_keys = Keys::initial(Version::V1, &CONNECTION_ID, true); - assert_eq!( - client_keys.local.packet.confidentiality_limit(), - 2u64.pow(23) - ); - assert_eq!(client_keys.local.packet.integrity_limit(), 2u64.pow(52)); - assert_eq!(client_keys.local.packet.tag_len(), 16); - - let mut buf = Vec::new(); - buf.extend(PLAIN_HEADER); - buf.extend(PAYLOAD); - let header_len = PLAIN_HEADER.len(); - let tag_len = client_keys.local.packet.tag_len(); - let padding_len = 1200 - header_len - PAYLOAD.len() - tag_len; - buf.extend(std::iter::repeat(0).take(padding_len)); - let (header, payload) = buf.split_at_mut(header_len); - let tag = client_keys - .local - .packet - .encrypt_in_place(PACKET_NUMBER, &*header, payload) - .unwrap(); - - let sample_len = client_keys.local.header.sample_len(); - let sample = &payload[..sample_len]; - let (first, rest) = header.split_at_mut(1); - client_keys - .local - .header - .encrypt_in_place(sample, &mut first[0], &mut rest[17..21]) - .unwrap(); - buf.extend_from_slice(tag.as_ref()); - - const PROTECTED: &[u8] = &[ - 0xc0, 0x00, 0x00, 0x00, 0x01, 0x08, 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08, - 0x00, 0x00, 0x44, 0x9e, 0x7b, 0x9a, 0xec, 0x34, 0xd1, 0xb1, 0xc9, 0x8d, 0xd7, 0x68, - 0x9f, 0xb8, 0xec, 0x11, 0xd2, 0x42, 0xb1, 0x23, 0xdc, 0x9b, 0xd8, 0xba, 0xb9, 0x36, - 0xb4, 0x7d, 0x92, 0xec, 0x35, 0x6c, 0x0b, 0xab, 0x7d, 0xf5, 0x97, 0x6d, 0x27, 0xcd, - 0x44, 0x9f, 0x63, 0x30, 0x00, 0x99, 0xf3, 0x99, 0x1c, 0x26, 0x0e, 0xc4, 0xc6, 0x0d, - 0x17, 0xb3, 0x1f, 0x84, 0x29, 0x15, 0x7b, 0xb3, 0x5a, 0x12, 0x82, 0xa6, 0x43, 0xa8, - 0xd2, 0x26, 0x2c, 0xad, 0x67, 0x50, 0x0c, 0xad, 0xb8, 0xe7, 0x37, 0x8c, 0x8e, 0xb7, - 0x53, 0x9e, 0xc4, 0xd4, 0x90, 0x5f, 0xed, 0x1b, 0xee, 0x1f, 0xc8, 0xaa, 0xfb, 0xa1, - 0x7c, 0x75, 0x0e, 0x2c, 0x7a, 0xce, 0x01, 0xe6, 0x00, 0x5f, 0x80, 0xfc, 0xb7, 0xdf, - 0x62, 0x12, 0x30, 0xc8, 0x37, 0x11, 0xb3, 0x93, 0x43, 0xfa, 0x02, 0x8c, 0xea, 0x7f, - 0x7f, 0xb5, 0xff, 0x89, 0xea, 0xc2, 0x30, 0x82, 0x49, 0xa0, 0x22, 0x52, 0x15, 0x5e, - 0x23, 0x47, 0xb6, 0x3d, 0x58, 0xc5, 0x45, 0x7a, 0xfd, 0x84, 0xd0, 0x5d, 0xff, 0xfd, - 0xb2, 0x03, 0x92, 0x84, 0x4a, 0xe8, 0x12, 0x15, 0x46, 0x82, 0xe9, 0xcf, 0x01, 0x2f, - 0x90, 0x21, 0xa6, 0xf0, 0xbe, 0x17, 0xdd, 0xd0, 0xc2, 0x08, 0x4d, 0xce, 0x25, 0xff, - 0x9b, 0x06, 0xcd, 0xe5, 0x35, 0xd0, 0xf9, 0x20, 0xa2, 0xdb, 0x1b, 0xf3, 0x62, 0xc2, - 0x3e, 0x59, 0x6d, 0x11, 0xa4, 0xf5, 0xa6, 0xcf, 0x39, 0x48, 0x83, 0x8a, 0x3a, 0xec, - 0x4e, 0x15, 0xda, 0xf8, 0x50, 0x0a, 0x6e, 0xf6, 0x9e, 0xc4, 0xe3, 0xfe, 0xb6, 0xb1, - 0xd9, 0x8e, 0x61, 0x0a, 0xc8, 0xb7, 0xec, 0x3f, 0xaf, 0x6a, 0xd7, 0x60, 0xb7, 0xba, - 0xd1, 0xdb, 0x4b, 0xa3, 0x48, 0x5e, 0x8a, 0x94, 0xdc, 0x25, 0x0a, 0xe3, 0xfd, 0xb4, - 0x1e, 0xd1, 0x5f, 0xb6, 0xa8, 0xe5, 0xeb, 0xa0, 0xfc, 0x3d, 0xd6, 0x0b, 0xc8, 0xe3, - 0x0c, 0x5c, 0x42, 0x87, 0xe5, 0x38, 0x05, 0xdb, 0x05, 0x9a, 0xe0, 0x64, 0x8d, 0xb2, - 0xf6, 0x42, 0x64, 0xed, 0x5e, 0x39, 0xbe, 0x2e, 0x20, 0xd8, 0x2d, 0xf5, 0x66, 0xda, - 0x8d, 0xd5, 0x99, 0x8c, 0xca, 0xbd, 0xae, 0x05, 0x30, 0x60, 0xae, 0x6c, 0x7b, 0x43, - 0x78, 0xe8, 0x46, 0xd2, 0x9f, 0x37, 0xed, 0x7b, 0x4e, 0xa9, 0xec, 0x5d, 0x82, 0xe7, - 0x96, 0x1b, 0x7f, 0x25, 0xa9, 0x32, 0x38, 0x51, 0xf6, 0x81, 0xd5, 0x82, 0x36, 0x3a, - 0xa5, 0xf8, 0x99, 0x37, 0xf5, 0xa6, 0x72, 0x58, 0xbf, 0x63, 0xad, 0x6f, 0x1a, 0x0b, - 0x1d, 0x96, 0xdb, 0xd4, 0xfa, 0xdd, 0xfc, 0xef, 0xc5, 0x26, 0x6b, 0xa6, 0x61, 0x17, - 0x22, 0x39, 0x5c, 0x90, 0x65, 0x56, 0xbe, 0x52, 0xaf, 0xe3, 0xf5, 0x65, 0x63, 0x6a, - 0xd1, 0xb1, 0x7d, 0x50, 0x8b, 0x73, 0xd8, 0x74, 0x3e, 0xeb, 0x52, 0x4b, 0xe2, 0x2b, - 0x3d, 0xcb, 0xc2, 0xc7, 0x46, 0x8d, 0x54, 0x11, 0x9c, 0x74, 0x68, 0x44, 0x9a, 0x13, - 0xd8, 0xe3, 0xb9, 0x58, 0x11, 0xa1, 0x98, 0xf3, 0x49, 0x1d, 0xe3, 0xe7, 0xfe, 0x94, - 0x2b, 0x33, 0x04, 0x07, 0xab, 0xf8, 0x2a, 0x4e, 0xd7, 0xc1, 0xb3, 0x11, 0x66, 0x3a, - 0xc6, 0x98, 0x90, 0xf4, 0x15, 0x70, 0x15, 0x85, 0x3d, 0x91, 0xe9, 0x23, 0x03, 0x7c, - 0x22, 0x7a, 0x33, 0xcd, 0xd5, 0xec, 0x28, 0x1c, 0xa3, 0xf7, 0x9c, 0x44, 0x54, 0x6b, - 0x9d, 0x90, 0xca, 0x00, 0xf0, 0x64, 0xc9, 0x9e, 0x3d, 0xd9, 0x79, 0x11, 0xd3, 0x9f, - 0xe9, 0xc5, 0xd0, 0xb2, 0x3a, 0x22, 0x9a, 0x23, 0x4c, 0xb3, 0x61, 0x86, 0xc4, 0x81, - 0x9e, 0x8b, 0x9c, 0x59, 0x27, 0x72, 0x66, 0x32, 0x29, 0x1d, 0x6a, 0x41, 0x82, 0x11, - 0xcc, 0x29, 0x62, 0xe2, 0x0f, 0xe4, 0x7f, 0xeb, 0x3e, 0xdf, 0x33, 0x0f, 0x2c, 0x60, - 0x3a, 0x9d, 0x48, 0xc0, 0xfc, 0xb5, 0x69, 0x9d, 0xbf, 0xe5, 0x89, 0x64, 0x25, 0xc5, - 0xba, 0xc4, 0xae, 0xe8, 0x2e, 0x57, 0xa8, 0x5a, 0xaf, 0x4e, 0x25, 0x13, 0xe4, 0xf0, - 0x57, 0x96, 0xb0, 0x7b, 0xa2, 0xee, 0x47, 0xd8, 0x05, 0x06, 0xf8, 0xd2, 0xc2, 0x5e, - 0x50, 0xfd, 0x14, 0xde, 0x71, 0xe6, 0xc4, 0x18, 0x55, 0x93, 0x02, 0xf9, 0x39, 0xb0, - 0xe1, 0xab, 0xd5, 0x76, 0xf2, 0x79, 0xc4, 0xb2, 0xe0, 0xfe, 0xb8, 0x5c, 0x1f, 0x28, - 0xff, 0x18, 0xf5, 0x88, 0x91, 0xff, 0xef, 0x13, 0x2e, 0xef, 0x2f, 0xa0, 0x93, 0x46, - 0xae, 0xe3, 0x3c, 0x28, 0xeb, 0x13, 0x0f, 0xf2, 0x8f, 0x5b, 0x76, 0x69, 0x53, 0x33, - 0x41, 0x13, 0x21, 0x19, 0x96, 0xd2, 0x00, 0x11, 0xa1, 0x98, 0xe3, 0xfc, 0x43, 0x3f, - 0x9f, 0x25, 0x41, 0x01, 0x0a, 0xe1, 0x7c, 0x1b, 0xf2, 0x02, 0x58, 0x0f, 0x60, 0x47, - 0x47, 0x2f, 0xb3, 0x68, 0x57, 0xfe, 0x84, 0x3b, 0x19, 0xf5, 0x98, 0x40, 0x09, 0xdd, - 0xc3, 0x24, 0x04, 0x4e, 0x84, 0x7a, 0x4f, 0x4a, 0x0a, 0xb3, 0x4f, 0x71, 0x95, 0x95, - 0xde, 0x37, 0x25, 0x2d, 0x62, 0x35, 0x36, 0x5e, 0x9b, 0x84, 0x39, 0x2b, 0x06, 0x10, - 0x85, 0x34, 0x9d, 0x73, 0x20, 0x3a, 0x4a, 0x13, 0xe9, 0x6f, 0x54, 0x32, 0xec, 0x0f, - 0xd4, 0xa1, 0xee, 0x65, 0xac, 0xcd, 0xd5, 0xe3, 0x90, 0x4d, 0xf5, 0x4c, 0x1d, 0xa5, - 0x10, 0xb0, 0xff, 0x20, 0xdc, 0xc0, 0xc7, 0x7f, 0xcb, 0x2c, 0x0e, 0x0e, 0xb6, 0x05, - 0xcb, 0x05, 0x04, 0xdb, 0x87, 0x63, 0x2c, 0xf3, 0xd8, 0xb4, 0xda, 0xe6, 0xe7, 0x05, - 0x76, 0x9d, 0x1d, 0xe3, 0x54, 0x27, 0x01, 0x23, 0xcb, 0x11, 0x45, 0x0e, 0xfc, 0x60, - 0xac, 0x47, 0x68, 0x3d, 0x7b, 0x8d, 0x0f, 0x81, 0x13, 0x65, 0x56, 0x5f, 0xd9, 0x8c, - 0x4c, 0x8e, 0xb9, 0x36, 0xbc, 0xab, 0x8d, 0x06, 0x9f, 0xc3, 0x3b, 0xd8, 0x01, 0xb0, - 0x3a, 0xde, 0xa2, 0xe1, 0xfb, 0xc5, 0xaa, 0x46, 0x3d, 0x08, 0xca, 0x19, 0x89, 0x6d, - 0x2b, 0xf5, 0x9a, 0x07, 0x1b, 0x85, 0x1e, 0x6c, 0x23, 0x90, 0x52, 0x17, 0x2f, 0x29, - 0x6b, 0xfb, 0x5e, 0x72, 0x40, 0x47, 0x90, 0xa2, 0x18, 0x10, 0x14, 0xf3, 0xb9, 0x4a, - 0x4e, 0x97, 0xd1, 0x17, 0xb4, 0x38, 0x13, 0x03, 0x68, 0xcc, 0x39, 0xdb, 0xb2, 0xd1, - 0x98, 0x06, 0x5a, 0xe3, 0x98, 0x65, 0x47, 0x92, 0x6c, 0xd2, 0x16, 0x2f, 0x40, 0xa2, - 0x9f, 0x0c, 0x3c, 0x87, 0x45, 0xc0, 0xf5, 0x0f, 0xba, 0x38, 0x52, 0xe5, 0x66, 0xd4, - 0x45, 0x75, 0xc2, 0x9d, 0x39, 0xa0, 0x3f, 0x0c, 0xda, 0x72, 0x19, 0x84, 0xb6, 0xf4, - 0x40, 0x59, 0x1f, 0x35, 0x5e, 0x12, 0xd4, 0x39, 0xff, 0x15, 0x0a, 0xab, 0x76, 0x13, - 0x49, 0x9d, 0xbd, 0x49, 0xad, 0xab, 0xc8, 0x67, 0x6e, 0xef, 0x02, 0x3b, 0x15, 0xb6, - 0x5b, 0xfc, 0x5c, 0xa0, 0x69, 0x48, 0x10, 0x9f, 0x23, 0xf3, 0x50, 0xdb, 0x82, 0x12, - 0x35, 0x35, 0xeb, 0x8a, 0x74, 0x33, 0xbd, 0xab, 0xcb, 0x90, 0x92, 0x71, 0xa6, 0xec, - 0xbc, 0xb5, 0x8b, 0x93, 0x6a, 0x88, 0xcd, 0x4e, 0x8f, 0x2e, 0x6f, 0xf5, 0x80, 0x01, - 0x75, 0xf1, 0x13, 0x25, 0x3d, 0x8f, 0xa9, 0xca, 0x88, 0x85, 0xc2, 0xf5, 0x52, 0xe6, - 0x57, 0xdc, 0x60, 0x3f, 0x25, 0x2e, 0x1a, 0x8e, 0x30, 0x8f, 0x76, 0xf0, 0xbe, 0x79, - 0xe2, 0xfb, 0x8f, 0x5d, 0x5f, 0xbb, 0xe2, 0xe3, 0x0e, 0xca, 0xdd, 0x22, 0x07, 0x23, - 0xc8, 0xc0, 0xae, 0xa8, 0x07, 0x8c, 0xdf, 0xcb, 0x38, 0x68, 0x26, 0x3f, 0xf8, 0xf0, - 0x94, 0x00, 0x54, 0xda, 0x48, 0x78, 0x18, 0x93, 0xa7, 0xe4, 0x9a, 0xd5, 0xaf, 0xf4, - 0xaf, 0x30, 0x0c, 0xd8, 0x04, 0xa6, 0xb6, 0x27, 0x9a, 0xb3, 0xff, 0x3a, 0xfb, 0x64, - 0x49, 0x1c, 0x85, 0x19, 0x4a, 0xab, 0x76, 0x0d, 0x58, 0xa6, 0x06, 0x65, 0x4f, 0x9f, - 0x44, 0x00, 0xe8, 0xb3, 0x85, 0x91, 0x35, 0x6f, 0xbf, 0x64, 0x25, 0xac, 0xa2, 0x6d, - 0xc8, 0x52, 0x44, 0x25, 0x9f, 0xf2, 0xb1, 0x9c, 0x41, 0xb9, 0xf9, 0x6f, 0x3c, 0xa9, - 0xec, 0x1d, 0xde, 0x43, 0x4d, 0xa7, 0xd2, 0xd3, 0x92, 0xb9, 0x05, 0xdd, 0xf3, 0xd1, - 0xf9, 0xaf, 0x93, 0xd1, 0xaf, 0x59, 0x50, 0xbd, 0x49, 0x3f, 0x5a, 0xa7, 0x31, 0xb4, - 0x05, 0x6d, 0xf3, 0x1b, 0xd2, 0x67, 0xb6, 0xb9, 0x0a, 0x07, 0x98, 0x31, 0xaa, 0xf5, - 0x79, 0xbe, 0x0a, 0x39, 0x01, 0x31, 0x37, 0xaa, 0xc6, 0xd4, 0x04, 0xf5, 0x18, 0xcf, - 0xd4, 0x68, 0x40, 0x64, 0x7e, 0x78, 0xbf, 0xe7, 0x06, 0xca, 0x4c, 0xf5, 0xe9, 0xc5, - 0x45, 0x3e, 0x9f, 0x7c, 0xfd, 0x2b, 0x8b, 0x4c, 0x8d, 0x16, 0x9a, 0x44, 0xe5, 0x5c, - 0x88, 0xd4, 0xa9, 0xa7, 0xf9, 0x47, 0x42, 0x41, 0xe2, 0x21, 0xaf, 0x44, 0x86, 0x00, - 0x18, 0xab, 0x08, 0x56, 0x97, 0x2e, 0x19, 0x4c, 0xd9, 0x34, - ]; - - assert_eq!(&buf, PROTECTED); - - let (header, payload) = buf.split_at_mut(header_len); - let (first, rest) = header.split_at_mut(1); - let sample = &payload[..sample_len]; - - let server_keys = Keys::initial(Version::V1, &CONNECTION_ID, false); - server_keys - .remote - .header - .decrypt_in_place(sample, &mut first[0], &mut rest[17..21]) - .unwrap(); - let payload = server_keys - .remote - .packet - .decrypt_in_place(PACKET_NUMBER, &*header, payload) - .unwrap(); - - assert_eq!(&payload[..PAYLOAD.len()], PAYLOAD); - assert_eq!(payload.len(), buf.len() - header_len - tag_len); - } - - #[tokio::test] - fn test_quic_exporter() { - for &kt in ALL_KEY_TYPES.iter() { - let client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13]); - let server_config = make_server_config_with_versions(kt, &[&rustls::version::TLS13]); - - do_exporter_test(client_config, server_config); - } - } -} // mod test_quic - #[tokio::test] async fn test_client_does_not_offer_sha1() { use tls_client::internal::msgs::{ diff --git a/components/tls/tls-client/tests/common/mod.rs b/crates/tls/client/tests/common/mod.rs similarity index 100% rename from components/tls/tls-client/tests/common/mod.rs rename to crates/tls/client/tests/common/mod.rs diff --git a/components/tls/tls-core/Cargo.toml b/crates/tls/core/Cargo.toml similarity index 74% rename from components/tls/tls-core/Cargo.toml rename to crates/tls/core/Cargo.toml index 5bf38184ca..287d6fb1c5 100644 --- a/components/tls/tls-core/Cargo.toml +++ b/crates/tls/core/Cargo.toml @@ -5,7 +5,7 @@ description = "Cryptographic operations for the TLSNotary TLS client" keywords = ["tls", "mpc", "2pc"] categories = ["cryptography"] license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" +version = "0.1.0-alpha.7" edition = "2021" [lib] @@ -21,15 +21,15 @@ logging = ["tracing"] prf = ["dep:hmac", "dep:sha2"] [dependencies] -rand.workspace = true -sct.workspace = true -webpki = { workspace = true, features = ["alloc", "std"] } -tracing = { workspace = true, optional = true } -ring.workspace = true -futures.workspace = true -serde = { workspace = true, optional = true, features = ["derive"] } -rustls-pemfile.workspace = true -thiserror.workspace = true -web-time.workspace = true +futures = { workspace = true } hmac = { workspace = true, optional = true } +rand = { workspace = true } +ring = { workspace = true } +rustls-pemfile = { workspace = true } +sct = { workspace = true } +serde = { workspace = true, optional = true, features = ["derive"] } sha2 = { workspace = true, optional = true } +thiserror = { workspace = true } +tracing = { workspace = true, optional = true } +web-time = { workspace = true } +webpki = { workspace = true, features = ["alloc", "std"] } diff --git a/components/tls/tls-core/src/anchors.rs b/crates/tls/core/src/anchors.rs similarity index 100% rename from components/tls/tls-core/src/anchors.rs rename to crates/tls/core/src/anchors.rs diff --git a/components/tls/tls-core/src/cert.rs b/crates/tls/core/src/cert.rs similarity index 100% rename from components/tls/tls-core/src/cert.rs rename to crates/tls/core/src/cert.rs diff --git a/components/tls/tls-core/src/cipher.rs b/crates/tls/core/src/cipher.rs similarity index 100% rename from components/tls/tls-core/src/cipher.rs rename to crates/tls/core/src/cipher.rs diff --git a/components/tls/tls-core/src/dns.rs b/crates/tls/core/src/dns.rs similarity index 100% rename from components/tls/tls-core/src/dns.rs rename to crates/tls/core/src/dns.rs diff --git a/components/tls/tls-core/src/error.rs b/crates/tls/core/src/error.rs similarity index 100% rename from components/tls/tls-core/src/error.rs rename to crates/tls/core/src/error.rs diff --git a/components/tls/tls-core/src/handshake.rs b/crates/tls/core/src/handshake.rs similarity index 100% rename from components/tls/tls-core/src/handshake.rs rename to crates/tls/core/src/handshake.rs diff --git a/components/tls/tls-core/src/ke.rs b/crates/tls/core/src/ke.rs similarity index 90% rename from components/tls/tls-core/src/ke.rs rename to crates/tls/core/src/ke.rs index 9cb6895333..88f7d52609 100644 --- a/components/tls/tls-core/src/ke.rs +++ b/crates/tls/core/src/ke.rs @@ -3,8 +3,8 @@ use crate::msgs::handshake::DigitallySignedStruct; #[derive(Debug, Clone)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct ServerKxDetails { - kx_params: Vec, - kx_sig: DigitallySignedStruct, + pub kx_params: Vec, + pub kx_sig: DigitallySignedStruct, } impl ServerKxDetails { diff --git a/components/tls/tls-core/src/key.rs b/crates/tls/core/src/key.rs similarity index 100% rename from components/tls/tls-core/src/key.rs rename to crates/tls/core/src/key.rs diff --git a/components/tls/tls-core/src/lib.rs b/crates/tls/core/src/lib.rs similarity index 100% rename from components/tls/tls-core/src/lib.rs rename to crates/tls/core/src/lib.rs diff --git a/components/tls/tls-core/src/msgs/alert.rs b/crates/tls/core/src/msgs/alert.rs similarity index 100% rename from components/tls/tls-core/src/msgs/alert.rs rename to crates/tls/core/src/msgs/alert.rs diff --git a/components/tls/tls-core/src/msgs/base.rs b/crates/tls/core/src/msgs/base.rs similarity index 100% rename from components/tls/tls-core/src/msgs/base.rs rename to crates/tls/core/src/msgs/base.rs diff --git a/components/tls/tls-core/src/msgs/ccs.rs b/crates/tls/core/src/msgs/ccs.rs similarity index 100% rename from components/tls/tls-core/src/msgs/ccs.rs rename to crates/tls/core/src/msgs/ccs.rs diff --git a/components/tls/tls-core/src/msgs/codec.rs b/crates/tls/core/src/msgs/codec.rs similarity index 100% rename from components/tls/tls-core/src/msgs/codec.rs rename to crates/tls/core/src/msgs/codec.rs diff --git a/components/tls/tls-core/src/msgs/deframer.rs b/crates/tls/core/src/msgs/deframer.rs similarity index 100% rename from components/tls/tls-core/src/msgs/deframer.rs rename to crates/tls/core/src/msgs/deframer.rs diff --git a/components/tls/tls-core/src/msgs/enums.rs b/crates/tls/core/src/msgs/enums.rs similarity index 93% rename from components/tls/tls-core/src/msgs/enums.rs rename to crates/tls/core/src/msgs/enums.rs index 49e2b06259..f8438d494a 100644 --- a/components/tls/tls-core/src/msgs/enums.rs +++ b/crates/tls/core/src/msgs/enums.rs @@ -2,7 +2,9 @@ use crate::msgs::codec::{Codec, Reader}; enum_builder! { - /// The `ProtocolVersion` TLS protocol enum. Values in this enum are taken + /// The `ProtocolVersion` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U16 @@ -20,7 +22,9 @@ enum_builder! { } enum_builder! { - /// The `HashAlgorithm` TLS protocol enum. Values in this enum are taken + /// The `HashAlgorithm` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -37,7 +41,9 @@ enum_builder! { } enum_builder! { - /// The `SignatureAlgorithm` TLS protocol enum. Values in this enum are taken + /// The `SignatureAlgorithm` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -53,7 +59,9 @@ enum_builder! { } enum_builder! { - /// The `ClientCertificateType` TLS protocol enum. Values in this enum are taken + /// The `ClientCertificateType` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -73,7 +81,9 @@ enum_builder! { } enum_builder! { - /// The `Compression` TLS protocol enum. Values in this enum are taken + /// The `Compression` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -86,7 +96,9 @@ enum_builder! { } enum_builder! { - /// The `ContentType` TLS protocol enum. Values in this enum are taken + /// The `ContentType` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -101,7 +113,9 @@ enum_builder! { } enum_builder! { - /// The `HandshakeType` TLS protocol enum. Values in this enum are taken + /// The `HandshakeType` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -130,7 +144,9 @@ enum_builder! { } enum_builder! { - /// The `AlertLevel` TLS protocol enum. Values in this enum are taken + /// The `AlertLevel` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -142,7 +158,9 @@ enum_builder! { } enum_builder! { - /// The `AlertDescription` TLS protocol enum. Values in this enum are taken + /// The `AlertDescription` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -186,7 +204,9 @@ enum_builder! { } enum_builder! { - /// The `HeartbeatMessageType` TLS protocol enum. Values in this enum are taken + /// The `HeartbeatMessageType` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -198,7 +218,9 @@ enum_builder! { } enum_builder! { - /// The `ExtensionType` TLS protocol enum. Values in this enum are taken + /// The `ExtensionType` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U16 @@ -245,7 +267,9 @@ enum_builder! { } enum_builder! { - /// The `ServerNameType` TLS protocol enum. Values in this enum are taken + /// The `ServerNameType` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -256,7 +280,9 @@ enum_builder! { } enum_builder! { - /// The `NamedCurve` TLS protocol enum. Values in this enum are taken + /// The `NamedCurve` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U16 @@ -298,7 +324,9 @@ enum_builder! { } enum_builder! { - /// The `NamedGroup` TLS protocol enum. Values in this enum are taken + /// The `NamedGroup` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U16 @@ -318,7 +346,9 @@ enum_builder! { } enum_builder! { - /// The `CipherSuite` TLS protocol enum. Values in this enum are taken + /// The `CipherSuite` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U16 @@ -704,7 +734,9 @@ enum_builder! { } enum_builder! { - /// The `ECPointFormat` TLS protocol enum. Values in this enum are taken + /// The `ECPointFormat` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -717,7 +749,9 @@ enum_builder! { } enum_builder! { - /// The `HeartbeatMode` TLS protocol enum. Values in this enum are taken + /// The `HeartbeatMode` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -729,7 +763,9 @@ enum_builder! { } enum_builder! { - /// The `ECCurveType` TLS protocol enum. Values in this enum are taken + /// The `ECCurveType` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -742,7 +778,9 @@ enum_builder! { } enum_builder! { - /// The `SignatureScheme` TLS protocol enum. Values in this enum are taken + /// The `SignatureScheme` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U16 @@ -765,7 +803,9 @@ enum_builder! { } enum_builder! { - /// The `PSKKeyExchangeMode` TLS protocol enum. Values in this enum are taken + /// The `PSKKeyExchangeMode` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -777,7 +817,9 @@ enum_builder! { } enum_builder! { - /// The `KeyUpdateRequest` TLS protocol enum. Values in this enum are taken + /// The `KeyUpdateRequest` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 @@ -789,7 +831,9 @@ enum_builder! { } enum_builder! { - /// The `CertificateStatusType` TLS protocol enum. Values in this enum are taken + /// The `CertificateStatusType` TLS protocol enum. + /// + /// Values in this enum are taken /// from the various RFCs covering TLS, and are listed by IANA. /// The `Unknown` item is used when processing unrecognised ordinals. @U8 diff --git a/components/tls/tls-core/src/msgs/enums_test.rs b/crates/tls/core/src/msgs/enums_test.rs similarity index 100% rename from components/tls/tls-core/src/msgs/enums_test.rs rename to crates/tls/core/src/msgs/enums_test.rs diff --git a/components/tls/tls-core/src/msgs/fragmenter.rs b/crates/tls/core/src/msgs/fragmenter.rs similarity index 100% rename from components/tls/tls-core/src/msgs/fragmenter.rs rename to crates/tls/core/src/msgs/fragmenter.rs diff --git a/components/tls/tls-core/src/msgs/handshake-test.1.bin b/crates/tls/core/src/msgs/handshake-test.1.bin similarity index 100% rename from components/tls/tls-core/src/msgs/handshake-test.1.bin rename to crates/tls/core/src/msgs/handshake-test.1.bin diff --git a/components/tls/tls-core/src/msgs/handshake.rs b/crates/tls/core/src/msgs/handshake.rs similarity index 100% rename from components/tls/tls-core/src/msgs/handshake.rs rename to crates/tls/core/src/msgs/handshake.rs diff --git a/components/tls/tls-core/src/msgs/handshake_test.rs b/crates/tls/core/src/msgs/handshake_test.rs similarity index 100% rename from components/tls/tls-core/src/msgs/handshake_test.rs rename to crates/tls/core/src/msgs/handshake_test.rs diff --git a/components/tls/tls-core/src/msgs/hsjoiner.rs b/crates/tls/core/src/msgs/hsjoiner.rs similarity index 100% rename from components/tls/tls-core/src/msgs/hsjoiner.rs rename to crates/tls/core/src/msgs/hsjoiner.rs diff --git a/components/tls/tls-core/src/msgs/macros.rs b/crates/tls/core/src/msgs/macros.rs similarity index 100% rename from components/tls/tls-core/src/msgs/macros.rs rename to crates/tls/core/src/msgs/macros.rs diff --git a/components/tls/tls-core/src/msgs/message.rs b/crates/tls/core/src/msgs/message.rs similarity index 100% rename from components/tls/tls-core/src/msgs/message.rs rename to crates/tls/core/src/msgs/message.rs diff --git a/components/tls/tls-core/src/msgs/message_test.rs b/crates/tls/core/src/msgs/message_test.rs similarity index 100% rename from components/tls/tls-core/src/msgs/message_test.rs rename to crates/tls/core/src/msgs/message_test.rs diff --git a/components/tls/tls-core/src/msgs/mod.rs b/crates/tls/core/src/msgs/mod.rs similarity index 88% rename from components/tls/tls-core/src/msgs/mod.rs rename to crates/tls/core/src/msgs/mod.rs index ab5e37ffdb..1d977f480e 100644 --- a/components/tls/tls-core/src/msgs/mod.rs +++ b/crates/tls/core/src/msgs/mod.rs @@ -31,8 +31,10 @@ mod test { #[test] fn smoketest() { - use super::codec::Reader; - use super::message::{Message, OpaqueMessage}; + use super::{ + codec::Reader, + message::{Message, OpaqueMessage}, + }; let bytes = include_bytes!("handshake-test.1.bin"); let mut r = Reader::init(bytes); diff --git a/components/tls/tls-core/src/prf.rs b/crates/tls/core/src/prf.rs similarity index 100% rename from components/tls/tls-core/src/prf.rs rename to crates/tls/core/src/prf.rs diff --git a/components/tls/tls-core/src/rand.rs b/crates/tls/core/src/rand.rs similarity index 100% rename from components/tls/tls-core/src/rand.rs rename to crates/tls/core/src/rand.rs diff --git a/components/tls/tls-core/src/suites/mod.rs b/crates/tls/core/src/suites/mod.rs similarity index 100% rename from components/tls/tls-core/src/suites/mod.rs rename to crates/tls/core/src/suites/mod.rs diff --git a/components/tls/tls-core/src/suites/tls12.rs b/crates/tls/core/src/suites/tls12.rs similarity index 100% rename from components/tls/tls-core/src/suites/tls12.rs rename to crates/tls/core/src/suites/tls12.rs diff --git a/components/tls/tls-core/src/suites/tls13.rs b/crates/tls/core/src/suites/tls13.rs similarity index 100% rename from components/tls/tls-core/src/suites/tls13.rs rename to crates/tls/core/src/suites/tls13.rs diff --git a/components/tls/tls-core/src/utils/bs_debug.rs b/crates/tls/core/src/utils/bs_debug.rs similarity index 100% rename from components/tls/tls-core/src/utils/bs_debug.rs rename to crates/tls/core/src/utils/bs_debug.rs diff --git a/components/tls/tls-core/src/utils/mod.rs b/crates/tls/core/src/utils/mod.rs similarity index 100% rename from components/tls/tls-core/src/utils/mod.rs rename to crates/tls/core/src/utils/mod.rs diff --git a/components/tls/tls-core/src/verify.rs b/crates/tls/core/src/verify.rs similarity index 99% rename from components/tls/tls-core/src/verify.rs rename to crates/tls/core/src/verify.rs index 4a6e68967f..df1634e740 100644 --- a/components/tls/tls-core/src/verify.rs +++ b/crates/tls/core/src/verify.rs @@ -339,6 +339,11 @@ impl WebPkiVerifier { Self { roots, ct_policy } } + /// Returns the root store. + pub fn root_store(&self) -> &RootCertStore { + &self.roots + } + /// Returns the signature verification methods supported by /// webpki. pub fn verification_schemes() -> Vec { diff --git a/components/tls/tls-core/src/versions.rs b/crates/tls/core/src/versions.rs similarity index 97% rename from components/tls/tls-core/src/versions.rs rename to crates/tls/core/src/versions.rs index 566805c80b..90aa75faa1 100644 --- a/components/tls/tls-core/src/versions.rs +++ b/crates/tls/core/src/versions.rs @@ -6,23 +6,21 @@ use crate::msgs::enums::ProtocolVersion; /// the [`ALL_VERSIONS`] array, as well as individually as [`TLS12`] /// and [`TLS13`]. #[derive(Debug, PartialEq)] +#[non_exhaustive] pub struct SupportedProtocolVersion { /// The TLS enumeration naming this version. pub version: ProtocolVersion, - is_private: (), } /// TLS1.2 #[cfg(feature = "tls12")] pub static TLS12: SupportedProtocolVersion = SupportedProtocolVersion { version: ProtocolVersion::TLSv1_2, - is_private: (), }; /// TLS1.3 pub static TLS13: SupportedProtocolVersion = SupportedProtocolVersion { version: ProtocolVersion::TLSv1_3, - is_private: (), }; /// A list of all the protocol versions supported by rustls. diff --git a/components/tls/tls-core/src/x509.rs b/crates/tls/core/src/x509.rs similarity index 100% rename from components/tls/tls-core/src/x509.rs rename to crates/tls/core/src/x509.rs diff --git a/components/tls/tls-core/testdata/cert-arstechnica.0.der b/crates/tls/core/testdata/cert-arstechnica.0.der similarity index 100% rename from components/tls/tls-core/testdata/cert-arstechnica.0.der rename to crates/tls/core/testdata/cert-arstechnica.0.der diff --git a/components/tls/tls-core/testdata/cert-arstechnica.1.der b/crates/tls/core/testdata/cert-arstechnica.1.der similarity index 100% rename from components/tls/tls-core/testdata/cert-arstechnica.1.der rename to crates/tls/core/testdata/cert-arstechnica.1.der diff --git a/components/tls/tls-core/testdata/cert-arstechnica.2.der b/crates/tls/core/testdata/cert-arstechnica.2.der similarity index 100% rename from components/tls/tls-core/testdata/cert-arstechnica.2.der rename to crates/tls/core/testdata/cert-arstechnica.2.der diff --git a/components/tls/tls-core/testdata/cert-arstechnica.3.der b/crates/tls/core/testdata/cert-arstechnica.3.der similarity index 100% rename from components/tls/tls-core/testdata/cert-arstechnica.3.der rename to crates/tls/core/testdata/cert-arstechnica.3.der diff --git a/components/tls/tls-core/testdata/cert-digicert.pem b/crates/tls/core/testdata/cert-digicert.pem similarity index 100% rename from components/tls/tls-core/testdata/cert-digicert.pem rename to crates/tls/core/testdata/cert-digicert.pem diff --git a/components/tls/tls-core/testdata/cert-duckduckgo.0.der b/crates/tls/core/testdata/cert-duckduckgo.0.der similarity index 100% rename from components/tls/tls-core/testdata/cert-duckduckgo.0.der rename to crates/tls/core/testdata/cert-duckduckgo.0.der diff --git a/components/tls/tls-core/testdata/cert-duckduckgo.1.der b/crates/tls/core/testdata/cert-duckduckgo.1.der similarity index 100% rename from components/tls/tls-core/testdata/cert-duckduckgo.1.der rename to crates/tls/core/testdata/cert-duckduckgo.1.der diff --git a/components/tls/tls-core/testdata/cert-github.0.der b/crates/tls/core/testdata/cert-github.0.der similarity index 100% rename from components/tls/tls-core/testdata/cert-github.0.der rename to crates/tls/core/testdata/cert-github.0.der diff --git a/components/tls/tls-core/testdata/cert-github.1.der b/crates/tls/core/testdata/cert-github.1.der similarity index 100% rename from components/tls/tls-core/testdata/cert-github.1.der rename to crates/tls/core/testdata/cert-github.1.der diff --git a/components/tls/tls-core/testdata/cert-google.0.der b/crates/tls/core/testdata/cert-google.0.der similarity index 100% rename from components/tls/tls-core/testdata/cert-google.0.der rename to crates/tls/core/testdata/cert-google.0.der diff --git a/components/tls/tls-core/testdata/cert-google.1.der b/crates/tls/core/testdata/cert-google.1.der similarity index 100% rename from components/tls/tls-core/testdata/cert-google.1.der rename to crates/tls/core/testdata/cert-google.1.der diff --git a/components/tls/tls-core/testdata/cert-google.2.der b/crates/tls/core/testdata/cert-google.2.der similarity index 100% rename from components/tls/tls-core/testdata/cert-google.2.der rename to crates/tls/core/testdata/cert-google.2.der diff --git a/components/tls/tls-core/testdata/cert-hn.0.der b/crates/tls/core/testdata/cert-hn.0.der similarity index 100% rename from components/tls/tls-core/testdata/cert-hn.0.der rename to crates/tls/core/testdata/cert-hn.0.der diff --git a/components/tls/tls-core/testdata/cert-hn.1.der b/crates/tls/core/testdata/cert-hn.1.der similarity index 100% rename from components/tls/tls-core/testdata/cert-hn.1.der rename to crates/tls/core/testdata/cert-hn.1.der diff --git a/components/tls/tls-core/testdata/cert-reddit.0.der b/crates/tls/core/testdata/cert-reddit.0.der similarity index 100% rename from components/tls/tls-core/testdata/cert-reddit.0.der rename to crates/tls/core/testdata/cert-reddit.0.der diff --git a/components/tls/tls-core/testdata/cert-reddit.1.der b/crates/tls/core/testdata/cert-reddit.1.der similarity index 100% rename from components/tls/tls-core/testdata/cert-reddit.1.der rename to crates/tls/core/testdata/cert-reddit.1.der diff --git a/components/tls/tls-core/testdata/cert-rustlang.0.der b/crates/tls/core/testdata/cert-rustlang.0.der similarity index 100% rename from components/tls/tls-core/testdata/cert-rustlang.0.der rename to crates/tls/core/testdata/cert-rustlang.0.der diff --git a/components/tls/tls-core/testdata/cert-rustlang.1.der b/crates/tls/core/testdata/cert-rustlang.1.der similarity index 100% rename from components/tls/tls-core/testdata/cert-rustlang.1.der rename to crates/tls/core/testdata/cert-rustlang.1.der diff --git a/components/tls/tls-core/testdata/cert-rustlang.2.der b/crates/tls/core/testdata/cert-rustlang.2.der similarity index 100% rename from components/tls/tls-core/testdata/cert-rustlang.2.der rename to crates/tls/core/testdata/cert-rustlang.2.der diff --git a/components/tls/tls-core/testdata/cert-rustlang.3.der b/crates/tls/core/testdata/cert-rustlang.3.der similarity index 100% rename from components/tls/tls-core/testdata/cert-rustlang.3.der rename to crates/tls/core/testdata/cert-rustlang.3.der diff --git a/components/tls/tls-core/testdata/cert-servo.0.der b/crates/tls/core/testdata/cert-servo.0.der similarity index 100% rename from components/tls/tls-core/testdata/cert-servo.0.der rename to crates/tls/core/testdata/cert-servo.0.der diff --git a/components/tls/tls-core/testdata/cert-servo.1.der b/crates/tls/core/testdata/cert-servo.1.der similarity index 100% rename from components/tls/tls-core/testdata/cert-servo.1.der rename to crates/tls/core/testdata/cert-servo.1.der diff --git a/components/tls/tls-core/testdata/cert-stackoverflow.0.der b/crates/tls/core/testdata/cert-stackoverflow.0.der similarity index 100% rename from components/tls/tls-core/testdata/cert-stackoverflow.0.der rename to crates/tls/core/testdata/cert-stackoverflow.0.der diff --git a/components/tls/tls-core/testdata/cert-stackoverflow.1.der b/crates/tls/core/testdata/cert-stackoverflow.1.der similarity index 100% rename from components/tls/tls-core/testdata/cert-stackoverflow.1.der rename to crates/tls/core/testdata/cert-stackoverflow.1.der diff --git a/components/tls/tls-core/testdata/cert-stackoverflow.2.der b/crates/tls/core/testdata/cert-stackoverflow.2.der similarity index 100% rename from components/tls/tls-core/testdata/cert-stackoverflow.2.der rename to crates/tls/core/testdata/cert-stackoverflow.2.der diff --git a/components/tls/tls-core/testdata/cert-twitter.0.der b/crates/tls/core/testdata/cert-twitter.0.der similarity index 100% rename from components/tls/tls-core/testdata/cert-twitter.0.der rename to crates/tls/core/testdata/cert-twitter.0.der diff --git a/components/tls/tls-core/testdata/cert-twitter.1.der b/crates/tls/core/testdata/cert-twitter.1.der similarity index 100% rename from components/tls/tls-core/testdata/cert-twitter.1.der rename to crates/tls/core/testdata/cert-twitter.1.der diff --git a/components/tls/tls-core/testdata/cert-wapo.0.der b/crates/tls/core/testdata/cert-wapo.0.der similarity index 100% rename from components/tls/tls-core/testdata/cert-wapo.0.der rename to crates/tls/core/testdata/cert-wapo.0.der diff --git a/components/tls/tls-core/testdata/cert-wapo.1.der b/crates/tls/core/testdata/cert-wapo.1.der similarity index 100% rename from components/tls/tls-core/testdata/cert-wapo.1.der rename to crates/tls/core/testdata/cert-wapo.1.der diff --git a/components/tls/tls-core/testdata/cert-wikipedia.0.der b/crates/tls/core/testdata/cert-wikipedia.0.der similarity index 100% rename from components/tls/tls-core/testdata/cert-wikipedia.0.der rename to crates/tls/core/testdata/cert-wikipedia.0.der diff --git a/components/tls/tls-core/testdata/cert-wikipedia.1.der b/crates/tls/core/testdata/cert-wikipedia.1.der similarity index 100% rename from components/tls/tls-core/testdata/cert-wikipedia.1.der rename to crates/tls/core/testdata/cert-wikipedia.1.der diff --git a/components/tls/tls-core/testdata/deframer-empty-applicationdata.bin b/crates/tls/core/testdata/deframer-empty-applicationdata.bin similarity index 100% rename from components/tls/tls-core/testdata/deframer-empty-applicationdata.bin rename to crates/tls/core/testdata/deframer-empty-applicationdata.bin diff --git a/components/tls/tls-core/testdata/deframer-invalid-contenttype.bin b/crates/tls/core/testdata/deframer-invalid-contenttype.bin similarity index 100% rename from components/tls/tls-core/testdata/deframer-invalid-contenttype.bin rename to crates/tls/core/testdata/deframer-invalid-contenttype.bin diff --git a/components/tls/tls-core/testdata/deframer-invalid-empty.bin b/crates/tls/core/testdata/deframer-invalid-empty.bin similarity index 100% rename from components/tls/tls-core/testdata/deframer-invalid-empty.bin rename to crates/tls/core/testdata/deframer-invalid-empty.bin diff --git a/components/tls/tls-core/testdata/deframer-invalid-length.bin b/crates/tls/core/testdata/deframer-invalid-length.bin similarity index 100% rename from components/tls/tls-core/testdata/deframer-invalid-length.bin rename to crates/tls/core/testdata/deframer-invalid-length.bin diff --git a/components/tls/tls-core/testdata/deframer-invalid-version.bin b/crates/tls/core/testdata/deframer-invalid-version.bin similarity index 100% rename from components/tls/tls-core/testdata/deframer-invalid-version.bin rename to crates/tls/core/testdata/deframer-invalid-version.bin diff --git a/components/tls/tls-core/testdata/deframer-test.1.bin b/crates/tls/core/testdata/deframer-test.1.bin similarity index 100% rename from components/tls/tls-core/testdata/deframer-test.1.bin rename to crates/tls/core/testdata/deframer-test.1.bin diff --git a/components/tls/tls-core/testdata/deframer-test.2.bin b/crates/tls/core/testdata/deframer-test.2.bin similarity index 100% rename from components/tls/tls-core/testdata/deframer-test.2.bin rename to crates/tls/core/testdata/deframer-test.2.bin diff --git a/components/tls/tls-core/testdata/eddsakey.der b/crates/tls/core/testdata/eddsakey.der similarity index 100% rename from components/tls/tls-core/testdata/eddsakey.der rename to crates/tls/core/testdata/eddsakey.der diff --git a/components/tls/tls-core/testdata/nistp256key.der b/crates/tls/core/testdata/nistp256key.der similarity index 100% rename from components/tls/tls-core/testdata/nistp256key.der rename to crates/tls/core/testdata/nistp256key.der diff --git a/components/tls/tls-core/testdata/nistp256key.pkcs8.der b/crates/tls/core/testdata/nistp256key.pkcs8.der similarity index 100% rename from components/tls/tls-core/testdata/nistp256key.pkcs8.der rename to crates/tls/core/testdata/nistp256key.pkcs8.der diff --git a/components/tls/tls-core/testdata/nistp384key.der b/crates/tls/core/testdata/nistp384key.der similarity index 100% rename from components/tls/tls-core/testdata/nistp384key.der rename to crates/tls/core/testdata/nistp384key.der diff --git a/components/tls/tls-core/testdata/nistp384key.pkcs8.der b/crates/tls/core/testdata/nistp384key.pkcs8.der similarity index 100% rename from components/tls/tls-core/testdata/nistp384key.pkcs8.der rename to crates/tls/core/testdata/nistp384key.pkcs8.der diff --git a/components/tls/tls-core/testdata/prf-result.1.bin b/crates/tls/core/testdata/prf-result.1.bin similarity index 100% rename from components/tls/tls-core/testdata/prf-result.1.bin rename to crates/tls/core/testdata/prf-result.1.bin diff --git a/components/tls/tls-core/testdata/prf-result.2.bin b/crates/tls/core/testdata/prf-result.2.bin similarity index 100% rename from components/tls/tls-core/testdata/prf-result.2.bin rename to crates/tls/core/testdata/prf-result.2.bin diff --git a/components/tls/tls-core/testdata/rsa2048key.pkcs1.der b/crates/tls/core/testdata/rsa2048key.pkcs1.der similarity index 100% rename from components/tls/tls-core/testdata/rsa2048key.pkcs1.der rename to crates/tls/core/testdata/rsa2048key.pkcs1.der diff --git a/components/tls/tls-core/testdata/rsa2048key.pkcs8.der b/crates/tls/core/testdata/rsa2048key.pkcs8.der similarity index 100% rename from components/tls/tls-core/testdata/rsa2048key.pkcs8.der rename to crates/tls/core/testdata/rsa2048key.pkcs8.der diff --git a/crates/tls/mpc/Cargo.toml b/crates/tls/mpc/Cargo.toml new file mode 100644 index 0000000000..bbaa661308 --- /dev/null +++ b/crates/tls/mpc/Cargo.toml @@ -0,0 +1,57 @@ +[package] +name = "tlsn-tls-mpc" +authors = ["TLSNotary Team"] +description = "Implementation of the backend trait for 2PC" +keywords = ["tls", "mpc", "2pc"] +categories = ["cryptography"] +license = "MIT OR Apache-2.0" +version = "0.1.0-alpha.7" +edition = "2021" + +[lib] +name = "tls_mpc" + +[features] +default = [] + +[dependencies] +tlsn-aead = { workspace = true } +tlsn-block-cipher = { workspace = true } +tlsn-hmac-sha256 = { workspace = true } +tlsn-key-exchange = { workspace = true } +tlsn-stream-cipher = { workspace = true } +tlsn-tls-backend = { workspace = true } +tlsn-tls-core = { workspace = true, features = ["serde"] } +tlsn-universal-hash = { workspace = true } +tlsn-utils-aio = { workspace = true } + +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } + +uid-mux = { version = "0.1", features = ["serio"] } +ludi = { git = "https://github.com/sinui0/ludi", rev = "b590de5" } + +async-trait = { workspace = true } +derive_builder = { workspace = true } +enum-try-as-inner = { workspace = true } +futures = { workspace = true } +p256 = { workspace = true } +rand = { workspace = true } +serde = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +serio = { version = "0.1", features = ["compat"] } +tls-server-fixture = { workspace = true } +tlsn-tls-client = { workspace = true } +tlsn-tls-client-async = { workspace = true } + +tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } +tokio-util = { workspace = true, features = ["compat"] } +tracing-subscriber = { workspace = true } diff --git a/crates/tls/mpc/src/components.rs b/crates/tls/mpc/src/components.rs new file mode 100644 index 0000000000..b68ad3aa7b --- /dev/null +++ b/crates/tls/mpc/src/components.rs @@ -0,0 +1,208 @@ +use aead::{ + aes_gcm::{AesGcmConfig, AesGcmError, MpcAesGcm, Role as AeadRole}, + Aead, +}; +use block_cipher::{Aes128, BlockCipherConfig, MpcBlockCipher}; +use hmac_sha256::{MpcPrf, Prf, PrfConfig, Role as PrfRole}; +use key_exchange::{KeyExchange, KeyExchangeConfig, MpcKeyExchange, Role as KeRole}; +use mpz_common::{Context, Preprocess}; +use mpz_fields::{gf2_128::Gf2_128, p256::P256}; +use mpz_garble::{Decode, DecodePrivate, Execute, Load, Memory, Prove, Thread, Verify}; +use mpz_ole::rot::{OLEReceiver, OLESender}; +use mpz_ot::{OTError, RandomOTReceiver, RandomOTSender}; +use mpz_share_conversion::{ShareConversionReceiver, ShareConversionSender}; +use tlsn_stream_cipher::{Aes128Ctr, MpcStreamCipher, StreamCipherConfig}; +use tlsn_universal_hash::{ + ghash::{Ghash, GhashConfig}, + UniversalHash, +}; + +use crate::{MpcTlsCommonConfig, TlsRole}; + +/// Builds the components for MPC-TLS. +// TODO: Better dependency injection!! +#[allow(clippy::too_many_arguments, clippy::type_complexity)] +pub fn build_components( + role: TlsRole, + config: &MpcTlsCommonConfig, + ctx_ke: Ctx, + ctx_encrypter: Ctx, + ctx_decrypter: Ctx, + ctx_ghash_encrypter: Ctx, + ctx_ghash_decrypter: Ctx, + thread_ke: T, + thread_prf_0: T, + thread_prf_1: T, + thread_encrypter_block_cipher: T, + thread_decrypter_block_cipher: T, + thread_encrypter_stream_cipher: T, + thread_decrypter_stream_cipher: T, + ot_send: OTS, + ot_recv: OTR, +) -> ( + Box, + Box, + Box + Send>, + Box + Send>, +) +where + Ctx: Context + 'static, + T: Thread + + Memory + + Execute + + Load + + Decode + + DecodePrivate + + Prove + + Verify + + Send + + Sync + + 'static, + OTS: Preprocess + + RandomOTSender + + RandomOTSender + + Clone + + Send + + Sync + + 'static, + OTR: Preprocess + + RandomOTReceiver + + RandomOTReceiver + + Clone + + Send + + Sync + + 'static, +{ + let ke: Box = match role { + TlsRole::Leader => Box::new(MpcKeyExchange::new( + KeyExchangeConfig::builder() + .role(KeRole::Leader) + .build() + .unwrap(), + ctx_ke, + ShareConversionSender::new(OLESender::new(ot_send.clone())), + ShareConversionReceiver::new(OLEReceiver::new(ot_recv.clone())), + thread_ke, + )), + TlsRole::Follower => Box::new(MpcKeyExchange::new( + KeyExchangeConfig::builder() + .role(KeRole::Follower) + .build() + .unwrap(), + ctx_ke, + ShareConversionReceiver::new(OLEReceiver::new(ot_recv.clone())), + ShareConversionSender::new(OLESender::new(ot_send.clone())), + thread_ke, + )), + }; + + let prf: Box = Box::new(MpcPrf::new( + PrfConfig::builder() + .role(match role { + TlsRole::Leader => PrfRole::Leader, + TlsRole::Follower => PrfRole::Follower, + }) + .build() + .unwrap(), + thread_prf_0, + thread_prf_1, + )); + + // Encrypter + let block_cipher = Box::new(MpcBlockCipher::::new( + BlockCipherConfig::builder() + .id("encrypter/block_cipher") + .build() + .unwrap(), + thread_encrypter_block_cipher, + )); + + let stream_cipher = Box::new(MpcStreamCipher::::new( + StreamCipherConfig::builder() + .id("encrypter/stream_cipher") + .transcript_id("tx") + .build() + .unwrap(), + thread_encrypter_stream_cipher, + )); + + let ghash: Box = match role { + TlsRole::Leader => Box::new(Ghash::new( + GhashConfig::builder().build().unwrap(), + ShareConversionSender::new(OLESender::new(ot_send.clone())), + ctx_ghash_encrypter, + )), + TlsRole::Follower => Box::new(Ghash::new( + GhashConfig::builder().build().unwrap(), + ShareConversionReceiver::new(OLEReceiver::new(ot_recv.clone())), + ctx_ghash_encrypter, + )), + }; + + let mut encrypter = Box::new(MpcAesGcm::new( + AesGcmConfig::builder() + .id("encrypter/aes_gcm") + .role(match role { + TlsRole::Leader => AeadRole::Leader, + TlsRole::Follower => AeadRole::Follower, + }) + .build() + .unwrap(), + ctx_encrypter, + block_cipher, + stream_cipher, + ghash, + )); + + encrypter.set_transcript_id(config.tx_config().opaque_id()); + + // Decrypter + let block_cipher = Box::new(MpcBlockCipher::::new( + BlockCipherConfig::builder() + .id("decrypter/block_cipher") + .build() + .unwrap(), + thread_decrypter_block_cipher, + )); + + let stream_cipher = Box::new(MpcStreamCipher::::new( + StreamCipherConfig::builder() + .id("decrypter/stream_cipher") + .transcript_id("rx") + .build() + .unwrap(), + thread_decrypter_stream_cipher, + )); + + let ghash: Box = match role { + TlsRole::Leader => Box::new(Ghash::new( + GhashConfig::builder().build().unwrap(), + ShareConversionSender::new(OLESender::new(ot_send)), + ctx_ghash_decrypter, + )), + TlsRole::Follower => Box::new(Ghash::new( + GhashConfig::builder().build().unwrap(), + ShareConversionReceiver::new(OLEReceiver::new(ot_recv)), + ctx_ghash_decrypter, + )), + }; + + let mut decrypter = Box::new(MpcAesGcm::new( + AesGcmConfig::builder() + .id("decrypter/aes_gcm") + .role(match role { + TlsRole::Leader => AeadRole::Leader, + TlsRole::Follower => AeadRole::Follower, + }) + .build() + .unwrap(), + ctx_decrypter, + block_cipher, + stream_cipher, + ghash, + )); + + decrypter.set_transcript_id(config.rx_config().opaque_id()); + + (ke, prf, encrypter, decrypter) +} diff --git a/crates/tls/mpc/src/config.rs b/crates/tls/mpc/src/config.rs new file mode 100644 index 0000000000..5e2afdd1d1 --- /dev/null +++ b/crates/tls/mpc/src/config.rs @@ -0,0 +1,181 @@ +use derive_builder::Builder; + +static DEFAULT_OPAQUE_TX_TRANSCRIPT_ID: &str = "opaque_tx"; +static DEFAULT_OPAQUE_RX_TRANSCRIPT_ID: &str = "opaque_rx"; +static DEFAULT_TX_TRANSCRIPT_ID: &str = "tx"; +static DEFAULT_RX_TRANSCRIPT_ID: &str = "rx"; +const DEFAULT_TRANSCRIPT_MAX_SIZE: usize = 1 << 14; + +/// Transcript configuration. +#[derive(Debug, Clone, Builder)] +pub struct TranscriptConfig { + /// The transcript id. + id: String, + /// The "opaque" transcript id, used for parts of the transcript that are + /// not part of the application data. + opaque_id: String, + /// The maximum number of bytes that can be written to the transcript during + /// the **online** phase, i.e. while the MPC-TLS connection is active. + max_online_size: usize, + /// The maximum number of bytes that can be written to the transcript during + /// the **offline** phase, i.e. after the MPC-TLS connection was closed. + max_offline_size: usize, +} + +impl TranscriptConfig { + /// Creates a new default builder for the sent transcript config. + pub fn default_tx() -> TranscriptConfigBuilder { + let mut builder = TranscriptConfigBuilder::default(); + + builder + .id(DEFAULT_TX_TRANSCRIPT_ID.to_string()) + .opaque_id(DEFAULT_OPAQUE_TX_TRANSCRIPT_ID.to_string()) + .max_online_size(DEFAULT_TRANSCRIPT_MAX_SIZE) + .max_offline_size(0); + + builder + } + + /// Creates a new default builder for the received transcript config. + pub fn default_rx() -> TranscriptConfigBuilder { + let mut builder = TranscriptConfigBuilder::default(); + + builder + .id(DEFAULT_RX_TRANSCRIPT_ID.to_string()) + .opaque_id(DEFAULT_OPAQUE_RX_TRANSCRIPT_ID.to_string()) + .max_online_size(0) + .max_offline_size(DEFAULT_TRANSCRIPT_MAX_SIZE); + + builder + } + + /// Creates a new builder for `TranscriptConfig`. + pub fn builder() -> TranscriptConfigBuilder { + TranscriptConfigBuilder::default() + } + + /// Returns the transcript id. + pub fn id(&self) -> &str { + &self.id + } + + /// Returns the "opaque" transcript id. + pub fn opaque_id(&self) -> &str { + &self.opaque_id + } + + /// Returns the maximum number of bytes that can be written to the + /// transcript during the **online** phase, i.e. while the MPC-TLS + /// connection is active. + pub fn max_online_size(&self) -> usize { + self.max_online_size + } + + /// Returns the maximum number of bytes that can be written to the + /// transcript during the **offline** phase, i.e. after the MPC-TLS + /// connection was closed. + pub fn max_offline_size(&self) -> usize { + self.max_offline_size + } +} + +/// Configuration options which are common to both the leader and the follower +#[derive(Debug, Clone, Builder)] +pub struct MpcTlsCommonConfig { + /// The number of threads to use + #[builder(default = "8")] + num_threads: usize, + /// The sent data transcript configuration. + #[builder(default = "TranscriptConfig::default_tx().build().unwrap()")] + tx_config: TranscriptConfig, + /// The received data transcript configuration. + #[builder(default = "TranscriptConfig::default_rx().build().unwrap()")] + rx_config: TranscriptConfig, + /// Whether the leader commits to the handshake data. + #[builder(default = "true")] + handshake_commit: bool, +} + +impl MpcTlsCommonConfig { + /// Creates a new builder for `MpcTlsCommonConfig`. + pub fn builder() -> MpcTlsCommonConfigBuilder { + MpcTlsCommonConfigBuilder::default() + } + + /// Returns the number of threads to use. + pub fn num_threads(&self) -> usize { + self.num_threads + } + + /// Returns the configuration for the sent data transcript. + pub fn tx_config(&self) -> &TranscriptConfig { + &self.tx_config + } + + /// Returns the configuration for the received data transcript. + pub fn rx_config(&self) -> &TranscriptConfig { + &self.rx_config + } + + /// Whether the leader commits to the handshake data. + pub fn handshake_commit(&self) -> bool { + self.handshake_commit + } +} + +/// Configuration for the leader +#[allow(missing_docs)] +#[derive(Debug, Clone, Builder)] +pub struct MpcTlsLeaderConfig { + common: MpcTlsCommonConfig, + /// Whether the `deferred decryption` feature is toggled on from the start + /// of the MPC-TLS connection. + /// + /// The received data will be decrypted locally without MPC, thus improving + /// bandwidth usage and performance. + /// + /// Decryption of the data received while `deferred decryption` is toggled + /// on will be deferred until after the MPC-TLS connection is closed. + /// If you need to decrypt some subset of data received from the TLS peer + /// while the MPC-TLS connection is active, you must toggle `deferred + /// decryption` **off** for that subset of data. + #[builder(default = "true")] + defer_decryption_from_start: bool, +} + +impl MpcTlsLeaderConfig { + /// Creates a new builder for `MpcTlsLeaderConfig`. + pub fn builder() -> MpcTlsLeaderConfigBuilder { + MpcTlsLeaderConfigBuilder::default() + } + + /// Returns the common config. + pub fn common(&self) -> &MpcTlsCommonConfig { + &self.common + } + + /// Returns whether the `deferred decryption` feature is toggled on from the + /// start of the MPC-TLS connection. + pub fn defer_decryption_from_start(&self) -> bool { + self.defer_decryption_from_start + } +} + +/// Configuration for the follower +#[allow(missing_docs)] +#[derive(Debug, Clone, Builder)] +pub struct MpcTlsFollowerConfig { + common: MpcTlsCommonConfig, +} + +impl MpcTlsFollowerConfig { + /// Creates a new builder for `MpcTlsFollowerConfig`. + pub fn builder() -> MpcTlsFollowerConfigBuilder { + MpcTlsFollowerConfigBuilder::default() + } + + /// Returns the common config. + pub fn common(&self) -> &MpcTlsCommonConfig { + &self.common + } +} diff --git a/components/tls/tls-mpc/src/error.rs b/crates/tls/mpc/src/error.rs similarity index 92% rename from components/tls/tls-mpc/src/error.rs rename to crates/tls/mpc/src/error.rs index 3d4a88a817..9d5eae8f60 100644 --- a/components/tls/tls-mpc/src/error.rs +++ b/crates/tls/mpc/src/error.rs @@ -63,6 +63,8 @@ impl MpcTlsError { pub(crate) enum Kind { /// An unexpected state was encountered State, + /// Context error. + Ctx, /// IO related error Io, /// An error occurred during MPC @@ -87,6 +89,7 @@ impl Display for Kind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Kind::State => write!(f, "State"), + Kind::Ctx => write!(f, "Context"), Kind::Io => write!(f, "Io"), Kind::Mpc => write!(f, "Mpc"), Kind::KeyExchange => write!(f, "KeyExchange"), @@ -120,6 +123,16 @@ impl From for MpcTlsError { } } +impl From for MpcTlsError { + fn from(err: mpz_common::ContextError) -> Self { + Self { + kind: Kind::Ctx, + msg: "context error".to_string(), + source: Some(Box::new(err)), + } + } +} + impl From for MpcTlsError { fn from(err: mpz_garble::VmError) -> Self { Self { diff --git a/components/tls/tls-mpc/src/follower.rs b/crates/tls/mpc/src/follower.rs similarity index 64% rename from components/tls/tls-mpc/src/follower.rs rename to crates/tls/mpc/src/follower.rs index 1d3f5ca34c..2945b38c67 100644 --- a/components/tls/tls-mpc/src/follower.rs +++ b/crates/tls/mpc/src/follower.rs @@ -5,15 +5,12 @@ use futures::{ FutureExt, StreamExt, }; -use hmac_sha256 as prf; use key_exchange as ke; use ludi::{Address, FuturesAddress}; -use mpz_core::hash::Hash; use p256::elliptic_curve::sec1::ToEncodedPoint; -use prf::SessionKeys; -use aead::Aead; +use aead::{aes_gcm::AesGcmError, Aead}; use hmac_sha256::Prf; use ke::KeyExchange; use tls_core::{ @@ -22,16 +19,18 @@ use tls_core::{ alert::AlertMessagePayload, base::Payload, codec::Codec, - enums::{AlertDescription, ContentType, NamedGroup, ProtocolVersion}, + enums::{AlertDescription, ContentType, HandshakeType, NamedGroup, ProtocolVersion}, + handshake::{HandshakeMessagePayload, HandshakePayload}, message::{OpaqueMessage, PlainMessage}, }, }; +use tracing::{debug, instrument, Instrument}; use crate::{ error::Kind, msg::{CloseConnection, Commit, MpcTlsFollowerMsg, MpcTlsMessage}, record_layer::{Decrypter, Encrypter}, - MpcTlsChannel, MpcTlsError, MpcTlsFollowerConfig, + Direction, MpcTlsChannel, MpcTlsError, MpcTlsFollowerConfig, }; /// Controller for MPC-TLS follower. @@ -60,8 +59,6 @@ pub struct MpcTlsFollower { /// Data collected by the MPC-TLS follower. #[derive(Debug)] pub struct MpcTlsFollowerData { - /// The prover's commitment to the handshake data - pub handshake_commitment: Option, /// The server's public key pub server_key: PublicKey, /// The total number of bytes sent @@ -75,19 +72,14 @@ impl ludi::Actor for MpcTlsFollower { type Error = MpcTlsError; async fn stopped(&mut self) -> Result { - #[cfg(feature = "tracing")] - tracing::debug!("follower actor stopped"); + debug!("follower actor stopped"); - let Closed { - handshake_commitment, - server_key, - } = self.state.take().try_into_closed()?; + let Closed { server_key } = self.state.take().try_into_closed()?; let bytes_sent = self.encrypter.sent_bytes(); let bytes_recv = self.decrypter.recv_bytes(); Ok(MpcTlsFollowerData { - handshake_commitment, server_key, bytes_sent, bytes_recv, @@ -96,24 +88,24 @@ impl ludi::Actor for MpcTlsFollower { } impl MpcTlsFollower { - /// Create a new follower instance + /// Creates a new follower. pub fn new( config: MpcTlsFollowerConfig, channel: MpcTlsChannel, ke: Box, prf: Box, - encrypter: Box, - decrypter: Box, + encrypter: Box + Send>, + decrypter: Box + Send>, ) -> Self { let encrypter = Encrypter::new( encrypter, - config.common().tx_transcript_id().to_string(), - config.common().opaque_tx_transcript_id().to_string(), + config.common().tx_config().id().to_string(), + config.common().tx_config().opaque_id().to_string(), ); let decrypter = Decrypter::new( decrypter, - config.common().rx_transcript_id().to_string(), - config.common().opaque_rx_transcript_id().to_string(), + config.common().rx_config().id().to_string(), + config.common().rx_config().opaque_id().to_string(), ); let (_sink, stream) = channel.split(); @@ -133,20 +125,39 @@ impl MpcTlsFollower { } /// Performs any one-time setup operations. - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "trace", skip_all, err)] pub async fn setup(&mut self) -> Result<(), MpcTlsError> { let pms = self.ke.setup().await?; - self.prf.setup(pms.into_value()).await?; + let session_keys = self.prf.setup(pms.into_value()).await?; + futures::try_join!(self.encrypter.setup(), self.decrypter.setup())?; + + futures::try_join!( + self.encrypter + .set_key(session_keys.client_write_key, session_keys.client_iv), + self.decrypter + .set_key(session_keys.server_write_key, session_keys.server_iv) + )?; + + self.ke.preprocess().await?; + self.prf.preprocess().await?; + + let preprocess_encrypt = self.config.common().tx_config().max_online_size(); + let preprocess_decrypt = self.config.common().rx_config().max_online_size(); + + futures::try_join!( + self.encrypter.preprocess(preprocess_encrypt), + self.decrypter.preprocess(preprocess_decrypt), + )?; + + self.prf.set_client_random(None).await?; Ok(()) } /// Runs the follower actor. /// - /// Returns a control handle and a future that resolves when the actor is stopped. + /// Returns a control handle and a future that resolves when the actor is + /// stopped. /// /// # Note /// @@ -182,34 +193,45 @@ impl MpcTlsFollower { loop { futures::select! { res = &mut remote_fut => { - if let Err(e) = res { - return Err(e); - } + res?; }, res = &mut actor_fut => return res, } } }; - (ctrl, fut) - } - - /// Returns the total number of bytes sent and received. - fn total_bytes_transferred(&self) -> usize { - self.encrypter.sent_bytes() + self.decrypter.recv_bytes() - } - - fn check_transcript_length(&self, len: usize) -> Result<(), MpcTlsError> { - let new_len = self.total_bytes_transferred() + len; - if new_len > self.config.common().max_transcript_size() { - return Err(MpcTlsError::new( - Kind::Config, - format!( - "max transcript size exceeded: {} > {}", - new_len, - self.config.common().max_transcript_size() - ), - )); + (ctrl, fut.in_current_span()) + } + + fn check_transcript_length(&self, direction: Direction, len: usize) -> Result<(), MpcTlsError> { + match direction { + Direction::Sent => { + let new_len = self.encrypter.sent_bytes() + len; + let max_size = self.config.common().tx_config().max_online_size(); + if new_len > max_size { + return Err(MpcTlsError::new( + Kind::Config, + format!( + "max sent transcript size exceeded: {} > {}", + new_len, max_size + ), + )); + } + } + Direction::Recv => { + let new_len = self.decrypter.recv_bytes() + len; + let max_size = self.config.common().rx_config().max_online_size() + + self.config.common().rx_config().max_offline_size(); + if new_len > max_size { + return Err(MpcTlsError::new( + Kind::Config, + format!( + "max received transcript size exceeded: {} > {}", + new_len, max_size + ), + )); + } + } } Ok(()) @@ -217,8 +239,8 @@ impl MpcTlsFollower { /// Returns an error if the follower is not accepting new messages. /// - /// This can happen if the follower has received a CloseNotify alert or if the leader has - /// committed to the transcript. + /// This can happen if the follower has received a CloseNotify alert or if + /// the leader has committed to the transcript. fn is_accepting_messages(&self) -> Result<(), MpcTlsError> { if self.close_notify { return Err(MpcTlsError::new( @@ -237,40 +259,10 @@ impl MpcTlsFollower { Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] - async fn compute_client_key(&mut self) -> Result<(), MpcTlsError> { + #[instrument(level = "trace", skip_all, err)] + async fn compute_key_exchange(&mut self, server_random: [u8; 32]) -> Result<(), MpcTlsError> { self.state.take().try_into_init()?; - _ = self - .ke - .compute_client_key(p256::SecretKey::random(&mut rand::rngs::OsRng)) - .await?; - - self.state = State::ClientKey; - - Ok(()) - } - - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] - async fn compute_key_exchange( - &mut self, - handshake_commitment: Option, - ) -> Result<(), MpcTlsError> { - self.state.take().try_into_client_key()?; - - if self.config.common().handshake_commit() && handshake_commitment.is_none() { - return Err(MpcTlsError::new( - Kind::PeerMisbehaved, - "handshake commitment missing", - )); - } - // Key exchange self.ke.compute_pms().await?; @@ -280,18 +272,11 @@ impl MpcTlsFollower { .expect("server key should be set after computing pms"); // PRF - let SessionKeys { - client_write_key, - server_write_key, - client_iv, - server_iv, - } = self.prf.compute_session_keys_blind().await?; + self.prf.compute_session_keys(server_random).await?; - self.encrypter.set_key(client_write_key, client_iv).await?; - self.decrypter.set_key(server_write_key, server_iv).await?; + futures::try_join!(self.encrypter.start(), self.decrypter.start())?; self.state = State::Ke(Ke { - handshake_commitment, server_key: PublicKey::new( NamedGroup::secp256r1, server_key.to_encoded_point(false).as_bytes(), @@ -301,40 +286,41 @@ impl MpcTlsFollower { Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] - async fn client_finished_vd(&mut self) -> Result<(), MpcTlsError> { - let Ke { - handshake_commitment, - server_key, - } = self.state.take().try_into_ke()?; + #[instrument(level = "trace", skip_all, err)] + async fn client_finished_vd(&mut self, handshake_hash: [u8; 32]) -> Result<(), MpcTlsError> { + let Ke { server_key } = self.state.take().try_into_ke()?; - self.prf.compute_client_finished_vd_blind().await?; + let client_finished = self.prf.compute_client_finished_vd(handshake_hash).await?; self.state = State::Cf(Cf { - handshake_commitment, server_key, + client_finished, }); Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] - async fn server_finished_vd(&mut self) -> Result<(), MpcTlsError> { + #[instrument(level = "trace", skip_all, err)] + async fn server_finished_vd(&mut self, handshake_hash: [u8; 32]) -> Result<(), MpcTlsError> { let Sf { - handshake_commitment, server_key, + server_finished, } = self.state.take().try_into_sf()?; - self.prf.compute_server_finished_vd_blind().await?; + let expected_server_finished = self.prf.compute_server_finished_vd(handshake_hash).await?; + + let Some(server_finished) = server_finished else { + return Err(MpcTlsError::new(Kind::State, "server finished is not set")); + }; + + if server_finished != expected_server_finished { + return Err(MpcTlsError::new( + Kind::Prf, + "server finished does not match", + )); + } self.state = State::Active(Active { - handshake_commitment, server_key, buffer: Default::default(), }); @@ -342,32 +328,37 @@ impl MpcTlsFollower { Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "trace", skip_all, err)] async fn encrypt_client_finished(&mut self) -> Result<(), MpcTlsError> { let Cf { - handshake_commitment, server_key, + client_finished, } = self.state.take().try_into_cf()?; + let msg = HandshakeMessagePayload { + typ: HandshakeType::Finished, + payload: HandshakePayload::Finished(Payload::new(client_finished)), + }; + let mut payload = Vec::new(); + msg.encode(&mut payload); + self.encrypter - .encrypt_blind(ContentType::Handshake, ProtocolVersion::TLSv1_2, 16) + .encrypt_public(PlainMessage { + typ: ContentType::Handshake, + version: ProtocolVersion::TLSv1_2, + payload: Payload(payload), + }) .await?; self.state = State::Sf(Sf { - handshake_commitment, server_key, + server_finished: None, }); Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "trace", skip_all, err)] async fn encrypt_alert(&mut self, msg: Vec) -> Result<(), MpcTlsError> { self.is_accepting_messages()?; if let Some(alert) = AlertMessagePayload::read_bytes(&msg) { @@ -396,13 +387,10 @@ impl MpcTlsFollower { Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "trace", skip_all, err)] async fn encrypt_message(&mut self, len: usize) -> Result<(), MpcTlsError> { self.is_accepting_messages()?; - self.check_transcript_length(len)?; + self.check_transcript_length(Direction::Sent, len)?; self.state.try_as_active()?; self.encrypter @@ -412,13 +400,10 @@ impl MpcTlsFollower { Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "trace", skip_all, err)] fn commit_message(&mut self, payload: Vec) -> Result<(), MpcTlsError> { self.is_accepting_messages()?; - self.check_transcript_length(payload.len())?; + self.check_transcript_length(Direction::Recv, payload.len())?; let Active { buffer, .. } = self.state.try_as_active_mut()?; buffer.push_back(OpaqueMessage { @@ -430,28 +415,37 @@ impl MpcTlsFollower { Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "trace", skip_all, err)] async fn decrypt_server_finished(&mut self, msg: Vec) -> Result<(), MpcTlsError> { - self.state.try_as_sf()?; + let Sf { + server_finished, .. + } = self.state.try_as_sf_mut()?; - self.decrypter - .decrypt_blind(OpaqueMessage { + let msg = self + .decrypter + .decrypt_public(OpaqueMessage { typ: ContentType::Handshake, version: ProtocolVersion::TLSv1_2, payload: Payload::new(msg), }) .await?; + let msg = msg.payload.0; + if msg.len() != 16 { + return Err(MpcTlsError::new( + Kind::Decrypt, + "server finished message is not 16 bytes", + )); + } + + let sf: [u8; 12] = msg[4..].try_into().expect("slice should be 12 bytes"); + + server_finished.replace(sf); + Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "trace", skip_all, err)] async fn decrypt_alert(&mut self, msg: Vec) -> Result<(), MpcTlsError> { self.state.try_as_active()?; @@ -480,10 +474,7 @@ impl MpcTlsFollower { Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "trace", skip_all, err)] async fn decrypt_message(&mut self) -> Result<(), MpcTlsError> { let Active { buffer, .. } = self.state.try_as_active_mut()?; @@ -492,13 +483,13 @@ impl MpcTlsFollower { "attempted to decrypt message when no messages are committed", ))?; - #[cfg(feature = "tracing")] - tracing::debug!("decrypting message"); + debug!("decrypting message"); if self.committed { - // At this point the AEAD key was revealed to the leader and the leader locally decrypted - // the TLS message and now is proving to us that they know the plaintext which encrypts - // to the ciphertext of this TLS message. + // At this point the AEAD key was revealed to the leader and the leader locally + // decrypted the TLS message and now is proving to us that they know + // the plaintext which encrypts to the ciphertext of this TLS + // message. self.decrypter.verify_plaintext(msg).await?; } else { self.decrypter.decrypt_blind(msg).await?; @@ -507,16 +498,9 @@ impl MpcTlsFollower { Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "trace", skip_all, err)] fn close_connection(&mut self) -> Result<(), MpcTlsError> { - let Active { - handshake_commitment, - server_key, - buffer, - } = self.state.take().try_into_active()?; + let Active { server_key, buffer } = self.state.take().try_into_active()?; if !buffer.is_empty() { return Err(MpcTlsError::new( @@ -525,10 +509,7 @@ impl MpcTlsFollower { )); } - self.state = State::Closed(Closed { - handshake_commitment, - server_key, - }); + self.state = State::Closed(Closed { server_key }); Ok(()) } @@ -536,12 +517,12 @@ impl MpcTlsFollower { async fn commit(&mut self) -> Result<(), MpcTlsError> { let Active { buffer, .. } = self.state.try_as_active()?; - #[cfg(feature = "tracing")] - tracing::debug!("leader committed transcript"); + debug!("leader committed transcript"); self.committed = true; - // Reveal the AEAD key to the leader only if there are TLS messages which need to be decrypted. + // Reveal the AEAD key to the leader only if there are TLS messages which need + // to be decrypted. if !buffer.is_empty() { self.decrypter.decode_key_blind().await?; } @@ -554,21 +535,19 @@ impl MpcTlsFollower { #[msg(name = "{name}")] #[msg(attrs(derive(Debug, serde::Serialize, serde::Deserialize)))] impl MpcTlsFollower { - pub async fn compute_client_key(&mut self) { - ctx.try_or_stop(|_| self.compute_client_key()).await; - } - - pub async fn compute_key_exchange(&mut self, handshake_commitment: Option) { - ctx.try_or_stop(|_| self.compute_key_exchange(handshake_commitment)) + pub async fn compute_key_exchange(&mut self, server_random: [u8; 32]) { + ctx.try_or_stop(|_| self.compute_key_exchange(server_random)) .await; } - pub async fn client_finished_vd(&mut self) { - ctx.try_or_stop(|_| self.client_finished_vd()).await; + pub async fn client_finished_vd(&mut self, handshake_hash: [u8; 32]) { + ctx.try_or_stop(|_| self.client_finished_vd(handshake_hash)) + .await; } - pub async fn server_finished_vd(&mut self) { - ctx.try_or_stop(|_| self.server_finished_vd()).await; + pub async fn server_finished_vd(&mut self, handshake_hash: [u8; 32]) { + ctx.try_or_stop(|_| self.server_finished_vd(handshake_hash)) + .await; } pub async fn encrypt_client_finished(&mut self) { @@ -626,7 +605,6 @@ mod state { #[derive_err(Debug)] pub(super) enum State { Init, - ClientKey, Ke(Ke), Cf(Cf), Sf(Sf), @@ -649,36 +627,34 @@ mod state { #[derive(Debug)] pub(super) struct Ke { - pub(super) handshake_commitment: Option, pub(super) server_key: PublicKey, } #[derive(Debug)] pub(super) struct Cf { - pub(super) handshake_commitment: Option, pub(super) server_key: PublicKey, + pub(super) client_finished: [u8; 12], } #[derive(Debug)] pub(super) struct Sf { - pub(super) handshake_commitment: Option, pub(super) server_key: PublicKey, + pub(super) server_finished: Option<[u8; 12]>, } #[derive(Debug)] pub(super) struct Active { - pub(super) handshake_commitment: Option, pub(super) server_key: PublicKey, /// TLS messages purportedly received by the leader from the server. /// - /// The follower must verify the authenticity of these messages with AEAD verification - /// (i.e. by verifying the authentication tag). + /// The follower must verify the authenticity of these messages with + /// AEAD verification (i.e. by verifying the authentication + /// tag). pub(super) buffer: VecDeque, } #[derive(Debug)] pub(super) struct Closed { - pub(super) handshake_commitment: Option, pub(super) server_key: PublicKey, } } diff --git a/components/tls/tls-mpc/src/leader.rs b/crates/tls/mpc/src/leader.rs similarity index 78% rename from components/tls/tls-mpc/src/leader.rs rename to crates/tls/mpc/src/leader.rs index 31e449ab56..ae4ba6101d 100644 --- a/components/tls/tls-mpc/src/leader.rs +++ b/crates/tls/mpc/src/leader.rs @@ -3,16 +3,12 @@ use std::{collections::VecDeque, future::Future}; use async_trait::async_trait; use futures::SinkExt; -use hmac_sha256 as prf; use key_exchange as ke; -use mpz_core::commit::{Decommitment, HashCommit}; -use prf::SessionKeys; -use aead::Aead; +use aead::{aes_gcm::AesGcmError, Aead}; use hmac_sha256::Prf; use ke::KeyExchange; -use p256::SecretKey; use tls_backend::{ Backend, BackendError, BackendNotifier, BackendNotify, DecryptMode, EncryptMode, }; @@ -30,17 +26,18 @@ use tls_core::{ }, suites::SupportedCipherSuite, }; +use tracing::{debug, instrument, trace, Instrument}; use crate::{ error::Kind, follower::{ - ClientFinishedVd, CommitMessage, ComputeClientKey, ComputeKeyExchange, DecryptAlert, - DecryptMessage, DecryptServerFinished, EncryptAlert, EncryptClientFinished, EncryptMessage, + ClientFinishedVd, CommitMessage, ComputeKeyExchange, DecryptAlert, DecryptMessage, + DecryptServerFinished, EncryptAlert, EncryptClientFinished, EncryptMessage, ServerFinishedVd, }, msg::{CloseConnection, Commit, MpcTlsLeaderMsg, MpcTlsMessage}, record_layer::{Decrypter, Encrypter}, - MpcTlsChannel, MpcTlsError, MpcTlsLeaderConfig, + Direction, MpcTlsChannel, MpcTlsError, MpcTlsLeaderConfig, }; /// Controller for MPC-TLS leader. @@ -59,7 +56,8 @@ pub struct MpcTlsLeader { encrypter: Encrypter, decrypter: Decrypter, - /// When set, notifies the backend that there are TLS messages which need to be decrypted. + /// When set, notifies the backend that there are TLS messages which need to + /// be decrypted. notifier: BackendNotifier, /// Whether the backend is ready to decrypt messages. @@ -75,8 +73,7 @@ impl ludi::Actor for MpcTlsLeader { type Error = MpcTlsError; async fn stopped(&mut self) -> Result { - #[cfg(feature = "tracing")] - tracing::debug!("leader actor stopped"); + debug!("leader actor stopped"); let state::Closed { data } = self.state.take().try_into_closed()?; @@ -91,19 +88,20 @@ impl MpcTlsLeader { channel: MpcTlsChannel, ke: Box, prf: Box, - encrypter: Box, - decrypter: Box, + encrypter: Box + Send>, + decrypter: Box + Send>, ) -> Self { let encrypter = Encrypter::new( encrypter, - config.common().tx_transcript_id().to_string(), - config.common().opaque_tx_transcript_id().to_string(), + config.common().tx_config().id().to_string(), + config.common().tx_config().opaque_id().to_string(), ); let decrypter = Decrypter::new( decrypter, - config.common().rx_transcript_id().to_string(), - config.common().opaque_rx_transcript_id().to_string(), + config.common().rx_config().id().to_string(), + config.common().rx_config().opaque_id().to_string(), ); + let is_decrypting = !config.defer_decryption_from_start(); Self { config, @@ -114,27 +112,48 @@ impl MpcTlsLeader { encrypter, decrypter, notifier: BackendNotifier::new(), - is_decrypting: true, + is_decrypting, buffer: VecDeque::new(), committed: false, } } /// Performs any one-time setup operations. - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "debug", skip_all, err)] pub async fn setup(&mut self) -> Result<(), MpcTlsError> { let pms = self.ke.setup().await?; - self.prf.setup(pms.into_value()).await?; + let session_keys = self.prf.setup(pms.into_value()).await?; + futures::try_join!(self.encrypter.setup(), self.decrypter.setup())?; + + futures::try_join!( + self.encrypter + .set_key(session_keys.client_write_key, session_keys.client_iv), + self.decrypter + .set_key(session_keys.server_write_key, session_keys.server_iv) + )?; + + self.ke.preprocess().await?; + self.prf.preprocess().await?; + + let preprocess_encrypt = self.config.common().tx_config().max_online_size(); + let preprocess_decrypt = self.config.common().rx_config().max_online_size(); + + futures::try_join!( + self.encrypter.preprocess(preprocess_encrypt), + self.decrypter.preprocess(preprocess_decrypt), + )?; + + self.prf + .set_client_random(Some(self.state.try_as_ke()?.client_random.0)) + .await?; Ok(()) } /// Runs the leader actor. /// - /// Returns a control handle and a future that resolves when the actor is stopped. + /// Returns a control handle and a future that resolves when the actor is + /// stopped. /// /// # Note /// @@ -150,7 +169,7 @@ impl MpcTlsLeader { let ctrl = LeaderCtrl::from(addr); let fut = async move { ludi::run(&mut self, &mut mailbox).await }; - (ctrl, fut) + (ctrl, fut.in_current_span()) } /// Returns the number of bytes sent and received. @@ -158,31 +177,41 @@ impl MpcTlsLeader { (self.encrypter.sent_bytes(), self.decrypter.recv_bytes()) } - /// Returns the total number of bytes sent and received. - fn total_bytes_transferred(&self) -> usize { - self.encrypter.sent_bytes() + self.decrypter.recv_bytes() - } - - fn check_transcript_length(&self, len: usize) -> Result<(), MpcTlsError> { - let new_len = self.total_bytes_transferred() + len; - if new_len > self.config.common().max_transcript_size() { - return Err(MpcTlsError::new( - Kind::Config, - format!( - "max transcript size exceeded: {} > {}", - new_len, - self.config.common().max_transcript_size() - ), - )); - } else { - Ok(()) + fn check_transcript_length(&self, direction: Direction, len: usize) -> Result<(), MpcTlsError> { + match direction { + Direction::Sent => { + let new_len = self.encrypter.sent_bytes() + len; + let max_size = self.config.common().tx_config().max_online_size(); + if new_len > max_size { + return Err(MpcTlsError::new( + Kind::Config, + format!( + "max sent transcript size exceeded: {} > {}", + new_len, max_size + ), + )); + } + } + Direction::Recv => { + let new_len = self.decrypter.recv_bytes() + len; + let max_size = self.config.common().rx_config().max_online_size() + + self.config.common().rx_config().max_offline_size(); + if new_len > max_size { + return Err(MpcTlsError::new( + Kind::Config, + format!( + "max received transcript size exceeded: {} > {}", + new_len, max_size + ), + )); + } + } } + + Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "debug", skip_all, err)] async fn encrypt_client_finished( &mut self, msg: PlainMessage, @@ -193,17 +222,14 @@ impl MpcTlsLeader { .send(MpcTlsMessage::EncryptClientFinished(EncryptClientFinished)) .await?; - let msg = self.encrypter.encrypt_private(msg).await?; + let msg = self.encrypter.encrypt_public(msg).await?; self.state = State::Sf(Sf { data }); Ok(msg) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "debug", skip_all, err)] async fn encrypt_alert(&mut self, msg: PlainMessage) -> Result { if let Some(alert) = AlertMessagePayload::read_bytes(&msg.payload.0) { // We only allow CloseNotify alerts. @@ -229,16 +255,13 @@ impl MpcTlsLeader { Ok(msg) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "debug", skip_all, err)] async fn encrypt_application_data( &mut self, msg: PlainMessage, ) -> Result { self.state.try_as_active()?; - self.check_transcript_length(msg.payload.0.len())?; + self.check_transcript_length(Direction::Sent, msg.payload.0.len())?; self.channel .send(MpcTlsMessage::EncryptMessage(EncryptMessage { @@ -251,10 +274,7 @@ impl MpcTlsLeader { Ok(msg) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "debug", skip_all, err)] async fn decrypt_server_finished( &mut self, msg: OpaqueMessage, @@ -269,17 +289,14 @@ impl MpcTlsLeader { )) .await?; - let msg = self.decrypter.decrypt_private(msg).await?; + let msg = self.decrypter.decrypt_public(msg).await?; self.state = State::Active(Active { data }); Ok(msg) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "debug", skip_all, err)] async fn decrypt_alert(&mut self, msg: OpaqueMessage) -> Result { self.state.try_as_active()?; @@ -294,24 +311,22 @@ impl MpcTlsLeader { Ok(msg) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "trace", skip_all, err) - )] + #[instrument(level = "debug", skip_all, err)] async fn decrypt_application_data( &mut self, msg: OpaqueMessage, ) -> Result { self.state.try_as_active()?; - self.check_transcript_length(msg.payload.0.len())?; + self.check_transcript_length(Direction::Recv, msg.payload.0.len())?; self.channel .send(MpcTlsMessage::DecryptMessage(DecryptMessage)) .await?; let msg = if self.committed { - // At this point the AEAD key was revealed to us. We will locally decrypt the TLS message - // and will prove the knowledge of the plaintext to the follower. + // At this point the AEAD key was revealed to us. We will locally decrypt the + // TLS message and will prove the knowledge of the plaintext to the + // follower. self.decrypter.prove_plaintext(msg).await? } else { self.decrypter.decrypt_private(msg).await? @@ -320,11 +335,14 @@ impl MpcTlsLeader { Ok(msg) } + #[instrument(level = "debug", skip_all, err)] async fn commit(&mut self) -> Result<(), MpcTlsError> { + if self.committed { + return Ok(()); + } self.state.try_as_active()?; - #[cfg(feature = "tracing")] - tracing::debug!("committing to transcript"); + debug!("committing to transcript"); self.channel.send(MpcTlsMessage::Commit(Commit)).await?; @@ -343,14 +361,10 @@ impl MpcTlsLeader { #[ludi::implement(msg(name = "{name}"), ctrl(err))] impl MpcTlsLeader { /// Closes the connection. - #[cfg_attr( - feature = "tracing", - tracing::instrument(name = "close_connection", level = "trace", skip_all, err) - )] + #[instrument(name = "close_connection", level = "debug", skip_all, err)] #[msg(skip, name = "CloseConnection")] pub async fn close_connection(&mut self) -> Result<(), MpcTlsError> { - #[cfg(feature = "tracing")] - tracing::debug!("closing connection"); + debug!("closing connection"); self.channel .send(MpcTlsMessage::CloseConnection(CloseConnection)) @@ -379,12 +393,9 @@ impl MpcTlsLeader { /// Commits the leader to the current transcript. /// - /// This reveals the AEAD key to the leader and disables sending or receiving - /// any further messages. - #[cfg_attr( - feature = "tracing", - tracing::instrument(name = "finalize", level = "trace", skip_all, err) - )] + /// This reveals the AEAD key to the leader and disables sending or + /// receiving any further messages. + #[instrument(name = "finalize", level = "debug", skip_all, err)] #[msg(skip, name = "Commit")] pub async fn commit(&mut self) -> Result<(), MpcTlsError> { self.commit().await @@ -401,6 +412,8 @@ impl Backend for MpcTlsLeader { protocol_version, .. } = self.state.try_as_ke_mut().map_err(MpcTlsError::from)?; + trace!("setting protocol version: {:?}", version); + *protocol_version = Some(version); Ok(()) @@ -409,6 +422,8 @@ impl Backend for MpcTlsLeader { async fn set_cipher_suite(&mut self, suite: SupportedCipherSuite) -> Result<(), BackendError> { let Ke { cipher_suite, .. } = self.state.try_as_ke_mut().map_err(MpcTlsError::from)?; + trace!("setting cipher suite: {:?}", suite); + *cipher_suite = Some(suite.suite()); Ok(()) @@ -432,18 +447,9 @@ impl Backend for MpcTlsLeader { Ok(*client_random) } + #[instrument(level = "debug", skip_all, err)] async fn get_client_key_share(&mut self) -> Result { - self.channel - .send(MpcTlsMessage::ComputeClientKey(ComputeClientKey)) - .await - .map_err(MpcTlsError::from)?; - - let pk = self - .ke - .compute_client_key(SecretKey::random(&mut rand::rngs::OsRng)) - .await - .map_err(MpcTlsError::from)? - .expect("client key is returned as leader"); + let pk = self.ke.client_key().await.map_err(MpcTlsError::from)?; Ok(PublicKey::new( NamedGroup::secp256r1, @@ -459,6 +465,7 @@ impl Backend for MpcTlsLeader { Ok(()) } + #[instrument(level = "debug", skip_all, err)] async fn set_server_key_share(&mut self, key: PublicKey) -> Result<(), BackendError> { let Ke { server_public_key, .. @@ -471,7 +478,16 @@ impl Backend for MpcTlsLeader { ) .into()) } else { + let server_key = p256::PublicKey::from_sec1_bytes(&key.key) + .map_err(|_| MpcTlsError::other("server key is not valid sec1p256"))?; + *server_public_key = Some(key); + + self.ke + .set_server_key(server_key) + .await + .map_err(MpcTlsError::from)?; + Ok(()) } } @@ -511,44 +527,51 @@ impl Backend for MpcTlsLeader { Ok(()) } + #[instrument(level = "debug", skip_all, err)] async fn get_server_finished_vd(&mut self, hash: Vec) -> Result, BackendError> { let hash: [u8; 32] = hash .try_into() .map_err(|_| MpcTlsError::other("server finished handshake hash is not 32 bytes"))?; self.channel - .send(MpcTlsMessage::ServerFinishedVd(ServerFinishedVd)) + .send(MpcTlsMessage::ServerFinishedVd(ServerFinishedVd { + handshake_hash: hash, + })) .await .map_err(|e| BackendError::InternalError(e.to_string()))?; let vd = self .prf - .compute_server_finished_vd_private(hash) + .compute_server_finished_vd(hash) .await .map_err(MpcTlsError::from)?; Ok(vd.to_vec()) } + #[instrument(level = "debug", skip_all, err)] async fn get_client_finished_vd(&mut self, hash: Vec) -> Result, BackendError> { let hash: [u8; 32] = hash .try_into() .map_err(|_| MpcTlsError::other("client finished handshake hash is not 32 bytes"))?; self.channel - .send(MpcTlsMessage::ClientFinishedVd(ClientFinishedVd)) + .send(MpcTlsMessage::ClientFinishedVd(ClientFinishedVd { + handshake_hash: hash, + })) .await .map_err(|e| BackendError::InternalError(e.to_string()))?; let vd = self .prf - .compute_client_finished_vd_private(hash) + .compute_client_finished_vd(hash) .await .map_err(MpcTlsError::from)?; Ok(vd.to_vec()) } + #[instrument(level = "debug", skip_all, err)] async fn prepare_encryption(&mut self) -> Result<(), BackendError> { let Ke { protocol_version, @@ -578,42 +601,21 @@ impl Backend for MpcTlsLeader { server_random, ); - let (handshake_decommitment, handshake_commitment) = - if self.config.common().handshake_commit() { - let (decommitment, commitment) = handshake_data.clone().hash_commit(); - - (Some(decommitment), Some(commitment)) - } else { - (None, None) - }; - self.channel .send(MpcTlsMessage::ComputeKeyExchange(ComputeKeyExchange { - handshake_commitment, + server_random: server_random.0, })) .await .map_err(|e| BackendError::InternalError(e.to_string()))?; - let server_key = p256::PublicKey::from_sec1_bytes(&server_public_key.key) - .map_err(|_| MpcTlsError::other("server key is not valid sec1p256"))?; - - self.ke.set_server_key(server_key); - self.ke.compute_pms().await.map_err(MpcTlsError::from)?; - let SessionKeys { - client_write_key, - server_write_key, - client_iv, - server_iv, - } = self - .prf - .compute_session_keys_private(client_random.0, server_random.0) + self.prf + .compute_session_keys(server_random.0) .await .map_err(MpcTlsError::from)?; - self.encrypter.set_key(client_write_key, client_iv).await?; - self.decrypter.set_key(server_write_key, server_iv).await?; + futures::try_join!(self.encrypter.start(), self.decrypter.start())?; self.state = State::Cf(Cf { data: MpcTlsData { @@ -625,7 +627,6 @@ impl Backend for MpcTlsLeader { server_public_key, server_kx_details, handshake_data, - handshake_decommitment, }, }); @@ -741,8 +742,6 @@ pub struct MpcTlsData { pub server_kx_details: ServerKxDetails, /// Handshake data. pub handshake_data: HandshakeData, - /// Handshake data decommitment. - pub handshake_decommitment: Option>, } mod state { diff --git a/components/tls/tls-mpc/src/lib.rs b/crates/tls/mpc/src/lib.rs similarity index 72% rename from components/tls/tls-mpc/src/lib.rs rename to crates/tls/mpc/src/lib.rs index 57ea05acd0..8741cc2821 100644 --- a/components/tls/tls-mpc/src/lib.rs +++ b/crates/tls/mpc/src/lib.rs @@ -1,29 +1,32 @@ -//! This crate provides tooling for instantiating MPC TLS machinery for leader and follower. +//! This crate provides tooling for instantiating MPC TLS machinery for leader +//! and follower. -//! The main API objects are [MpcTlsLeader] and [MpcTlsFollower], which wrap the necessary -//! cryptographic machinery and also an [MpcTlsChannel] for communication. +//! The main API objects are [MpcTlsLeader] and [MpcTlsFollower], which wrap the +//! necessary cryptographic machinery and also an [MpcTlsChannel] for +//! communication. #![deny(missing_docs, unreachable_pub, unused_must_use)] #![deny(clippy::all)] #![forbid(unsafe_code)] +mod components; mod config; pub(crate) mod error; pub(crate) mod follower; pub(crate) mod leader; pub mod msg; pub(crate) mod record_layer; -pub(crate) mod setup; +pub use components::build_components; pub use config::{ MpcTlsCommonConfig, MpcTlsCommonConfigBuilder, MpcTlsCommonConfigBuilderError, MpcTlsFollowerConfig, MpcTlsFollowerConfigBuilder, MpcTlsFollowerConfigBuilderError, MpcTlsLeaderConfig, MpcTlsLeaderConfigBuilder, MpcTlsLeaderConfigBuilderError, + TranscriptConfig, TranscriptConfigBuilder, TranscriptConfigBuilderError, }; pub use error::MpcTlsError; pub use follower::{FollowerCtrl, MpcTlsFollower, MpcTlsFollowerData}; pub use leader::{LeaderCtrl, MpcTlsData, MpcTlsLeader}; -pub use setup::setup_components; use utils_aio::duplex::Duplex; /// A channel for sending and receiving messages between leader and follower @@ -36,3 +39,11 @@ pub enum TlsRole { Leader, Follower, } + +/// The direction of a message +pub(crate) enum Direction { + /// Data sent to the TLS peer + Sent, + /// Data received from the TLS peer + Recv, +} diff --git a/components/tls/tls-mpc/src/msg.rs b/crates/tls/mpc/src/msg.rs similarity index 93% rename from components/tls/tls-mpc/src/msg.rs rename to crates/tls/mpc/src/msg.rs index bfe255b898..e58a1461c1 100644 --- a/components/tls/tls-mpc/src/msg.rs +++ b/crates/tls/mpc/src/msg.rs @@ -5,8 +5,8 @@ use serde::{Deserialize, Serialize}; use crate::{ error::Kind, follower::{ - ClientFinishedVd, CommitMessage, ComputeClientKey, ComputeKeyExchange, DecryptAlert, - DecryptMessage, DecryptServerFinished, EncryptAlert, EncryptClientFinished, EncryptMessage, + ClientFinishedVd, CommitMessage, ComputeKeyExchange, DecryptAlert, DecryptMessage, + DecryptServerFinished, EncryptAlert, EncryptClientFinished, EncryptMessage, ServerFinishedVd, }, leader::{ @@ -26,7 +26,6 @@ use crate::{ #[allow(missing_docs)] #[derive(Debug, Serialize, Deserialize)] pub enum MpcTlsMessage { - ComputeClientKey(ComputeClientKey), ComputeKeyExchange(ComputeKeyExchange), ClientFinishedVd(ClientFinishedVd), EncryptClientFinished(EncryptClientFinished), @@ -48,7 +47,6 @@ impl TryFrom for MpcTlsFollowerMsg { fn try_from(msg: MpcTlsMessage) -> Result { #[allow(unreachable_patterns)] match msg { - MpcTlsMessage::ComputeClientKey(msg) => Ok(Self::ComputeClientKey(msg)), MpcTlsMessage::ComputeKeyExchange(msg) => Ok(Self::ComputeKeyExchange(msg)), MpcTlsMessage::ClientFinishedVd(msg) => Ok(Self::ClientFinishedVd(msg)), MpcTlsMessage::EncryptClientFinished(msg) => Ok(Self::EncryptClientFinished(msg)), @@ -105,7 +103,6 @@ pub enum MpcTlsLeaderMsg { #[allow(missing_docs)] #[ludi(return_attrs(allow(missing_docs)))] pub enum MpcTlsFollowerMsg { - ComputeClientKey(ComputeClientKey), ComputeKeyExchange(ComputeKeyExchange), ClientFinishedVd(ClientFinishedVd), EncryptClientFinished(EncryptClientFinished), diff --git a/components/tls/tls-mpc/src/record_layer.rs b/crates/tls/mpc/src/record_layer.rs similarity index 83% rename from components/tls/tls-mpc/src/record_layer.rs rename to crates/tls/mpc/src/record_layer.rs index 8742dca955..1d986933b2 100644 --- a/components/tls/tls-mpc/src/record_layer.rs +++ b/crates/tls/mpc/src/record_layer.rs @@ -1,3 +1,4 @@ +use aead::aes_gcm::AesGcmError; use mpz_garble::value::ValueRef; use tls_core::{ @@ -12,7 +13,7 @@ use tls_core::{ use crate::{error::Kind, MpcTlsError}; pub(crate) struct Encrypter { - aead: Box, + aead: Box>, seq: u64, sent_bytes: usize, transcript_id: String, @@ -21,7 +22,7 @@ pub(crate) struct Encrypter { impl Encrypter { pub(crate) fn new( - aead: Box, + aead: Box>, transcript_id: String, opaque_transcript_id: String, ) -> Self { @@ -47,6 +48,33 @@ impl Encrypter { Ok(()) } + pub(crate) async fn preprocess(&mut self, len: usize) -> Result<(), MpcTlsError> { + self.aead + .preprocess(len) + .await + .map_err(|e| MpcTlsError::new_with_source(Kind::Encrypt, "preprocess error", e))?; + + Ok(()) + } + + pub(crate) async fn setup(&mut self) -> Result<(), MpcTlsError> { + self.aead + .setup() + .await + .map_err(|e| MpcTlsError::new_with_source(Kind::Encrypt, "setup error", e))?; + + Ok(()) + } + + pub(crate) async fn start(&mut self) -> Result<(), MpcTlsError> { + self.aead + .start() + .await + .map_err(|e| MpcTlsError::new_with_source(Kind::Encrypt, "start error", e))?; + + Ok(()) + } + pub(crate) async fn encrypt_private( &mut self, msg: PlainMessage, @@ -158,7 +186,7 @@ impl Encrypter { } pub(crate) struct Decrypter { - aead: Box, + aead: Box>, seq: u64, recv_bytes: usize, transcript_id: String, @@ -167,7 +195,7 @@ pub(crate) struct Decrypter { impl Decrypter { pub(crate) fn new( - aead: Box, + aead: Box>, transcript_id: String, opaque_transcript_id: String, ) -> Self { @@ -193,6 +221,33 @@ impl Decrypter { Ok(()) } + pub(crate) async fn preprocess(&mut self, len: usize) -> Result<(), MpcTlsError> { + self.aead + .preprocess(len) + .await + .map_err(|e| MpcTlsError::new_with_source(Kind::Decrypt, "preprocess error", e))?; + + Ok(()) + } + + pub(crate) async fn setup(&mut self) -> Result<(), MpcTlsError> { + self.aead + .setup() + .await + .map_err(|e| MpcTlsError::new_with_source(Kind::Decrypt, "setup error", e))?; + + Ok(()) + } + + pub(crate) async fn start(&mut self) -> Result<(), MpcTlsError> { + self.aead + .start() + .await + .map_err(|e| MpcTlsError::new_with_source(Kind::Decrypt, "start error", e))?; + + Ok(()) + } + pub(crate) async fn decrypt_private( &mut self, msg: OpaqueMessage, @@ -297,8 +352,9 @@ impl Decrypter { /// Proves the plaintext of the message to the other party /// - /// This verifies the tag of the message and locally decrypts it. Then, this party - /// commits to the plaintext and proves it encrypts back to the ciphertext. + /// This verifies the tag of the message and locally decrypts it. Then, this + /// party commits to the plaintext and proves it encrypts back to the + /// ciphertext. pub(crate) async fn prove_plaintext( &mut self, msg: OpaqueMessage, @@ -333,8 +389,9 @@ impl Decrypter { /// Verifies the plaintext of the message /// - /// This verifies the tag of the message then has the other party decrypt it. Then, - /// the other party commits to the plaintext and proves it encrypts back to the ciphertext. + /// This verifies the tag of the message then has the other party decrypt + /// it. Then, the other party commits to the plaintext and proves it + /// encrypts back to the ciphertext. pub(crate) async fn verify_plaintext(&mut self, msg: OpaqueMessage) -> Result<(), MpcTlsError> { let OpaqueMessage { typ, diff --git a/crates/tls/mpc/tests/test.rs b/crates/tls/mpc/tests/test.rs new file mode 100644 index 0000000000..852a3bad5e --- /dev/null +++ b/crates/tls/mpc/tests/test.rs @@ -0,0 +1,334 @@ +use std::{sync::Arc, time::Duration}; + +use futures::{AsyncReadExt, AsyncWriteExt}; +use mpz_common::{executor::MTExecutor, Allocate}; +use mpz_garble::{config::Role as GarbleRole, protocol::deap::DEAPThread}; +use mpz_ot::{ + chou_orlandi::{ + Receiver as BaseReceiver, ReceiverConfig as BaseReceiverConfig, Sender as BaseSender, + SenderConfig as BaseSenderConfig, + }, + kos::{Receiver, ReceiverConfig, Sender, SenderConfig, SharedReceiver, SharedSender}, + CommittedOTSender, VerifiableOTReceiver, +}; +use serio::StreamExt; +use tls_client::Certificate; +use tls_client_async::bind_client; +use tls_mpc::{ + build_components, MpcTlsCommonConfig, MpcTlsFollower, MpcTlsFollowerConfig, MpcTlsLeader, + MpcTlsLeaderConfig, TlsRole, +}; +use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN}; +use tokio_util::compat::TokioAsyncReadCompatExt; +use uid_mux::{ + test_utils::{test_framed_mux, TestFramedMux}, + FramedUidMux, +}; + +const OT_SETUP_COUNT: usize = 1_000_000; + +async fn leader(config: MpcTlsCommonConfig, mux: TestFramedMux) { + let mut exec = MTExecutor::new(mux.clone(), 8); + + let mut ot_sender = Sender::new( + SenderConfig::default(), + BaseReceiver::new(BaseReceiverConfig::default()), + ); + ot_sender.alloc(OT_SETUP_COUNT); + + let mut ot_receiver = Receiver::new( + ReceiverConfig::builder().sender_commit().build().unwrap(), + BaseSender::new( + BaseSenderConfig::builder() + .receiver_commit() + .build() + .unwrap(), + ), + ); + ot_receiver.alloc(OT_SETUP_COUNT); + + let ot_sender = SharedSender::new(ot_sender); + let mut ot_receiver = SharedReceiver::new(ot_receiver); + + let mut vm = DEAPThread::new( + GarbleRole::Leader, + [0u8; 32], + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ); + + let (ke, prf, encrypter, decrypter) = build_components( + TlsRole::Leader, + &config, + exec.new_thread().await.unwrap(), + exec.new_thread().await.unwrap(), + exec.new_thread().await.unwrap(), + exec.new_thread().await.unwrap(), + exec.new_thread().await.unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ); + + let mut leader = MpcTlsLeader::new( + MpcTlsLeaderConfig::builder() + .common(config) + .defer_decryption_from_start(false) + .build() + .unwrap(), + Box::new(StreamExt::compat_stream( + mux.open_framed(b"mpc_tls").await.unwrap(), + )), + ke, + prf, + encrypter, + decrypter, + ); + + leader.setup().await.unwrap(); + + let (leader_ctrl, leader_fut) = leader.run(); + tokio::spawn(async { leader_fut.await.unwrap() }); + + let mut root_store = tls_client::RootCertStore::empty(); + root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap(); + let config = tls_client::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + + let server_name = SERVER_DOMAIN.try_into().unwrap(); + + let client = tls_client::ClientConnection::new( + Arc::new(config), + Box::new(leader_ctrl.clone()), + server_name, + ) + .unwrap(); + + let (client_socket, server_socket) = tokio::io::duplex(1 << 16); + + tokio::spawn(bind_test_server_hyper(server_socket.compat())); + + let (mut conn, conn_fut) = bind_client(client_socket.compat(), client); + + tokio::spawn(async { conn_fut.await.unwrap() }); + + let msg = concat!( + "POST /echo HTTP/1.1\r\n", + "Host: test-server.io\r\n", + "Connection: keep-alive\r\n", + "Accept-Encoding: identity\r\n", + "Content-Length: 5\r\n", + "\r\n", + "hello", + "\r\n" + ); + + conn.write_all(msg.as_bytes()).await.unwrap(); + + let mut buf = vec![0u8; 48]; + conn.read_exact(&mut buf).await.unwrap(); + + println!("{}", String::from_utf8_lossy(&buf)); + + leader_ctrl.defer_decryption().await.unwrap(); + + let msg = concat!( + "POST /echo HTTP/1.1\r\n", + "Host: test-server.io\r\n", + "Connection: close\r\n", + "Accept-Encoding: identity\r\n", + "Content-Length: 5\r\n", + "\r\n", + "hello", + "\r\n" + ); + + conn.write_all(msg.as_bytes()).await.unwrap(); + + // Wait for the server to reply. + tokio::time::sleep(Duration::from_millis(100)).await; + + leader_ctrl.commit().await.unwrap(); + + let mut buf = vec![0u8; 1024]; + conn.read_to_end(&mut buf).await.unwrap(); + + leader_ctrl.close_connection().await.unwrap(); + conn.close().await.unwrap(); + + let mut ctx = exec.new_thread().await.unwrap(); + + ot_receiver.accept_reveal(&mut ctx).await.unwrap(); + + vm.finalize().await.unwrap(); +} + +async fn follower(config: MpcTlsCommonConfig, mux: TestFramedMux) { + let mut exec = MTExecutor::new(mux.clone(), 8); + + let mut ot_sender = Sender::new( + SenderConfig::builder().sender_commit().build().unwrap(), + BaseReceiver::new( + BaseReceiverConfig::builder() + .receiver_commit() + .build() + .unwrap(), + ), + ); + ot_sender.alloc(OT_SETUP_COUNT); + + let mut ot_receiver = Receiver::new( + ReceiverConfig::default(), + BaseSender::new(BaseSenderConfig::default()), + ); + ot_receiver.alloc(OT_SETUP_COUNT); + + let mut ot_sender = SharedSender::new(ot_sender); + let ot_receiver = SharedReceiver::new(ot_receiver); + + let mut vm = DEAPThread::new( + GarbleRole::Follower, + [0u8; 32], + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ); + + let (ke, prf, encrypter, decrypter) = build_components( + TlsRole::Follower, + &config, + exec.new_thread().await.unwrap(), + exec.new_thread().await.unwrap(), + exec.new_thread().await.unwrap(), + exec.new_thread().await.unwrap(), + exec.new_thread().await.unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + vm.new_thread( + exec.new_thread().await.unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ) + .unwrap(), + ot_sender.clone(), + ot_receiver.clone(), + ); + + let mut follower = MpcTlsFollower::new( + MpcTlsFollowerConfig::builder() + .common(config) + .build() + .unwrap(), + Box::new(StreamExt::compat_stream( + mux.open_framed(b"mpc_tls").await.unwrap(), + )), + ke, + prf, + encrypter, + decrypter, + ); + + follower.setup().await.unwrap(); + + let (_, fut) = follower.run(); + fut.await.unwrap(); + + let mut ctx = exec.new_thread().await.unwrap(); + + ot_sender.reveal(&mut ctx).await.unwrap(); + + vm.finalize().await.unwrap(); +} + +#[tokio::test] +#[ignore] +async fn test() { + tracing_subscriber::fmt::init(); + + let (leader_mux, follower_mux) = test_framed_mux(8); + + let common_config = MpcTlsCommonConfig::builder().build().unwrap(); + + tokio::join!( + leader(common_config.clone(), leader_mux), + follower(common_config.clone(), follower_mux) + ); +} diff --git a/components/tls/tls-server-fixture/Cargo.toml b/crates/tls/server-fixture/Cargo.toml similarity index 60% rename from components/tls/tls-server-fixture/Cargo.toml rename to crates/tls/server-fixture/Cargo.toml index b488006b48..476381631f 100644 --- a/components/tls/tls-server-fixture/Cargo.toml +++ b/crates/tls/server-fixture/Cargo.toml @@ -10,10 +10,12 @@ edition = "2021" publish = false [dependencies] -async-rustls.workspace = true -futures.workspace = true +bytes = { workspace = true } +futures = { workspace = true } +futures-rustls = { workspace = true } +http-body-util = { workspace = true } hyper = { workspace = true, features = ["full"] } -rustls = { version = "0.21", features = ["logging"] } +hyper-util = { workspace = true, features = ["full"] } +tokio = { workspace = true } tokio-util = { workspace = true, features = ["compat", "io-util"] } -tracing.workspace = true -tokio.workspace = true +tracing = { workspace = true } diff --git a/crates/tls/server-fixture/src/README.md b/crates/tls/server-fixture/src/README.md new file mode 100644 index 0000000000..01d243ef86 --- /dev/null +++ b/crates/tls/server-fixture/src/README.md @@ -0,0 +1,23 @@ +# Create a private key for the root CA +openssl genpkey -algorithm RSA -out root_ca.key -pkeyopt rsa_keygen_bits:2048 + +# Create a self-signed root CA certificate (100 years validity) +openssl req -x509 -new -nodes -key root_ca.key -sha256 -days 36525 -out root_ca.crt -subj "/C=US/ST=State/L=City/O=tlsnotary/OU=IT/CN=tlsnotary.org" + +# Create a private key for the end entity certificate +openssl genpkey -algorithm RSA -out test_server.key -pkeyopt rsa_keygen_bits:2048 + +# Create a certificate signing request (CSR) for the end entity certificate +openssl req -new -key test_server.key -out test_server.csr -subj "/C=US/ST=State/L=City/O=tlsnotary/OU=IT/CN=test-server.io" + +# Sign the CSR with the root CA to create the end entity certificate (100 years validity) +openssl x509 -req -in test_server.csr -CA root_ca.crt -CAkey root_ca.key -CAcreateserial -out test_server.crt -days 36525 -sha256 -extfile openssl.cnf -extensions v3_req + +# Convert the root CA certificate to DER format +openssl x509 -in root_ca.crt -outform der -out root_ca_cert.der + +# Convert the end entity certificate to DER format +openssl x509 -in test_server.crt -outform der -out test_server_cert.der + +# Convert the end entity certificate private key to DER format +openssl pkcs8 -topk8 -inform PEM -outform DER -in test_server.key -out test_server_private_key.der -nocrypt \ No newline at end of file diff --git a/components/tls/tls-server-fixture/src/lib.rs b/crates/tls/server-fixture/src/lib.rs similarity index 79% rename from components/tls/tls-server-fixture/src/lib.rs rename to crates/tls/server-fixture/src/lib.rs index f7019b04aa..70f2a7fbe9 100644 --- a/components/tls/tls-server-fixture/src/lib.rs +++ b/crates/tls/server-fixture/src/lib.rs @@ -4,10 +4,21 @@ #![deny(clippy::all)] #![forbid(unsafe_code)] -use async_rustls::{server::TlsStream, TlsAcceptor}; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, TryStreamExt}; -use hyper::{server::conn::Http, service::service_fn, Body, Method, Request, Response, StatusCode}; -use rustls::{Certificate, PrivateKey, ServerConfig}; +use bytes::Bytes; +use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use futures_rustls::{ + pki_types::{CertificateDer, PrivateKeyDer}, + rustls::ServerConfig, + TlsAcceptor, +}; +use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full}; +use hyper::{ + body::{Frame, Incoming}, + server::conn::http1, + service::service_fn, + Method, Request, Response, StatusCode, +}; +use hyper_util::rt::TokioIo; use std::{io::Write, sync::Arc}; use tokio_util::{ compat::{Compat, FuturesAsyncReadCompatExt}, @@ -16,11 +27,11 @@ use tokio_util::{ use tracing::Instrument; /// A certificate authority certificate fixture. -pub static CA_CERT_DER: &[u8] = include_bytes!("rootCA.der"); +pub static CA_CERT_DER: &[u8] = include_bytes!("root_ca_cert.der"); /// A server certificate (domain=test-server.io) fixture. -pub static SERVER_CERT_DER: &[u8] = include_bytes!("domain.der"); +pub static SERVER_CERT_DER: &[u8] = include_bytes!("test_server_cert.der"); /// A server private key fixture. -pub static SERVER_KEY_DER: &[u8] = include_bytes!("domain_key.der"); +pub static SERVER_KEY_DER: &[u8] = include_bytes!("test_server_private_key.der"); /// The domain name bound to the server certificate. pub static SERVER_DOMAIN: &str = "test-server.io"; /// The length of an application record expected by the test TLS server. @@ -32,12 +43,11 @@ pub static CLOSE_DELAY: u64 = 1000; #[tracing::instrument(skip(socket))] pub async fn bind_test_server_hyper( socket: T, -) -> Result, hyper::Error> { - let key = PrivateKey(SERVER_KEY_DER.to_vec()); - let cert = Certificate(SERVER_CERT_DER.to_vec()); +) -> Result<(), hyper::Error> { + let key = PrivateKeyDer::Pkcs8(SERVER_KEY_DER.into()); + let cert = CertificateDer::from(SERVER_CERT_DER); let config = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() .with_single_cert(vec![cert], key) .unwrap(); @@ -46,14 +56,13 @@ pub async fn bind_test_server_hyper( socket: Compat, ) { - let key = PrivateKey(SERVER_KEY_DER.to_vec()); - let cert = Certificate(SERVER_CERT_DER.to_vec()); + let key = PrivateKeyDer::Pkcs8(SERVER_KEY_DER.into()); + let cert = CertificateDer::from(SERVER_CERT_DER); let config = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() .with_single_cert(vec![cert], key) .unwrap(); @@ -85,8 +93,8 @@ pub async fn bind_test_server< let mut read_buf = vec![0u8; APP_RECORD_LENGTH]; if conn.read_exact(&mut read_buf).await.is_err() { // EOF reached because client closed its tx part of the socket. - // The client's rx part of the socket is still open and waiting for a clean server - // shutdown. + // The client's rx part of the socket is still open and waiting for a clean + // server shutdown. if must_delay_when_closing { // delay closing the socket tokio::time::sleep(std::time::Duration::from_millis(CLOSE_DELAY)).await; @@ -171,8 +179,8 @@ pub async fn bind_test_server< break; } "send_record_with_bad_mac" => { - // send a record which a bad MAC which will trigger the `bad_record_mac` alert on - // the client side + // send a record which a bad MAC which will trigger the `bad_record_mac` alert + // on the client side let (socket, _tls) = conn.into_inner(); @@ -228,26 +236,35 @@ pub async fn bind_test_server< } } +// Adapted from https://github.com/hyperium/hyper/blob/721785efad8537513e48d900a85c05ce79483018/examples/echo.rs #[tracing::instrument] -async fn echo(req: Request) -> Result, hyper::Error> { +async fn echo( + req: Request, +) -> Result>, hyper::Error> { match (req.method(), req.uri().path()) { // Serve some instructions at / - (&Method::GET, "/") => Ok(Response::new(Body::from( + (&Method::GET, "/") => Ok(Response::new(full( "Try POSTing data to /echo such as: `curl localhost:3000/echo -XPOST -d 'hello world'`", ))), // Simply echo the body back to the client. - (&Method::POST, "/echo") => Ok(Response::new(req.into_body())), + (&Method::POST, "/echo") => Ok(Response::new(req.into_body().boxed())), // Convert to uppercase before sending back to client using a stream. (&Method::POST, "/echo/uppercase") => { - let chunk_stream = req.into_body().map_ok(|chunk| { - chunk - .iter() - .map(|byte| byte.to_ascii_uppercase()) - .collect::>() + let frame_stream = req.into_body().map_frame(|frame| { + let frame = if let Ok(data) = frame.into_data() { + data.iter() + .map(|byte| byte.to_ascii_uppercase()) + .collect::() + } else { + Bytes::new() + }; + + Frame::data(frame) }); - Ok(Response::new(Body::wrap_stream(chunk_stream))) + + Ok(Response::new(frame_stream.boxed())) } // Reverse the entire body before sending back to the client. @@ -257,17 +274,29 @@ async fn echo(req: Request) -> Result, hyper::Error> { // So here we do `.await` on the future, waiting on concatenating the full body, // then afterwards the content can be reversed. Only then can we return a `Response`. (&Method::POST, "/echo/reversed") => { - let whole_body = hyper::body::to_bytes(req.into_body()).await?; + let whole_body = req.collect().await?.to_bytes(); let reversed_body = whole_body.iter().rev().cloned().collect::>(); - Ok(Response::new(Body::from(reversed_body))) + Ok(Response::new(full(reversed_body))) } // Return the 404 Not Found for other routes. _ => { - let mut not_found = Response::default(); + let mut not_found = Response::new(empty()); *not_found.status_mut() = StatusCode::NOT_FOUND; Ok(not_found) } } } + +fn full>(chunk: T) -> BoxBody { + Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed() +} + +fn empty() -> BoxBody { + Empty::::new() + .map_err(|never| match never {}) + .boxed() +} diff --git a/crates/tls/server-fixture/src/openssl.cnf b/crates/tls/server-fixture/src/openssl.cnf new file mode 100644 index 0000000000..6a52a8855b --- /dev/null +++ b/crates/tls/server-fixture/src/openssl.cnf @@ -0,0 +1,7 @@ +[ v3_req ] +basicConstraints = CA:FALSE +keyUsage = nonRepudiation, digitalSignature, keyEncipherment +subjectAltName = @alt_names + +[ alt_names ] +DNS.1 = test-server.io \ No newline at end of file diff --git a/crates/tls/server-fixture/src/root_ca.crt b/crates/tls/server-fixture/src/root_ca.crt new file mode 100644 index 0000000000..e3b0cbf520 --- /dev/null +++ b/crates/tls/server-fixture/src/root_ca.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDrTCCApWgAwIBAgIUPO4oH+2bEenSnnIz7irzoPsDS4owDQYJKoZIhvcNAQEL +BQAwZTELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MRIwEAYDVQQKDAl0bHNub3RhcnkxCzAJBgNVBAsMAklUMRYwFAYDVQQDDA10bHNu +b3Rhcnkub3JnMCAXDTI0MDgwMjEwNDMyN1oYDzIxMjQwODAzMTA0MzI3WjBlMQsw +CQYDVQQGEwJVUzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxEjAQBgNV +BAoMCXRsc25vdGFyeTELMAkGA1UECwwCSVQxFjAUBgNVBAMMDXRsc25vdGFyeS5v +cmcwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCVvgedJ3zVE7ICYoaD +CwybhEN/6g1baoyDRVD8fpZfhdkh0uMMKBFqRa1qO9wF3Fthq6DJRaHsmZeE42Jm +aDvlRtaKDfB0MMcSeNqmP8ia7+8TFgMBY/YP7dW3d9QADFHLqyMcS6O2iaSMjBzg +4nx33TdAhQOIPHOSZbMZJGO18jn55GEeogIz6UiV8gqjQtbel/cn8jXi2rOgub+p +CZziixQ6ikppdW6a8p37B5W4/WNHDIRgRP890q0GyrEJWtj9TwyMmeC6/0mxXjZC +caLWV0072j3Dd+66XvkeL04mSe4Bp0YUs8jcTPsfOAo3FAvPgyQ6UqQfZBqOnU93 +xmYzAgMBAAGjUzBRMB0GA1UdDgQWBBTJgXIkPw2ZVkTscFx/CKZZrhymzTAfBgNV +HSMEGDAWgBTJgXIkPw2ZVkTscFx/CKZZrhymzTAPBgNVHRMBAf8EBTADAQH/MA0G +CSqGSIb3DQEBCwUAA4IBAQAPRStXfyEyQ9nThEXsCC+qDPUDM9hrE9YKzGlqcjzA +fRcWbtZQ4hR6nhoQ1wcs5o56R0xk7BvwP4Y+wX499uBDUUUWpFRasrlPjvAVoseU +IXtIqjDKoag1Q2JULKD3cWxxbXITJttdUCEf0Kfn/3tS+6Fev6fquV47Dp32SuZP +tnacPbsKC/q/K80siFoWYNRrwED+c4gnnOCKI/VCfv/oREGFpgyD2pFuLWMLVH3z +wZW1EzkzQKgCNCjzs7oh0CyA2TFdJ4xVgDqAcmD5EPl8r6Nc9joYM/zBkY/cFRvp +AVhUnBPgFr1gb5CcG+Y0nptal64ukTYrgfMaIiO9h6sx +-----END CERTIFICATE----- diff --git a/crates/tls/server-fixture/src/root_ca.key b/crates/tls/server-fixture/src/root_ca.key new file mode 100644 index 0000000000..712469bf3e --- /dev/null +++ b/crates/tls/server-fixture/src/root_ca.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCVvgedJ3zVE7IC +YoaDCwybhEN/6g1baoyDRVD8fpZfhdkh0uMMKBFqRa1qO9wF3Fthq6DJRaHsmZeE +42JmaDvlRtaKDfB0MMcSeNqmP8ia7+8TFgMBY/YP7dW3d9QADFHLqyMcS6O2iaSM +jBzg4nx33TdAhQOIPHOSZbMZJGO18jn55GEeogIz6UiV8gqjQtbel/cn8jXi2rOg +ub+pCZziixQ6ikppdW6a8p37B5W4/WNHDIRgRP890q0GyrEJWtj9TwyMmeC6/0mx +XjZCcaLWV0072j3Dd+66XvkeL04mSe4Bp0YUs8jcTPsfOAo3FAvPgyQ6UqQfZBqO +nU93xmYzAgMBAAECggEAB+ybV4rgQCBqMlZyGtuJ/8Ld6uuBEx442wuJ2nV9J1yc +cyicq6cv1hQONh8pKMWSr8EBjGqFw/u+znaqsuj/iRsYvbaOISqhpk3Eow6guD5L +7xJ3oepfJP786S12B8ifHYGWz+ewKA1HAB8RZNSSKf+ywv8nAt3Rbzpi4h47CUT4 +Z06gLJYZNimLVPIWLzrHa+/ZOyHq/XRWsr6GTFgXfT6nudfCxzdlIdajrBvaSLBG +KbOs52tffEUHn+V1AoH6kmNp0EPSCbnR2b1KIv7loj6vi52UBpipjNFwa8PNzWfL +Cuu9N6fl7qRv9VYCnC2gJz6rTARaNJWf57UP2avygQKBgQDJC89y4lgai8qJLv3w +go+kFiWnZE0C8U69sOmNeACYhunQFKX2cG7EkTuPOnZj8XJcLYVHMSJLrEJcqyX/ +wDv1at+KqDMQsf0j7NHCSpkoG93wlffCB87VPndy7ajRN4d17tbQOJP6zmOQo1YP +7MTeVtDF3JF9IxfTb+Pxmp5nswKBgQC+rDzBN8Drr1jp6FfzZrDcr/gvlSftXupF +jTSkSxywQjophp02Hdi2t32Xq+wEuaMaJUOtywK/NVs5hJeGC584rWQjLObh7oUD +td+2V802kzsERSeDiDwtBYgjePtkeO7MXadGLwJSaZxocjcjgGj2qWPs9ihUASuB +TEtkO0jHgQKBgQCUFGXc2YhJLTOlrX4O+ytvkXx0ebUbeL8lirvLnlrZ/W0T/VFs +Xc3IbKxwx3/SB1HTQRgMosz+7ccHWGwpnt7K2cgC6faK0n6ASnsJX0bFuxjSjrMp +L/URLexvM0uHph3ZKG0CetnL/t5o91V5b0xl843cXqSuhf2Tl7NODjOkbwKBgAIn +5mP04myHxgSXCO+KmLNWFgNLt3DaouF4cEDvTHq9tPSlPf/PpJSkTHo7imafRrXT ++AjuA7DvxIFI+4GbfghhBYHUTyP802owU0A3i+1zCrbIpWK6VpvXtStZgdYn++M5 +p9uGSotuAEO6Dt+K4yTu019phRk2DizfFPckKHWBAoGAehmqjR+5T7SpDiZwXFyN +CA4qKVoYPexmNjbECYkpLbEkPxOc145H0Y4oHOBH46jIiHumSV3N2bvywYQ2IlyV +BSGqGFAeFhpRAtMKCFMG7bNPTbskKcpUyGD2csoiYxXsFuFZX4Db9i0tpjt57C/a +9ij7zNzrAj5Iby8EMykK+aM= +-----END PRIVATE KEY----- diff --git a/crates/tls/server-fixture/src/root_ca.srl b/crates/tls/server-fixture/src/root_ca.srl new file mode 100644 index 0000000000..f4add3930b --- /dev/null +++ b/crates/tls/server-fixture/src/root_ca.srl @@ -0,0 +1 @@ +1B924A233FDF6D40DDA57D7E4C0C37DE64BE996A diff --git a/crates/tls/server-fixture/src/root_ca_cert.der b/crates/tls/server-fixture/src/root_ca_cert.der new file mode 100644 index 0000000000..10a2ab64a2 Binary files /dev/null and b/crates/tls/server-fixture/src/root_ca_cert.der differ diff --git a/crates/tls/server-fixture/src/test_server.crt b/crates/tls/server-fixture/src/test_server.crt new file mode 100644 index 0000000000..8065b0ce26 --- /dev/null +++ b/crates/tls/server-fixture/src/test_server.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID0DCCArigAwIBAgIUG5JKIz/fbUDdpX1+TAw33mS+mWowDQYJKoZIhvcNAQEL +BQAwZTELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MRIwEAYDVQQKDAl0bHNub3RhcnkxCzAJBgNVBAsMAklUMRYwFAYDVQQDDA10bHNu +b3Rhcnkub3JnMCAXDTI0MDgwMjEwNDM1NloYDzIxMjQwODAzMTA0MzU2WjBmMQsw +CQYDVQQGEwJVUzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxEjAQBgNV +BAoMCXRsc25vdGFyeTELMAkGA1UECwwCSVQxFzAVBgNVBAMMDnRlc3Qtc2VydmVy +LmlvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAoeBDxxxAASDtcXx4 +07dK7YfLw2+cRz5rDdv/HHPHJLGJTvCXfZCTfV3y3KzTuLeOWHhGyG1bH075Jg/1 +TZ+nTdr/T/78mV4GXilf6hvmnwX3Pr7KLfXDEizRDKbnQqTgThs9hgHJ5pm8Jkid +5dWJEnvT5ChaBzwITpAe7qD05dVln7wkayKkT28IuV1iOglXjoBsozsL2qvj2wmL +pYQqn17Ir98CY9AUjJ/D4tAGRbxGmhQ3+kLakO2wR+TA0E51opjlWeP4qc8i1OWp +MH3fz5GddrC0BYVF0yute2VgjOXlM0PB2V4aMrqeB52hppix9XZOXymLeVQHddXQ +YbtPPQIDAQABo3UwczAJBgNVHRMEAjAAMAsGA1UdDwQEAwIF4DAZBgNVHREEEjAQ +gg50ZXN0LXNlcnZlci5pbzAdBgNVHQ4EFgQUXLxbOoGpjtxTs0zuIRtl74jPNokw +HwYDVR0jBBgwFoAUyYFyJD8NmVZE7HBcfwimWa4cps0wDQYJKoZIhvcNAQELBQAD +ggEBAIxHgJqh26P0XHawz8QQgYQlKHD74uuluMgStArVIydE1/gRqhuqBNt4kvdE +/lU/ZtlfQ2sZjB5a1fz3Rj4VNxlysTvp8d5fMOzcYhKTYx5eWuwejWdioarg7CUO +lHy65gRSAw9E4qiyLi3zXFK2Lu/ta29/RGH+OpyziSvoD/EQ0h+8Hr792UTkJqHB +eHQkaLTCr/QfSd/rf+0ao3/LptJTeDC7L2hN54L692SC/PXTW197d0+1HCjEmwmK +smgfAKZcIlfiRlN7HMGWaCIRpEVcdZmOBhiTDxpQVZQdbEAcME8y7ALLTYOMpyBE +a2FHrDiKtxNQZCZnaoUw3seXKHg= +-----END CERTIFICATE----- diff --git a/crates/tls/server-fixture/src/test_server.csr b/crates/tls/server-fixture/src/test_server.csr new file mode 100644 index 0000000000..6be1fbcf4c --- /dev/null +++ b/crates/tls/server-fixture/src/test_server.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICqzCCAZMCAQAwZjELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYD +VQQHDARDaXR5MRIwEAYDVQQKDAl0bHNub3RhcnkxCzAJBgNVBAsMAklUMRcwFQYD +VQQDDA50ZXN0LXNlcnZlci5pbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC +ggEBAKHgQ8ccQAEg7XF8eNO3Su2Hy8NvnEc+aw3b/xxzxySxiU7wl32Qk31d8tys +07i3jlh4RshtWx9O+SYP9U2fp03a/0/+/JleBl4pX+ob5p8F9z6+yi31wxIs0Qym +50Kk4E4bPYYByeaZvCZIneXViRJ70+QoWgc8CE6QHu6g9OXVZZ+8JGsipE9vCLld +YjoJV46AbKM7C9qr49sJi6WEKp9eyK/fAmPQFIyfw+LQBkW8RpoUN/pC2pDtsEfk +wNBOdaKY5Vnj+KnPItTlqTB938+RnXawtAWFRdMrrXtlYIzl5TNDwdleGjK6nged +oaaYsfV2Tl8pi3lUB3XV0GG7Tz0CAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4IBAQCA +aNz5mVndHInJJloJIuFvHbQLeuglEfn1Iyjjk3ILLm29RqcVlJ1LsnZZXG4rv8JH +YWHpvsLLrR/nIkT+wxFCfYVHp8szpyLVW/mTLWb6xAB/d6i1SEmYSN0LNkmNvWFS +kDq9A3v5sa9SZ1/btgfIVa6QzZWHuqYqad3KWJcpn+PckqiG+Bihx69TGsIMJHgN +9P//ra2lWyL391KGycNrKTbydpFjRT6vwC2QZJWG47liRS/PYfm6wtdoJa7Mw9vl +ciBvDhTFF7FYl0uV1NlzIoVyChMmRv2JR66efcTfWqfP44E4dhBKHIpBxc8+4GtI +ol18bSfvVKBlIyoZPdRP +-----END CERTIFICATE REQUEST----- diff --git a/crates/tls/server-fixture/src/test_server.key b/crates/tls/server-fixture/src/test_server.key new file mode 100644 index 0000000000..a3f4a433b8 --- /dev/null +++ b/crates/tls/server-fixture/src/test_server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCh4EPHHEABIO1x +fHjTt0rth8vDb5xHPmsN2/8cc8cksYlO8Jd9kJN9XfLcrNO4t45YeEbIbVsfTvkm +D/VNn6dN2v9P/vyZXgZeKV/qG+afBfc+vsot9cMSLNEMpudCpOBOGz2GAcnmmbwm +SJ3l1YkSe9PkKFoHPAhOkB7uoPTl1WWfvCRrIqRPbwi5XWI6CVeOgGyjOwvaq+Pb +CYulhCqfXsiv3wJj0BSMn8Pi0AZFvEaaFDf6QtqQ7bBH5MDQTnWimOVZ4/ipzyLU +5akwfd/PkZ12sLQFhUXTK617ZWCM5eUzQ8HZXhoyup4HnaGmmLH1dk5fKYt5VAd1 +1dBhu089AgMBAAECggEAARLZCuZdEPou7k8Xs2Ub0hzRyny1r03yrSeFtvftnN4F +6HLKuRfPDUh6O+HJkF1nTEmVQ+8LE6y/v542JPWnc79oF20RhSg30pgOUsyB6GbE +ZQjO6SR1eWwNVuV50y9UwtqGEJrNGVfGWlqmRsf2c3DuztdAVvFG/NO9Nh1LLTBi +bxkyhVma4NRs3Yt9tzL/wCeIAWFXjHIctYxUw7TUd21s7/i/+dN96wEGywpuN0zs +HeiV8OOV0wEgaIldShJnDW7EF5AVl9eGZznm5gJnCqE4/Sq+0CuRLW95jFwZu0GR +tUNjSi+ypPLt0vf1LIM0hrYoFNB0xxLdnRqpEB96AQKBgQDatlglMc0KvJ2ghNGJ +TbK/jtkb7uEkdFxM+l+4EYOdpVrh7gQMekSDv6nNyze9eCBWouB8m5ty1+HKTPk8 +X+4shd7VE8joJus/R/0vqVrcK3eh/fuL8E2TgZbKmsBh168NQY5dxMgs69do3bPg +k46EUxLd3MJ9fdv2xoila7c2DQKBgQC9eVAyVlDA+VBonZA1kmVY4858JKMsPH+Y +pfC1ty+QwFxsN3AxpYolUYJAp11vDCdGP9GpOaHhfGmxLCIyUSWDUdcBjXZCz/lW +76xT4wh7LieOezhPMzoLP+vdel9xvLGQYu2a6GvRaUxfPH7fbFaGBuTvi4JjEaq0 +CjtmksDh8QKBgDXKHseXBeycEtBFmhsApvOBuFesWmbSz1iHQz9L32jIIB/sn8ZJ +08vrOWHJlv3cK2fjSv6abpLCEV/lqm500WjVy8XvxbuCxtybYeN07Um0zwliI5l5 +Ejsy5dkSUjo+B2llNBRPr0ONBT9fNzwGTkiw/bTe9F5Us+JvVXAJm9eJAoGAX04X +Jcq/AeImLQkcUaYarlSgN1eicAzaTakiY/UJyvDHTHOyTnaq/0x5jRXibIobcz2E +s29W2vnenAzMAq1Ihj5zPMewNbkw/Sa/cs6fJH65zPR0BXqJ9sCnXpdATRCR7EOm +qqXAHeyuSrU+SBnRh8cN/uQYqMZpK/h9moG03bECgYB8MtBxLxLctVYofIwyzRF/ +gu4Tm84bVBeaRlUapuXB0ZsC6JBnCM9cWDR2SxLHEbjGQ07Oaae+OQb4C5Uc0Ex0 +fmu0ATFb52BScnNFw+elQHk8ynv0sfV6LbPe2C1uNBVFtatRMgNJGkg7UauM68TJ +VOZB9TwtVa9Sr0QpQNiyvg== +-----END PRIVATE KEY----- diff --git a/crates/tls/server-fixture/src/test_server_cert.der b/crates/tls/server-fixture/src/test_server_cert.der new file mode 100644 index 0000000000..c6c0048492 Binary files /dev/null and b/crates/tls/server-fixture/src/test_server_cert.der differ diff --git a/crates/tls/server-fixture/src/test_server_private_key.der b/crates/tls/server-fixture/src/test_server_private_key.der new file mode 100644 index 0000000000..d2591402c8 Binary files /dev/null and b/crates/tls/server-fixture/src/test_server_private_key.der differ diff --git a/crates/verifier/Cargo.toml b/crates/verifier/Cargo.toml new file mode 100644 index 0000000000..9453bd18f2 --- /dev/null +++ b/crates/verifier/Cargo.toml @@ -0,0 +1,53 @@ +[package] +name = "tlsn-verifier" +authors = ["TLSNotary Team"] +description = "A library for the TLSNotary verifier" +keywords = ["tls", "mpc", "2pc"] +categories = ["cryptography"] +license = "MIT OR Apache-2.0" +version = "0.1.0-alpha.7" +edition = "2021" + +[features] +default = ["rayon"] +rayon = ["mpz-common/rayon"] +force-st = ["mpz-common/force-st"] +# Enables the AuthDecode protocol which allows to prove zk-friendly hashes over the transcript data. +# This is an early iteration meant for gathering feedback and assessing performance. As such, this +# feature has an "_unsafe" suffix since it will leak the ranges of data committed to. +# Future iterations will get rid of the leakage at the cost of worse performance. +# This feature is EXPERIMENTAL and will be removed in future releases without prior notice. +authdecode_unsafe = ["tlsn-common/authdecode_unsafe_common"] + +[dependencies] +tlsn-authdecode = { workspace = true } +tlsn-authdecode-core = { workspace = true } +tlsn-authdecode-transcript = { workspace = true } +tlsn-common = { workspace = true } +tlsn-core = { workspace = true } +tlsn-tls-core = { workspace = true } +tlsn-tls-mpc = { workspace = true } +tlsn-utils = { workspace = true } +tlsn-utils-aio = { workspace = true } + +serde = { workspace = true } +serio = { workspace = true, features = ["compat"] } +uid-mux = { workspace = true, features = ["serio"] } + +mpz-circuits = { workspace = true } +mpz-common = { workspace = true } +mpz-core = { workspace = true } +mpz-garble = { workspace = true } +mpz-garble-core = { workspace = true } +mpz-ole = { workspace = true } +mpz-ot = { workspace = true } +mpz-share-conversion = { workspace = true } + +derive_builder = { workspace = true } +futures = { workspace = true } +opaque-debug = { workspace = true } +rand = { workspace = true } +signature = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } +web-time = { workspace = true } diff --git a/crates/verifier/src/authdecode.rs b/crates/verifier/src/authdecode.rs new file mode 100644 index 0000000000..203bd1b6e7 --- /dev/null +++ b/crates/verifier/src/authdecode.rs @@ -0,0 +1,157 @@ +use std::mem; + +use authdecode_core::{ + backend::{ + halo2::{Bn256F, CHUNK_SIZE}, + traits::Field, + }, + msgs::{Commit, Proofs}, + verifier::{CommitmentReceived, Initialized, Verifier, VerifierError as CoreVerifierError}, +}; +use authdecode_transcript::{TranscriptData, TranscriptEncoder}; +use tlsn_core::{ + hash::{HashAlgId, TypedHash}, + transcript::{Idx, PlaintextHash}, +}; + +/// Returns an AuthDecode verifier depending on the hash algorithm contained in the request. +pub(crate) fn authdecode_verifier(alg: &HashAlgId) -> impl TranscriptVerifier { + match alg { + &HashAlgId::POSEIDON_BN256_434 => PoseidonHalo2Verifier::new(), + _ => unimplemented!(), + } +} + +/// An AuthDecode verifier for a TLS transcript. +pub(crate) trait TranscriptVerifier { + type CommitmentMessage: serio::Deserialize; + type ProofMessage: serio::Deserialize; + + /// Creates a new verifier. + fn new() -> Self; + + /// Receives commitments. + /// + /// # Arguments + /// + /// * `commitments` - The commitments to receive. + /// * `max_plaintext` - The maximum bytesize of committed plaintext allowed to be contained in the + /// `commitments`. + fn receive_commitments( + &mut self, + commitments: Self::CommitmentMessage, + max_plaintext: usize, + ) -> Result<(), TranscriptVerifierError>; + + /// Verifies proofs and returns authenticated plaintext hashes. + /// + /// # Arguments + /// + /// * `proofs` - The proofs to verify. + /// * `seed` - The seed to generate encodings from. + fn verify( + &mut self, + proofs: Self::ProofMessage, + seed: [u8; 32], + ) -> Result, TranscriptVerifierError>; +} + +/// An AuthDecode verifier which uses hashes of the POSEIDON_HALO2 kind. +pub(crate) struct PoseidonHalo2Verifier { + /// The verifier in the [Initialized] state. + initialized: Option>, + /// The verifier in the [CommitmentReceived] state. + commitment_received: + Option, Bn256F>>, +} + +impl TranscriptVerifier for PoseidonHalo2Verifier { + type CommitmentMessage = Commit; + type ProofMessage = Proofs; + + fn new() -> Self { + Self { + initialized: Some(Verifier::new(Box::new( + authdecode_core::backend::halo2::verifier::Verifier::new(), + ))), + commitment_received: None, + } + } + + fn receive_commitments( + &mut self, + commitments: Self::CommitmentMessage, + max_plaintext: usize, + ) -> Result<(), TranscriptVerifierError> { + let verifier = mem::take(&mut self.initialized).ok_or(TranscriptVerifierError::Other( + "The verifier was called in the wrong state".to_string(), + ))?; + + if commitments.commitment_count() != commitments.chunk_count() { + return Err(TranscriptVerifierError::Other( + "Some commitments contain more than one chunk of plaintext data".to_string(), + )); + } + + if commitments.chunk_count() * CHUNK_SIZE > max_plaintext { + return Err(TranscriptVerifierError::Other( + "The amount of data in commitments exceeded the limit".to_string(), + )); + } + + self.commitment_received = Some(verifier.receive_commitments(commitments)?); + + Ok(()) + } + + fn verify( + &mut self, + proofs: Self::ProofMessage, + seed: [u8; 32], + ) -> Result, TranscriptVerifierError> { + let verifier = + mem::take(&mut self.commitment_received).ok_or(TranscriptVerifierError::Other( + "The verifier was called in the wrong state".to_string(), + ))?; + + let encoding_provider = TranscriptEncoder::new(seed); + + let verifier = verifier.verify(proofs, &encoding_provider)?; + + let coms = verifier + .commitments() + .iter() + .map(|com| { + // Earlier we checked that each commitment has only one chunk. + debug_assert!(com.chunk_commitments().len() == 1); + + let com = &com.chunk_commitments()[0]; + let range = com.ids(); + PlaintextHash { + direction: *range.direction(), + hash: TypedHash { + alg: HashAlgId::POSEIDON_BN256_434, + value: com + .plaintext_hash() + .clone() + .to_bytes_be() + .try_into() + .unwrap(), + }, + idx: Idx::new(range.range().clone()), + } + }) + .collect(); + + Ok(coms) + } +} + +#[derive(Debug, thiserror::Error)] +/// Error for [TranscriptVerifier]. +pub(crate) enum TranscriptVerifierError { + #[error(transparent)] + CoreProtocolError(#[from] CoreVerifierError), + #[error("AuthDecode verifier failed with an error: {0}")] + Other(String), +} diff --git a/crates/verifier/src/config.rs b/crates/verifier/src/config.rs new file mode 100644 index 0000000000..4ac898fa68 --- /dev/null +++ b/crates/verifier/src/config.rs @@ -0,0 +1,97 @@ +use mpz_ot::{chou_orlandi, kos}; +use std::{ + fmt::{Debug, Formatter, Result}, + sync::Arc, +}; +use tls_mpc::{MpcTlsCommonConfig, MpcTlsFollowerConfig, TranscriptConfig}; +use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; +use tlsn_core::CryptoProvider; + +/// Configuration for the [`Verifier`](crate::tls::Verifier). +#[allow(missing_docs)] +#[derive(derive_builder::Builder)] +#[builder(pattern = "owned")] +pub struct VerifierConfig { + protocol_config_validator: ProtocolConfigValidator, + /// Cryptography provider. + #[builder(default, setter(into))] + crypto_provider: Arc, +} + +impl Debug for VerifierConfig { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + f.debug_struct("VerifierConfig") + .field("protocol_config_validator", &self.protocol_config_validator) + .finish_non_exhaustive() + } +} + +impl VerifierConfig { + /// Creates a new configuration builder. + pub fn builder() -> VerifierConfigBuilder { + VerifierConfigBuilder::default() + } + + /// Returns the protocol configuration validator. + pub fn protocol_config_validator(&self) -> &ProtocolConfigValidator { + &self.protocol_config_validator + } + + /// Returns the cryptography provider. + pub fn crypto_provider(&self) -> &CryptoProvider { + &self.crypto_provider + } + + pub(crate) fn build_base_ot_sender_config(&self) -> chou_orlandi::SenderConfig { + chou_orlandi::SenderConfig::default() + } + + pub(crate) fn build_base_ot_receiver_config(&self) -> chou_orlandi::ReceiverConfig { + chou_orlandi::ReceiverConfig::builder() + .receiver_commit() + .build() + .unwrap() + } + + pub(crate) fn build_ot_sender_config(&self) -> kos::SenderConfig { + kos::SenderConfig::builder() + .sender_commit() + .build() + .unwrap() + } + + pub(crate) fn build_ot_receiver_config(&self) -> kos::ReceiverConfig { + kos::ReceiverConfig::default() + } + + pub(crate) fn build_mpc_tls_config( + &self, + protocol_config: &ProtocolConfig, + ) -> MpcTlsFollowerConfig { + MpcTlsFollowerConfig::builder() + .common( + MpcTlsCommonConfig::builder() + .tx_config( + TranscriptConfig::default_tx() + .max_online_size(protocol_config.max_sent_data()) + .build() + .unwrap(), + ) + .rx_config( + TranscriptConfig::default_rx() + .max_online_size(protocol_config.max_recv_data_online()) + .max_offline_size( + protocol_config.max_recv_data() + - protocol_config.max_recv_data_online(), + ) + .build() + .unwrap(), + ) + .handshake_commit(true) + .build() + .unwrap(), + ) + .build() + .unwrap() + } +} diff --git a/crates/verifier/src/error.rs b/crates/verifier/src/error.rs new file mode 100644 index 0000000000..8fc448b788 --- /dev/null +++ b/crates/verifier/src/error.rs @@ -0,0 +1,153 @@ +use std::{error::Error, fmt}; +use tls_mpc::MpcTlsError; + +/// Error for [`Verifier`](crate::Verifier). +#[derive(Debug, thiserror::Error)] +pub struct VerifierError { + kind: ErrorKind, + source: Option>, +} + +impl VerifierError { + fn new(kind: ErrorKind, source: E) -> Self + where + E: Into>, + { + Self { + kind, + source: Some(source.into()), + } + } + + pub(crate) fn attestation(source: E) -> Self + where + E: Into>, + { + Self::new(ErrorKind::Attestation, source) + } + + pub(crate) fn verify(source: E) -> Self + where + E: Into>, + { + Self::new(ErrorKind::Verify, source) + } +} + +#[derive(Debug)] +enum ErrorKind { + Io, + Config, + Mpc, + Attestation, + Verify, + #[cfg(feature = "authdecode_unsafe")] + AuthDecode, +} + +impl fmt::Display for VerifierError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("verifier error: ")?; + + match self.kind { + ErrorKind::Io => f.write_str("io error")?, + ErrorKind::Config => f.write_str("config error")?, + ErrorKind::Mpc => f.write_str("mpc error")?, + ErrorKind::Attestation => f.write_str("attestation error")?, + ErrorKind::Verify => f.write_str("verification error")?, + #[cfg(feature = "authdecode_unsafe")] + ErrorKind::AuthDecode => f.write_str("authdecode error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source)?; + } + + Ok(()) + } +} + +impl From for VerifierError { + fn from(e: std::io::Error) -> Self { + Self::new(ErrorKind::Io, e) + } +} + +impl From for VerifierError { + fn from(e: tlsn_common::config::ProtocolConfigError) -> Self { + Self::new(ErrorKind::Config, e) + } +} + +impl From for VerifierError { + fn from(e: uid_mux::yamux::ConnectionError) -> Self { + Self::new(ErrorKind::Io, e) + } +} + +impl From for VerifierError { + fn from(e: mpz_common::ContextError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for VerifierError { + fn from(e: MpcTlsError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for VerifierError { + fn from(e: mpz_ot::OTError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for VerifierError { + fn from(e: mpz_ot::kos::SenderError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for VerifierError { + fn from(e: mpz_ole::OLEError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for VerifierError { + fn from(e: mpz_ot::kos::ReceiverError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for VerifierError { + fn from(e: mpz_garble::VmError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for VerifierError { + fn from(e: mpz_garble::protocol::deap::DEAPError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for VerifierError { + fn from(e: mpz_garble::MemoryError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +impl From for VerifierError { + fn from(e: mpz_garble::VerifyError) -> Self { + Self::new(ErrorKind::Mpc, e) + } +} + +#[cfg(feature = "authdecode_unsafe")] +impl From for VerifierError { + fn from(e: authdecode_core::verifier::VerifierError) -> Self { + Self::new(ErrorKind::AuthDecode, e) + } +} diff --git a/crates/verifier/src/lib.rs b/crates/verifier/src/lib.rs new file mode 100644 index 0000000000..007d446813 --- /dev/null +++ b/crates/verifier/src/lib.rs @@ -0,0 +1,386 @@ +//! TLSNotary verifier library. + +#![deny(missing_docs, unreachable_pub, unused_must_use)] +#![deny(clippy::all)] +#![forbid(unsafe_code)] + +#[cfg(feature = "authdecode_unsafe")] +pub(crate) mod authdecode; +pub(crate) mod config; +mod error; +mod notarize; +pub mod state; +mod verify; + +pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderError}; +pub use error::VerifierError; +use mpz_common::Allocate; +use serio::{stream::IoStreamExt, StreamExt}; +use uid_mux::FramedUidMux; + +use web_time::{SystemTime, UNIX_EPOCH}; + +use futures::{AsyncRead, AsyncWrite}; +use mpz_garble::config::Role as DEAPRole; +use mpz_ot::{chou_orlandi, kos}; +use rand::Rng; +use state::{Notarize, Verify}; +use tls_mpc::{build_components, MpcTlsFollower, MpcTlsFollowerData, TlsRole}; +use tlsn_common::{ + config::ProtocolConfig, + mux::{attach_mux, MuxControl}, + DEAPThread, Executor, OTReceiver, OTSender, Role, +}; +use tlsn_core::{ + attestation::{Attestation, AttestationConfig}, + connection::{ConnectionInfo, ServerName, TlsVersion, TranscriptLength}, + transcript::PartialTranscript, +}; + +use tracing::{debug, info, info_span, instrument, Span}; + +/// Information about the TLS session. +#[derive(Debug)] +pub struct SessionInfo { + /// Server's name. + pub server_name: ServerName, + /// Connection information. + pub connection_info: ConnectionInfo, +} + +/// A Verifier instance. +pub struct Verifier { + config: VerifierConfig, + span: Span, + state: T, +} + +impl Verifier { + /// Creates a new verifier. + pub fn new(config: VerifierConfig) -> Self { + let span = info_span!("verifier"); + Self { + config, + span, + state: state::Initialized, + } + } + + /// Sets up the verifier. + /// + /// This performs all MPC setup. + /// + /// # Arguments + /// + /// * `socket` - The socket to the prover. + #[instrument(parent = &self.span, level = "info", skip_all, err)] + pub async fn setup( + self, + socket: S, + ) -> Result, VerifierError> { + let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Verifier); + + // Maximum thread forking concurrency of 8. + // TODO: Determine the optimal number of threads. + let mut exec = Executor::new(mux_ctrl.clone(), 8); + + let mut io = mux_fut + .poll_with(mux_ctrl.open_framed(b"tlsnotary")) + .await?; + + // Receives protocol configuration from prover to perform compatibility check. + let protocol_config = mux_fut + .poll_with(async { + let peer_configuration: ProtocolConfig = io.expect_next().await?; + self.config + .protocol_config_validator() + .validate(&peer_configuration)?; + + Ok::<_, VerifierError>(peer_configuration) + }) + .await?; + + #[cfg(feature = "authdecode_unsafe")] + let wants_authdecode = protocol_config.max_authdecode_data() > 0; + + let encoder_seed: [u8; 32] = rand::rngs::OsRng.gen(); + let (mpc_tls, vm, ot_send) = mux_fut + .poll_with(setup_mpc_backend( + &self.config, + protocol_config, + &mux_ctrl, + &mut exec, + encoder_seed, + )) + .await?; + + let ctx = mux_fut.poll_with(exec.new_thread()).await?; + + Ok(Verifier { + config: self.config, + span: self.span, + state: state::Setup { + io, + mux_ctrl, + mux_fut, + mpc_tls, + vm, + ot_send, + ctx, + encoder_seed, + #[cfg(feature = "authdecode_unsafe")] + wants_authdecode, + }, + }) + } + + /// Runs the TLS verifier to completion, notarizing the TLS session. + /// + /// This is a convenience method which runs all the steps needed for + /// notarization. + /// + /// # Arguments + /// + /// * `socket` - The socket to the prover. + /// * `signer` - The signer used to sign the notarization result. + #[instrument(parent = &self.span, level = "info", skip_all, err)] + pub async fn notarize( + self, + socket: S, + config: &AttestationConfig, + ) -> Result { + self.setup(socket) + .await? + .run() + .await? + .start_notarize() + .finalize(config) + .await + } + + /// Runs the TLS verifier to completion, verifying the TLS session. + /// + /// This is a convenience method which runs all the steps needed for + /// verification. + /// + /// # Arguments + /// + /// * `socket` - The socket to the prover. + #[instrument(parent = &self.span, level = "info", skip_all, err)] + pub async fn verify( + self, + socket: S, + ) -> Result<(PartialTranscript, SessionInfo), VerifierError> { + let mut verifier = self.setup(socket).await?.run().await?.start_verify(); + let transcript = verifier.receive().await?; + + let session_info = verifier.finalize().await?; + Ok((transcript, session_info)) + } +} + +impl Verifier { + /// Runs the verifier until the TLS connection is closed. + #[instrument(parent = &self.span, level = "info", skip_all, err)] + pub async fn run(self) -> Result, VerifierError> { + let state::Setup { + io, + mux_ctrl, + mut mux_fut, + mpc_tls, + vm, + ot_send, + ctx, + encoder_seed, + #[cfg(feature = "authdecode_unsafe")] + wants_authdecode, + } = self.state; + + let start_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let MpcTlsFollowerData { + server_key, + bytes_sent, + bytes_recv, + .. + } = mux_fut.poll_with(mpc_tls.run().1).await?; + + info!("Finished TLS session"); + + let connection_info = ConnectionInfo { + time: start_time, + version: TlsVersion::V1_2, + transcript_length: TranscriptLength { + sent: bytes_sent as u32, + received: bytes_recv as u32, + }, + }; + + Ok(Verifier { + config: self.config, + span: self.span, + state: state::Closed { + io, + mux_ctrl, + mux_fut, + vm, + ot_send, + ctx, + encoder_seed, + server_ephemeral_key: server_key + .try_into() + .expect("only supported key type should have been accepted"), + connection_info, + #[cfg(feature = "authdecode_unsafe")] + wants_authdecode, + }, + }) + } +} + +impl Verifier { + /// Starts notarization of the TLS session. + /// + /// If the verifier is a Notary, this function will transition the verifier + /// to the next state where it can sign the prover's commitments to the + /// transcript. + pub fn start_notarize(self) -> Verifier { + Verifier { + config: self.config, + span: self.span, + state: self.state.into(), + } + } + + /// Starts verification of the TLS session. + /// + /// This function transitions the verifier into a state where it can verify + /// content of the transcript. + pub fn start_verify(self) -> Verifier { + Verifier { + config: self.config, + span: self.span, + state: self.state.into(), + } + } +} + +/// Performs a setup of the various MPC subprotocols. +#[instrument(level = "debug", skip_all, err)] +async fn setup_mpc_backend( + config: &VerifierConfig, + protocol_config: ProtocolConfig, + mux: &MuxControl, + exec: &mut Executor, + encoder_seed: [u8; 32], +) -> Result<(MpcTlsFollower, DEAPThread, OTSender), VerifierError> { + debug!("starting MPC backend setup"); + + let mut ot_sender = kos::Sender::new( + config.build_ot_sender_config(), + chou_orlandi::Receiver::new(config.build_base_ot_receiver_config()), + ); + ot_sender.alloc(protocol_config.ot_sender_setup_count(Role::Verifier)); + + let mut ot_receiver = kos::Receiver::new( + config.build_ot_receiver_config(), + chou_orlandi::Sender::new(config.build_base_ot_sender_config()), + ); + ot_receiver.alloc(protocol_config.ot_receiver_setup_count(Role::Verifier)); + + let ot_sender = OTSender::new(ot_sender); + let ot_receiver = OTReceiver::new(ot_receiver); + + let ( + ctx_vm, + ctx_ke_0, + ctx_ke_1, + ctx_prf_0, + ctx_prf_1, + ctx_encrypter_block_cipher, + ctx_encrypter_stream_cipher, + ctx_encrypter_ghash, + ctx_encrypter, + ctx_decrypter_block_cipher, + ctx_decrypter_stream_cipher, + ctx_decrypter_ghash, + ctx_decrypter, + ) = futures::try_join!( + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + )?; + + let vm = DEAPThread::new( + DEAPRole::Follower, + encoder_seed, + ctx_vm, + ot_sender.clone(), + ot_receiver.clone(), + ); + + let mpc_tls_config = config.build_mpc_tls_config(&protocol_config); + let (ke, prf, encrypter, decrypter) = build_components( + TlsRole::Follower, + mpc_tls_config.common(), + ctx_ke_0, + ctx_encrypter, + ctx_decrypter, + ctx_encrypter_ghash, + ctx_decrypter_ghash, + vm.new_thread(ctx_ke_1, ot_sender.clone(), ot_receiver.clone())?, + vm.new_thread(ctx_prf_0, ot_sender.clone(), ot_receiver.clone())?, + vm.new_thread(ctx_prf_1, ot_sender.clone(), ot_receiver.clone())?, + vm.new_thread( + ctx_encrypter_block_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + vm.new_thread( + ctx_decrypter_block_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + vm.new_thread( + ctx_encrypter_stream_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + vm.new_thread( + ctx_decrypter_stream_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + ot_sender.clone(), + ot_receiver.clone(), + ); + + let channel = mux.open_framed(b"mpc_tls").await?; + let mut mpc_tls = MpcTlsFollower::new( + mpc_tls_config, + Box::new(StreamExt::compat_stream(channel)), + ke, + prf, + encrypter, + decrypter, + ); + + mpc_tls.setup().await?; + + debug!("MPC backend setup complete"); + + Ok((mpc_tls, vm, ot_sender)) +} diff --git a/crates/verifier/src/notarize.rs b/crates/verifier/src/notarize.rs new file mode 100644 index 0000000000..e60f634be1 --- /dev/null +++ b/crates/verifier/src/notarize.rs @@ -0,0 +1,124 @@ +//! This module handles the notarization phase of the verifier. +//! +//! The TLS verifier is only a notary. + +use super::{state::Notarize, Verifier, VerifierError}; +use mpz_ot::CommittedOTSender; +use serio::{stream::IoStreamExt, SinkExt as _}; + +use tlsn_core::{ + attestation::{Attestation, AttestationConfig}, + request::Request, +}; +use tracing::{debug, info, instrument}; + +#[cfg(feature = "authdecode_unsafe")] +use crate::authdecode::{authdecode_verifier, TranscriptVerifier}; +#[cfg(feature = "authdecode_unsafe")] +use tlsn_core::hash::HashAlgId; + +impl Verifier { + /// Notarizes the TLS session. + /// + /// # Arguments + /// + /// * `signer` - The signer used to sign the notarization result. + #[instrument(parent = &self.span, level = "debug", skip_all, err)] + pub async fn finalize(self, config: &AttestationConfig) -> Result { + let Notarize { + mut io, + mux_ctrl, + mut mux_fut, + mut vm, + mut ot_send, + mut ctx, + encoder_seed, + server_ephemeral_key, + connection_info, + #[cfg(feature = "authdecode_unsafe")] + wants_authdecode, + } = self.state; + + let attestation = mux_fut + .poll_with(async { + // Receive attestation request, which also contains commitments required before + // finalization. + let request: Request = io.expect_next().await?; + + #[cfg(feature = "authdecode_unsafe")] + let authdecode_verifier = if wants_authdecode { + let alg: HashAlgId = io.expect_next().await?; + + let mut verifier = authdecode_verifier(&alg); + + let max = self + .config + .protocol_config_validator() + .max_authdecode_data(); + + verifier + .receive_commitments(io.expect_next().await?, max) + .unwrap(); + + debug!("received Authdecode commitment"); + // Now that the commitments are received, it is safe to reveal MPC secrets. + Some(verifier) + } else { + None + }; + + // Finalize all MPC before attesting. + ot_send.reveal(&mut ctx).await?; + + debug!("revealed OT secret"); + + vm.finalize().await?; + + info!("Finalized all MPC"); + + #[allow(unused_mut)] + let mut builder = Attestation::builder(config); + + #[cfg(feature = "authdecode_unsafe")] + if wants_authdecode { + let mut authdecode_verifier = authdecode_verifier + .expect("AuthDecode verifier should be Some when wants_authdecode is set"); + + let hashes = authdecode_verifier + .verify(io.expect_next().await?, encoder_seed) + .unwrap(); + + debug!("verified Authdecode proofs"); + + builder.plaintext_hashes(hashes); + } + + let mut builder = builder + .accept_request(request) + .map_err(VerifierError::attestation)?; + + builder + .connection_info(connection_info) + .server_ephemeral_key(server_ephemeral_key) + .encoding_seed(encoder_seed.to_vec()); + + let attestation = builder + .build(self.config.crypto_provider()) + .map_err(VerifierError::attestation)?; + + io.send(attestation.clone()).await?; + + info!("Sent session header"); + + Ok::<_, VerifierError>(attestation) + }) + .await?; + + if !mux_fut.is_complete() { + mux_ctrl.mux().close(); + mux_fut.await?; + } + + Ok(attestation) + } +} diff --git a/crates/verifier/src/state.rs b/crates/verifier/src/state.rs new file mode 100644 index 0000000000..a98e575080 --- /dev/null +++ b/crates/verifier/src/state.rs @@ -0,0 +1,137 @@ +//! TLS Verifier state. + +use tls_mpc::MpcTlsFollower; +use tlsn_common::{ + mux::{MuxControl, MuxFuture}, + Context, DEAPThread, Io, OTSender, +}; +use tlsn_core::connection::{ConnectionInfo, ServerEphemKey}; + +/// TLS Verifier state. +pub trait VerifierState: sealed::Sealed {} + +/// Initialized state. +pub struct Initialized; + +opaque_debug::implement!(Initialized); + +/// State after MPC setup has completed. +pub struct Setup { + pub(crate) io: Io, + pub(crate) mux_ctrl: MuxControl, + pub(crate) mux_fut: MuxFuture, + + pub(crate) mpc_tls: MpcTlsFollower, + pub(crate) vm: DEAPThread, + pub(crate) ot_send: OTSender, + pub(crate) ctx: Context, + + pub(crate) encoder_seed: [u8; 32], + /// Whether the Prover wants to run the AuthDecode protocol. + #[cfg(feature = "authdecode_unsafe")] + pub(crate) wants_authdecode: bool, +} + +/// State after the TLS connection has been closed. +pub struct Closed { + pub(crate) io: Io, + pub(crate) mux_ctrl: MuxControl, + pub(crate) mux_fut: MuxFuture, + + pub(crate) vm: DEAPThread, + pub(crate) ot_send: OTSender, + pub(crate) ctx: Context, + + pub(crate) encoder_seed: [u8; 32], + pub(crate) server_ephemeral_key: ServerEphemKey, + pub(crate) connection_info: ConnectionInfo, + /// Whether the Prover wants to run the AuthDecode protocol. + #[cfg(feature = "authdecode_unsafe")] + pub(crate) wants_authdecode: bool, +} + +opaque_debug::implement!(Closed); + +/// Notarizing state. +pub struct Notarize { + pub(crate) io: Io, + pub(crate) mux_ctrl: MuxControl, + pub(crate) mux_fut: MuxFuture, + + pub(crate) vm: DEAPThread, + pub(crate) ot_send: OTSender, + pub(crate) ctx: Context, + + pub(crate) encoder_seed: [u8; 32], + pub(crate) server_ephemeral_key: ServerEphemKey, + pub(crate) connection_info: ConnectionInfo, + /// Whether the Prover wants to run the AuthDecode protocol. + #[cfg(feature = "authdecode_unsafe")] + pub(crate) wants_authdecode: bool, +} + +opaque_debug::implement!(Notarize); + +impl From for Notarize { + fn from(value: Closed) -> Self { + Self { + io: value.io, + mux_ctrl: value.mux_ctrl, + mux_fut: value.mux_fut, + vm: value.vm, + ot_send: value.ot_send, + ctx: value.ctx, + encoder_seed: value.encoder_seed, + server_ephemeral_key: value.server_ephemeral_key, + connection_info: value.connection_info, + #[cfg(feature = "authdecode_unsafe")] + wants_authdecode: value.wants_authdecode, + } + } +} + +/// Verifying state. +pub struct Verify { + pub(crate) io: Io, + pub(crate) mux_ctrl: MuxControl, + pub(crate) mux_fut: MuxFuture, + + pub(crate) vm: DEAPThread, + pub(crate) ot_send: OTSender, + pub(crate) ctx: Context, + + pub(crate) server_ephemeral_key: ServerEphemKey, + pub(crate) connection_info: ConnectionInfo, +} + +opaque_debug::implement!(Verify); + +impl From for Verify { + fn from(value: Closed) -> Self { + Self { + io: value.io, + mux_ctrl: value.mux_ctrl, + mux_fut: value.mux_fut, + vm: value.vm, + ot_send: value.ot_send, + ctx: value.ctx, + server_ephemeral_key: value.server_ephemeral_key, + connection_info: value.connection_info, + } + } +} + +impl VerifierState for Initialized {} +impl VerifierState for Setup {} +impl VerifierState for Closed {} +impl VerifierState for Notarize {} +impl VerifierState for Verify {} + +mod sealed { + pub trait Sealed {} + impl Sealed for super::Initialized {} + impl Sealed for super::Setup {} + impl Sealed for super::Closed {} + impl Sealed for super::Notarize {} + impl Sealed for super::Verify {} +} diff --git a/crates/verifier/src/verify.rs b/crates/verifier/src/verify.rs new file mode 100644 index 0000000000..83141d4b50 --- /dev/null +++ b/crates/verifier/src/verify.rs @@ -0,0 +1,131 @@ +//! This module handles the verification phase of the verifier. +//! +//! The TLS verifier is an application-specific verifier. + +use crate::SessionInfo; + +use super::{state::Verify as VerifyState, Verifier, VerifierError}; +use mpz_circuits::types::Value; +use mpz_garble::{Memory, Verify}; +use mpz_ot::CommittedOTSender; +use serio::stream::IoStreamExt; +use tlsn_common::msg::ServerIdentityProof; +use tlsn_core::transcript::{get_value_ids, Direction, PartialTranscript}; + +use tracing::{info, instrument}; + +impl Verifier { + /// Receives the **purported** transcript from the Prover. + /// + /// # Warning + /// + /// The content of the received transcripts can not be considered authentic + /// until after finalization. + #[instrument(parent = &self.span, level = "info", skip_all, err)] + pub async fn receive(&mut self) -> Result { + self.state + .mux_fut + .poll_with(async { + // Receive partial transcript from the prover + let partial_transcript: PartialTranscript = self.state.io.expect_next().await?; + + info!("Received partial transcript from prover"); + + // Check ranges + if partial_transcript.len_sent() + != self.state.connection_info.transcript_length.sent as usize + || partial_transcript.len_received() + != self.state.connection_info.transcript_length.received as usize + { + return Err(VerifierError::verify( + "prover sent transcript with incorrect length", + )); + } + + // Now verify the transcript parts which the prover wants to reveal + let sent_value_ids = + get_value_ids(Direction::Sent, partial_transcript.sent_authed()); + let recv_value_ids = + get_value_ids(Direction::Received, partial_transcript.received_authed()); + + let value_refs = sent_value_ids + .chain(recv_value_ids) + .map(|id| { + self.state + .vm + .get_value(id.as_str()) + .expect("Byte should be in VM memory") + }) + .collect::>(); + + let values = partial_transcript + .iter(Direction::Sent) + .chain(partial_transcript.iter(Direction::Received)) + .map(Value::U8) + .collect::>(); + + // Check that purported values are correct + self.state.vm.verify(&value_refs, &values).await?; + + info!("Successfully verified purported cleartext"); + + Ok::<_, VerifierError>(partial_transcript) + }) + .await + } + + /// Verifies the TLS session. + #[instrument(parent = &self.span, level = "info", skip_all, err)] + pub async fn finalize(self) -> Result { + let VerifyState { + mut io, + mux_ctrl, + mut mux_fut, + mut vm, + mut ot_send, + mut ctx, + server_ephemeral_key, + connection_info, + .. + } = self.state; + + let ServerIdentityProof { + name: server_name, + data, + } = mux_fut + .poll_with(async { + // Finalize all MPC + ot_send.reveal(&mut ctx).await?; + + vm.finalize().await?; + + info!("Finalized all MPC"); + + let identity_proof: ServerIdentityProof = io.expect_next().await?; + + Ok::<_, VerifierError>(identity_proof) + }) + .await?; + + // Verify the server identity data. + data.verify_with_provider( + self.config.crypto_provider(), + connection_info.time, + &server_ephemeral_key, + &server_name, + ) + .map_err(VerifierError::verify)?; + + info!("Successfully verified session"); + + if !mux_fut.is_complete() { + mux_ctrl.mux().close(); + mux_fut.await?; + } + + Ok(SessionInfo { + server_name, + connection_info, + }) + } +} diff --git a/crates/wasm-test-runner/Cargo.toml b/crates/wasm-test-runner/Cargo.toml new file mode 100644 index 0000000000..c3bf210b2e --- /dev/null +++ b/crates/wasm-test-runner/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "tlsn-wasm-test-runner" +version = "0.0.0" +edition = "2021" +publish = false + +[dependencies] +tlsn-common = { workspace = true } +tlsn-core = { workspace = true } +tlsn-prover = { workspace = true } +tlsn-server-fixture = { workspace = true } +tlsn-server-fixture-certs = { workspace = true } +tlsn-tls-core = { workspace = true } +tlsn-verifier = { workspace = true } + +websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "73a6be1" } + +anyhow = { workspace = true } +axum = { workspace = true } +chromiumoxide = { version = "0.6", features = ["tokio-runtime"] } +futures = { workspace = true } +once_cell = { workspace = true } +k256 = { workspace = true } +rand = { workspace = true } +serde = { workspace = true, features = ["derive"] } +tokio = { workspace = true, features = ["full"] } +tokio-tungstenite = { version = "0.23", features = ["url"] } +tokio-util = { workspace = true, features = ["compat"] } +tower = { version = "0.4" } +tower-http = { version = "0.5", features = ["fs", "set-header"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } diff --git a/crates/wasm-test-runner/run.sh b/crates/wasm-test-runner/run.sh new file mode 100755 index 0000000000..a43d98c284 --- /dev/null +++ b/crates/wasm-test-runner/run.sh @@ -0,0 +1,4 @@ +RUSTFLAGS='-C target-feature=+atomics,+bulk-memory,+mutable-globals' \ +rustup run nightly \ +wasm-pack build ../wasm --target web --no-pack --out-dir=../wasm-test-runner/static/generated -- -Zbuild-std=panic_abort,std --features test,no-bundler \ +&& RUST_LOG=debug cargo run --release diff --git a/crates/wasm-test-runner/src/chrome_driver.rs b/crates/wasm-test-runner/src/chrome_driver.rs new file mode 100644 index 0000000000..fc192c8c42 --- /dev/null +++ b/crates/wasm-test-runner/src/chrome_driver.rs @@ -0,0 +1,90 @@ +use anyhow::{anyhow, Result}; +use chromiumoxide::{ + cdp::{ + browser_protocol::log::{EventEntryAdded, LogEntryLevel}, + js_protocol::runtime::EventExceptionThrown, + }, + Browser, BrowserConfig, Page, +}; +use futures::{Future, FutureExt, StreamExt}; +use std::{env, time::Duration}; +use tracing::{debug, error, instrument}; + +use crate::{TestResult, DEFAULT_SERVER_IP, DEFAULT_WASM_PORT}; + +#[instrument] +pub async fn run() -> Result> { + let config = BrowserConfig::builder() + .request_timeout(Duration::from_secs(60)) + .incognito() // Run in incognito mode to avoid unexplained WS connection errors in chromiumoxide. + .build() + .map_err(|s| anyhow!(s))?; + + debug!("launching chromedriver"); + + let (mut browser, mut handler) = Browser::launch(config).await?; + + debug!("chromedriver started"); + + tokio::spawn(async move { + while let Some(res) = handler.next().await { + res.unwrap(); + } + }); + + let wasm_port: u16 = env::var("WASM_PORT") + .map(|port| port.parse().expect("port should be valid integer")) + .unwrap_or(DEFAULT_WASM_PORT); + let wasm_addr: String = env::var("WASM_IP").unwrap_or_else(|_| DEFAULT_SERVER_IP.to_string()); + + let page = browser + .new_page(&format!("http://{}:{}/index.html", wasm_addr, wasm_port)) + .await?; + + tokio::spawn(register_listeners(&page).await?); + + page.wait_for_navigation().await?; + let results: Vec = page + .evaluate( + r#" + (async () => { + await window.testWorker.init(); + return await window.testWorker.run(); + })(); + "#, + ) + .await? + .into_value()?; + + browser.close().await?; + browser.wait().await?; + + Ok(results) +} + +async fn register_listeners(page: &Page) -> Result> { + let mut logs = page.event_listener::().await?.fuse(); + let mut exceptions = page.event_listener::().await?.fuse(); + + Ok(futures::future::join( + async move { + while let Some(event) = logs.next().await { + let entry = &event.entry; + match entry.level { + LogEntryLevel::Error => { + error!("{:?}", entry); + } + _ => { + debug!("{:?}: {}", entry.timestamp, entry.text); + } + } + } + }, + async move { + while let Some(event) = exceptions.next().await { + error!("{:?}", event); + } + }, + ) + .map(|_| ())) +} diff --git a/crates/wasm-test-runner/src/lib.rs b/crates/wasm-test-runner/src/lib.rs new file mode 100644 index 0000000000..c727704b83 --- /dev/null +++ b/crates/wasm-test-runner/src/lib.rs @@ -0,0 +1,37 @@ +use std::fmt::Display; + +pub mod chrome_driver; +pub mod server_fixture; +pub mod tlsn_fixture; +pub mod wasm_server; +pub mod ws; + +pub static DEFAULT_SERVER_IP: &str = "127.0.0.1"; +pub static DEFAULT_WASM_PORT: u16 = 8013; +pub static DEFAULT_WS_PORT: u16 = 8080; +pub static DEFAULT_SERVER_PORT: u16 = 8083; +pub static DEFAULT_VERIFIER_PORT: u16 = 8010; +pub static DEFAULT_NOTARY_PORT: u16 = 8011; +pub static DEFAULT_PROVER_PORT: u16 = 8012; + +#[derive(Debug, serde::Deserialize)] +pub struct TestResult { + pub name: String, + pub passed: bool, + pub error: Option, +} + +impl Display for TestResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.passed { + write!(f, "{}: passed", self.name)?; + } else { + write!(f, "{}: failed", self.name)?; + if let Some(error) = &self.error { + write!(f, "\ncaused by: {}", error)?; + } + } + + Ok(()) + } +} diff --git a/crates/wasm-test-runner/src/main.rs b/crates/wasm-test-runner/src/main.rs new file mode 100644 index 0000000000..7b8fb5fa73 --- /dev/null +++ b/crates/wasm-test-runner/src/main.rs @@ -0,0 +1,42 @@ +use anyhow::Result; + +fn init_tracing() { + use tracing_subscriber::EnvFilter; + + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); +} + +#[tokio::main] +async fn main() -> Result<()> { + init_tracing(); + + let fut_wasm = tlsn_wasm_test_runner::wasm_server::start().await?; + let fut_proxy = tlsn_wasm_test_runner::ws::start().await?; + let fut_tlsn = tlsn_wasm_test_runner::tlsn_fixture::start().await?; + let fut_server = tlsn_wasm_test_runner::server_fixture::start().await?; + + tokio::spawn(async move { + futures::future::try_join4(fut_wasm, fut_proxy, fut_tlsn, fut_server) + .await + .unwrap() + }); + + let results = tlsn_wasm_test_runner::chrome_driver::run().await?; + + for result in &results { + println!("{}", result); + } + + let passed = results.iter().filter(|r| r.passed).count(); + let failed = results.iter().filter(|r| !r.passed).count(); + + println!("{} passed, {} failed", passed, failed); + + if results.iter().any(|r| !r.passed) { + std::process::exit(1); + } + + Ok(()) +} diff --git a/crates/wasm-test-runner/src/server_fixture.rs b/crates/wasm-test-runner/src/server_fixture.rs new file mode 100644 index 0000000000..6a56fb6c08 --- /dev/null +++ b/crates/wasm-test-runner/src/server_fixture.rs @@ -0,0 +1,34 @@ +use std::{env, net::IpAddr}; + +use tlsn_server_fixture; + +use anyhow::Result; +use futures::Future; +use tokio::net::TcpListener; +use tokio_util::compat::TokioAsyncReadCompatExt; +use tracing::{info, instrument}; + +use crate::{DEFAULT_SERVER_IP, DEFAULT_SERVER_PORT}; + +#[instrument] +pub async fn start() -> Result>> { + let port: u16 = env::var("SERVER_PORT") + .map(|port| port.parse().expect("port should be valid integer")) + .unwrap_or(DEFAULT_SERVER_PORT); + let addr: IpAddr = env::var("SERVER_IP") + .map(|addr| addr.parse().expect("should be valid IP address")) + .unwrap_or(IpAddr::V4(DEFAULT_SERVER_IP.parse().unwrap())); + + let listener = TcpListener::bind((addr, port)).await?; + + info!("listening on: {}", listener.local_addr()?); + + Ok(async move { + loop { + let (socket, addr) = listener.accept().await?; + info!("accepted connection from: {}", addr); + + tokio::spawn(tlsn_server_fixture::bind(socket.compat())); + } + }) +} diff --git a/crates/wasm-test-runner/src/tlsn_fixture.rs b/crates/wasm-test-runner/src/tlsn_fixture.rs new file mode 100644 index 0000000000..d2a1366e8d --- /dev/null +++ b/crates/wasm-test-runner/src/tlsn_fixture.rs @@ -0,0 +1,193 @@ +use std::{env, net::IpAddr}; + +use anyhow::Result; +use futures::{AsyncReadExt, AsyncWriteExt, Future}; +use tls_core::{anchors::RootCertStore, verify::WebPkiVerifier}; +use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; +use tlsn_core::{ + attestation::AttestationConfig, signing::SignatureAlgId, transcript::Idx, CryptoProvider, +}; +use tlsn_prover::{Prover, ProverConfig}; +use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN}; +use tlsn_verifier::{Verifier, VerifierConfig}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_util::compat::TokioAsyncReadCompatExt; +use tracing::{info, instrument}; + +use crate::{ + DEFAULT_NOTARY_PORT, DEFAULT_PROVER_PORT, DEFAULT_SERVER_IP, DEFAULT_SERVER_PORT, + DEFAULT_VERIFIER_PORT, +}; + +#[instrument] +pub async fn start() -> Result>> { + let verifier_port: u16 = env::var("VERIFIER_PORT") + .map(|port| port.parse().expect("port should be valid integer")) + .unwrap_or(DEFAULT_VERIFIER_PORT); + let notary_port: u16 = env::var("NOTARY_PORT") + .map(|port| port.parse().expect("port should be valid integer")) + .unwrap_or(DEFAULT_NOTARY_PORT); + let prover_port: u16 = env::var("PROVER_PORT") + .map(|port| port.parse().expect("port should be valid integer")) + .unwrap_or(DEFAULT_PROVER_PORT); + let addr: IpAddr = env::var("TLSN_IP") + .map(|addr| addr.parse().expect("should be valid IP address")) + .unwrap_or(IpAddr::V4(DEFAULT_SERVER_IP.parse().unwrap())); + + let verifier_listener = TcpListener::bind((addr, verifier_port)).await?; + let notary_listener = TcpListener::bind((addr, notary_port)).await?; + let prover_listener = TcpListener::bind((addr, prover_port)).await?; + + Ok(async move { + loop { + tokio::select! { + res = verifier_listener.accept() => { + let (socket, addr) = res?; + info!("verifier accepted connection from: {}", addr); + + tokio::spawn(handle_verifier(socket)); + }, + res = notary_listener.accept() => { + let (socket, addr) = res?; + info!("notary accepted connection from: {}", addr); + + tokio::spawn(handle_notary(socket)); + }, + res = prover_listener.accept() => { + let (socket, addr) = res?; + info!("prover accepted connection from: {}", addr); + + tokio::spawn(handle_prover(socket)); + }, + } + } + }) +} + +#[instrument(level = "debug", skip_all, err)] +async fn handle_verifier(io: TcpStream) -> Result<()> { + let mut root_store = RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + let config_validator = ProtocolConfigValidator::builder() + .max_sent_data(1024) + .max_recv_data(1024) + .build() + .unwrap(); + + let provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store, None), + ..Default::default() + }; + + let config = VerifierConfig::builder() + .crypto_provider(provider) + .protocol_config_validator(config_validator) + .build() + .unwrap(); + + let verifier = Verifier::new(config); + + verifier.verify(io.compat()).await?; + + Ok(()) +} + +#[instrument(level = "debug", skip_all, err)] +async fn handle_notary(io: TcpStream) -> Result<()> { + let mut provider = CryptoProvider::default(); + + provider.signer.set_secp256k1(&[1u8; 32]).unwrap(); + + let config_validator = ProtocolConfigValidator::builder() + .max_sent_data(1024) + .max_recv_data(1024) + .build() + .unwrap(); + + let config = VerifierConfig::builder() + .protocol_config_validator(config_validator) + .crypto_provider(provider) + .build() + .unwrap(); + + let verifier = Verifier::new(config); + + let mut builder = AttestationConfig::builder(); + builder.supported_signature_algs(vec![SignatureAlgId::SECP256K1]); + + let attestation_config = builder.build().unwrap(); + + verifier.notarize(io.compat(), &attestation_config).await?; + + Ok(()) +} + +#[instrument(level = "debug", skip_all, err)] +async fn handle_prover(io: TcpStream) -> Result<()> { + let mut root_store = tls_core::anchors::RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + let provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store, None), + ..Default::default() + }; + + let protocol_config = ProtocolConfig::builder() + .max_sent_data(1024) + .max_recv_data(1024) + .build() + .unwrap(); + + let prover = Prover::new( + ProverConfig::builder() + .server_name(SERVER_DOMAIN) + .protocol_config(protocol_config) + .crypto_provider(provider) + .build() + .unwrap(), + ) + .setup(io.compat()) + .await + .unwrap(); + + let port: u16 = env::var("SERVER_PORT") + .map(|port| port.parse().expect("port should be valid integer")) + .unwrap_or(DEFAULT_SERVER_PORT); + let addr: IpAddr = env::var("SERVER_IP") + .map(|addr| addr.parse().expect("should be valid IP address")) + .unwrap_or(IpAddr::V4(DEFAULT_SERVER_IP.parse().unwrap())); + + let client_socket = TcpStream::connect((addr, port)).await.unwrap(); + + let (mut tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); + let prover_task = tokio::spawn(prover_fut); + + tls_connection + .write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + tls_connection.close().await.unwrap(); + + let mut response = vec![0u8; 1024]; + tls_connection.read_to_end(&mut response).await.unwrap(); + + let mut prover = prover_task.await.unwrap().unwrap().start_prove(); + + let sent_transcript_len = prover.transcript().sent().len(); + let recv_transcript_len = prover.transcript().received().len(); + + let sent_idx = Idx::new(0..sent_transcript_len - 1); + let recv_idx = Idx::new(2..recv_transcript_len); + + // Reveal parts of the transcript + prover.prove_transcript(sent_idx, recv_idx).await.unwrap(); + + prover.finalize().await.unwrap(); + + Ok(()) +} diff --git a/crates/wasm-test-runner/src/wasm_server.rs b/crates/wasm-test-runner/src/wasm_server.rs new file mode 100644 index 0000000000..9e01d903e6 --- /dev/null +++ b/crates/wasm-test-runner/src/wasm_server.rs @@ -0,0 +1,49 @@ +use std::{env, net::IpAddr}; + +use anyhow::Result; +use axum::{ + http::{HeaderName, HeaderValue}, + Router, +}; +use futures::Future; +use tokio::net::TcpListener; +use tower::ServiceBuilder; +use tower_http::{services::ServeDir, set_header::SetResponseHeaderLayer}; +use tracing::{info, instrument}; + +use crate::{DEFAULT_SERVER_IP, DEFAULT_WASM_PORT}; + +#[instrument] +pub async fn start() -> Result>> { + let port: u16 = env::var("WASM_PORT") + .map(|port| port.parse().expect("port should be valid integer")) + .unwrap_or(DEFAULT_WASM_PORT); + let addr: IpAddr = env::var("WASM_IP") + .map(|addr| addr.parse().expect("should be valid IP address")) + .unwrap_or(IpAddr::V4(DEFAULT_SERVER_IP.parse().unwrap())); + + let files = ServeDir::new("static"); + + let service = ServiceBuilder::new() + .layer(SetResponseHeaderLayer::if_not_present( + HeaderName::from_static("cross-origin-embedder-policy"), + HeaderValue::from_static("require-corp"), + )) + .layer(SetResponseHeaderLayer::if_not_present( + HeaderName::from_static("cross-origin-opener-policy"), + HeaderValue::from_static("same-origin"), + )) + .service(files); + + // build our application with a single route + let app = Router::new().fallback_service(service); + + let listener = TcpListener::bind((addr, port)).await?; + + info!("listening on {}", listener.local_addr()?); + + Ok(async move { + axum::serve(listener, app).await?; + Ok(()) + }) +} diff --git a/crates/wasm-test-runner/src/ws.rs b/crates/wasm-test-runner/src/ws.rs new file mode 100644 index 0000000000..0cc294ec54 --- /dev/null +++ b/crates/wasm-test-runner/src/ws.rs @@ -0,0 +1,26 @@ +use std::{env, net::IpAddr}; + +use anyhow::{Context, Result}; +use futures::Future; +use tokio::net::TcpListener; +use tracing::{info, instrument}; + +use crate::{DEFAULT_SERVER_IP, DEFAULT_WS_PORT}; + +#[instrument] +pub async fn start() -> Result>> { + let port: u16 = env::var("PROXY_PORT") + .map(|port| port.parse().expect("port should be valid integer")) + .unwrap_or(DEFAULT_WS_PORT); + let addr: IpAddr = env::var("PROXY_IP") + .map(|addr| addr.parse().expect("should be valid IP address")) + .unwrap_or(IpAddr::V4(DEFAULT_SERVER_IP.parse().unwrap())); + + let listener = TcpListener::bind((addr, port)) + .await + .context("failed to bind to address")?; + + info!("listening on: {}", listener.local_addr()?); + + Ok(websocket_relay::run(listener)) +} diff --git a/crates/wasm-test-runner/static/favicon.ico b/crates/wasm-test-runner/static/favicon.ico new file mode 100644 index 0000000000..0431afc3bd Binary files /dev/null and b/crates/wasm-test-runner/static/favicon.ico differ diff --git a/crates/wasm-test-runner/static/index.html b/crates/wasm-test-runner/static/index.html new file mode 100644 index 0000000000..becde77cd4 --- /dev/null +++ b/crates/wasm-test-runner/static/index.html @@ -0,0 +1,7 @@ + + + + + + + diff --git a/crates/wasm-test-runner/static/index.js b/crates/wasm-test-runner/static/index.js new file mode 100644 index 0000000000..b8fe8e2fc0 --- /dev/null +++ b/crates/wasm-test-runner/static/index.js @@ -0,0 +1,5 @@ +import * as Comlink from "https://unpkg.com/comlink/dist/esm/comlink.mjs"; + +const testWorker = Comlink.wrap(new Worker("worker.js", { type: "module" })); + +window.testWorker = testWorker; diff --git a/crates/wasm-test-runner/static/worker.js b/crates/wasm-test-runner/static/worker.js new file mode 100644 index 0000000000..1853e0a9b2 --- /dev/null +++ b/crates/wasm-test-runner/static/worker.js @@ -0,0 +1,41 @@ +import * as Comlink from "https://unpkg.com/comlink/dist/esm/comlink.mjs"; +import init_wasm, { init_logging, initThreadPool } from "./generated/tlsn_wasm.js"; + +const module = await import("./generated/tlsn_wasm.js"); + +class TestWorker { + async init() { + try { + await init_wasm(); + init_logging(); + console.log("initialized logging"); + await initThreadPool(8); + console.log("initialized worker"); + } catch (e) { + console.error(e); + throw e; + } + } + + run() { + let promises = []; + for (const [name, func] of Object.entries(module)) { + + if(name.startsWith("test_") && (typeof func === 'function')) { + promises.push(func().then(_ => { return { + name: name, + passed: true, + } }).catch(error => { return { + name: name, + passed: false, + error: error.toString(), + } })); + } + } + return Promise.all(promises); + } +} + +const worker = new TestWorker(); + +Comlink.expose(worker); diff --git a/crates/wasm/.cargo/config.toml b/crates/wasm/.cargo/config.toml new file mode 100644 index 0000000000..bdc14ff398 --- /dev/null +++ b/crates/wasm/.cargo/config.toml @@ -0,0 +1,8 @@ +[build] +target = "wasm32-unknown-unknown" + +[unstable] +build-std = ["panic_abort", "std"] + +[target.wasm32-unknown-unknown] +rustflags = ["-C", "target-feature=+atomics,+bulk-memory,+mutable-globals"] diff --git a/crates/wasm/Cargo.toml b/crates/wasm/Cargo.toml new file mode 100644 index 0000000000..d87ccbc089 --- /dev/null +++ b/crates/wasm/Cargo.toml @@ -0,0 +1,55 @@ +[package] +name = "tlsn-wasm" +version = "0.1.0-alpha.7" +edition = "2021" + +[lib] +crate-type = ["cdylib", "rlib"] + +[package.metadata.wasm-pack.profile.debug] +debug = false + +[package.metadata.wasm-pack.profile.release] +opt-level = "z" +wasm-opt = true + +[features] +default = [] +test = [] +no-bundler = ["wasm-bindgen-rayon/no-bundler"] + +[dependencies] +tlsn-common = { path = "../common" } +tlsn-core = { path = "../core" } +tlsn-prover = { path = "../prover" } +tlsn-server-fixture-certs = { workspace = true } +tlsn-tls-client-async = { path = "../tls/client-async" } +tlsn-tls-core = { path = "../tls/core" } +tlsn-verifier = { path = "../verifier" } + +bincode = { workspace = true } +console_error_panic_hook = { version = "0.1" } +enum-try-as-inner = { workspace = true } +futures = { workspace = true } +getrandom = { version = "0.2", features = ["js"] } +http-body-util = { version = "0.1" } +hyper = { workspace = true, features = ["client", "http1"] } +p256 = { workspace = true } +parking_lot = { version = "0.12", features = ["nightly"] } +pin-project-lite = { workspace = true } +ring = { version = "0.17", features = ["wasm32_unknown_unknown_js"] } +serde = { workspace = true, features = ["derive"] } +serde-wasm-bindgen = { version = "0.6" } +serde_json = { version = "1.0" } +time = { version = "0.3", features = ["wasm-bindgen"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["time"] } +tracing-web = { version = "0.1" } +tsify-next = { version = "0.5", default-features = false, features = ["js"] } +wasm-bindgen = { version = "=0.2.92" } +wasm-bindgen-futures = { version = "0.4" } +# Use the patched ws_stream_wasm to fix the issue https://github.com/najamelan/ws_stream_wasm/issues/12#issuecomment-1711902958 +ws_stream_wasm = { git = "https://github.com/tlsnotary/ws_stream_wasm", rev = "2ed12aad9f0236e5321f577672f309920b2aef51" } + +[target.'cfg(target_arch = "wasm32")'.dependencies] +wasm-bindgen-rayon = { version = "1.0" } diff --git a/crates/wasm/README.md b/crates/wasm/README.md new file mode 100644 index 0000000000..b3849ef45c --- /dev/null +++ b/crates/wasm/README.md @@ -0,0 +1,13 @@ +# TLSNotary WASM bindings + +## Build + +This crate must be built using the nightly rust compiler with the following flags: + +```bash +RUSTFLAGS='-C target-feature=+atomics,+bulk-memory,+mutable-globals' \ + rustup run nightly \ + wasm-pack build --target web . -- -Zbuild-std=panic_abort,std +``` + + diff --git a/crates/wasm/build.sh b/crates/wasm/build.sh new file mode 100755 index 0000000000..0a9b965865 --- /dev/null +++ b/crates/wasm/build.sh @@ -0,0 +1,3 @@ +RUSTFLAGS='-C target-feature=+atomics,+bulk-memory,+mutable-globals' \ + rustup run nightly \ + wasm-pack build --target web . -- -Zbuild-std=panic_abort,std diff --git a/crates/wasm/rust-toolchain b/crates/wasm/rust-toolchain new file mode 100644 index 0000000000..5d56faf9ae --- /dev/null +++ b/crates/wasm/rust-toolchain @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly" diff --git a/crates/wasm/src/io.rs b/crates/wasm/src/io.rs new file mode 100644 index 0000000000..91df76dd35 --- /dev/null +++ b/crates/wasm/src/io.rs @@ -0,0 +1,88 @@ +use core::slice; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use pin_project_lite::pin_project; + +pin_project! { + #[derive(Debug)] + pub(crate) struct FuturesIo { + #[pin] + inner: T, + } +} + +impl FuturesIo { + /// Create a new `FuturesIo` wrapping the given I/O object. + /// + /// # Safety + /// + /// This wrapper is only safe to use if the inner I/O object does not under + /// any circumstance read from the buffer passed to `poll_read` in the + /// `futures::AsyncRead` implementation. + pub(crate) fn new(inner: T) -> Self { + Self { inner } + } +} + +impl hyper::rt::Write for FuturesIo +where + T: futures::AsyncWrite + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().inner.poll_close(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + self.project().inner.poll_write_vectored(cx, bufs) + } +} + +// Adapted from https://github.com/hyperium/hyper-util/blob/99b77a5a6f75f24bc0bcb4ca74b5f26a07b19c80/src/rt/tokio.rs +impl hyper::rt::Read for FuturesIo +where + T: futures::AsyncRead + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + // Safety: buf_slice should only be written to, so it's safe to convert `&mut + // [MaybeUninit]` to `&mut [u8]`. + let buf_slice = unsafe { + slice::from_raw_parts_mut(buf.as_mut().as_mut_ptr() as *mut u8, buf.as_mut().len()) + }; + + let n = match futures::AsyncRead::poll_read(self.project().inner, cx, buf_slice) { + Poll::Ready(Ok(n)) => n, + other => return other.map_ok(|_| ()), + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} diff --git a/crates/wasm/src/lib.rs b/crates/wasm/src/lib.rs new file mode 100644 index 0000000000..c7338c532c --- /dev/null +++ b/crates/wasm/src/lib.rs @@ -0,0 +1,96 @@ +//! TLSNotary WASM bindings. + +#![deny(unreachable_pub, unused_must_use, clippy::all)] +#![allow(non_snake_case)] + +pub(crate) mod io; +mod log; +pub mod prover; +#[cfg(feature = "test")] +pub mod tests; +pub mod types; +pub mod verifier; + +use log::LoggingConfig; +use tlsn_core::{transcript::Direction, CryptoProvider}; +use tracing::error; +use tracing_subscriber::{ + filter::FilterFn, + fmt::{format::FmtSpan, time::UtcTime}, + layer::SubscriberExt, + util::SubscriberInitExt, +}; +use tracing_web::MakeWebConsoleWriter; +use wasm_bindgen::prelude::*; + +use crate::types::{Attestation, Presentation, Reveal, Secrets}; + +#[cfg(feature = "test")] +pub use tests::*; + +#[cfg(target_arch = "wasm32")] +pub use wasm_bindgen_rayon::init_thread_pool; + +/// Initializes logging. +#[wasm_bindgen] +pub fn init_logging(config: Option) { + let mut config = config.unwrap_or_default(); + + // Default is NONE + let fmt_span = config + .span_events + .take() + .unwrap_or_default() + .into_iter() + .map(FmtSpan::from) + .fold(FmtSpan::NONE, |acc, span| acc | span); + + let fmt_layer = tracing_subscriber::fmt::layer() + .with_ansi(false) // Only partially supported across browsers + .with_timer(UtcTime::rfc_3339()) // std::time is not available in browsers + .with_span_events(fmt_span) + .without_time() + .with_writer(MakeWebConsoleWriter::new()); // write events to the console + + tracing_subscriber::registry() + .with(FilterFn::new(log::filter(config))) + .with(fmt_layer) + .init(); + + // https://github.com/rustwasm/console_error_panic_hook + std::panic::set_hook(Box::new(|info| { + error!("panic occurred: {:?}", info); + console_error_panic_hook::hook(info); + })); +} + +/// Builds a presentation. +#[wasm_bindgen] +pub fn build_presentation( + attestation: &Attestation, + secrets: &Secrets, + reveal: Reveal, +) -> Result { + let provider = CryptoProvider::default(); + + let mut builder = attestation.0.presentation_builder(&provider); + + builder.identity_proof(secrets.0.identity_proof()); + + let mut proof_builder = secrets.0.transcript_proof_builder(); + + for range in reveal.sent.iter() { + proof_builder.reveal(range, Direction::Sent)?; + } + + for range in reveal.recv.iter() { + proof_builder.reveal(range, Direction::Received)?; + } + + builder.transcript_proof(proof_builder.build()?); + + builder + .build() + .map(Presentation::from) + .map_err(JsError::from) +} diff --git a/crates/wasm/src/log.rs b/crates/wasm/src/log.rs new file mode 100644 index 0000000000..81cf2d6b33 --- /dev/null +++ b/crates/wasm/src/log.rs @@ -0,0 +1,88 @@ +use serde::Deserialize; +use tracing::{Level, Metadata}; +use tracing_subscriber::fmt::format::FmtSpan; +use tsify_next::Tsify; + +#[derive(Debug, Clone, Copy, Tsify, Deserialize)] +#[tsify(from_wasm_abi)] +pub enum LoggingLevel { + Trace, + Debug, + Info, + Warn, + Error, +} + +impl From for Level { + fn from(value: LoggingLevel) -> Self { + match value { + LoggingLevel::Trace => Level::TRACE, + LoggingLevel::Debug => Level::DEBUG, + LoggingLevel::Info => Level::INFO, + LoggingLevel::Warn => Level::WARN, + LoggingLevel::Error => Level::ERROR, + } + } +} + +#[derive(Debug, Clone, Copy, Tsify, Deserialize)] +#[tsify(from_wasm_abi)] +pub enum SpanEvent { + New, + Close, + Active, +} + +impl From for FmtSpan { + fn from(value: SpanEvent) -> Self { + match value { + SpanEvent::New => FmtSpan::NEW, + SpanEvent::Close => FmtSpan::CLOSE, + SpanEvent::Active => FmtSpan::ACTIVE, + } + } +} + +#[derive(Debug, Default, Tsify, Deserialize)] +#[tsify(from_wasm_abi)] +pub struct LoggingConfig { + pub level: Option, + pub crate_filters: Option>, + pub span_events: Option>, +} + +#[derive(Debug, Tsify, Deserialize)] +#[tsify(from_wasm_abi)] +pub struct CrateLogFilter { + pub level: LoggingLevel, + pub name: String, +} + +pub(crate) fn filter(config: LoggingConfig) -> impl Fn(&Metadata) -> bool { + let default_level: Level = config.level.unwrap_or(LoggingLevel::Info).into(); + let crate_filters = config + .crate_filters + .unwrap_or_default() + .into_iter() + .map(|filter| (filter.name, Level::from(filter.level))) + .collect::>(); + + move |meta| { + let level = if let Some(crate_name) = meta.target().split("::").next() { + crate_filters + .iter() + .find_map(|(filter_name, filter_level)| { + if crate_name.eq_ignore_ascii_case(filter_name) { + Some(filter_level) + } else { + None + } + }) + .unwrap_or(&default_level) + } else { + &default_level + }; + + meta.level() <= level + } +} diff --git a/crates/wasm/src/prover/config.rs b/crates/wasm/src/prover/config.rs new file mode 100644 index 0000000000..76603cf2c9 --- /dev/null +++ b/crates/wasm/src/prover/config.rs @@ -0,0 +1,39 @@ +use serde::Deserialize; +use tlsn_common::config::ProtocolConfig; +use tsify_next::Tsify; + +#[derive(Debug, Tsify, Deserialize)] +#[tsify(from_wasm_abi)] +pub struct ProverConfig { + pub server_name: String, + pub max_sent_data: usize, + pub max_recv_data_online: Option, + pub max_recv_data: usize, + pub defer_decryption_from_start: Option, +} + +impl From for tlsn_prover::ProverConfig { + fn from(value: ProverConfig) -> Self { + let mut builder = ProtocolConfig::builder(); + + builder.max_sent_data(value.max_sent_data); + builder.max_recv_data(value.max_recv_data); + + if let Some(value) = value.max_recv_data_online { + builder.max_recv_data_online(value); + } + + let protocol_config = builder.build().unwrap(); + + let mut builder = tlsn_prover::ProverConfig::builder(); + builder + .server_name(value.server_name.as_ref()) + .protocol_config(protocol_config); + + if let Some(value) = value.defer_decryption_from_start { + builder.defer_decryption_from_start(value); + } + + builder.build().unwrap() + } +} diff --git a/crates/wasm/src/prover/mod.rs b/crates/wasm/src/prover/mod.rs new file mode 100644 index 0000000000..dc32924466 --- /dev/null +++ b/crates/wasm/src/prover/mod.rs @@ -0,0 +1,202 @@ +mod config; + +pub use config::ProverConfig; + +use enum_try_as_inner::EnumTryAsInner; +use futures::TryFutureExt; +use http_body_util::{BodyExt, Full}; +use hyper::body::Bytes; +use tls_client_async::TlsConnection; +use tlsn_core::{ + request::RequestConfig, + transcript::{Idx, TranscriptCommitConfigBuilder}, +}; +use tlsn_prover::{state, Prover}; +use tracing::info; +use wasm_bindgen::{prelude::*, JsError}; +use wasm_bindgen_futures::spawn_local; +use ws_stream_wasm::WsMeta; + +use crate::{io::FuturesIo, types::*}; + +type Result = std::result::Result; + +#[wasm_bindgen(js_name = Prover)] +pub struct JsProver { + state: State, +} + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +enum State { + Initialized(Prover), + Setup(Prover), + Closed(Prover), + Complete, + Error, +} + +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, State::Error) + } +} + +#[wasm_bindgen(js_class = Prover)] +impl JsProver { + #[wasm_bindgen(constructor)] + pub fn new(config: ProverConfig) -> JsProver { + JsProver { + state: State::Initialized(Prover::new(config.into())), + } + } + + /// Set up the prover. + /// + /// This performs all MPC setup prior to establishing the connection to the + /// application server. + pub async fn setup(&mut self, verifier_url: &str) -> Result<()> { + let prover = self.state.take().try_into_initialized()?; + + info!("connecting to verifier"); + + let (_, verifier_conn) = WsMeta::connect(verifier_url, None).await?; + + info!("connected to verifier"); + + let prover = prover.setup(verifier_conn.into_io()).await?; + + self.state = State::Setup(prover); + + Ok(()) + } + + /// Send the HTTP request to the server. + pub async fn send_request( + &mut self, + ws_proxy_url: &str, + request: HttpRequest, + ) -> Result { + let prover = self.state.take().try_into_setup()?; + + info!("connecting to server"); + + let (_, server_conn) = WsMeta::connect(ws_proxy_url, None).await?; + + info!("connected to server"); + + let (tls_conn, prover_fut) = prover.connect(server_conn.into_io()).await?; + + info!("sending request"); + + let (response, prover) = futures::try_join!( + send_request(tls_conn, request), + prover_fut.map_err(Into::into) + )?; + + info!("response received"); + + self.state = State::Closed(prover); + + Ok(response) + } + + /// Returns the transcript. + pub fn transcript(&self) -> Result { + let prover = self.state.try_as_closed()?; + + Ok(Transcript::from(prover.transcript())) + } + + /// Runs the notarization protocol. + pub async fn notarize(&mut self, commit: Commit) -> Result { + let mut prover = self.state.take().try_into_closed()?.start_notarize(); + + info!("starting notarization"); + + let mut builder = TranscriptCommitConfigBuilder::new(prover.transcript()); + + for range in commit.sent { + builder.commit_sent(&range)?; + } + + for range in commit.recv { + builder.commit_recv(&range)?; + } + + let config = builder.build()?; + + prover.transcript_commit(config); + + let request_config = RequestConfig::default(); + let (attestation, secrets) = prover.finalize(&request_config).await?; + + info!("notarization complete"); + + self.state = State::Complete; + + Ok(NotarizationOutput { + attestation: attestation.into(), + secrets: secrets.into(), + }) + } + + /// Reveals data to the verifier and finalizes the protocol. + pub async fn reveal(&mut self, reveal: Reveal) -> Result<()> { + let mut prover = self.state.take().try_into_closed()?.start_prove(); + + info!("revealing data"); + + let sent = Idx::new(reveal.sent); + let recv = Idx::new(reveal.recv); + + prover.prove_transcript(sent, recv).await?; + prover.finalize().await?; + + info!("Finalized"); + + self.state = State::Complete; + + Ok(()) + } +} + +impl From> for JsProver { + fn from(value: Prover) -> Self { + JsProver { + state: State::Initialized(value), + } + } +} + +async fn send_request(conn: TlsConnection, request: HttpRequest) -> Result { + let conn = FuturesIo::new(conn); + let request = hyper::Request::>::try_from(request)?; + + let (mut request_sender, conn) = hyper::client::conn::http1::handshake(conn).await?; + + spawn_local(async move { conn.await.expect("connection runs to completion") }); + + let response = request_sender.send_request(request).await?; + + let (response, body) = response.into_parts(); + + // TODO: return the body + let _body = body.collect().await?; + + let headers = response + .headers + .into_iter() + .map(|(k, v)| { + ( + k.map(|k| k.to_string()).unwrap_or_default(), + v.as_bytes().to_vec(), + ) + }) + .collect(); + + Ok(HttpResponse { + status: response.status.as_u16(), + headers, + }) +} diff --git a/crates/wasm/src/tests.rs b/crates/wasm/src/tests.rs new file mode 100644 index 0000000000..80689ed6c4 --- /dev/null +++ b/crates/wasm/src/tests.rs @@ -0,0 +1,186 @@ +#![allow(clippy::single_range_in_vec_init)] + +use std::collections::HashMap; + +use tls_core::verify::WebPkiVerifier; +use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; +use tlsn_core::CryptoProvider; +use tlsn_prover::{Prover, ProverConfig}; +use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN}; +use tlsn_verifier::{Verifier, VerifierConfig}; +use wasm_bindgen::prelude::*; + +use crate::{ + build_presentation, + prover::JsProver, + types::{ + Attestation, Commit, HttpRequest, Method, NotarizationOutput, Presentation, Reveal, Secrets, + }, + verifier::JsVerifier, +}; + +#[wasm_bindgen] +pub async fn test_prove() -> Result<(), JsValue> { + let mut root_store = tls_core::anchors::RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + let provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store, None), + ..Default::default() + }; + + let prover = Prover::new( + ProverConfig::builder() + .server_name(SERVER_DOMAIN) + .protocol_config( + ProtocolConfig::builder() + .max_sent_data(1024) + .max_recv_data(1024) + .build() + .unwrap(), + ) + .crypto_provider(provider) + .build() + .unwrap(), + ); + + let mut prover = JsProver::from(prover); + + let uri = format!("https://{}/bytes?size=512", SERVER_DOMAIN); + + prover + .setup("ws://localhost:8080/tcp?addr=localhost%3A8010") + .await?; + + prover + .send_request( + "ws://localhost:8080/tcp?addr=localhost%3A8083", + HttpRequest { + method: Method::GET, + uri, + headers: HashMap::from([("Accept".to_string(), b"*".to_vec())]), + body: None, + }, + ) + .await?; + + prover + .reveal(Reveal { + sent: vec![0..10], + recv: vec![0..10], + }) + .await?; + + Ok(()) +} + +#[wasm_bindgen] +pub async fn test_notarize() -> Result<(), JsValue> { + let mut root_store = tls_core::anchors::RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + let provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store, None), + ..Default::default() + }; + + let prover = Prover::new( + ProverConfig::builder() + .server_name(SERVER_DOMAIN) + .protocol_config( + ProtocolConfig::builder() + .max_sent_data(1024) + .max_recv_data(1024) + .build() + .unwrap(), + ) + .crypto_provider(provider) + .build() + .unwrap(), + ); + + let mut prover = JsProver::from(prover); + + let uri = format!("https://{SERVER_DOMAIN}/bytes?size=512"); + + prover + .setup("ws://localhost:8080/tcp?addr=localhost%3A8011") + .await?; + + prover + .send_request( + "ws://localhost:8080/tcp?addr=localhost%3A8083", + HttpRequest { + method: Method::GET, + uri, + headers: HashMap::from([("Accept".to_string(), b"*".to_vec())]), + body: None, + }, + ) + .await?; + + let _ = prover.transcript()?; + + let NotarizationOutput { + attestation, + secrets, + } = prover + .notarize(Commit { + sent: vec![0..10], + recv: vec![0..10], + }) + .await?; + + let attestation = Attestation::deserialize(attestation.serialize())?; + let secrets = Secrets::deserialize(secrets.serialize())?; + + let presentation = build_presentation( + &attestation, + &secrets, + Reveal { + sent: vec![(0..10)], + recv: vec![(0..10)], + }, + )?; + + let _ = Presentation::deserialize(presentation.serialize())?; + + Ok(()) +} + +#[wasm_bindgen] +pub async fn test_verifier() -> Result<(), JsValue> { + let mut root_store = tls_core::anchors::RootCertStore::empty(); + root_store + .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) + .unwrap(); + + let provider = CryptoProvider { + cert: WebPkiVerifier::new(root_store, None), + ..Default::default() + }; + + let config = VerifierConfig::builder() + .protocol_config_validator( + ProtocolConfigValidator::builder() + .max_sent_data(1024) + .max_recv_data(1024) + .build() + .unwrap(), + ) + .crypto_provider(provider) + .build() + .unwrap(); + + let mut verifier = JsVerifier::from(Verifier::new(config)); + verifier + .connect("ws://localhost:8080/tcp?addr=localhost%3A8012") + .await?; + verifier.verify().await?; + + Ok(()) +} diff --git a/crates/wasm/src/types.rs b/crates/wasm/src/types.rs new file mode 100644 index 0000000000..2f7b0938c0 --- /dev/null +++ b/crates/wasm/src/types.rs @@ -0,0 +1,326 @@ +use std::{collections::HashMap, ops::Range}; + +use http_body_util::Full; +use hyper::body::Bytes; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; +use tlsn_core::CryptoProvider; +use tsify_next::Tsify; +use wasm_bindgen::prelude::*; + +#[derive(Debug, Tsify, Deserialize)] +#[tsify(from_wasm_abi)] +#[serde(untagged)] +#[non_exhaustive] +pub enum Body { + Json(JsonValue), +} + +#[derive(Debug, Tsify, Deserialize)] +#[tsify(from_wasm_abi)] +pub enum Method { + GET, + POST, + PUT, + DELETE, +} + +impl From for hyper::Method { + fn from(value: Method) -> Self { + match value { + Method::GET => hyper::Method::GET, + Method::POST => hyper::Method::POST, + Method::PUT => hyper::Method::PUT, + Method::DELETE => hyper::Method::DELETE, + } + } +} + +#[derive(Debug, Tsify, Deserialize)] +#[tsify(from_wasm_abi)] +pub struct HttpRequest { + pub uri: String, + pub method: Method, + pub headers: HashMap>, + pub body: Option, +} + +impl TryFrom for hyper::Request> { + type Error = JsError; + + fn try_from(value: HttpRequest) -> Result { + let mut builder = hyper::Request::builder(); + builder = builder.uri(value.uri).method(value.method); + for (name, value) in value.headers { + builder = builder.header(name, value); + } + + if let Some(body) = value.body { + let body = match body { + Body::Json(value) => Full::new(Bytes::from(serde_json::to_vec(&value).unwrap())), + }; + builder.body(body).map_err(Into::into) + } else { + builder.body(Full::new(Bytes::new())).map_err(Into::into) + } + } +} + +#[derive(Debug, Tsify, Serialize)] +#[tsify(into_wasm_abi)] +pub struct HttpResponse { + pub status: u16, + pub headers: Vec<(String, Vec)>, +} + +#[derive(Debug, Tsify, Serialize)] +#[tsify(into_wasm_abi)] +pub enum TlsVersion { + V1_2, + V1_3, +} + +impl From for TlsVersion { + fn from(value: tlsn_core::connection::TlsVersion) -> Self { + match value { + tlsn_core::connection::TlsVersion::V1_2 => Self::V1_2, + tlsn_core::connection::TlsVersion::V1_3 => Self::V1_3, + } + } +} + +#[derive(Debug, Tsify, Serialize)] +#[tsify(into_wasm_abi)] +pub struct TranscriptLength { + pub sent: usize, + pub recv: usize, +} + +impl From for TranscriptLength { + fn from(value: tlsn_core::connection::TranscriptLength) -> Self { + Self { + sent: value.sent as usize, + recv: value.received as usize, + } + } +} + +#[derive(Debug, Tsify, Serialize)] +#[tsify(into_wasm_abi)] +pub struct ConnectionInfo { + time: u64, + version: TlsVersion, + transcript_length: TranscriptLength, +} + +impl From for ConnectionInfo { + fn from(value: tlsn_core::connection::ConnectionInfo) -> Self { + Self { + time: value.time, + version: value.version.into(), + transcript_length: value.transcript_length.into(), + } + } +} + +#[derive(Debug, Tsify, Serialize)] +#[tsify(into_wasm_abi)] +pub struct Transcript { + pub sent: Vec, + pub recv: Vec, +} + +impl From<&tlsn_core::transcript::Transcript> for Transcript { + fn from(value: &tlsn_core::transcript::Transcript) -> Self { + Self { + sent: value.sent().to_vec(), + recv: value.received().to_vec(), + } + } +} + +#[derive(Debug, Tsify, Serialize)] +#[tsify(into_wasm_abi)] +pub struct PartialTranscript { + pub sent: Vec, + pub sent_authed: Vec>, + pub recv: Vec, + pub recv_authed: Vec>, +} + +impl From for PartialTranscript { + fn from(value: tlsn_core::transcript::PartialTranscript) -> Self { + Self { + sent: value.sent_unsafe().to_vec(), + sent_authed: value.sent_authed().iter_ranges().collect(), + recv: value.received_unsafe().to_vec(), + recv_authed: value.received_authed().iter_ranges().collect(), + } + } +} + +#[derive(Debug, Tsify, Deserialize)] +#[tsify(from_wasm_abi)] +pub struct Commit { + pub sent: Vec>, + pub recv: Vec>, +} + +#[derive(Debug, Tsify, Deserialize)] +#[tsify(from_wasm_abi)] +pub struct Reveal { + pub sent: Vec>, + pub recv: Vec>, +} + +#[derive(Debug, Tsify, Deserialize)] +#[tsify(from_wasm_abi)] +pub enum KeyType { + P256, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[wasm_bindgen] +#[serde(transparent)] +pub struct Attestation(pub(crate) tlsn_core::attestation::Attestation); + +#[wasm_bindgen] +impl Attestation { + pub fn verifying_key(&self) -> VerifyingKey { + self.0.body.verifying_key().into() + } + + /// Serializes to a byte array. + pub fn serialize(&self) -> Vec { + bincode::serialize(self).expect("Attestation should be serializable") + } + + /// Deserializes from a byte array. + pub fn deserialize(bytes: Vec) -> Result { + Ok(bincode::deserialize(&bytes)?) + } +} + +impl From for Attestation { + fn from(value: tlsn_core::attestation::Attestation) -> Self { + Self(value) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[wasm_bindgen] +#[serde(transparent)] +pub struct Secrets(pub(crate) tlsn_core::Secrets); + +#[wasm_bindgen] +impl Secrets { + /// Returns the transcript. + pub fn transcript(&self) -> Transcript { + self.0.transcript().into() + } + + /// Serializes to a byte array. + pub fn serialize(&self) -> Vec { + bincode::serialize(self).expect("Secrets should be serializable") + } + + /// Deserializes from a byte array. + pub fn deserialize(bytes: Vec) -> Result { + Ok(bincode::deserialize(&bytes)?) + } +} + +impl From for Secrets { + fn from(value: tlsn_core::Secrets) -> Self { + Self(value) + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[wasm_bindgen] +#[serde(transparent)] +pub struct Presentation(tlsn_core::presentation::Presentation); + +#[wasm_bindgen] +impl Presentation { + /// Returns the verifying key. + pub fn verifying_key(&self) -> VerifyingKey { + self.0.verifying_key().into() + } + + /// Verifies the presentation. + pub fn verify(&self) -> Result { + let provider = CryptoProvider::default(); + + self.0 + .clone() + .verify(&provider) + .map(PresentationOutput::from) + .map_err(JsError::from) + } + + pub fn serialize(&self) -> Vec { + bincode::serialize(self).expect("Presentation should be serializable") + } + + pub fn deserialize(bytes: Vec) -> Result { + Ok(bincode::deserialize(&bytes)?) + } +} + +impl From for Presentation { + fn from(value: tlsn_core::presentation::Presentation) -> Self { + Self(value) + } +} + +#[derive(Debug, Tsify, Serialize)] +#[tsify(into_wasm_abi)] +pub struct PresentationOutput { + pub attestation: Attestation, + pub server_name: Option, + pub connection_info: ConnectionInfo, + pub transcript: Option, +} + +impl From for PresentationOutput { + fn from(value: tlsn_core::presentation::PresentationOutput) -> Self { + Self { + attestation: value.attestation.into(), + server_name: value.server_name.map(|name| name.as_str().to_string()), + connection_info: value.connection_info.into(), + transcript: value.transcript.map(PartialTranscript::from), + } + } +} + +#[derive(Debug, Serialize)] +#[wasm_bindgen(getter_with_clone)] +pub struct NotarizationOutput { + pub attestation: Attestation, + pub secrets: Secrets, +} + +#[derive(Debug, Tsify, Serialize)] +#[tsify(into_wasm_abi)] +pub struct VerifierOutput { + pub server_name: String, + pub connection_info: ConnectionInfo, + pub transcript: PartialTranscript, +} + +#[derive(Debug, Tsify, Serialize)] +#[tsify(into_wasm_abi)] +pub struct VerifyingKey { + pub alg: u8, + pub data: Vec, +} + +impl From<&tlsn_core::signing::VerifyingKey> for VerifyingKey { + fn from(value: &tlsn_core::signing::VerifyingKey) -> Self { + Self { + alg: value.alg.as_u8(), + data: value.data.clone(), + } + } +} diff --git a/crates/wasm/src/verifier/config.rs b/crates/wasm/src/verifier/config.rs new file mode 100644 index 0000000000..6c55338c74 --- /dev/null +++ b/crates/wasm/src/verifier/config.rs @@ -0,0 +1,26 @@ +use serde::Deserialize; +use tlsn_common::config::ProtocolConfigValidator; +use tsify_next::Tsify; + +#[derive(Debug, Tsify, Deserialize)] +#[tsify(from_wasm_abi)] +pub struct VerifierConfig { + pub max_sent_data: usize, + pub max_recv_data: usize, +} + +impl From for tlsn_verifier::VerifierConfig { + fn from(value: VerifierConfig) -> Self { + let mut builder = ProtocolConfigValidator::builder(); + + builder.max_sent_data(value.max_sent_data); + builder.max_recv_data(value.max_recv_data); + + let validator = builder.build().unwrap(); + + tlsn_verifier::VerifierConfig::builder() + .protocol_config_validator(validator) + .build() + .unwrap() + } +} diff --git a/crates/wasm/src/verifier/mod.rs b/crates/wasm/src/verifier/mod.rs new file mode 100644 index 0000000000..6af7ae8933 --- /dev/null +++ b/crates/wasm/src/verifier/mod.rs @@ -0,0 +1,90 @@ +mod config; + +pub use config::VerifierConfig; + +use enum_try_as_inner::EnumTryAsInner; +use tlsn_verifier::{ + state::{self, Initialized}, + Verifier, +}; +use tracing::info; +use wasm_bindgen::prelude::*; +use ws_stream_wasm::{WsMeta, WsStream}; + +use crate::types::VerifierOutput; + +type Result = std::result::Result; + +#[wasm_bindgen(js_name = Verifier)] +pub struct JsVerifier { + state: State, +} + +#[derive(EnumTryAsInner)] +#[derive_err(Debug)] +enum State { + Initialized(Verifier), + Connected((Verifier, WsStream)), + Complete, + Error, +} + +impl std::fmt::Debug for State { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "State") + } +} + +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, State::Error) + } +} + +#[wasm_bindgen(js_class = Verifier)] +impl JsVerifier { + #[wasm_bindgen(constructor)] + pub fn new(config: VerifierConfig) -> JsVerifier { + JsVerifier { + state: State::Initialized(Verifier::new(config.into())), + } + } + + /// Connect to the prover. + pub async fn connect(&mut self, prover_url: &str) -> Result<()> { + let verifier = self.state.take().try_into_initialized()?; + + info!("Connecting to prover"); + + let (_, prover_conn) = WsMeta::connect(prover_url, None).await?; + + info!("Connected to prover"); + + self.state = State::Connected((verifier, prover_conn)); + + Ok(()) + } + + /// Verifies the connection and finalizes the protocol. + pub async fn verify(&mut self) -> Result { + let (verifier, prover_conn) = self.state.take().try_into_connected()?; + + let (transcript, info) = verifier.verify(prover_conn.into_io()).await?; + + self.state = State::Complete; + + Ok(VerifierOutput { + server_name: info.server_name.as_str().to_string(), + connection_info: info.connection_info.into(), + transcript: transcript.into(), + }) + } +} + +impl From> for JsVerifier { + fn from(value: tlsn_verifier::Verifier) -> Self { + Self { + state: State::Initialized(value), + } + } +} diff --git a/notary-server/Cargo.toml b/notary-server/Cargo.toml deleted file mode 100644 index a62e0b9dd8..0000000000 --- a/notary-server/Cargo.toml +++ /dev/null @@ -1,49 +0,0 @@ -[package] -name = "notary-server" -version = "0.1.0-alpha.3" -edition = "2021" - -[dependencies] -async-trait = "0.1.67" -async-tungstenite = { version = "0.22.2", features = ["tokio-native-tls"] } -axum = { version = "0.6.18", features = ["ws"] } -axum-core = "0.3.4" -axum-macros = "0.3.8" -base64 = "0.21.0" -chrono = "0.4.31" -csv = "1.3.0" -eyre = "0.6.8" -futures = "0.3" -futures-util = "0.3.28" -http = "0.2.9" -hyper = { version = "0.14", features = ["client", "http1", "server", "tcp"] } -opentelemetry = { version = "0.19" } -p256 = "0.13" -rstest = "0.18" -rustls = { version = "0.21" } -rustls-pemfile = { version = "1.0.2" } -serde = { version = "1.0.147", features = ["derive"] } -serde_json = "1.0" -serde_yaml = "0.9.21" -sha1 = "0.10" -structopt = "0.3.26" -thiserror = "1" -tlsn-verifier = { path = "../tlsn/tlsn-verifier" } -tlsn-tls-core = { path = "../components/tls/tls-core" } -tokio = { version = "1", features = ["full"] } -tokio-rustls = { version = "0.24.1" } -tokio-util = { version = "0.7", features = ["compat"] } -tower = { version = "0.4.12", features = ["make"] } -tower-http = { version = "0.4.4", features = ["cors"] } -tracing = "0.1" -tracing-opentelemetry = "0.19" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -uuid = { version = "1.4.1", features = ["v4", "fast-rng"] } -ws_stream_tungstenite = { version = "0.10.0", features = ["tokio_io"] } - -[dev-dependencies] -# specify vendored feature to use statically linked copy of OpenSSL -hyper-tls = { version = "0.5.0", features = ["vendored"] } -tls-server-fixture = { path = "../components/tls/tls-server-fixture" } -tlsn-prover = { path = "../tlsn/tlsn-prover" } -tokio-native-tls = { version = "0.3.1", features = ["vendored"] } diff --git a/notary-server/README.md b/notary-server/README.md deleted file mode 100644 index bf268d0294..0000000000 --- a/notary-server/README.md +++ /dev/null @@ -1,117 +0,0 @@ -# notary-server - -An implementation of the notary server in Rust. - -## ⚠️ Notice - -This crate is currently under active development and should not be used in production. Expect bugs and regular major breaking changes. - ---- -## Running the server -⚠️ To run this server in a *production environment*, please first read this [page](https://docs.tlsnotary.org/developers/notary_server.html). - -### Using Cargo -1. Configure the server setting in this config [file](./config/config.yaml) — refer [here](./src/config.rs) for more information on the definition of the setting parameters. -2. Start the server by running the following in a terminal at the root of this crate. -```bash -cargo run --release -``` -3. To use a config file from a different location, run the following command to override the default config file location. -```bash -cargo run --release -- --config-file -``` - -### Using Docker -There are two ways to obtain the notary server's Docker image: -- [GitHub](#obtaining-the-image-via-github) -- [Building from source](#building-from-source) - -#### GitHub -1. Obtain the latest image with: -```bash -docker pull ghcr.io/tlsnotary/tlsn/notary-server:latest -``` -2. Run the docker container with: -```bash -docker run --init -p 127.0.0.1:7047:7047 ghcr.io/tlsnotary/tlsn/notary-server:latest -``` -3. If you want to change the default configuration, create a `config` folder locally, that contains a `config.yaml`, whose content follows the format of the default config file [here](./config/config.yaml). -4. Instead of step 2, run the docker container with the following (remember to change the port mapping if you have changed that in the config): -```bash -docker run --init -p 127.0.0.1:7047:7047 -v :/root/.notary-server/config ghcr.io/tlsnotary/tlsn/notary-server:latest -``` - -⚠️ When running this notary-server image against a [prover](https://github.com/tlsnotary/tlsn/tree/3e0dcc77d5b8b7d6739ca725f36345108ebecd75/tlsn/examples), please ensure that the prover's tagged version is the same as the version tag of this image. - -#### Building from source -1. Configure the server setting in this config [file](./config/config.yaml). -2. Build the docker image by running the following in a terminal at the root of this *repository*. -```bash -docker build . -t notary-server:local -f notary-server/notary-server.Dockerfile -``` -3. Run the docker container and specify the port specified in the config file, e.g. for the default port 7047 -```bash -docker run --init -p 127.0.0.1:7047:7047 notary-server:local -``` - -### Using different setting files with Docker -1. Instead of changing the key/cert/auth file path(s) in the config file, create a folder containing your key/cert/auth files by following the folder structure [here](./fixture/). -2. When launching the docker container, mount your folder onto the docker container at the relevant path prefixed by `/root/.notary-server`. -- Example 1: Using different key, cert, and auth files: -```bash -docker run --init -p 127.0.0.1:7047:7047 -v :/root/.notary-server/fixture notary-server:local -``` -- Example 2: Using different key for notarization: -```bash -docker run --init -p 127.0.0.1:7047:7047 -v :/root/.notary-server/fixture/notary notary-server:local -``` ---- -## API -All APIs are TLS-protected, hence please use `https://` or `wss://`. -### HTTP APIs -Defined in the [OpenAPI specification](./openapi.yaml). - -### WebSocket APIs -#### /notarize -##### Description -To perform notarization using the session id (unique id returned upon calling the `/session` endpoint successfully). - -##### Query Parameter -`sessionId` - -##### Query Parameter Type -String - ---- -## Architecture -### Objective -The main objective of a notary server is to perform notarization together with a prover. In this case, the prover can either be -1. TCP client — which has access and control over the transport layer, i.e. TCP -2. WebSocket client — which has no access over TCP and instead uses WebSocket for notarization - -### Features -#### Notarization Configuration -To perform notarization, some parameters need to be configured by the prover and notary server (more details in the [OpenAPI specification](./openapi.yaml)), i.e. -- maximum transcript size -- unique session id - -To streamline this process, a single HTTP endpoint (`/session`) is used by both TCP and WebSocket clients. - -#### Notarization -After calling the configuration endpoint above, prover can proceed to start notarization. For TCP client, that means calling the `/notarize` endpoint using HTTP (`https`), while WebSocket client should call the same endpoint but using WebSocket (`wss`). Example implementations of these clients can be found in the [integration test](./tests/integration_test.rs). - -#### Signatures -Currently, both the private key (and cert) used to establish TLS connection with prover, and the private key used by notary server to sign the notarized transcript, are hardcoded PEM keys stored in this repository. Though the paths of these keys can be changed in the config to use different keys instead. - -#### Authorization -An optional authorization module is available to only allow requests with valid API key attached in the authorization header. The API key whitelist path (as well as the flag to enable/disable this module) is configurable [here](./config/config.yaml). - -#### Optional TLS -TLS between prover and notary is currently manually handled in the server, though it can be turned off if TLS is to be handled by an external environment, e.g. reverse proxy, cloud setup (configurable [here](./config/config.yaml)). - -### Design Choices -#### Web Framework -Axum is chosen as the framework to serve HTTP and WebSocket requests from the prover clients due to its rich and well supported features, e.g. native integration with Tokio/Hyper/Tower, customizable middleware, ability to support lower level integration of TLS ([example](https://github.com/tokio-rs/axum/blob/main/examples/low-level-rustls/src/main.rs)). To simplify the notary server setup, a single Axum router is used to support both HTTP and WebSocket connections, i.e. all requests can be made to the same port of the notary server. - -#### WebSocket -Axum's internal implementation of WebSocket uses [tokio_tungstenite](https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/), which provides a WebSocket struct that doesn't implement [AsyncRead](https://docs.rs/futures/latest/futures/io/trait.AsyncRead.html) and [AsyncWrite](https://docs.rs/futures/latest/futures/io/trait.AsyncWrite.html). Both these traits are required by TLSN core libraries for prover and notary. To overcome this, a [slight modification](./src/service/axum_websocket.rs) of Axum's implementation of WebSocket is used, where [async_tungstenite](https://docs.rs/async-tungstenite/latest/async_tungstenite/) is used instead so that [ws_stream_tungstenite](https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html) can be used to wrap on top of the WebSocket struct to get AsyncRead and AsyncWrite implemented. diff --git a/notary-server/config/config.yaml b/notary-server/config/config.yaml deleted file mode 100644 index 71d7577fec..0000000000 --- a/notary-server/config/config.yaml +++ /dev/null @@ -1,23 +0,0 @@ -server: - name: "notary-server" - host: "0.0.0.0" - port: 7047 - -notarization: - max-transcript-size: 16384 - -tls: - enabled: true - private-key-pem-path: "./fixture/tls/notary.key" - certificate-pem-path: "./fixture/tls/notary.crt" - -notary-key: - private-key-pem-path: "./fixture/notary/notary.key" - public-key-pem-path: "./fixture/notary/notary.pub" - -tracing: - default-level: DEBUG - -authorization: - enabled: false - whitelist-csv-path: "./fixture/auth/whitelist.csv" diff --git a/notary-server/fixture/notary/notary.key b/notary-server/fixture/notary/notary.key deleted file mode 100644 index a88cf51f80..0000000000 --- a/notary-server/fixture/notary/notary.key +++ /dev/null @@ -1,5 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgEvBc/VMWn3E4PGfe -ETc/ekdTRmRwNN9J6eKDPxJ98ZmhRANCAAQG/foUjhkWzMlrQNAUnfBYJe9UsWtx -HMwbmRpN4cahLMO7pwWrHe4RZikUajoLQQ5SB/6YSBuS0utehy/nIfMq ------END PRIVATE KEY----- diff --git a/notary-server/fixture/notary/notary.pub b/notary-server/fixture/notary/notary.pub deleted file mode 100644 index fa63c8d282..0000000000 --- a/notary-server/fixture/notary/notary.pub +++ /dev/null @@ -1,4 +0,0 @@ ------BEGIN PUBLIC KEY----- -MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEBv36FI4ZFszJa0DQFJ3wWCXvVLFr -cRzMG5kaTeHGoSzDu6cFqx3uEWYpFGo6C0EOUgf+mEgbktLrXocv5yHzKg== ------END PUBLIC KEY----- diff --git a/notary-server/fixture/tls/notary.crt b/notary-server/fixture/tls/notary.crt deleted file mode 100644 index 17c52c3558..0000000000 --- a/notary-server/fixture/tls/notary.crt +++ /dev/null @@ -1,20 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDTzCCAjegAwIBAgIJALo+PtyTmxELMA0GCSqGSIb3DQEBCwUAMCgxEjAQBgNV -BAoMCXRsc25vdGFyeTESMBAGA1UEAwwJdGxzbm90YXJ5MB4XDTIzMDYyNjE2MTI1 -N1oXDTI0MDYyNTE2MTI1N1owNDEYMBYGA1UECgwPdGxzbm90YXJ5c2VydmVyMRgw -FgYDVQQDDA90bHNub3RhcnlzZXJ2ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw -ggEKAoIBAQCqo+rOvL/l3ehVLrOBpzQrjWClV03rl+xiDIElEcVSz017gvoHX0ti -+etBHX+plJOhVRQrO+a3QeYv7NqDnQKIMozsStClkK4MagU1114JO+z4eArfQFDv -Czq2VYwDYBmLj4Lz0y54oQLyy/O8ON/ganYaW/3quGhufo+d774m8qCjSvhdTBnL -h1GxiZKfM8PFaRmBCMGa4mViTlpmnZq7eDzLumlh8WeOTFIWmbjNL4DkMJKu/gHK -5uOCtIUkFPezIN88Pq6wC88jRihXM7hrJUofPZZKzkDpwydGxol9fS0kiMANG6L8 -CUIeQhMDElCV/XiAXHi4MtH93XWjTR3VAgMBAAGjcDBuMEIGA1UdIwQ7MDmhLKQq -MCgxEjAQBgNVBAoMCXRsc25vdGFyeTESMBAGA1UEAwwJdGxzbm90YXJ5ggkAwxok -9FN4wLMwCQYDVR0TBAIwADAdBgNVHREEFjAUghJ0bHNub3RhcnlzZXJ2ZXIuaW8w -DQYJKoZIhvcNAQELBQADggEBAByvWsHE5qZYAJT1io1mwVQdXkDnlVjT/GAdu/Mx -EoUPJ9Pt/1XiS1dWXJMIZFbfOiZJBnX+sKxPpy/flaI4kbnXJY8nB5gFPkLWI7ok -V+r2iqEapsX3zrLx7x3AAM2kJbTieMLaGWe9g40wkGzmnpFJf5W8SgI2JEc4KlDo -joQJtsJa85PeOGtMsKLXnqUofDHbvDR0ab9obkh4Ngw+D1CGVXEGduCx1+SwB1jO -eDysCo+8ikyrrlzyDR1OyFJW28WVzLRJH0Z2bwldekM1RvCXqBYeLtAgNtS3Xb1w -RVP9VAx7KlmNF6kG52R2dQ1Z7J7i8JIZEkBcKjITEmpKrfE= ------END CERTIFICATE----- diff --git a/notary-server/fixture/tls/notary.csr b/notary-server/fixture/tls/notary.csr deleted file mode 100644 index 069d205b66..0000000000 --- a/notary-server/fixture/tls/notary.csr +++ /dev/null @@ -1,16 +0,0 @@ ------BEGIN CERTIFICATE REQUEST----- -MIICeTCCAWECAQAwNDEYMBYGA1UECgwPdGxzbm90YXJ5c2VydmVyMRgwFgYDVQQD -DA90bHNub3RhcnlzZXJ2ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB -AQCqo+rOvL/l3ehVLrOBpzQrjWClV03rl+xiDIElEcVSz017gvoHX0ti+etBHX+p -lJOhVRQrO+a3QeYv7NqDnQKIMozsStClkK4MagU1114JO+z4eArfQFDvCzq2VYwD -YBmLj4Lz0y54oQLyy/O8ON/ganYaW/3quGhufo+d774m8qCjSvhdTBnLh1GxiZKf -M8PFaRmBCMGa4mViTlpmnZq7eDzLumlh8WeOTFIWmbjNL4DkMJKu/gHK5uOCtIUk -FPezIN88Pq6wC88jRihXM7hrJUofPZZKzkDpwydGxol9fS0kiMANG6L8CUIeQhMD -ElCV/XiAXHi4MtH93XWjTR3VAgMBAAGgADANBgkqhkiG9w0BAQsFAAOCAQEAqkWE -FyI9r3cY3tXt6j0/xGYZCX3X1AGje7vcEUeYzlED32putmH96Fkia+X2CMpEwcn7 -jaojJWvtAKGAk46p/cRpbPEOhLLebXn4znaeBVF5ph283WmeExRlhQml0e7kwTs9 -MwSniEFKBtvq4cSqO7BM1+NXDpjauVpaACl2+E9KTE8LcGG0BvH2eJOM/yW6wZmG -ykgyMeSg5UV/i5STWlryeaGBLCCmXx4jVfkBgaXw2Zq4ve1F/qU/eQFNUPk/iRSh -aQEQIfEC0hwqEe2Nc7X6PoVd7Py/x7Bke1JP9mRI7EPoN/IT0XHanJ08tusDCcYG -omGrHBGk9mELh39TXQ== ------END CERTIFICATE REQUEST----- diff --git a/notary-server/fixture/tls/notary.ext b/notary-server/fixture/tls/notary.ext deleted file mode 100644 index b9a398df9b..0000000000 --- a/notary-server/fixture/tls/notary.ext +++ /dev/null @@ -1,5 +0,0 @@ -authorityKeyIdentifier=keyid,issuer -basicConstraints=CA:FALSE -subjectAltName = @alt_names -[alt_names] -DNS.1 = tlsnotaryserver.io diff --git a/notary-server/fixture/tls/notary.key b/notary-server/fixture/tls/notary.key deleted file mode 100644 index 9e52bce222..0000000000 --- a/notary-server/fixture/tls/notary.key +++ /dev/null @@ -1,28 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCqo+rOvL/l3ehV -LrOBpzQrjWClV03rl+xiDIElEcVSz017gvoHX0ti+etBHX+plJOhVRQrO+a3QeYv -7NqDnQKIMozsStClkK4MagU1114JO+z4eArfQFDvCzq2VYwDYBmLj4Lz0y54oQLy -y/O8ON/ganYaW/3quGhufo+d774m8qCjSvhdTBnLh1GxiZKfM8PFaRmBCMGa4mVi -TlpmnZq7eDzLumlh8WeOTFIWmbjNL4DkMJKu/gHK5uOCtIUkFPezIN88Pq6wC88j -RihXM7hrJUofPZZKzkDpwydGxol9fS0kiMANG6L8CUIeQhMDElCV/XiAXHi4MtH9 -3XWjTR3VAgMBAAECggEBAIYDgk+nMVbIdsUfjl8PAAwMVpDEBjA2+rDufSat1Dj7 -EjEkZlUP5FbxTG+xSSfXxjH4bYSe4M2f9bZB4ENpNinc+YxCHadJ/0dEpJ7qa7H4 -3F0veepnyqhSO2Qjv3iPKsDOjtwLSP34BibFQsDaMgk/001UXhDPj0ToJMa3GLHg -pw1G2ri4WO4NxQA354y61jBNy0D4mjHHlcnofi4iLOFr2Kf538f0RgUyw8EkZ2sE -QyqL5HpHE93qIuLzl3/NxjQNHfO99dNNl6oWzPmXGi0nPGCMith3+8dMH7QiR/sS -r2bjdusIccV3tlZqCJUdWDC/RVgVQKDV3pWBx+i/1gECgYEA19UrQAwgxaimcs7E -NXISBzm2XgOOg9e5/W5EJvObu9zflqB3CBvrdZhhl1ZR+hTAbP5rHZQHWkMMAFbD -dT+VYIqTWUCIDkFpcB7vNa41A5eIbSdz1W+V4ZdAOUuwFcaG2xQA/F37S5DvH12V -JZ4ktJQklYUKmlUXSDTDgRUiApUCgYEAymWqTJjlSi1ADa3bAENaSKGaTazoeWBF -OesnwTLYCT+Aap+3aMjnG5+gxlSbdfJ1odrahXA3VSZwUL0IvCg8HcJsvwc0Bw3/ -LUpwCk2yLlq+OsOtpQSgsOVOzKzXJTEnjHBZxyInJsuTb+Kf5qn7/zvGQSNvbePT -h+YMAwGpHkECgYEAmg3DkzGU6sCYHfZLwkIrcBDXhH9RZ/XBAY2FA7B6Bjt/NApR -K+6RwBwF/HlWhgPt3V4zoqcYIGse09caKEQ8IO6Igfo3osU5txe9cjloCapNbGvu -l/fPqXfGFZ9ajhBoDVNX6MpEJgnLRD4NyQ358RKUkkyl5sa5mYZfzXECF4kCgYA+ -PcmDSLmqeAPssPxaNlw7XccQAA511Q804oYVOceKAIdDQt6qUK4RpqNQmpA8U1Wt -cpok0v+RJgMAMUHQaychl7rNfC+Zw8onaW7PHFmhO7Koa6ioyKWKANqcwsJe46Df -5WUWggA8Q/qRO8Ykrz2Zng431efciWVxs2MaQZZ6gQKBgA+QdsMadsrWoqh3tCZA -uruQ7hXCfALJfgexWFAtLIwlHujXI81+YVICCutCe6riktPZ/zTz/nbEAyt7kIiz -6BF7UYGT29qu1rC0MLzjfwK9ExuMSkvy9ZXGM0bCEgANIkZ0A/zTtTeRaFFGbx6l -F1Y+ihMVuZ4rOGQbVUfQxz1F ------END PRIVATE KEY----- diff --git a/notary-server/fixture/tls/rootCA.crt b/notary-server/fixture/tls/rootCA.crt deleted file mode 100644 index d12936ff51..0000000000 --- a/notary-server/fixture/tls/rootCA.crt +++ /dev/null @@ -1,17 +0,0 @@ ------BEGIN CERTIFICATE----- -MIICzDCCAbQCCQDDGiT0U3jAszANBgkqhkiG9w0BAQsFADAoMRIwEAYDVQQKDAl0 -bHNub3RhcnkxEjAQBgNVBAMMCXRsc25vdGFyeTAeFw0yMzA2MjYxMTE2MTZaFw0y -ODA2MjQxMTE2MTZaMCgxEjAQBgNVBAoMCXRsc25vdGFyeTESMBAGA1UEAwwJdGxz -bm90YXJ5MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA7Vf+O9l4WNXE -Xh48MwjnvZ9wGN/Ls+jzzF1Q+J/QfXAYR/REQgJQmuk6sBgJyXUW7Dr5dKAY5tfL -rjfSaLhdMSxBH/tMepf5HVfEo6jvgk1bdR43DIZw7Z0hfuGUo6qOue8LZry2Nl+9 -VZpG64quRZ///4LdMBQyXcS2yeWKU10yVNBvstKW0i8krqQfbWOIG1nu5nDg5onB -paKUvbyrLyuHLz8gzKDFezxADTugq2KRXYKIZmyRucK+kmnJnZ/k46GZ84Vju15v -ktC0CvaR9IfvLfJMAo1Y0lUR4HjQkEAfjnDFYj5B18KFxXABraVD8UxjeMbAHTjf -i1lV0yp+qQIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQABxRni6FZIFeK0KCS1Nrks -ONLVPfvDSNEKpImWFFoJbaSAAankTiQM1nKTY9SRIhqG2t+xJ6c8+qe905lFFvOy -r85LMb3z2ZWs4ez6Uy6IdpSdkTULk+1huE/Y9ZqRJ5aQy7PqiHTe+mNDFmHXGdcS -azHywd4hQeRQhCBXlAG7I18uZR9DPtGaJnvZlfbpD6Iq7x3ocfGhQiV9VJS1JaQ3 -Z7CJs2pa4da5FXQMAbKI2f7V5kbn3bjMp57yeYFo5wJMhEeSFqkrojR0oZDzfxW9 -b0W/PI4R4d2hUvX0fwrQyXbGo8HvYDFUhlMMSF60gUNcbpF6P93tXxR2FM/hnu+T ------END CERTIFICATE----- diff --git a/notary-server/fixture/tls/rootCA.key b/notary-server/fixture/tls/rootCA.key deleted file mode 100644 index ca615a9799..0000000000 --- a/notary-server/fixture/tls/rootCA.key +++ /dev/null @@ -1,30 +0,0 @@ ------BEGIN ENCRYPTED PRIVATE KEY----- -MIIFHzBJBgkqhkiG9w0BBQ0wPDAbBgkqhkiG9w0BBQwwDgQIODItYjQ2oGICAggA -MB0GCWCGSAFlAwQBKgQQOVlpTszSmqOQ13RqJ0k9vQSCBNDOS9KT7QzKhXaSquKQ -vNylu9+hwkO2+SyVqkf82kHNezEr45r6DuxW0tQJhd93v6DMqGKS9LvFf0qshM2t -OhMD52PVvwHA4Fg4xMQORvvmHOHw7sxzrsQAWIjZ2cPpCRX2zcHYC5zwIaICBBdl -1qzUiO7n0nw77GRBUFoX0eSJPTkH42Nbc/nB+oa5t+n8mxY89Hqdh2wO5k569vkW -7ZC3FphzjPXIWk885qpZ7/O9eeR1OEhBf07PYqzar8RPDU02/lrUZ6Y+QMh69W/y -zNl7CGjy0IuDao77IFsu7rk6cBFq3uTXSyoyDoLmpkLbJEnS+ydjRyfbOjHJfG10 -Ca82mBA51IFs2Werf4+hzyM1EsvlGaGz2vhUffMK2lvJjoWocCi//E2XGQb+7hdY -HOP8sEWuDRqHGtZlMpIxp1JJs5jm+7RGCtgtO/tTb5hCdCB8msqNmCugzjUtg9ZJ -z3BygoRaBVgjBPKtDNo6NHf0Hfq1NzwImed3kTJAgNqkTxpCWYxvuJOWFROvps2t -jGkpGN0S17Ee9JAdsmsz3Mr92OXi9ncig3YTCI+WwN06xjjyIVPvtGDhvP30dPft -A6B1TgeS8g3yEFPkXjDqpKYsiZm38iN+GIcncgU1ng8r1gUnDzwZHwDFPAq/PwYC -sB4CA097eEaFLz49ttcO5ZXl7/pxMYKsBsUJQIG5MQcnyUBugQB2hVT+8WPjAjQE -ocScxHfMQAofMXn+Nwv+J5WlswZjoaXdqa/GG2AKHMPbrlMBdPDADMe0KGWxSKvt -HIbvhd/i+mXS2bwNDzHRCk3GEgdgEzOBrUbxzzJiN5EhRLqxuI2VnMUu5JcjxO/k -X2X6Ekpuc0D3PKezCGC98JbtwfVb5A1vBmVaD7ZMXSLCdBU+27qq96txsKP+WRzL -kcnp3EMfTscIgJOApslwkncGar6lsgfzBVD0bQy3luPEvEfhY+UOisApYsJEkuXy -HEbALLAvSibf1+1YCq00KTd7ERboCJ4N4e5ONE8wJBLoRMvRvxbi4zZsW4sApSjP -3r2P5FuB5x2VGlYo3BFd5yzAYzPQl+dFc7wg8yDohKNOVa9XAuNgDhXG+RS7imNL -PP3BIMFuj8+uH0rLRtse+pVhXKQ9pqvgvTpvGAKzqHjFHmH9GDKTnntL3B7kjRa4 -n/0DKT99iHEekvBAEN2qjIuYq0/xiSKSkVeGVksPUtZ/8V2SKdBRXnuDNlItaICW -bgZH/qCisNKx53jSv25yfoq0+rJtxLDNVuPFdKcQa97SsiUqElnxH/gwqSDLViYn -B3XsX+RKZJVNm1rqaoYf7yRFpddld4BgqUYQEj3rAyIelMGuTPSaiJxMysENnJl6 -MPdfug0NoUYJC+xle4YBeMjLj2qout7kq08414ZOEF4MdR6y1oqcM1IW14nSPAdG -O40Kma3V+aTJDIOY3cbS13f9yFY0n7KrD7dZ7faXDRT81LEhEVo9RZ01PBvES6lx -amL1CxQmVrhUD3YbeKvc17tvfOa67YU+6g0ELKFoWf2/5OJyj6wchACU98JyYonr -RGKSaL2zSDiFPTUKYscsFDMZuRtXs98okwoIbK8TujHTWZnI5DrBj2XLiVwPRYiu -dkFsSOjDwjmd11CglKpiBh5a7A== ------END ENCRYPTED PRIVATE KEY----- diff --git a/notary-server/fixture/tls/rootCA.srl b/notary-server/fixture/tls/rootCA.srl deleted file mode 100644 index 4c6854ba78..0000000000 --- a/notary-server/fixture/tls/rootCA.srl +++ /dev/null @@ -1 +0,0 @@ -BA3E3EDC939B110B diff --git a/notary-server/src/config.rs b/notary-server/src/config.rs deleted file mode 100644 index b4b0c37267..0000000000 --- a/notary-server/src/config.rs +++ /dev/null @@ -1,66 +0,0 @@ -use serde::Deserialize; - -#[derive(Clone, Debug, Deserialize)] -#[serde(rename_all = "kebab-case")] -pub struct NotaryServerProperties { - /// Name and address of the notary server - pub server: ServerProperties, - /// Setting for notarization - pub notarization: NotarizationProperties, - /// Setting for TLS connection between prover and notary - pub tls: TLSProperties, - /// File path of private key (in PEM format) used to sign the notarization - pub notary_key: NotarySigningKeyProperties, - /// Setting for logging/tracing - pub tracing: TracingProperties, - /// Setting for authorization - pub authorization: AuthorizationProperties, -} - -#[derive(Clone, Debug, Deserialize)] -#[serde(rename_all = "kebab-case")] -pub struct AuthorizationProperties { - /// Switch to turn on or off auth middleware - pub enabled: bool, - /// File path of the whitelist API key csv - pub whitelist_csv_path: String, -} - -#[derive(Clone, Debug, Deserialize)] -#[serde(rename_all = "kebab-case")] -pub struct NotarizationProperties { - /// Global limit for maximum transcript size in bytes - pub max_transcript_size: usize, -} - -#[derive(Clone, Debug, Deserialize)] -#[serde(rename_all = "kebab-case")] -pub struct ServerProperties { - /// Used for testing purpose - pub name: String, - pub host: String, - pub port: u16, -} - -#[derive(Clone, Debug, Deserialize)] -#[serde(rename_all = "kebab-case")] -pub struct TLSProperties { - /// Flag to turn on/off TLS between prover and notary (should always be turned on unless TLS is handled by external setup e.g. reverse proxy, cloud) - pub enabled: bool, - pub private_key_pem_path: String, - pub certificate_pem_path: String, -} - -#[derive(Clone, Debug, Deserialize)] -#[serde(rename_all = "kebab-case")] -pub struct NotarySigningKeyProperties { - pub private_key_pem_path: String, - pub public_key_pem_path: String, -} - -#[derive(Clone, Debug, Deserialize)] -#[serde(rename_all = "kebab-case")] -pub struct TracingProperties { - /// The minimum logging level, must be either of - pub default_level: String, -} diff --git a/notary-server/src/server.rs b/notary-server/src/server.rs deleted file mode 100644 index 0e149d6c38..0000000000 --- a/notary-server/src/server.rs +++ /dev/null @@ -1,273 +0,0 @@ -use axum::{ - http::{Request, StatusCode}, - middleware::from_extractor_with_state, - response::IntoResponse, - routing::{get, post}, - Json, Router, -}; -use eyre::{ensure, eyre, Result}; -use futures_util::future::poll_fn; -use hyper::server::{ - accept::Accept, - conn::{AddrIncoming, Http}, -}; -use p256::{ecdsa::SigningKey, pkcs8::DecodePrivateKey}; -use rustls::{Certificate, PrivateKey, ServerConfig}; -use std::{ - fs::File as StdFile, - io::BufReader, - net::{IpAddr, SocketAddr}, - pin::Pin, - sync::Arc, -}; -use tower_http::cors::CorsLayer; - -use tokio::{fs::File, net::TcpListener}; -use tokio_rustls::TlsAcceptor; -use tower::MakeService; -use tracing::{debug, error, info}; - -use crate::{ - config::{NotaryServerProperties, NotarySigningKeyProperties}, - domain::{ - auth::{authorization_whitelist_vec_into_hashmap, AuthorizationWhitelistRecord}, - notary::NotaryGlobals, - InfoResponse, - }, - error::NotaryServerError, - middleware::AuthorizationMiddleware, - service::{initialize, upgrade_protocol}, - util::parse_csv_file, -}; - -/// Start a TCP server (with or without TLS) to accept notarization request for both TCP and WebSocket clients -#[tracing::instrument(skip(config))] -pub async fn run_server(config: &NotaryServerProperties) -> Result<(), NotaryServerError> { - // Load the private key for notarized transcript signing - let notary_signing_key = load_notary_signing_key(&config.notary_key).await?; - // Build TLS acceptor if it is turned on - let tls_acceptor = if !config.tls.enabled { - debug!("Skipping TLS setup as it is turned off."); - None - } else { - let (tls_private_key, tls_certificates) = load_tls_key_and_cert( - &config.tls.private_key_pem_path, - &config.tls.certificate_pem_path, - ) - .await?; - - let mut server_config = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(tls_certificates, tls_private_key) - .map_err(|err| eyre!("Failed to instantiate notary server tls config: {err}"))?; - - // Set the http protocols we support - server_config.alpn_protocols = vec![b"http/1.1".to_vec()]; - let tls_config = Arc::new(server_config); - Some(TlsAcceptor::from(tls_config)) - }; - - // Load the authorization whitelist csv if it is turned on - let authorization_whitelist = if !config.authorization.enabled { - debug!("Skipping authorization as it is turned off."); - None - } else { - // Load the csv - let whitelist_csv = parse_csv_file::( - &config.authorization.whitelist_csv_path, - ) - .map_err(|err| eyre!("Failed to parse authorization whitelist csv: {:?}", err))?; - // Convert the whitelist record into hashmap for faster lookup - Some(authorization_whitelist_vec_into_hashmap(whitelist_csv)) - }; - - let notary_address = SocketAddr::new( - IpAddr::V4(config.server.host.parse().map_err(|err| { - eyre!("Failed to parse notary host address from server config: {err}") - })?), - config.server.port, - ); - let listener = TcpListener::bind(notary_address) - .await - .map_err(|err| eyre!("Failed to bind server address to tcp listener: {err}"))?; - let mut listener = AddrIncoming::from_listener(listener) - .map_err(|err| eyre!("Failed to build hyper tcp listener: {err}"))?; - - info!("Listening for TCP traffic at {}", notary_address); - - let protocol = Arc::new(Http::new()); - let notary_globals = NotaryGlobals::new( - notary_signing_key, - config.notarization.clone(), - // Use Arc to prevent cloning the whitelist for every request - authorization_whitelist.map(Arc::new), - ); - - // Parameters needed for the info endpoint - let public_key = std::fs::read_to_string(&config.notary_key.public_key_pem_path) - .map_err(|err| eyre!("Failed to load notary public signing key for notarization: {err}"))?; - let version = env!("CARGO_PKG_VERSION").to_string(); - let git_commit_hash = env!("GIT_COMMIT_HASH").to_string(); - let git_commit_timestamp = env!("GIT_COMMIT_TIMESTAMP").to_string(); - - let router = Router::new() - .route( - "/healthcheck", - get(|| async move { (StatusCode::OK, "Ok").into_response() }), - ) - .route( - "/info", - get(|| async move { - ( - StatusCode::OK, - Json(InfoResponse { - version, - public_key, - git_commit_hash, - git_commit_timestamp, - }), - ) - .into_response() - }), - ) - .route("/session", post(initialize)) - // Not applying auth middleware to /notarize endpoint for now as we can rely on our - // short-lived session id generated from /session endpoint, as it is not possible - // to use header for API key for websocket /notarize endpoint due to browser restriction - // ref: https://stackoverflow.com/a/4361358; And putting it in url query param - // seems to be more insecured: https://stackoverflow.com/questions/5517281/place-api-key-in-headers-or-url - .route_layer(from_extractor_with_state::< - AuthorizationMiddleware, - NotaryGlobals, - >(notary_globals.clone())) - .route("/notarize", get(upgrade_protocol)) - .layer(CorsLayer::permissive()) - .with_state(notary_globals); - let mut app = router.into_make_service(); - - loop { - // Poll and await for any incoming connection, ensure that all operations inside are infallible to prevent bringing down the server - let (prover_address, stream) = - match poll_fn(|cx| Pin::new(&mut listener).poll_accept(cx)).await { - Some(Ok(connection)) => (connection.remote_addr(), connection), - Some(Err(err)) => { - error!("{}", NotaryServerError::Connection(err.to_string())); - continue; - } - None => unreachable!("The poll_accept method should never return None"), - }; - debug!(?prover_address, "Received a prover's TCP connection"); - - let tls_acceptor = tls_acceptor.clone(); - let protocol = protocol.clone(); - let service = MakeService::<_, Request>::make_service(&mut app, &stream); - - // Spawn a new async task to handle the new connection - tokio::spawn(async move { - // When TLS is enabled - if let Some(acceptor) = tls_acceptor { - match acceptor.accept(stream).await { - Ok(stream) => { - info!( - ?prover_address, - "Accepted prover's TLS-secured TCP connection", - ); - // Serve different requests using the same hyper protocol and axum router - let _ = protocol - // Can unwrap because it's infallible - .serve_connection(stream, service.await.unwrap()) - // use with_upgrades to upgrade connection to websocket for websocket clients - // and to extract tcp connection for tcp clients - .with_upgrades() - .await; - } - Err(err) => { - error!( - ?prover_address, - "{}", - NotaryServerError::Connection(err.to_string()) - ); - } - } - } else { - // When TLS is disabled - info!(?prover_address, "Accepted prover's TCP connection",); - // Serve different requests using the same hyper protocol and axum router - let _ = protocol - // Can unwrap because it's infallible - .serve_connection(stream, service.await.unwrap()) - // use with_upgrades to upgrade connection to websocket for websocket clients - // and to extract tcp connection for tcp clients - .with_upgrades() - .await; - } - }); - } -} - -/// Temporary function to load notary signing key from static file -async fn load_notary_signing_key(config: &NotarySigningKeyProperties) -> Result { - debug!("Loading notary server's signing key"); - - let notary_signing_key = SigningKey::read_pkcs8_pem_file(&config.private_key_pem_path) - .map_err(|err| eyre!("Failed to load notary signing key for notarization: {err}"))?; - - debug!("Successfully loaded notary server's signing key!"); - Ok(notary_signing_key) -} - -/// Read a PEM-formatted file and return its buffer reader -pub async fn read_pem_file(file_path: &str) -> Result> { - let key_file = File::open(file_path).await?.into_std().await; - Ok(BufReader::new(key_file)) -} - -/// Load notary tls private key and cert from static files -async fn load_tls_key_and_cert( - private_key_pem_path: &str, - certificate_pem_path: &str, -) -> Result<(PrivateKey, Vec)> { - debug!("Loading notary server's tls private key and certificate"); - - let mut private_key_file_reader = read_pem_file(private_key_pem_path).await?; - let mut private_keys = rustls_pemfile::pkcs8_private_keys(&mut private_key_file_reader)?; - ensure!( - private_keys.len() == 1, - "More than 1 key found in the tls private key pem file" - ); - let private_key = PrivateKey(private_keys.remove(0)); - - let mut certificate_file_reader = read_pem_file(certificate_pem_path).await?; - let certificates = rustls_pemfile::certs(&mut certificate_file_reader)? - .into_iter() - .map(Certificate) - .collect(); - - debug!("Successfully loaded notary server's tls private key and certificate!"); - Ok((private_key, certificates)) -} - -#[cfg(test)] -mod test { - use super::*; - - #[tokio::test] - async fn test_load_notary_key_and_cert() { - let private_key_pem_path = "./fixture/tls/notary.key"; - let certificate_pem_path = "./fixture/tls/notary.crt"; - let result: Result<(PrivateKey, Vec)> = - load_tls_key_and_cert(private_key_pem_path, certificate_pem_path).await; - assert!(result.is_ok(), "Could not load tls private key and cert"); - } - - #[tokio::test] - async fn test_load_notary_signing_key() { - let config = NotarySigningKeyProperties { - private_key_pem_path: "./fixture/notary/notary.key".to_string(), - public_key_pem_path: "./fixture/notary/notary.pub".to_string(), - }; - let result: Result = load_notary_signing_key(&config).await; - assert!(result.is_ok(), "Could not load notary private key"); - } -} diff --git a/notary-server/src/server_tracing.rs b/notary-server/src/server_tracing.rs deleted file mode 100644 index 65b5aa8c2a..0000000000 --- a/notary-server/src/server_tracing.rs +++ /dev/null @@ -1,37 +0,0 @@ -use eyre::Result; -use opentelemetry::{ - global, - sdk::{export::trace::stdout, propagation::TraceContextPropagator}, -}; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry}; - -use crate::config::NotaryServerProperties; - -pub fn init_tracing(config: &NotaryServerProperties) -> Result<()> { - // Create a new OpenTelemetry pipeline - let tracer = stdout::new_pipeline().install_simple(); - - // Create a tracing layer with the configured tracer - let tracing_layer = tracing_opentelemetry::layer().with_tracer(tracer); - - // Set the log level - let env_filter_layer = EnvFilter::new(&config.tracing.default_level); - - // Format the log - let format_layer = tracing_subscriber::fmt::layer() - // Use a more compact, abbreviated log format - .compact() - .with_thread_ids(true) - .with_thread_names(true); - - // Set up context propagation - global::set_text_map_propagator(TraceContextPropagator::default()); - - Registry::default() - .with(tracing_layer) - .with(env_filter_layer) - .with(format_layer) - .try_init()?; - - Ok(()) -} diff --git a/notary-server/tests/integration_test.rs b/notary-server/tests/integration_test.rs deleted file mode 100644 index 6aa0e0ca0a..0000000000 --- a/notary-server/tests/integration_test.rs +++ /dev/null @@ -1,481 +0,0 @@ -use async_tungstenite::{ - tokio::connect_async_with_tls_connector_and_config, tungstenite::protocol::WebSocketConfig, -}; -use futures::AsyncWriteExt; -use hyper::{ - body::to_bytes, - client::{conn::Parts, HttpConnector}, - Body, Client, Request, StatusCode, -}; -use hyper_tls::HttpsConnector; -use rstest::rstest; -use rustls::{Certificate, ClientConfig, RootCertStore}; -use std::{ - net::{IpAddr, SocketAddr}, - sync::Arc, - time::Duration, -}; -use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN}; -use tlsn_prover::tls::{Prover, ProverConfig}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - net::TcpStream, -}; -use tokio_rustls::{client::TlsStream, TlsConnector}; -use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; -use tracing::debug; -use ws_stream_tungstenite::WsStream; - -use notary_server::{ - read_pem_file, run_server, AuthorizationProperties, NotarizationProperties, - NotarizationSessionRequest, NotarizationSessionResponse, NotaryServerProperties, - NotarySigningKeyProperties, ServerProperties, TLSProperties, TracingProperties, -}; - -const NOTARY_CA_CERT_PATH: &str = "./fixture/tls/rootCA.crt"; -const NOTARY_CA_CERT_BYTES: &[u8] = include_bytes!("../fixture/tls/rootCA.crt"); - -fn get_server_config(port: u16, tls_enabled: bool) -> NotaryServerProperties { - NotaryServerProperties { - server: ServerProperties { - name: "tlsnotaryserver.io".to_string(), - host: "127.0.0.1".to_string(), - port, - }, - notarization: NotarizationProperties { - max_transcript_size: 1 << 14, - }, - tls: TLSProperties { - enabled: tls_enabled, - private_key_pem_path: "./fixture/tls/notary.key".to_string(), - certificate_pem_path: "./fixture/tls/notary.crt".to_string(), - }, - notary_key: NotarySigningKeyProperties { - private_key_pem_path: "./fixture/notary/notary.key".to_string(), - public_key_pem_path: "./fixture/notary/notary.pub".to_string(), - }, - tracing: TracingProperties { - default_level: "DEBUG".to_string(), - }, - authorization: AuthorizationProperties { - enabled: false, - whitelist_csv_path: "./fixture/auth/whitelist.csv".to_string(), - }, - } -} - -async fn setup_config_and_server( - sleep_ms: u64, - port: u16, - tls_enabled: bool, -) -> NotaryServerProperties { - let notary_config = get_server_config(port, tls_enabled); - - let _ = tracing_subscriber::fmt::try_init(); - - let config = notary_config.clone(); - - // Run the notary server - tokio::spawn(async move { - run_server(&config).await.unwrap(); - }); - - // Sleep for a while to allow notary server to finish set up and start listening - tokio::time::sleep(Duration::from_millis(sleep_ms)).await; - - notary_config -} - -async fn tcp_socket(notary_config: NotaryServerProperties) -> TcpStream { - tokio::net::TcpStream::connect(SocketAddr::new( - IpAddr::V4(notary_config.server.host.parse().unwrap()), - notary_config.server.port, - )) - .await - .unwrap() -} - -async fn tls_socket(notary_config: NotaryServerProperties) -> TlsStream { - let notary_tcp_socket = tokio::net::TcpStream::connect(SocketAddr::new( - IpAddr::V4(notary_config.server.host.parse().unwrap()), - notary_config.server.port, - )) - .await - .unwrap(); - - // Connect to the Notary via TLS-TCP - let mut certificate_file_reader = read_pem_file(NOTARY_CA_CERT_PATH).await.unwrap(); - let mut certificates: Vec = rustls_pemfile::certs(&mut certificate_file_reader) - .unwrap() - .into_iter() - .map(Certificate) - .collect(); - let certificate = certificates.remove(0); - - let mut root_store = RootCertStore::empty(); - root_store.add(&certificate).unwrap(); - - let client_notary_config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_no_client_auth(); - - let notary_connector = TlsConnector::from(Arc::new(client_notary_config)); - notary_connector - .connect( - notary_config.server.name.as_str().try_into().unwrap(), - notary_tcp_socket, - ) - .await - .unwrap() -} - -#[rstest] -#[case::with_tls( - setup_config_and_server(100, 7048, true), - tls_socket(get_server_config(7048, true)) -)] -#[case::without_tls( - setup_config_and_server(100, 7049, false), - tcp_socket(get_server_config(7049, false)) -)] -#[awt] -#[tokio::test] -async fn test_tcp_prover( - #[future] - #[case] - notary_config: NotaryServerProperties, - #[future] - #[case] - notary_socket: S, -) { - let notary_host = notary_config.server.host; - let notary_port = notary_config.server.port; - let http_scheme = if notary_config.tls.enabled { - "https" - } else { - "http" - }; - - // Attach the hyper HTTP client to the notary connection to send request to the /session endpoint to configure notarization and obtain session id - let (mut request_sender, connection) = - hyper::client::conn::handshake(notary_socket).await.unwrap(); - - // Spawn the HTTP task to be run concurrently - let connection_task = tokio::spawn(connection.without_shutdown()); - - // Build the HTTP request to configure notarization - let payload = serde_json::to_string(&NotarizationSessionRequest { - client_type: notary_server::ClientType::Tcp, - max_transcript_size: Some(notary_config.notarization.max_transcript_size), - }) - .unwrap(); - let request = Request::builder() - .uri(format!( - "{http_scheme}://{notary_host}:{notary_port}/session" - )) - .method("POST") - .header("Host", notary_host.clone()) - // Need to specify application/json for axum to parse it as json - .header("Content-Type", "application/json") - .body(Body::from(payload)) - .unwrap(); - - debug!("Sending configuration request"); - - let response = request_sender.send_request(request).await.unwrap(); - - debug!("Sent configuration request"); - - assert!(response.status() == StatusCode::OK); - - debug!("Response OK"); - - // Pretty printing :) - let payload = to_bytes(response.into_body()).await.unwrap().to_vec(); - let notarization_response = - serde_json::from_str::(&String::from_utf8_lossy(&payload)) - .unwrap(); - - debug!("Notarization response: {:?}", notarization_response,); - - // Send notarization request via HTTP, where the underlying TCP connection will be extracted later - let request = Request::builder() - // Need to specify the session_id so that notary server knows the right configuration to use - // as the configuration is set in the previous HTTP call - .uri(format!( - "{http_scheme}://{}:{}/notarize?sessionId={}", - notary_host, - notary_port, - notarization_response.session_id.clone() - )) - .method("GET") - .header("Host", notary_host) - .header("Connection", "Upgrade") - // Need to specify this upgrade header for server to extract tcp connection later - .header("Upgrade", "TCP") - .body(Body::empty()) - .unwrap(); - - debug!("Sending notarization request"); - - let response = request_sender.send_request(request).await.unwrap(); - - debug!("Sent notarization request"); - - assert!(response.status() == StatusCode::SWITCHING_PROTOCOLS); - - debug!("Switched protocol OK"); - - // Claim back the socket after HTTP exchange is done so that client can use it for notarization - let Parts { - io: notary_socket, .. - } = connection_task.await.unwrap().unwrap(); - - // Connect to the Server - let (client_socket, server_socket) = tokio::io::duplex(2 << 16); - let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat())); - - let mut root_store = tls_core::anchors::RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - - // Basic default prover config — use the responded session id from notary server - let prover_config = ProverConfig::builder() - .id(notarization_response.session_id) - .server_dns(SERVER_DOMAIN) - .root_cert_store(root_store) - .build() - .unwrap(); - - // Bind the Prover to the sockets - let prover = Prover::new(prover_config) - .setup(notary_socket.compat()) - .await - .unwrap(); - let (tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); - - // Spawn the Prover task to be run concurrently - let prover_task = tokio::spawn(prover_fut); - - let (mut request_sender, connection) = hyper::client::conn::handshake(tls_connection.compat()) - .await - .unwrap(); - - let connection_task = tokio::spawn(connection.without_shutdown()); - - let request = Request::builder() - .uri(format!("https://{}/echo", SERVER_DOMAIN)) - .header("Host", SERVER_DOMAIN) - .header("Connection", "close") - .method("POST") - .body(Body::from("echo")) - .unwrap(); - - debug!("Sending request to server: {:?}", request); - - let response = request_sender.send_request(request).await.unwrap(); - - assert!(response.status() == StatusCode::OK); - - debug!( - "Received response from server: {:?}", - String::from_utf8_lossy(&to_bytes(response.into_body()).await.unwrap()) - ); - - let mut server_tls_conn = server_task.await.unwrap().unwrap(); - - // Make sure the server closes cleanly (sends close notify) - server_tls_conn.close().await.unwrap(); - - let mut client_socket = connection_task.await.unwrap().unwrap().io.into_inner(); - - client_socket.close().await.unwrap(); - - let mut prover = prover_task.await.unwrap().unwrap().start_notarize(); - - let sent_len = prover.sent_transcript().data().len(); - let recv_len = prover.recv_transcript().data().len(); - - let builder = prover.commitment_builder(); - - builder.commit_sent(0..sent_len).unwrap(); - builder.commit_recv(0..recv_len).unwrap(); - - _ = prover.finalize().await.unwrap(); - - debug!("Done notarization!"); -} - -#[tokio::test] -async fn test_websocket_prover() { - // Notary server configuration setup - let notary_config = setup_config_and_server(100, 7050, true).await; - let notary_host = notary_config.server.host.clone(); - let notary_port = notary_config.server.port; - - // Connect to the notary server via TLS-WebSocket - // Try to avoid dealing with transport layer directly to mimic the limitation of a browser extension that uses websocket - // - // Establish TLS setup for connections later - let certificate = - tokio_native_tls::native_tls::Certificate::from_pem(NOTARY_CA_CERT_BYTES).unwrap(); - let notary_tls_connector = tokio_native_tls::native_tls::TlsConnector::builder() - .add_root_certificate(certificate) - .use_sni(false) - .danger_accept_invalid_certs(true) - .build() - .unwrap(); - - // Call the /session HTTP API to configure notarization and obtain session id - let mut hyper_http_connector = HttpConnector::new(); - hyper_http_connector.enforce_http(false); - let mut hyper_tls_connector = - HttpsConnector::from((hyper_http_connector, notary_tls_connector.clone().into())); - hyper_tls_connector.https_only(true); - let https_client = Client::builder().build::<_, hyper::Body>(hyper_tls_connector); - - // Build the HTTP request to configure notarization - let payload = serde_json::to_string(&NotarizationSessionRequest { - client_type: notary_server::ClientType::Websocket, - max_transcript_size: Some(notary_config.notarization.max_transcript_size), - }) - .unwrap(); - - let request = Request::builder() - .uri(format!("https://{notary_host}:{notary_port}/session")) - .method("POST") - .header("Host", notary_host.clone()) - // Need to specify application/json for axum to parse it as json - .header("Content-Type", "application/json") - .body(Body::from(payload)) - .unwrap(); - - debug!("Sending request"); - - let response = https_client.request(request).await.unwrap(); - - debug!("Sent request"); - - assert!(response.status() == StatusCode::OK); - - debug!("Response OK"); - - // Pretty printing :) - let payload = to_bytes(response.into_body()).await.unwrap().to_vec(); - let notarization_response = - serde_json::from_str::(&String::from_utf8_lossy(&payload)) - .unwrap(); - - debug!("Notarization response: {:?}", notarization_response,); - - // Connect to the Notary via TLS-Websocket - // - // Note: This will establish a new TLS-TCP connection instead of reusing the previous TCP connection - // used in the previous HTTP POST request because we cannot claim back the tcp connection used in hyper - // client while using its high level request function — there does not seem to have a crate that can let you - // make a request without establishing TCP connection where you can claim the TCP connection later after making the request - let request = http::Request::builder() - // Need to specify the session_id so that notary server knows the right configuration to use - // as the configuration is set in the previous HTTP call - .uri(format!( - "wss://{}:{}/notarize?sessionId={}", - notary_host, - notary_port, - notarization_response.session_id.clone() - )) - .header("Host", notary_host.clone()) - .header("Sec-WebSocket-Key", uuid::Uuid::new_v4().to_string()) - .header("Sec-WebSocket-Version", "13") - .header("Connection", "Upgrade") - .header("Upgrade", "Websocket") - .body(()) - .unwrap(); - - let (notary_ws_stream, _) = connect_async_with_tls_connector_and_config( - request, - Some(notary_tls_connector.into()), - Some(WebSocketConfig::default()), - ) - .await - .unwrap(); - - // Wrap the socket with the adapter so that we get AsyncRead and AsyncWrite implemented - let notary_ws_socket = WsStream::new(notary_ws_stream); - - // Connect to the Server - let (client_socket, server_socket) = tokio::io::duplex(2 << 16); - let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat())); - - let mut root_store = tls_core::anchors::RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - - // Basic default prover config — use the responded session id from notary server - let prover_config = ProverConfig::builder() - .id(notarization_response.session_id) - .server_dns(SERVER_DOMAIN) - .root_cert_store(root_store) - .build() - .unwrap(); - - // Bind the Prover to the sockets - let prover = Prover::new(prover_config) - .setup(notary_ws_socket) - .await - .unwrap(); - let (tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); - - // Spawn the Prover and Mux tasks to be run concurrently - let prover_task = tokio::spawn(prover_fut); - - let (mut request_sender, connection) = hyper::client::conn::handshake(tls_connection.compat()) - .await - .unwrap(); - - let connection_task = tokio::spawn(connection.without_shutdown()); - - let request = Request::builder() - .uri(format!("https://{}/echo", SERVER_DOMAIN)) - .header("Host", SERVER_DOMAIN) - .header("Connection", "close") - .method("POST") - .body(Body::from("echo")) - .unwrap(); - - debug!("Sending request to server: {:?}", request); - - let response = request_sender.send_request(request).await.unwrap(); - - assert!(response.status() == StatusCode::OK); - - debug!( - "Received response from server: {:?}", - String::from_utf8_lossy(&to_bytes(response.into_body()).await.unwrap()) - ); - - let mut server_tls_conn = server_task.await.unwrap().unwrap(); - - // Make sure the server closes cleanly (sends close notify) - server_tls_conn.close().await.unwrap(); - - let mut client_socket = connection_task.await.unwrap().unwrap().io.into_inner(); - - client_socket.close().await.unwrap(); - - let mut prover = prover_task.await.unwrap().unwrap().start_notarize(); - - let sent_len = prover.sent_transcript().data().len(); - let recv_len = prover.recv_transcript().data().len(); - - let builder = prover.commitment_builder(); - - builder.commit_sent(0..sent_len).unwrap(); - builder.commit_recv(0..recv_len).unwrap(); - - _ = prover.finalize().await.unwrap(); - - debug!("Done notarization!"); -} diff --git a/rustfmt.toml b/rustfmt.toml index 5a25ea3f21..c1a72c71b1 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,3 +1,4 @@ -ignore = ["tls-core/", "tls-client/"] +ignore = ["crates/tls/core", "crates/tls/client"] imports_granularity = "Crate" +wrap_comments = true diff --git a/tlsn/Cargo.toml b/tlsn/Cargo.toml deleted file mode 100644 index 1fe45db246..0000000000 --- a/tlsn/Cargo.toml +++ /dev/null @@ -1,68 +0,0 @@ -[workspace] -members = [ - "tlsn-core", - "tlsn-common", - "tlsn-verifier", - "tlsn-prover", - # "tlsn-formats", - "tlsn-server-fixture", - "tests-integration", - "examples", - "benches", -] -resolver = "2" - -[workspace.dependencies] -tlsn-core = { path = "tlsn-core" } -tlsn-common = { path = "tlsn-common" } -tlsn-prover = { path = "tlsn-prover" } -tlsn-verifier = { path = "tlsn-verifier" } -tlsn-server-fixture = { path = "tlsn-server-fixture" } - -#tlsn-formats = { path = "tlsn-formats" } - -tlsn-tls-core = { path = "../components/tls/tls-core" } -tlsn-tls-mpc = { path = "../components/tls/tls-mpc" } -tlsn-tls-client = { path = "../components/tls/tls-client" } -tlsn-tls-client-async = { path = "../components/tls/tls-client-async" } -tls-server-fixture = { path = "../components/tls/tls-server-fixture" } -uid-mux = { path = "../components/uid-mux" } - -tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" } -tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" } - -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" } - - -futures = "0.3" -tokio-util = "0.7" -hyper = "<=0.14.26" -tokio = "1" - -signature = "2" -p256 = "0.13" -rs_merkle = "1" -rand_chacha = "0.3" -rand = "0.8" -rand_core = "0.6" -webpki-roots = "0.26" - -derive_builder = "0.12" -thiserror = "1" -serde = "1" -bincode = "1" -hex = "0.4" -bytes = "1.4" -opaque-debug = "0.3" -spansy = { git = "https://github.com/sinui0/spansy", rev = "becb33d" } - -tracing = "0.1" -tracing-subscriber = "0.3" -rstest = "0.17" - -web-time = "0.2" diff --git a/tlsn/benches/Cargo.toml b/tlsn/benches/Cargo.toml deleted file mode 100644 index 76286bcbdc..0000000000 --- a/tlsn/benches/Cargo.toml +++ /dev/null @@ -1,39 +0,0 @@ -[package] -name = "tlsn-benches" -version = "0.0.0" -edition = "2021" -publish = false - -[dependencies] -tlsn-core.workspace = true -tlsn-prover.workspace = true -tlsn-verifier.workspace = true -tlsn-server-fixture.workspace = true -tlsn-tls-core.workspace = true -futures.workspace = true -tokio = { workspace = true, features = [ - "rt", - "rt-multi-thread", - "macros", - "net", - "io-std", - "fs", -] } -tokio-util.workspace = true -tracing-subscriber = { workspace = true, features = ["env-filter"] } - -[[bin]] -name = "setup_network" -path = "src/setup_network.rs" - -[[bin]] -name = "cleanup_network" -path = "src/cleanup_network.rs" - -[[bin]] -name = "prover" -path = "src/prover.rs" - -[[bin]] -name = "verifier" -path = "src/verifier.rs" diff --git a/tlsn/benches/README.md b/tlsn/benches/README.md deleted file mode 100644 index 3f3ac17269..0000000000 --- a/tlsn/benches/README.md +++ /dev/null @@ -1,121 +0,0 @@ -# TLSNotary bench utilities - -This crate provides utilities for benchmarking protocol performance under various network conditions and usage patterns. - -As the protocol is mostly IO bound, it's important to track how it performs in low bandwidth and/or high latency environments. To do this we set up temporary network namespaces and add virtual ethernet interfaces which we can control using the linux `tc` (Traffic Control) utility. - -## Setup - -To start we must create network namespaces for the prover and verifier, respectively. - -```sh -ip netns add prover-ns -ip netns add verifier-ns -``` - -Then we create a pair of virtual ethernet interfaces and add them to their respective namespaces. - -```sh -ip link add prover-veth type veth peer name verifier-veth -ip link set prover-veth netns prover-ns -ip link set verifier-veth netns verifier-ns -``` - -If successful you should be able to see each interface in its namespace. For example, to see the prover interface: - -```sh -ip netns exec prover-ns ip link -``` - -Then, activate each interface (bring it up). - -```sh -ip netns exec prover-ns ip link set prover-veth up -ip netns exec verifier-ns ip link set verifier-veth up -``` - -Next we'll assign IP addresses to each interface and set default routes: - -```sh -ip netns exec prover-ns ip addr add 10.10.1.0/24 dev prover-veth -ip netns exec prover-ns ip route add default via 10.10.1.0 dev prover-veth -ip netns exec verifier-ns ip addr add 10.10.1.1/24 dev verifier-veth -ip netns exec verifier-ns ip route add default via 10.10.1.1 dev verifier-veth -``` - -Verify that everything worked by pinging between them: - -```sh -ip netns exec prover-ns ping 10.10.1.1 -``` - -## Clean up - -For future reference, you can clean up this configuration as shown below. - -First, delete the interface pair (this removes both): - -```sh -ip netns exec prover-ns ip link delete prover-veth -``` - -Finally, delete each namespace: - -```sh -ip netns del prover-ns -ip netns del verifier-ns -``` - -## Configuration binaries - -Alternatively, instead of doing the above configuration manually, you can build the `setup_network` and `cleanup_network` binaries and execute them instead. Though they haven't been tested and you have to run them as root, so use at your own risk. - -## Configuring network - -To simulate different network conditions we use the linux utility `tc`. Typically, only the egress performance of an interface is configured. So we will configure the egress of both the prover and verifier to simulate the conditions we want. - -### Adding rules - -For example, to add both an egress bandwidth limit and delay to the prover we can do this: - -```sh -ip netns exec prover-ns tc qdisc add dev prover-veth root handle 1: tbf rate 10mbit burst 1mbit latency 60s -ip netns exec prover-ns tc qdisc add dev prover-veth parent 1:1 handle 10: netem delay 50ms -``` - -The above command will chain a bandwidth filter with a delay filter. The bandwidth filter will cap prover "upload" at 10Mbps with 1Mbps bursts, and drops packets not sent within 60. The delay filter will cause all packets to wait 50ms before arriving at the verifier's network interface. - -To simulate a prover with 10Mbps up and 100Mbps down @100ms latency with the verifier, one would also add the following filters to the verifier interface: - -```sh -ip netns exec verifier-ns tc qdisc add dev verifier-veth root handle 1: tbf rate 100mbit burst 1mbit latency 60s -ip netns exec verifier-ns tc qdisc add dev verifier-veth parent 1:1 handle 10: netem delay 50ms -``` - -### Modifying rules - -To modify a rule you have to delete the existing one and re-add a new one. - -### Deleting rules - -You can delete all rules on a device like so: - -```sh -ip netns exec prover-ns tc qdisc del dev prover-veth root -``` - -## Running benches - -In order to run a binary in another network namespace you need to run as root, and this won't place nice with cargo. The simplest way to run the bench is to first compile the binaries and run them directly. - -```sh -cargo b --bin prover --release -cargo b --bin verifier --release -``` - -Run these separately: - -```sh -ip netns exec prover-ns ../target/release/prover -ip netns exec verifier-ns ../target/release/verifier -``` \ No newline at end of file diff --git a/tlsn/benches/src/cleanup_network.rs b/tlsn/benches/src/cleanup_network.rs deleted file mode 100644 index 52c7841cd1..0000000000 --- a/tlsn/benches/src/cleanup_network.rs +++ /dev/null @@ -1,30 +0,0 @@ -// Clean up the network namespaces and interface pair created by setup_network.rs - -use std::process::Command; -use tlsn_benches::*; - -fn main() -> Result<(), std::io::Error> { - // Delete interface pair - Command::new("sudo") - .args(&[ - "ip", - "netns", - "exec", - PROVER_NAMESPACE, - "ip", - "link", - "delete", - PROVER_INTERFACE, - ]) - .status()?; - - // Delete namespaces - Command::new("sudo") - .args(&["ip", "netns", "del", PROVER_NAMESPACE]) - .status()?; - Command::new("sudo") - .args(&["ip", "netns", "del", VERIFIER_NAMESPACE]) - .status()?; - - Ok(()) -} diff --git a/tlsn/benches/src/lib.rs b/tlsn/benches/src/lib.rs deleted file mode 100644 index 729625aa1b..0000000000 --- a/tlsn/benches/src/lib.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub const PROVER_NAMESPACE: &str = "prover-ns"; -pub const PROVER_INTERFACE: &str = "prover-veth"; -pub const PROVER_SUBNET: &str = "10.10.1.0/24"; -pub const VERIFIER_NAMESPACE: &str = "verifier-ns"; -pub const VERIFIER_INTERFACE: &str = "verifier-veth"; -pub const VERIFIER_SUBNET: &str = "10.10.1.1/24"; diff --git a/tlsn/benches/src/prover.rs b/tlsn/benches/src/prover.rs deleted file mode 100644 index 294e71f509..0000000000 --- a/tlsn/benches/src/prover.rs +++ /dev/null @@ -1,79 +0,0 @@ -use std::time::Instant; - -use futures::{AsyncReadExt, AsyncWriteExt}; -use tlsn_core::Direction; -use tlsn_server_fixture::{CA_CERT_DER, SERVER_DOMAIN}; -use tokio_util::compat::TokioAsyncReadCompatExt; - -use tlsn_prover::tls::{Prover, ProverConfig}; -use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; - -#[tokio::main] -async fn main() { - tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env()) - .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) - .init(); - - let (client_conn, server_conn) = tokio::io::duplex(2 << 16); - let server_task = tokio::spawn(tlsn_server_fixture::bind(server_conn.compat())); - - let ip = std::env::var("VERIFIER_IP").unwrap_or_else(|_| "10.10.1.1".to_string()); - let port: u16 = std::env::var("VERIFIER_PORT") - .map(|port| port.parse().expect("port is valid u16")) - .unwrap_or(8000); - let verifier_host = (ip.as_str(), port); - let verifier_conn = tokio::net::TcpStream::connect(verifier_host).await.unwrap(); - - let mut root_store = tls_core::anchors::RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - - let start_time = Instant::now(); - - let prover = Prover::new( - ProverConfig::builder() - .id("test") - .server_dns(SERVER_DOMAIN) - .root_cert_store(root_store) - .build() - .unwrap(), - ) - .setup(verifier_conn.compat()) - .await - .unwrap(); - let (mut mpc_tls_connection, prover_fut) = prover.connect(client_conn.compat()).await.unwrap(); - let prover_task = tokio::spawn(async { prover_fut.await.unwrap() }); - - mpc_tls_connection - .write_all(b"GET /formats/json?size=8 HTTP/1.1\r\nConnection: close\r\n\r\n") - .await - .unwrap(); - - mpc_tls_connection.close().await.unwrap(); - - let mut response = vec![0u8; 1024]; - mpc_tls_connection.read_to_end(&mut response).await.unwrap(); - - server_task.await.unwrap(); - - let mut prover = prover_task.await.unwrap().start_prove(); - - prover - .reveal(0..prover.sent_transcript().data().len(), Direction::Sent) - .unwrap(); - prover - .reveal( - 0..prover.recv_transcript().data().len(), - Direction::Received, - ) - .unwrap(); - prover.prove().await.unwrap(); - prover.finalize().await.unwrap(); - - println!( - "completed: {} seconds", - Instant::now().duration_since(start_time).as_secs() - ); -} diff --git a/tlsn/benches/src/setup_network.rs b/tlsn/benches/src/setup_network.rs deleted file mode 100644 index fecd825ff6..0000000000 --- a/tlsn/benches/src/setup_network.rs +++ /dev/null @@ -1,171 +0,0 @@ -// Set up network namespaces and veth pairs for benchmarking - -use std::process::Command; - -use tlsn_benches::{ - PROVER_INTERFACE, PROVER_NAMESPACE, PROVER_SUBNET, VERIFIER_INTERFACE, VERIFIER_NAMESPACE, - VERIFIER_SUBNET, -}; - -fn main() -> Result<(), std::io::Error> { - // Create network namespaces - create_network_namespace(PROVER_NAMESPACE)?; - create_network_namespace(VERIFIER_NAMESPACE)?; - - // Create veth pair and attach to namespaces - create_veth_pair( - PROVER_NAMESPACE, - PROVER_INTERFACE, - VERIFIER_NAMESPACE, - VERIFIER_INTERFACE, - )?; - - // Set devices up - set_device_up(PROVER_NAMESPACE, PROVER_INTERFACE)?; - set_device_up(VERIFIER_NAMESPACE, VERIFIER_INTERFACE)?; - - // Assign IPs - assign_ip_to_interface(PROVER_NAMESPACE, PROVER_INTERFACE, PROVER_SUBNET)?; - assign_ip_to_interface(VERIFIER_NAMESPACE, VERIFIER_INTERFACE, VERIFIER_SUBNET)?; - - // Set default routes - set_default_route( - PROVER_NAMESPACE, - PROVER_INTERFACE, - PROVER_SUBNET.split('/').nth(0).unwrap(), - )?; - set_default_route( - VERIFIER_NAMESPACE, - VERIFIER_INTERFACE, - VERIFIER_SUBNET.split('/').nth(0).unwrap(), - )?; - - Ok(()) -} - -/// Create a network namespace with the given name if it does not already exist. -fn create_network_namespace(name: &str) -> Result<(), std::io::Error> { - // Check if namespace already exists - if Command::new("sudo") - .args(&["ip", "netns", "list"]) - .output()? - .stdout - .windows(name.len()) - .any(|ns| ns == name.as_bytes()) - { - println!("Namespace {} already exists", name); - return Ok(()); - } else { - println!("Creating namespace {}", name); - Command::new("sudo") - .args(&["ip", "netns", "add", name]) - .status()?; - } - - Ok(()) -} - -fn create_veth_pair( - left_namespace: &str, - left_interface: &str, - right_namespace: &str, - right_interface: &str, -) -> Result<(), std::io::Error> { - // Check if interfaces are already present in namespaces - if is_interface_present_in_namespace(left_namespace, left_interface)? - || is_interface_present_in_namespace(right_namespace, right_interface)? - { - println!("Virtual interface already exists."); - return Ok(()); - } - - // Create veth pair - Command::new("sudo") - .args(&[ - "ip", - "link", - "add", - left_interface, - "type", - "veth", - "peer", - "name", - right_interface, - ]) - .status()?; - - println!( - "Created veth pair {} and {}", - left_interface, right_interface - ); - - // Attach veth pair to namespaces - attach_interface_to_namespace(left_namespace, left_interface)?; - attach_interface_to_namespace(right_namespace, right_interface)?; - - Ok(()) -} - -fn attach_interface_to_namespace(namespace: &str, interface: &str) -> Result<(), std::io::Error> { - Command::new("sudo") - .args(&["ip", "link", "set", interface, "netns", namespace]) - .status()?; - - println!("Attached {} to namespace {}", interface, namespace); - - Ok(()) -} - -fn set_default_route(namespace: &str, interface: &str, ip: &str) -> Result<(), std::io::Error> { - Command::new("sudo") - .args(&[ - "ip", "netns", "exec", namespace, "ip", "route", "add", "default", "via", ip, "dev", - interface, - ]) - .status()?; - - println!( - "Set default route for namespace {} ip {} to {}", - namespace, ip, interface - ); - - Ok(()) -} - -fn is_interface_present_in_namespace( - namespace: &str, - interface: &str, -) -> Result { - Ok(Command::new("sudo") - .args(&[ - "ip", "netns", "exec", namespace, "ip", "link", "list", "dev", interface, - ]) - .output()? - .stdout - .windows(interface.len()) - .any(|ns| ns == interface.as_bytes())) -} - -fn set_device_up(namespace: &str, interface: &str) -> Result<(), std::io::Error> { - Command::new("sudo") - .args(&[ - "ip", "netns", "exec", namespace, "ip", "link", "set", interface, "up", - ]) - .status()?; - - Ok(()) -} - -fn assign_ip_to_interface( - namespace: &str, - interface: &str, - ip: &str, -) -> Result<(), std::io::Error> { - Command::new("sudo") - .args(&[ - "ip", "netns", "exec", namespace, "ip", "addr", "add", ip, "dev", interface, - ]) - .status()?; - - Ok(()) -} diff --git a/tlsn/benches/src/verifier.rs b/tlsn/benches/src/verifier.rs deleted file mode 100644 index 365800562a..0000000000 --- a/tlsn/benches/src/verifier.rs +++ /dev/null @@ -1,43 +0,0 @@ -use tls_core::verify::WebPkiVerifier; -use tlsn_server_fixture::CA_CERT_DER; -use tokio_util::compat::TokioAsyncReadCompatExt; - -use tlsn_verifier::tls::{Verifier, VerifierConfig}; -use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; - -#[tokio::main] -async fn main() { - tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env()) - .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) - .init(); - - let ip = std::env::var("VERIFIER_IP").unwrap_or_else(|_| "10.10.1.1".to_string()); - let port: u16 = std::env::var("VERIFIER_PORT") - .map(|port| port.parse().expect("port is valid u16")) - .unwrap_or(8000); - let host = (ip.as_str(), port); - - let mut root_store = tls_core::anchors::RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - - let verifier = Verifier::new( - VerifierConfig::builder() - .id("test") - .cert_verifier(WebPkiVerifier::new(root_store, None)) - .build() - .unwrap(), - ); - - let listener = tokio::net::TcpListener::bind(host).await.unwrap(); - - let (prover_conn, _) = listener.accept().await.unwrap(); - - println!("connected to prover"); - - verifier.verify(prover_conn.compat()).await.unwrap(); - - println!("success"); -} diff --git a/tlsn/examples/Cargo.toml b/tlsn/examples/Cargo.toml deleted file mode 100644 index f6536191f1..0000000000 --- a/tlsn/examples/Cargo.toml +++ /dev/null @@ -1,66 +0,0 @@ -[package] -name = "tlsn-examples" -version = "0.0.0" -edition = "2021" -publish = false - -[dev-dependencies] -tlsn-prover = { workspace = true, features = ["tracing"] } -tlsn-verifier.workspace = true -tlsn-core.workspace = true -tlsn-tls-core.workspace = true -tlsn-tls-client.workspace = true -notary-server = { path = "../../notary-server" } -mpz-core.workspace = true - -futures.workspace = true -tokio = { workspace = true, features = [ - "rt", - "rt-multi-thread", - "macros", - "net", - "io-std", - "fs", -] } -tokio-util.workspace = true - -tracing.workspace = true -tracing-subscriber.workspace = true - -hyper = { version = "0.14", features = ["client", "http1"] } -chrono = "0.4" -p256 = { workspace = true, features = ["ecdsa"] } -elliptic-curve = { version = "0.13.5", features = ["pkcs8"] } -webpki-roots.workspace = true - -async-tls = { version = "0.12", default-features = false, features = [ - "client", -] } - -serde = { version = "1.0.147", features = ["derive"] } -serde_json = "1.0" -eyre = "0.6.8" -rustls = { version = "0.21" } -rustls-pemfile = { version = "1.0.2" } -tokio-rustls = { version = "0.24.1" } -dotenv = "0.15.0" - -[[example]] -name = "simple_prover" -path = "simple/simple_prover.rs" - -[[example]] -name = "simple_verifier" -path = "simple/simple_verifier.rs" - -[[example]] -name = "twitter_dm" -path = "twitter/twitter_dm.rs" - -[[example]] -name = "discord_dm" -path = "discord/discord_dm.rs" - -[[example]] -name = "discord_dm_verifier" -path = "discord/discord_dm_verifier.rs" diff --git a/tlsn/examples/discord/discord_dm.rs b/tlsn/examples/discord/discord_dm.rs deleted file mode 100644 index 3ac3acc145..0000000000 --- a/tlsn/examples/discord/discord_dm.rs +++ /dev/null @@ -1,341 +0,0 @@ -/// This example shows how to notarize Discord DMs. -/// -/// The example uses the notary server implemented in ../../../notary-server -use futures::AsyncWriteExt; -use hyper::{body::to_bytes, client::conn::Parts, Body, Request, StatusCode}; -use rustls::{Certificate, ClientConfig, RootCertStore}; -use serde::{Deserialize, Serialize}; -use std::{env, ops::Range, str, sync::Arc}; -use tlsn_core::proof::TlsProof; -use tokio::{io::AsyncWriteExt as _, net::TcpStream}; -use tokio_rustls::TlsConnector; -use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; -use tracing::debug; - -use tlsn_prover::tls::{Prover, ProverConfig}; - -// Setting of the application server -const SERVER_DOMAIN: &str = "discord.com"; - -// Setting of the notary server — make sure these are the same with those in ../../../notary-server -const NOTARY_HOST: &str = "127.0.0.1"; -const NOTARY_PORT: u16 = 7047; - -// Configuration of notarization -const NOTARY_MAX_TRANSCRIPT_SIZE: usize = 16384; - -/// Response object of the /session API -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct NotarizationSessionResponse { - pub session_id: String, -} - -/// Request object of the /session API -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct NotarizationSessionRequest { - pub client_type: ClientType, - /// Maximum transcript size in bytes - pub max_transcript_size: Option, -} - -/// Types of client that the prover is using -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub enum ClientType { - /// Client that has access to the transport layer - Tcp, - /// Client that cannot directly access transport layer, e.g. browser extension - Websocket, -} - -#[tokio::main] -async fn main() { - tracing_subscriber::fmt::init(); - - // Load secret variables frome environment for discord server connection - dotenv::dotenv().ok(); - let channel_id = env::var("CHANNEL_ID").unwrap(); - let auth_token = env::var("AUTHORIZATION").unwrap(); - let user_agent = env::var("USER_AGENT").unwrap(); - - let (notary_tls_socket, session_id) = setup_notary_connection().await; - - // Basic default prover config using the session_id returned from /session endpoint just now - let config = ProverConfig::builder() - .id(session_id) - .server_dns(SERVER_DOMAIN) - .build() - .unwrap(); - - // Create a new prover and set up the MPC backend. - let prover = Prover::new(config) - .setup(notary_tls_socket.compat()) - .await - .unwrap(); - - let client_socket = tokio::net::TcpStream::connect((SERVER_DOMAIN, 443)) - .await - .unwrap(); - - // Bind the Prover to server connection - let (tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); - - // Spawn the Prover to be run concurrently - let prover_task = tokio::spawn(prover_fut); - - // Attach the hyper HTTP client to the TLS connection - let (mut request_sender, connection) = hyper::client::conn::handshake(tls_connection.compat()) - .await - .unwrap(); - - // Spawn the HTTP task to be run concurrently - let connection_task = tokio::spawn(connection.without_shutdown()); - - // Build the HTTP request to fetch the DMs - let request = Request::builder() - .uri(format!( - "https://{SERVER_DOMAIN}/api/v9/channels/{channel_id}/messages?limit=2" - )) - .header("Host", SERVER_DOMAIN) - .header("Accept", "*/*") - .header("Accept-Language", "en-US,en;q=0.5") - .header("Accept-Encoding", "identity") - .header("User-Agent", user_agent) - .header("Authorization", &auth_token) - .header("Connection", "close") - .body(Body::empty()) - .unwrap(); - - debug!("Sending request"); - - let response = request_sender.send_request(request).await.unwrap(); - - debug!("Sent request"); - - assert!(response.status() == StatusCode::OK, "{}", response.status()); - - debug!("Request OK"); - - // Pretty printing :) - let payload = to_bytes(response.into_body()).await.unwrap().to_vec(); - let parsed = - serde_json::from_str::(&String::from_utf8_lossy(&payload)).unwrap(); - debug!("{}", serde_json::to_string_pretty(&parsed).unwrap()); - - // Close the connection to the server - let mut client_socket = connection_task.await.unwrap().unwrap().io.into_inner(); - client_socket.close().await.unwrap(); - - // The Prover task should be done now, so we can grab it. - let prover = prover_task.await.unwrap().unwrap(); - - // Prepare for notarization - let mut prover = prover.start_notarize(); - - // Identify the ranges in the transcript that contain secrets - let (public_ranges, private_ranges) = - find_ranges(prover.sent_transcript().data(), &[auth_token.as_bytes()]); - - let recv_len = prover.recv_transcript().data().len(); - - let builder = prover.commitment_builder(); - - // Collect commitment ids for the outbound transcript - let mut commitment_ids = public_ranges - .iter() - .chain(private_ranges.iter()) - .map(|range| builder.commit_sent(range.clone()).unwrap()) - .collect::>(); - - // Commit to the full received transcript in one shot, as we don't need to redact anything - commitment_ids.push(builder.commit_recv(0..recv_len).unwrap()); - - // Finalize, returning the notarized session - let notarized_session = prover.finalize().await.unwrap(); - - debug!("Notarization complete!"); - - // Dump the notarized session to a file - let mut file = tokio::fs::File::create("discord_dm_notarized_session.json") - .await - .unwrap(); - file.write_all( - serde_json::to_string_pretty(¬arized_session) - .unwrap() - .as_bytes(), - ) - .await - .unwrap(); - - let session_proof = notarized_session.session_proof(); - - let mut proof_builder = notarized_session.data().build_substrings_proof(); - - // Reveal everything but the auth token (which was assigned commitment id 2) - proof_builder.reveal(commitment_ids[0]).unwrap(); - proof_builder.reveal(commitment_ids[1]).unwrap(); - proof_builder.reveal(commitment_ids[3]).unwrap(); - - let substrings_proof = proof_builder.build().unwrap(); - - let proof = TlsProof { - session: session_proof, - substrings: substrings_proof, - }; - - // Dump the proof to a file. - let mut file = tokio::fs::File::create("discord_dm_proof.json") - .await - .unwrap(); - file.write_all(serde_json::to_string_pretty(&proof).unwrap().as_bytes()) - .await - .unwrap(); -} - -async fn setup_notary_connection() -> (tokio_rustls::client::TlsStream, String) { - // Connect to the Notary via TLS-TCP - let pem_file = str::from_utf8(include_bytes!( - "../../../notary-server/fixture/tls/rootCA.crt" - )) - .unwrap(); - let mut reader = std::io::BufReader::new(pem_file.as_bytes()); - let mut certificates: Vec = rustls_pemfile::certs(&mut reader) - .unwrap() - .into_iter() - .map(Certificate) - .collect(); - let certificate = certificates.remove(0); - - let mut root_store = RootCertStore::empty(); - root_store.add(&certificate).unwrap(); - - let client_notary_config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_no_client_auth(); - let notary_connector = TlsConnector::from(Arc::new(client_notary_config)); - - let notary_socket = tokio::net::TcpStream::connect((NOTARY_HOST, NOTARY_PORT)) - .await - .unwrap(); - - let notary_tls_socket = notary_connector - // Require the domain name of notary server to be the same as that in the server cert - .connect("tlsnotaryserver.io".try_into().unwrap(), notary_socket) - .await - .unwrap(); - - // Attach the hyper HTTP client to the notary TLS connection to send request to the /session endpoint to configure notarization and obtain session id - let (mut request_sender, connection) = hyper::client::conn::handshake(notary_tls_socket) - .await - .unwrap(); - - // Spawn the HTTP task to be run concurrently - let connection_task = tokio::spawn(connection.without_shutdown()); - - // Build the HTTP request to configure notarization - let payload = serde_json::to_string(&NotarizationSessionRequest { - client_type: ClientType::Tcp, - max_transcript_size: Some(NOTARY_MAX_TRANSCRIPT_SIZE), - }) - .unwrap(); - - let request = Request::builder() - .uri(format!("https://{NOTARY_HOST}:{NOTARY_PORT}/session")) - .method("POST") - .header("Host", NOTARY_HOST) - // Need to specify application/json for axum to parse it as json - .header("Content-Type", "application/json") - .body(Body::from(payload)) - .unwrap(); - - debug!("Sending configuration request"); - - let configuration_response = request_sender.send_request(request).await.unwrap(); - - debug!("Sent configuration request"); - - assert!(configuration_response.status() == StatusCode::OK); - - debug!("Response OK"); - - // Pretty printing :) - let payload = to_bytes(configuration_response.into_body()) - .await - .unwrap() - .to_vec(); - let notarization_response = - serde_json::from_str::(&String::from_utf8_lossy(&payload)) - .unwrap(); - - debug!("Notarization response: {:?}", notarization_response,); - - // Send notarization request via HTTP, where the underlying TCP connection will be extracted later - let request = Request::builder() - // Need to specify the session_id so that notary server knows the right configuration to use - // as the configuration is set in the previous HTTP call - .uri(format!( - "https://{}:{}/notarize?sessionId={}", - NOTARY_HOST, - NOTARY_PORT, - notarization_response.session_id.clone() - )) - .method("GET") - .header("Host", NOTARY_HOST) - .header("Connection", "Upgrade") - // Need to specify this upgrade header for server to extract tcp connection later - .header("Upgrade", "TCP") - .body(Body::empty()) - .unwrap(); - - debug!("Sending notarization request"); - - let response = request_sender.send_request(request).await.unwrap(); - - debug!("Sent notarization request"); - - assert!(response.status() == StatusCode::SWITCHING_PROTOCOLS); - - debug!("Switched protocol OK"); - - // Claim back the TLS socket after HTTP exchange is done - let Parts { - io: notary_tls_socket, - .. - } = connection_task.await.unwrap().unwrap(); - - (notary_tls_socket, notarization_response.session_id) -} - -/// Find the ranges of the public and private parts of a sequence. -/// -/// Returns a tuple of `(public, private)` ranges. -fn find_ranges(seq: &[u8], sub_seq: &[&[u8]]) -> (Vec>, Vec>) { - let mut private_ranges = Vec::new(); - for s in sub_seq { - for (idx, w) in seq.windows(s.len()).enumerate() { - if w == *s { - private_ranges.push(idx..(idx + w.len())); - } - } - } - - let mut sorted_ranges = private_ranges.clone(); - sorted_ranges.sort_by_key(|r| r.start); - - let mut public_ranges = Vec::new(); - let mut last_end = 0; - for r in sorted_ranges { - if r.start > last_end { - public_ranges.push(last_end..r.start); - } - last_end = r.end; - } - - if last_end < seq.len() { - public_ranges.push(last_end..seq.len()); - } - - (public_ranges, private_ranges) -} diff --git a/tlsn/examples/discord/discord_dm_verifier.rs b/tlsn/examples/discord/discord_dm_verifier.rs deleted file mode 100644 index bd40d1a46a..0000000000 --- a/tlsn/examples/discord/discord_dm_verifier.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::{str, time::Duration}; - -use elliptic_curve::pkcs8::DecodePublicKey; - -use tlsn_core::proof::{SessionProof, TlsProof}; - -/// A simple verifier which reads a proof generated by `discord_dm.rs` from "discord_dm_proof.json", verifies -/// it and prints the verified data to the console. -fn main() { - // Deserialize the proof - let proof = std::fs::read_to_string("discord_dm_proof.json").unwrap(); - let proof: TlsProof = serde_json::from_str(proof.as_str()).unwrap(); - - let TlsProof { - // The session proof establishes the identity of the server and the commitments - // to the TLS transcript. - session, - // The substrings proof proves select portions of the transcript, while redacting - // anything the Prover chose not to disclose. - substrings, - } = proof; - - // Verify the session proof against the Notary's public key - // - // This verifies the identity of the server using a default certificate verifier which trusts - // the root certificates from the `webpki-roots` crate. - session - .verify_with_default_cert_verifier(notary_pubkey()) - .unwrap(); - - let SessionProof { - // The session header that was signed by the Notary is a succinct commitment to the TLS transcript. - header, - // This is the session_info, which contains the server_name, that is checked against the - // certificate chain shared in the TLS handshake. - session_info, - .. - } = session; - - // The time at which the session was recorded - let time = chrono::DateTime::UNIX_EPOCH + Duration::from_secs(header.time()); - - // Verify the substrings proof against the session header. - // - // This returns the redacted transcripts - let (mut sent, mut recv) = substrings.verify(&header).unwrap(); - - // Replace the bytes which the Prover chose not to disclose with 'X' - sent.set_redacted(b'X'); - recv.set_redacted(b'X'); - - println!("-------------------------------------------------------------------"); - println!( - "Successfully verified that the bytes below came from a session with {:?} at {}.", - session_info.server_name, time - ); - println!("Note that the bytes which the Prover chose not to disclose are shown as X."); - println!(); - println!("Bytes sent:"); - println!(); - print!("{}", String::from_utf8(sent.data().to_vec()).unwrap()); - println!(); - println!("Bytes received:"); - println!(); - println!("{}", String::from_utf8(recv.data().to_vec()).unwrap()); - println!("-------------------------------------------------------------------"); -} - -/// Returns a Notary pubkey trusted by this Verifier -fn notary_pubkey() -> p256::PublicKey { - let pem_file = str::from_utf8(include_bytes!( - "../../../notary-server/fixture/notary/notary.pub" - )) - .unwrap(); - p256::PublicKey::from_public_key_pem(pem_file).unwrap() -} diff --git a/tlsn/examples/simple/README.md b/tlsn/examples/simple/README.md deleted file mode 100644 index 9c5d1bf454..0000000000 --- a/tlsn/examples/simple/README.md +++ /dev/null @@ -1,90 +0,0 @@ -## Simple Example: Notarize Public Data from example.com (Rust) - -This example demonstrates the simplest possible use case for TLSNotary: -1. Notarize: Fetch and create a proof of its content. -2. Verify the proof. - -Next, we will redact the content and verify it again: -1. Redact the `USER_AGENT` and titles. -2. Verify the redacted proof. - -### 1. Notarize - -Run a simple prover: - -```shell -cargo run --release --example simple_prover -``` - -If the notarization was successful, you should see this output in the console: - -```log -Starting an MPC TLS connection with the server -Got a response from the server -Notarization completed successfully! -The proof has been written to `simple_proof.json` -``` - -⚠️ In this simple example the `Notary` server is automatically started in the background. Note that this is for demonstration purposes only. In a real work example, the notary should be run by a neutral party or the verifier of the proofs. Consult the [Notary Server Docs](https://docs.tlsnotary.org/developers/notary_server.html) for more details on how to run a notary server. - -### 2. Verify the Proof - -When you open `simple_proof.json` in an editor, you will see a JSON file with lots of non-human-readable byte arrays. You can decode this file by running: - -```shell -cargo run --release --example simple_verifier -``` - -This will output the TLS-transaction in clear text: - -```log -Successfully verified that the bytes below came from a session with Dns("example.com") at 2023-11-03 08:48:20 UTC. -Note that the bytes which the Prover chose not to disclose are shown as X. - -Bytes sent: -... -``` - -### 3. Redact Information - -Open `simple/src/examples/simple_prover.rs` and locate the line with: - -```rust -let redact = false; -``` - -and change it to: - -```rust -let redact = true; -``` - -Next, if you run the `simple_prover` and `simple_verifier` again, you'll notice redacted `X`'s in the output: - -```shell -cargo run --release --example simple_prover -cargo run --release --example simple_verifier -``` - -```log - - - - XXXXXXXXXXXXXX -... -``` - -You can also use to inspect your proofs. Simply drag and drop `simple_proof.json` from your proof file explorer into the drop zone. Redacted bytes are marked with red Xs characters. - -### (Optional) Extra Experiments - -Feel free to try these extra challenges: - -- [ ] Modify the `server_name` (or any other data) in `simple_proof.json` and verify that the proof is no longer valid. -- [ ] Modify the `build_proof_with_redactions` function in `simple_prover.rs` to redact more or different data. - -### Next steps - -Try out the [Discord example](../Discord/README.md) and notarize a Discord conversations. - - diff --git a/tlsn/examples/simple/simple_prover.rs b/tlsn/examples/simple/simple_prover.rs deleted file mode 100644 index d9dabcadb5..0000000000 --- a/tlsn/examples/simple/simple_prover.rs +++ /dev/null @@ -1,251 +0,0 @@ -/// Runs a simple Prover which connects to the Notary and notarizes a request/response from -/// example.com. The Prover then generates a proof and writes it to disk. -/// -/// The example uses the notary server implemented in ./simple_notary.rs -use futures::AsyncWriteExt; -use hyper::{Body, Request, StatusCode}; -use std::ops::Range; -use tlsn_core::proof::TlsProof; -use tokio::io::{AsyncWriteExt as _, DuplexStream}; -use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; - -use tlsn_prover::tls::{state::Notarize, Prover, ProverConfig}; - -// Setting of the application server -const SERVER_DOMAIN: &str = "example.com"; -const USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36"; - -use p256::pkcs8::DecodePrivateKey; -use std::str; - -use tlsn_verifier::tls::{Verifier, VerifierConfig}; - -#[tokio::main] -async fn main() { - tracing_subscriber::fmt::init(); - - let (prover_socket, notary_socket) = tokio::io::duplex(1 << 16); - - // Start a local simple notary service - start_notary_thread(prover_socket).await; - - // A Prover configuration - let config = ProverConfig::builder() - .id("example") - .server_dns(SERVER_DOMAIN) - .build() - .unwrap(); - - // Create a Prover and set it up with the Notary - // This will set up the MPC backend prior to connecting to the server. - let prover = Prover::new(config) - .setup(notary_socket.compat()) - .await - .unwrap(); - - // Connect to the Server via TCP. This is the TLS client socket. - let client_socket = tokio::net::TcpStream::connect((SERVER_DOMAIN, 443)) - .await - .unwrap(); - - // Bind the Prover to the server connection. - // The returned `mpc_tls_connection` is an MPC TLS connection to the Server: all data written - // to/read from it will be encrypted/decrypted using MPC with the Notary. - let (mpc_tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); - - // Spawn the Prover task to be run concurrently - let prover_task = tokio::spawn(prover_fut); - - // Attach the hyper HTTP client to the MPC TLS connection - let (mut request_sender, connection) = - hyper::client::conn::handshake(mpc_tls_connection.compat()) - .await - .unwrap(); - - // Spawn the HTTP task to be run concurrently - let connection_task = tokio::spawn(connection.without_shutdown()); - - // Build a simple HTTP request with common headers - let request = Request::builder() - .uri("/") - .header("Host", SERVER_DOMAIN) - .header("Accept", "*/*") - // Using "identity" instructs the Server not to use compression for its HTTP response. - // TLSNotary tooling does not support compression. - .header("Accept-Encoding", "identity") - .header("Connection", "close") - .header("User-Agent", USER_AGENT) - .body(Body::empty()) - .unwrap(); - - println!("Starting an MPC TLS connection with the server"); - - // Send the request to the Server and get a response via the MPC TLS connection - let response = request_sender.send_request(request).await.unwrap(); - - println!("Got a response from the server"); - - assert!(response.status() == StatusCode::OK); - - // Close the connection to the server - let mut client_socket = connection_task.await.unwrap().unwrap().io.into_inner(); - client_socket.close().await.unwrap(); - - // The Prover task should be done now, so we can grab the Prover. - let prover = prover_task.await.unwrap().unwrap(); - - // Prepare for notarization. - let prover = prover.start_notarize(); - - // Build proof (with or without redactions) - let redact = false; - let proof = if !redact { - build_proof_without_redactions(prover).await - } else { - build_proof_with_redactions(prover).await - }; - - // Write the proof to a file - let mut file = tokio::fs::File::create("simple_proof.json").await.unwrap(); - file.write_all(serde_json::to_string_pretty(&proof).unwrap().as_bytes()) - .await - .unwrap(); - - println!("Notarization completed successfully!"); - println!("The proof has been written to `simple_proof.json`"); -} - -/// Find the ranges of the public and private parts of a sequence. -/// -/// Returns a tuple of `(public, private)` ranges. -fn find_ranges(seq: &[u8], private_seq: &[&[u8]]) -> (Vec>, Vec>) { - let mut private_ranges = Vec::new(); - for s in private_seq { - for (idx, w) in seq.windows(s.len()).enumerate() { - if w == *s { - private_ranges.push(idx..(idx + w.len())); - } - } - } - - let mut sorted_ranges = private_ranges.clone(); - sorted_ranges.sort_by_key(|r| r.start); - - let mut public_ranges = Vec::new(); - let mut last_end = 0; - for r in sorted_ranges { - if r.start > last_end { - public_ranges.push(last_end..r.start); - } - last_end = r.end; - } - - if last_end < seq.len() { - public_ranges.push(last_end..seq.len()); - } - - (public_ranges, private_ranges) -} - -async fn build_proof_without_redactions(mut prover: Prover) -> TlsProof { - let sent_len = prover.sent_transcript().data().len(); - let recv_len = prover.recv_transcript().data().len(); - - let builder = prover.commitment_builder(); - let sent_commitment = builder.commit_sent(0..sent_len).unwrap(); - let recv_commitment = builder.commit_recv(0..recv_len).unwrap(); - - // Finalize, returning the notarized session - let notarized_session = prover.finalize().await.unwrap(); - - // Create a proof for all committed data in this session - let mut proof_builder = notarized_session.data().build_substrings_proof(); - - // Reveal all the public ranges - proof_builder.reveal(sent_commitment).unwrap(); - proof_builder.reveal(recv_commitment).unwrap(); - - let substrings_proof = proof_builder.build().unwrap(); - - TlsProof { - session: notarized_session.session_proof(), - substrings: substrings_proof, - } -} - -async fn build_proof_with_redactions(mut prover: Prover) -> TlsProof { - // Identify the ranges in the outbound data which contain data which we want to disclose - let (sent_public_ranges, _) = find_ranges( - prover.sent_transcript().data(), - &[ - // Redact the value of the "User-Agent" header. It will NOT be disclosed. - USER_AGENT.as_bytes(), - ], - ); - - // Identify the ranges in the inbound data which contain data which we want to disclose - let (recv_public_ranges, _) = find_ranges( - prover.recv_transcript().data(), - &[ - // Redact the value of the title. It will NOT be disclosed. - "Example Domain".as_bytes(), - ], - ); - - let builder = prover.commitment_builder(); - - // Commit to each range of the public outbound data which we want to disclose - let sent_commitments: Vec<_> = sent_public_ranges - .iter() - .map(|r| builder.commit_sent(r.clone()).unwrap()) - .collect(); - // Commit to each range of the public inbound data which we want to disclose - let recv_commitments: Vec<_> = recv_public_ranges - .iter() - .map(|r| builder.commit_recv(r.clone()).unwrap()) - .collect(); - - // Finalize, returning the notarized session - let notarized_session = prover.finalize().await.unwrap(); - - // Create a proof for all committed data in this session - let mut proof_builder = notarized_session.data().build_substrings_proof(); - - // Reveal all the public ranges - for commitment_id in sent_commitments { - proof_builder.reveal(commitment_id).unwrap(); - } - for commitment_id in recv_commitments { - proof_builder.reveal(commitment_id).unwrap(); - } - - let substrings_proof = proof_builder.build().unwrap(); - - TlsProof { - session: notarized_session.session_proof(), - substrings: substrings_proof, - } -} - -async fn start_notary_thread(socket: DuplexStream) { - tokio::spawn(async { - // Load the notary signing key - let signing_key_str = str::from_utf8(include_bytes!( - "../../../notary-server/fixture/notary/notary.key" - )) - .unwrap(); - let signing_key = p256::ecdsa::SigningKey::from_pkcs8_pem(signing_key_str).unwrap(); - - // Spawn notarization task to be run concurrently - tokio::spawn(async move { - // Setup default config. Normally a different ID would be generated - // for each notarization. - let config = VerifierConfig::builder().id("example").build().unwrap(); - - Verifier::new(config) - .notarize::<_, p256::ecdsa::Signature>(socket.compat(), &signing_key) - .await - .unwrap(); - }); - }); -} diff --git a/tlsn/examples/simple/simple_verifier.rs b/tlsn/examples/simple/simple_verifier.rs deleted file mode 100644 index c3d65f7aa0..0000000000 --- a/tlsn/examples/simple/simple_verifier.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::{str, time::Duration}; - -use elliptic_curve::pkcs8::DecodePublicKey; - -use tlsn_core::proof::{SessionProof, TlsProof}; - -/// A simple verifier which reads a proof generated by `simple_prover.rs` from "proof.json", verifies -/// it and prints the verified data to the console. -fn main() { - // Deserialize the proof - let proof = std::fs::read_to_string("simple_proof.json").unwrap(); - let proof: TlsProof = serde_json::from_str(proof.as_str()).unwrap(); - - let TlsProof { - // The session proof establishes the identity of the server and the commitments - // to the TLS transcript. - session, - // The substrings proof proves select portions of the transcript, while redacting - // anything the Prover chose not to disclose. - substrings, - } = proof; - - // Verify the session proof against the Notary's public key - // - // This verifies the identity of the server using a default certificate verifier which trusts - // the root certificates from the `webpki-roots` crate. - session - .verify_with_default_cert_verifier(notary_pubkey()) - .unwrap(); - - let SessionProof { - // The session header that was signed by the Notary is a succinct commitment to the TLS transcript. - header, - // This is the session_info, which contains the server_name, that is checked against the - // certificate chain shared in the TLS handshake. - session_info, - .. - } = session; - - // The time at which the session was recorded - let time = chrono::DateTime::UNIX_EPOCH + Duration::from_secs(header.time()); - - // Verify the substrings proof against the session header. - // - // This returns the redacted transcripts - let (mut sent, mut recv) = substrings.verify(&header).unwrap(); - - // Replace the bytes which the Prover chose not to disclose with 'X' - sent.set_redacted(b'X'); - recv.set_redacted(b'X'); - - println!("-------------------------------------------------------------------"); - println!( - "Successfully verified that the bytes below came from a session with {:?} at {}.", - session_info.server_name, time - ); - println!("Note that the bytes which the Prover chose not to disclose are shown as X."); - println!(); - println!("Bytes sent:"); - println!(); - print!("{}", String::from_utf8(sent.data().to_vec()).unwrap()); - println!(); - println!("Bytes received:"); - println!(); - println!("{}", String::from_utf8(recv.data().to_vec()).unwrap()); - println!("-------------------------------------------------------------------"); -} - -/// Returns a Notary pubkey trusted by this Verifier -fn notary_pubkey() -> p256::PublicKey { - let pem_file = str::from_utf8(include_bytes!( - "../../../notary-server/fixture/notary/notary.pub" - )) - .unwrap(); - p256::PublicKey::from_public_key_pem(pem_file).unwrap() -} diff --git a/tlsn/examples/twitter/.env.example b/tlsn/examples/twitter/.env.example deleted file mode 100644 index afda2fd549..0000000000 --- a/tlsn/examples/twitter/.env.example +++ /dev/null @@ -1,5 +0,0 @@ -CONVERSATION_ID="20124652-973145016511139841" -CLIENT_UUID="e6f00000-cccc-dddd-bbbb-eeeeeefaaa27" -AUTH_TOKEN="670ccccccbe2bbbbbbbc1025aaaaaafa55555551" -ACCESS_TOKEN="AAAAAAAAAAAAAAAAAAAAANRILgAA...4puTs%3D1Zv7...WjCpTnA" -CSRF_TOKEN="77d8ef46bd57f722ea7e9f...f4235a713040bfcaac1cd6909" diff --git a/tlsn/examples/twitter/README.md b/tlsn/examples/twitter/README.md deleted file mode 100644 index 394895e00f..0000000000 --- a/tlsn/examples/twitter/README.md +++ /dev/null @@ -1,106 +0,0 @@ -# Notarize Twitter DMs - -The `twtter_dm.rs` example sets up a TLS connection with Twitter and notarizes the requested DMs. The full received transcript is notarized in one commitment, so nothing is redacted. The resulting proof is written to a local JSON file (`twitter_dm_proof.json`) for easier inspection. - -This involves 3 steps: -1. Configure the inputs -2. Start the (local) notary server -3. Notarize - -## Inputs - -In this tlsn/examples/twitter folder, create a `.env` file. -Then in that `.env` file, set the values of the following constants by following the format shown in this [example env file](./.env.example). - -| Name | Example | Location in Request Headers Section (within Network Tab of Developer Tools) | -| --------------- | ------------------------------------------------------- | -------------------------------------------------------------------------------- | -| CONVERSATION_ID | `20124652-973145016511139841` | Look for `Referer`, then extract the `ID` in `https://twitter.com/messages/` | -| CLIENT_UUID | `e6f00000-cccc-dddd-bbbb-eeeeeefaaa27` | Look for `X-Client-Uuid`, then copy the entire value | -| AUTH_TOKEN | `670ccccccbe2bbbbbbbc1025aaaaaafa55555551` | Look for `Cookie`, then extract the `token` in `;auth_token=;` | -| ACCESS_TOKEN | `AAAAAAAAAAAAAAAAAAAAANRILgAA...4puTs%3D1Zv7...WjCpTnA` | Look for `Authorization`, then extract the `token` in `Bearer ` | -| CSRF_TOKEN | `77d8ef46bd57f722ea7e9f...f4235a713040bfcaac1cd6909` | Look for `X-Csrf-Token`, then copy the entire value | - -You can obtain these parameters by opening [Twitter](https://twitter.com/messages/) in your browser and accessing the message history you want to notarize. Please note that notarizing only works for short transcripts at the moment, so choose a contact with a short history. - -Next, open the **Developer Tools**, go to the **Network** tab, and refresh the page. Then, click on **Search** and type `uuid` as shown in the screenshot below — all of these constants should be under the **Request Headers** section. Refer to the table above on where to find each of the constant value. - -![Screenshot](twitter_dm_browser.png) - -## Start the notary server -At the root level of this repository, run -```sh -cd notary-server -cargo run --release -``` - -The notary server will now be running in the background waiting for connections. - -For more information on how to configure the notary server, please refer to [this](../../../notary-server/README.md#running-the-server). - -## Notarize - -In this tlsn/examples/twitter folder, run the following command: - -```sh -RUST_LOG=debug,yamux=info cargo run --release --example twitter_dm -``` - -If everything goes well, you should see output similar to the following: - -```log - Compiling tlsn-examples v0.0.0 (/Users/heeckhau/tlsnotary/tlsn/tlsn/examples) - Finished release [optimized] target(s) in 8.52s - Running `/Users/heeckhau/tlsnotary/tlsn/tlsn/target/release/examples/twitter_dm` -2023-08-15T12:49:38.532924Z DEBUG rustls::client::hs: No cached session for DnsName("tlsnotaryserver.io") -2023-08-15T12:49:38.533384Z DEBUG rustls::client::hs: Not resuming any session -2023-08-15T12:49:38.543493Z DEBUG rustls::client::hs: Using ciphersuite TLS13_AES_256_GCM_SHA384 -2023-08-15T12:49:38.543632Z DEBUG rustls::client::tls13: Not resuming -2023-08-15T12:49:38.543792Z DEBUG rustls::client::tls13: TLS1.3 encrypted extensions: [ServerNameAck] -2023-08-15T12:49:38.543803Z DEBUG rustls::client::hs: ALPN protocol is None -2023-08-15T12:49:38.544305Z DEBUG twitter_dm: Sending configuration request -2023-08-15T12:49:38.544556Z DEBUG hyper::proto::h1::io: flushed 163 bytes -2023-08-15T12:49:38.546069Z DEBUG hyper::proto::h1::io: parsed 3 headers -2023-08-15T12:49:38.546078Z DEBUG hyper::proto::h1::conn: incoming body is content-length (52 bytes) -2023-08-15T12:49:38.546168Z DEBUG hyper::proto::h1::conn: incoming body completed -2023-08-15T12:49:38.546187Z DEBUG twitter_dm: Sent configuration request -2023-08-15T12:49:38.546192Z DEBUG twitter_dm: Response OK -2023-08-15T12:49:38.546224Z DEBUG twitter_dm: Notarization response: NotarizationSessionResponse { session_id: "2675e0f9-d06c-499b-8e9e-2b893a6d7356" } -2023-08-15T12:49:38.546257Z DEBUG twitter_dm: Sending notarization request -2023-08-15T12:49:38.546291Z DEBUG hyper::proto::h1::io: flushed 152 bytes -2023-08-15T12:49:38.546743Z DEBUG hyper::proto::h1::io: parsed 3 headers -2023-08-15T12:49:38.546748Z DEBUG hyper::proto::h1::conn: incoming body is empty -2023-08-15T12:49:38.546766Z DEBUG twitter_dm: Sent notarization request -2023-08-15T12:49:38.546772Z DEBUG twitter_dm: Switched protocol OK -2023-08-15T12:49:40.088422Z DEBUG twitter_dm: Sending request -2023-08-15T12:49:40.088464Z DEBUG hyper::proto::h1::io: flushed 950 bytes -2023-08-15T12:49:40.143884Z DEBUG tls_client::client::hs: ALPN protocol is None -2023-08-15T12:49:40.143893Z DEBUG tls_client::client::hs: Using ciphersuite Tls12(Tls12CipherSuite { suite: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, algorithm: AES_128_GCM }) -2023-08-15T12:49:40.144666Z DEBUG tls_client::client::tls12: ECDHE curve is ECParameters { curve_type: NamedCurve, named_group: secp256r1 } -2023-08-15T12:49:40.144687Z DEBUG tls_client::client::tls12: Server DNS name is DnsName(DnsName(DnsName("twitter.com"))) -2023-08-15T12:51:01.336491Z DEBUG hyper::proto::h1::io: parsed 31 headers -2023-08-15T12:51:01.336507Z DEBUG hyper::proto::h1::conn: incoming body is content-length (4330 bytes) -2023-08-15T12:51:01.336516Z DEBUG hyper::proto::h1::conn: incoming body completed -2023-08-15T12:51:01.336528Z DEBUG twitter_dm: Sent request -2023-08-15T12:51:01.336537Z DEBUG twitter_dm: Request OK -2023-08-15T12:51:01.336585Z DEBUG twitter_dm: { - "conversation_timeline": { - "entries": [ - { - "message": { - "conversation_id": "20124652-45653288", - ... - "withheld_in_countries": [] - } - } - } -} -2023-08-15T12:51:08.854818Z DEBUG twitter_dm: Notarization complete! -``` - -If the transcript was too long, you may encounter the following error: - -``` -thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: IOError(Custom { kind: InvalidData, error: BackendError(DecryptionError("Other: KOSReceiverActor is not setup")) })', /Users/heeckhau/tlsnotary/tlsn/tlsn/tlsn-prover/src/lib.rs:173:50 -``` - -> **_NOTE:_** ℹ️ hosts a generic proof visualizer. Drag and drop your proof into the drop zone to check and render your proof. \ No newline at end of file diff --git a/tlsn/examples/twitter/twitter_dm.rs b/tlsn/examples/twitter/twitter_dm.rs deleted file mode 100644 index fd78a2a306..0000000000 --- a/tlsn/examples/twitter/twitter_dm.rs +++ /dev/null @@ -1,358 +0,0 @@ -/// This example shows how to notarize Twitter DMs. -/// -/// The example uses the notary server implemented in ../../../notary-server -use futures::AsyncWriteExt; -use hyper::{body::to_bytes, client::conn::Parts, Body, Request, StatusCode}; -use rustls::{Certificate, ClientConfig, RootCertStore}; -use serde::{Deserialize, Serialize}; -use std::{env, ops::Range, str, sync::Arc}; -use tlsn_core::proof::TlsProof; -use tokio::{io::AsyncWriteExt as _, net::TcpStream}; -use tokio_rustls::TlsConnector; -use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; -use tracing::debug; - -use tlsn_prover::tls::{Prover, ProverConfig}; - -// Setting of the application server -const SERVER_DOMAIN: &str = "twitter.com"; -const ROUTE: &str = "i/api/1.1/dm/conversation"; -const USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36"; - -// Setting of the notary server — make sure these are the same with those in ../../../notary-server -const NOTARY_HOST: &str = "127.0.0.1"; -const NOTARY_PORT: u16 = 7047; - -// Configuration of notarization -const NOTARY_MAX_TRANSCRIPT_SIZE: usize = 16384; - -/// Response object of the /session API -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct NotarizationSessionResponse { - pub session_id: String, -} - -/// Request object of the /session API -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct NotarizationSessionRequest { - pub client_type: ClientType, - /// Maximum transcript size in bytes - pub max_transcript_size: Option, -} - -/// Types of client that the prover is using -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub enum ClientType { - /// Client that has access to the transport layer - Tcp, - /// Client that cannot directly access transport layer, e.g. browser extension - Websocket, -} - -#[tokio::main] -async fn main() { - tracing_subscriber::fmt::init(); - - // Load secret variables frome environment for twitter server connection - dotenv::dotenv().ok(); - let conversation_id = env::var("CONVERSATION_ID").unwrap(); - let client_uuid = env::var("CLIENT_UUID").unwrap(); - let auth_token = env::var("AUTH_TOKEN").unwrap(); - let access_token = env::var("ACCESS_TOKEN").unwrap(); - let csrf_token = env::var("CSRF_TOKEN").unwrap(); - - let (notary_tls_socket, session_id) = setup_notary_connection().await; - - // Basic default prover config using the session_id returned from /session endpoint just now - let config = ProverConfig::builder() - .id(session_id) - .server_dns(SERVER_DOMAIN) - .build() - .unwrap(); - - // Create a new prover and set up the MPC backend. - let prover = Prover::new(config) - .setup(notary_tls_socket.compat()) - .await - .unwrap(); - - let client_socket = tokio::net::TcpStream::connect((SERVER_DOMAIN, 443)) - .await - .unwrap(); - - // Bind the Prover to server connection - let (tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); - - // Spawn the Prover to be run concurrently - let prover_task = tokio::spawn(prover_fut); - - // Attach the hyper HTTP client to the TLS connection - let (mut request_sender, connection) = hyper::client::conn::handshake(tls_connection.compat()) - .await - .unwrap(); - - // Spawn the HTTP task to be run concurrently - let connection_task = tokio::spawn(connection.without_shutdown()); - - // Build the HTTP request to fetch the DMs - let request = Request::builder() - .uri(format!( - "https://{SERVER_DOMAIN}/{ROUTE}/{conversation_id}.json" - )) - .header("Host", SERVER_DOMAIN) - .header("Accept", "*/*") - .header("Accept-Encoding", "identity") - .header("Connection", "close") - .header("User-Agent", USER_AGENT) - .header("Authorization", format!("Bearer {access_token}")) - .header( - "Cookie", - format!("auth_token={auth_token}; ct0={csrf_token}"), - ) - .header("Authority", SERVER_DOMAIN) - .header("X-Twitter-Auth-Type", "OAuth2Session") - .header("x-twitter-active-user", "yes") - .header("X-Client-Uuid", client_uuid.clone()) - .header("X-Csrf-Token", csrf_token.clone()) - .body(Body::empty()) - .unwrap(); - - debug!("Sending request"); - - let response = request_sender.send_request(request).await.unwrap(); - - debug!("Sent request"); - - assert!(response.status() == StatusCode::OK, "{}", response.status()); - - debug!("Request OK"); - - // Pretty printing :) - let payload = to_bytes(response.into_body()).await.unwrap().to_vec(); - let parsed = - serde_json::from_str::(&String::from_utf8_lossy(&payload)).unwrap(); - debug!("{}", serde_json::to_string_pretty(&parsed).unwrap()); - - // Close the connection to the server - let mut client_socket = connection_task.await.unwrap().unwrap().io.into_inner(); - client_socket.close().await.unwrap(); - - // The Prover task should be done now, so we can grab it. - let prover = prover_task.await.unwrap().unwrap(); - - // Prepare for notarization - let mut prover = prover.start_notarize(); - - // Identify the ranges in the transcript that contain secrets - let (public_ranges, private_ranges) = find_ranges( - prover.sent_transcript().data(), - &[ - access_token.as_bytes(), - auth_token.as_bytes(), - csrf_token.as_bytes(), - client_uuid.as_bytes(), - ], - ); - - let recv_len = prover.recv_transcript().data().len(); - - let builder = prover.commitment_builder(); - - // Commit to send public data and collect commitment ids for the outbound transcript - let mut commitment_ids = public_ranges - .iter() - .map(|range| builder.commit_sent(range.clone()).unwrap()) - .collect::>(); - // Commit to private data. This is not needed for proof creation but ensures the data - // is in the notarized session file for optional future disclosure. - private_ranges.iter().for_each(|range| { - builder.commit_sent(range.clone()).unwrap(); - }); - // Commit to the received (public) data. - commitment_ids.push(builder.commit_recv(0..recv_len).unwrap()); - - // Finalize, returning the notarized session - let notarized_session = prover.finalize().await.unwrap(); - - debug!("Notarization complete!"); - - // Dump the notarized session to a file - let mut file = tokio::fs::File::create("twitter_dm.json").await.unwrap(); - file.write_all( - serde_json::to_string_pretty(¬arized_session) - .unwrap() - .as_bytes(), - ) - .await - .unwrap(); - - let session_proof = notarized_session.session_proof(); - - let mut proof_builder = notarized_session.data().build_substrings_proof(); - for commitment_id in commitment_ids { - proof_builder.reveal(commitment_id).unwrap(); - } - let substrings_proof = proof_builder.build().unwrap(); - - let proof = TlsProof { - session: session_proof, - substrings: substrings_proof, - }; - - // Dump the proof to a file. - let mut file = tokio::fs::File::create("twitter_dm_proof.json") - .await - .unwrap(); - file.write_all(serde_json::to_string_pretty(&proof).unwrap().as_bytes()) - .await - .unwrap(); -} - -async fn setup_notary_connection() -> (tokio_rustls::client::TlsStream, String) { - // Connect to the Notary via TLS-TCP - let pem_file = str::from_utf8(include_bytes!( - "../../../notary-server/fixture/tls/rootCA.crt" - )) - .unwrap(); - let mut reader = std::io::BufReader::new(pem_file.as_bytes()); - let mut certificates: Vec = rustls_pemfile::certs(&mut reader) - .unwrap() - .into_iter() - .map(Certificate) - .collect(); - let certificate = certificates.remove(0); - - let mut root_store = RootCertStore::empty(); - root_store.add(&certificate).unwrap(); - - let client_notary_config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_no_client_auth(); - let notary_connector = TlsConnector::from(Arc::new(client_notary_config)); - - let notary_socket = tokio::net::TcpStream::connect((NOTARY_HOST, NOTARY_PORT)) - .await - .unwrap(); - - let notary_tls_socket = notary_connector - // Require the domain name of notary server to be the same as that in the server cert - .connect("tlsnotaryserver.io".try_into().unwrap(), notary_socket) - .await - .unwrap(); - - // Attach the hyper HTTP client to the notary TLS connection to send request to the /session endpoint to configure notarization and obtain session id - let (mut request_sender, connection) = hyper::client::conn::handshake(notary_tls_socket) - .await - .unwrap(); - - // Spawn the HTTP task to be run concurrently - let connection_task = tokio::spawn(connection.without_shutdown()); - - // Build the HTTP request to configure notarization - let payload = serde_json::to_string(&NotarizationSessionRequest { - client_type: ClientType::Tcp, - max_transcript_size: Some(NOTARY_MAX_TRANSCRIPT_SIZE), - }) - .unwrap(); - - let request = Request::builder() - .uri(format!("https://{NOTARY_HOST}:{NOTARY_PORT}/session")) - .method("POST") - .header("Host", NOTARY_HOST) - // Need to specify application/json for axum to parse it as json - .header("Content-Type", "application/json") - .body(Body::from(payload)) - .unwrap(); - - debug!("Sending configuration request"); - - let configuration_response = request_sender.send_request(request).await.unwrap(); - - debug!("Sent configuration request"); - - assert!(configuration_response.status() == StatusCode::OK); - - debug!("Response OK"); - - // Pretty printing :) - let payload = to_bytes(configuration_response.into_body()) - .await - .unwrap() - .to_vec(); - let notarization_response = - serde_json::from_str::(&String::from_utf8_lossy(&payload)) - .unwrap(); - - debug!("Notarization response: {:?}", notarization_response,); - - // Send notarization request via HTTP, where the underlying TCP connection will be extracted later - let request = Request::builder() - // Need to specify the session_id so that notary server knows the right configuration to use - // as the configuration is set in the previous HTTP call - .uri(format!( - "https://{}:{}/notarize?sessionId={}", - NOTARY_HOST, - NOTARY_PORT, - notarization_response.session_id.clone() - )) - .method("GET") - .header("Host", NOTARY_HOST) - .header("Connection", "Upgrade") - // Need to specify this upgrade header for server to extract tcp connection later - .header("Upgrade", "TCP") - .body(Body::empty()) - .unwrap(); - - debug!("Sending notarization request"); - - let response = request_sender.send_request(request).await.unwrap(); - - debug!("Sent notarization request"); - - assert!(response.status() == StatusCode::SWITCHING_PROTOCOLS); - - debug!("Switched protocol OK"); - - // Claim back the TLS socket after HTTP exchange is done - let Parts { - io: notary_tls_socket, - .. - } = connection_task.await.unwrap().unwrap(); - - (notary_tls_socket, notarization_response.session_id) -} - -/// Find the ranges of the public and private parts of a sequence. -/// -/// Returns a tuple of `(public, private)` ranges. -fn find_ranges(seq: &[u8], sub_seq: &[&[u8]]) -> (Vec>, Vec>) { - let mut private_ranges = Vec::new(); - for s in sub_seq { - for (idx, w) in seq.windows(s.len()).enumerate() { - if w == *s { - private_ranges.push(idx..(idx + w.len())); - } - } - } - - let mut sorted_ranges = private_ranges.clone(); - sorted_ranges.sort_by_key(|r| r.start); - - let mut public_ranges = Vec::new(); - let mut last_end = 0; - for r in sorted_ranges { - if r.start > last_end { - public_ranges.push(last_end..r.start); - } - last_end = r.end; - } - - if last_end < seq.len() { - public_ranges.push(last_end..seq.len()); - } - - (public_ranges, private_ranges) -} diff --git a/tlsn/examples/twitter/twitter_dm_browser.png b/tlsn/examples/twitter/twitter_dm_browser.png deleted file mode 100644 index 96df52498c..0000000000 Binary files a/tlsn/examples/twitter/twitter_dm_browser.png and /dev/null differ diff --git a/tlsn/tests-integration/Cargo.toml b/tlsn/tests-integration/Cargo.toml deleted file mode 100644 index 708e1f6170..0000000000 --- a/tlsn/tests-integration/Cargo.toml +++ /dev/null @@ -1,26 +0,0 @@ -[package] -name = "tests-integration" -version = "0.0.0" -edition = "2021" -publish = false - -[dev-dependencies] -tlsn-core.workspace = true -tlsn-tls-core.workspace = true -tlsn-prover = { workspace = true, features = ["tracing"] } -tlsn-verifier = { workspace = true, features = ["tracing"] } -tlsn-server-fixture.workspace = true -tlsn-utils.workspace = true - -p256 = { workspace = true, features = ["ecdsa"] } -hyper = { workspace = true, features = ["client", "http1"] } - -futures.workspace = true -tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] } -tokio-util.workspace = true - -tracing.workspace = true -tracing-subscriber.workspace = true - -serde_json = "1.0" -bincode = "*" diff --git a/tlsn/tests-integration/tests/defer_decryption.rs b/tlsn/tests-integration/tests/defer_decryption.rs deleted file mode 100644 index daccea1671..0000000000 --- a/tlsn/tests-integration/tests/defer_decryption.rs +++ /dev/null @@ -1,82 +0,0 @@ -use futures::{AsyncReadExt, AsyncWriteExt}; -use tlsn_prover::tls::{Prover, ProverConfig}; -use tlsn_server_fixture::{CA_CERT_DER, SERVER_DOMAIN}; -use tlsn_verifier::tls::{Verifier, VerifierConfig}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::compat::TokioAsyncReadCompatExt; -use tracing::instrument; - -#[tokio::test] -#[ignore] -async fn test_defer_decryption() { - tracing_subscriber::fmt::init(); - - let (socket_0, socket_1) = tokio::io::duplex(2 << 23); - - tokio::join!(prover(socket_0), notary(socket_1)); -} - -#[instrument(skip(notary_socket))] -async fn prover(notary_socket: T) { - let (client_socket, server_socket) = tokio::io::duplex(2 << 16); - - let server_task = tokio::spawn(tlsn_server_fixture::bind(server_socket.compat())); - - let mut root_store = tls_core::anchors::RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - - let prover = Prover::new( - ProverConfig::builder() - .id("test") - .server_dns(SERVER_DOMAIN) - .root_cert_store(root_store) - .build() - .unwrap(), - ) - .setup(notary_socket.compat()) - .await - .unwrap(); - - let (mut tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); - let prover_ctrl = prover_fut.control(); - let prover_task = tokio::spawn(prover_fut); - - // Defer decryption until after the server closes the connection. - prover_ctrl.defer_decryption().await.unwrap(); - - tls_connection - .write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") - .await - .unwrap(); - tls_connection.close().await.unwrap(); - - let mut response = vec![0u8; 1024]; - tls_connection.read_to_end(&mut response).await.unwrap(); - - server_task.await.unwrap(); - - let mut prover = prover_task.await.unwrap().unwrap().start_notarize(); - let sent_tx_len = prover.sent_transcript().data().len(); - let recv_tx_len = prover.recv_transcript().data().len(); - - let builder = prover.commitment_builder(); - - // Commit to everything - builder.commit_sent(0..sent_tx_len).unwrap(); - builder.commit_recv(0..recv_tx_len).unwrap(); - - let _notarized_session = prover.finalize().await.unwrap(); -} - -#[instrument(skip(socket))] -async fn notary(socket: T) { - let verifier = Verifier::new(VerifierConfig::builder().id("test").build().unwrap()); - let signing_key = p256::ecdsa::SigningKey::from_bytes(&[1u8; 32].into()).unwrap(); - - _ = verifier - .notarize::<_, p256::ecdsa::Signature>(socket.compat(), &signing_key) - .await - .unwrap(); -} diff --git a/tlsn/tests-integration/tests/notarize.rs b/tlsn/tests-integration/tests/notarize.rs deleted file mode 100644 index 0462d8ee23..0000000000 --- a/tlsn/tests-integration/tests/notarize.rs +++ /dev/null @@ -1,98 +0,0 @@ -use futures::AsyncWriteExt; -use hyper::{body::to_bytes, Body, Request, StatusCode}; -use tlsn_prover::tls::{Prover, ProverConfig}; -use tlsn_server_fixture::{CA_CERT_DER, SERVER_DOMAIN}; -use tlsn_verifier::tls::{Verifier, VerifierConfig}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; -use tracing::instrument; - -#[tokio::test] -#[ignore] -async fn notarize() { - tracing_subscriber::fmt::init(); - - let (socket_0, socket_1) = tokio::io::duplex(2 << 23); - - tokio::join!(prover(socket_0), notary(socket_1)); -} - -#[instrument(skip(notary_socket))] -async fn prover(notary_socket: T) { - let (client_socket, server_socket) = tokio::io::duplex(2 << 16); - - let server_task = tokio::spawn(tlsn_server_fixture::bind(server_socket.compat())); - - let mut root_store = tls_core::anchors::RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - - let prover = Prover::new( - ProverConfig::builder() - .id("test") - .server_dns(SERVER_DOMAIN) - .root_cert_store(root_store) - .build() - .unwrap(), - ) - .setup(notary_socket.compat()) - .await - .unwrap(); - - let (tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); - - let prover_task = tokio::spawn(prover_fut); - - let (mut request_sender, connection) = hyper::client::conn::handshake(tls_connection.compat()) - .await - .unwrap(); - - let connection_task = tokio::spawn(connection.without_shutdown()); - - let request = Request::builder() - .uri(format!("https://{}", SERVER_DOMAIN)) - .header("Host", SERVER_DOMAIN) - .header("Connection", "close") - .method("GET") - .body(Body::empty()) - .unwrap(); - - let response = request_sender.send_request(request).await.unwrap(); - - assert!(response.status() == StatusCode::OK); - - println!( - "{:?}", - String::from_utf8_lossy(&to_bytes(response.into_body()).await.unwrap()) - ); - - server_task.await.unwrap(); - - let mut client_socket = connection_task.await.unwrap().unwrap().io.into_inner(); - - client_socket.close().await.unwrap(); - - let mut prover = prover_task.await.unwrap().unwrap().start_notarize(); - let sent_tx_len = prover.sent_transcript().data().len(); - let recv_tx_len = prover.recv_transcript().data().len(); - - let builder = prover.commitment_builder(); - - // Commit to everything - builder.commit_sent(0..sent_tx_len).unwrap(); - builder.commit_recv(0..recv_tx_len).unwrap(); - - let _notarized_session = prover.finalize().await.unwrap(); -} - -#[instrument(skip(socket))] -async fn notary(socket: T) { - let verifier = Verifier::new(VerifierConfig::builder().id("test").build().unwrap()); - let signing_key = p256::ecdsa::SigningKey::from_bytes(&[1u8; 32].into()).unwrap(); - - _ = verifier - .notarize::<_, p256::ecdsa::Signature>(socket.compat(), &signing_key) - .await - .unwrap(); -} diff --git a/tlsn/tests-integration/tests/verify.rs b/tlsn/tests-integration/tests/verify.rs deleted file mode 100644 index 15b84163bf..0000000000 --- a/tlsn/tests-integration/tests/verify.rs +++ /dev/null @@ -1,119 +0,0 @@ -use futures::AsyncWriteExt; -use hyper::{body::to_bytes, Body, Request, StatusCode}; -use tls_core::{anchors::RootCertStore, verify::WebPkiVerifier}; -use tlsn_core::{proof::SessionInfo, Direction, RedactedTranscript}; -use tlsn_prover::tls::{Prover, ProverConfig}; -use tlsn_server_fixture::{CA_CERT_DER, SERVER_DOMAIN}; -use tlsn_verifier::tls::{Verifier, VerifierConfig}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; -use tracing::instrument; -use utils::range::RangeSet; - -#[tokio::test] -#[ignore] -async fn verify() { - tracing_subscriber::fmt::init(); - - let (socket_0, socket_1) = tokio::io::duplex(2 << 23); - - let (_, (sent, received, _session_info)) = tokio::join!(prover(socket_0), verifier(socket_1)); - - assert_eq!(sent.authed(), &RangeSet::from(0..sent.data().len() - 1)); - assert_eq!( - sent.redacted(), - &RangeSet::from(sent.data().len() - 1..sent.data().len()) - ); - - assert_eq!(received.authed(), &RangeSet::from(2..received.data().len())); - assert_eq!(received.redacted(), &RangeSet::from(0..2)); -} - -#[instrument(skip(notary_socket))] -async fn prover(notary_socket: T) { - let (client_socket, server_socket) = tokio::io::duplex(2 << 16); - - let server_task = tokio::spawn(tlsn_server_fixture::bind(server_socket.compat())); - - let mut root_store = RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - - let prover = Prover::new( - ProverConfig::builder() - .id("test") - .server_dns(SERVER_DOMAIN) - .root_cert_store(root_store) - .build() - .unwrap(), - ) - .setup(notary_socket.compat()) - .await - .unwrap(); - - let (tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); - - let prover_task = tokio::spawn(prover_fut); - - let (mut request_sender, connection) = hyper::client::conn::handshake(tls_connection.compat()) - .await - .unwrap(); - - let connection_task = tokio::spawn(connection.without_shutdown()); - - let request = Request::builder() - .uri(format!("https://{}", SERVER_DOMAIN)) - .header("Host", SERVER_DOMAIN) - .header("Connection", "close") - .method("GET") - .body(Body::empty()) - .unwrap(); - - let response = request_sender.send_request(request).await.unwrap(); - - assert!(response.status() == StatusCode::OK); - - println!( - "{:?}", - String::from_utf8_lossy(&to_bytes(response.into_body()).await.unwrap()) - ); - - server_task.await.unwrap(); - - let mut client_socket = connection_task.await.unwrap().unwrap().io.into_inner(); - - client_socket.close().await.unwrap(); - - let mut prover = prover_task.await.unwrap().unwrap().start_prove(); - - let sent_transcript_len = prover.sent_transcript().data().len(); - let recv_transcript_len = prover.recv_transcript().data().len(); - - // Reveal parts of the transcript - _ = prover.reveal(0..sent_transcript_len - 1, Direction::Sent); - _ = prover.reveal(2..recv_transcript_len, Direction::Received); - prover.prove().await.unwrap(); - - prover.finalize().await.unwrap() -} - -#[instrument(skip(socket))] -async fn verifier( - socket: T, -) -> (RedactedTranscript, RedactedTranscript, SessionInfo) { - let mut root_store = RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - - let verifier_config = VerifierConfig::builder() - .id("test") - .cert_verifier(WebPkiVerifier::new(root_store, None)) - .build() - .unwrap(); - let verifier = Verifier::new(verifier_config); - - let (sent, received, session_info) = verifier.verify(socket.compat()).await.unwrap(); - (sent, received, session_info) -} diff --git a/tlsn/tlsn-common/Cargo.toml b/tlsn/tlsn-common/Cargo.toml deleted file mode 100644 index b70c14e2a2..0000000000 --- a/tlsn/tlsn-common/Cargo.toml +++ /dev/null @@ -1,15 +0,0 @@ -[package] -name = "tlsn-common" -description = "Common code shared between tlsn-prover and tlsn-verifier" -version = "0.1.0-alpha.3" -edition = "2021" - -[features] -default = ["tracing"] -tracing = ["uid-mux/tracing"] - -[dependencies] -tlsn-utils-aio.workspace = true - -futures.workspace = true -uid-mux.workspace = true diff --git a/tlsn/tlsn-common/src/lib.rs b/tlsn/tlsn-common/src/lib.rs deleted file mode 100644 index bcdfc7afa9..0000000000 --- a/tlsn/tlsn-common/src/lib.rs +++ /dev/null @@ -1,17 +0,0 @@ -//! Common code shared between `tlsn-prover` and `tlsn-verifier`. - -#![deny(missing_docs, unreachable_pub, unused_must_use)] -#![deny(clippy::all)] -#![forbid(unsafe_code)] - -pub mod mux; - -/// The party's role in the TLSN protocol. -/// -/// A Notary is classified as a Verifier. -pub enum Role { - /// The prover. - Prover, - /// The verifier. - Verifier, -} diff --git a/tlsn/tlsn-common/src/mux.rs b/tlsn/tlsn-common/src/mux.rs deleted file mode 100644 index 67a9be1a7c..0000000000 --- a/tlsn/tlsn-common/src/mux.rs +++ /dev/null @@ -1,45 +0,0 @@ -//! Multiplexer used in the TLSNotary protocol. - -use utils_aio::codec::BincodeMux; - -use futures::{AsyncRead, AsyncWrite}; -use uid_mux::{yamux, UidYamux, UidYamuxControl}; - -use crate::Role; - -/// Multiplexer supporting unique deterministic stream IDs. -pub type Mux = UidYamux; -/// Multiplexer controller providing streams with a codec attached. -pub type MuxControl = BincodeMux; - -const KB: usize = 1024; -const MB: usize = 1024 * KB; - -/// Attaches a multiplexer to the provided socket. -/// -/// Returns the multiplexer and a controller for creating streams with a codec attached. -/// -/// # Arguments -/// -/// * `socket` - The socket to attach the multiplexer to. -/// * `role` - The role of the party using the multiplexer. -pub fn attach_mux( - socket: T, - role: Role, -) -> (Mux, MuxControl) { - let mut mux_config = yamux::Config::default(); - // See PR #418 - mux_config.set_max_num_streams(40); - mux_config.set_max_buffer_size(16 * MB); - mux_config.set_receive_window(16 * MB as u32); - - let mux_role = match role { - Role::Prover => yamux::Mode::Client, - Role::Verifier => yamux::Mode::Server, - }; - - let mux = UidYamux::new(mux_config, socket, mux_role); - let ctrl = BincodeMux::new(mux.control()); - - (mux, ctrl) -} diff --git a/tlsn/tlsn-core/Cargo.toml b/tlsn/tlsn-core/Cargo.toml deleted file mode 100644 index ec3854adb3..0000000000 --- a/tlsn/tlsn-core/Cargo.toml +++ /dev/null @@ -1,51 +0,0 @@ -[package] -name = "tlsn-core" -authors = ["TLSNotary Team"] -description = "Core types for TLSNotary" -keywords = ["tls", "mpc", "2pc", "types"] -categories = ["cryptography"] -license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" -edition = "2021" - -[features] -default = [] -fixtures = ["dep:hex"] - -[dependencies] -tlsn-tls-core = { workspace = true, features = ["serde"] } - -tlsn-utils.workspace = true - -mpz-core.workspace = true -mpz-garble-core.workspace = true -mpz-circuits.workspace = true - -thiserror.workspace = true -serde.workspace = true -p256 = { workspace = true, features = ["serde"] } -webpki-roots.workspace = true -rs_merkle.workspace = true -rstest = { workspace = true, optional = true } -hex = { workspace = true, optional = true } -bytes = { workspace = true, features = ["serde"] } -opaque-debug.workspace = true - -bimap = { version = "0.6.3", features = ["serde"] } - -web-time.workspace = true - -[dev-dependencies] -rstest.workspace = true -hex.workspace = true -rand_core.workspace = true -rand_chacha.workspace = true -bincode.workspace = true - -[[test]] -name = "api" -required-features = ["fixtures"] - -[target.'cfg(target_arch = "wasm32")'.dependencies] -ring = { version = "0.17", features = ["wasm32_unknown_unknown_js"] } -getrandom = { version = "0.2", features = ["js"] } diff --git a/tlsn/tlsn-core/src/commitment/blake3.rs b/tlsn/tlsn-core/src/commitment/blake3.rs deleted file mode 100644 index 900ed99235..0000000000 --- a/tlsn/tlsn-core/src/commitment/blake3.rs +++ /dev/null @@ -1,105 +0,0 @@ -use crate::commitment::{Commitment, CommitmentOpening}; -use mpz_core::{ - commit::{Decommitment, HashCommit, Nonce}, - hash::Hash, -}; -use mpz_garble_core::{encoding_state, encoding_state::Full, EncodedValue}; -use serde::{Deserialize, Serialize}; - -/// A Blake3 commitment to the encodings of the substrings of a [`Transcript`](crate::Transcript). -#[derive(Clone, Copy, Serialize, Deserialize)] -pub struct Blake3Commitment { - hash: Hash, - nonce: Nonce, -} - -opaque_debug::implement!(Blake3Commitment); - -impl Blake3Commitment { - /// Creates a new Blake3 commitment - pub fn new(encodings: &[EncodedValue]) -> Self { - let (decommitment, hash) = encodings.hash_commit(); - - Self { - hash, - nonce: *decommitment.nonce(), - } - } - - /// Returns the hash of this commitment - pub fn hash(&self) -> &Hash { - &self.hash - } - - /// Returns the nonce of this commitment - pub fn nonce(&self) -> &Nonce { - &self.nonce - } - - /// Opens this commitment - pub fn open(&self, data: Vec) -> Blake3Opening { - Blake3Opening::new(data, self.nonce) - } -} - -impl From for Commitment { - fn from(value: Blake3Commitment) -> Self { - Self::Blake3(value) - } -} - -/// A substring opening using Blake3 -#[derive(Serialize, Deserialize, Clone)] -pub struct Blake3Opening { - data: Vec, - nonce: Nonce, -} - -impl Blake3Opening { - pub(crate) fn new(data: Vec, nonce: Nonce) -> Self { - Self { data, nonce } - } - - /// Recovers the expected commitment from this opening. - /// - /// # Panics - /// - /// - If the number of encodings does not match the number of bytes in the opening. - /// - If an encoding is not for a u8. - pub fn recover(&self, encodings: &[EncodedValue]) -> Blake3Commitment { - assert_eq!( - encodings.len(), - self.data.len(), - "encodings and data must have the same length" - ); - - let encodings = encodings - .iter() - .zip(&self.data) - .map(|(encoding, data)| encoding.select(*data).expect("encoding is for a u8")) - .collect::>(); - - let hash = Decommitment::new_with_nonce(encodings, self.nonce).commit(); - - Blake3Commitment { - hash, - nonce: self.nonce, - } - } - - /// Returns the transcript data corresponding to this opening - pub fn data(&self) -> &[u8] { - &self.data - } - - /// Returns the transcript data corresponding to this opening - pub fn into_data(self) -> Vec { - self.data - } -} - -impl From for CommitmentOpening { - fn from(value: Blake3Opening) -> Self { - Self::Blake3(value) - } -} diff --git a/tlsn/tlsn-core/src/commitment/builder.rs b/tlsn/tlsn-core/src/commitment/builder.rs deleted file mode 100644 index afaba54340..0000000000 --- a/tlsn/tlsn-core/src/commitment/builder.rs +++ /dev/null @@ -1,171 +0,0 @@ -use std::collections::HashMap; - -use bimap::BiMap; -use mpz_core::hash::Hash; -use utils::range::RangeSet; - -use crate::{ - commitment::{ - blake3::Blake3Commitment, Commitment, CommitmentId, CommitmentInfo, CommitmentKind, - TranscriptCommitments, - }, - merkle::MerkleTree, - transcript::get_value_ids, - Direction, EncodingProvider, -}; - -/// An error for [`TranscriptCommitmentBuilder`] -#[derive(Debug, thiserror::Error)] -pub enum TranscriptCommitmentBuilderError { - /// Empty range - #[error("can not commit to an empty range")] - EmptyRange, - /// Range out of bounds - #[error("range out of bounds")] - RangeOutOfBounds, - /// Failed to retrieve encodings for the provided transcript ranges. - #[error("failed to retrieve encodings for the provided transcript ranges")] - MissingEncodings, - /// Duplicate commitment - #[error("attempted to create a duplicate commitment, overwriting: {0:?}")] - Duplicate(CommitmentId), - /// No commitments were added - #[error("no commitments were added")] - NoCommitments, -} - -/// A builder for [`TranscriptCommitments`]. -pub struct TranscriptCommitmentBuilder { - commitments: HashMap, - /// Information about the above `commitments`. - commitment_info: BiMap, - merkle_leaves: Vec, - /// A function that returns the encodings for the provided transcript byte ids. - encoding_provider: EncodingProvider, - sent_len: usize, - recv_len: usize, -} - -opaque_debug::implement!(TranscriptCommitmentBuilder); - -impl TranscriptCommitmentBuilder { - /// Creates a new builder. - /// - /// # Arguments - /// - /// * `encoding_provider` - A function that returns the encodings for the provided transcript byte ids. - #[doc(hidden)] - pub fn new(encoding_provider: EncodingProvider, sent_len: usize, recv_len: usize) -> Self { - Self { - commitments: HashMap::default(), - commitment_info: BiMap::default(), - merkle_leaves: Vec::default(), - encoding_provider, - sent_len, - recv_len, - } - } - - /// Commits to the provided ranges of the `sent` transcript. - pub fn commit_sent( - &mut self, - ranges: impl Into>, - ) -> Result { - self.add_substrings_commitment(ranges.into(), Direction::Sent) - } - - /// Commits to the provided ranges of the `received` transcript. - pub fn commit_recv( - &mut self, - ranges: impl Into>, - ) -> Result { - self.add_substrings_commitment(ranges.into(), Direction::Received) - } - - /// Gets the commitment id for the provided commitment info. - pub fn get_id( - &self, - kind: CommitmentKind, - ranges: impl Into>, - direction: Direction, - ) -> Option { - self.commitment_info - .get_by_right(&CommitmentInfo { - kind, - ranges: ranges.into(), - direction, - }) - .copied() - } - - /// Add a commitment to substrings of the transcript - fn add_substrings_commitment( - &mut self, - ranges: RangeSet, - direction: Direction, - ) -> Result { - let max = ranges - .max() - .ok_or(TranscriptCommitmentBuilderError::EmptyRange)?; - let len = match direction { - Direction::Sent => self.sent_len, - Direction::Received => self.recv_len, - }; - - if max > len { - return Err(TranscriptCommitmentBuilderError::RangeOutOfBounds); - } - - let ids: Vec<_> = get_value_ids(&ranges, direction).collect(); - - let id_refs = ids.iter().map(|id| id.as_ref()).collect::>(); - - let encodings = (self.encoding_provider)(&id_refs) - .ok_or(TranscriptCommitmentBuilderError::MissingEncodings)?; - - // We only support BLAKE3 for now - let commitment = Blake3Commitment::new(&encodings); - let hash = *commitment.hash(); - - let id = CommitmentId::new(self.merkle_leaves.len() as u32); - - let commitment: Commitment = commitment.into(); - - // Store commitment with its id - self.commitment_info - .insert_no_overwrite( - id, - CommitmentInfo::new(commitment.kind(), ranges, direction), - ) - .map_err(|(id, _)| TranscriptCommitmentBuilderError::Duplicate(id))?; - - if self.commitments.insert(id, commitment).is_some() { - // This shouldn't be possible, as we check for duplicates above. - panic!("commitment id already exists"); - } - - // Insert commitment hash into the merkle tree - self.merkle_leaves.push(hash); - - Ok(id) - } - - /// Builds the [`TranscriptCommitments`] - pub fn build(self) -> Result { - let Self { - commitments, - commitment_info, - merkle_leaves, - .. - } = self; - - let merkle_tree = MerkleTree::from_leaves(&merkle_leaves) - .map_err(|_| TranscriptCommitmentBuilderError::NoCommitments)?; - - Ok(TranscriptCommitments { - merkle_tree, - commitments, - commitment_info, - }) - } -} diff --git a/tlsn/tlsn-core/src/commitment/mod.rs b/tlsn/tlsn-core/src/commitment/mod.rs deleted file mode 100644 index 7c32e56d5c..0000000000 --- a/tlsn/tlsn-core/src/commitment/mod.rs +++ /dev/null @@ -1,193 +0,0 @@ -//! Types related to transcript commitments. - -/// BLAKE3 commitments. -pub mod blake3; -mod builder; - -use std::collections::HashMap; - -use bimap::BiMap; -use mpz_core::hash::Hash; -use mpz_garble_core::{encoding_state::Full, EncodedValue}; -use serde::{Deserialize, Serialize}; -use utils::range::RangeSet; - -use crate::{ - merkle::{MerkleRoot, MerkleTree}, - Direction, -}; - -pub use builder::{TranscriptCommitmentBuilder, TranscriptCommitmentBuilderError}; - -/// A commitment id. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] -pub struct CommitmentId(u32); - -impl CommitmentId { - /// Creates a new commitment id - pub(crate) fn new(id: u32) -> Self { - Self(id) - } - - /// Returns the inner value - pub(crate) fn to_inner(self) -> u32 { - self.0 - } -} - -/// Info of a transcript commitment -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub struct CommitmentInfo { - pub(crate) kind: CommitmentKind, - pub(crate) ranges: RangeSet, - pub(crate) direction: Direction, -} - -impl CommitmentInfo { - /// Creates new commitment info. - pub(crate) fn new(kind: CommitmentKind, ranges: RangeSet, direction: Direction) -> Self { - Self { - kind, - ranges, - direction, - } - } - - /// Returns the kind of this commitment - pub fn kind(&self) -> CommitmentKind { - self.kind - } - - /// Returns the ranges of this commitment - pub fn ranges(&self) -> &RangeSet { - &self.ranges - } - - /// Returns the direction of this commitment - pub fn direction(&self) -> &Direction { - &self.direction - } -} - -/// A commitment to some bytes in a transcript -#[derive(Clone, Serialize, Deserialize)] -#[non_exhaustive] -pub enum Commitment { - /// A BLAKE3 commitment to encodings of the transcript. - Blake3(blake3::Blake3Commitment), -} - -impl Commitment { - /// Returns the hash of this commitment - pub fn hash(&self) -> Hash { - match self { - Commitment::Blake3(commitment) => *commitment.hash(), - } - } - - /// Returns the kind of this commitment - pub fn kind(&self) -> CommitmentKind { - match self { - Commitment::Blake3(_) => CommitmentKind::Blake3, - } - } -} - -/// The kind of a [`Commitment`]. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[non_exhaustive] -pub enum CommitmentKind { - /// A BLAKE3 commitment to encodings of the transcript. - Blake3, -} - -/// An opening to a commitment to the transcript. -#[derive(Clone, Serialize, Deserialize)] -#[non_exhaustive] -pub enum CommitmentOpening { - /// An opening to a BLAKE3 commitment - Blake3(blake3::Blake3Opening), -} - -impl CommitmentOpening { - /// Returns the kind of this opening - pub fn kind(&self) -> CommitmentKind { - match self { - CommitmentOpening::Blake3(_) => CommitmentKind::Blake3, - } - } - - /// Recovers the expected commitment from this opening. - /// - /// # Panics - /// - /// Implementations may panic if the following conditions are not met: - /// - /// - If the number of encodings does not match the number of bytes in the opening. - /// - If an encoding is not for a u8. - pub fn recover(&self, encodings: &[EncodedValue]) -> Commitment { - match self { - CommitmentOpening::Blake3(opening) => opening.recover(encodings).into(), - } - } - - /// Returns the transcript data corresponding to this opening - pub fn data(&self) -> &[u8] { - match self { - CommitmentOpening::Blake3(opening) => opening.data(), - } - } - - /// Returns the transcript data corresponding to this opening - pub fn into_data(self) -> Vec { - match self { - CommitmentOpening::Blake3(opening) => opening.into_data(), - } - } -} - -/// A collection of transcript commitments. -#[derive(Clone, Serialize, Deserialize)] -pub struct TranscriptCommitments { - /// A Merkle tree of commitments. Each commitment's index in the tree matches its `CommitmentId`. - merkle_tree: MerkleTree, - commitments: HashMap, - /// Information about the above `commitments`. - commitment_info: BiMap, -} - -opaque_debug::implement!(TranscriptCommitments); - -impl TranscriptCommitments { - /// Returns the merkle tree of the commitments. - pub fn merkle_tree(&self) -> &MerkleTree { - &self.merkle_tree - } - - /// Returns the merkle root of the commitments. - pub fn merkle_root(&self) -> MerkleRoot { - self.merkle_tree.root() - } - - /// Returns a commitment if it exists. - pub fn get(&self, id: &CommitmentId) -> Option<&Commitment> { - self.commitments.get(id) - } - - /// Returns the commitment id for a commitment with the given info, if it exists. - pub fn get_id_by_info( - &self, - kind: CommitmentKind, - ranges: RangeSet, - direction: Direction, - ) -> Option { - self.commitment_info - .get_by_right(&CommitmentInfo::new(kind, ranges, direction)) - .copied() - } - - /// Returns commitment info, if it exists. - pub fn get_info(&self, id: &CommitmentId) -> Option<&CommitmentInfo> { - self.commitment_info.get_by_left(id) - } -} diff --git a/tlsn/tlsn-core/src/fixtures/cert.rs b/tlsn/tlsn-core/src/fixtures/cert.rs deleted file mode 100644 index c52d14be46..0000000000 --- a/tlsn/tlsn-core/src/fixtures/cert.rs +++ /dev/null @@ -1,130 +0,0 @@ -use tls_core::{ - key::{Certificate, PublicKey}, - msgs::{ - codec::Codec, - enums::{NamedGroup, SignatureScheme}, - handshake::{DigitallySignedStruct, Random, ServerECDHParams}, - }, -}; - -use hex::FromHex; - -/// Collects data needed for testing -pub struct TestData { - /// end-entity cert - pub ee: Certificate, - /// intermediate cert - pub inter: Certificate, - /// CA cert - pub ca: Certificate, - /// client random - pub cr: Random, - /// server random - pub sr: Random, - /// server ephemeral P256 pubkey - pub pubkey: PublicKey, - /// server signature over the key exchange parameters - pub sig: Vec, - /// unix time when TLS handshake began - pub time: u64, - /// algorithm used to create the sig - pub sig_scheme: SignatureScheme, - /// DNS name of the website - pub dns_name: String, -} - -impl TestData { - /// Returns the [ServerECDHParams] in encoded form - pub fn kx_params(&self) -> Vec { - let mut params = Vec::new(); - let ecdh_params = ServerECDHParams::new(NamedGroup::secp256r1, &self.pubkey.key); - ecdh_params.encode(&mut params); - params - } - - /// Returns the [DigitallySignedStruct] - pub fn dss(&self) -> DigitallySignedStruct { - DigitallySignedStruct::new(self.sig_scheme, self.sig.clone()) - } - - /// Returns the client random + server random + kx params in encoded form - pub fn signature_msg(&self) -> Vec { - let mut msg = Vec::new(); - msg.extend_from_slice(&self.cr.0); - msg.extend_from_slice(&self.sr.0); - msg.extend_from_slice(&self.kx_params()); - msg - } -} - -/// Returns test data for the tlsnotary.org website -pub fn tlsnotary() -> TestData { - TestData { - ee: Certificate(include_bytes!("testdata/key_exchange/tlsnotary.org/ee.der").to_vec()), - inter: Certificate( - include_bytes!("testdata/key_exchange/tlsnotary.org/inter.der").to_vec(), - ), - ca: Certificate(include_bytes!("testdata/key_exchange/tlsnotary.org/ca.der").to_vec()), - cr: Random( - <[u8; 32]>::from_hex(include_bytes!( - "testdata/key_exchange/tlsnotary.org/client_random" - )) - .unwrap(), - ), - sr: Random( - <[u8; 32]>::from_hex(include_bytes!( - "testdata/key_exchange/tlsnotary.org/server_random" - )) - .unwrap(), - ), - pubkey: PublicKey::new( - NamedGroup::secp256r1, - &Vec::::from_hex(include_bytes!("testdata/key_exchange/tlsnotary.org/pubkey")) - .unwrap(), - ), - sig: Vec::::from_hex(include_bytes!( - "testdata/key_exchange/tlsnotary.org/signature" - )) - .unwrap(), - time: 1671637529, - sig_scheme: SignatureScheme::RSA_PKCS1_SHA256, - dns_name: "tlsnotary.org".to_string(), - } -} - -/// Returns test data for the appliedzkp.org website -pub fn appliedzkp() -> TestData { - TestData { - ee: Certificate(include_bytes!("testdata/key_exchange/appliedzkp.org/ee.der").to_vec()), - inter: Certificate( - include_bytes!("testdata/key_exchange/appliedzkp.org/inter.der").to_vec(), - ), - ca: Certificate(include_bytes!("testdata/key_exchange/appliedzkp.org/ca.der").to_vec()), - cr: Random( - <[u8; 32]>::from_hex(include_bytes!( - "testdata/key_exchange/appliedzkp.org/client_random" - )) - .unwrap(), - ), - sr: Random( - <[u8; 32]>::from_hex(include_bytes!( - "testdata/key_exchange/appliedzkp.org/server_random" - )) - .unwrap(), - ), - pubkey: PublicKey::new( - NamedGroup::secp256r1, - &Vec::::from_hex(include_bytes!( - "testdata/key_exchange/appliedzkp.org/pubkey" - )) - .unwrap(), - ), - sig: Vec::::from_hex(include_bytes!( - "testdata/key_exchange/appliedzkp.org/signature" - )) - .unwrap(), - time: 1671637529, - sig_scheme: SignatureScheme::ECDSA_NISTP256_SHA256, - dns_name: "appliedzkp.org".to_string(), - } -} diff --git a/tlsn/tlsn-core/src/fixtures/mod.rs b/tlsn/tlsn-core/src/fixtures/mod.rs deleted file mode 100644 index 90aa7aa1bf..0000000000 --- a/tlsn/tlsn-core/src/fixtures/mod.rs +++ /dev/null @@ -1,169 +0,0 @@ -//! Fixtures for testing - -/// Certificate fixtures -pub mod cert; - -use std::collections::HashMap; - -use hex::FromHex; -use mpz_circuits::types::ValueType; -use mpz_core::{commit::HashCommit, hash::Hash, utils::blake3}; -use mpz_garble_core::{ChaChaEncoder, Encoder}; -use tls_core::{ - cert::ServerCertDetails, - handshake::HandshakeData, - ke::ServerKxDetails, - key::{Certificate, PublicKey}, - msgs::{ - codec::Codec, - enums::{NamedGroup, SignatureScheme}, - handshake::{DigitallySignedStruct, Random, ServerECDHParams}, - }, -}; - -use p256::ecdsa::SigningKey; - -use crate::{ - merkle::MerkleRoot, - session::{HandshakeSummary, SessionHeader}, - EncodingProvider, -}; - -fn value_id(id: &str) -> u64 { - let hash = blake3(id.as_bytes()); - u64::from_be_bytes(hash[..8].try_into().unwrap()) -} - -/// Returns a session header fixture using the given transcript lengths and merkle root. -/// -/// # Arguments -/// -/// * `root` - The merkle root of the transcript commitments. -/// * `sent_len` - The length of the sent transcript. -/// * `recv_len` - The length of the received transcript. -pub fn session_header(root: MerkleRoot, sent_len: usize, recv_len: usize) -> SessionHeader { - SessionHeader::new( - encoder_seed(), - root, - sent_len, - recv_len, - handshake_summary(), - ) -} - -/// Returns an encoding provider fixture using the given transcripts. -pub fn encoding_provider(transcript_tx: &[u8], transcript_rx: &[u8]) -> EncodingProvider { - let encoder = encoder(); - let mut active_encodings = HashMap::new(); - for (idx, byte) in transcript_tx.iter().enumerate() { - let id = format!("tx/{idx}"); - let enc = encoder.encode_by_type(value_id(&id), &ValueType::U8); - active_encodings.insert(id, enc.select(*byte).unwrap()); - } - for (idx, byte) in transcript_rx.iter().enumerate() { - let id = format!("rx/{idx}"); - let enc = encoder.encode_by_type(value_id(&id), &ValueType::U8); - active_encodings.insert(id, enc.select(*byte).unwrap()); - } - - Box::new(move |ids: &[&str]| { - ids.iter() - .map(|id| active_encodings.get(*id).cloned()) - .collect() - }) -} - -/// Returns a handshake summary fixture. -pub fn handshake_summary() -> HandshakeSummary { - HandshakeSummary::new(1671637529, server_ephemeral_key(), handshake_commitment()) -} - -/// Returns a handshake commitment fixture. -pub fn handshake_commitment() -> Hash { - let (_, hash) = handshake_data().hash_commit(); - hash -} - -/// Returns a handshake data fixture. -pub fn handshake_data() -> HandshakeData { - HandshakeData::new( - server_cert_details(), - server_kx_details(), - client_random(), - server_random(), - ) -} - -/// Returns a server certificate details fixture. -pub fn server_cert_details() -> ServerCertDetails { - ServerCertDetails::new( - vec![ - Certificate(include_bytes!("testdata/key_exchange/tlsnotary.org/ee.der").to_vec()), - Certificate(include_bytes!("testdata/key_exchange/tlsnotary.org/inter.der").to_vec()), - Certificate(include_bytes!("testdata/key_exchange/tlsnotary.org/ca.der").to_vec()), - ], - vec![], - None, - ) -} - -/// Returns a server key exchange details fixture. -pub fn server_kx_details() -> ServerKxDetails { - let mut params = Vec::new(); - let ecdh_params = ServerECDHParams::new(NamedGroup::secp256r1, &server_ephemeral_key().key); - ecdh_params.encode(&mut params); - - ServerKxDetails::new( - params, - DigitallySignedStruct::new( - SignatureScheme::RSA_PKCS1_SHA256, - Vec::::from_hex(include_bytes!( - "testdata/key_exchange/tlsnotary.org/signature" - )) - .unwrap(), - ), - ) -} - -/// Returns a client random fixture. -pub fn client_random() -> Random { - Random( - <[u8; 32]>::from_hex(include_bytes!( - "testdata/key_exchange/tlsnotary.org/client_random" - )) - .unwrap(), - ) -} - -/// Returns a server random fixture. -pub fn server_random() -> Random { - Random( - <[u8; 32]>::from_hex(include_bytes!( - "testdata/key_exchange/tlsnotary.org/server_random" - )) - .unwrap(), - ) -} - -/// Returns an encoder fixture. -pub fn encoder() -> ChaChaEncoder { - ChaChaEncoder::new(encoder_seed()) -} - -/// Returns an encoder seed fixture. -pub fn encoder_seed() -> [u8; 32] { - [0u8; 32] -} - -/// Returns a server ephemeral key fixture. -pub fn server_ephemeral_key() -> PublicKey { - PublicKey::new( - NamedGroup::secp256r1, - &Vec::::from_hex(include_bytes!("testdata/key_exchange/tlsnotary.org/pubkey")).unwrap(), - ) -} - -/// Returns a notary signing key fixture. -pub fn notary_signing_key() -> SigningKey { - SigningKey::from_slice(&[1; 32]).unwrap() -} diff --git a/tlsn/tlsn-core/src/lib.rs b/tlsn/tlsn-core/src/lib.rs deleted file mode 100644 index 037f801038..0000000000 --- a/tlsn/tlsn-core/src/lib.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! TLSNotary core protocol library. -//! -//! This crate contains core types for the TLSNotary protocol, including some functionality for selective disclosure. - -#![deny(missing_docs, unreachable_pub, unused_must_use)] -#![deny(clippy::all)] -#![forbid(unsafe_code)] - -pub mod commitment; -#[cfg(any(test, feature = "fixtures"))] -pub mod fixtures; -pub mod merkle; -pub mod msg; -pub mod proof; -pub mod session; -mod signature; -pub mod transcript; - -pub use session::{HandshakeSummary, NotarizedSession, SessionData, SessionHeader}; -pub use signature::{NotaryPublicKey, Signature}; -pub use transcript::{Direction, RedactedTranscript, Transcript, TranscriptSlice}; - -use mpz_garble_core::{encoding_state, EncodedValue}; -use serde::{Deserialize, Serialize}; - -/// The maximum allowed total bytelength of all committed data. Used to prevent DoS during verification. -/// (this will cause the verifier to hash up to a max of 1GB * 128 = 128GB of plaintext encodings if the -/// commitment type is [crate::commitment::Blake3]). -/// -/// This value must not exceed bcs's MAX_SEQUENCE_LENGTH limit (which is (1 << 31) - 1 by default) -const MAX_TOTAL_COMMITTED_DATA: usize = 1_000_000_000; - -/// A provider of plaintext encodings. -pub(crate) type EncodingProvider = - Box Option>> + Send>; - -/// The encoding id -/// -/// A 64 bit Blake3 hash which is used for the plaintext encodings -#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)] -pub(crate) struct EncodingId(u64); - -impl EncodingId { - /// Create a new encoding ID. - pub(crate) fn new(id: &str) -> Self { - let hash = mpz_core::utils::blake3(id.as_bytes()); - Self(u64::from_be_bytes(hash[..8].try_into().unwrap())) - } - - /// Returns the encoding ID. - pub(crate) fn to_inner(self) -> u64 { - self.0 - } -} - -/// A Server's name. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub enum ServerName { - /// A DNS name. - Dns(String), -} - -impl ServerName { - /// Returns a reference to the server name as a string slice. - pub fn as_str(&self) -> &str { - match self { - Self::Dns(name) => name.as_str(), - } - } -} - -impl AsRef for ServerName { - fn as_ref(&self) -> &str { - match self { - Self::Dns(name) => name.as_ref(), - } - } -} diff --git a/tlsn/tlsn-core/src/merkle.rs b/tlsn/tlsn-core/src/merkle.rs deleted file mode 100644 index 941782090a..0000000000 --- a/tlsn/tlsn-core/src/merkle.rs +++ /dev/null @@ -1,385 +0,0 @@ -//! Merkle tree types. -//! -//! # Usage -//! -//! During notarization, the `Prover` generates various commitments to the transcript data, which are subsequently -//! inserted into a `MerkleTree`. Rather than send each commitment to the Notary individually, the `Prover` simply sends the -//! `MerkleRoot`. This hides the number of commitments from the Notary, which is important for privacy as it can leak -//! information about the content of the transcript. -//! -//! Later, during selective disclosure to a `Verifier`, the `Prover` can open any subset of the commitments in the `MerkleTree` -//! by providing a `MerkleProof` for the corresponding `MerkleRoot` which was signed by the Notary. - -use mpz_core::hash::Hash; -use rs_merkle::{ - algorithms::Sha256, proof_serializers, MerkleProof as MerkleProof_rs_merkle, - MerkleTree as MerkleTree_rs_merkle, -}; -use serde::{ser::Serializer, Deserialize, Deserializer, Serialize}; -use utils::iter::DuplicateCheck; - -/// A Merkle root. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -pub struct MerkleRoot([u8; 32]); - -impl MerkleRoot { - /// Returns the inner byte array - pub fn to_inner(self) -> [u8; 32] { - self.0 - } -} - -impl From<[u8; 32]> for MerkleRoot { - fn from(bytes: [u8; 32]) -> Self { - Self(bytes) - } -} - -/// Errors that can occur during operations with Merkle tree and Merkle proof -#[derive(Debug, thiserror::Error, PartialEq)] -#[allow(missing_docs)] -pub enum MerkleError { - #[error("Failed to verify a Merkle proof")] - MerkleProofVerificationFailed, - #[error("No leaves were provided when constructing a Merkle tree")] - MerkleNoLeavesProvided, -} - -/// A Merkle proof. -#[derive(Serialize, Deserialize)] -pub struct MerkleProof { - #[serde( - serialize_with = "merkle_proof_serialize", - deserialize_with = "merkle_proof_deserialize" - )] - proof: MerkleProof_rs_merkle, - total_leaves: usize, -} - -impl MerkleProof { - /// Checks if indices, hashes and leaves count are valid for the provided root - /// - /// # Panics - /// - /// - If the length of `leaf_indices` and `leaf_hashes` does not match. - /// - If `leaf_indices` contains duplicates. - pub fn verify( - &self, - root: &MerkleRoot, - leaf_indices: &[usize], - leaf_hashes: &[Hash], - ) -> Result<(), MerkleError> { - assert_eq!( - leaf_indices.len(), - leaf_hashes.len(), - "leaf indices length must match leaf hashes length" - ); - assert!( - !leaf_indices.iter().contains_dups(), - "duplicate indices provided {:?}", - leaf_indices - ); - - // zip indices and hashes - let mut tuples: Vec<(usize, [u8; 32])> = leaf_indices - .iter() - .cloned() - .zip(leaf_hashes.iter().cloned().map(|h| *h.as_bytes())) - .collect(); - - // sort by index and unzip - tuples.sort_by(|(a, _), (b, _)| a.cmp(b)); - let (indices, hashes): (Vec, Vec<[u8; 32]>) = tuples.into_iter().unzip(); - - if !self - .proof - .verify(root.to_inner(), &indices, &hashes, self.total_leaves) - { - return Err(MerkleError::MerkleProofVerificationFailed); - } - Ok(()) - } -} - -impl Clone for MerkleProof { - fn clone(&self) -> Self { - let bytes = self.proof.to_bytes(); - Self { - proof: MerkleProof_rs_merkle::::from_bytes(&bytes).unwrap(), - total_leaves: self.total_leaves, - } - } -} - -fn merkle_proof_serialize( - proof: &MerkleProof_rs_merkle, - serializer: S, -) -> Result -where - S: Serializer, -{ - let bytes = proof.serialize::(); - serializer.serialize_bytes(&bytes) -} - -fn merkle_proof_deserialize<'de, D>( - deserializer: D, -) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - let bytes = Vec::deserialize(deserializer)?; - MerkleProof_rs_merkle::::from_bytes(bytes.as_slice()).map_err(serde::de::Error::custom) -} - -/// A Merkle tree. -#[derive(Serialize, Deserialize, Default, Clone)] -pub struct MerkleTree( - #[serde( - serialize_with = "merkle_tree_serialize", - deserialize_with = "merkle_tree_deserialize" - )] - pub MerkleTree_rs_merkle, -); - -impl MerkleTree { - /// Create a new Merkle tree from the given `leaves` - pub fn from_leaves(leaves: &[Hash]) -> Result { - if leaves.is_empty() { - return Err(MerkleError::MerkleNoLeavesProvided); - } - let leaves: Vec<[u8; 32]> = leaves.iter().map(|h| *h.as_bytes()).collect(); - Ok(Self(MerkleTree_rs_merkle::::from_leaves(&leaves))) - } - - /// Creates an inclusion proof for the given `indices` - /// - /// # Panics - /// - /// - if `indices` is not sorted. - /// - if `indices` contains duplicates - pub fn proof(&self, indices: &[usize]) -> MerkleProof { - assert!( - indices.windows(2).all(|w| w[0] < w[1]), - "indices must be sorted" - ); - - let proof = self.0.proof(indices); - MerkleProof { - proof, - total_leaves: self.0.leaves_len(), - } - } - - /// Returns the Merkle root for this MerkleTree - pub fn root(&self) -> MerkleRoot { - self.0 - .root() - .expect("Merkle root should be available") - .into() - } -} - -/// Serialize the rs_merkle's `MerkleTree` type -fn merkle_tree_serialize( - tree: &MerkleTree_rs_merkle, - serializer: S, -) -> Result -where - S: Serializer, -{ - // all leaves are sha256 hashes - let hash_size = 32; - let mut bytes: Vec = Vec::with_capacity(tree.leaves_len() * hash_size); - if let Some(leaves) = tree.leaves() { - for leaf in leaves { - bytes.append(&mut leaf.to_vec()); - } - } - - serializer.serialize_bytes(&bytes) -} - -fn merkle_tree_deserialize<'de, D>( - deserializer: D, -) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - let bytes: Vec = Vec::deserialize(deserializer)?; - if bytes.len() % 32 != 0 { - return Err(serde::de::Error::custom("leaves must be 32 bytes")); - } - let leaves: Vec<[u8; 32]> = bytes.chunks(32).map(|c| c.try_into().unwrap()).collect(); - - Ok(MerkleTree_rs_merkle::::from_leaves( - leaves.as_slice(), - )) -} - -#[cfg(test)] -mod test { - use super::*; - - // Expect Merkle proof verification to succeed - #[test] - fn test_verify_success() { - let leaf0 = Hash::from([0u8; 32]); - let leaf1 = Hash::from([1u8; 32]); - let leaf2 = Hash::from([2u8; 32]); - let leaf3 = Hash::from([3u8; 32]); - let leaf4 = Hash::from([4u8; 32]); - let tree = MerkleTree::from_leaves(&[leaf0, leaf1, leaf2, leaf3, leaf4]).unwrap(); - let proof = tree.proof(&[2, 3, 4]); - - assert!(proof - .verify(&tree.root(), &[2, 3, 4], &[leaf2, leaf3, leaf4]) - .is_ok(),); - } - - #[test] - fn test_verify_fail_wrong_leaf() { - let leaf0 = Hash::from([0u8; 32]); - let leaf1 = Hash::from([1u8; 32]); - let leaf2 = Hash::from([2u8; 32]); - let leaf3 = Hash::from([3u8; 32]); - let leaf4 = Hash::from([4u8; 32]); - let tree = MerkleTree::from_leaves(&[leaf0, leaf1, leaf2, leaf3, leaf4]).unwrap(); - let proof = tree.proof(&[2, 3, 4]); - - // fail because the leaf is wrong - assert_eq!( - proof - .verify(&tree.root(), &[2, 3, 4], &[leaf1, leaf3, leaf4]) - .err() - .unwrap(), - MerkleError::MerkleProofVerificationFailed - ); - } - - #[test] - #[should_panic] - fn test_proof_fail_length_unsorted() { - let leaf0 = Hash::from([0u8; 32]); - let leaf1 = Hash::from([1u8; 32]); - let leaf2 = Hash::from([2u8; 32]); - let leaf3 = Hash::from([3u8; 32]); - let leaf4 = Hash::from([4u8; 32]); - let tree = MerkleTree::from_leaves(&[leaf0, leaf1, leaf2, leaf3, leaf4]).unwrap(); - _ = tree.proof(&[2, 4, 3]); - } - - #[test] - #[should_panic] - fn test_proof_fail_length_duplicates() { - let leaf0 = Hash::from([0u8; 32]); - let leaf1 = Hash::from([1u8; 32]); - let leaf2 = Hash::from([2u8; 32]); - let leaf3 = Hash::from([3u8; 32]); - let leaf4 = Hash::from([4u8; 32]); - let tree = MerkleTree::from_leaves(&[leaf0, leaf1, leaf2, leaf3, leaf4]).unwrap(); - _ = tree.proof(&[2, 2, 3]); - } - - #[test] - #[should_panic] - fn test_verify_fail_length_mismatch() { - let leaf0 = Hash::from([0u8; 32]); - let leaf1 = Hash::from([1u8; 32]); - let leaf2 = Hash::from([2u8; 32]); - let leaf3 = Hash::from([3u8; 32]); - let leaf4 = Hash::from([4u8; 32]); - let tree = MerkleTree::from_leaves(&[leaf0, leaf1, leaf2, leaf3, leaf4]).unwrap(); - let proof = tree.proof(&[2, 3, 4]); - - _ = proof.verify(&tree.root(), &[1, 2, 3, 4], &[leaf2, leaf3, leaf4]); - } - - #[test] - #[should_panic] - fn test_verify_fail_duplicates() { - let leaf0 = Hash::from([0u8; 32]); - let leaf1 = Hash::from([1u8; 32]); - let leaf2 = Hash::from([2u8; 32]); - let leaf3 = Hash::from([3u8; 32]); - let leaf4 = Hash::from([4u8; 32]); - let tree = MerkleTree::from_leaves(&[leaf0, leaf1, leaf2, leaf3, leaf4]).unwrap(); - let proof = tree.proof(&[2, 3, 4]); - - _ = proof.verify(&tree.root(), &[2, 2, 3], &[leaf2, leaf2, leaf3]); - } - - #[test] - fn test_verify_fail_incorrect_leaf_count() { - let leaf0 = Hash::from([0u8; 32]); - let leaf1 = Hash::from([1u8; 32]); - let leaf2 = Hash::from([2u8; 32]); - let leaf3 = Hash::from([3u8; 32]); - let leaf4 = Hash::from([4u8; 32]); - let tree = MerkleTree::from_leaves(&[leaf0, leaf1, leaf2, leaf3, leaf4]).unwrap(); - let mut proof = tree.proof(&[2, 3, 4]); - - proof.total_leaves = 6; - - // fail because leaf count is wrong - assert!(proof - .verify(&tree.root(), &[2, 3, 4], &[leaf2, leaf3, leaf4]) - .is_err()); - } - - #[test] - fn test_verify_fail_incorrect_indices() { - let leaf0 = Hash::from([0u8; 32]); - let leaf1 = Hash::from([1u8; 32]); - let leaf2 = Hash::from([2u8; 32]); - let leaf3 = Hash::from([3u8; 32]); - let leaf4 = Hash::from([4u8; 32]); - let tree = MerkleTree::from_leaves(&[leaf0, leaf1, leaf2, leaf3, leaf4]).unwrap(); - let proof = tree.proof(&[2, 3, 4]); - - // fail because tree index is wrong - assert!(proof - .verify(&tree.root(), &[1, 3, 4], &[leaf1, leaf3, leaf4]) - .is_err()); - } - - #[test] - fn test_verify_fail_fewer_indices() { - let leaf0 = Hash::from([0u8; 32]); - let leaf1 = Hash::from([1u8; 32]); - let leaf2 = Hash::from([2u8; 32]); - let leaf3 = Hash::from([3u8; 32]); - let leaf4 = Hash::from([4u8; 32]); - let tree = MerkleTree::from_leaves(&[leaf0, leaf1, leaf2, leaf3, leaf4]).unwrap(); - let proof = tree.proof(&[2, 3, 4]); - - // trying to verify less leaves than what was included in the proof - assert!(proof - .verify(&tree.root(), &[3, 4], &[leaf3, leaf4]) - .is_err()); - } - - // Expect MerkleProof/MerkleTree custom serialization/deserialization to work - #[test] - fn test_serialization() { - let leaf0 = Hash::from([0u8; 32]); - let leaf1 = Hash::from([1u8; 32]); - let leaf2 = Hash::from([2u8; 32]); - let leaf3 = Hash::from([3u8; 32]); - let leaf4 = Hash::from([4u8; 32]); - let tree = MerkleTree::from_leaves(&[leaf0, leaf1, leaf2, leaf3, leaf4]).unwrap(); - let proof = tree.proof(&[2, 3, 4]); - - // serialize - let tree_bytes = bincode::serialize(&tree).unwrap(); - let proof_bytes = bincode::serialize(&proof).unwrap(); - - // deserialize - let tree2: MerkleTree = bincode::deserialize(&tree_bytes).unwrap(); - let proof2: MerkleProof = bincode::deserialize(&proof_bytes).unwrap(); - - assert!(proof2 - .verify(&tree2.root(), &[2, 3, 4], &[leaf2, leaf3, leaf4]) - .is_ok()); - } -} diff --git a/tlsn/tlsn-core/src/msg.rs b/tlsn/tlsn-core/src/msg.rs deleted file mode 100644 index 458447aa41..0000000000 --- a/tlsn/tlsn-core/src/msg.rs +++ /dev/null @@ -1,41 +0,0 @@ -//! Protocol message types. - -use serde::{Deserialize, Serialize}; -use utils::range::RangeSet; - -use crate::{merkle::MerkleRoot, proof::SessionInfo, signature::Signature, SessionHeader}; - -/// Top-level enum for all messages -#[derive(Debug, Serialize, Deserialize)] -pub enum TlsnMessage { - /// A Merkle root for the tree of commitments to the transcript. - TranscriptCommitmentRoot(MerkleRoot), - /// A session header signed by a notary. - SignedSessionHeader(SignedSessionHeader), - /// A session header. - SessionHeader(SessionHeader), - /// Information about the TLS session - SessionInfo(SessionInfo), - /// Information about the values the prover wants to prove - ProvingInfo(ProvingInfo), -} - -/// A signed session header. -#[derive(Debug, Serialize, Deserialize)] -pub struct SignedSessionHeader { - /// The session header - pub header: SessionHeader, - /// The notary's signature - pub signature: Signature, -} - -/// Information about the values the prover wants to prove -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct ProvingInfo { - /// The ids for the sent transcript - pub sent_ids: RangeSet, - /// The ids for the received transcript - pub recv_ids: RangeSet, - /// Purported cleartext values - pub cleartext: Vec, -} diff --git a/tlsn/tlsn-core/src/proof/mod.rs b/tlsn/tlsn-core/src/proof/mod.rs deleted file mode 100644 index e9f4ec9daa..0000000000 --- a/tlsn/tlsn-core/src/proof/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -//! Different types of proofs used in the TLSNotary protocol. - -mod session; -mod substrings; - -pub use session::{default_cert_verifier, SessionInfo, SessionProof, SessionProofError}; -pub use substrings::{ - SubstringsProof, SubstringsProofBuilder, SubstringsProofBuilderError, SubstringsProofError, -}; - -use serde::{Deserialize, Serialize}; -use std::fmt::Debug; - -/// Proof that a transcript of communications took place between a Prover and Server. -#[derive(Debug, Serialize, Deserialize)] -pub struct TlsProof { - /// Proof of the TLS handshake, server identity, and commitments to the transcript. - pub session: SessionProof, - /// Proof regarding the contents of the transcript. - pub substrings: SubstringsProof, -} diff --git a/tlsn/tlsn-core/src/proof/session.rs b/tlsn/tlsn-core/src/proof/session.rs deleted file mode 100644 index 8a81ed324b..0000000000 --- a/tlsn/tlsn-core/src/proof/session.rs +++ /dev/null @@ -1,339 +0,0 @@ -use web_time::{Duration, UNIX_EPOCH}; - -use serde::{Deserialize, Serialize}; - -use mpz_core::{commit::Decommitment, serialize::CanonicalSerialize}; -use tls_core::{ - anchors::{OwnedTrustAnchor, RootCertStore}, - dns::ServerName as TlsServerName, - handshake::HandshakeData, - verify::{ServerCertVerifier, WebPkiVerifier}, -}; - -use crate::{ - session::SessionHeader, - signature::{Signature, SignatureVerifyError}, - HandshakeSummary, NotaryPublicKey, ServerName, -}; - -/// An error that can occur while verifying a [`SessionProof`]. -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum SessionProofError { - /// Session proof is missing Notary signature - #[error("session proof is missing notary signature")] - MissingNotarySignature, - /// Invalid signature - #[error(transparent)] - InvalidSignature(#[from] SignatureVerifyError), - /// Invalid server name. - #[error("invalid server name: {0}")] - InvalidServerName(String), - /// Invalid handshake - #[error("handshake verification failed: {0}")] - InvalidHandshake(String), - /// Invalid server certificate - #[error("server certificate verification failed: {0}")] - InvalidServerCertificate(String), -} - -/// A session proof which is created from a [crate::session::NotarizedSession] -/// -/// Proof of the TLS handshake, server identity, and commitments to the transcript. -#[derive(Debug, Serialize, Deserialize)] -pub struct SessionProof { - /// The session header - pub header: SessionHeader, - /// Signature for the session header, if the notary signed it - pub signature: Option, - /// Information about the server - pub session_info: SessionInfo, -} - -impl SessionProof { - /// Verify the session proof. - /// - /// # Arguments - /// - /// * `notary_public_key` - The public key of the notary. - /// * `cert_verifier` - The certificate verifier. - pub fn verify( - &self, - notary_public_key: impl Into, - cert_verifier: &impl ServerCertVerifier, - ) -> Result<(), SessionProofError> { - // Verify notary signature - let signature = self - .signature - .as_ref() - .ok_or(SessionProofError::MissingNotarySignature)?; - - signature.verify(&self.header.to_bytes(), notary_public_key)?; - self.session_info - .verify(self.header.handshake_summary(), cert_verifier)?; - - Ok(()) - } - - /// Verify the session proof using trust anchors from the `webpki-roots` crate. - /// - /// # Arguments - /// - /// * `notary_public_key` - The public key of the notary. - pub fn verify_with_default_cert_verifier( - &self, - notary_public_key: impl Into, - ) -> Result<(), SessionProofError> { - self.verify(notary_public_key, &default_cert_verifier()) - } -} - -/// Contains information about the session -/// -/// Includes the [ServerName] and the decommitment to the [HandshakeData]. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct SessionInfo { - /// The server name. - pub server_name: ServerName, - /// Decommitment to the TLS handshake and server identity. - pub handshake_decommitment: Decommitment, -} - -impl SessionInfo { - /// Verify the session info. - pub fn verify( - &self, - handshake_summary: &HandshakeSummary, - cert_verifier: &impl ServerCertVerifier, - ) -> Result<(), SessionProofError> { - // Verify server name - let server_name = TlsServerName::try_from(self.server_name.as_ref()) - .map_err(|e| SessionProofError::InvalidServerName(e.to_string()))?; - - // Verify handshake - self.handshake_decommitment - .verify(handshake_summary.handshake_commitment()) - .map_err(|e| SessionProofError::InvalidHandshake(e.to_string()))?; - - // Verify server certificate - self.handshake_decommitment - .data() - .verify( - cert_verifier, - UNIX_EPOCH + Duration::from_secs(handshake_summary.time()), - &server_name, - ) - .map_err(|e| SessionProofError::InvalidServerCertificate(e.to_string()))?; - - Ok(()) - } - - /// Verify the session info using trust anchors from the `webpki-roots` crate. - /// - /// # Arguments - /// - /// * `notary_public_key` - The public key of the notary. - pub fn verify_with_default_cert_verifier( - &self, - handshake_summary: &HandshakeSummary, - ) -> Result<(), SessionProofError> { - self.verify(handshake_summary, &default_cert_verifier()) - } -} - -/// Create a new [`WebPkiVerifier`] with the default trust anchors from the `webpki-roots` crate. -pub fn default_cert_verifier() -> WebPkiVerifier { - let mut root_store = RootCertStore::empty(); - root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject.as_ref(), - ta.subject_public_key_info.as_ref(), - ta.name_constraints.as_ref().map(|nc| nc.as_ref()), - ) - })); - - WebPkiVerifier::new(root_store, None) -} - -#[cfg(test)] -mod tests { - use super::*; - use rstest::*; - - use crate::fixtures::cert::{appliedzkp, tlsnotary, TestData}; - use tls_core::{dns::ServerName, key::Certificate}; - use web_time::SystemTime; - - /// Expect chain verification to succeed - #[rstest] - #[case::tlsnotary(tlsnotary())] - #[case::appliedzkp(appliedzkp())] - fn test_verify_cert_chain_sucess_ca_implicit(#[case] data: TestData) { - assert!(default_cert_verifier() - .verify_server_cert( - &data.ee, - &[data.inter], - &ServerName::try_from(data.dns_name.as_ref()).unwrap(), - &mut std::iter::empty(), - &[], - SystemTime::UNIX_EPOCH + Duration::from_secs(data.time), - ) - .is_ok()); - } - - /// Expect chain verification to succeed even when a trusted CA is provided among the intermediate - /// certs. webpki handles such cases properly. - #[rstest] - #[case::tlsnotary(tlsnotary())] - #[case::appliedzkp(appliedzkp())] - fn test_verify_cert_chain_success_ca_explicit(#[case] data: TestData) { - assert!(default_cert_verifier() - .verify_server_cert( - &data.ee, - &[data.inter, data.ca], - &ServerName::try_from(data.dns_name.as_ref()).unwrap(), - &mut std::iter::empty(), - &[], - SystemTime::UNIX_EPOCH + Duration::from_secs(data.time), - ) - .is_ok()); - } - - /// Expect to fail since the end entity cert was not valid at the time - #[rstest] - #[case::tlsnotary(tlsnotary())] - #[case::appliedzkp(appliedzkp())] - fn test_verify_cert_chain_fail_bad_time(#[case] data: TestData) { - // unix time when the cert chain was NOT valid - let bad_time: u64 = 1571465711; - - let err = default_cert_verifier().verify_server_cert( - &data.ee, - &[data.inter], - &ServerName::try_from(data.dns_name.as_ref()).unwrap(), - &mut std::iter::empty(), - &[], - SystemTime::UNIX_EPOCH + Duration::from_secs(bad_time), - ); - - assert!(matches!( - err.unwrap_err(), - tls_core::Error::InvalidCertificateData(_) - )); - } - - /// Expect to fail when no intermediate cert provided - #[rstest] - #[case::tlsnotary(tlsnotary())] - #[case::appliedzkp(appliedzkp())] - fn test_verify_cert_chain_fail_no_interm_cert(#[case] data: TestData) { - let err = default_cert_verifier().verify_server_cert( - &data.ee, - &[], - &ServerName::try_from(data.dns_name.as_ref()).unwrap(), - &mut std::iter::empty(), - &[], - SystemTime::UNIX_EPOCH + Duration::from_secs(data.time), - ); - - assert!(matches!( - err.unwrap_err(), - tls_core::Error::InvalidCertificateData(_) - )); - } - - /// Expect to fail when no intermediate cert provided even if a trusted CA cert is provided - #[rstest] - #[case::tlsnotary(tlsnotary())] - #[case::appliedzkp(appliedzkp())] - fn test_verify_cert_chain_fail_no_interm_cert_with_ca_cert(#[case] data: TestData) { - let err = default_cert_verifier().verify_server_cert( - &data.ee, - &[data.ca], - &ServerName::try_from(data.dns_name.as_ref()).unwrap(), - &mut std::iter::empty(), - &[], - SystemTime::UNIX_EPOCH + Duration::from_secs(data.time), - ); - - assert!(matches!( - err.unwrap_err(), - tls_core::Error::InvalidCertificateData(_) - )); - } - - /// Expect to fail because end-entity cert is wrong - #[rstest] - #[case::tlsnotary(tlsnotary())] - #[case::appliedzkp(appliedzkp())] - fn test_verify_cert_chain_fail_bad_ee_cert(#[case] data: TestData) { - let ee: &[u8] = include_bytes!("../fixtures/testdata/key_exchange/unknown/ee.der"); - - let err = default_cert_verifier().verify_server_cert( - &Certificate(ee.to_vec()), - &[data.inter], - &ServerName::try_from(data.dns_name.as_ref()).unwrap(), - &mut std::iter::empty(), - &[], - SystemTime::UNIX_EPOCH + Duration::from_secs(data.time), - ); - - assert!(matches!( - err.unwrap_err(), - tls_core::Error::InvalidCertificateData(_) - )); - } - - /// Expect to succeed when key exchange params signed correctly with a cert - #[rstest] - #[case::tlsnotary(tlsnotary())] - #[case::appliedzkp(appliedzkp())] - fn test_verify_sig_ke_params_success(#[case] data: TestData) { - assert!(default_cert_verifier() - .verify_tls12_signature(&data.signature_msg(), &data.ee, &data.dss()) - .is_ok()); - } - - /// Expect sig verification to fail because client_random is wrong - #[rstest] - #[case::tlsnotary(tlsnotary())] - #[case::appliedzkp(appliedzkp())] - fn test_verify_sig_ke_params_fail_bad_client_random(#[case] mut data: TestData) { - data.cr.0[31] = data.cr.0[31].wrapping_add(1); - - assert!(default_cert_verifier() - .verify_tls12_signature(&data.signature_msg(), &data.ee, &data.dss()) - .is_err()); - } - - /// Expect sig verification to fail because the sig is wrong - #[rstest] - #[case::tlsnotary(tlsnotary())] - #[case::appliedzkp(appliedzkp())] - fn test_verify_sig_ke_params_fail_bad_sig(#[case] mut data: TestData) { - data.sig[31] = data.sig[31].wrapping_add(1); - - assert!(default_cert_verifier() - .verify_tls12_signature(&data.signature_msg(), &data.ee, &data.dss()) - .is_err()); - } - - /// Expect to fail because the dns name is not in the cert - #[rstest] - #[case::tlsnotary(tlsnotary())] - #[case::appliedzkp(appliedzkp())] - fn test_check_dns_name_present_in_cert_fail_bad_host(#[case] data: TestData) { - let bad_name = ServerName::try_from("badhost.com").unwrap(); - - assert!(default_cert_verifier() - .verify_server_cert( - &data.ee, - &[data.inter, data.ca], - &bad_name, - &mut std::iter::empty(), - &[], - SystemTime::UNIX_EPOCH + Duration::from_secs(data.time), - ) - .is_err()); - } -} diff --git a/tlsn/tlsn-core/src/proof/substrings.rs b/tlsn/tlsn-core/src/proof/substrings.rs deleted file mode 100644 index 7d35e8d9de..0000000000 --- a/tlsn/tlsn-core/src/proof/substrings.rs +++ /dev/null @@ -1,286 +0,0 @@ -//! Substrings proofs based on commitments. - -use crate::{ - commitment::{ - Commitment, CommitmentId, CommitmentInfo, CommitmentOpening, TranscriptCommitments, - }, - merkle::MerkleProof, - transcript::get_value_ids, - Direction, EncodingId, RedactedTranscript, SessionHeader, Transcript, TranscriptSlice, - MAX_TOTAL_COMMITTED_DATA, -}; -use mpz_circuits::types::ValueType; -use mpz_garble_core::Encoder; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use utils::range::{RangeDisjoint, RangeSet, RangeUnion}; - -/// An error for [`SubstringsProofBuilder`] -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum SubstringsProofBuilderError { - /// Invalid commitment id. - #[error("invalid commitment id: {0:?}")] - InvalidCommitmentId(CommitmentId), - /// Invalid commitment type. - #[error("commitment {0:?} is not a substrings commitment")] - InvalidCommitmentType(CommitmentId), - /// Attempted to add a commitment with a duplicate id. - #[error("commitment with id {0:?} already exists")] - DuplicateCommitmentId(CommitmentId), -} - -/// A builder for [`SubstringsProof`] -pub struct SubstringsProofBuilder<'a> { - commitments: &'a TranscriptCommitments, - transcript_tx: &'a Transcript, - transcript_rx: &'a Transcript, - openings: HashMap, -} - -opaque_debug::implement!(SubstringsProofBuilder<'_>); - -impl<'a> SubstringsProofBuilder<'a> { - /// Creates a new builder. - pub fn new( - commitments: &'a TranscriptCommitments, - transcript_tx: &'a Transcript, - transcript_rx: &'a Transcript, - ) -> Self { - Self { - commitments, - transcript_tx, - transcript_rx, - openings: HashMap::default(), - } - } - - /// Returns a reference to the commitments. - pub fn commitments(&self) -> &TranscriptCommitments { - self.commitments - } - - /// Reveals data corresponding to the provided commitment id - pub fn reveal(&mut self, id: CommitmentId) -> Result<&mut Self, SubstringsProofBuilderError> { - let commitment = self - .commitments() - .get(&id) - .ok_or(SubstringsProofBuilderError::InvalidCommitmentId(id))?; - - let info = self - .commitments() - .get_info(&id) - .expect("info exists if commitment exists"); - - #[allow(irrefutable_let_patterns)] - let Commitment::Blake3(commitment) = commitment - else { - return Err(SubstringsProofBuilderError::InvalidCommitmentType(id)); - }; - - let transcript = match info.direction() { - Direction::Sent => self.transcript_tx, - Direction::Received => self.transcript_rx, - }; - - let data = transcript.get_bytes_in_ranges(info.ranges()); - - // add commitment to openings and return an error if it is already present - if self - .openings - .insert(id, (info.clone(), commitment.open(data).into())) - .is_some() - { - return Err(SubstringsProofBuilderError::DuplicateCommitmentId(id)); - } - - Ok(self) - } - - /// Builds the [`SubstringsProof`] - pub fn build(self) -> Result { - let Self { - commitments, - openings, - .. - } = self; - - let mut indices = openings - .keys() - .map(|id| id.to_inner() as usize) - .collect::>(); - indices.sort(); - - let inclusion_proof = commitments.merkle_tree().proof(&indices); - - Ok(SubstringsProof { - openings, - inclusion_proof, - }) - } -} - -/// An error relating to [`SubstringsProof`] -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum SubstringsProofError { - /// The proof contains more data than the maximum allowed. - #[error( - "substrings proof opens more data than the maximum allowed: {0} > {}", - MAX_TOTAL_COMMITTED_DATA - )] - MaxDataExceeded(usize), - /// The proof contains duplicate transcript data. - #[error("proof contains duplicate transcript data")] - DuplicateData(Direction, RangeSet), - /// Range of the opening is out of bounds. - #[error("range of opening {0:?} is out of bounds: {1}")] - RangeOutOfBounds(CommitmentId, usize), - /// The proof contains an invalid commitment opening. - #[error("invalid opening for commitment id: {0:?}")] - InvalidOpening(CommitmentId), - /// The proof contains an invalid inclusion proof. - #[error("invalid inclusion proof: {0}")] - InvalidInclusionProof(String), -} - -/// A substring proof using commitments -/// -/// This substring proof contains the commitment openings and a proof -/// that the corresponding commitments are present in the merkle tree. -#[derive(Serialize, Deserialize)] -pub struct SubstringsProof { - openings: HashMap, - inclusion_proof: MerkleProof, -} - -opaque_debug::implement!(SubstringsProof); - -impl SubstringsProof { - /// Verifies this proof and, if successful, returns the redacted sent and received transcripts. - /// - /// # Arguments - /// - /// * `header` - The session header. - pub fn verify( - self, - header: &SessionHeader, - ) -> Result<(RedactedTranscript, RedactedTranscript), SubstringsProofError> { - let Self { - openings, - inclusion_proof, - } = self; - - let mut indices = Vec::with_capacity(openings.len()); - let mut expected_hashes = Vec::with_capacity(openings.len()); - let mut sent = vec![0u8; header.sent_len()]; - let mut recv = vec![0u8; header.recv_len()]; - let mut sent_ranges = RangeSet::default(); - let mut recv_ranges = RangeSet::default(); - let mut total_opened = 0u128; - for (id, (info, opening)) in openings { - let CommitmentInfo { - ranges, direction, .. - } = info; - - let opened_len = ranges.len(); - - // Make sure the amount of data being proved is bounded. - total_opened += opened_len as u128; - if total_opened > MAX_TOTAL_COMMITTED_DATA as u128 { - return Err(SubstringsProofError::MaxDataExceeded(total_opened as usize)); - } - - // Make sure the opening length matches the ranges length. - if opening.data().len() != opened_len { - return Err(SubstringsProofError::InvalidOpening(id)); - } - - // Make sure duplicate data is not opened. - match direction { - Direction::Sent => { - if !sent_ranges.is_disjoint(&ranges) { - return Err(SubstringsProofError::DuplicateData(direction, ranges)); - } - sent_ranges = sent_ranges.union(&ranges); - } - Direction::Received => { - if !recv_ranges.is_disjoint(&ranges) { - return Err(SubstringsProofError::DuplicateData(direction, ranges)); - } - recv_ranges = recv_ranges.union(&ranges); - } - } - - // Make sure the ranges are within the bounds of the transcript - let max = ranges - .max() - .ok_or(SubstringsProofError::InvalidOpening(id))?; - let transcript_len = match direction { - Direction::Sent => header.sent_len(), - Direction::Received => header.recv_len(), - }; - - if max > transcript_len { - return Err(SubstringsProofError::RangeOutOfBounds(id, max)); - } - - // Generate the expected encodings for the purported data in the opening. - let encodings = get_value_ids(&ranges, direction) - .map(|id| { - header - .encoder() - .encode_by_type(EncodingId::new(&id).to_inner(), &ValueType::U8) - }) - .collect::>(); - - // Compute the expected hash of the commitment to make sure it is - // present in the merkle tree. - indices.push(id.to_inner() as usize); - expected_hashes.push(opening.recover(&encodings).hash()); - - // Make sure the length of data from the opening matches the commitment. - let mut data = opening.into_data(); - if data.len() != ranges.len() { - return Err(SubstringsProofError::InvalidOpening(id)); - } - - let dest = match direction { - Direction::Sent => &mut sent, - Direction::Received => &mut recv, - }; - - // Iterate over the ranges backwards, copying the data from the opening - // then truncating it. - for range in ranges.iter_ranges().rev() { - let start = data.len() - range.len(); - dest[range].copy_from_slice(&data[start..]); - data.truncate(start); - } - } - - // Verify that the expected hashes are present in the merkle tree. - // - // This proves the Prover committed to the purported data prior to the encoder - // seed being revealed. - inclusion_proof - .verify(header.merkle_root(), &indices, &expected_hashes) - .map_err(|e| SubstringsProofError::InvalidInclusionProof(e.to_string()))?; - - // Iterate over the unioned ranges and create TranscriptSlices for each. - // This ensures that the slices are sorted and disjoint. - let sent_slices = sent_ranges - .iter_ranges() - .map(|range| TranscriptSlice::new(range.clone(), sent[range].to_vec())) - .collect(); - let recv_slices = recv_ranges - .iter_ranges() - .map(|range| TranscriptSlice::new(range.clone(), recv[range].to_vec())) - .collect(); - - Ok(( - RedactedTranscript::new(header.sent_len(), sent_slices), - RedactedTranscript::new(header.recv_len(), recv_slices), - )) - } -} diff --git a/tlsn/tlsn-core/src/session/data.rs b/tlsn/tlsn-core/src/session/data.rs deleted file mode 100644 index 5bb148e802..0000000000 --- a/tlsn/tlsn-core/src/session/data.rs +++ /dev/null @@ -1,77 +0,0 @@ -use crate::{ - commitment::TranscriptCommitments, - proof::{SessionInfo, SubstringsProofBuilder}, - ServerName, Transcript, -}; -use mpz_core::commit::Decommitment; -use serde::{Deserialize, Serialize}; -use tls_core::handshake::HandshakeData; - -/// Session data used for notarization. -/// -/// This contains all the private data held by the `Prover` after notarization including -/// commitments to the parts of the transcript. -/// -/// # Selective disclosure -/// -/// The `Prover` can selectively disclose parts of the transcript to a `Verifier` using a -/// [`SubstringsProof`](crate::proof::SubstringsProof). -/// -/// See [`build_substrings_proof`](SessionData::build_substrings_proof). -#[derive(Serialize, Deserialize)] -pub struct SessionData { - session_info: SessionInfo, - transcript_tx: Transcript, - transcript_rx: Transcript, - commitments: TranscriptCommitments, -} - -impl SessionData { - /// Creates new session data. - pub fn new( - server_name: ServerName, - handshake_data_decommitment: Decommitment, - transcript_tx: Transcript, - transcript_rx: Transcript, - commitments: TranscriptCommitments, - ) -> Self { - let session_info = SessionInfo { - server_name, - handshake_decommitment: handshake_data_decommitment, - }; - - Self { - session_info, - transcript_tx, - transcript_rx, - commitments, - } - } - - /// Returns the session info - pub fn session_info(&self) -> &SessionInfo { - &self.session_info - } - - /// Returns the transcript for data sent to the server - pub fn sent_transcript(&self) -> &Transcript { - &self.transcript_tx - } - - /// Returns the transcript for data received from the server - pub fn recv_transcript(&self) -> &Transcript { - &self.transcript_rx - } - - /// Returns the transcript commitments. - pub fn commitments(&self) -> &TranscriptCommitments { - &self.commitments - } - - /// Returns a substrings proof builder. - pub fn build_substrings_proof(&self) -> SubstringsProofBuilder { - SubstringsProofBuilder::new(&self.commitments, &self.transcript_tx, &self.transcript_rx) - } -} - -opaque_debug::implement!(SessionData); diff --git a/tlsn/tlsn-core/src/session/handshake.rs b/tlsn/tlsn-core/src/session/handshake.rs deleted file mode 100644 index 142f9a66b9..0000000000 --- a/tlsn/tlsn-core/src/session/handshake.rs +++ /dev/null @@ -1,82 +0,0 @@ -use mpz_core::{commit::Decommitment, hash::Hash}; -use serde::{Deserialize, Serialize}; -use tls_core::{handshake::HandshakeData, key::PublicKey, msgs::handshake::ServerECDHParams}; - -/// An error that can occur while verifying a handshake summary -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum HandshakeVerifyError { - /// The handshake data does not match the commitment - #[error("Handshake data does not match commitment")] - Commitment, - /// The key exchange parameters are invalid - #[error("Key exchange parameters are invalid")] - KxParams, - /// The server ephemeral key does not match - #[error("Server ephemeral key does not match")] - ServerEphemKey, -} - -/// Handshake summary is part of the session header signed by the Notary -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct HandshakeSummary { - /// time when Notary signed the session header - // TODO: we can change this to be the time when the Notary started the TLS handshake 2PC - time: u64, - /// server ephemeral public key - server_public_key: PublicKey, - /// Prover's commitment to [crate::handshake_data::HandshakeData] - handshake_commitment: Hash, -} - -impl HandshakeSummary { - /// Creates a new HandshakeSummary - pub fn new(time: u64, ephemeral_ec_pubkey: PublicKey, handshake_commitment: Hash) -> Self { - Self { - time, - server_public_key: ephemeral_ec_pubkey, - handshake_commitment, - } - } - - /// Time of the TLS session, in seconds since the UNIX epoch. - /// - /// # Note - /// - /// This time is not necessarily exactly aligned with the TLS handshake. - pub fn time(&self) -> u64 { - self.time - } - - /// Returns the server ephemeral public key - pub fn server_public_key(&self) -> &PublicKey { - &self.server_public_key - } - - /// Returns commitment to the handshake data - pub fn handshake_commitment(&self) -> &Hash { - &self.handshake_commitment - } - - /// Verifies that the provided handshake data matches this handshake summary - pub fn verify(&self, data: &Decommitment) -> Result<(), HandshakeVerifyError> { - // Verify the handshake data matches the commitment in the session header - data.verify(&self.handshake_commitment) - .map_err(|_| HandshakeVerifyError::Commitment)?; - - let ecdh_params = tls_core::suites::tls12::decode_ecdh_params::( - data.data().server_kx_details().kx_params(), - ) - .ok_or(HandshakeVerifyError::KxParams)?; - - let server_public_key = - PublicKey::new(ecdh_params.curve_params.named_group, &ecdh_params.public.0); - - // Ephemeral pubkey must match the one which the Notary signed - if server_public_key != self.server_public_key { - return Err(HandshakeVerifyError::ServerEphemKey); - } - - Ok(()) - } -} diff --git a/tlsn/tlsn-core/src/session/header.rs b/tlsn/tlsn-core/src/session/header.rs deleted file mode 100644 index 38c8fbeb19..0000000000 --- a/tlsn/tlsn-core/src/session/header.rs +++ /dev/null @@ -1,120 +0,0 @@ -use mpz_core::commit::Decommitment; -use serde::{Deserialize, Serialize}; - -use mpz_garble_core::ChaChaEncoder; -use tls_core::{handshake::HandshakeData, key::PublicKey}; - -use crate::{merkle::MerkleRoot, HandshakeSummary}; - -/// An error that can occur while verifying a session header -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum SessionHeaderVerifyError { - /// The session header is not consistent with the provided data - #[error("session header is not consistent with the provided data")] - InconsistentHeader, -} - -/// An authentic session header from the Notary -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SessionHeader { - /// A PRG seeds used to generate encodings for the plaintext - encoder_seed: [u8; 32], - - /// The root of the Merkle tree of all the commitments. The Prover must prove that each one of the - /// `commitments` is included in the Merkle tree. - /// This approach allows the Prover to hide from the Notary the exact amount of commitments thus - /// increasing Prover privacy against the Notary. - /// The root was made known to the Notary before the Notary opened his garbled circuits - /// to the Prover. - merkle_root: MerkleRoot, - - /// Bytelength of all data which was sent to the webserver - sent_len: usize, - /// Bytelength of all data which was received from the webserver - recv_len: usize, - - handshake_summary: HandshakeSummary, -} - -impl SessionHeader { - /// Create a new instance of SessionHeader - pub fn new( - encoder_seed: [u8; 32], - merkle_root: MerkleRoot, - sent_len: usize, - recv_len: usize, - handshake_summary: HandshakeSummary, - ) -> Self { - Self { - encoder_seed, - merkle_root, - sent_len, - recv_len, - handshake_summary, - } - } - - /// Verify the data in the header is consistent with the Prover's view - pub fn verify( - &self, - time: u64, - server_public_key: &PublicKey, - root: &MerkleRoot, - encoder_seed: &[u8; 32], - handshake_data_decommitment: &Decommitment, - ) -> Result<(), SessionHeaderVerifyError> { - let ok_time = self.handshake_summary.time().abs_diff(time) <= 300; - let ok_root = &self.merkle_root == root; - let ok_encoder_seed = &self.encoder_seed == encoder_seed; - let ok_handshake_data = handshake_data_decommitment - .verify(self.handshake_summary.handshake_commitment()) - .is_ok(); - let ok_server_public_key = self.handshake_summary.server_public_key() == server_public_key; - - if !(ok_time && ok_root && ok_encoder_seed && ok_handshake_data && ok_server_public_key) { - return Err(SessionHeaderVerifyError::InconsistentHeader); - } - - Ok(()) - } - - /// Create a new [ChaChaEncoder] from encoder_seed - pub fn encoder(&self) -> ChaChaEncoder { - ChaChaEncoder::new(self.encoder_seed) - } - - /// Returns the seed used to generate plaintext encodings - pub fn encoder_seed(&self) -> &[u8; 32] { - &self.encoder_seed - } - - /// Returns the merkle_root of the merkle tree of the prover's commitments - pub fn merkle_root(&self) -> &MerkleRoot { - &self.merkle_root - } - - /// Returns the [HandshakeSummary] of the TLS session between prover and server - pub fn handshake_summary(&self) -> &HandshakeSummary { - &self.handshake_summary - } - - /// Time of the TLS session, in seconds since the UNIX epoch. - /// - /// # Note - /// - /// This time is not necessarily exactly aligned with the TLS handshake. - pub fn time(&self) -> u64 { - self.handshake_summary.time() - } - - /// Returns the number of bytes sent to the server - pub fn sent_len(&self) -> usize { - self.sent_len - } - - /// Returns the number of bytes received by the server - pub fn recv_len(&self) -> usize { - self.recv_len - } -} diff --git a/tlsn/tlsn-core/src/session/mod.rs b/tlsn/tlsn-core/src/session/mod.rs deleted file mode 100644 index a9762f1e2f..0000000000 --- a/tlsn/tlsn-core/src/session/mod.rs +++ /dev/null @@ -1,66 +0,0 @@ -//! TLS session types. - -mod data; -mod handshake; -mod header; - -use serde::{Deserialize, Serialize}; - -pub use data::SessionData; -pub use handshake::{HandshakeSummary, HandshakeVerifyError}; -pub use header::{SessionHeader, SessionHeaderVerifyError}; - -use crate::{ - proof::{SessionInfo, SessionProof}, - signature::Signature, -}; - -/// A validated notarized session stored by the Prover -#[derive(Serialize, Deserialize)] -pub struct NotarizedSession { - header: SessionHeader, - signature: Option, - data: SessionData, -} - -opaque_debug::implement!(NotarizedSession); - -impl NotarizedSession { - /// Create a new notarized session. - pub fn new(header: SessionHeader, signature: Option, data: SessionData) -> Self { - Self { - header, - signature, - data, - } - } - - /// Returns a proof of the TLS session - pub fn session_proof(&self) -> SessionProof { - let session_info = SessionInfo { - server_name: self.data.session_info().server_name.clone(), - handshake_decommitment: self.data.session_info().handshake_decommitment.clone(), - }; - - SessionProof { - header: self.header.clone(), - signature: self.signature.clone(), - session_info, - } - } - - /// Returns the [SessionHeader] - pub fn header(&self) -> &SessionHeader { - &self.header - } - - /// Returns the signature for the session header, if the notary signed it - pub fn signature(&self) -> &Option { - &self.signature - } - - /// Returns the [SessionData] - pub fn data(&self) -> &SessionData { - &self.data - } -} diff --git a/tlsn/tlsn-core/src/signature.rs b/tlsn/tlsn-core/src/signature.rs deleted file mode 100644 index b25dbf6743..0000000000 --- a/tlsn/tlsn-core/src/signature.rs +++ /dev/null @@ -1,63 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use p256::ecdsa::{signature::Verifier, VerifyingKey}; - -/// A Notary public key. -#[derive(Debug, Clone, Deserialize, Serialize)] -#[non_exhaustive] -pub enum NotaryPublicKey { - /// A NIST P-256 public key. - P256(p256::PublicKey), -} - -impl From for NotaryPublicKey { - fn from(key: p256::PublicKey) -> Self { - Self::P256(key) - } -} - -/// An error occurred while verifying a signature. -#[derive(Debug, thiserror::Error)] -#[error("signature verification failed: {0}")] -pub struct SignatureVerifyError(String); - -/// A Notary signature. -#[derive(Debug, Clone, Deserialize, Serialize)] -#[non_exhaustive] -pub enum Signature { - /// A secp256r1 signature. - P256(p256::ecdsa::Signature), -} - -impl From for Signature { - fn from(sig: p256::ecdsa::Signature) -> Self { - Self::P256(sig) - } -} - -impl Signature { - /// Returns the bytes of this signature. - pub fn to_bytes(&self) -> Vec { - match self { - Self::P256(sig) => sig.to_vec(), - } - } - - /// Verifies the signature. - /// - /// # Arguments - /// - /// * `msg` - The message to verify. - /// * `notary_public_key` - The public key of the notary. - pub fn verify( - &self, - msg: &[u8], - notary_public_key: impl Into, - ) -> Result<(), SignatureVerifyError> { - match (self, notary_public_key.into()) { - (Self::P256(sig), NotaryPublicKey::P256(key)) => VerifyingKey::from(key) - .verify(msg, sig) - .map_err(|e| SignatureVerifyError(e.to_string())), - } - } -} diff --git a/tlsn/tlsn-core/src/transcript.rs b/tlsn/tlsn-core/src/transcript.rs deleted file mode 100644 index 13751767f8..0000000000 --- a/tlsn/tlsn-core/src/transcript.rs +++ /dev/null @@ -1,244 +0,0 @@ -//! Transcript data types. - -use std::ops::Range; - -use bytes::Bytes; -use serde::{Deserialize, Serialize}; -use utils::range::{RangeDifference, RangeSet, RangeUnion}; - -pub(crate) static TX_TRANSCRIPT_ID: &str = "tx"; -pub(crate) static RX_TRANSCRIPT_ID: &str = "rx"; - -/// A transcript contains a subset of bytes from a TLS session -#[derive(Default, Serialize, Deserialize, Clone, Debug)] -pub struct Transcript { - data: Bytes, -} - -impl Transcript { - /// Creates a new transcript with the given ID and data - pub fn new(data: impl Into) -> Self { - Self { data: data.into() } - } - - /// Returns the actual traffic data of this transcript - pub fn data(&self) -> &Bytes { - &self.data - } - - /// Returns a concatenated bytestring located in the given ranges of the transcript. - /// - /// # Panics - /// - /// Panics if the range set is empty or is out of bounds. - pub(crate) fn get_bytes_in_ranges(&self, ranges: &RangeSet) -> Vec { - let max = ranges.max().expect("range set is not empty"); - assert!(max <= self.data.len(), "range set is out of bounds"); - - ranges - .iter_ranges() - .flat_map(|range| &self.data[range]) - .copied() - .collect() - } -} - -/// A transcript which may have some data redacted. -#[derive(Debug)] -pub struct RedactedTranscript { - data: Vec, - /// Ranges of `data` which have been authenticated - auth: RangeSet, - /// Ranges of `data` which have been redacted - redacted: RangeSet, -} - -impl RedactedTranscript { - /// Creates a new redacted transcript with the given length. - /// - /// All bytes in the transcript are initialized to 0. - /// - /// # Arguments - /// - /// * `len` - The length of the transcript - /// * `slices` - A list of slices of data which have been authenticated - pub fn new(len: usize, slices: Vec) -> Self { - let mut data = vec![0u8; len]; - let mut auth = RangeSet::default(); - for slice in slices { - data[slice.range()].copy_from_slice(slice.data()); - auth = auth.union(&slice.range()); - } - let redacted = RangeSet::from(0..len).difference(&auth); - - Self { - data, - auth, - redacted, - } - } - - /// Returns a reference to the data. - /// - /// # Warning - /// - /// Not all of the data in the transcript may have been authenticated. See - /// [authed](RedactedTranscript::authed) for a set of ranges which have been. - pub fn data(&self) -> &[u8] { - &self.data - } - - /// Returns all the ranges of data which have been authenticated. - pub fn authed(&self) -> &RangeSet { - &self.auth - } - - /// Returns all the ranges of data which have been redacted. - pub fn redacted(&self) -> &RangeSet { - &self.redacted - } - - /// Sets all bytes in the transcript which were redacted. - /// - /// # Arguments - /// - /// * `value` - The value to set the redacted bytes to - pub fn set_redacted(&mut self, value: u8) { - for range in self.redacted().clone().iter_ranges() { - self.data[range].fill(value); - } - } - - /// Sets all bytes in the transcript which were redacted in the given range. - /// - /// # Arguments - /// - /// * `value` - The value to set the redacted bytes to - /// * `range` - The range of redacted bytes to set - pub fn set_redacted_range(&mut self, value: u8, range: Range) { - for range in self - .redacted - .difference(&(0..self.data.len()).difference(&range)) - .iter_ranges() - { - self.data[range].fill(value); - } - } -} - -/// Slice of a transcript. -#[derive(PartialEq, Debug, Clone, Default)] -pub struct TranscriptSlice { - /// A byte range of this slice - range: Range, - /// The actual byte content of the slice - data: Vec, -} - -impl TranscriptSlice { - /// Creates a new transcript slice. - pub fn new(range: Range, data: Vec) -> Self { - Self { range, data } - } - - /// Returns the range of bytes this slice refers to in the transcript - pub fn range(&self) -> Range { - self.range.clone() - } - - /// Returns the bytes of this slice - pub fn data(&self) -> &[u8] { - &self.data - } - - /// Returns the bytes of this slice - pub fn into_bytes(self) -> Vec { - self.data - } -} - -/// The direction of data communicated over a TLS connection. -/// -/// This is used to differentiate between data sent to the Server, and data received from the Server. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub enum Direction { - /// Sent from the prover to the server - Sent, - /// Received by the prover from the server - Received, -} - -/// Returns the value ID for each byte in the provided range set -pub fn get_value_ids( - ranges: &RangeSet, - direction: Direction, -) -> impl Iterator + '_ { - let id = match direction { - Direction::Sent => TX_TRANSCRIPT_ID, - Direction::Received => RX_TRANSCRIPT_ID, - }; - - ranges.iter().map(move |idx| format!("{}/{}", id, idx)) -} - -#[cfg(test)] -mod tests { - use rstest::{fixture, rstest}; - - use super::*; - - #[fixture] - fn transcripts() -> (Transcript, Transcript) { - let sent = "data sent 123456789".as_bytes().to_vec(); - let recv = "data received 987654321".as_bytes().to_vec(); - (Transcript::new(sent), Transcript::new(recv)) - } - - #[rstest] - fn test_get_bytes_in_ranges(transcripts: (Transcript, Transcript)) { - let (sent, recv) = transcripts; - - let range1 = Range { start: 2, end: 4 }; - let range2 = Range { start: 10, end: 15 }; - // a full range spanning the entirety of the data - let range3 = Range { - start: 0, - end: sent.data().len(), - }; - - let expected = "ta12345".as_bytes().to_vec(); - assert_eq!( - expected, - sent.get_bytes_in_ranges(&RangeSet::from([range1.clone(), range2.clone()])) - ); - - let expected = "taved 9".as_bytes().to_vec(); - assert_eq!( - expected, - recv.get_bytes_in_ranges(&RangeSet::from([range1, range2])) - ); - - assert_eq!( - sent.data().as_ref(), - sent.get_bytes_in_ranges(&RangeSet::from([range3])) - ); - } - - #[rstest] - #[should_panic] - fn test_get_bytes_in_ranges_empty(transcripts: (Transcript, Transcript)) { - let (sent, _) = transcripts; - sent.get_bytes_in_ranges(&RangeSet::default()); - } - - #[rstest] - #[should_panic] - fn test_get_bytes_in_ranges_out_of_bounds(transcripts: (Transcript, Transcript)) { - let (sent, _) = transcripts; - let range = Range { - start: 0, - end: sent.data().len() + 1, - }; - sent.get_bytes_in_ranges(&RangeSet::from([range])); - } -} diff --git a/tlsn/tlsn-core/tests/api.rs b/tlsn/tlsn-core/tests/api.rs deleted file mode 100644 index d219664f4e..0000000000 --- a/tlsn/tlsn-core/tests/api.rs +++ /dev/null @@ -1,192 +0,0 @@ -use std::ops::Range; - -use p256::{ - ecdsa::{ - signature::{SignerMut, Verifier}, - Signature as P256Signature, SigningKey, - }, - PublicKey, -}; -use rand_chacha::ChaCha20Rng; -use rand_core::SeedableRng; - -use tls_core::{ - cert::ServerCertDetails, - handshake::HandshakeData, - ke::ServerKxDetails, - msgs::{enums::SignatureScheme, handshake::DigitallySignedStruct}, -}; - -use mpz_core::{commit::HashCommit, serialize::CanonicalSerialize}; - -use tlsn_core::{ - commitment::TranscriptCommitmentBuilder, - fixtures, - msg::SignedSessionHeader, - proof::{SessionProof, SubstringsProof}, - HandshakeSummary, NotarizedSession, ServerName, SessionData, SessionHeader, Signature, - Transcript, -}; - -#[test] -/// Tests that the commitment creation protocol and verification work end-to-end -fn test_api() { - let testdata = fixtures::cert::tlsnotary(); - // Prover's transcript - let data_sent = "sent data".as_bytes(); - let data_recv = "received data".as_bytes(); - let transcript_tx = Transcript::new(data_sent.to_vec()); - let transcript_rx = Transcript::new(data_recv.to_vec()); - - // Ranges of plaintext for which the Prover wants to create a commitment - let range1: Range = Range { start: 0, end: 2 }; - let range2: Range = Range { start: 1, end: 3 }; - - // Plaintext encodings which the Prover obtained from GC evaluation - let encodings_provider = fixtures::encoding_provider(data_sent, data_recv); - - // At the end of the session the Prover holds the: - // - time when the TLS handshake began - // - server ephemeral key - // - handshake data (to which the Prover sent a commitment earlier) - // - encoder seed revealed by the Notary at the end of the label commitment protocol - - let time = testdata.time; - let ephem_key = testdata.pubkey.clone(); - - let handshake_data = HandshakeData::new( - ServerCertDetails::new( - vec![ - testdata.ee.clone(), - testdata.inter.clone(), - testdata.ca.clone(), - ], - vec![], - None, - ), - ServerKxDetails::new( - testdata.kx_params(), - DigitallySignedStruct::new(SignatureScheme::RSA_PKCS1_SHA256, testdata.sig.clone()), - ), - testdata.cr, - testdata.sr, - ); - - // Commitment to the handshake which the Prover sent at the start of the TLS handshake - let (hs_decommitment, hs_commitment) = handshake_data.hash_commit(); - - let mut commitment_builder = - TranscriptCommitmentBuilder::new(encodings_provider, data_sent.len(), data_recv.len()); - - let commitment_id_1 = commitment_builder.commit_sent(range1.clone()).unwrap(); - let commitment_id_2 = commitment_builder.commit_recv(range2.clone()).unwrap(); - - let commitments = commitment_builder.build().unwrap(); - - let notarized_session_data = SessionData::new( - ServerName::Dns(testdata.dns_name.clone()), - hs_decommitment.clone(), - transcript_tx, - transcript_rx, - commitments, - ); - - // Some outer context generates an (ephemeral) signing key for the Notary, e.g. - let mut rng = ChaCha20Rng::from_seed([6u8; 32]); - let signing_key = SigningKey::random(&mut rng); - let raw_key = signing_key.to_bytes(); - - // Notary receives the raw signing key from some outer context - let mut signer = SigningKey::from_bytes(&raw_key).unwrap(); - let notary_pubkey = PublicKey::from(*signer.verifying_key()); - let notary_verifing_key = *signer.verifying_key(); - - // Notary creates the session header - assert!(data_sent.len() <= (u32::MAX as usize) && data_recv.len() <= (u32::MAX as usize)); - - let header = SessionHeader::new( - fixtures::encoder_seed(), - notarized_session_data.commitments().merkle_root(), - data_sent.len(), - data_recv.len(), - // the session's end time and TLS handshake start time may be a few mins apart - HandshakeSummary::new(time + 60, ephem_key.clone(), hs_commitment), - ); - - let signature: P256Signature = signer.sign(&header.to_bytes()); - // Notary creates a msg and sends it to Prover - let msg = SignedSessionHeader { - header, - signature: signature.into(), - }; - - //--------------------------------------- - let msg_bytes = bincode::serialize(&msg).unwrap(); - let SignedSessionHeader { header, signature } = bincode::deserialize(&msg_bytes).unwrap(); - //--------------------------------------- - - // Prover verifies the signature - #[allow(irrefutable_let_patterns)] - if let Signature::P256(signature) = signature { - notary_verifing_key - .verify(&header.to_bytes(), &signature) - .unwrap(); - } else { - panic!("Notary signature is not P256"); - }; - - // Prover verifies the header and stores it with the signature in NotarizedSession - header - .verify( - time, - &ephem_key, - ¬arized_session_data.commitments().merkle_root(), - header.encoder_seed(), - ¬arized_session_data.session_info().handshake_decommitment, - ) - .unwrap(); - - let session = NotarizedSession::new(header, Some(signature), notarized_session_data); - - // Prover converts NotarizedSession into SessionProof and SubstringsProof and sends them to the Verifier - let session_proof = session.session_proof(); - - let mut substrings_proof_builder = session.data().build_substrings_proof(); - - substrings_proof_builder - .reveal(commitment_id_1) - .unwrap() - .reveal(commitment_id_2) - .unwrap(); - - let substrings_proof = substrings_proof_builder.build().unwrap(); - - //--------------------------------------- - let session_proof_bytes = bincode::serialize(&session_proof).unwrap(); - let substrings_proof_bytes = bincode::serialize(&substrings_proof).unwrap(); - let session_proof: SessionProof = bincode::deserialize(&session_proof_bytes).unwrap(); - let substrings_proof: SubstringsProof = bincode::deserialize(&substrings_proof_bytes).unwrap(); - //--------------------------------------- - - // The Verifier does: - session_proof - .verify_with_default_cert_verifier(notary_pubkey) - .unwrap(); - - let SessionProof { - header, - session_info, - .. - } = session_proof; - - // assert dns name is expected - assert_eq!( - session_info.server_name.as_ref(), - testdata.dns_name.as_str() - ); - - let (sent, recv) = substrings_proof.verify(&header).unwrap(); - - assert_eq!(&sent.data()[range1], b"se".as_slice()); - assert_eq!(&recv.data()[range2], b"ec".as_slice()); -} diff --git a/tlsn/tlsn-formats/Cargo.toml b/tlsn/tlsn-formats/Cargo.toml deleted file mode 100644 index 6ac3af0118..0000000000 --- a/tlsn/tlsn-formats/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "tlsn-formats" -version = "0.1.0" -edition = "2021" - -[dependencies] -tlsn-core.workspace = true -tlsn-utils.workspace = true - -bytes.workspace = true -spansy = { workspace = true, features = ["serde"] } -serde.workspace = true -thiserror.workspace = true - -[dev-dependencies] -tlsn-core = { workspace = true, features = ["fixtures"] } -rstest.workspace = true diff --git a/tlsn/tlsn-formats/src/http/body.rs b/tlsn/tlsn-formats/src/http/body.rs deleted file mode 100644 index 4a5bae223e..0000000000 --- a/tlsn/tlsn-formats/src/http/body.rs +++ /dev/null @@ -1,136 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use tlsn_core::{ - commitment::{CommitmentId, TranscriptCommitmentBuilder, TranscriptCommitments}, - proof::SubstringsProofBuilder, - Direction, -}; - -use crate::{ - http::{HttpCommitmentBuilderError, HttpProofBuilderError}, - json::{JsonBody, JsonCommitmentBuilder, JsonProofBuilder}, - unknown::{UnknownCommitmentBuilder, UnknownProofBuilder, UnknownSpan}, -}; - -/// A body of an HTTP request or response -#[derive(Debug, Serialize, Deserialize)] -#[non_exhaustive] -pub enum Body { - /// A JSON body - Json(JsonBody), - /// A body with an unsupported content type - Unknown(UnknownSpan), -} - -/// Builder for commitments to an HTTP body. -#[derive(Debug)] -#[non_exhaustive] -pub enum BodyCommitmentBuilder<'a> { - /// Builder for commitments to a JSON body. - Json(JsonCommitmentBuilder<'a>), - /// Builder for commitments to a body with an unknown format. - Unknown(UnknownCommitmentBuilder<'a>), -} - -impl<'a> BodyCommitmentBuilder<'a> { - pub(crate) fn new( - builder: &'a mut TranscriptCommitmentBuilder, - value: &'a Body, - direction: Direction, - built: &'a mut bool, - ) -> Self { - match value { - Body::Json(body) => BodyCommitmentBuilder::Json(JsonCommitmentBuilder::new( - builder, &body.0, direction, built, - )), - Body::Unknown(body) => BodyCommitmentBuilder::Unknown(UnknownCommitmentBuilder::new( - builder, body, direction, built, - )), - } - } - - /// Commits to the entire body. - pub fn all(&mut self) -> Result { - match self { - BodyCommitmentBuilder::Json(builder) => builder - .all() - .map_err(|e| HttpCommitmentBuilderError::Body(e.to_string())), - BodyCommitmentBuilder::Unknown(builder) => builder - .all() - .map_err(|e| HttpCommitmentBuilderError::Body(e.to_string())), - } - } - - /// Builds the commitment to the body. - pub fn build(self) -> Result<(), HttpCommitmentBuilderError> { - match self { - BodyCommitmentBuilder::Json(builder) => builder - .build() - .map_err(|e| HttpCommitmentBuilderError::Body(e.to_string())), - BodyCommitmentBuilder::Unknown(builder) => builder - .build() - .map_err(|e| HttpCommitmentBuilderError::Body(e.to_string())), - } - } -} - -/// Builder for proofs of an HTTP body. -#[derive(Debug)] -#[non_exhaustive] -pub enum BodyProofBuilder<'a, 'b> { - /// Builder for proofs of a JSON body. - Json(JsonProofBuilder<'a, 'b>), - /// Builder for proofs of a body with an unknown format. - Unknown(UnknownProofBuilder<'a, 'b>), -} - -impl<'a, 'b> BodyProofBuilder<'a, 'b> { - pub(crate) fn new( - builder: &'a mut SubstringsProofBuilder<'b>, - commitments: &'a TranscriptCommitments, - value: &'a Body, - direction: Direction, - built: &'a mut bool, - ) -> Self { - match value { - Body::Json(body) => BodyProofBuilder::Json(JsonProofBuilder::new( - builder, - commitments, - &body.0, - direction, - built, - )), - Body::Unknown(body) => BodyProofBuilder::Unknown(UnknownProofBuilder::new( - builder, - commitments, - body, - direction, - built, - )), - } - } - - /// Proves the entire body. - pub fn all(&mut self) -> Result<(), HttpProofBuilderError> { - match self { - BodyProofBuilder::Json(builder) => builder - .all() - .map_err(|e| HttpProofBuilderError::Body(e.to_string())), - BodyProofBuilder::Unknown(builder) => builder - .all() - .map_err(|e| HttpProofBuilderError::Body(e.to_string())), - } - } - - /// Builds the proof for the body. - pub fn build(self) -> Result<(), HttpProofBuilderError> { - match self { - BodyProofBuilder::Json(builder) => builder - .build() - .map_err(|e| HttpProofBuilderError::Body(e.to_string())), - BodyProofBuilder::Unknown(builder) => builder - .build() - .map_err(|e| HttpProofBuilderError::Body(e.to_string())), - } - } -} diff --git a/tlsn/tlsn-formats/src/http/commitment.rs b/tlsn/tlsn-formats/src/http/commitment.rs deleted file mode 100644 index fb98e558b3..0000000000 --- a/tlsn/tlsn-formats/src/http/commitment.rs +++ /dev/null @@ -1,332 +0,0 @@ -use std::fmt::Debug; - -use crate::http::{Body, BodyCommitmentBuilder, Request, Response}; -use spansy::Spanned; -use tlsn_core::{ - commitment::{CommitmentId, TranscriptCommitmentBuilder, TranscriptCommitmentBuilderError}, - Direction, -}; -use utils::range::{RangeSet, RangeSubset, RangeUnion}; - -use super::PUBLIC_HEADERS; - -/// HTTP commitment builder error. -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum HttpCommitmentBuilderError { - /// Header is missing. - #[error("header with name \"{0}\" does not exist.")] - MissingHeader(String), - /// Body commitment error. - #[error("body commitment error: {0}")] - Body(String), - /// Transcript commitment builder error. - #[error("commitment builder error: {0}")] - Commitment(#[from] TranscriptCommitmentBuilderError), -} - -/// Builder for commitments to data in an HTTP connection. -#[derive(Debug)] -pub struct HttpCommitmentBuilder<'a> { - builder: &'a mut TranscriptCommitmentBuilder, - requests: &'a [(Request, Option)], - responses: &'a [(Response, Option)], - built_requests: Vec, - built_responses: Vec, -} - -impl<'a> HttpCommitmentBuilder<'a> { - #[doc(hidden)] - pub fn new( - builder: &'a mut TranscriptCommitmentBuilder, - requests: &'a [(Request, Option)], - responses: &'a [(Response, Option)], - ) -> Self { - Self { - builder, - requests, - responses, - built_requests: vec![false; requests.len()], - built_responses: vec![false; responses.len()], - } - } - - /// Returns a commitment builder for the request at the given index. - /// - /// # Arguments - /// - /// * `index` - The index of the request. - #[must_use] - pub fn request(&mut self, index: usize) -> Option> { - self.requests.get(index).map(|request| { - HttpRequestCommitmentBuilder::new( - self.builder, - &request.0, - request.1.as_ref(), - &mut self.built_requests[index], - ) - }) - } - - /// Returns a commitment builder for the response at the given index. - /// - /// # Arguments - /// - /// * `index` - The index of the response. - #[must_use] - pub fn response(&mut self, index: usize) -> Option> { - self.responses.get(index).map(|response| { - HttpResponseCommitmentBuilder::new( - self.builder, - &response.0, - response.1.as_ref(), - &mut self.built_responses[index], - ) - }) - } - - /// Builds commitments to the HTTP requests and responses. - /// - /// This automatically will commit to all header values which have no yet been committed. - pub fn build(mut self) -> Result<(), HttpCommitmentBuilderError> { - // Builds all request commitments - for i in 0..self.requests.len() { - if !self.built_requests[i] { - self.request(i).unwrap().build()?; - } - } - - // Build all response commitments - for i in 0..self.responses.len() { - if !self.built_responses[i] { - self.response(i).unwrap().build()?; - } - } - - Ok(()) - } -} - -/// Builder for commitments to an HTTP request. -#[derive(Debug)] -pub struct HttpRequestCommitmentBuilder<'a> { - builder: &'a mut TranscriptCommitmentBuilder, - request: &'a Request, - body: Option<&'a Body>, - committed: RangeSet, - built: &'a mut bool, - body_built: bool, -} - -impl<'a> HttpRequestCommitmentBuilder<'a> { - pub(crate) fn new( - builder: &'a mut TranscriptCommitmentBuilder, - request: &'a Request, - body: Option<&'a Body>, - built: &'a mut bool, - ) -> Self { - Self { - builder, - request, - body, - committed: RangeSet::default(), - built, - body_built: false, - } - } - - /// Commits to the path of the request. - pub fn path(&mut self) -> Result { - let range = self.request.0.path.range(); - let id = self.builder.commit_sent(range.clone())?; - - self.committed = self.committed.union(&range); - - Ok(id) - } - - /// Commits the value of the header with the given name. - /// - /// # Arguments - /// - /// * `name` - The name of the header value to commit. - pub fn header(&mut self, name: &str) -> Result { - let header = self - .request - .0 - .header(name) - .ok_or(HttpCommitmentBuilderError::MissingHeader(name.to_string()))?; - - let range = header.value.span().range(); - let id = self.builder.commit_sent(range.clone())?; - - self.committed = self.committed.union(&range); - - Ok(id) - } - - /// Commits all request headers. - /// - /// Returns a vector of the names of the headers that were committed and their commitment IDs. - pub fn headers(&mut self) -> Result, HttpCommitmentBuilderError> { - let mut commitments = Vec::new(); - - for header in &self.request.0.headers { - let name = header.name.span().as_str().to_string(); - let id = self.header(&name)?; - - commitments.push((name, id)); - } - - Ok(commitments) - } - - /// Returns a commitment builder for the request body if it exists. - pub fn body(&mut self) -> Option> { - self.body.map(|body| { - BodyCommitmentBuilder::new(self.builder, body, Direction::Sent, &mut self.body_built) - }) - } - - /// Finishes building the request commitment. - /// - /// This commits to everything that has not already been committed, including a commitment - /// to the format data of the request. - pub fn build(mut self) -> Result<(), HttpCommitmentBuilderError> { - // Commit to the path if it has not already been committed. - let path_range = self.request.0.path.range(); - if !path_range.is_subset(&self.committed) { - self.path()?; - } - - // Commit to any headers that have not already been committed. - for header in &self.request.0.headers { - let name = header.name.span().as_str().to_ascii_lowercase(); - - // Public headers can not be committed separately - if PUBLIC_HEADERS.contains(&name.as_str()) { - continue; - } - - let range = header.value.span().range(); - if !range.is_subset(&self.committed) { - self.header(&name)?; - } - } - - self.builder.commit_sent(self.request.public_ranges())?; - - if self.body.is_some() && !self.body_built { - self.body().unwrap().build()?; - } - - *self.built = true; - - Ok(()) - } -} - -/// Builder for commitments to an HTTP response. -#[derive(Debug)] -pub struct HttpResponseCommitmentBuilder<'a> { - builder: &'a mut TranscriptCommitmentBuilder, - response: &'a Response, - body: Option<&'a Body>, - committed: RangeSet, - built: &'a mut bool, - body_built: bool, -} - -impl<'a> HttpResponseCommitmentBuilder<'a> { - pub(crate) fn new( - builder: &'a mut TranscriptCommitmentBuilder, - response: &'a Response, - body: Option<&'a Body>, - built: &'a mut bool, - ) -> Self { - Self { - builder, - response, - body, - committed: RangeSet::default(), - built, - body_built: false, - } - } - - /// Commits the value of the header with the given name. - /// - /// # Arguments - /// - /// * `name` - The name of the header value to commit. - pub fn header(&mut self, name: &str) -> Result { - let header = self - .response - .0 - .header(name) - .ok_or(HttpCommitmentBuilderError::MissingHeader(name.to_string()))?; - - self.builder - .commit_recv(header.value.span().range()) - .map_err(From::from) - } - - /// Commits all response headers. - /// - /// Returns a vector of the names of the headers that were committed and their commitment IDs. - pub fn headers(&mut self) -> Result, HttpCommitmentBuilderError> { - let mut commitments = Vec::new(); - - for header in &self.response.0.headers { - let name = header.name.span().as_str().to_string(); - let id = self.header(&name)?; - - commitments.push((name, id)); - } - - Ok(commitments) - } - - /// Returns a commitment builder for the response body if it exists. - pub fn body(&mut self) -> Option> { - self.body.map(|body| { - BodyCommitmentBuilder::new( - self.builder, - body, - Direction::Received, - &mut self.body_built, - ) - }) - } - - /// Finishes building the response commitment. - /// - /// This commits to everything that has not already been committed, including a commitment - /// to the format data of the response. - pub fn build(mut self) -> Result<(), HttpCommitmentBuilderError> { - // Commit to any headers that have not already been committed. - for header in &self.response.0.headers { - let name = header.name.span().as_str().to_ascii_lowercase(); - - // Public headers can not be committed separately - if PUBLIC_HEADERS.contains(&name.as_str()) { - continue; - } - - let range = header.value.span().range(); - if !range.is_subset(&self.committed) { - self.header(&name)?; - } - } - - self.builder.commit_recv(self.response.public_ranges())?; - - if self.body.is_some() && !self.body_built { - self.body().unwrap().build()?; - } - - *self.built = true; - - Ok(()) - } -} diff --git a/tlsn/tlsn-formats/src/http/mod.rs b/tlsn/tlsn-formats/src/http/mod.rs deleted file mode 100644 index 707516ad72..0000000000 --- a/tlsn/tlsn-formats/src/http/mod.rs +++ /dev/null @@ -1,208 +0,0 @@ -//! Tooling for working with HTTP data. - -mod body; -mod commitment; -mod parse; -mod proof; -mod session; - -pub use body::{Body, BodyCommitmentBuilder, BodyProofBuilder}; -pub use commitment::{ - HttpCommitmentBuilder, HttpCommitmentBuilderError, HttpRequestCommitmentBuilder, - HttpResponseCommitmentBuilder, -}; -pub use parse::{parse_body, parse_requests, parse_responses, ParseError}; -pub use proof::{HttpProofBuilder, HttpProofBuilderError}; -pub use session::NotarizedHttpSession; - -use serde::{Deserialize, Serialize}; -use spansy::Spanned; -use utils::range::{RangeDifference, RangeSet, RangeUnion}; - -static PUBLIC_HEADERS: &[&str] = &["content-length", "content-type"]; - -/// An HTTP request. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Request(pub(crate) spansy::http::Request); - -impl Request { - pub(crate) fn public_ranges(&self) -> RangeSet { - let mut private_ranges = RangeSet::default(); - - let path_range = self.0.path.range(); - - private_ranges = private_ranges.union(&path_range); - - for header in &self.0.headers { - let name = header.name.span().as_str().to_ascii_lowercase(); - let range = header.value.span().range(); - if !PUBLIC_HEADERS.contains(&name.as_str()) { - private_ranges = private_ranges.union(&range); - } - } - - if let Some(body) = &self.0.body { - private_ranges = private_ranges.union(&body.span().range()); - } - - self.0.span().range().difference(&private_ranges) - } -} - -/// An HTTP response. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Response(pub(crate) spansy::http::Response); - -impl Response { - pub(crate) fn public_ranges(&self) -> RangeSet { - let mut private_ranges = RangeSet::default(); - - for header in &self.0.headers { - let name = header.name.span().as_str().to_ascii_lowercase(); - let range = header.value.span().range(); - if !PUBLIC_HEADERS.contains(&name.as_str()) { - private_ranges = private_ranges.union(&range); - } - } - - if let Some(body) = &self.0.body { - private_ranges = private_ranges.union(&body.span().range()); - } - - self.0.span().range().difference(&private_ranges) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use bytes::Bytes; - use tlsn_core::{ - commitment::{CommitmentKind, TranscriptCommitmentBuilder}, - fixtures, - proof::SubstringsProofBuilder, - Direction, Transcript, - }; - - static TX: &[u8] = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n\ - POST /hello HTTP/1.1\r\nHost: localhost\r\nContent-Length: 44\r\nContent-Type: application/json\r\n\r\n\ - {\"foo\": \"bar\", \"bazz\": 123, \"buzz\": [1,\"5\"]}"; - static RX: &[u8] = - b"HTTP/1.1 200 OK\r\nCookie: very-secret-cookie\r\nContent-Length: 14\r\nContent-Type: application/json\r\n\r\n\ - {\"foo\": \"bar\"}\r\n\ - HTTP/1.1 200 OK\r\nContent-Length: 14\r\nContent-Type: text/plain\r\n\r\n\ - Hello World!!!"; - - #[test] - fn test_http_commit() { - let mut transcript_commitment_builder = TranscriptCommitmentBuilder::new( - fixtures::encoding_provider(TX, RX), - TX.len(), - RX.len(), - ); - - let requests = parse_requests(Bytes::copy_from_slice(TX)).unwrap(); - let responses = parse_responses(Bytes::copy_from_slice(RX)).unwrap(); - - HttpCommitmentBuilder::new(&mut transcript_commitment_builder, &requests, &responses) - .build() - .unwrap(); - - let commitments = transcript_commitment_builder.build().unwrap(); - - // Path - assert!(commitments - .get_id_by_info(CommitmentKind::Blake3, (4..5).into(), Direction::Sent) - .is_some()); - - // Host - assert!(commitments - .get_id_by_info(CommitmentKind::Blake3, (22..31).into(), Direction::Sent) - .is_some()); - // foo - assert!(commitments - .get_id_by_info(CommitmentKind::Blake3, (137..140).into(), Direction::Sent) - .is_some()); - - // Cookie - assert!(commitments - .get_id_by_info(CommitmentKind::Blake3, (25..43).into(), Direction::Received) - .is_some()); - // Body - assert!(commitments - .get_id_by_info( - CommitmentKind::Blake3, - (180..194).into(), - Direction::Received - ) - .is_some()); - } - - #[test] - fn test_http_prove() { - let transcript_tx = Transcript::new(TX); - let transcript_rx = Transcript::new(RX); - - let mut transcript_commitment_builder = TranscriptCommitmentBuilder::new( - fixtures::encoding_provider(TX, RX), - TX.len(), - RX.len(), - ); - - let requests = parse_requests(Bytes::copy_from_slice(TX)).unwrap(); - let responses = parse_responses(Bytes::copy_from_slice(RX)).unwrap(); - - HttpCommitmentBuilder::new(&mut transcript_commitment_builder, &requests, &responses) - .build() - .unwrap(); - - let commitments = transcript_commitment_builder.build().unwrap(); - - let spb = SubstringsProofBuilder::new(&commitments, &transcript_tx, &transcript_rx); - - let mut builder = HttpProofBuilder::new(spb, &commitments, &requests, &responses); - - let mut req_0 = builder.request(0).unwrap(); - - req_0.path().unwrap(); - req_0.header("host").unwrap(); - - let mut req_1 = builder.request(1).unwrap(); - - req_1.path().unwrap(); - - let BodyProofBuilder::Json(mut json) = req_1.body().unwrap() else { - unreachable!(); - }; - - json.path("bazz").unwrap(); - - let mut resp_0 = builder.response(0).unwrap(); - - resp_0.header("cookie").unwrap(); - - assert!(matches!(resp_0.body().unwrap(), BodyProofBuilder::Json(_))); - - let mut resp_1 = builder.response(1).unwrap(); - - let BodyProofBuilder::Unknown(mut unknown) = resp_1.body().unwrap() else { - unreachable!(); - }; - - unknown.all().unwrap(); - - let proof = builder.build().unwrap(); - - let header = fixtures::session_header(commitments.merkle_root(), TX.len(), RX.len()); - - let (sent, recv) = proof.verify(&header).unwrap(); - - assert_eq!(&sent.data()[4..5], b"/"); - assert_eq!(&sent.data()[22..31], b"localhost"); - assert_eq!(&sent.data()[151..154], b"123"); - - assert_eq!(&recv.data()[25..43], b"very-secret-cookie"); - assert_eq!(&recv.data()[180..194], b"Hello World!!!"); - } -} diff --git a/tlsn/tlsn-formats/src/http/parse.rs b/tlsn/tlsn-formats/src/http/parse.rs deleted file mode 100644 index 8b309e828e..0000000000 --- a/tlsn/tlsn-formats/src/http/parse.rs +++ /dev/null @@ -1,262 +0,0 @@ -use bytes::Bytes; -use spansy::{ - http::{Requests, Responses}, - json::{self}, - Spanned, -}; - -use crate::{ - http::{Body, Request, Response}, - json::JsonBody, - unknown::UnknownSpan, -}; - -/// An HTTP transcript parse error -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum ParseError { - /// Failed to parse request - #[error("failed to parse request at index {index}: {reason}")] - Request { - /// The index of the request - index: usize, - /// The reason for the error - reason: String, - }, - /// Failed to parse response - #[error("failed to parse response at index {index}: {reason}")] - Response { - /// The index of the response - index: usize, - /// The reason for the error - reason: String, - }, - /// Failed to parse JSON body - #[error("failed to parse JSON at index {index}: {reason}")] - Json { - /// The index of the request or response - index: usize, - /// The reason for the error - reason: String, - }, -} - -/// Parses a body of an HTTP request or response -/// -/// # Arguments -/// -/// * `index` - The index of the request or response -/// * `content_type` - The content type of the body -/// * `body` - The body data -/// * `offset` - The offset of the body from the start of the transcript -/// -/// # Panics -/// -/// Panics if the range and body length do not match. -pub fn parse_body( - index: usize, - content_type: &[u8], - body: Bytes, - offset: usize, -) -> Result { - if content_type.get(..16) == Some(b"application/json".as_slice()) { - let mut body = json::parse(body).map_err(|e| ParseError::Json { - index, - reason: e.to_string(), - })?; - - body.offset(offset); - - Ok(Body::Json(JsonBody(body))) - } else { - Ok(Body::Unknown(UnknownSpan::new(offset..offset + body.len()))) - } -} - -/// Parses the requests of an HTTP transcript. -/// -/// # Arguments -/// -/// * `data` - The HTTP transcript data -pub fn parse_requests(data: Bytes) -> Result)>, ParseError> { - let mut requests = Vec::new(); - for (index, request) in Requests::new(data.clone()).enumerate() { - let request = request.map_err(|e| ParseError::Request { - index, - reason: e.to_string(), - })?; - - let body = if let Some(ref body) = request.body { - let range = body.span().range(); - let body = data.slice(range.clone()); - - let body = if let Some(content_type) = request.header("content-type") { - parse_body( - index, - content_type.value.span().as_bytes(), - body, - range.start, - )? - } else { - Body::Unknown(UnknownSpan::new(range)) - }; - - Some(body) - } else { - None - }; - - requests.push((Request(request), body)); - } - - Ok(requests) -} - -/// Parses the responses of an HTTP transcript. -/// -/// # Arguments -/// -/// * `data` - The HTTP transcript data -pub fn parse_responses(data: Bytes) -> Result)>, ParseError> { - let mut responses = Vec::new(); - for (index, response) in Responses::new(data.clone()).enumerate() { - let response = response.map_err(|e| ParseError::Response { - index, - reason: e.to_string(), - })?; - - let body = if let Some(ref body) = response.body { - let range = body.span().range(); - let body = data.slice(range.clone()); - - let body = if let Some(content_type) = response.header("content-type") { - parse_body( - index, - content_type.value.span().as_bytes(), - body, - range.start, - )? - } else { - Body::Unknown(UnknownSpan::new(range)) - }; - - Some(body) - } else { - None - }; - - responses.push((Response(response), body)); - } - - Ok(responses) -} - -#[cfg(test)] -mod tests { - use super::*; - - use bytes::Bytes; - - #[test] - fn test_parse_body_json() { - let body = b"{\"foo\": \"bar\"}"; - - let body = parse_body(0, b"application/json", Bytes::copy_from_slice(body), 0).unwrap(); - - let Body::Json(body) = body else { - unreachable!(); - }; - - let range = body.0.span().range(); - - assert_eq!(range.start, 0); - assert_eq!(range.end, 14); - assert_eq!(body.0.span().as_str(), "{\"foo\": \"bar\"}"); - - let foo = body.0.get("foo").unwrap(); - let range = foo.span().range(); - - assert_eq!(range.start, 9); - assert_eq!(range.end, 12); - assert_eq!(foo.span().as_str(), "bar"); - } - - #[test] - fn test_parse_body_json_offset() { - let body = b" {\"foo\": \"bar\"}"; - - let body = parse_body( - 0, - b"application/json", - Bytes::copy_from_slice(&body[4..]), - 4, - ) - .unwrap(); - - let Body::Json(body) = body else { - unreachable!(); - }; - - let range = body.0.span().range(); - - assert_eq!(range.start, 4); - assert_eq!(range.end, 18); - assert_eq!(body.0.span().as_str(), "{\"foo\": \"bar\"}"); - - let foo = body.0.get("foo").unwrap(); - let range = foo.span().range(); - - assert_eq!(range.start, 13); - assert_eq!(range.end, 16); - assert_eq!(foo.span().as_str(), "bar"); - } - - #[test] - fn test_parse_body_unknown() { - let body = b"foo"; - - let body = parse_body(0, b"text/plain", Bytes::copy_from_slice(body), 0).unwrap(); - - assert!(matches!(body, Body::Unknown(_))); - } - - #[test] - fn test_parse_requests() { - let reqs = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n\ - POST /hello HTTP/1.1\r\nHost: localhost\r\nContent-Length: 14\r\nContent-Type: application/json\r\n\r\n\ - {\"foo\": \"bar\"}"; - - let requests = parse_requests(Bytes::copy_from_slice(reqs)).unwrap(); - - assert_eq!(requests.len(), 2); - assert!(requests[0].1.is_none()); - assert!(requests[1].1.is_some()); - - let Body::Json(body) = requests[1].1.as_ref().unwrap() else { - unreachable!(); - }; - - let foo = body.0.get("foo").unwrap(); - let range = foo.span().range(); - - assert_eq!(range.start, 137); - assert_eq!(range.end, 140); - } - - #[test] - fn test_parse_responses() { - let resps = - b"HTTP/1.1 200 OK\r\nContent-Length: 14\r\nContent-Type: application/json\r\n\r\n\ - {\"foo\": \"bar\"}\r\n\ - HTTP/1.1 200 OK\r\nContent-Length: 14\r\nContent-Type: text/plain\r\n\r\n\ - Hello World!!!"; - - let responses = parse_responses(Bytes::copy_from_slice(resps)).unwrap(); - - assert_eq!(responses.len(), 2); - assert!(responses[0].1.is_some()); - assert!(responses[1].1.is_some()); - assert!(matches!(responses[0].1.as_ref().unwrap(), Body::Json(_))); - assert!(matches!(responses[1].1.as_ref().unwrap(), Body::Unknown(_))); - } -} diff --git a/tlsn/tlsn-formats/src/http/proof.rs b/tlsn/tlsn-formats/src/http/proof.rs deleted file mode 100644 index 78a93b671a..0000000000 --- a/tlsn/tlsn-formats/src/http/proof.rs +++ /dev/null @@ -1,315 +0,0 @@ -use std::ops::Range; - -use crate::http::{body::BodyProofBuilder, Body, Request, Response}; -use spansy::Spanned; -use tlsn_core::{ - commitment::{CommitmentId, CommitmentKind, TranscriptCommitments}, - proof::{SubstringsProof, SubstringsProofBuilder, SubstringsProofBuilderError}, - Direction, -}; - -/// HTTP proof builder error. -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum HttpProofBuilderError { - /// Header is missing. - #[error("header with name \"{0}\" does not exist.")] - MissingHeader(String), - /// Body proof error. - #[error("body proof error: {0}")] - Body(String), - /// Missing commitment for value. - #[error("missing commitment for {0}")] - MissingCommitment(String), - /// Substrings proof builder error. - #[error("proof builder error: {0}")] - Proof(#[from] SubstringsProofBuilderError), -} - -/// Builder for proofs of data in an HTTP connection. -#[derive(Debug)] -pub struct HttpProofBuilder<'a, 'b> { - builder: SubstringsProofBuilder<'b>, - commitments: &'a TranscriptCommitments, - requests: &'a [(Request, Option)], - responses: &'a [(Response, Option)], - built_requests: Vec, - built_responses: Vec, -} - -impl<'a, 'b> HttpProofBuilder<'a, 'b> { - #[doc(hidden)] - pub fn new( - builder: SubstringsProofBuilder<'b>, - commitments: &'a TranscriptCommitments, - requests: &'a [(Request, Option)], - responses: &'a [(Response, Option)], - ) -> Self { - Self { - builder, - commitments, - requests, - responses, - built_requests: vec![false; requests.len()], - built_responses: vec![false; responses.len()], - } - } - - /// Returns a proof builder for the given request, if it exists. - /// - /// # Arguments - /// - /// * `index` - The index of the request to build a proof for. - pub fn request<'c>(&'c mut self, index: usize) -> Option> - where - 'a: 'c, - { - self.requests - .get(index) - .map(|request| HttpRequestProofBuilder { - builder: &mut self.builder, - commitments: self.commitments, - request: &request.0, - body: request.1.as_ref(), - built: &mut self.built_requests[index], - body_built: false, - }) - } - - /// Returns a proof builder for the given response, if it exists. - /// - /// # Arguments - /// - /// * `index` - The index of the response to build a proof for. - pub fn response<'c>(&'c mut self, index: usize) -> Option> - where - 'a: 'c, - { - self.responses - .get(index) - .map(|response| HttpResponseProofBuilder { - builder: &mut self.builder, - commitments: self.commitments, - response: &response.0, - body: response.1.as_ref(), - built: &mut self.built_responses[index], - body_built: false, - }) - } - - /// Builds the HTTP transcript proof. - pub fn build(mut self) -> Result { - // Build any remaining request proofs - for i in 0..self.requests.len() { - if !self.built_requests[i] { - self.request(i).unwrap().build()?; - } - } - - // Build any remaining response proofs - for i in 0..self.responses.len() { - if !self.built_responses[i] { - self.response(i).unwrap().build()?; - } - } - - self.builder.build().map_err(From::from) - } -} - -#[derive(Debug)] -pub struct HttpRequestProofBuilder<'a, 'b> { - builder: &'a mut SubstringsProofBuilder<'b>, - commitments: &'a TranscriptCommitments, - request: &'a Request, - body: Option<&'a Body>, - built: &'a mut bool, - // TODO: this field will be used in the future to support advanced configurations - // but for now we don't want to build the body proof unless it is specifically requested - body_built: bool, -} - -impl<'a, 'b> HttpRequestProofBuilder<'a, 'b> { - /// Reveals the entirety of the request. - /// - /// # Arguments - /// - /// * `body` - Whether to reveal the entirety of the request body as well. - pub fn all(&mut self, body: bool) -> Result<&mut Self, HttpProofBuilderError> { - let id = self - .commit_id(self.request.0.span().range()) - .ok_or_else(|| { - HttpProofBuilderError::MissingCommitment("the entire request".to_string()) - })?; - - self.builder.reveal(id)?; - - if body && self.body.is_some() { - self.body().unwrap().all()?; - } - - Ok(self) - } - - /// Reveals the path of the request. - pub fn path(&mut self) -> Result<&mut Self, HttpProofBuilderError> { - let id = self - .commit_id(self.request.0.path.range()) - .ok_or_else(|| HttpProofBuilderError::MissingCommitment("path".to_string()))?; - - self.builder.reveal(id)?; - - Ok(self) - } - - /// Reveals the value of the given header. - /// - /// # Arguments - /// - /// * `name` - The name of the header value to reveal. - pub fn header(&mut self, name: &str) -> Result<&mut Self, HttpProofBuilderError> { - let header = self - .request - .0 - .header(name) - .ok_or_else(|| HttpProofBuilderError::MissingHeader(name.to_string()))?; - - let id = self.commit_id(header.value.span().range()).ok_or_else(|| { - HttpProofBuilderError::MissingCommitment(format!("header \"{}\"", name)) - })?; - - self.builder.reveal(id)?; - - Ok(self) - } - - /// Returns a proof builder for the request body, if it exists. - pub fn body<'c>(&'c mut self) -> Option> { - self.body.map(|body| { - BodyProofBuilder::new( - self.builder, - self.commitments, - body, - Direction::Sent, - &mut self.body_built, - ) - }) - } - - /// Builds the HTTP request proof. - pub fn build(self) -> Result<(), HttpProofBuilderError> { - let public_id = self - .commitments - .get_id_by_info( - CommitmentKind::Blake3, - self.request.public_ranges(), - Direction::Sent, - ) - .ok_or_else(|| HttpProofBuilderError::MissingCommitment("public data".to_string()))?; - - self.builder.reveal(public_id)?; - - *self.built = true; - - Ok(()) - } - - fn commit_id(&self, range: Range) -> Option { - // TODO: support different kinds of commitments - self.commitments - .get_id_by_info(CommitmentKind::Blake3, range.into(), Direction::Sent) - } -} - -#[derive(Debug)] -pub struct HttpResponseProofBuilder<'a, 'b: 'a> { - builder: &'a mut SubstringsProofBuilder<'b>, - commitments: &'a TranscriptCommitments, - response: &'a Response, - body: Option<&'a Body>, - built: &'a mut bool, - // TODO: this field will be used in the future to support advanced configurations - // but for now we don't want to build the body proof unless it is specifically requested - body_built: bool, -} - -impl<'a, 'b> HttpResponseProofBuilder<'a, 'b> { - /// Reveals the entirety of the response. - /// - /// # Arguments - /// - /// * `body` - Whether to reveal the entirety of the response body as well. - pub fn all(&mut self, body: bool) -> Result<&mut Self, HttpProofBuilderError> { - let id = self - .commit_id(self.response.0.span().range()) - .ok_or_else(|| { - HttpProofBuilderError::MissingCommitment("the entire response".to_string()) - })?; - - self.builder.reveal(id)?; - - if body && self.body.is_some() { - self.body().unwrap().all()?; - } - - Ok(self) - } - - /// Reveals the value of the given header. - /// - /// # Arguments - /// - /// * `name` - The name of the header value to reveal. - pub fn header(&mut self, name: &str) -> Result<&mut Self, HttpProofBuilderError> { - let header = self - .response - .0 - .header(name) - .ok_or_else(|| HttpProofBuilderError::MissingHeader(name.to_string()))?; - - let id = self.commit_id(header.value.span().range()).ok_or_else(|| { - HttpProofBuilderError::MissingCommitment(format!("header \"{}\"", name)) - })?; - - self.builder.reveal(id)?; - - Ok(self) - } - - /// Returns a proof builder for the response body, if it exists. - pub fn body<'c>(&'c mut self) -> Option> { - self.body.map(|body| { - BodyProofBuilder::new( - self.builder, - self.commitments, - body, - Direction::Received, - &mut self.body_built, - ) - }) - } - - /// Builds the HTTP response proof. - pub fn build(self) -> Result<(), HttpProofBuilderError> { - let public_id = self - .commitments - .get_id_by_info( - CommitmentKind::Blake3, - self.response.public_ranges(), - Direction::Received, - ) - .ok_or_else(|| HttpProofBuilderError::MissingCommitment("public data".to_string()))?; - - self.builder.reveal(public_id)?; - - *self.built = true; - - Ok(()) - } - - fn commit_id(&self, range: Range) -> Option { - // TODO: support different kinds of commitments - self.commitments - .get_id_by_info(CommitmentKind::Blake3, range.into(), Direction::Received) - } -} diff --git a/tlsn/tlsn-formats/src/http/session.rs b/tlsn/tlsn-formats/src/http/session.rs deleted file mode 100644 index 72624ddf0c..0000000000 --- a/tlsn/tlsn-formats/src/http/session.rs +++ /dev/null @@ -1,51 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use tlsn_core::{proof::SessionProof, NotarizedSession}; - -use crate::http::{Body, Request, Response}; - -use super::HttpProofBuilder; - -/// A notarized HTTP session. -#[derive(Debug, Serialize, Deserialize)] -pub struct NotarizedHttpSession { - session: NotarizedSession, - requests: Vec<(Request, Option)>, - responses: Vec<(Response, Option)>, -} - -impl NotarizedHttpSession { - /// Creates a new notarized HTTP session. - #[doc(hidden)] - pub fn new( - session: NotarizedSession, - requests: Vec<(Request, Option)>, - responses: Vec<(Response, Option)>, - ) -> Self { - Self { - session, - requests, - responses, - } - } - - /// Returns the notarized TLS session. - pub fn session(&self) -> &NotarizedSession { - &self.session - } - - /// Returns a proof for the TLS session. - pub fn session_proof(&self) -> SessionProof { - self.session.session_proof() - } - - /// Returns a proof builder for the HTTP session. - pub fn proof_builder(&self) -> HttpProofBuilder { - HttpProofBuilder::new( - self.session.data().build_substrings_proof(), - self.session.data().commitments(), - &self.requests, - &self.responses, - ) - } -} diff --git a/tlsn/tlsn-formats/src/json/commitment.rs b/tlsn/tlsn-formats/src/json/commitment.rs deleted file mode 100644 index 85aa758576..0000000000 --- a/tlsn/tlsn-formats/src/json/commitment.rs +++ /dev/null @@ -1,157 +0,0 @@ -use spansy::{ - json::{JsonValue, JsonVisit}, - Spanned, -}; -use tlsn_core::{ - commitment::{ - CommitmentId, CommitmentKind, TranscriptCommitmentBuilder, TranscriptCommitmentBuilderError, - }, - Direction, -}; - -use super::public_ranges; - -/// JSON commitment builder error. -#[derive(Debug, thiserror::Error)] -pub enum JsonCommitmentBuilderError { - /// Invalid path. - #[error("invalid path: {0}")] - InvalidPath(String), - /// Transcript commitment builder error. - #[error("commitment builder error: {0}")] - Commitment(#[from] TranscriptCommitmentBuilderError), -} - -/// Builder for commitments to a JSON value. -#[derive(Debug)] -pub struct JsonCommitmentBuilder<'a> { - builder: &'a mut TranscriptCommitmentBuilder, - value: &'a JsonValue, - direction: Direction, - built: &'a mut bool, -} - -impl<'a> JsonCommitmentBuilder<'a> { - pub(crate) fn new( - builder: &'a mut TranscriptCommitmentBuilder, - value: &'a JsonValue, - direction: Direction, - built: &'a mut bool, - ) -> Self { - JsonCommitmentBuilder { - builder, - value, - direction, - built, - } - } - - /// Commits to the entire JSON value. - pub fn all(&mut self) -> Result { - match self.direction { - Direction::Sent => self.builder.commit_sent(self.value.span().range()), - Direction::Received => self.builder.commit_recv(self.value.span().range()), - } - .map_err(From::from) - } - - /// Commits to the value at the given path. - pub fn path(&mut self, path: &str) -> Result { - let value = self.value.get(path).ok_or_else(|| { - JsonCommitmentBuilderError::InvalidPath(format!("invalid path: {}", path)) - })?; - - let range = value.span().range(); - match self.direction { - Direction::Sent => self.builder.commit_sent(range), - Direction::Received => self.builder.commit_recv(range), - } - .map_err(From::from) - } - - /// Finishes building commitments the a JSON value. - pub fn build(self) -> Result<(), JsonCommitmentBuilderError> { - let public_ranges = public_ranges(self.value); - - match self.direction { - Direction::Sent => self.builder.commit_sent(public_ranges)?, - Direction::Received => self.builder.commit_recv(public_ranges)?, - }; - - let mut visitor = JsonCommitter { - builder: self.builder, - direction: self.direction, - err: None, - }; - - visitor.visit_value(self.value); - - if let Some(err) = visitor.err { - err? - } - - *self.built = true; - - Ok(()) - } -} - -struct JsonCommitter<'a> { - builder: &'a mut TranscriptCommitmentBuilder, - direction: Direction, - err: Option>, -} - -impl<'a> JsonVisit for JsonCommitter<'a> { - fn visit_number(&mut self, node: &spansy::json::Number) { - if self.err.is_some() { - return; - } - - let range = node.span().range(); - if self - .builder - .get_id(CommitmentKind::Blake3, range.clone(), self.direction) - .is_some() - { - return; - } - - let res = match self.direction { - Direction::Sent => self.builder.commit_sent(range), - Direction::Received => self.builder.commit_recv(range), - } - .map(|_| ()) - .map_err(From::from); - - if res.is_err() { - self.err = Some(res); - } - } - - fn visit_string(&mut self, node: &spansy::json::String) { - if self.err.is_some() { - return; - } - - let range = node.span().range(); - if self - .builder - .get_id(CommitmentKind::Blake3, range.clone(), self.direction) - .is_some() - { - return; - } - - let res = match self.direction { - Direction::Sent => self.builder.commit_sent(range), - Direction::Received => self.builder.commit_recv(range), - } - .map(|_| ()) - .map_err(From::from); - - if res.is_err() { - self.err = Some(res); - } - } -} diff --git a/tlsn/tlsn-formats/src/json/mod.rs b/tlsn/tlsn-formats/src/json/mod.rs deleted file mode 100644 index 9e4375a168..0000000000 --- a/tlsn/tlsn-formats/src/json/mod.rs +++ /dev/null @@ -1,44 +0,0 @@ -//! Tooling for working with JSON data. - -mod commitment; -mod proof; - -pub use commitment::{JsonCommitmentBuilder, JsonCommitmentBuilderError}; -pub use proof::{JsonProofBuilder, JsonProofBuilderError}; - -use serde::{Deserialize, Serialize}; -use spansy::{ - json::{JsonValue, JsonVisit}, - Spanned, -}; -use utils::range::{RangeDifference, RangeSet, RangeUnion}; - -/// A JSON body -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct JsonBody(pub(crate) JsonValue); - -/// Computes all the public ranges of a JSON value. -/// -/// Right now this is just the ranges of all the numbers and strings. -pub(crate) fn public_ranges(value: &JsonValue) -> RangeSet { - #[derive(Default)] - struct PrivateRanges { - private_ranges: RangeSet, - } - - // For now only numbers and strings are redactable. - impl JsonVisit for PrivateRanges { - fn visit_number(&mut self, node: &spansy::json::Number) { - self.private_ranges = self.private_ranges.union(&node.span().range()); - } - - fn visit_string(&mut self, node: &spansy::json::String) { - self.private_ranges = self.private_ranges.union(&node.span().range()); - } - } - - let mut visitor = PrivateRanges::default(); - visitor.visit_value(value); - - value.span().range().difference(&visitor.private_ranges) -} diff --git a/tlsn/tlsn-formats/src/json/proof.rs b/tlsn/tlsn-formats/src/json/proof.rs deleted file mode 100644 index 30f16d7fec..0000000000 --- a/tlsn/tlsn-formats/src/json/proof.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::ops::Range; - -use spansy::{json::JsonValue, Spanned}; -use tlsn_core::{ - commitment::{CommitmentId, CommitmentKind, TranscriptCommitments}, - proof::{SubstringsProofBuilder, SubstringsProofBuilderError}, - Direction, -}; - -use crate::json::public_ranges; - -/// JSON proof builder error. -#[derive(Debug, thiserror::Error)] -pub enum JsonProofBuilderError { - /// Missing value - #[error("missing value at path: {0}")] - MissingValue(String), - /// Missing commitment. - #[error("missing commitment")] - MissingCommitment, - /// Substrings proof builder error. - #[error("proof builder error: {0}")] - Proof(#[from] SubstringsProofBuilderError), -} - -/// Builder for proofs of a JSON value. -#[derive(Debug)] -pub struct JsonProofBuilder<'a, 'b> { - builder: &'a mut SubstringsProofBuilder<'b>, - commitments: &'a TranscriptCommitments, - value: &'a JsonValue, - direction: Direction, - built: &'a mut bool, -} - -impl<'a, 'b> JsonProofBuilder<'a, 'b> { - pub(crate) fn new( - builder: &'a mut SubstringsProofBuilder<'b>, - commitments: &'a TranscriptCommitments, - value: &'a JsonValue, - direction: Direction, - built: &'a mut bool, - ) -> Self { - JsonProofBuilder { - builder, - commitments, - value, - direction, - built, - } - } - - /// Proves the entire JSON value. - pub fn all(&mut self) -> Result<(), JsonProofBuilderError> { - let id = self - .commit_id(self.value.span().range()) - .ok_or(JsonProofBuilderError::MissingCommitment)?; - - self.builder.reveal(id)?; - - Ok(()) - } - - /// Proves the value at the given path. - /// - /// # Arguments - /// - /// * `path` - The path to the value to prove. - pub fn path(&mut self, path: &str) -> Result<(), JsonProofBuilderError> { - let value = self - .value - .get(path) - .ok_or_else(|| JsonProofBuilderError::MissingValue(format!("\"{}\"", path)))?; - - let id = self - .commit_id(value.span().range()) - .ok_or(JsonProofBuilderError::MissingCommitment)?; - - self.builder.reveal(id)?; - - Ok(()) - } - - /// Finishes building the JSON proof. - pub fn build(self) -> Result<(), JsonProofBuilderError> { - let public_ranges = public_ranges(self.value); - - let public_id = self - .commitments - .get_id_by_info(CommitmentKind::Blake3, public_ranges, self.direction) - .ok_or(JsonProofBuilderError::MissingCommitment)?; - - self.builder.reveal(public_id)?; - - *self.built = true; - - Ok(()) - } - - fn commit_id(&self, range: Range) -> Option { - // TODO: support different kinds of commitments - self.commitments - .get_id_by_info(CommitmentKind::Blake3, range.into(), self.direction) - } -} diff --git a/tlsn/tlsn-formats/src/unknown.rs b/tlsn/tlsn-formats/src/unknown.rs deleted file mode 100644 index 1321311478..0000000000 --- a/tlsn/tlsn-formats/src/unknown.rs +++ /dev/null @@ -1,190 +0,0 @@ -use std::ops::Range; - -use serde::{Deserialize, Serialize}; - -use tlsn_core::{ - commitment::{ - CommitmentId, CommitmentKind, TranscriptCommitmentBuilder, - TranscriptCommitmentBuilderError, TranscriptCommitments, - }, - proof::{SubstringsProofBuilder, SubstringsProofBuilderError}, - Direction, -}; - -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum UnknownCommitmentBuilderError { - /// The provided range is out of bounds of the span. - #[error("provided range is out of bounds of the span")] - OutOfBounds, - #[error("commitment builder error: {0}")] - Commitment(#[from] TranscriptCommitmentBuilderError), -} - -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum UnknownProofBuilderError { - /// Missing commitment. - #[error("missing commitment")] - MissingCommitment, - /// The provided range is out of bounds of the span. - #[error("provided range is out of bounds of the span")] - OutOfBounds, - /// Substrings proof builder error. - #[error("proof builder error: {0}")] - Proof(#[from] SubstringsProofBuilderError), -} - -/// A span within the transcript with an unknown format. -#[derive(Debug, Serialize, Deserialize)] -pub struct UnknownSpan(pub(crate) Range); - -impl UnknownSpan { - pub(crate) fn new(span: Range) -> Self { - UnknownSpan(span) - } -} - -/// A builder for commitments to spans with an unknown format. -#[derive(Debug)] -pub struct UnknownCommitmentBuilder<'a> { - builder: &'a mut TranscriptCommitmentBuilder, - span: Range, - direction: Direction, - built: &'a mut bool, -} - -impl<'a> UnknownCommitmentBuilder<'a> { - pub(crate) fn new( - builder: &'a mut TranscriptCommitmentBuilder, - span: &'a UnknownSpan, - direction: Direction, - built: &'a mut bool, - ) -> Self { - UnknownCommitmentBuilder { - builder, - span: span.0.clone(), - direction, - built, - } - } - - /// Commits to the given range within the span. - pub fn range( - &mut self, - range: Range, - ) -> Result { - let span_range = self.span.clone(); - - let start = span_range.start + range.start; - let end = span_range.start + range.end; - - if start >= end || end > span_range.end { - return Err(UnknownCommitmentBuilderError::OutOfBounds); - } - - match self.direction { - Direction::Sent => self.builder.commit_sent(start..end), - Direction::Received => self.builder.commit_recv(start..end), - } - .map_err(From::from) - } - - /// Commits to the entire body. - pub fn all(&mut self) -> Result { - match self.direction { - Direction::Sent => self.builder.commit_sent(self.span.clone()), - Direction::Received => self.builder.commit_recv(self.span.clone()), - } - .map_err(From::from) - } - - /// Builds the commitment. - pub fn build(self) -> Result<(), UnknownCommitmentBuilderError> { - // commit to the entire span - match self.direction { - Direction::Sent => self.builder.commit_sent(self.span.clone()), - Direction::Received => self.builder.commit_recv(self.span.clone()), - }?; - - *self.built = true; - - Ok(()) - } -} - -/// A proof builder for spans with an unknown format. -#[derive(Debug)] -pub struct UnknownProofBuilder<'a, 'b> { - builder: &'a mut SubstringsProofBuilder<'b>, - commitments: &'a TranscriptCommitments, - span: Range, - direction: Direction, - built: &'a mut bool, -} - -impl<'a, 'b> UnknownProofBuilder<'a, 'b> { - pub(crate) fn new( - builder: &'a mut SubstringsProofBuilder<'b>, - commitments: &'a TranscriptCommitments, - span: &'a UnknownSpan, - direction: Direction, - built: &'a mut bool, - ) -> Self { - UnknownProofBuilder { - builder, - commitments, - span: span.0.clone(), - direction, - built, - } - } - - /// Reveals the entire span. - pub fn all(&mut self) -> Result<(), UnknownProofBuilderError> { - let id = self - .commit_id(self.span.clone()) - .ok_or(UnknownProofBuilderError::MissingCommitment)?; - - self.builder.reveal(id)?; - - Ok(()) - } - - /// Reveals the given range within the span. - /// - /// # Arguments - /// - /// * `range` - The range to reveal. - pub fn range(&mut self, range: Range) -> Result<(), UnknownProofBuilderError> { - let span_range = self.span.clone(); - - let start = span_range.start + range.start; - let end = span_range.start + range.end; - - if start >= end || end > span_range.end { - return Err(UnknownProofBuilderError::OutOfBounds); - } - - let id = self - .commit_id(start..end) - .ok_or(UnknownProofBuilderError::MissingCommitment)?; - - self.builder.reveal(id)?; - - Ok(()) - } - - /// Builds the proof. - pub fn build(self) -> Result<(), UnknownProofBuilderError> { - *self.built = true; - - Ok(()) - } - - fn commit_id(&self, range: Range) -> Option { - // TODO: support different kinds of commitments - self.commitments - .get_id_by_info(CommitmentKind::Blake3, range.into(), self.direction) - } -} diff --git a/tlsn/tlsn-prover/Cargo.toml b/tlsn/tlsn-prover/Cargo.toml deleted file mode 100644 index 296e458e74..0000000000 --- a/tlsn/tlsn-prover/Cargo.toml +++ /dev/null @@ -1,53 +0,0 @@ -[package] -name = "tlsn-prover" -authors = ["TLSNotary Team"] -description = "Contains the prover library" -keywords = ["tls", "mpc", "2pc", "prover"] -categories = ["cryptography"] -license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" -edition = "2021" - -[features] -default = [] -#formats = ["dep:tlsn-formats"] -tracing = [ - "dep:tracing", - "tlsn-tls-client-async/tracing", - "tlsn-tls-mpc/tracing", - "tlsn-common/tracing", -] - -[dependencies] -tlsn-tls-core.workspace = true -tlsn-tls-client.workspace = true -tlsn-tls-client-async.workspace = true -tlsn-core.workspace = true -tlsn-common.workspace = true -#tlsn-formats = { workspace = true, optional = true } -tlsn-tls-mpc.workspace = true - -tlsn-utils.workspace = true -tlsn-utils-aio.workspace = true - -mpz-share-conversion.workspace = true -mpz-garble.workspace = true -mpz-garble-core.workspace = true -mpz-ot.workspace = true -mpz-core.workspace = true - -rand.workspace = true -futures.workspace = true -thiserror.workspace = true -webpki-roots.workspace = true -derive_builder.workspace = true -opaque-debug.workspace = true -bytes.workspace = true - -tracing = { workspace = true, optional = true } - -web-time.workspace = true - -[target.'cfg(target_arch = "wasm32")'.dependencies] -ring = { version = "0.17", features = ["wasm32_unknown_unknown_js"] } -getrandom = { version = "0.2", features = ["js"] } diff --git a/tlsn/tlsn-prover/src/http/mod.rs b/tlsn/tlsn-prover/src/http/mod.rs deleted file mode 100644 index 0438583f6b..0000000000 --- a/tlsn/tlsn-prover/src/http/mod.rs +++ /dev/null @@ -1,101 +0,0 @@ -//! HTTP Prover. -//! -//! An HTTP prover can be created from a TLS [`Prover`](crate::tls::Prover), after the TLS connection has been closed, by calling the -//! [`to_http`](crate::tls::Prover::to_http) method. -//! -//! The [`HttpProver`] provides higher-level APIs for committing and proving data communicated during an HTTP connection. - -pub mod state; - -use tlsn_formats::http::{parse_requests, parse_responses, ParseError}; - -use crate::tls::{state as prover_state, Prover, ProverError}; - -pub use tlsn_formats::{ - http::{ - HttpCommitmentBuilder, HttpCommitmentBuilderError, HttpProofBuilder, HttpProofBuilderError, - HttpRequestCommitmentBuilder, HttpResponseCommitmentBuilder, NotarizedHttpSession, - }, - json::{ - JsonBody, JsonCommitmentBuilder, JsonCommitmentBuilderError, JsonProofBuilder, - JsonProofBuilderError, - }, -}; - -/// HTTP prover error. -#[derive(Debug, thiserror::Error)] -pub enum HttpProverError { - /// An error originated from the TLS prover. - #[error(transparent)] - Prover(#[from] ProverError), - /// Commitment error. - #[error(transparent)] - Commitment(#[from] HttpCommitmentBuilderError), - /// An error occurred while parsing the HTTP data. - #[error(transparent)] - Parse(#[from] ParseError), -} - -/// An HTTP prover. -pub struct HttpProver { - state: S, -} - -impl HttpProver { - /// Creates a new HTTP prover. - pub fn new(prover: Prover) -> Result { - let requests = parse_requests(prover.sent_transcript().data().clone())?; - let responses = parse_responses(prover.recv_transcript().data().clone())?; - - Ok(Self { - state: state::Closed { - prover, - requests, - responses, - }, - }) - } - - /// Starts notarization of the HTTP session. - /// - /// If the verifier is a Notary, this function will transition the prover to the next state - /// where it can generate commitments to the transcript prior to finalization. - pub fn start_notarize(self) -> HttpProver { - HttpProver { - state: state::Notarize { - prover: self.state.prover.start_notarize(), - requests: self.state.requests, - responses: self.state.responses, - }, - } - } -} - -impl HttpProver { - /// Generates commitments to the HTTP session prior to finalization. - pub fn commit(&mut self) -> Result<(), HttpProverError> { - self.commitment_builder().build()?; - - Ok(()) - } - - /// Returns a commitment builder for the HTTP session. - /// - /// This is for more advanced use cases, you should prefer using `commit` instead. - pub fn commitment_builder(&mut self) -> HttpCommitmentBuilder { - HttpCommitmentBuilder::new( - self.state.prover.commitment_builder(), - &self.state.requests, - &self.state.responses, - ) - } - - /// Finalizes the HTTP session. - pub async fn finalize(self) -> Result { - Ok(NotarizedHttpSession::new( - self.state.prover.finalize().await?, - self.state.requests, - self.state.responses, - )) - } -} diff --git a/tlsn/tlsn-prover/src/http/state.rs b/tlsn/tlsn-prover/src/http/state.rs deleted file mode 100644 index b767a56bfc..0000000000 --- a/tlsn/tlsn-prover/src/http/state.rs +++ /dev/null @@ -1,32 +0,0 @@ -//! HTTP prover state. - -use tlsn_formats::http::{Body, Request, Response}; - -use crate::tls::{state as prover_state, Prover}; - -/// The state of an HTTP prover -pub trait State: sealed::Sealed {} - -/// Connection closed state. -pub struct Closed { - pub(super) prover: Prover, - pub(super) requests: Vec<(Request, Option)>, - pub(super) responses: Vec<(Response, Option)>, -} - -/// Notarizing state. -pub struct Notarize { - pub(super) prover: Prover, - pub(super) requests: Vec<(Request, Option)>, - pub(super) responses: Vec<(Response, Option)>, -} - -impl State for Closed {} -impl State for Notarize {} - -mod sealed { - pub trait Sealed {} - - impl Sealed for super::Closed {} - impl Sealed for super::Notarize {} -} diff --git a/tlsn/tlsn-prover/src/lib.rs b/tlsn/tlsn-prover/src/lib.rs deleted file mode 100644 index fb65664075..0000000000 --- a/tlsn/tlsn-prover/src/lib.rs +++ /dev/null @@ -1,13 +0,0 @@ -//! The prover library -//! -//! This library contains TLSNotary prover implementations: -//! * [`tls`] for the low-level API for working with the underlying byte streams of a TLS connection. -//! * [`http`] for a higher-level API which provides abstractions for working with HTTP connections. - -#![deny(missing_docs, unreachable_pub, unused_must_use)] -#![deny(clippy::all)] -#![forbid(unsafe_code)] - -#[cfg(feature = "formats")] -pub mod http; -pub mod tls; diff --git a/tlsn/tlsn-prover/src/tls/config.rs b/tlsn/tlsn-prover/src/tls/config.rs deleted file mode 100644 index e1d51a35f2..0000000000 --- a/tlsn/tlsn-prover/src/tls/config.rs +++ /dev/null @@ -1,108 +0,0 @@ -use mpz_ot::{chou_orlandi, kos}; -use mpz_share_conversion::{ReceiverConfig, SenderConfig}; -use tls_client::RootCertStore; -use tls_mpc::{MpcTlsCommonConfig, MpcTlsLeaderConfig}; - -const DEFAULT_MAX_TRANSCRIPT_SIZE: usize = 1 << 14; // 16Kb - -/// Configuration for the prover -#[derive(Debug, Clone, derive_builder::Builder)] -pub struct ProverConfig { - /// Id of the notarization session. - #[builder(setter(into))] - id: String, - /// The server DNS name. - #[builder(setter(into))] - server_dns: String, - /// TLS root certificate store. - #[builder(setter(strip_option), default = "default_root_store()")] - pub(crate) root_cert_store: RootCertStore, - /// Maximum transcript size in bytes - /// - /// This includes the number of bytes sent and received to the server. - #[builder(default = "DEFAULT_MAX_TRANSCRIPT_SIZE")] - max_transcript_size: usize, -} - -impl ProverConfig { - /// Create a new builder for `ProverConfig`. - pub fn builder() -> ProverConfigBuilder { - ProverConfigBuilder::default() - } - - /// Get the maximum transcript size in bytes. - pub fn max_transcript_size(&self) -> usize { - self.max_transcript_size - } - - /// Returns the server DNS name. - pub fn server_dns(&self) -> &str { - &self.server_dns - } - - pub(crate) fn build_mpc_tls_config(&self) -> MpcTlsLeaderConfig { - MpcTlsLeaderConfig::builder() - .common( - MpcTlsCommonConfig::builder() - .id(format!("{}/mpc_tls", &self.id)) - .max_transcript_size(self.max_transcript_size) - .handshake_commit(true) - .build() - .unwrap(), - ) - .build() - .unwrap() - } - - pub(crate) fn build_base_ot_sender_config(&self) -> chou_orlandi::SenderConfig { - chou_orlandi::SenderConfig::builder() - .receiver_commit() - .build() - .unwrap() - } - - pub(crate) fn build_base_ot_receiver_config(&self) -> chou_orlandi::ReceiverConfig { - chou_orlandi::ReceiverConfig::default() - } - - pub(crate) fn build_ot_sender_config(&self) -> kos::SenderConfig { - kos::SenderConfig::default() - } - - pub(crate) fn build_ot_receiver_config(&self) -> kos::ReceiverConfig { - kos::ReceiverConfig::builder() - .sender_commit() - .build() - .unwrap() - } - - pub(crate) fn ot_count(&self) -> usize { - self.max_transcript_size * 8 - } - - pub(crate) fn build_p256_sender_config(&self) -> SenderConfig { - SenderConfig::builder().id("p256/0").build().unwrap() - } - - pub(crate) fn build_p256_receiver_config(&self) -> ReceiverConfig { - ReceiverConfig::builder().id("p256/1").build().unwrap() - } - - pub(crate) fn build_gf2_config(&self) -> SenderConfig { - SenderConfig::builder().id("gf2").record().build().unwrap() - } -} - -/// Default root store using mozilla certs. -fn default_root_store() -> RootCertStore { - let mut root_store = tls_client::RootCertStore::empty(); - root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - tls_client::OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject.as_ref(), - ta.subject_public_key_info.as_ref(), - ta.name_constraints.as_ref().map(|nc| nc.as_ref()), - ) - })); - - root_store -} diff --git a/tlsn/tlsn-prover/src/tls/error.rs b/tlsn/tlsn-prover/src/tls/error.rs deleted file mode 100644 index 5ecc8751c6..0000000000 --- a/tlsn/tlsn-prover/src/tls/error.rs +++ /dev/null @@ -1,103 +0,0 @@ -use std::error::Error; -use tls_mpc::MpcTlsError; -use tlsn_core::commitment::TranscriptCommitmentBuilderError; - -/// An error that can occur during proving. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum ProverError { - #[error(transparent)] - TlsClientError(#[from] tls_client::Error), - #[error(transparent)] - AsyncClientError(#[from] tls_client_async::ConnectionError), - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - MuxerError(#[from] utils_aio::mux::MuxerError), - #[error("notarization error: {0}")] - NotarizationError(String), - #[error(transparent)] - CommitmentBuilder(#[from] TranscriptCommitmentBuilderError), - #[error(transparent)] - InvalidServerName(#[from] tls_core::dns::InvalidDnsNameError), - #[error("error occurred in MPC protocol: {0}")] - MpcError(Box), - #[error("server did not send a close_notify")] - ServerNoCloseNotify, - #[error(transparent)] - CommitmentError(#[from] CommitmentError), - #[error("Range exceeds transcript length")] - InvalidRange, -} - -impl From for ProverError { - fn from(e: MpcTlsError) -> Self { - Self::MpcError(Box::new(e)) - } -} - -impl From for ProverError { - fn from(e: mpz_ot::OTError) -> Self { - Self::MpcError(Box::new(e)) - } -} - -impl From for ProverError { - fn from(e: mpz_garble::VmError) -> Self { - Self::MpcError(Box::new(e)) - } -} - -impl From for ProverError { - fn from(e: mpz_garble::MemoryError) -> Self { - Self::MpcError(Box::new(e)) - } -} - -impl From for ProverError { - fn from(e: mpz_garble::ProveError) -> Self { - Self::MpcError(Box::new(e)) - } -} - -impl From for ProverError { - fn from(value: mpz_ot::actor::kos::SenderActorError) -> Self { - Self::MpcError(Box::new(value)) - } -} - -impl From for ProverError { - fn from(value: mpz_ot::actor::kos::ReceiverActorError) -> Self { - Self::MpcError(Box::new(value)) - } -} - -#[derive(Debug)] -pub struct OTShutdownError; - -impl std::fmt::Display for OTShutdownError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("ot shutdown prior to completion") - } -} - -impl Error for OTShutdownError {} - -impl From for ProverError { - fn from(e: OTShutdownError) -> Self { - Self::MpcError(Box::new(e)) - } -} - -impl From for ProverError { - fn from(e: tlsn_core::merkle::MerkleError) -> Self { - Self::CommitmentError(e.into()) - } -} - -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum CommitmentError { - #[error(transparent)] - MerkleError(#[from] tlsn_core::merkle::MerkleError), -} diff --git a/tlsn/tlsn-prover/src/tls/future.rs b/tlsn/tlsn-prover/src/tls/future.rs deleted file mode 100644 index e4fad79769..0000000000 --- a/tlsn/tlsn-prover/src/tls/future.rs +++ /dev/null @@ -1,75 +0,0 @@ -//! This module collects futures which are used by the [Prover]. - -use super::{state, Prover, ProverControl, ProverError}; -use futures::{future::FusedFuture, Future}; -use std::pin::Pin; - -/// Prover future which must be polled for the TLS connection to make progress. -pub struct ProverFuture { - #[allow(clippy::type_complexity)] - pub(crate) fut: - Pin, ProverError>> + Send + 'static>>, - pub(crate) ctrl: ProverControl, -} - -impl ProverFuture { - /// Returns a controller for the prover for advanced functionality. - pub fn control(&self) -> ProverControl { - self.ctrl.clone() - } -} - -impl Future for ProverFuture { - type Output = Result, ProverError>; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - self.fut.as_mut().poll(cx) - } -} - -/// A future which must be polled for the muxer to make progress. -pub(crate) struct MuxFuture { - pub(crate) fut: Pin> + Send + 'static>>, -} - -impl Future for MuxFuture { - type Output = Result<(), ProverError>; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - self.fut.as_mut().poll(cx) - } -} - -impl FusedFuture for MuxFuture { - fn is_terminated(&self) -> bool { - self.fut.is_terminated() - } -} - -/// A future which must be polled for the Oblivious Transfer protocol to make progress. -pub(crate) struct OTFuture { - pub(crate) fut: Pin> + Send + 'static>>, -} - -impl Future for OTFuture { - type Output = Result<(), ProverError>; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - self.fut.as_mut().poll(cx) - } -} - -impl FusedFuture for OTFuture { - fn is_terminated(&self) -> bool { - self.fut.is_terminated() - } -} diff --git a/tlsn/tlsn-prover/src/tls/mod.rs b/tlsn/tlsn-prover/src/tls/mod.rs deleted file mode 100644 index 1b0958bb6f..0000000000 --- a/tlsn/tlsn-prover/src/tls/mod.rs +++ /dev/null @@ -1,378 +0,0 @@ -//! TLS prover. -//! -//! This module provides the TLS prover, which is used with a TLS verifier to prove a transcript of communications with a server. -//! -//! The TLS prover provides a low-level API, see the [`HTTP prover`](crate::http) which provides abstractions for working -//! with HTTP sessions. - -mod config; -mod error; -mod future; -mod notarize; -mod prove; -pub mod state; - -pub use config::{ProverConfig, ProverConfigBuilder, ProverConfigBuilderError}; -pub use error::ProverError; -pub use future::ProverFuture; -use tlsn_common::{ - mux::{attach_mux, MuxControl}, - Role, -}; - -use error::OTShutdownError; -use future::{MuxFuture, OTFuture}; -use futures::{AsyncRead, AsyncWrite, FutureExt, StreamExt, TryFutureExt}; -use mpz_garble::{config::Role as DEAPRole, protocol::deap::DEAPVm}; -use mpz_ot::{ - actor::kos::{ReceiverActor, SenderActor, SharedReceiver, SharedSender}, - chou_orlandi, kos, -}; -use mpz_share_conversion as ff; -use rand::Rng; -use state::{Notarize, Prove}; -use std::sync::Arc; -use tls_client::{ClientConnection, ServerName as TlsServerName}; -use tls_client_async::{bind_client, ClosedConnection, TlsConnection}; -use tls_mpc::{setup_components, LeaderCtrl, MpcTlsLeader, TlsRole}; -use tlsn_core::transcript::Transcript; -use utils_aio::mux::MuxChannel; - -#[cfg(feature = "formats")] -use http::{state as http_state, HttpProver, HttpProverError}; - -#[cfg(feature = "tracing")] -use tracing::{debug, debug_span, instrument, Instrument}; - -/// A prover instance. -#[derive(Debug)] -pub struct Prover { - config: ProverConfig, - state: T, -} - -impl Prover { - /// Creates a new prover. - /// - /// # Arguments - /// - /// * `config` - The configuration for the prover. - pub fn new(config: ProverConfig) -> Self { - Self { - config, - state: state::Initialized, - } - } - - /// Set up the prover. - /// - /// This performs all MPC setup prior to establishing the connection to the - /// application server. - /// - /// # Arguments - /// - /// * `socket` - The socket to the notary. - pub async fn setup( - self, - socket: S, - ) -> Result, ProverError> { - let (mut mux, mux_ctrl) = attach_mux(socket, Role::Prover); - - let mut mux_fut = MuxFuture { - fut: Box::pin(async move { mux.run().await.map_err(ProverError::from) }.fuse()), - }; - - let mpc_setup_fut = setup_mpc_backend(&self.config, mux_ctrl.clone()); - let (mpc_tls, vm, _, gf2, ot_fut) = futures::select! { - res = mpc_setup_fut.fuse() => res?, - _ = (&mut mux_fut).fuse() => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; - - Ok(Prover { - config: self.config, - state: state::Setup { - mux_ctrl, - mux_fut, - mpc_tls, - vm, - ot_fut, - gf2, - }, - }) - } -} - -impl Prover { - /// Connects to the server using the provided socket. - /// - /// Returns a handle to the TLS connection, a future which returns the prover once the connection is - /// closed. - /// - /// # Arguments - /// - /// * `socket` - The socket to the server. - #[cfg_attr( - feature = "tracing", - instrument(level = "debug", skip(self, socket), err) - )] - pub async fn connect( - self, - socket: S, - ) -> Result<(TlsConnection, ProverFuture), ProverError> { - let state::Setup { - mux_ctrl, - mut mux_fut, - mpc_tls, - vm, - mut ot_fut, - gf2, - } = self.state; - - let (mpc_ctrl, mpc_fut) = mpc_tls.run(); - - let server_name = TlsServerName::try_from(self.config.server_dns())?; - let config = tls_client::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(self.config.root_cert_store.clone()) - .with_no_client_auth(); - let client = - ClientConnection::new(Arc::new(config), Box::new(mpc_ctrl.clone()), server_name)?; - - let (conn, conn_fut) = bind_client(socket, client); - - let start_time = web_time::UNIX_EPOCH.elapsed().unwrap().as_secs(); - - let fut = Box::pin({ - let mpc_ctrl = mpc_ctrl.clone(); - #[allow(clippy::let_and_return)] - let fut = async move { - let conn_fut = async { - let ClosedConnection { sent, recv, .. } = futures::select! { - res = conn_fut.fuse() => res?, - _ = ot_fut => return Err(OTShutdownError)?, - _ = mux_fut => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; - - mpc_ctrl.close_connection().await?; - - Ok::<_, ProverError>((sent, recv)) - }; - - let ((sent, recv), mpc_tls_data) = - futures::try_join!(conn_fut, mpc_fut.map_err(ProverError::from))?; - - Ok(Prover { - config: self.config, - state: state::Closed { - mux_ctrl, - mux_fut, - vm, - ot_fut, - gf2, - start_time, - handshake_decommitment: mpc_tls_data - .handshake_decommitment - .expect("handshake was committed"), - server_public_key: mpc_tls_data.server_public_key, - transcript_tx: Transcript::new(sent), - transcript_rx: Transcript::new(recv), - }, - }) - }; - #[cfg(feature = "tracing")] - let fut = fut.instrument(debug_span!("prover_tls_connection")); - fut - }); - - Ok(( - conn, - ProverFuture { - fut, - ctrl: ProverControl { mpc_ctrl }, - }, - )) - } -} - -impl Prover { - /// Returns the transcript of the sent requests - pub fn sent_transcript(&self) -> &Transcript { - &self.state.transcript_tx - } - - /// Returns the transcript of the received responses - pub fn recv_transcript(&self) -> &Transcript { - &self.state.transcript_rx - } - - /// Creates an HTTP prover. - #[cfg(feature = "formats")] - pub fn to_http(self) -> Result, HttpProverError> { - HttpProver::new(self) - } - - /// Starts notarization of the TLS session. - /// - /// If the verifier is a Notary, this function will transition the prover to the next state - /// where it can generate commitments to the transcript prior to finalization. - pub fn start_notarize(self) -> Prover { - Prover { - config: self.config, - state: self.state.into(), - } - } - - /// Starts proving the TLS session. - /// - /// This function transitions the prover into a state where it can prove content of the - /// transcript. - pub fn start_prove(self) -> Prover { - Prover { - config: self.config, - state: self.state.into(), - } - } -} - -/// Performs a setup of the various MPC subprotocols. -#[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] -#[allow(clippy::type_complexity)] -async fn setup_mpc_backend( - config: &ProverConfig, - mut mux: MuxControl, -) -> Result< - ( - MpcTlsLeader, - DEAPVm, - SharedReceiver, - ff::ConverterSender, - OTFuture, - ), - ProverError, -> { - let (ot_send_sink, ot_send_stream) = mux.get_channel("ot/0").await?.split(); - let (ot_recv_sink, ot_recv_stream) = mux.get_channel("ot/1").await?.split(); - - let mut ot_sender_actor = SenderActor::new( - kos::Sender::new( - config.build_ot_sender_config(), - chou_orlandi::Receiver::new(config.build_base_ot_receiver_config()), - ), - ot_send_sink, - ot_send_stream, - ); - - let mut ot_receiver_actor = ReceiverActor::new( - kos::Receiver::new( - config.build_ot_receiver_config(), - chou_orlandi::Sender::new(config.build_base_ot_sender_config()), - ), - ot_recv_sink, - ot_recv_stream, - ); - - let ot_send = ot_sender_actor.sender(); - let ot_recv = ot_receiver_actor.receiver(); - - #[cfg(feature = "tracing")] - debug!("Starting OT setup"); - - futures::try_join!( - ot_sender_actor - .setup(config.ot_count()) - .map_err(ProverError::from), - ot_receiver_actor - .setup(config.ot_count()) - .map_err(ProverError::from) - )?; - - #[cfg(feature = "tracing")] - debug!("OT setup complete"); - - let ot_fut = OTFuture { - fut: Box::pin( - async move { - futures::try_join!( - ot_sender_actor.run().map_err(ProverError::from), - ot_receiver_actor.run().map_err(ProverError::from) - )?; - - Ok(()) - } - .fuse(), - ), - }; - - let mut vm = DEAPVm::new( - "vm", - DEAPRole::Leader, - rand::rngs::OsRng.gen(), - mux.get_channel("vm").await?, - Box::new(mux.clone()), - ot_send.clone(), - ot_recv.clone(), - ); - - let p256_sender_config = config.build_p256_sender_config(); - let channel = mux.get_channel(p256_sender_config.id()).await?; - let p256_send = - ff::ConverterSender::::new(p256_sender_config, ot_send.clone(), channel); - - let p256_receiver_config = config.build_p256_receiver_config(); - let channel = mux.get_channel(p256_receiver_config.id()).await?; - let p256_recv = - ff::ConverterReceiver::::new(p256_receiver_config, ot_recv.clone(), channel); - - let gf2_config = config.build_gf2_config(); - let channel = mux.get_channel(gf2_config.id()).await?; - let gf2 = ff::ConverterSender::::new(gf2_config, ot_send.clone(), channel); - - let mpc_tls_config = config.build_mpc_tls_config(); - - let (ke, prf, encrypter, decrypter) = setup_components( - mpc_tls_config.common(), - TlsRole::Leader, - &mut mux, - &mut vm, - p256_send, - p256_recv, - gf2.handle() - .map_err(|e| ProverError::MpcError(Box::new(e)))?, - ) - .await - .map_err(|e| ProverError::MpcError(Box::new(e)))?; - - let channel = mux.get_channel(mpc_tls_config.common().id()).await?; - let mut mpc_tls = MpcTlsLeader::new(mpc_tls_config, channel, ke, prf, encrypter, decrypter); - - mpc_tls.setup().await?; - - #[cfg(feature = "tracing")] - debug!("MPC backend setup complete"); - - Ok((mpc_tls, vm, ot_recv, gf2, ot_fut)) -} - -/// A controller for the prover. -#[derive(Clone)] -pub struct ProverControl { - mpc_ctrl: LeaderCtrl, -} - -impl ProverControl { - /// Defers decryption of data from the server until the server has closed the connection. - /// - /// This is a performance optimization which will significantly reduce the amount of upload bandwidth - /// used by the prover. - /// - /// # Notes - /// - /// * The prover may need to close the connection to the server in order for it to close the connection - /// on its end. If neither the prover or server close the connection this will cause a deadlock. - pub async fn defer_decryption(&self) -> Result<(), ProverError> { - self.mpc_ctrl - .defer_decryption() - .await - .map_err(ProverError::from) - } -} diff --git a/tlsn/tlsn-prover/src/tls/notarize.rs b/tlsn/tlsn-prover/src/tls/notarize.rs deleted file mode 100644 index 76d550beff..0000000000 --- a/tlsn/tlsn-prover/src/tls/notarize.rs +++ /dev/null @@ -1,115 +0,0 @@ -//! This module handles the notarization phase of the prover. -//! -//! The prover deals with a TLS verifier that is only a notary. - -use crate::tls::error::OTShutdownError; - -use super::{ff::ShareConversionReveal, state::Notarize, Prover, ProverError}; -use futures::{FutureExt, SinkExt, StreamExt}; -use tlsn_core::{ - commitment::TranscriptCommitmentBuilder, - msg::{SignedSessionHeader, TlsnMessage}, - transcript::Transcript, - NotarizedSession, ServerName, SessionData, -}; -#[cfg(feature = "tracing")] -use tracing::instrument; -use utils_aio::{expect_msg_or_err, mux::MuxChannel}; - -impl Prover { - /// Returns the transcript of the sent requests - pub fn sent_transcript(&self) -> &Transcript { - &self.state.transcript_tx - } - - /// Returns the transcript of the received responses - pub fn recv_transcript(&self) -> &Transcript { - &self.state.transcript_rx - } - - /// Returns the transcript commitment builder - pub fn commitment_builder(&mut self) -> &mut TranscriptCommitmentBuilder { - &mut self.state.builder - } - - /// Finalize the notarization returning a [`NotarizedSession`] - #[cfg_attr(feature = "tracing", instrument(level = "info", skip(self), err))] - pub async fn finalize(self) -> Result { - let Notarize { - mut mux_ctrl, - mut mux_fut, - mut vm, - mut ot_fut, - mut gf2, - start_time, - handshake_decommitment, - server_public_key, - transcript_tx, - transcript_rx, - builder, - } = self.state; - - let commitments = builder.build()?; - - let session_data = SessionData::new( - ServerName::Dns(self.config.server_dns().to_string()), - handshake_decommitment, - transcript_tx, - transcript_rx, - commitments, - ); - - let merkle_root = session_data.commitments().merkle_root(); - - let mut notarize_fut = Box::pin(async move { - let mut channel = mux_ctrl.get_channel("notarize").await?; - - channel - .send(TlsnMessage::TranscriptCommitmentRoot(merkle_root)) - .await?; - - let notary_encoder_seed = vm - .finalize() - .await - .map_err(|e| ProverError::MpcError(Box::new(e)))? - .expect("encoder seed returned"); - - // This is a temporary approach until a maliciously secure share conversion protocol is implemented. - // The prover is essentially revealing the TLS MAC key. In some exotic scenarios this allows a malicious - // TLS verifier to modify the prover's request. - gf2.reveal() - .await - .map_err(|e| ProverError::MpcError(Box::new(e)))?; - - let signed_header = expect_msg_or_err!(channel, TlsnMessage::SignedSessionHeader)?; - - Ok::<_, ProverError>((notary_encoder_seed, signed_header)) - }) - .fuse(); - - let (notary_encoder_seed, SignedSessionHeader { header, signature }) = futures::select_biased! { - res = notarize_fut => res?, - _ = ot_fut => return Err(OTShutdownError)?, - _ = &mut mux_fut => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; - // Wait for the notary to correctly close the connection - mux_fut.await?; - - // Check the header is consistent with the Prover's view - header - .verify( - start_time, - &server_public_key, - &session_data.commitments().merkle_root(), - ¬ary_encoder_seed, - &session_data.session_info().handshake_decommitment, - ) - .map_err(|_| { - ProverError::NotarizationError( - "notary signed an inconsistent session header".to_string(), - ) - })?; - - Ok(NotarizedSession::new(header, Some(signature), session_data)) - } -} diff --git a/tlsn/tlsn-prover/src/tls/prove.rs b/tlsn/tlsn-prover/src/tls/prove.rs deleted file mode 100644 index af2b9ce78d..0000000000 --- a/tlsn/tlsn-prover/src/tls/prove.rs +++ /dev/null @@ -1,206 +0,0 @@ -//! This module handles the proving phase of the prover. -//! -//! Here the prover deals with a verifier directly, so there is no notary involved. Instead -//! the verifier directly verifies parts of the transcript. - -use super::{state::Prove as ProveState, Prover, ProverError}; -use crate::tls::error::OTShutdownError; -use futures::{FutureExt, SinkExt}; -use mpz_garble::{Memory, Prove, Vm}; -use mpz_share_conversion::ShareConversionReveal; -use tlsn_core::{ - msg::TlsnMessage, proof::SessionInfo, transcript::get_value_ids, Direction, ServerName, - Transcript, -}; -use utils::range::{RangeSet, RangeUnion}; -use utils_aio::mux::MuxChannel; - -#[cfg(feature = "tracing")] -use tracing::info; - -impl Prover { - /// Returns the transcript of the sent requests - pub fn sent_transcript(&self) -> &Transcript { - &self.state.transcript_tx - } - - /// Returns the transcript of the received responses - pub fn recv_transcript(&self) -> &Transcript { - &self.state.transcript_rx - } - - /// Reveal certain parts of the transcripts to the verifier - /// - /// This function allows to collect certain transcript ranges. When [Prover::prove] is called, these - /// ranges will be opened to the verifier. - /// - /// # Arguments - /// * `ranges` - The ranges of the transcript to reveal - /// * `direction` - The direction of the transcript to reveal - pub fn reveal( - &mut self, - ranges: impl Into>, - direction: Direction, - ) -> Result<(), ProverError> { - let sent_ids = &mut self.state.proving_info.sent_ids; - let recv_ids = &mut self.state.proving_info.recv_ids; - - let range_set = ranges.into(); - - // Check ranges - let transcript = match direction { - Direction::Sent => &self.state.transcript_tx, - Direction::Received => &self.state.transcript_rx, - }; - - if range_set.max().unwrap_or_default() > transcript.data().len() { - return Err(ProverError::InvalidRange); - } - - match direction { - Direction::Sent => *sent_ids = sent_ids.union(&range_set), - Direction::Received => *recv_ids = recv_ids.union(&range_set), - } - - Ok(()) - } - - /// Prove transcript values - pub async fn prove(&mut self) -> Result<(), ProverError> { - let mut proving_info = std::mem::take(&mut self.state.proving_info); - - let mut prove_fut = Box::pin(async { - // Create a new channel and vm thread if not already present - let channel = if let Some(ref mut channel) = self.state.channel { - channel - } else { - self.state.channel = Some(self.state.mux_ctrl.get_channel("prove-verify").await?); - self.state.channel.as_mut().unwrap() - }; - - let prove_thread = if let Some(ref mut prove_thread) = self.state.prove_thread { - prove_thread - } else { - self.state.prove_thread = Some(self.state.vm.new_thread("prove-verify").await?); - self.state.prove_thread.as_mut().unwrap() - }; - - // Now prove the transcript parts which have been marked for reveal - let sent_value_ids = proving_info - .sent_ids - .iter_ranges() - .map(|r| get_value_ids(&r.into(), Direction::Sent).collect::>()); - let recv_value_ids = proving_info - .recv_ids - .iter_ranges() - .map(|r| get_value_ids(&r.into(), Direction::Received).collect::>()); - - let value_refs = sent_value_ids - .chain(recv_value_ids) - .map(|ids| { - let inner_refs = ids - .iter() - .map(|id| { - prove_thread - .get_value(id.as_str()) - .expect("Byte should be in VM memory") - }) - .collect::>(); - - prove_thread - .array_from_values(inner_refs.as_slice()) - .expect("Byte should be in VM Memory") - }) - .collect::>(); - - // Extract cleartext we want to reveal from transcripts - let mut cleartext = - Vec::with_capacity(proving_info.sent_ids.len() + proving_info.recv_ids.len()); - proving_info - .sent_ids - .iter_ranges() - .for_each(|r| cleartext.extend_from_slice(&self.state.transcript_tx.data()[r])); - proving_info - .recv_ids - .iter_ranges() - .for_each(|r| cleartext.extend_from_slice(&self.state.transcript_rx.data()[r])); - proving_info.cleartext = cleartext; - - // Send the proving info to the verifier - channel.send(TlsnMessage::ProvingInfo(proving_info)).await?; - - #[cfg(feature = "tracing")] - info!("Sent proving info to verifier"); - - // Prove the revealed transcript parts - prove_thread.prove(value_refs.as_slice()).await?; - - #[cfg(feature = "tracing")] - info!("Successfully proved cleartext"); - - Ok::<_, ProverError>(()) - }) - .fuse(); - - futures::select_biased! { - res = prove_fut => res?, - _ = &mut self.state.ot_fut => return Err(OTShutdownError)?, - _ = &mut self.state.mux_fut => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; - - Ok(()) - } - - /// Finalize the proving - pub async fn finalize(self) -> Result<(), ProverError> { - let ProveState { - mut mux_ctrl, - mut mux_fut, - mut vm, - mut ot_fut, - mut gf2, - handshake_decommitment, - .. - } = self.state; - - // Create session data and session_info - let session_info = SessionInfo { - server_name: ServerName::Dns(self.config.server_dns().to_string()), - handshake_decommitment, - }; - - let mut finalize_fut = Box::pin(async move { - let mut channel = mux_ctrl.get_channel("finalize").await?; - - _ = vm - .finalize() - .await - .map_err(|e| ProverError::MpcError(Box::new(e)))? - .expect("encoder seed returned"); - - // This is a temporary approach until a maliciously secure share conversion protocol is implemented. - // The prover is essentially revealing the TLS MAC key. In some exotic scenarios this allows a malicious - // TLS verifier to modify the prover's request. - gf2.reveal() - .await - .map_err(|e| ProverError::MpcError(Box::new(e)))?; - - // Send session_info to the verifier - channel.send(TlsnMessage::SessionInfo(session_info)).await?; - - Ok::<_, ProverError>(()) - }) - .fuse(); - - futures::select_biased! { - res = finalize_fut => res?, - _ = ot_fut => return Err(OTShutdownError)?, - _ = &mut mux_fut => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; - - // We need to wait for the verifier to correctly close the connection. Otherwise the prover - // would rush ahead and close the connection before the verifier has finished. - mux_fut.await?; - Ok(()) - } -} diff --git a/tlsn/tlsn-prover/src/tls/state.rs b/tlsn/tlsn-prover/src/tls/state.rs deleted file mode 100644 index ba1f4d75a7..0000000000 --- a/tlsn/tlsn-prover/src/tls/state.rs +++ /dev/null @@ -1,182 +0,0 @@ -//! TLS prover states. - -use crate::tls::{MuxFuture, OTFuture}; -use mpz_core::commit::Decommitment; -use mpz_garble::protocol::deap::{DEAPThread, DEAPVm, PeerEncodings}; -use mpz_garble_core::{encoding_state, EncodedValue}; -use mpz_ot::actor::kos::{SharedReceiver, SharedSender}; -use mpz_share_conversion::{ConverterSender, Gf2_128}; -use std::collections::HashMap; -use tls_core::{handshake::HandshakeData, key::PublicKey}; -use tls_mpc::MpcTlsLeader; -use tlsn_common::mux::MuxControl; -use tlsn_core::{ - commitment::TranscriptCommitmentBuilder, - msg::{ProvingInfo, TlsnMessage}, - Transcript, -}; -use utils_aio::duplex::Duplex; - -/// Entry state -pub struct Initialized; - -opaque_debug::implement!(Initialized); - -/// State after MPC setup has completed. -pub struct Setup { - /// A muxer for communication with the Notary - pub(crate) mux_ctrl: MuxControl, - pub(crate) mux_fut: MuxFuture, - - pub(crate) mpc_tls: MpcTlsLeader, - pub(crate) vm: DEAPVm, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterSender, -} - -opaque_debug::implement!(Setup); - -/// State after the TLS connection has been closed. -pub struct Closed { - pub(crate) mux_ctrl: MuxControl, - pub(crate) mux_fut: MuxFuture, - - pub(crate) vm: DEAPVm, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterSender, - - pub(crate) start_time: u64, - pub(crate) handshake_decommitment: Decommitment, - pub(crate) server_public_key: PublicKey, - - pub(crate) transcript_tx: Transcript, - pub(crate) transcript_rx: Transcript, -} - -opaque_debug::implement!(Closed); - -/// Notarizing state. -pub struct Notarize { - /// A muxer for communication with the Notary - pub(crate) mux_ctrl: MuxControl, - pub(crate) mux_fut: MuxFuture, - - pub(crate) vm: DEAPVm, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterSender, - - pub(crate) start_time: u64, - pub(crate) handshake_decommitment: Decommitment, - pub(crate) server_public_key: PublicKey, - - pub(crate) transcript_tx: Transcript, - pub(crate) transcript_rx: Transcript, - - pub(crate) builder: TranscriptCommitmentBuilder, -} - -opaque_debug::implement!(Notarize); - -impl From for Notarize { - fn from(state: Closed) -> Self { - let encodings = collect_encodings(&state.vm, &state.transcript_tx, &state.transcript_rx); - - let encoding_provider = Box::new(move |ids: &[&str]| { - ids.iter().map(|id| encodings.get(*id).cloned()).collect() - }); - - let builder = TranscriptCommitmentBuilder::new( - encoding_provider, - state.transcript_tx.data().len(), - state.transcript_rx.data().len(), - ); - - Self { - mux_ctrl: state.mux_ctrl, - mux_fut: state.mux_fut, - vm: state.vm, - ot_fut: state.ot_fut, - gf2: state.gf2, - start_time: state.start_time, - handshake_decommitment: state.handshake_decommitment, - server_public_key: state.server_public_key, - transcript_tx: state.transcript_tx, - transcript_rx: state.transcript_rx, - builder, - } - } -} - -/// Proving state. -pub struct Prove { - pub(crate) mux_ctrl: MuxControl, - pub(crate) mux_fut: MuxFuture, - - pub(crate) vm: DEAPVm, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterSender, - - pub(crate) handshake_decommitment: Decommitment, - - pub(crate) transcript_tx: Transcript, - pub(crate) transcript_rx: Transcript, - - pub(crate) proving_info: ProvingInfo, - pub(crate) channel: Option>>, - pub(crate) prove_thread: Option>, -} - -impl From for Prove { - fn from(state: Closed) -> Self { - Self { - mux_ctrl: state.mux_ctrl, - mux_fut: state.mux_fut, - vm: state.vm, - ot_fut: state.ot_fut, - gf2: state.gf2, - handshake_decommitment: state.handshake_decommitment, - transcript_tx: state.transcript_tx, - transcript_rx: state.transcript_rx, - proving_info: ProvingInfo::default(), - channel: None, - prove_thread: None, - } - } -} - -#[allow(missing_docs)] -pub trait ProverState: sealed::Sealed {} - -impl ProverState for Initialized {} -impl ProverState for Setup {} -impl ProverState for Closed {} -impl ProverState for Notarize {} -impl ProverState for Prove {} - -mod sealed { - pub trait Sealed {} - impl Sealed for super::Initialized {} - impl Sealed for super::Setup {} - impl Sealed for super::Closed {} - impl Sealed for super::Notarize {} - impl Sealed for super::Prove {} -} - -fn collect_encodings( - vm: &DEAPVm, - transcript_tx: &Transcript, - transcript_rx: &Transcript, -) -> HashMap> { - let tx_ids = (0..transcript_tx.data().len()).map(|id| format!("tx/{id}")); - let rx_ids = (0..transcript_rx.data().len()).map(|id| format!("rx/{id}")); - - let ids = tx_ids.chain(rx_ids).collect::>(); - let id_refs = ids.iter().map(|id| id.as_ref()).collect::>(); - - vm.get_peer_encodings(&id_refs) - .expect("encodings for all transcript values should be present") - .into_iter() - .zip(ids) - .map(|(encoding, id)| (id, encoding)) - .collect() -} diff --git a/tlsn/tlsn-server-fixture/Cargo.toml b/tlsn/tlsn-server-fixture/Cargo.toml deleted file mode 100644 index b54e93155e..0000000000 --- a/tlsn/tlsn-server-fixture/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "tlsn-server-fixture" -version = "0.1.0" -edition = "2021" - -[dependencies] -async-rustls = "0.4.1" -axum = "0.6" -futures.workspace = true -hyper.workspace = true -rustls = "0.21.7" -tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } -tokio-util = { workspace = true, features = ["compat", "io"] } - -[[bin]] -name = "main" -path = "src/main.rs" diff --git a/tlsn/tlsn-server-fixture/src/lib.rs b/tlsn/tlsn-server-fixture/src/lib.rs deleted file mode 100644 index a6be004108..0000000000 --- a/tlsn/tlsn-server-fixture/src/lib.rs +++ /dev/null @@ -1,73 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use async_rustls::TlsAcceptor; -use axum::{ - extract::Query, - response::{Html, Json}, - routing::get, - Router, -}; -use futures::{AsyncRead, AsyncWrite}; -use hyper::{server::conn::Http, StatusCode}; -use rustls::{Certificate, PrivateKey, ServerConfig}; - -use tokio_util::compat::FuturesAsyncReadCompatExt; - -/// A certificate authority certificate fixture. -pub static CA_CERT_DER: &[u8] = include_bytes!("tls/rootCA.der"); -/// A server certificate (domain=test-server.io) fixture. -pub static SERVER_CERT_DER: &[u8] = include_bytes!("tls/domain.der"); -/// A server private key fixture. -pub static SERVER_KEY_DER: &[u8] = include_bytes!("tls/domain_key.der"); -/// The domain name bound to the server certificate. -pub static SERVER_DOMAIN: &str = "test-server.io"; - -fn app() -> Router { - Router::new() - .route("/", get(|| async { "Hello, World!" })) - .route("/formats/json", get(json)) - .route("/formats/html", get(html)) -} - -/// Bind the server to the given socket. -pub async fn bind(socket: T) { - let key = PrivateKey(SERVER_KEY_DER.to_vec()); - let cert = Certificate(SERVER_CERT_DER.to_vec()); - - let config = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(vec![cert], key) - .unwrap(); - - let acceptor = TlsAcceptor::from(Arc::new(config)); - - let conn = acceptor.accept(socket).await.unwrap(); - - Http::new() - .http1_only(true) - .http1_keep_alive(false) - .serve_connection(conn.compat(), app()) - .await - .unwrap(); -} - -async fn json( - Query(params): Query>, -) -> Result, StatusCode> { - let size = params - .get("size") - .and_then(|size| size.parse::().ok()) - .unwrap_or(1); - - match size { - 1 => Ok(Json(include_str!("data/1kb.json"))), - 4 => Ok(Json(include_str!("data/4kb.json"))), - 8 => Ok(Json(include_str!("data/8kb.json"))), - _ => Err(StatusCode::NOT_FOUND), - } -} - -async fn html() -> Html<&'static str> { - Html(include_str!("data/4kb.html")) -} diff --git a/tlsn/tlsn-server-fixture/src/tls/domain.crt b/tlsn/tlsn-server-fixture/src/tls/domain.crt deleted file mode 100644 index f946fbdc7c..0000000000 --- a/tlsn/tlsn-server-fixture/src/tls/domain.crt +++ /dev/null @@ -1,22 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDmzCCAoOgAwIBAgIUXYmS5GJNi70RMwqR1zzSjI6gjnIwDQYJKoZIhvcNAQEL -BQAwSjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxEjAQBgNVBAoM -CXRsc25vdGFyeTESMBAGA1UEAwwJdGxzbm90YXJ5MB4XDTIzMDYwNzAxMzcyOFoX -DTI0MDYwNjAxMzcyOFowWzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3Rh -dGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UEAwwL -dGVzdC1zZXJ2ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQClrwyM -b9JVV56IqAAE7QoAfnyYMRxL/93NV2II35Hq1DrBeXGvQ9EjD0qKMKCNIJFLQUaO -dIQ853+OGez9Q73836bPOqM7hSRdn34bB+4phKnxM2QegEL+0oR6YhXS+9iavAWj -obfgVmmZhtLMXAIRuZrMLVPRNm+MIi4BcEu6Ckgdxhvkvp7Gzb74NXffgdilqTJG -bKdvCuF7IgFZXTi/8ACAZDKuXt8t9w0VUZRC9X4YTNslesW8MaPbmRwlVR8+qDMA -f2UdoUS5t9NGnniM/zsVK+RvbrFD0yXoJ3Pr+NdpFwLFPuhQCX2l0r2edb5Xsw6d -unCm2vZIrajFmzLTAgMBAAGjaDBmMB8GA1UdIwQYMBaAFCzjJ91fETcwwaoAM0mB -Z8r8srGlMAkGA1UdEwQCMAAwGQYDVR0RBBIwEIIOdGVzdC1zZXJ2ZXIuaW8wHQYD -VR0OBBYEFJMXVtfmDSqllg3BuymCj4OgENVnMA0GCSqGSIb3DQEBCwUAA4IBAQBd -W2Y58hHXei5K1wXRKaSZV8uyI5a4F4h+75vNNDGcbU204YRAtwmTYLXZnUCtxhL5 -wDnH00z8Z8s+ZHfDdH/64lxM0VmVKNxIdF6KUMIyvrdK9aL2wRMLSWCZTMBGibs0 -npY6fGgXdAeZnG8iP6ede3tF7vNpr+no+lrsx7ZCYhUg/XvaGsR2wIoMpMhVyDv2 -jkxc+Xnt/Prr89mQQUQVg2zkwcPrgEM+NwpDMqH3BFVsx6Qu1FO6sAIREewSM6t9 -kgfkzmH97Z5HEjGV2CWjsNBEAPaafAnE8qqvHQkFUmps12LnsEGbZbM/8kxifjNX -V6wbaLYrV6WDttQINST8 ------END CERTIFICATE----- diff --git a/tlsn/tlsn-server-fixture/src/tls/domain.csr b/tlsn/tlsn-server-fixture/src/tls/domain.csr deleted file mode 100644 index a4201d2587..0000000000 --- a/tlsn/tlsn-server-fixture/src/tls/domain.csr +++ /dev/null @@ -1,17 +0,0 @@ ------BEGIN CERTIFICATE REQUEST----- -MIICoDCCAYgCAQAwWzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx -ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UEAwwLdGVz -dC1zZXJ2ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQClrwyMb9JV -V56IqAAE7QoAfnyYMRxL/93NV2II35Hq1DrBeXGvQ9EjD0qKMKCNIJFLQUaOdIQ8 -53+OGez9Q73836bPOqM7hSRdn34bB+4phKnxM2QegEL+0oR6YhXS+9iavAWjobfg -VmmZhtLMXAIRuZrMLVPRNm+MIi4BcEu6Ckgdxhvkvp7Gzb74NXffgdilqTJGbKdv -CuF7IgFZXTi/8ACAZDKuXt8t9w0VUZRC9X4YTNslesW8MaPbmRwlVR8+qDMAf2Ud -oUS5t9NGnniM/zsVK+RvbrFD0yXoJ3Pr+NdpFwLFPuhQCX2l0r2edb5Xsw6dunCm -2vZIrajFmzLTAgMBAAGgADANBgkqhkiG9w0BAQsFAAOCAQEAUY8AusY1fDOy9Ary -rbkaX+pSUd4HSoElYEvn5ikrguqsTny5+Dvd+4aVGLHSO0/Y/cA1oRq8YqFEC/2C -2PxQkiZY8BdGKTWCT26f/3S197K/lOzxGmXbfZDZVHstzzobMIHFd0NEGBamLH4w -iSnuNRHre48cVIlCx3S2CVskNikJGBiZVfQeNVhkd5zqEkAdYViEycoQ5RyCu3po -HGlUQx3Z2TY4wNv1iKMPj8701C+c1uIosHMrjDPpIqHjS+nbaKPkXVwTF2lbOOcO -111t3fk0EeCdnGq8g6rS9mEL/kBZZNZI0/Pwl3SIBqoxvBv6U+ennqZhnPVJSK07 -//ovUA== ------END CERTIFICATE REQUEST----- diff --git a/tlsn/tlsn-server-fixture/src/tls/domain.der b/tlsn/tlsn-server-fixture/src/tls/domain.der deleted file mode 100644 index 0a9191c164..0000000000 Binary files a/tlsn/tlsn-server-fixture/src/tls/domain.der and /dev/null differ diff --git a/tlsn/tlsn-server-fixture/src/tls/domain.ext b/tlsn/tlsn-server-fixture/src/tls/domain.ext deleted file mode 100644 index 3853d7ccb9..0000000000 --- a/tlsn/tlsn-server-fixture/src/tls/domain.ext +++ /dev/null @@ -1,5 +0,0 @@ -authorityKeyIdentifier=keyid,issuer -basicConstraints=CA:FALSE -subjectAltName = @alt_names -[alt_names] -DNS.1 = test-server.io diff --git a/tlsn/tlsn-server-fixture/src/tls/domain.key b/tlsn/tlsn-server-fixture/src/tls/domain.key deleted file mode 100644 index 0ca070d38c..0000000000 --- a/tlsn/tlsn-server-fixture/src/tls/domain.key +++ /dev/null @@ -1,30 +0,0 @@ ------BEGIN ENCRYPTED PRIVATE KEY----- -MIIFHDBOBgkqhkiG9w0BBQ0wQTApBgkqhkiG9w0BBQwwHAQInlBxGC6Q/nwCAggA -MAwGCCqGSIb3DQIJBQAwFAYIKoZIhvcNAwcECPW0f2DL+mvdBIIEyOCxtYAeFPSA -IHXX/UtF0Z0FtFP+d1/osMI0wxJrxz0KM9fN8z7XEkzXIR8ZjMCsduIl/zirl+bB -L9UGAHBNIn3yYpwh9UXhFBC4t8GM5xy89E6GMD7DU5X0xA/B1RvGFlHY+EkyTYr6 -crhY8uxTEoVPaxgbeyZfJaTP4bcZyAVbgIN0TqOX7fk+zIKmCHauRETwJho8ZZ+d -tvucnpA4WqNyLhgWW+ScApzgVEsbBXV80tEDOXEw7TVtaS1VqlKP2pFVVHuCWQWk -GnZT57LY4vzBkmpEnmpNj10IVPg5D9Ro+XNsXv9ji67NU42AOopWngOLhqMEyVS2 -DJL9X5DQYQJKSx7cKVkpNkBMLZC6V7iPhG1WHr+OUgYdNSf6l/59tLLQTpN0u/19 -VsgaS84CqAv4eFOVwfr9QTwPg7BNKbvdzn41IvXA05OT3xs0P2FhbjafFaJrivu+ -0krVb1T7RyiwqJSr8EujufHptIzkax/HE5+g++v6laiHlZXealGmGsYKNkO4Tyn9 -3QU38U41A4LyBrnlWn8mEZxY/SM7nQhYdfZ2CDCwqpBmvozxuY4p6fwcyPXmMmxv -bwzDlIRkk1VxobRKYqoWhCXKCRm9c/WZoymbWghLhEhbwgRAfM6sTAtNKbAooY15 -RGGCZE/QIqroWdMyH9hWQ01KOSA+CWWgsAUr+l6w678Jcx8bOSlbJkjpHVY4Kg0G -kngficAHb4nVNmGXNoDqg/EgDi57J8jNEKxLK3SIs5xwcFqTQl5qF6vJUR64yYS1 -aJe6+AedSvubrSzqnuI7SgX44WoHpOduhu99AAMgHOSAOgg5Nxr29Jyo/h6qKsx8 -u+EDrIcQPsgpITNALmI6yT7uZJYnkFfvuZdiTOqAURVUVLUsnAuCzaUzbBdM2Vbb -rJlOkHJqrX9NxhIl8CSzdGiwdBc1gwi5eL84NY3MTSOD4+HN0u17a9O/EK2OhVNR -Jyt+r0WgE2lmaLKNaRvruwuU9oVG0cfWxRilOxFKnDpk9SMe4ZvUUIGe8wjx2BNs -u87EvPK71+gGLkWxo6qGH6xk/26dO/yjclFHDxjLe63WBTBrUrMucKjMbyxuKYrC -JgNyHskeK7cqOixg4cDMCIfNoHg8nnawF2fAYY+NEKJq29YcAWO1P8PuqpINdgke -Sac+ogpp56ZtFlmgQ6PYOM/24b3dFQwdt46HzK7iluVR7Ihq0uiBO7UDilz9qUPX -sdPSmOYoFST6kgx9Occk6MeJKMBH+f8nVnTauwhBs3s2AZEgHd07MreUF+G1njbI -ukCvL8MHYhwqCCd6EZaH7U7oNGA8roIRPkvpSXeuNhDsabxQIzQNyOy9v3fKNtyt -EAR1bnrKwXX/32dC/JFcJOfzMuhCclw/88miNtzDz/tXym7EUU4UbX5ARoDNHuaJ -0DHfEMCTSa2umiqMClTvfOgn9OWI140y6HPqWXmJJReiKkGkDJiBcRiBCGj0uxwz -qrclNr5iDk+USY09qR2gWt30xaBclgzAL9b7YEMulfcrlrOGxrVUlDSezs2YZQRH -156WWeCrruY+MjFZ4w0jhqkRW0osO9TwbyAZTlVYqrn9RChmfnW1Mn6p526ZkDer -nSQuQlLmTNOHGRRFvKg/6Q== ------END ENCRYPTED PRIVATE KEY----- diff --git a/tlsn/tlsn-server-fixture/src/tls/domain_key.der b/tlsn/tlsn-server-fixture/src/tls/domain_key.der deleted file mode 100644 index 984284521b..0000000000 Binary files a/tlsn/tlsn-server-fixture/src/tls/domain_key.der and /dev/null differ diff --git a/tlsn/tlsn-server-fixture/src/tls/rootCA.crt b/tlsn/tlsn-server-fixture/src/tls/rootCA.crt deleted file mode 100644 index 728847ce25..0000000000 --- a/tlsn/tlsn-server-fixture/src/tls/rootCA.crt +++ /dev/null @@ -1,21 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDdTCCAl2gAwIBAgIUYL1NTquyWYgFShuzP3SXUW1LRLswDQYJKoZIhvcNAQEL -BQAwSjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxEjAQBgNVBAoM -CXRsc25vdGFyeTESMBAGA1UEAwwJdGxzbm90YXJ5MB4XDTIzMDYwNzAxMzQyN1oX -DTI4MDYwNTAxMzQyN1owSjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3Rh -dGUxEjAQBgNVBAoMCXRsc25vdGFyeTESMBAGA1UEAwwJdGxzbm90YXJ5MIIBIjAN -BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyhPglwjjaUxTvlQZK3IOVdiormCz -eHCvyZscBcolbhL8CbrtJYkozfMLwaBSccqBLpMbvulOeKBBB3ll9iffeWkQIcnh -m4ozf7DdG7EhDnp0lBS40UmDV4zHKbqlUzBBt737RGaDSDzf0w2b8tVOA0ZkBXw7 -aGlbj1ikV45GFlyH5b+wCvcboUaJkt4CllqF8uIo4Csjp3EqLMnGW6556HNukLoi -YlqW89VM+7C0yaL43ROjPr0lXPBpCrV9jnsVSaBJ2u3Ae35KgaqFxcMrXzpMArjS -jEXSW6Gi7XdI7bWMpfnLv3eyRkuEzkhCu1PPTA1EBa4eEikAqilU7ukKzQIDAQAB -o1MwUTAdBgNVHQ4EFgQULOMn3V8RNzDBqgAzSYFnyvyysaUwHwYDVR0jBBgwFoAU -LOMn3V8RNzDBqgAzSYFnyvyysaUwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0B -AQsFAAOCAQEArIKYiJ6eLqNcy1UCg4vboRWByVG4nN5NOPoBNRT1xTzEPsjEoZpw -EdE8roKD4jG6Pf73yiMsLGhSUA7HBET6NTS5g7wGgGL0QPQIeLHLf1oLWb17Zqmm -i316NBzwPEXx8snK/WcBzPimbcKRDAYJG5UahGJyaIE5tcXORq7tK4YCycM1qTq0 -dx5RjO+ZJ1160nstXRUhRumilPW1FzxOt+wnaJza07zwLF9NphnAh6XGCsisz/KW -8scL+DNu6W0+irixhGl3Ly9idNQLmDUEWoO6sp3MFW7He5iJgzFlIY1oo7uQFkHR -GzrhAlDfCcoCHFvHc5MuqFpvgSsWEoC2AQ== ------END CERTIFICATE----- diff --git a/tlsn/tlsn-server-fixture/src/tls/rootCA.der b/tlsn/tlsn-server-fixture/src/tls/rootCA.der deleted file mode 100644 index d974cd41ff..0000000000 Binary files a/tlsn/tlsn-server-fixture/src/tls/rootCA.der and /dev/null differ diff --git a/tlsn/tlsn-server-fixture/src/tls/rootCA.key b/tlsn/tlsn-server-fixture/src/tls/rootCA.key deleted file mode 100644 index dbd954eb03..0000000000 --- a/tlsn/tlsn-server-fixture/src/tls/rootCA.key +++ /dev/null @@ -1,30 +0,0 @@ ------BEGIN ENCRYPTED PRIVATE KEY----- -MIIFHDBOBgkqhkiG9w0BBQ0wQTApBgkqhkiG9w0BBQwwHAQIwMUhKGCZU0kCAggA -MAwGCCqGSIb3DQIJBQAwFAYIKoZIhvcNAwcECJxXqH54fWnLBIIEyI9L1uoU2blz -siT/FPwMzDTgGJCYtre8yIuA81OTURH8V3PM4r8SguPk5xv75WcXu71/+I5hK0s6 -ia1BZHHSQosnu8XrGNTyIaybg1JzrXtunlFUoH64Tg5tXc2G53lbXb7LMGs7YAhh -SG6bmbSerVVOj5Iu1X3+Vdjlsp26BdcOGZ83xZVyRWRWUDjHfrQ0xmMBdB2pe0nD -7cqpSD6WGFXHaRoV+zL/ErjelkqU7V9GSDajZxQxniENEPddpucc7Pz29rqE/2Qf -XX/HD/9hz95RvssYX7b3dczoqheRxIsx72Fn5B/QzEGW/PPaNGi8iQZr8rsoE3S1 -S+nbBKCRiMTwLYu3Z5EjWeTHBtA8tKGPW6HF/9Ga1eEnKTX5+7SgyxbfkeDK1SJN -M77V4qN6R9HvbvduSjO9utDr/zCvHu68Ju9dYCNOiKdLS417pVNtB47EnUiN7kDF -Zrd+lL/8crqNWokocrw2qxLuGMsBqHdvtl6F5O57w9rKZl72gfQTg0dsh5YPN7K9 -xckGaZg0n8nUDAop6bSzXc+0JCwzXfGtKsdq4VjBgy/QtqLX24SgloHJHrl1qyGj -eSnKRvDcazD4XEDEsn9HBIaHHHm8lIrh8IhgsE3UvifgfOf4DO9uasKEdXahZaEu -9WH1mxnY67X+z9IaslMgkzoVnZARq1Swin+SxdY6rWJHucCLO8pCh4R9w6J/olXM -+bs3ivrjygJ2TOYgPV8CJ8IU65Z4C/Wv/Rt4p1QwxK/swkUjINqryh6rg2BGfHE8 -yLUioVimx4c/kziZEq3z8+eq6vs2cB28q3ze39REHAtk3ezsvIirtY/FG10zrWgx -cWHJYUKjYug1RkPWwYr3/vuTh68IrEmGi4TMcojxSUpq1vdc2lk9zu9vXTEsKF00 -NQP5txscxQZwdb5Q01pylcTHOeiEjbQ0hRQ+oelwRniKXAuAp5D7H4LnUcw07wfp -nA1GEqMh7M7/dM3E7eXmzS2izDnwE9E0Va0Qu4fYpu+VPwahgP71bVheYRb1Rr1a -ODnUcReuNqJhZ7q/BTRZ4fK7vrVfuJjecMGNBdFRYH1etQgcZPJzJWnZaMP+S+iO -Jqk/1aMAR5E6ejO5KwExmO9JFQplfbJmrvtGqq1AkkOr+EefR5CK7Tr2jNdEE2xg -b5Z1BmqniJ2I/o9L9AxFFKE9PjXeLLS6XHS3N5VPjLzACJhchQZWI1RbEt6qAW6i -LwizIDJLN57mxaxtuQsEg+hD1tNiRIaF3dsCpBq7MTc64/iWphCydYF8adRSzoVy -07GwAmZ9JX/fd6WCWYAo97N0S7/adBqAi/XXA1wqkW/ZpRF63tXNcLJXwLcoeclr -DKpMb0NtjCT5QNJNVGHTUv4RdGEWbN2u6Oq6rfOfYnXvjrzmpAVVkRJ8BlQsTkzN -5T+BV4txwS5a1EH2ERfY4KuQWcSKIP2lKjNvqKCtwbHOSoQ6wmtLJKzw9mLD8JzQ -9wROywbA4CDW4dbSPuHUTmSpfBrt36mfV8+loJ8m3+eqcQNIh44drO3P8wwaIAEe -1YT7Yt75YD2EjtXvw2umxvQP/PWk2t+EI+Fo9/t/U2zPHgFuxevCMzNsRktz108M -rIuTToR3EfYojhEvyMyf7g== ------END ENCRYPTED PRIVATE KEY----- diff --git a/tlsn/tlsn-verifier/Cargo.toml b/tlsn/tlsn-verifier/Cargo.toml deleted file mode 100644 index 0c3bd6a3d4..0000000000 --- a/tlsn/tlsn-verifier/Cargo.toml +++ /dev/null @@ -1,36 +0,0 @@ -[package] -name = "tlsn-verifier" -authors = ["TLSNotary Team"] -description = "A library for the TLSNotary verifier" -keywords = ["tls", "mpc", "2pc"] -categories = ["cryptography"] -license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.3" -edition = "2021" - -[features] -tracing = ["dep:tracing", "tlsn-tls-mpc/tracing", "tlsn-common/tracing"] - -[dependencies] -tlsn-core.workspace = true -tlsn-common.workspace = true -tlsn-tls-core.workspace = true -tlsn-tls-mpc.workspace = true -uid-mux.workspace = true - -tlsn-utils-aio.workspace = true - -mpz-core.workspace = true -mpz-garble.workspace = true -mpz-ot.workspace = true -mpz-share-conversion.workspace = true -mpz-circuits.workspace = true - -futures.workspace = true -thiserror.workspace = true -derive_builder.workspace = true -rand.workspace = true -signature.workspace = true -opaque-debug.workspace = true - -tracing = { workspace = true, optional = true } diff --git a/tlsn/tlsn-verifier/src/lib.rs b/tlsn/tlsn-verifier/src/lib.rs deleted file mode 100644 index 403a9f3f44..0000000000 --- a/tlsn/tlsn-verifier/src/lib.rs +++ /dev/null @@ -1,7 +0,0 @@ -//! TLSNotary verifier library. - -#![deny(missing_docs, unreachable_pub, unused_must_use)] -#![deny(clippy::all)] -#![forbid(unsafe_code)] - -pub mod tls; diff --git a/tlsn/tlsn-verifier/src/tls/config.rs b/tlsn/tlsn-verifier/src/tls/config.rs deleted file mode 100644 index 55122a23a0..0000000000 --- a/tlsn/tlsn-verifier/src/tls/config.rs +++ /dev/null @@ -1,119 +0,0 @@ -use mpz_ot::{chou_orlandi, kos}; -use mpz_share_conversion::{ReceiverConfig, SenderConfig}; -use std::fmt::{Debug, Formatter, Result}; -use tls_core::verify::{ServerCertVerifier, WebPkiVerifier}; -use tls_mpc::{MpcTlsCommonConfig, MpcTlsFollowerConfig}; -use tlsn_core::proof::default_cert_verifier; - -const DEFAULT_MAX_TRANSCRIPT_SIZE: usize = 1 << 14; // 16Kb - -/// Configuration for the [`Verifier`](crate::tls::Verifier) -#[allow(missing_docs)] -#[derive(derive_builder::Builder)] -#[builder(pattern = "owned")] -pub struct VerifierConfig { - #[builder(setter(into))] - id: String, - - /// Maximum transcript size in bytes - /// - /// This includes the number of bytes sent and received to the server. - #[builder(default = "DEFAULT_MAX_TRANSCRIPT_SIZE")] - max_transcript_size: usize, - #[builder( - pattern = "owned", - setter(strip_option), - default = "Some(default_cert_verifier())" - )] - cert_verifier: Option, -} - -impl Debug for VerifierConfig { - fn fmt(&self, f: &mut Formatter<'_>) -> Result { - f.debug_struct("VerifierConfig") - .field("id", &self.id) - .field("max_transcript_size", &self.max_transcript_size) - .field("cert_verifier", &"_") - .finish() - } -} - -impl VerifierConfig { - /// Create a new configuration builder. - pub fn builder() -> VerifierConfigBuilder { - VerifierConfigBuilder::default() - } - - /// Returns the ID of the notarization session. - pub fn id(&self) -> &str { - &self.id - } - - /// Get the maximum transcript size in bytes. - pub fn max_transcript_size(&self) -> usize { - self.max_transcript_size - } - - /// Get the certificate verifier. - pub fn cert_verifier(&self) -> &impl ServerCertVerifier { - self.cert_verifier - .as_ref() - .expect("Certificate verifier should be set") - } - - pub(crate) fn build_base_ot_sender_config(&self) -> chou_orlandi::SenderConfig { - chou_orlandi::SenderConfig::default() - } - - pub(crate) fn build_base_ot_receiver_config(&self) -> chou_orlandi::ReceiverConfig { - chou_orlandi::ReceiverConfig::builder() - .receiver_commit() - .build() - .unwrap() - } - - pub(crate) fn build_ot_sender_config(&self) -> kos::SenderConfig { - kos::SenderConfig::builder() - .sender_commit() - .build() - .unwrap() - } - - pub(crate) fn build_ot_receiver_config(&self) -> kos::ReceiverConfig { - kos::ReceiverConfig::default() - } - - pub(crate) fn build_mpc_tls_config(&self) -> MpcTlsFollowerConfig { - MpcTlsFollowerConfig::builder() - .common( - MpcTlsCommonConfig::builder() - .id(format!("{}/mpc_tls", &self.id)) - .max_transcript_size(self.max_transcript_size) - .handshake_commit(true) - .build() - .unwrap(), - ) - .build() - .unwrap() - } - - pub(crate) fn ot_count(&self) -> usize { - self.max_transcript_size * 8 - } - - pub(crate) fn build_p256_sender_config(&self) -> SenderConfig { - SenderConfig::builder().id("p256/1").build().unwrap() - } - - pub(crate) fn build_p256_receiver_config(&self) -> ReceiverConfig { - ReceiverConfig::builder().id("p256/0").build().unwrap() - } - - pub(crate) fn build_gf2_config(&self) -> ReceiverConfig { - ReceiverConfig::builder() - .id("gf2") - .record() - .build() - .unwrap() - } -} diff --git a/tlsn/tlsn-verifier/src/tls/error.rs b/tlsn/tlsn-verifier/src/tls/error.rs deleted file mode 100644 index 9791c3ebeb..0000000000 --- a/tlsn/tlsn-verifier/src/tls/error.rs +++ /dev/null @@ -1,64 +0,0 @@ -use std::error::Error; -use tls_mpc::MpcTlsError; - -/// An error that can occur during TLS verification. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum VerifierError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - MuxerError(#[from] utils_aio::mux::MuxerError), - #[error("error occurred in MPC protocol: {0}")] - MpcError(Box), - #[error("Range exceeds transcript length")] - InvalidRange, -} - -impl From for VerifierError { - fn from(e: MpcTlsError) -> Self { - Self::MpcError(Box::new(e)) - } -} - -impl From for VerifierError { - fn from(e: mpz_ot::OTError) -> Self { - Self::MpcError(Box::new(e)) - } -} - -impl From for VerifierError { - fn from(e: mpz_ot::actor::kos::SenderActorError) -> Self { - Self::MpcError(Box::new(e)) - } -} - -impl From for VerifierError { - fn from(e: mpz_ot::actor::kos::ReceiverActorError) -> Self { - Self::MpcError(Box::new(e)) - } -} - -impl From for VerifierError { - fn from(e: mpz_garble::VerifyError) -> Self { - Self::MpcError(Box::new(e)) - } -} - -impl From for VerifierError { - fn from(e: mpz_garble::MemoryError) -> Self { - Self::MpcError(Box::new(e)) - } -} - -impl From for VerifierError { - fn from(e: tlsn_core::proof::SessionProofError) -> Self { - Self::MpcError(Box::new(e)) - } -} - -impl From for VerifierError { - fn from(e: mpz_garble::VmError) -> Self { - Self::MpcError(Box::new(e)) - } -} diff --git a/tlsn/tlsn-verifier/src/tls/future.rs b/tlsn/tlsn-verifier/src/tls/future.rs deleted file mode 100644 index 886fb4128e..0000000000 --- a/tlsn/tlsn-verifier/src/tls/future.rs +++ /dev/null @@ -1,50 +0,0 @@ -//! This module collects futures which are used by the [Verifier](crate::tls::Verifier). - -use super::{OTSenderActor, VerifierError}; -use futures::{future::FusedFuture, Future}; -use std::pin::Pin; - -/// A future which must be polled for the muxer to make progress. -pub(crate) struct MuxFuture { - pub(crate) fut: Pin> + Send + 'static>>, -} - -impl Future for MuxFuture { - type Output = Result<(), VerifierError>; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - self.fut.as_mut().poll(cx) - } -} - -impl FusedFuture for MuxFuture { - fn is_terminated(&self) -> bool { - self.fut.is_terminated() - } -} - -/// A future which must be polled for the Oblivious Transfer protocol to make progress. -pub(crate) struct OTFuture { - pub(crate) fut: - Pin> + Send + 'static>>, -} - -impl Future for OTFuture { - type Output = Result; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - self.fut.as_mut().poll(cx) - } -} - -impl FusedFuture for OTFuture { - fn is_terminated(&self) -> bool { - self.fut.is_terminated() - } -} diff --git a/tlsn/tlsn-verifier/src/tls/mod.rs b/tlsn/tlsn-verifier/src/tls/mod.rs deleted file mode 100644 index 4acfffcff3..0000000000 --- a/tlsn/tlsn-verifier/src/tls/mod.rs +++ /dev/null @@ -1,345 +0,0 @@ -//! TLS Verifier - -pub(crate) mod config; -mod error; -mod future; -mod notarize; -pub mod state; -mod verify; - -pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderError}; -pub use error::VerifierError; - -use std::time::{SystemTime, UNIX_EPOCH}; - -use crate::tls::future::OTFuture; -use future::MuxFuture; -use futures::{ - stream::{SplitSink, SplitStream}, - AsyncRead, AsyncWrite, FutureExt, StreamExt, TryFutureExt, -}; -use mpz_garble::{config::Role as GarbleRole, protocol::deap::DEAPVm}; -use mpz_ot::{ - actor::kos::{ - msgs::Message as ActorMessage, ReceiverActor, SenderActor, SharedReceiver, SharedSender, - }, - chou_orlandi, kos, -}; -use mpz_share_conversion as ff; -use rand::Rng; -use signature::Signer; -use state::{Notarize, Verify}; -use tls_mpc::{setup_components, MpcTlsFollower, MpcTlsFollowerData, TlsRole}; -use tlsn_common::{ - mux::{attach_mux, MuxControl}, - Role, -}; -use tlsn_core::{proof::SessionInfo, RedactedTranscript, SessionHeader, Signature}; -use utils_aio::{duplex::Duplex, mux::MuxChannel}; - -#[cfg(feature = "tracing")] -use tracing::{debug, info, instrument}; - -type OTSenderActor = SenderActor< - chou_orlandi::Receiver, - SplitSink< - Box>>, - ActorMessage, - >, - SplitStream>>>, ->; - -/// A Verifier instance. -pub struct Verifier { - config: VerifierConfig, - state: T, -} - -impl Verifier { - /// Create a new verifier. - pub fn new(config: VerifierConfig) -> Self { - Self { - config, - state: state::Initialized, - } - } - - /// Set up the verifier. - /// - /// This performs all MPC setup. - /// - /// # Arguments - /// - /// * `socket` - The socket to the prover. - pub async fn setup( - self, - socket: S, - ) -> Result, VerifierError> { - let (mut mux, mux_ctrl) = attach_mux(socket, Role::Verifier); - - let mut mux_fut = MuxFuture { - fut: Box::pin(async move { mux.run().await.map_err(VerifierError::from) }.fuse()), - }; - - let encoder_seed: [u8; 32] = rand::rngs::OsRng.gen(); - let mpc_setup_fut = setup_mpc_backend(&self.config, mux_ctrl.clone(), encoder_seed); - let (mpc_tls, vm, ot_send, ot_recv, gf2, ot_fut) = futures::select! { - res = mpc_setup_fut.fuse() => res?, - _ = &mut mux_fut => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; - - Ok(Verifier { - config: self.config, - state: state::Setup { - mux_ctrl, - mux_fut, - mpc_tls, - vm, - ot_send, - ot_recv, - ot_fut, - gf2, - encoder_seed, - }, - }) - } - - /// Runs the TLS verifier to completion, notarizing the TLS session. - /// - /// This is a convenience method which runs all the steps needed for notarization. - pub async fn notarize( - self, - socket: S, - signer: &impl Signer, - ) -> Result - where - T: Into, - { - self.setup(socket) - .await? - .run() - .await? - .start_notarize() - .finalize(signer) - .await - } - - /// Runs the TLS verifier to completion, verifying the TLS session. - /// - /// This is a convenience method which runs all the steps needed for verification. - pub async fn verify( - self, - socket: S, - ) -> Result<(RedactedTranscript, RedactedTranscript, SessionInfo), VerifierError> { - let mut verifier = self.setup(socket).await?.run().await?.start_verify(); - let (redacted_sent, redacted_received) = verifier.receive().await?; - - let session_info = verifier.finalize().await?; - Ok((redacted_sent, redacted_received, session_info)) - } -} - -impl Verifier { - /// Runs the verifier until the TLS connection is closed. - pub async fn run(self) -> Result, VerifierError> { - let state::Setup { - mux_ctrl, - mut mux_fut, - mpc_tls, - vm, - ot_send, - ot_recv, - mut ot_fut, - gf2, - encoder_seed, - } = self.state; - - let start_time = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - - let (_, mpc_fut) = mpc_tls.run(); - - let MpcTlsFollowerData { - handshake_commitment, - server_key: server_ephemeral_key, - bytes_sent: sent_len, - bytes_recv: recv_len, - } = futures::select! { - res = mpc_fut.fuse() => res?, - _ = &mut mux_fut => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - res = ot_fut => return Err(res.map(|_| ()).expect_err("future will not return Ok here")) - }; - - #[cfg(feature = "tracing")] - info!("Finished TLS session"); - - // TODO: We should be able to skip this commitment and verify the handshake directly. - let handshake_commitment = handshake_commitment.expect("handshake commitment is set"); - - Ok(Verifier { - config: self.config, - state: state::Closed { - mux_ctrl, - mux_fut, - vm, - ot_send, - ot_recv, - ot_fut, - gf2, - encoder_seed, - start_time, - server_ephemeral_key, - handshake_commitment, - sent_len, - recv_len, - }, - }) - } -} - -impl Verifier { - /// Starts notarization of the TLS session. - /// - /// If the verifier is a Notary, this function will transition the verifier to the next state - /// where it can sign the prover's commitments to the transcript. - pub fn start_notarize(self) -> Verifier { - Verifier { - config: self.config, - state: self.state.into(), - } - } - - /// Starts verification of the TLS session. - /// - /// This function transitions the verifier into a state where it can verify content of the - /// transcript. - pub fn start_verify(self) -> Verifier { - Verifier { - config: self.config, - state: self.state.into(), - } - } -} - -/// Performs a setup of the various MPC subprotocols. -#[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] -#[allow(clippy::type_complexity)] -async fn setup_mpc_backend( - config: &VerifierConfig, - mut mux_ctrl: MuxControl, - encoder_seed: [u8; 32], -) -> Result< - ( - MpcTlsFollower, - DEAPVm, - SharedSender, - SharedReceiver, - ff::ConverterReceiver, - OTFuture, - ), - VerifierError, -> { - let (ot_send_sink, ot_send_stream) = mux_ctrl.get_channel("ot/1").await?.split(); - let (ot_recv_sink, ot_recv_stream) = mux_ctrl.get_channel("ot/0").await?.split(); - - let mut ot_sender_actor = OTSenderActor::new( - kos::Sender::new( - config.build_ot_sender_config(), - chou_orlandi::Receiver::new(config.build_base_ot_receiver_config()), - ), - ot_send_sink, - ot_send_stream, - ); - - let mut ot_receiver_actor = ReceiverActor::new( - kos::Receiver::new( - config.build_ot_receiver_config(), - chou_orlandi::Sender::new(config.build_base_ot_sender_config()), - ), - ot_recv_sink, - ot_recv_stream, - ); - - let ot_send = ot_sender_actor.sender(); - let ot_recv = ot_receiver_actor.receiver(); - - #[cfg(feature = "tracing")] - debug!("Starting OT setup"); - - futures::try_join!( - ot_sender_actor - .setup(config.ot_count()) - .map_err(VerifierError::from), - ot_receiver_actor - .setup(config.ot_count()) - .map_err(VerifierError::from) - )?; - - #[cfg(feature = "tracing")] - debug!("OT setup complete"); - - let ot_fut = OTFuture { - fut: Box::pin( - async move { - futures::try_join!( - ot_sender_actor.run().map_err(VerifierError::from), - ot_receiver_actor.run().map_err(VerifierError::from) - )?; - - Ok(ot_sender_actor) - } - .fuse(), - ), - }; - - let mut vm = DEAPVm::new( - "vm", - GarbleRole::Follower, - encoder_seed, - mux_ctrl.get_channel("vm").await?, - Box::new(mux_ctrl.clone()), - ot_send.clone(), - ot_recv.clone(), - ); - - let p256_sender_config = config.build_p256_sender_config(); - let channel = mux_ctrl.get_channel(p256_sender_config.id()).await?; - let p256_send = - ff::ConverterSender::::new(p256_sender_config, ot_send.clone(), channel); - - let p256_receiver_config = config.build_p256_receiver_config(); - let channel = mux_ctrl.get_channel(p256_receiver_config.id()).await?; - let p256_recv = - ff::ConverterReceiver::::new(p256_receiver_config, ot_recv.clone(), channel); - - let gf2_config = config.build_gf2_config(); - let channel = mux_ctrl.get_channel(gf2_config.id()).await?; - let gf2 = ff::ConverterReceiver::::new(gf2_config, ot_recv.clone(), channel); - - let mpc_tls_config = config.build_mpc_tls_config(); - - let (ke, prf, encrypter, decrypter) = setup_components( - mpc_tls_config.common(), - TlsRole::Follower, - &mut mux_ctrl, - &mut vm, - p256_send, - p256_recv, - gf2.handle() - .map_err(|e| VerifierError::MpcError(Box::new(e)))?, - ) - .await - .map_err(|e| VerifierError::MpcError(Box::new(e)))?; - - let channel = mux_ctrl.get_channel(mpc_tls_config.common().id()).await?; - let mut mpc_tls = MpcTlsFollower::new(mpc_tls_config, channel, ke, prf, encrypter, decrypter); - - mpc_tls.setup().await?; - - #[cfg(feature = "tracing")] - debug!("MPC backend setup complete"); - - Ok((mpc_tls, vm, ot_send, ot_recv, gf2, ot_fut)) -} diff --git a/tlsn/tlsn-verifier/src/tls/notarize.rs b/tlsn/tlsn-verifier/src/tls/notarize.rs deleted file mode 100644 index a3fc9c39e1..0000000000 --- a/tlsn/tlsn-verifier/src/tls/notarize.rs +++ /dev/null @@ -1,107 +0,0 @@ -//! This module handles the notarization phase of the verifier. -//! -//! The TLS verifier is only a notary. - -use super::{state::Notarize, Verifier, VerifierError}; -use futures::{FutureExt, SinkExt, StreamExt, TryFutureExt}; -use mpz_core::serialize::CanonicalSerialize; -use mpz_share_conversion::ShareConversionVerify; -use signature::Signer; -use tlsn_core::{ - msg::{SignedSessionHeader, TlsnMessage}, - HandshakeSummary, SessionHeader, Signature, -}; -use utils_aio::{expect_msg_or_err, mux::MuxChannel}; - -#[cfg(feature = "tracing")] -use tracing::info; - -impl Verifier { - /// Notarizes the TLS session. - pub async fn finalize(self, signer: &impl Signer) -> Result - where - T: Into, - { - let Notarize { - mut mux_ctrl, - mut mux_fut, - mut vm, - ot_send, - ot_recv, - ot_fut, - mut gf2, - encoder_seed, - start_time, - server_ephemeral_key, - handshake_commitment, - sent_len, - recv_len, - } = self.state; - - let notarize_fut = async { - let mut notarize_channel = mux_ctrl.get_channel("notarize").await?; - - let merkle_root = - expect_msg_or_err!(notarize_channel, TlsnMessage::TranscriptCommitmentRoot)?; - - // Finalize all MPC before signing the session header - let (mut ot_sender_actor, _, _) = futures::try_join!( - ot_fut, - ot_send.shutdown().map_err(VerifierError::from), - ot_recv.shutdown().map_err(VerifierError::from) - )?; - - ot_sender_actor.reveal().await?; - - vm.finalize() - .await - .map_err(|e| VerifierError::MpcError(Box::new(e)))?; - - gf2.verify() - .await - .map_err(|e| VerifierError::MpcError(Box::new(e)))?; - - #[cfg(feature = "tracing")] - info!("Finalized all MPC"); - - let handshake_summary = - HandshakeSummary::new(start_time, server_ephemeral_key, handshake_commitment); - - let session_header = SessionHeader::new( - encoder_seed, - merkle_root, - sent_len, - recv_len, - handshake_summary, - ); - - let signature = signer.sign(&session_header.to_bytes()); - - #[cfg(feature = "tracing")] - info!("Signed session header"); - - notarize_channel - .send(TlsnMessage::SignedSessionHeader(SignedSessionHeader { - header: session_header.clone(), - signature: signature.into(), - })) - .await?; - - #[cfg(feature = "tracing")] - info!("Sent session header"); - - Ok::<_, VerifierError>(session_header) - }; - - let session_header = futures::select! { - res = notarize_fut.fuse() => res?, - _ = &mut mux_fut => Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; - - let mut mux_ctrl = mux_ctrl.into_inner(); - - futures::try_join!(mux_ctrl.close().map_err(VerifierError::from), mux_fut)?; - - Ok(session_header) - } -} diff --git a/tlsn/tlsn-verifier/src/tls/state.rs b/tlsn/tlsn-verifier/src/tls/state.rs deleted file mode 100644 index 3100fa061e..0000000000 --- a/tlsn/tlsn-verifier/src/tls/state.rs +++ /dev/null @@ -1,157 +0,0 @@ -//! TLS Verifier state. - -use mpz_core::hash::Hash; -use mpz_garble::protocol::deap::{DEAPThread, DEAPVm}; -use mpz_ot::actor::kos::{SharedReceiver, SharedSender}; -use mpz_share_conversion::{ConverterReceiver, Gf2_128}; -use tls_core::key::PublicKey; -use tls_mpc::MpcTlsFollower; -use tlsn_common::mux::MuxControl; -use tlsn_core::msg::TlsnMessage; -use utils_aio::duplex::Duplex; - -use crate::tls::future::{MuxFuture, OTFuture}; - -/// TLS Verifier state. -pub trait VerifierState: sealed::Sealed {} - -/// Initialized state. -pub struct Initialized; - -opaque_debug::implement!(Initialized); - -/// State after MPC setup has completed. -pub struct Setup { - pub(crate) mux_ctrl: MuxControl, - pub(crate) mux_fut: MuxFuture, - - pub(crate) mpc_tls: MpcTlsFollower, - pub(crate) vm: DEAPVm, - pub(crate) ot_send: SharedSender, - pub(crate) ot_recv: SharedReceiver, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterReceiver, - - pub(crate) encoder_seed: [u8; 32], -} - -/// State after the TLS connection has been closed. -pub struct Closed { - pub(crate) mux_ctrl: MuxControl, - pub(crate) mux_fut: MuxFuture, - - pub(crate) vm: DEAPVm, - pub(crate) ot_send: SharedSender, - pub(crate) ot_recv: SharedReceiver, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterReceiver, - - pub(crate) encoder_seed: [u8; 32], - pub(crate) start_time: u64, - pub(crate) server_ephemeral_key: PublicKey, - pub(crate) handshake_commitment: Hash, - pub(crate) sent_len: usize, - pub(crate) recv_len: usize, -} - -opaque_debug::implement!(Closed); - -/// Notarizing state. -pub struct Notarize { - pub(crate) mux_ctrl: MuxControl, - pub(crate) mux_fut: MuxFuture, - - pub(crate) vm: DEAPVm, - pub(crate) ot_send: SharedSender, - pub(crate) ot_recv: SharedReceiver, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterReceiver, - - pub(crate) encoder_seed: [u8; 32], - pub(crate) start_time: u64, - pub(crate) server_ephemeral_key: PublicKey, - pub(crate) handshake_commitment: Hash, - pub(crate) sent_len: usize, - pub(crate) recv_len: usize, -} - -opaque_debug::implement!(Notarize); - -impl From for Notarize { - fn from(value: Closed) -> Self { - Self { - mux_ctrl: value.mux_ctrl, - mux_fut: value.mux_fut, - vm: value.vm, - ot_send: value.ot_send, - ot_recv: value.ot_recv, - ot_fut: value.ot_fut, - gf2: value.gf2, - encoder_seed: value.encoder_seed, - start_time: value.start_time, - server_ephemeral_key: value.server_ephemeral_key, - handshake_commitment: value.handshake_commitment, - sent_len: value.sent_len, - recv_len: value.recv_len, - } - } -} - -/// Verifying state. -pub struct Verify { - pub(crate) mux_ctrl: MuxControl, - pub(crate) mux_fut: MuxFuture, - - pub(crate) vm: DEAPVm, - pub(crate) ot_send: SharedSender, - pub(crate) ot_recv: SharedReceiver, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterReceiver, - - pub(crate) start_time: u64, - pub(crate) server_ephemeral_key: PublicKey, - pub(crate) handshake_commitment: Hash, - pub(crate) sent_len: usize, - pub(crate) recv_len: usize, - - pub(crate) channel: Option>>, - pub(crate) verify_thread: Option>, -} - -opaque_debug::implement!(Verify); - -impl From for Verify { - fn from(value: Closed) -> Self { - Self { - mux_ctrl: value.mux_ctrl, - mux_fut: value.mux_fut, - vm: value.vm, - ot_send: value.ot_send, - ot_recv: value.ot_recv, - ot_fut: value.ot_fut, - gf2: value.gf2, - start_time: value.start_time, - server_ephemeral_key: value.server_ephemeral_key, - handshake_commitment: value.handshake_commitment, - sent_len: value.sent_len, - recv_len: value.recv_len, - channel: None, - verify_thread: None, - } - } -} - -impl VerifierState for Initialized {} -impl VerifierState for Setup {} -impl VerifierState for Closed {} -impl VerifierState for Notarize {} -impl VerifierState for Verify {} - -mod sealed { - pub trait Sealed {} - impl Sealed for super::Initialized {} - impl Sealed for super::Setup {} - impl Sealed for super::Closed {} - impl Sealed for super::Notarize {} - impl Sealed for super::Verify {} -} diff --git a/tlsn/tlsn-verifier/src/tls/verify.rs b/tlsn/tlsn-verifier/src/tls/verify.rs deleted file mode 100644 index 8535834fa9..0000000000 --- a/tlsn/tlsn-verifier/src/tls/verify.rs +++ /dev/null @@ -1,197 +0,0 @@ -//! This module handles the verification phase of the verifier. -//! -//! The TLS verifier is an application-specific verifier. - -use super::{state::Verify as VerifyState, Verifier, VerifierError}; -use futures::{FutureExt, StreamExt, TryFutureExt}; -use mpz_circuits::types::Value; -use mpz_garble::{Memory, Verify, Vm}; -use mpz_share_conversion::ShareConversionVerify; -use tlsn_core::{ - msg::TlsnMessage, proof::SessionInfo, transcript::get_value_ids, Direction, HandshakeSummary, - RedactedTranscript, TranscriptSlice, -}; -use utils_aio::{expect_msg_or_err, mux::MuxChannel}; - -#[cfg(feature = "tracing")] -use tracing::info; - -impl Verifier { - /// Receives the **purported** transcript from the Prover. - /// - /// # Warning - /// - /// The content of the received transcripts can not be considered authentic until after finalization. - pub async fn receive( - &mut self, - ) -> Result<(RedactedTranscript, RedactedTranscript), VerifierError> { - let verify_fut = async { - // Create a new channel and vm thread if not already present - let channel = if let Some(ref mut channel) = self.state.channel { - channel - } else { - self.state.channel = Some(self.state.mux_ctrl.get_channel("prove-verify").await?); - self.state.channel.as_mut().unwrap() - }; - - let verify_thread = if let Some(ref mut verify_thread) = self.state.verify_thread { - verify_thread - } else { - self.state.verify_thread = Some(self.state.vm.new_thread("prove-verify").await?); - self.state.verify_thread.as_mut().unwrap() - }; - - // Receive the proving info from the prover - let mut proving_info = expect_msg_or_err!(channel, TlsnMessage::ProvingInfo)?; - let mut cleartext = proving_info.cleartext.clone(); - - #[cfg(feature = "tracing")] - info!("Received proving info from prover"); - - // Check ranges - if proving_info.sent_ids.max().unwrap_or_default() > self.state.sent_len - || proving_info.recv_ids.max().unwrap_or_default() > self.state.recv_len - { - return Err(VerifierError::InvalidRange); - } - - // Now verify the transcript parts which the prover wants to reveal - let sent_value_ids = proving_info - .sent_ids - .iter_ranges() - .map(|r| get_value_ids(&r.into(), Direction::Sent).collect::>()); - let recv_value_ids = proving_info - .recv_ids - .iter_ranges() - .map(|r| get_value_ids(&r.into(), Direction::Received).collect::>()); - - let value_refs = sent_value_ids - .chain(recv_value_ids) - .map(|ids| { - let inner_refs = ids - .iter() - .map(|id| { - verify_thread - .get_value(id.as_str()) - .expect("Byte should be in VM memory") - }) - .collect::>(); - - verify_thread - .array_from_values(inner_refs.as_slice()) - .expect("Byte should be in VM Memory") - }) - .collect::>(); - - let values = proving_info - .sent_ids - .iter_ranges() - .chain(proving_info.recv_ids.iter_ranges()) - .map(|range| { - Value::Array(cleartext.drain(..range.len()).map(|b| (b).into()).collect()) - }) - .collect::>(); - - // Check that purported values are correct - verify_thread.verify(&value_refs, &values).await?; - - #[cfg(feature = "tracing")] - info!("Successfully verified purported cleartext"); - - // Create redacted transcripts - let mut transcripts = proving_info - .sent_ids - .iter_ranges() - .chain(proving_info.recv_ids.iter_ranges()) - .map(|range| { - TranscriptSlice::new( - range.clone(), - proving_info.cleartext.drain(..range.len()).collect(), - ) - }) - .collect::>(); - - let recv_transcripts = - transcripts.split_off(proving_info.sent_ids.iter_ranges().count()); - let (sent_redacted, recv_redacted) = ( - RedactedTranscript::new(self.state.sent_len, transcripts), - RedactedTranscript::new(self.state.recv_len, recv_transcripts), - ); - - #[cfg(feature = "tracing")] - info!("Successfully created redacted transcripts"); - - Ok::<_, VerifierError>((sent_redacted, recv_redacted)) - }; - - futures::select! { - res = verify_fut.fuse() => res, - _ = &mut self.state.mux_fut => Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - } - } - - /// Verify the TLS session. - pub async fn finalize(self) -> Result { - let VerifyState { - mut mux_ctrl, - mut mux_fut, - mut vm, - ot_send, - ot_recv, - ot_fut, - mut gf2, - start_time, - server_ephemeral_key, - handshake_commitment, - .. - } = self.state; - - let finalize_fut = async { - let mut channel = mux_ctrl.get_channel("finalize").await?; - - // Finalize all MPC - let (mut ot_sender_actor, _, _) = futures::try_join!( - ot_fut, - ot_send.shutdown().map_err(VerifierError::from), - ot_recv.shutdown().map_err(VerifierError::from) - )?; - - ot_sender_actor.reveal().await?; - - vm.finalize() - .await - .map_err(|e| VerifierError::MpcError(Box::new(e)))?; - - gf2.verify() - .await - .map_err(|e| VerifierError::MpcError(Box::new(e)))?; - - let session_info = expect_msg_or_err!(channel, TlsnMessage::SessionInfo)?; - - #[cfg(feature = "tracing")] - info!("Finalized all MPC"); - - Ok::<_, VerifierError>(session_info) - }; - - let session_info = futures::select! { - res = finalize_fut.fuse() => res?, - _ = &mut mux_fut => Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; - - let handshake_summary = - HandshakeSummary::new(start_time, server_ephemeral_key, handshake_commitment); - - // Verify the TLS session - session_info.verify(&handshake_summary, self.config.cert_verifier())?; - - #[cfg(feature = "tracing")] - info!("Successfully verified session"); - - let mut mux_ctrl = mux_ctrl.into_inner(); - - futures::try_join!(mux_ctrl.close().map_err(VerifierError::from), mux_fut)?; - - Ok(session_info) - } -}