From 2aeabff24fc58ecad701f1177a15642e68292587 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Thu, 8 Feb 2024 10:31:06 -0800 Subject: [PATCH 01/40] Rough WIP - adding some notes related to kicking off the audit. --- docs/audit.md | 91 +++++++++++++++++++++++++++++++++++++++++ docs/merkle-sum-trie.md | 21 +++++----- docs/smt.md | 57 +++++++++++++------------- 3 files changed, 131 insertions(+), 38 deletions(-) create mode 100644 docs/audit.md diff --git a/docs/audit.md b/docs/audit.md new file mode 100644 index 0000000..f1852f1 --- /dev/null +++ b/docs/audit.md @@ -0,0 +1,91 @@ +# Audit + +- [Audit](#audit) + +## Pre-Audit Checklist + +### Checklist Requirements + +Pre-Audit Checklist: + +A reminder to provide us with the must haves 3 business days prior to the audit start days to avoid delays: + +**Must haves:** + +- A URL of the repository containing the source code +- The release branch / commit hash to be reviewed +- An explicit list of files in scope and out of scope for the security audit +- Robust and comprehensive documentation describing the intended functionality of the system + +**Nice-to-haves:** + +- Clear instructions for setting up the system and run the tests (usually in the README file) +- Any past audits +- Any tooling output logs +- Output generated by running the test suite +- Test coverage report + +Please disregard what may be irrelevant in the above list for this particular audit. +Reminder that we cannot accept any scope changes during the course of the audit. + +And more on audit preparation in this blogpost: https://medium.com/thesis-defense/why-crypto-needs-security-audits-d12f3909ac21 - thanks! + +### Checklist Response + +**Repository**: https://github.com/pokt-network/smt + +- **Branch**: `main` +- **Hash**: `868237978c0b3c0e2added161b36eeb7a3dc93b0` + +**Documentation** + +- **Background**: [Relay Mining](https://arxiv.org/abs/2305.10672) +- **Technical Documentation**: https://github.com/pokt-network/smt/tree/main/docs +- _NOTE: we may integrate this into [https://dev.poktroll.com](https://dev.poktroll.com/) (which is out of scope) in the future_ + +**Files** + +- Nothing is explicitly out of scope but the focus should be on the files below +- The following is a manually filtered list of files after running `tree -P '*.go' -I '*_test.go'` + +```bash + . + ├── errors.go + ├── hasher.go + ├── kvstore + │   ├── badger + │   │   ├── errors.go + │   │   ├── godoc.go + │   │   ├── interface.go + │   │   └── kvstore.go + │   ├── interfaces.go + ├── options.go + ├── proofs.go + ├── smst.go + ├── smt.go + ├── types.go + └── utils.go +``` + +**Makefile** + +- Running `make` in the root of the repo shows a list of options +- This gives access to tests, benchmarks, etc... + +```bash +make +help Prints all the targets in all the Makefiles +list List all make targets +test_all runs the test suite +test_badger runs the badger KVStore submodule's test suite +mod_tidy runs go mod tidy for all (sub)modules +go_docs Generate documentation for the project +benchmark_all runs all benchmarks +benchmark_smt runs all benchmarks for the SMT +benchmark_smt_fill runs a benchmark on filling the SMT with different amounts of values +benchmark_smt_ops runs the benchmarks testing different operations on the SMT against different sized tries +benchmark_smst runs all benchmarks for the SMST +benchmark_smst_fill runs a benchmark on filling the SMST with different amounts of values +benchmark_smst_ops runs the benchmarks test different operations on the SMST against different sized tries +benchmark_proof_sizes runs the benchmarks test the proof sizes for different sized tries +``` diff --git a/docs/merkle-sum-trie.md b/docs/merkle-sum-trie.md index 299c0e6..4c8e96f 100644 --- a/docs/merkle-sum-trie.md +++ b/docs/merkle-sum-trie.md @@ -2,16 +2,17 @@ -- [Overview](#overview) -- [Implementation](#implementation) - * [Sum Encoding](#sum-encoding) - * [Digests](#digests) - * [Visualisations](#visualisations) - + [General Trie Structure](#general-trie-structure) - + [Binary Sum Digests](#binary-sum-digests) -- [Sum](#sum) -- [Roots](#roots) -- [Nil Values](#nil-values) +- [Sparse Merkle Sum Trie (smst)](#sparse-merkle-sum-trie-smst) + - [Overview](#overview) + - [Implementation](#implementation) + - [Sum Encoding](#sum-encoding) + - [Digests](#digests) + - [Visualisations](#visualisations) + - [General Trie Structure](#general-trie-structure) + - [Binary Sum Digests](#binary-sum-digests) + - [Sum](#sum) + - [Roots](#roots) + - [Nil Values](#nil-values) diff --git a/docs/smt.md b/docs/smt.md index bae85d2..3731328 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -2,34 +2,35 @@ -- [Overview](#overview) -- [Implementation](#implementation) - * [Inner Nodes](#inner-nodes) - * [Extension Nodes](#extension-nodes) - * [Leaf Nodes](#leaf-nodes) - * [Lazy Nodes](#lazy-nodes) - * [Lazy Loading](#lazy-loading) - * [Visualisations](#visualisations) - + [General Trie Structure](#general-trie-structure) - + [Lazy Nodes](#lazy-nodes-1) -- [Paths](#paths) - * [Visualisation](#visualisation) -- [Values](#values) - * [Nil values](#nil-values) -- [Hashers & Digests](#hashers--digests) -- [Roots](#roots) -- [Proofs](#proofs) - * [Verification](#verification) - * [Closest Proof](#closest-proof) - + [Closest Proof Use Cases](#closest-proof-use-cases) - * [Compression](#compression) - * [Serialisation](#serialisation) -- [Database](#database) - * [Database Submodules](#database-submodules) - + [SimpleMap](#simplemap) - + [Badger](#badger) - * [Data Loss](#data-loss) -- [Sparse Merkle Sum Trie](#sparse-merkle-sum-trie) +- [smt](#smt) + - [Overview](#overview) + - [Implementation](#implementation) + - [Inner Nodes](#inner-nodes) + - [Extension Nodes](#extension-nodes) + - [Leaf Nodes](#leaf-nodes) + - [Lazy Nodes](#lazy-nodes) + - [Lazy Loading](#lazy-loading) + - [Visualisations](#visualisations) + - [General Trie Structure](#general-trie-structure) + - [Lazy Nodes](#lazy-nodes-1) + - [Paths](#paths) + - [Visualisation](#visualisation) + - [Values](#values) + - [Nil values](#nil-values) + - [Hashers \& Digests](#hashers--digests) + - [Roots](#roots) + - [Proofs](#proofs) + - [Verification](#verification) + - [Closest Proof](#closest-proof) + - [Closest Proof Use Cases](#closest-proof-use-cases) + - [Compression](#compression) + - [Serialisation](#serialisation) + - [Database](#database) + - [Database Submodules](#database-submodules) + - [SimpleMap](#simplemap) + - [Badger](#badger) + - [Data Loss](#data-loss) + - [Sparse Merkle Sum Trie](#sparse-merkle-sum-trie) From b876799248885b9b349a43a2ffbc9bca834ffa1c Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Thu, 8 Feb 2024 11:17:34 -0800 Subject: [PATCH 02/40] WIP - Going through the extension node documentation --- docs/faq.md | 13 +++++++++++ docs/smt.md | 44 +++++++++++++++++++----------------- smt.go | 64 ++++++++++++++++++++++++++++++++++++----------------- 3 files changed, 81 insertions(+), 40 deletions(-) create mode 100644 docs/faq.md diff --git a/docs/faq.md b/docs/faq.md new file mode 100644 index 0000000..aea8b75 --- /dev/null +++ b/docs/faq.md @@ -0,0 +1,13 @@ +# FAQ + +- [Implementation](#implementation) + - [What's the story behind Extension Node Implementation?](#whats-the-story-behind-extension-node-implementation) + +This documentation is meant to capture common questions that come up and act +as a supplement or secondary reference to the primary documentation. + +## Implementation + +### What's the story behind Extension Node Implementation? + +[smt](./smt.md#extension-nodes) diff --git a/docs/smt.md b/docs/smt.md index 3731328..2fd73c9 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -5,9 +5,9 @@ - [smt](#smt) - [Overview](#overview) - [Implementation](#implementation) + - [Leaf Nodes](#leaf-nodes) - [Inner Nodes](#inner-nodes) - [Extension Nodes](#extension-nodes) - - [Leaf Nodes](#leaf-nodes) - [Lazy Nodes](#lazy-nodes) - [Lazy Loading](#lazy-loading) - [Visualisations](#visualisations) @@ -51,26 +51,38 @@ See [smt.go](../smt.go) for more details on the implementation. The SMT has 4 node types that are used to construct the trie: -- Inner Nodes - - Prefixed `[]byte{1}` - - `digest = hash([]byte{1} + leftChild.digest + rightChild.digest)` -- Extension Nodes - - Prefixed `[]byte{2}` - - `digest = hash([]byte{2} + pathBounds + path + child.digest)` -- Leaf Nodes - - Prefixed `[]byte{0}` - - `digest = hash([]byte{0} + path + value)` +- [Inner Nodes](#inner-nodes) +- [Extension Nodes](#extension-nodes) +- [Leaf Nodes](#leaf-nodes) - Lazy Nodes - Prefix of the actual node type is stored in the persisted digest as determined above - `digest = persistedDigest` +### Leaf Nodes + +Leaf nodes store the full path which they represent and also the hash of the +value they store. The `digest` of a leaf node is the hash of the leaf nodes path +and value concatenated. + +The SMT stores only the hashes of the values in the trie, not the raw values +themselves. In order to store the raw values in the underlying database the +option `WithValueHasher(nil)` must be passed into the `NewSparseMerkleTrie` +constructor. + +- _Prefix_: `[]byte{0}` +- _Digest_: `hash([]byte{0} + path + value)` + ### Inner Nodes Inner nodes represent a branch in the trie with two **non-nil** child nodes. + The inner node has an internal `digest` which represents the hash of the child nodes concatenated hashes. +- _Prefix_: `[]byte{1}` +- _Digest_: `hash([]byte{1} + leftChild.digest + rightChild.digest)` + ### Extension Nodes Extension nodes represent a singly linked chain of inner nodes, with a single @@ -79,16 +91,8 @@ the path and bounds of the path they represent. The `digest` of an extension node is the hash of its path bounds, the path itself and the child nodes digest concatenated. -### Leaf Nodes - -Leaf nodes store the full path which they represent and also the hash of the -value they store. The `digest` of a leaf node is the hash of the leaf nodes path -and value concatenated. - -The SMT stores only the hashes of the values in the trie, not the raw values -themselves. In order to store the raw values in the underlying database the -option `WithValueHasher(nil)` must be passed into the `NewSparseMerkleTrie` -constructor. +- _Prefix_: `[]byte{2}` +- _Digest_: `hash([]byte{2} + pathBounds + path + child.digest)` ### Lazy Nodes diff --git a/smt.go b/smt.go index 558db63..6ef2e81 100644 --- a/smt.go +++ b/smt.go @@ -13,21 +13,27 @@ var ( _ SparseMerkleTrie = (*SMT)(nil) ) +// A high-level interface that captures the behaviour of all types of nodes type trieNode interface { - // when committing a node to disk, skip if already persisted + // Persisted returns a boolean to determine whether or not the node + // has been persisted to disk or only held in memory. + // It can be used skip unnecessary iops if already persisted Persisted() bool + + // The digest of the node, returning a cached value if available. CachedDigest() []byte } -// A branch within the trie +// A branch within the binary trie pointing to a left & right child. type innerNode struct { - // Both child nodes are always non-nil + // Left and right child nodes. + // Both child nodes are always expected to be non-nil. leftChild, rightChild trieNode persisted bool digest []byte } -// Stores data and full path +// A leaf node storing a key-value pair for a full path. type leafNode struct { path []byte valueHash []byte @@ -39,7 +45,7 @@ type leafNode struct { type extensionNode struct { path []byte // Offsets into path slice of bounds defining actual path segment. - // Note: assumes path is <=256 bits + // NOTE: assumes path is <=256 bits pathBounds [2]byte // Child is always an inner node, or lazy. child trieNode @@ -171,23 +177,29 @@ func (smt *SMT) update( return newLeaf, nil } if leaf, ok := node.(*leafNode); ok { - prefixlen := countCommonPrefixBits(path, leaf.path, depth) - if prefixlen == smt.depth() { // replace leaf if paths are equal + prefixLen := countCommonPrefixBits(path, leaf.path, depth) + if prefixLen == smt.depth() { // replace leaf if paths are equal smt.addOrphan(orphans, node) return newLeaf, nil } // We insert an "extension" representing multiple single-branch inner nodes last := &node - if depth < prefixlen { + if depth < prefixLen { // note: this keeps path slice alive - GC inefficiency? if depth > 0xff { panic("invalid depth") } - ext := extensionNode{path: path, pathBounds: [2]byte{byte(depth), byte(prefixlen)}} + ext := extensionNode{ + path: path, + pathBounds: [2]byte{ + byte(depth), + byte(prefixLen), + }, + } *last = &ext last = &ext.child } - if getPathBit(path, prefixlen) == left { + if getPathBit(path, prefixLen) == left { *last = &innerNode{leftChild: newLeaf, rightChild: leaf} } else { *last = &innerNode{leftChild: leaf, rightChild: newLeaf} @@ -685,14 +697,22 @@ func (smt *SMT) addOrphan(orphans *[][]byte, node trieNode) { } } -func (node *leafNode) Persisted() bool { return node.persisted } -func (node *innerNode) Persisted() bool { return node.persisted } -func (node *lazyNode) Persisted() bool { return true } +// TODO_IMPROVE: Lots of opportunity to modularize and improve the code here. + +func (node *leafNode) Persisted() bool { return node.persisted } + +func (node *innerNode) Persisted() bool { return node.persisted } + +func (node *lazyNode) Persisted() bool { return true } + func (node *extensionNode) Persisted() bool { return node.persisted } -func (node *leafNode) CachedDigest() []byte { return node.digest } -func (node *innerNode) CachedDigest() []byte { return node.digest } -func (node *lazyNode) CachedDigest() []byte { return node.digest } +func (node *leafNode) CachedDigest() []byte { return node.digest } + +func (node *innerNode) CachedDigest() []byte { return node.digest } + +func (node *lazyNode) CachedDigest() []byte { return node.digest } + func (node *extensionNode) CachedDigest() []byte { return node.digest } func (inner *innerNode) setDirty() { @@ -733,7 +753,8 @@ func (ext *extensionNode) commonPrefix(path []byte) int { } func (ext *extensionNode) pathStart() int { return int(ext.pathBounds[0]) } -func (ext *extensionNode) pathEnd() int { return int(ext.pathBounds[1]) } + +func (ext *extensionNode) pathEnd() int { return int(ext.pathBounds[1]) } // Splits the node in-place; returns replacement node, child node at the split, and split depth func (ext *extensionNode) split(path []byte, depth int) (trieNode, *trieNode, int) { @@ -780,9 +801,12 @@ func (ext *extensionNode) split(path []byte, depth int) (trieNode, *trieNode, in *tail = child } else { *tail = &extensionNode{ - path: ext.path, - pathBounds: [2]byte{byte(index + 1), ext.pathBounds[1]}, - child: child, + path: ext.path, + pathBounds: [2]byte{ + byte(index + 1), + ext.pathBounds[1], + }, + child: child, } } ext.pathBounds[1] = byte(index) From d83d3749114f886d5496d6325b379aa2e6118c75 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Thu, 8 Feb 2024 11:57:05 -0800 Subject: [PATCH 03/40] WIP - Moved extension node into its own file and removed the resolver helper --- docs/faq.md | 16 ++- docs/merkle-sum-trie.md | 186 ++++++++++++++++----------------- docs/smt.md | 13 ++- extension_node.go | 128 +++++++++++++++++++++++ hasher.go | 2 + smst.go | 4 +- smt.go | 224 ++++++++-------------------------------- utils.go | 11 +- 8 files changed, 296 insertions(+), 288 deletions(-) create mode 100644 extension_node.go diff --git a/docs/faq.md b/docs/faq.md index aea8b75..d66b496 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -1,13 +1,27 @@ # FAQ +- [History](#history) + - [Fork](#fork) - [Implementation](#implementation) - [What's the story behind Extension Node Implementation?](#whats-the-story-behind-extension-node-implementation) This documentation is meant to capture common questions that come up and act as a supplement or secondary reference to the primary documentation. +## History + +### Fork + +This library was originally forked off of [celestiaorg/smt](https://github.com/celestiaorg/smt) +which was archived on Feb 27th, 2023. + ## Implementation ### What's the story behind Extension Node Implementation? -[smt](./smt.md#extension-nodes) +The [SMT extension node](./smt.md#extension-nodes) is very similar to that of +Ethereum's [Modified Merkle Patricia Trie](https://ethereum.org/developers/docs/data-structures-and-encoding/patricia-merkle-trie). + +A quick primer on it can be found in this [5P;1R post](https://olshansky.substack.com/p/5p1r-ethereums-modified-merkle-patricia). + +WIP diff --git a/docs/merkle-sum-trie.md b/docs/merkle-sum-trie.md index 4c8e96f..ee881db 100644 --- a/docs/merkle-sum-trie.md +++ b/docs/merkle-sum-trie.md @@ -7,7 +7,7 @@ - [Implementation](#implementation) - [Sum Encoding](#sum-encoding) - [Digests](#digests) - - [Visualisations](#visualisations) + - [Visualizations](#visualizations) - [General Trie Structure](#general-trie-structure) - [Binary Sum Digests](#binary-sum-digests) - [Sum](#sum) @@ -65,34 +65,34 @@ The golang `encoding/binary` package is used to encode the sum with `binary.BigEndian.PutUint64(sumBz[:], sum)` into a byte array `sumBz`. In order for the SMST to include the sum into a leaf node the SMT the SMST -initialises the SMT with the `WithValueHasher(nil)` option so that the SMT does +initializes the SMT with the `WithValueHasher(nil)` option so that the SMT does **not** hash any values. The SMST will then hash the value and append the sum bytes to the end of the hashed value, using whatever `ValueHasher` was given to -the SMST on initialisation. +the SMST on initialization. ```mermaid graph TD - subgraph KVS[Key-Value-Sum] - K1["Key: foo"] - K2["Value: bar"] - K3["Sum: 10"] - end - subgraph SMST[SMST] - SS1[ValueHasher: SHA256] - subgraph SUM["SMST.Update()"] - SU1["valueHash = ValueHasher(Value)"] - SU2["sumBytes = binary(Sum)"] - SU3["valueHash = append(valueHash, sumBytes...)"] - end - end - subgraph SMT[SMT] - SM1[ValueHasher: nil] - subgraph UPD["SMT.Update()"] - U2["SMT.nodeStore.Set(Key, valueHash)"] - end - end - KVS --"Key + Value + Sum"--> SMST - SMST --"Key + valueHash"--> SMT + subgraph KVS[Key-Value-Sum] + K1["Key: foo"] + K2["Value: bar"] + K3["Sum: 10"] + end + subgraph SMST[SMST] + SS1[ValueHasher: SHA256] + subgraph SUM["SMST.Update()"] + SU1["valueHash = ValueHasher(Value)"] + SU2["sumBytes = binary(Sum)"] + SU3["valueHash = append(valueHash, sumBytes...)"] + end + end + subgraph SMT[SMT] + SM1[ValueHasher: nil] + subgraph UPD["SMT.Update()"] + U2["SMT.nodeStore.Set(Key, valueHash)"] + end + end + KVS --"Key + Value + Sum"--> SMST + SMST --"Key + valueHash"--> SMT ``` ### Digests @@ -129,10 +129,10 @@ Therefore for the following node types, the digests are computed as follows: This means that with a hasher such as `sha256.New()` whose hash size is `32 bytes`, the digest of any node will be `40 bytes` in length. -### Visualisations +### Visualizations The following diagrams are representations of how the trie and its components -can be visualised. +can be visualized. #### General Trie Structure @@ -143,45 +143,45 @@ nodes as an extra field. ```mermaid graph TB - subgraph Root - A1["Digest: Hash(Hash(Path+H1)+Hash(H2+(Hash(H3+H4)))+Binary(20))+Binary(20)"] + subgraph Root + A1["Digest: Hash(Hash(Path+H1)+Hash(H2+(Hash(H3+H4)))+Binary(20))+Binary(20)"] A2[Sum: 20] - end - subgraph BI[Inner Node] - B1["Digest: Hash(H2+(Hash(H3+H4))+Binary(12))+Binary(12)"] + end + subgraph BI[Inner Node] + B1["Digest: Hash(H2+(Hash(H3+H4))+Binary(12))+Binary(12)"] B2[Sum: 12] - end - subgraph BE[Extension Node] - B3["Digest: Hash(Path+H1+Binary(8))+Binary(8)"] + end + subgraph BE[Extension Node] + B3["Digest: Hash(Path+H1+Binary(8))+Binary(8)"] B4[Sum: 8] - end - subgraph CI[Inner Node] - C1["Digest: Hash(H3+H4+Binary(7))+Binary(7)"] + end + subgraph CI[Inner Node] + C1["Digest: Hash(H3+H4+Binary(7))+Binary(7)"] C2[Sum: 7] - end - subgraph CL[Leaf Node] - C3[Digest: H2] + end + subgraph CL[Leaf Node] + C3[Digest: H2] C4[Sum: 5] - end - subgraph DL1[Leaf Node] - D1[Digest: H3] + end + subgraph DL1[Leaf Node] + D1[Digest: H3] D2[Sum: 4] - end - subgraph DL2[Leaf Node] - D3[Digest: H4] + end + subgraph DL2[Leaf Node] + D3[Digest: H4] D4[Sum: 3] - end - subgraph EL[Leaf Node] - E1[Digest: H1] + end + subgraph EL[Leaf Node] + E1[Digest: H1] E2[Sum: 8] - end - Root-->|0| BE - Root-->|1| BI - BI-->|0| CL - BI-->|1| CI - CI-->|0| DL1 - CI-->|1| DL2 - BE-->EL + end + Root-->|0| BE + Root-->|1| BI + BI-->|0| CL + BI-->|1| CI + CI-->|0| DL1 + CI-->|1| DL2 + BE-->EL ``` #### Binary Sum Digests @@ -193,56 +193,56 @@ exception of the leaf nodes where the sum is shown as part of its value. ```mermaid graph TB - subgraph RI[Inner Node] - RIA["Root Hash: Hash(D6+D7+Binary(18))+Binary(18)"] + subgraph RI[Inner Node] + RIA["Root Hash: Hash(D6+D7+Binary(18))+Binary(18)"] RIB[Sum: 15] - end - subgraph I1[Inner Node] - I1A["D7: Hash(D1+D5+Binary(11))+Binary(11)"] + end + subgraph I1[Inner Node] + I1A["D7: Hash(D1+D5+Binary(11))+Binary(11)"] I1B[Sum: 11] - end - subgraph I2[Inner Node] - I2A["D6: Hash(D3+D4+Binary(7))+Binary(7)"] + end + subgraph I2[Inner Node] + I2A["D6: Hash(D3+D4+Binary(7))+Binary(7)"] I2B[Sum: 7] - end - subgraph L1[Leaf Node] - L1A[Path: 0b0010000] - L1B["Value: 0x01+Binary(6)"] + end + subgraph L1[Leaf Node] + L1A[Path: 0b0010000] + L1B["Value: 0x01+Binary(6)"] L1C["H1: Hash(Path+Value+Binary(6))"] L1D["D1: H1+Binary(6)"] - end - subgraph L3[Leaf Node] - L3A[Path: 0b1010000] - L3B["Value: 0x03+Binary(3)"] + end + subgraph L3[Leaf Node] + L3A[Path: 0b1010000] + L3B["Value: 0x03+Binary(3)"] L3C["H3: Hash(Path+Value+Binary(3))"] L3D["D3: H3+Binary(3)"] - end - subgraph L4[Leaf Node] - L4A[Path: 0b1100000] - L4B["Value: 0x04+Binary(4)"] + end + subgraph L4[Leaf Node] + L4A[Path: 0b1100000] + L4B["Value: 0x04+Binary(4)"] L4C["H4: Hash(Path+Value+Binary(4))"] L4D["D4: H4+Binary(4)"] - end - subgraph E1[Extension Node] - E1A[Path: 0b01100101] - E1B["Path Bounds: [2, 6)"] + end + subgraph E1[Extension Node] + E1A[Path: 0b01100101] + E1B["Path Bounds: [2, 6)"] E1C[Sum: 5] E1D["H5: Hash(Path+PathBounds+D2+Binary(5))"] E1E["D5: H5+Binary(5)"] - end - subgraph L2[Leaf Node] - L2A[Path: 0b01100101] - L2B["Value: 0x02+Binary(5)"] + end + subgraph L2[Leaf Node] + L2A[Path: 0b01100101] + L2B["Value: 0x02+Binary(5)"] L2C["H2: Hash(Path+Value+Hex(5))+Binary(5)"] L2D["D2: H2+Binary(5)"] - end - RI -->|0| I1 - RI -->|1| I2 - I1 -->|0| L1 - I1 -->|1| E1 - E1 --> L2 - I2 -->|0| L3 - I2 -->|1| L4 + end + RI -->|0| I1 + RI -->|1| I2 + I1 -->|0| L1 + I1 -->|1| E1 + E1 --> L2 + I2 -->|0| L3 + I2 -->|1| L4 ``` ## Sum diff --git a/docs/smt.md b/docs/smt.md index 2fd73c9..07e711b 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -54,7 +54,7 @@ The SMT has 4 node types that are used to construct the trie: - [Inner Nodes](#inner-nodes) - [Extension Nodes](#extension-nodes) - [Leaf Nodes](#leaf-nodes) -- Lazy Nodes +- [Lazy Nodes](#lazy-nodes) - Prefix of the actual node type is stored in the persisted digest as determined above - `digest = persistedDigest` @@ -205,12 +205,15 @@ Where `Hash(Hash1 + Hash2)` is the same root hash as the previous example. ## Paths -Paths are **only** stored in two types of nodes: Leaf nodes and Extension nodes. +Paths are **only** stored in two types of nodes: `Leaf` nodes and `Extension` nodes. -- Extension nodes contain not only the path they represent but also the path +- `Leaf` nodes contain: + - The full path which it represent + - The value stored at that path +- `Extension` nodes contain: + - not only the path they represent but also the path bounds (ie. the start and end of the path they cover). -- Leaf nodes contain the full path which they represent, as well as the value - stored at that path. + Inner nodes do **not** contain a path, as they represent a branch in the trie and not a path. As such their children, _if they are extension nodes or leaf diff --git a/extension_node.go b/extension_node.go new file mode 100644 index 0000000..1520d5d --- /dev/null +++ b/extension_node.go @@ -0,0 +1,128 @@ +package smt + +var _ trieNode = (*extensionNode)(nil) + +// A compressed chain of singly-linked inner nodes +type extensionNode struct { + path []byte + // Offsets into path slice of bounds defining actual path segment. + // NOTE: assumes path is <=256 bits + pathBounds [2]byte + // Child is always an inner node, or lazy. + child trieNode + // Bool whether or not the node has been flushed to disk + persisted bool + // The cached digest of the node trie + digest []byte +} + +func (node *extensionNode) Persisted() bool { + return node.persisted +} + +func (node *extensionNode) CachedDigest() []byte { + return node.digest +} + +func (ext *extensionNode) length() int { return int(ext.pathBounds[1] - ext.pathBounds[0]) } + +func (ext *extensionNode) setDirty() { + ext.persisted = false + ext.digest = nil +} + +// Returns length of matching prefix, and whether it's a full match +func (ext *extensionNode) match(path []byte, depth int) (int, bool) { + if depth != ext.pathStart() { + panic("depth != path_begin") + } + for i := ext.pathStart(); i < ext.pathEnd(); i++ { + if getPathBit(ext.path, i) != getPathBit(path, i) { + return i - ext.pathStart(), false + } + } + return ext.length(), true +} + +func (ext *extensionNode) pathStart() int { + return int(ext.pathBounds[0]) +} + +func (ext *extensionNode) pathEnd() int { + return int(ext.pathBounds[1]) +} + +// Splits the node in-place; returns replacement node, child node at the split, and split depth +func (ext *extensionNode) split(path []byte, depth int) (trieNode, *trieNode, int) { + if depth != ext.pathStart() { + panic("depth != path_begin") + } + index := ext.pathStart() + var myBit, branchBit int + for ; index < ext.pathEnd(); index++ { + myBit = getPathBit(ext.path, index) + branchBit = getPathBit(path, index) + if myBit != branchBit { + break + } + } + if index == ext.pathEnd() { + return ext, &ext.child, index + } + + child := ext.child + var branch innerNode + var head trieNode + var tail *trieNode + if myBit == left { + tail = &branch.leftChild + } else { + tail = &branch.rightChild + } + + // Split at first bit: chain starts with new node + if index == ext.pathStart() { + head = &branch + ext.pathBounds[0]++ // Shrink the extension from front + if ext.length() == 0 { + *tail = child + } else { + *tail = ext + } + } else { + // Split inside: chain ends at index + head = ext + ext.child = &branch + if index == ext.pathEnd()-1 { + *tail = child + } else { + *tail = &extensionNode{ + path: ext.path, + pathBounds: [2]byte{ + byte(index + 1), + ext.pathBounds[1], + }, + child: child, + } + } + ext.pathBounds[1] = byte(index) + } + var b trieNode = &branch + return head, &b, index +} + +// expand returns the inner node that represents the start of the singly +// linked list that this extension node represents +func (ext *extensionNode) expand() trieNode { + last := ext.child + for i := ext.pathEnd() - 1; i >= ext.pathStart(); i-- { + var next innerNode + if getPathBit(ext.path, i) == left { + next.leftChild = last + } else { + next.rightChild = last + } + last = &next + } + return last +} diff --git a/hasher.go b/hasher.go index 1b9b2cd..beb0007 100644 --- a/hasher.go +++ b/hasher.go @@ -35,9 +35,11 @@ type trieHasher struct { hasher hash.Hash zeroValue []byte } + type pathHasher struct { trieHasher } + type valueHasher struct { trieHasher } diff --git a/smst.go b/smst.go index 1de9b14..bbf94df 100644 --- a/smst.go +++ b/smst.go @@ -49,8 +49,8 @@ func ImportSparseMerkleSumTrie( options ...Option, ) *SMST { smst := NewSparseMerkleSumTrie(nodes, hasher, options...) - smst.trie = &lazyNode{root} - smst.savedRoot = root + smst.root = &lazyNode{root} + smst.rootHash = root return smst } diff --git a/smt.go b/smt.go index 6ef2e81..17676ff 100644 --- a/smt.go +++ b/smt.go @@ -41,18 +41,6 @@ type leafNode struct { digest []byte } -// A compressed chain of singly-linked inner nodes -type extensionNode struct { - path []byte - // Offsets into path slice of bounds defining actual path segment. - // NOTE: assumes path is <=256 bits - pathBounds [2]byte - // Child is always an inner node, or lazy. - child trieNode - persisted bool - digest []byte -} - // Represents an uncached, persisted node type lazyNode struct { digest []byte @@ -61,11 +49,12 @@ type lazyNode struct { // SMT is a Sparse Merkle Trie object that implements the SparseMerkleTrie interface type SMT struct { TrieSpec + // Backing key-value store for the node nodes kvstore.MapStore // Last persisted root hash - savedRoot []byte - // Current state of trie - trie trieNode + rootHash []byte + // The current view of the SMT + root trieNode // Lists of per-operation orphan sets orphans []orphanNodes } @@ -99,8 +88,8 @@ func ImportSparseMerkleTrie( options ...Option, ) *SMT { smt := NewSparseMerkleTrie(nodes, hasher, options...) - smt.trie = &lazyNode{root} - smt.savedRoot = root + smt.root = &lazyNode{root} + smt.rootHash = root return smt } @@ -109,7 +98,7 @@ func (smt *SMT) Get(key []byte) ([]byte, error) { path := smt.ph.Path(key) var leaf *leafNode var err error - for node, depth := &smt.trie, 0; ; depth++ { + for node, depth := &smt.root, 0; ; depth++ { *node, err = smt.resolveLazy(*node) if err != nil { return nil, err @@ -147,24 +136,29 @@ func (smt *SMT) Get(key []byte) ([]byte, error) { return leaf.valueHash, nil } -// Update sets the value for the given key, to the digest of the provided value +// Update inserts the `value` for the given `key` into the SMT func (smt *SMT) Update(key []byte, value []byte) error { + // Expand path := smt.ph.Path(key) valueHash := smt.digestValue(value) var orphans orphanNodes - trie, err := smt.update(smt.trie, 0, path, valueHash, &orphans) + trie, err := smt.update(smt.root, 0, path, valueHash, &orphans) if err != nil { return err } - smt.trie = trie + smt.root = trie if len(orphans) > 0 { smt.orphans = append(smt.orphans, orphans) } return nil } +// Internal helper to the `Update` method func (smt *SMT) update( - node trieNode, depth int, path, value []byte, orphans *orphanNodes, + node trieNode, + depth int, + path, value []byte, + orphans *orphanNodes, ) (trieNode, error) { node, err := smt.resolveLazy(node) if err != nil { @@ -200,9 +194,15 @@ func (smt *SMT) update( last = &ext.child } if getPathBit(path, prefixLen) == left { - *last = &innerNode{leftChild: newLeaf, rightChild: leaf} + *last = &innerNode{ + leftChild: newLeaf, + rightChild: leaf, + } } else { - *last = &innerNode{leftChild: leaf, rightChild: newLeaf} + *last = &innerNode{ + leftChild: leaf, + rightChild: newLeaf, + } } return node, nil } @@ -239,11 +239,11 @@ func (smt *SMT) update( func (smt *SMT) Delete(key []byte) error { path := smt.ph.Path(key) var orphans orphanNodes - trie, err := smt.delete(smt.trie, 0, path, &orphans) + trie, err := smt.delete(smt.root, 0, path, &orphans) if err != nil { return err } - smt.trie = trie + smt.root = trie if len(orphans) > 0 { smt.orphans = append(smt.orphans, orphans) } @@ -333,7 +333,7 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) { var siblings []trieNode var sib trieNode - node := smt.trie + node := smt.root for depth := 0; depth < smt.depth(); depth++ { node, err = smt.resolveLazy(node) if err != nil { @@ -433,7 +433,7 @@ func (smt *SMT) ProveClosest(path []byte) ( FlippedBits: make([]int, 0), } - node := smt.trie + node := smt.root depth := 0 // continuously traverse the trie until we hit a leaf node for depth < smt.depth() { @@ -543,29 +543,20 @@ func (smt *SMT) ProveClosest(path []byte) ( return proof, nil } -//nolint:unused -func (smt *SMT) recursiveLoad(hash []byte) (trieNode, error) { - return smt.resolve(hash, smt.recursiveLoad) -} - // resolves a stub into a cached node func (smt *SMT) resolveLazy(node trieNode) (trieNode, error) { stub, ok := node.(*lazyNode) if !ok { return node, nil } - resolver := func(hash []byte) (trieNode, error) { - return &lazyNode{hash}, nil - } - ret, err := resolve(smt, stub.digest, resolver) + ret, err := resolve(smt, stub.digest) if err != nil { return node, err } return ret, nil } -func (smt *SMT) resolve(hash []byte, resolver func([]byte) (trieNode, error), -) (ret trieNode, err error) { +func (smt *SMT) resolve(hash []byte) (ret trieNode, err error) { if bytes.Equal(smt.th.placeholder(), hash) { return } @@ -583,27 +574,18 @@ func (smt *SMT) resolve(hash []byte, resolver func([]byte) (trieNode, error), pathBounds, path, childHash := parseExtension(data, smt.ph) ext.path = path copy(ext.pathBounds[:], pathBounds) - ext.child, err = resolver(childHash) - if err != nil { - return - } + ext.child = &lazyNode{childHash} return &ext, nil } leftHash, rightHash := smt.th.parseNode(data) inner := innerNode{persisted: true, digest: hash} - inner.leftChild, err = resolver(leftHash) - if err != nil { - return - } - inner.rightChild, err = resolver(rightHash) - if err != nil { - return - } + inner.leftChild = &lazyNode{leftHash} + inner.rightChild = &lazyNode{rightHash} return &inner, nil } -func (smt *SMT) resolveSum(hash []byte, resolver func([]byte) (trieNode, error), -) (ret trieNode, err error) { +// resolveSum resolves +func (smt *SMT) resolveSum(hash []byte) (ret trieNode, err error) { if bytes.Equal(placeholder(smt.Spec()), hash) { return } @@ -621,22 +603,13 @@ func (smt *SMT) resolveSum(hash []byte, resolver func([]byte) (trieNode, error), pathBounds, path, childHash, _ := parseSumExtension(data, smt.ph) ext.path = path copy(ext.pathBounds[:], pathBounds) - ext.child, err = resolver(childHash) - if err != nil { - return - } + ext.child = &lazyNode{childHash} return &ext, nil } leftHash, rightHash := smt.th.parseSumNode(data) inner := innerNode{persisted: true, digest: hash} - inner.leftChild, err = resolver(leftHash) - if err != nil { - return - } - inner.rightChild, err = resolver(rightHash) - if err != nil { - return - } + inner.leftChild = &lazyNode{leftHash} + inner.rightChild = &lazyNode{rightHash} return &inner, nil } @@ -652,10 +625,10 @@ func (smt *SMT) Commit() (err error) { } } smt.orphans = nil - if err = smt.commit(smt.trie); err != nil { + if err = smt.commit(smt.root); err != nil { return } - smt.savedRoot = smt.Root() + smt.rootHash = smt.Root() return } @@ -688,7 +661,7 @@ func (smt *SMT) commit(node trieNode) error { // Root returns the root hash of the trie func (smt *SMT) Root() MerkleRoot { - return hashNode(smt.Spec(), smt.trie) + return hashNode(smt.Spec(), smt.root) } func (smt *SMT) addOrphan(orphans *[][]byte, node trieNode) { @@ -705,128 +678,13 @@ func (node *innerNode) Persisted() bool { return node.persisted } func (node *lazyNode) Persisted() bool { return true } -func (node *extensionNode) Persisted() bool { return node.persisted } - func (node *leafNode) CachedDigest() []byte { return node.digest } func (node *innerNode) CachedDigest() []byte { return node.digest } func (node *lazyNode) CachedDigest() []byte { return node.digest } -func (node *extensionNode) CachedDigest() []byte { return node.digest } - func (inner *innerNode) setDirty() { inner.persisted = false inner.digest = nil } - -func (ext *extensionNode) length() int { return int(ext.pathBounds[1] - ext.pathBounds[0]) } - -func (ext *extensionNode) setDirty() { - ext.persisted = false - ext.digest = nil -} - -// Returns length of matching prefix, and whether it's a full match -func (ext *extensionNode) match(path []byte, depth int) (int, bool) { - if depth != ext.pathStart() { - panic("depth != path_begin") - } - for i := ext.pathStart(); i < ext.pathEnd(); i++ { - if getPathBit(ext.path, i) != getPathBit(path, i) { - return i - ext.pathStart(), false - } - } - return ext.length(), true -} - -//nolint:unused -func (ext *extensionNode) commonPrefix(path []byte) int { - count := 0 - for i := ext.pathStart(); i < ext.pathEnd(); i++ { - if getPathBit(ext.path, i) != getPathBit(path, i) { - break - } - count++ - } - return count -} - -func (ext *extensionNode) pathStart() int { return int(ext.pathBounds[0]) } - -func (ext *extensionNode) pathEnd() int { return int(ext.pathBounds[1]) } - -// Splits the node in-place; returns replacement node, child node at the split, and split depth -func (ext *extensionNode) split(path []byte, depth int) (trieNode, *trieNode, int) { - if depth != ext.pathStart() { - panic("depth != path_begin") - } - index := ext.pathStart() - var myBit, branchBit int - for ; index < ext.pathEnd(); index++ { - myBit = getPathBit(ext.path, index) - branchBit = getPathBit(path, index) - if myBit != branchBit { - break - } - } - if index == ext.pathEnd() { - return ext, &ext.child, index - } - - child := ext.child - var branch innerNode - var head trieNode - var tail *trieNode - if myBit == left { - tail = &branch.leftChild - } else { - tail = &branch.rightChild - } - - // Split at first bit: chain starts with new node - if index == ext.pathStart() { - head = &branch - ext.pathBounds[0]++ // Shrink the extension from front - if ext.length() == 0 { - *tail = child - } else { - *tail = ext - } - } else { - // Split inside: chain ends at index - head = ext - ext.child = &branch - if index == ext.pathEnd()-1 { - *tail = child - } else { - *tail = &extensionNode{ - path: ext.path, - pathBounds: [2]byte{ - byte(index + 1), - ext.pathBounds[1], - }, - child: child, - } - } - ext.pathBounds[1] = byte(index) - } - var b trieNode = &branch - return head, &b, index -} - -// expand returns the inner node that represents the start of the singly -// linked list that this extension node represents -func (ext *extensionNode) expand() trieNode { - last := ext.child - for i := ext.pathEnd() - 1; i >= ext.pathStart(); i-- { - var next innerNode - if getPathBit(ext.path, i) == left { - next.leftChild = last - } else { - next.rightChild = last - } - last = &next - } - return last -} diff --git a/utils.go b/utils.go index cc5bca8..bb268c2 100644 --- a/utils.go +++ b/utils.go @@ -9,7 +9,8 @@ type nilPathHasher struct { } func (n *nilPathHasher) Path(key []byte) []byte { return key[:n.hashSize] } -func (n *nilPathHasher) PathSize() int { return n.hashSize } + +func (n *nilPathHasher) PathSize() int { return n.hashSize } func newNilPathHasher(hashSize int) PathHasher { return &nilPathHasher{hashSize: hashSize} @@ -200,10 +201,12 @@ func hashSumSerialization(smt *TrieSpec, data []byte) []byte { } // resolve resolves a lazy node depending on the trie type -func resolve(smt *SMT, hash []byte, resolver func([]byte) (trieNode, error), +func resolve( + smt *SMT, + hash []byte, ) (trieNode, error) { if smt.sumTrie { - return smt.resolveSum(hash, resolver) + return smt.resolveSum(hash) } - return smt.resolve(hash, resolver) + return smt.resolve(hash) } From 3cb019efdb7e1514c24aada0d1476f41115613e0 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Thu, 8 Feb 2024 12:10:02 -0800 Subject: [PATCH 04/40] Split all the node types into different files --- docs/smt.md | 26 +++++++++++++++----------- extension_node.go | 15 +++++++++++++-- inner_node.go | 23 +++++++++++++++++++++++ lazy_node.go | 14 ++++++++++++++ leaf_node.go | 17 +++++++++++++++++ smt.go | 47 +---------------------------------------------- 6 files changed, 83 insertions(+), 59 deletions(-) create mode 100644 inner_node.go create mode 100644 lazy_node.go create mode 100644 leaf_node.go diff --git a/docs/smt.md b/docs/smt.md index 07e711b..16220c5 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -61,13 +61,15 @@ The SMT has 4 node types that are used to construct the trie: ### Leaf Nodes -Leaf nodes store the full path which they represent and also the hash of the -value they store. The `digest` of a leaf node is the hash of the leaf nodes path -and value concatenated. +Leaf nodes store the full path associated with the `key`. A leaf node also +store the hash of the `value` stored. -The SMT stores only the hashes of the values in the trie, not the raw values -themselves. In order to store the raw values in the underlying database the -option `WithValueHasher(nil)` must be passed into the `NewSparseMerkleTrie` +The `digest` of a leaf node is the hash of concatenation of the leaf node's +path and value. + +By default, the SMT only stores the hashes of the values in the trie, and not the +raw values themselves. In order to store the raw values in the underlying database, +the option `WithValueHasher(nil)` must be passed into the `NewSparseMerkleTrie` constructor. - _Prefix_: `[]byte{0}` @@ -86,8 +88,11 @@ nodes concatenated hashes. ### Extension Nodes Extension nodes represent a singly linked chain of inner nodes, with a single -child. They are used to represent a common path in the trie and as such contain -the path and bounds of the path they represent. The `digest` of an extension +child. In other words, they are an optimization to avoid having a long chain of +inner nodes where each inner node only has one child. + +In other words, they are used to +represent a common path in the trie and as such contain the path and bounds of the path they represent. The `digest` of an extension node is the hash of its path bounds, the path itself and the child nodes digest concatenated. @@ -211,9 +216,8 @@ Paths are **only** stored in two types of nodes: `Leaf` nodes and `Extension` no - The full path which it represent - The value stored at that path - `Extension` nodes contain: - - not only the path they represent but also the path - bounds (ie. the start and end of the path they cover). - + - not only the path they represent but also the path + bounds (ie. the start and end of the path they cover). Inner nodes do **not** contain a path, as they represent a branch in the trie and not a path. As such their children, _if they are extension nodes or leaf diff --git a/extension_node.go b/extension_node.go index 1520d5d..8e57408 100644 --- a/extension_node.go +++ b/extension_node.go @@ -2,8 +2,11 @@ package smt var _ trieNode = (*extensionNode)(nil) -// A compressed chain of singly-linked inner nodes +// A compressed chain of singly-linked inner nodes. +// Instead of storing innerNodes pointing to other innerNodes, the extensionNode +// at `path` capture all the innerNodes from `pathBounds[0]` to `pathBounds[1]`. type extensionNode struct { + // The path to this extensionNode starting at the root. path []byte // Offsets into path slice of bounds defining actual path segment. // NOTE: assumes path is <=256 bits @@ -16,16 +19,24 @@ type extensionNode struct { digest []byte } +// Persisted satisfied the trieNode#Persisted interface func (node *extensionNode) Persisted() bool { return node.persisted } +// Persisted satisfied the trieNode#CachedDigest interface func (node *extensionNode) CachedDigest() []byte { return node.digest } -func (ext *extensionNode) length() int { return int(ext.pathBounds[1] - ext.pathBounds[0]) } +// length returns the length of the path segment (number of inner nodes replaced) +// by this single extensionNode +func (ext *extensionNode) length() int { + return int(ext.pathBounds[1] - ext.pathBounds[0]) +} +// setDirty marks the node as dirty (i.e. not flushed to disk) and clears +// its digest func (ext *extensionNode) setDirty() { ext.persisted = false ext.digest = nil diff --git a/inner_node.go b/inner_node.go new file mode 100644 index 0000000..e6b3a63 --- /dev/null +++ b/inner_node.go @@ -0,0 +1,23 @@ +package smt + +var _ trieNode = (*innerNode)(nil) + +// A branch within the binary trie pointing to a left & right child. +type innerNode struct { + // Left and right child nodes. + // Both child nodes are always expected to be non-nil. + leftChild, rightChild trieNode + persisted bool + digest []byte +} + +// Persisted satisfied the trieNode#Persisted interface +func (node *innerNode) Persisted() bool { return node.persisted } + +// Persisted satisfied the trieNode#CachedDigest interface +func (node *innerNode) CachedDigest() []byte { return node.digest } + +func (node *innerNode) setDirty() { + node.persisted = false + node.digest = nil +} diff --git a/lazy_node.go b/lazy_node.go new file mode 100644 index 0000000..a0fc822 --- /dev/null +++ b/lazy_node.go @@ -0,0 +1,14 @@ +package smt + +var _ trieNode = (*lazyNode)(nil) + +// Represents an uncached, persisted node +type lazyNode struct { + digest []byte +} + +// Persisted satisfied the trieNode#Persisted interface +func (node *lazyNode) Persisted() bool { return true } + +// Persisted satisfied the trieNode#CachedDigest interface +func (node *lazyNode) CachedDigest() []byte { return node.digest } diff --git a/leaf_node.go b/leaf_node.go new file mode 100644 index 0000000..4c894e4 --- /dev/null +++ b/leaf_node.go @@ -0,0 +1,17 @@ +package smt + +var _ trieNode = (*leafNode)(nil) + +// A leaf node storing a key-value pair for a full path. +type leafNode struct { + path []byte + valueHash []byte + persisted bool + digest []byte +} + +// Persisted satisfied the trieNode#Persisted interface +func (node *leafNode) Persisted() bool { return node.persisted } + +// Persisted satisfied the trieNode#CachedDigest interface +func (node *leafNode) CachedDigest() []byte { return node.digest } diff --git a/smt.go b/smt.go index 17676ff..8e646b2 100644 --- a/smt.go +++ b/smt.go @@ -7,11 +7,7 @@ import ( "github.com/pokt-network/smt/kvstore" ) -var ( - _ trieNode = (*innerNode)(nil) - _ trieNode = (*leafNode)(nil) - _ SparseMerkleTrie = (*SMT)(nil) -) +var _ SparseMerkleTrie = (*SMT)(nil) // A high-level interface that captures the behaviour of all types of nodes type trieNode interface { @@ -24,28 +20,6 @@ type trieNode interface { CachedDigest() []byte } -// A branch within the binary trie pointing to a left & right child. -type innerNode struct { - // Left and right child nodes. - // Both child nodes are always expected to be non-nil. - leftChild, rightChild trieNode - persisted bool - digest []byte -} - -// A leaf node storing a key-value pair for a full path. -type leafNode struct { - path []byte - valueHash []byte - persisted bool - digest []byte -} - -// Represents an uncached, persisted node -type lazyNode struct { - digest []byte -} - // SMT is a Sparse Merkle Trie object that implements the SparseMerkleTrie interface type SMT struct { TrieSpec @@ -669,22 +643,3 @@ func (smt *SMT) addOrphan(orphans *[][]byte, node trieNode) { *orphans = append(*orphans, node.CachedDigest()) } } - -// TODO_IMPROVE: Lots of opportunity to modularize and improve the code here. - -func (node *leafNode) Persisted() bool { return node.persisted } - -func (node *innerNode) Persisted() bool { return node.persisted } - -func (node *lazyNode) Persisted() bool { return true } - -func (node *leafNode) CachedDigest() []byte { return node.digest } - -func (node *innerNode) CachedDigest() []byte { return node.digest } - -func (node *lazyNode) CachedDigest() []byte { return node.digest } - -func (inner *innerNode) setDirty() { - inner.persisted = false - inner.digest = nil -} From 562d7fafcc8416b5098eb8396f7059c52c0b7a4a Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Thu, 8 Feb 2024 12:19:06 -0800 Subject: [PATCH 05/40] Remove the unnecessary resolve helper and rename default value helpers --- bulk_test.go | 2 +- hasher.go | 4 ++-- proofs.go | 6 +++--- smst.go | 4 ++-- smst_proofs_test.go | 10 +++++----- smst_test.go | 10 +++++----- smst_utils_test.go | 4 ++-- smt.go | 39 +++++++++++++++++++++------------------ smt_proofs_test.go | 8 ++++---- smt_test.go | 10 +++++----- smt_utils_test.go | 4 ++-- types.go | 7 +++++-- utils.go | 13 +------------ 13 files changed, 58 insertions(+), 63 deletions(-) diff --git a/bulk_test.go b/bulk_test.go index 0f378d1..a0ab126 100644 --- a/bulk_test.go +++ b/bulk_test.go @@ -85,7 +85,7 @@ func bulkOperations(t *testing.T, operations int, insert int, update int, delete if err != nil && err != ErrKeyNotFound { t.Fatalf("error: %v", err) } - kv[ki].val = defaultValue + kv[ki].val = defaultEmptyValue } } diff --git a/hasher.go b/hasher.go index beb0007..e537382 100644 --- a/hasher.go +++ b/hasher.go @@ -168,10 +168,10 @@ func encodeSumInner(leftData []byte, rightData []byte) []byte { rightSum := uint64(0) leftSumBz := leftData[len(leftData)-sumSize:] rightSumBz := rightData[len(rightData)-sumSize:] - if !bytes.Equal(leftSumBz, defaultSum[:]) { + if !bytes.Equal(leftSumBz, defaultEmptySum[:]) { leftSum = binary.BigEndian.Uint64(leftSumBz) } - if !bytes.Equal(rightSumBz, defaultSum[:]) { + if !bytes.Equal(rightSumBz, defaultEmptySum[:]) { rightSum = binary.BigEndian.Uint64(rightSumBz) } binary.BigEndian.PutUint64(sum[:], leftSum+rightSum) diff --git a/proofs.go b/proofs.go index 64ce171..714f4db 100644 --- a/proofs.go +++ b/proofs.go @@ -285,8 +285,8 @@ func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint6 binary.BigEndian.PutUint64(sumBz[:], sum) valueHash := spec.digestValue(value) valueHash = append(valueHash, sumBz[:]...) - if bytes.Equal(value, defaultValue) && sum == 0 { - valueHash = defaultValue + if bytes.Equal(value, defaultEmptyValue) && sum == 0 { + valueHash = defaultEmptyValue } smtSpec := &TrieSpec{ th: spec.th, @@ -331,7 +331,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v // Determine what the leaf hash should be. var currentHash, currentData []byte - if bytes.Equal(value, defaultValue) { // Non-membership proof. + if bytes.Equal(value, defaultEmptyValue) { // Non-membership proof. if proof.NonMembershipLeafData == nil { // Leaf is a placeholder value. currentHash = placeholder(spec) } else { // Leaf is an unrelated leaf. diff --git a/smst.go b/smst.go index bbf94df..48fa501 100644 --- a/smst.go +++ b/smst.go @@ -66,8 +66,8 @@ func (smst *SMST) Get(key []byte) ([]byte, uint64, error) { if err != nil { return nil, 0, err } - if bytes.Equal(valueHash, defaultValue) { - return defaultValue, 0, nil + if bytes.Equal(valueHash, defaultEmptyValue) { + return defaultEmptyValue, 0, nil } var weightBz [sumSize]byte copy(weightBz[:], valueHash[len(valueHash)-sumSize:]) diff --git a/smst_proofs_test.go b/smst_proofs_test.go index 21fd454..d231868 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -33,7 +33,7 @@ func TestSMST_Proof_Operations(t *testing.T) { proof, err = smst.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifySumProof(proof, placeholder(base), []byte("testKey3"), defaultValue, 0, base) + result, err = VerifySumProof(proof, placeholder(base), []byte("testKey3"), defaultEmptyValue, 0, base) require.NoError(t, err) require.True(t, result) result, err = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 5, base) @@ -112,7 +112,7 @@ func TestSMST_Proof_Operations(t *testing.T) { SideNodes: proof.SideNodes, NonMembershipLeafData: leafData, } - result, err = VerifySumProof(proof, root, []byte("testKey2"), defaultValue, 0, base) + result, err = VerifySumProof(proof, root, []byte("testKey2"), defaultEmptyValue, 0, base) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) @@ -120,16 +120,16 @@ func TestSMST_Proof_Operations(t *testing.T) { proof, err = smst.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultValue, 0, base) // valid + result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultEmptyValue, 0, base) // valid require.NoError(t, err) require.True(t, result) result, err = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 0, base) // wrong value require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultValue, 5, base) // wrong sum + result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultEmptyValue, 5, base) // wrong sum require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey3"), defaultValue, 0, base) // invalid proof + result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey3"), defaultEmptyValue, 0, base) // invalid proof require.NoError(t, err) require.False(t, result) } diff --git a/smst_test.go b/smst_test.go index d8e0b3c..868e856 100644 --- a/smst_test.go +++ b/smst_test.go @@ -38,7 +38,7 @@ func TestSMST_TrieUpdateBasic(t *testing.T) { // Test getting an empty key. value, sum, err = smst.GetValueSum([]byte("testKey")) require.NoError(t, err) - require.Equal(t, defaultValue, value) + require.Equal(t, defaultEmptyValue, value) require.Equal(t, uint64(0), sum) has, err = smst.Has([]byte("testKey")) @@ -132,7 +132,7 @@ func TestSMST_TrieDeleteBasic(t *testing.T) { value, sum, err := smst.GetValueSum([]byte("testKey")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") require.Equal(t, uint64(0), sum, "getting deleted key") has, err := smst.Has([]byte("testKey")) @@ -157,7 +157,7 @@ func TestSMST_TrieDeleteBasic(t *testing.T) { value, sum, err = smst.GetValueSum([]byte("testKey2")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") require.Equal(t, uint64(0), sum, "getting deleted key") value, sum, err = smst.GetValueSum([]byte("testKey")) @@ -179,7 +179,7 @@ func TestSMST_TrieDeleteBasic(t *testing.T) { value, sum, err = smst.GetValueSum([]byte("foo")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") require.Equal(t, uint64(0), sum, "getting deleted key") value, sum, err = smst.GetValueSum([]byte("testKey")) @@ -202,7 +202,7 @@ func TestSMST_TrieDeleteBasic(t *testing.T) { value, sum, err = smst.GetValueSum([]byte("testKey")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") require.Equal(t, uint64(0), sum, "getting deleted key") has, err = smst.Has([]byte("testKey")) diff --git a/smst_utils_test.go b/smst_utils_test.go index 21a9ac7..dcac64f 100644 --- a/smst_utils_test.go +++ b/smst_utils_test.go @@ -52,7 +52,7 @@ func (smst *SMSTWithStorage) GetValueSum(key []byte) ([]byte, uint64, error) { if err != nil { if errors.Is(err, ErrKeyNotFound) { // If key isn't found, return default value and sum - return defaultValue, 0, nil + return defaultEmptyValue, 0, nil } // Otherwise percolate up any other error return nil, 0, err @@ -69,5 +69,5 @@ func (smst *SMSTWithStorage) GetValueSum(key []byte) ([]byte, uint64, error) { // Has returns true if the value at the given key is non-default, false otherwise. func (smst *SMSTWithStorage) Has(key []byte) (bool, error) { val, sum, err := smst.GetValueSum(key) - return !bytes.Equal(defaultValue, val) || sum != 0, err + return !bytes.Equal(defaultEmptyValue, val) || sum != 0, err } diff --git a/smt.go b/smt.go index 8e646b2..f418e11 100644 --- a/smt.go +++ b/smt.go @@ -67,45 +67,49 @@ func ImportSparseMerkleTrie( return smt } -// Get returns the digest of the value stored at the given key +// Get returns the hash (i.e. digest) of the leaf value stored at the given key func (smt *SMT) Get(key []byte) ([]byte, error) { path := smt.ph.Path(key) + // The leaf node whose value will be returned var leaf *leafNode var err error - for node, depth := &smt.root, 0; ; depth++ { - *node, err = smt.resolveLazy(*node) + + // Loop throughout the entire trie to find the corresponding leaf for the + // given key. + for currNode, depth := &smt.root, 0; ; depth++ { + *currNode, err = smt.resolveLazy(*currNode) if err != nil { return nil, err } - if *node == nil { + if *currNode == nil { break } - if n, ok := (*node).(*leafNode); ok { + if n, ok := (*currNode).(*leafNode); ok { if bytes.Equal(path, n.path) { leaf = n } break } - if ext, ok := (*node).(*extensionNode); ok { - if _, match := ext.match(path, depth); !match { + if extNode, ok := (*currNode).(*extensionNode); ok { + if _, match := extNode.match(path, depth); !match { break } - depth += ext.length() - node = &ext.child - *node, err = smt.resolveLazy(*node) + depth += extNode.length() + currNode = &extNode.child + *currNode, err = smt.resolveLazy(*currNode) if err != nil { return nil, err } } - inner := (*node).(*innerNode) + inner := (*currNode).(*innerNode) if getPathBit(path, depth) == left { - node = &inner.leftChild + currNode = &inner.leftChild } else { - node = &inner.rightChild + currNode = &inner.rightChild } } if leaf == nil { - return defaultValue, nil + return defaultEmptyValue, nil } return leaf.valueHash, nil } @@ -523,11 +527,10 @@ func (smt *SMT) resolveLazy(node trieNode) (trieNode, error) { if !ok { return node, nil } - ret, err := resolve(smt, stub.digest) - if err != nil { - return node, err + if smt.sumTrie { + return smt.resolveSum(stub.digest) } - return ret, nil + return smt.resolve(stub.digest) } func (smt *SMT) resolve(hash []byte) (ret trieNode, err error) { diff --git a/smt_proofs_test.go b/smt_proofs_test.go index 2cf70c8..b1d5005 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -32,7 +32,7 @@ func TestSMT_Proof_Operations(t *testing.T) { proof, err = smt.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifyProof(proof, base.th.placeholder(), []byte("testKey3"), defaultValue, base) + result, err = VerifyProof(proof, base.th.placeholder(), []byte("testKey3"), defaultEmptyValue, base) require.NoError(t, err) require.True(t, result) result, err = VerifyProof(proof, root, []byte("testKey3"), []byte("badValue"), base) @@ -89,7 +89,7 @@ func TestSMT_Proof_Operations(t *testing.T) { SideNodes: proof.SideNodes, NonMembershipLeafData: leafData, } - result, err = VerifyProof(proof, root, []byte("testKey2"), defaultValue, base) + result, err = VerifyProof(proof, root, []byte("testKey2"), defaultEmptyValue, base) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) @@ -97,13 +97,13 @@ func TestSMT_Proof_Operations(t *testing.T) { proof, err = smt.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifyProof(proof, root, []byte("testKey3"), defaultValue, base) + result, err = VerifyProof(proof, root, []byte("testKey3"), defaultEmptyValue, base) require.NoError(t, err) require.True(t, result) result, err = VerifyProof(proof, root, []byte("testKey3"), []byte("badValue"), base) require.NoError(t, err) require.False(t, result) - result, err = VerifyProof(randomiseProof(proof), root, []byte("testKey3"), defaultValue, base) + result, err = VerifyProof(randomiseProof(proof), root, []byte("testKey3"), defaultEmptyValue, base) require.NoError(t, err) require.False(t, result) } diff --git a/smt_test.go b/smt_test.go index fd1adac..28f1c6b 100644 --- a/smt_test.go +++ b/smt_test.go @@ -34,7 +34,7 @@ func TestSMT_TrieUpdateBasic(t *testing.T) { // Test getting an empty key. value, err := smt.GetValue([]byte("testKey")) require.NoError(t, err) - require.Equal(t, defaultValue, value) + require.Equal(t, defaultEmptyValue, value) has, err = smt.Has([]byte("testKey")) require.NoError(t, err) @@ -119,7 +119,7 @@ func TestSMT_TrieDeleteBasic(t *testing.T) { value, err := smt.GetValue([]byte("testKey")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") has, err := smt.Has([]byte("testKey")) require.NoError(t, err) @@ -142,7 +142,7 @@ func TestSMT_TrieDeleteBasic(t *testing.T) { value, err = smt.GetValue([]byte("testKey2")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") value, err = smt.GetValue([]byte("testKey")) require.NoError(t, err) @@ -162,7 +162,7 @@ func TestSMT_TrieDeleteBasic(t *testing.T) { value, err = smt.GetValue([]byte("foo")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") value, err = smt.GetValue([]byte("testKey")) require.NoError(t, err) @@ -183,7 +183,7 @@ func TestSMT_TrieDeleteBasic(t *testing.T) { value, err = smt.GetValue([]byte("testKey")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") has, err = smt.Has([]byte("testKey")) require.NoError(t, err) diff --git a/smt_utils_test.go b/smt_utils_test.go index 2912e26..cff400c 100644 --- a/smt_utils_test.go +++ b/smt_utils_test.go @@ -46,7 +46,7 @@ func (smt *SMTWithStorage) GetValue(key []byte) ([]byte, error) { if err != nil { if errors.Is(err, ErrKeyNotFound) { // If key isn't found, return default value - value = defaultValue + value = defaultEmptyValue } else { // Otherwise percolate up any other error return nil, err @@ -59,7 +59,7 @@ func (smt *SMTWithStorage) GetValue(key []byte) ([]byte, error) { // otherwise. func (smt *SMTWithStorage) Has(key []byte) (bool, error) { val, err := smt.GetValue(key) - return !bytes.Equal(defaultValue, val), err + return !bytes.Equal(defaultEmptyValue, val), err } // ProveCompact generates a compacted Merkle proof for a key against the diff --git a/types.go b/types.go index 150d41f..c88f311 100644 --- a/types.go +++ b/types.go @@ -11,8 +11,10 @@ const ( ) var ( - defaultValue []byte - defaultSum [sumSize]byte + // defaultEmptyValue is the default value for a leaf node + defaultEmptyValue []byte + // defaultEmptySum is the default sum value for a leaf node + defaultEmptySum [sumSize]byte ) // MerkleRoot is a type alias for a byte slice returned from the Root method @@ -99,6 +101,7 @@ func newTrieSpec(hasher hash.Hash, sumTrie bool) TrieSpec { func (spec *TrieSpec) Spec() *TrieSpec { return spec } func (spec *TrieSpec) depth() int { return spec.ph.PathSize() * 8 } + func (spec *TrieSpec) digestValue(data []byte) []byte { if spec.vh == nil { return data diff --git a/utils.go b/utils.go index bb268c2..26abfe4 100644 --- a/utils.go +++ b/utils.go @@ -122,7 +122,7 @@ func bytesToInt(bz []byte) int { func placeholder(spec *TrieSpec) []byte { if spec.sumTrie { placeholder := spec.th.placeholder() - placeholder = append(placeholder, defaultSum[:]...) + placeholder = append(placeholder, defaultEmptySum[:]...) return placeholder } return spec.th.placeholder() @@ -199,14 +199,3 @@ func hashSumSerialization(smt *TrieSpec, data []byte) []byte { digest = append(digest, data[len(data)-sumSize:]...) return digest } - -// resolve resolves a lazy node depending on the trie type -func resolve( - smt *SMT, - hash []byte, -) (trieNode, error) { - if smt.sumTrie { - return smt.resolveSum(hash) - } - return smt.resolve(hash) -} From 04853f714e782e1726cf5b29bafbaa477bb09e9b Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Thu, 8 Feb 2024 12:58:21 -0800 Subject: [PATCH 06/40] WIP - adding comments to different parts of extension_node.go --- extension_node.go | 114 ++++++++++++++++++++++++-------------------- hasher.go | 26 +++++----- proofs.go | 8 ++-- proofs_test.go | 4 +- smst.go | 8 ++-- smst_proofs_test.go | 6 +-- smst_test.go | 2 +- smst_utils_test.go | 8 ++-- smt.go | 82 +++++++++++++++---------------- types.go | 12 ++--- utils.go | 4 +- 11 files changed, 141 insertions(+), 133 deletions(-) diff --git a/extension_node.go b/extension_node.go index 8e57408..88335f3 100644 --- a/extension_node.go +++ b/extension_node.go @@ -3,15 +3,19 @@ package smt var _ trieNode = (*extensionNode)(nil) // A compressed chain of singly-linked inner nodes. -// Instead of storing innerNodes pointing to other innerNodes, the extensionNode -// at `path` capture all the innerNodes from `pathBounds[0]` to `pathBounds[1]`. +// +// Extension nodes are used to captures a series of inner nodes that only +// have one child in a succinct `pathBounds` for optimization purposes. +// +// Assumption: the path is <=256 bits type extensionNode struct { - // The path to this extensionNode starting at the root. + // The path (starting at the root) to this extension node. path []byte - // Offsets into path slice of bounds defining actual path segment. - // NOTE: assumes path is <=256 bits + // The path (starting at pathBounds[0] and ending at pathBounds[1]) of + // inner nodes that this single extension node replaces. pathBounds [2]byte - // Child is always an inner node, or lazy. + // A child node from this extension node. + // It MUST be either an innerNode or a lazyNode. child trieNode // Bool whether or not the node has been flushed to disk persisted bool @@ -29,10 +33,21 @@ func (node *extensionNode) CachedDigest() []byte { return node.digest } -// length returns the length of the path segment (number of inner nodes replaced) -// by this single extensionNode +// Length returns the length of the path segment represented by this single +// extensionNode. Since the SMT is a binary trie, the length represents both +// the depth and the number of nodes replaced by a single extension node. If +// this SMT were to have k-ary support, the depth would be strictly less than +// the number of nodes replaced. func (ext *extensionNode) length() int { - return int(ext.pathBounds[1] - ext.pathBounds[0]) + return ext.pathEnd() - ext.pathStart() +} + +func (ext *extensionNode) pathStart() int { + return int(ext.pathBounds[0]) +} + +func (ext *extensionNode) pathEnd() int { + return int(ext.pathBounds[1]) } // setDirty marks the node as dirty (i.e. not flushed to disk) and clears @@ -42,84 +57,77 @@ func (ext *extensionNode) setDirty() { ext.digest = nil } -// Returns length of matching prefix, and whether it's a full match -func (ext *extensionNode) match(path []byte, depth int) (int, bool) { - if depth != ext.pathStart() { - panic("depth != path_begin") +// boundsMatch returns the length of the matching prefix between `ext.pathBounds` +// and `path` starting at index `depth`, along with a bool if a full match is found. +func (extNode *extensionNode) boundsMatch(path []byte, depth int) (int, bool) { + if depth != extNode.pathStart() { + panic("depth != extNode.pathStart") } - for i := ext.pathStart(); i < ext.pathEnd(); i++ { - if getPathBit(ext.path, i) != getPathBit(path, i) { - return i - ext.pathStart(), false + for pathIdx := extNode.pathStart(); pathIdx < extNode.pathEnd(); pathIdx++ { + if getPathBit(extNode.path, pathIdx) != getPathBit(path, pathIdx) { + return pathIdx - extNode.pathStart(), false } } - return ext.length(), true + return extNode.length(), true } -func (ext *extensionNode) pathStart() int { - return int(ext.pathBounds[0]) -} - -func (ext *extensionNode) pathEnd() int { - return int(ext.pathBounds[1]) -} - -// Splits the node in-place; returns replacement node, child node at the split, and split depth -func (ext *extensionNode) split(path []byte, depth int) (trieNode, *trieNode, int) { - if depth != ext.pathStart() { - panic("depth != path_begin") - } - index := ext.pathStart() - var myBit, branchBit int - for ; index < ext.pathEnd(); index++ { - myBit = getPathBit(ext.path, index) - branchBit = getPathBit(path, index) - if myBit != branchBit { +// split splits the node in-place by returning a new node at the extension node, +// a child node at the split and split depth. +func (extNode *extensionNode) split(path []byte) (trieNode, *trieNode, int) { + // Start path to extNode.pathBounds until there is no match + var extNodeBit, pathBit int + pathIdx := extNode.pathStart() + for ; pathIdx < extNode.pathEnd(); pathIdx++ { + extNodeBit = getPathBit(extNode.path, pathIdx) + pathBit = getPathBit(path, pathIdx) + if extNodeBit != pathBit { break } } - if index == ext.pathEnd() { - return ext, &ext.child, index + // Return the extension node's child if path fully matches extNode.pathBounds + if pathIdx == extNode.pathEnd() { + return extNode, &extNode.child, pathIdx } - child := ext.child + child := extNode.child var branch innerNode var head trieNode var tail *trieNode - if myBit == left { + if extNodeBit == leftChildBit { tail = &branch.leftChild } else { tail = &branch.rightChild } // Split at first bit: chain starts with new node - if index == ext.pathStart() { + if pathIdx == extNode.pathStart() { head = &branch - ext.pathBounds[0]++ // Shrink the extension from front - if ext.length() == 0 { + extNode.pathBounds[0]++ // Shrink the extension from front + if extNode.length() == 0 { *tail = child } else { - *tail = ext + *tail = extNode } } else { // Split inside: chain ends at index - head = ext - ext.child = &branch - if index == ext.pathEnd()-1 { + head = extNode + extNode.child = &branch + if pathIdx == extNode.pathEnd()-1 { *tail = child } else { *tail = &extensionNode{ - path: ext.path, + path: extNode.path, pathBounds: [2]byte{ - byte(index + 1), - ext.pathBounds[1], + byte(pathIdx + 1), + extNode.pathBounds[1], }, child: child, } } - ext.pathBounds[1] = byte(index) + extNode.pathBounds[1] = byte(pathIdx) } var b trieNode = &branch - return head, &b, index + return head, &b, pathIdx } // expand returns the inner node that represents the start of the singly @@ -128,7 +136,7 @@ func (ext *extensionNode) expand() trieNode { last := ext.child for i := ext.pathEnd() - 1; i >= ext.pathStart(); i-- { var next innerNode - if getPathBit(ext.path, i) == left { + if getPathBit(ext.path, i) == leftChildBit { next.leftChild = last } else { next.rightChild = last diff --git a/hasher.go b/hasher.go index e537382..57b1614 100644 --- a/hasher.go +++ b/hasher.go @@ -80,7 +80,7 @@ func (th *trieHasher) digestLeaf(path []byte, leafData []byte) ([]byte, []byte) func (th *trieHasher) digestSumLeaf(path []byte, leafData []byte) ([]byte, []byte) { value := encodeLeaf(path, leafData) digest := th.digest(value) - digest = append(digest, value[len(value)-sumSize:]...) + digest = append(digest, value[len(value)-sumSizeBits:]...) return digest, value } @@ -92,7 +92,7 @@ func (th *trieHasher) digestNode(leftData []byte, rightData []byte) ([]byte, []b func (th *trieHasher) digestSumNode(leftData []byte, rightData []byte) ([]byte, []byte) { value := encodeSumInner(leftData, rightData) digest := th.digest(value) - digest = append(digest, value[len(value)-sumSize:]...) + digest = append(digest, value[len(value)-sumSizeBits:]...) return digest, value } @@ -101,8 +101,8 @@ func (th *trieHasher) parseNode(data []byte) ([]byte, []byte) { } func (th *trieHasher) parseSumNode(data []byte) ([]byte, []byte) { - sumless := data[:len(data)-sumSize] - return sumless[len(innerPrefix) : th.hashSize()+sumSize+len(innerPrefix)], sumless[len(innerPrefix)+th.hashSize()+sumSize:] + sumless := data[:len(data)-sumSizeBits] + return sumless[len(innerPrefix) : th.hashSize()+sumSizeBits+len(innerPrefix)], sumless[len(innerPrefix)+th.hashSize()+sumSizeBits:] } func (th *trieHasher) hashSize() int { @@ -131,12 +131,12 @@ func parseExtension(data []byte, ph PathHasher) (pathBounds, path, childData []b data[len(extPrefix)+2+ph.PathSize():] } -func parseSumExtension(data []byte, ph PathHasher) (pathBounds, path, childData []byte, sum [sumSize]byte) { - var sumBz [sumSize]byte - copy(sumBz[:], data[len(data)-sumSize:]) +func parseSumExtension(data []byte, ph PathHasher) (pathBounds, path, childData []byte, sum [sumSizeBits]byte) { + var sumBz [sumSizeBits]byte + copy(sumBz[:], data[len(data)-sumSizeBits:]) return data[len(extPrefix) : len(extPrefix)+2], // +2 represents the length of the pathBounds data[len(extPrefix)+2 : len(extPrefix)+2+ph.PathSize()], - data[len(extPrefix)+2+ph.PathSize() : len(data)-sumSize], + data[len(extPrefix)+2+ph.PathSize() : len(data)-sumSizeBits], sumBz } @@ -163,11 +163,11 @@ func encodeSumInner(leftData []byte, rightData []byte) []byte { value = append(value, innerPrefix...) value = append(value, leftData...) value = append(value, rightData...) - var sum [sumSize]byte + var sum [sumSizeBits]byte leftSum := uint64(0) rightSum := uint64(0) - leftSumBz := leftData[len(leftData)-sumSize:] - rightSumBz := rightData[len(rightData)-sumSize:] + leftSumBz := leftData[len(leftData)-sumSizeBits:] + rightSumBz := rightData[len(rightData)-sumSizeBits:] if !bytes.Equal(leftSumBz, defaultEmptySum[:]) { leftSum = binary.BigEndian.Uint64(leftSumBz) } @@ -194,8 +194,8 @@ func encodeSumExtension(pathBounds [2]byte, path []byte, childData []byte) []byt value = append(value, pathBounds[:]...) value = append(value, path...) value = append(value, childData...) - var sum [sumSize]byte - copy(sum[:], childData[len(childData)-sumSize:]) + var sum [sumSizeBits]byte + copy(sum[:], childData[len(childData)-sumSizeBits:]) value = append(value, sum[:]...) return value } diff --git a/proofs.go b/proofs.go index 714f4db..90176a6 100644 --- a/proofs.go +++ b/proofs.go @@ -281,7 +281,7 @@ func VerifyProof(proof *SparseMerkleProof, root, key, value []byte, spec *TrieSp // VerifySumProof verifies a Merkle proof for a sum trie. func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) { - var sumBz [sumSize]byte + var sumBz [sumSizeBits]byte binary.BigEndian.PutUint64(sumBz[:], sum) valueHash := spec.digestValue(value) valueHash = append(valueHash, sumBz[:]...) @@ -314,9 +314,9 @@ func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *Trie if proof.ClosestValueHash == nil { return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, spec) } - sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSize:] + sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSizeBits:] sum := binary.BigEndian.Uint64(sumBz) - valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize] + valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSizeBits] return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, spec) } @@ -360,7 +360,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v node := make([]byte, hashSize(spec)) copy(node, proof.SideNodes[i]) - if getPathBit(path, len(proof.SideNodes)-1-i) == left { + if getPathBit(path, len(proof.SideNodes)-1-i) == leftChildBit { currentHash, currentData = digestNode(spec, currentHash, node) } else { currentHash, currentData = digestNode(spec, node, currentHash) diff --git a/proofs_test.go b/proofs_test.go index 6248e5c..de4d3e9 100644 --- a/proofs_test.go +++ b/proofs_test.go @@ -177,9 +177,9 @@ func randomiseProof(proof *SparseMerkleProof) *SparseMerkleProof { func randomiseSumProof(proof *SparseMerkleProof) *SparseMerkleProof { sideNodes := make([][]byte, len(proof.SideNodes)) for i := range sideNodes { - sideNodes[i] = make([]byte, len(proof.SideNodes[i])-sumSize) + sideNodes[i] = make([]byte, len(proof.SideNodes[i])-sumSizeBits) rand.Read(sideNodes[i]) // nolint: errcheck - sideNodes[i] = append(sideNodes[i], proof.SideNodes[i][len(proof.SideNodes[i])-sumSize:]...) + sideNodes[i] = append(sideNodes[i], proof.SideNodes[i][len(proof.SideNodes[i])-sumSizeBits:]...) } return &SparseMerkleProof{ SideNodes: sideNodes, diff --git a/smst.go b/smst.go index 48fa501..6684705 100644 --- a/smst.go +++ b/smst.go @@ -69,10 +69,10 @@ func (smst *SMST) Get(key []byte) ([]byte, uint64, error) { if bytes.Equal(valueHash, defaultEmptyValue) { return defaultEmptyValue, 0, nil } - var weightBz [sumSize]byte - copy(weightBz[:], valueHash[len(valueHash)-sumSize:]) + var weightBz [sumSizeBits]byte + copy(weightBz[:], valueHash[len(valueHash)-sumSizeBits:]) weight := binary.BigEndian.Uint64(weightBz[:]) - return valueHash[:len(valueHash)-sumSize], weight, nil + return valueHash[:len(valueHash)-sumSizeBits], weight, nil } // Update sets the value for the given key, to the digest of the provided value @@ -80,7 +80,7 @@ func (smst *SMST) Get(key []byte) ([]byte, uint64, error) { // is used to compute the interim and total sum of the trie. func (smst *SMST) Update(key, value []byte, weight uint64) error { valueHash := smst.digestValue(value) - var weightBz [sumSize]byte + var weightBz [sumSizeBits]byte binary.BigEndian.PutUint64(weightBz[:], weight) valueHash = append(valueHash, weightBz[:]...) return smst.SMT.Update(key, valueHash) diff --git a/smst_proofs_test.go b/smst_proofs_test.go index d231868..394d595 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -103,7 +103,7 @@ func TestSMST_Proof_Operations(t *testing.T) { require.False(t, result) // Try proving a default value for a non-default leaf. - var sum [sumSize]byte + var sum [sumSizeBits]byte binary.BigEndian.PutUint64(sum[:], 5) tval := base.digestValue([]byte("testValue")) tval = append(tval, sum[:]...) @@ -281,7 +281,7 @@ func TestSMST_ProveClosest(t *testing.T) { var result bool var root []byte var err error - var sumBz [sumSize]byte + var sumBz [sumSizeBits]byte smn = simplemap.NewSimpleMap() require.NoError(t, err) @@ -407,7 +407,7 @@ func TestSMST_ProveClosest_OneNode(t *testing.T) { closestPath := sha256.Sum256([]byte("foo")) closestValueHash := []byte("bar") - var sumBz [sumSize]byte + var sumBz [sumSizeBits]byte binary.BigEndian.PutUint64(sumBz[:], 5) closestValueHash = append(closestValueHash, sumBz[:]...) require.Equal(t, proof, &SparseMerkleClosestProof{ diff --git a/smst_test.go b/smst_test.go index 868e856..fc0c75d 100644 --- a/smst_test.go +++ b/smst_test.go @@ -441,7 +441,7 @@ func TestSMST_TotalSum(t *testing.T) { // Check root hash contains the correct hex sum root1 := smst.Root() - sumBz := root1[len(root1)-sumSize:] + sumBz := root1[len(root1)-sumSizeBits:] rootSum := binary.BigEndian.Uint64(sumBz) require.NoError(t, err) diff --git a/smst_utils_test.go b/smst_utils_test.go index dcac64f..90e27d9 100644 --- a/smst_utils_test.go +++ b/smst_utils_test.go @@ -27,7 +27,7 @@ func (smst *SMSTWithStorage) Update(key, value []byte, sum uint64) error { return err } valueHash := smst.digestValue(value) - var sumBz [sumSize]byte + var sumBz [sumSizeBits]byte binary.BigEndian.PutUint64(sumBz[:], sum) value = append(value, sumBz[:]...) return smst.preimages.Set(valueHash, value) @@ -57,13 +57,13 @@ func (smst *SMSTWithStorage) GetValueSum(key []byte) ([]byte, uint64, error) { // Otherwise percolate up any other error return nil, 0, err } - var sumBz [sumSize]byte - copy(sumBz[:], value[len(value)-sumSize:]) + var sumBz [sumSizeBits]byte + copy(sumBz[:], value[len(value)-sumSizeBits:]) storedSum := binary.BigEndian.Uint64(sumBz[:]) if storedSum != sum { return nil, 0, fmt.Errorf("sum mismatch for %s: got %d, expected %d", string(key), storedSum, sum) } - return value[:len(value)-sumSize], storedSum, nil + return value[:len(value)-sumSizeBits], storedSum, nil } // Has returns true if the value at the given key is non-default, false otherwise. diff --git a/smt.go b/smt.go index f418e11..603620f 100644 --- a/smt.go +++ b/smt.go @@ -91,7 +91,7 @@ func (smt *SMT) Get(key []byte) ([]byte, error) { break } if extNode, ok := (*currNode).(*extensionNode); ok { - if _, match := extNode.match(path, depth); !match { + if _, fullMatch := extNode.boundsMatch(path, depth); !fullMatch { break } depth += extNode.length() @@ -102,7 +102,7 @@ func (smt *SMT) Get(key []byte) ([]byte, error) { } } inner := (*currNode).(*innerNode) - if getPathBit(path, depth) == left { + if getPathBit(path, depth) == leftChildBit { currNode = &inner.leftChild } else { currNode = &inner.rightChild @@ -171,7 +171,7 @@ func (smt *SMT) update( *last = &ext last = &ext.child } - if getPathBit(path, prefixLen) == left { + if getPathBit(path, prefixLen) == leftChildBit { *last = &innerNode{ leftChild: newLeaf, rightChild: leaf, @@ -187,20 +187,20 @@ func (smt *SMT) update( smt.addOrphan(orphans, node) - if ext, ok := node.(*extensionNode); ok { + if extNode, ok := node.(*extensionNode); ok { var branch *trieNode - node, branch, depth = ext.split(path, depth) + node, branch, depth = extNode.split(path) *branch, err = smt.update(*branch, depth, path, value, orphans) if err != nil { return node, err } - ext.setDirty() + extNode.setDirty() return node, nil } inner := node.(*innerNode) var child *trieNode - if getPathBit(path, depth) == left { + if getPathBit(path, depth) == leftChildBit { child = &inner.leftChild } else { child = &inner.rightChild @@ -248,30 +248,30 @@ func (smt *SMT) delete(node trieNode, depth int, path []byte, orphans *orphanNod smt.addOrphan(orphans, node) - if ext, ok := node.(*extensionNode); ok { - if _, match := ext.match(path, depth); !match { + if extNode, ok := node.(*extensionNode); ok { + if _, fullMatch := extNode.boundsMatch(path, depth); !fullMatch { return node, ErrKeyNotFound } - ext.child, err = smt.delete(ext.child, depth+ext.length(), path, orphans) + extNode.child, err = smt.delete(extNode.child, depth+extNode.length(), path, orphans) if err != nil { return node, err } - switch n := ext.child.(type) { + switch n := extNode.child.(type) { case *leafNode: return n, nil case *extensionNode: // Join this extension with the child smt.addOrphan(orphans, n) - n.pathBounds[0] = ext.pathBounds[0] + n.pathBounds[0] = extNode.pathBounds[0] node = n } - ext.setDirty() + extNode.setDirty() return node, nil } inner := node.(*innerNode) var child, sib *trieNode - if getPathBit(path, depth) == left { + if getPathBit(path, depth) == leftChildBit { child, sib = &inner.leftChild, &inner.rightChild } else { child, sib = &inner.rightChild, &inner.leftChild @@ -323,24 +323,24 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) { if _, ok := node.(*leafNode); ok { break } - if ext, ok := node.(*extensionNode); ok { - length, match := ext.match(path, depth) - if match { - for i := 0; i < length; i++ { + if extNode, ok := node.(*extensionNode); ok { + matchLen, fullMatch := extNode.boundsMatch(path, depth) + if fullMatch { + for i := 0; i < matchLen; i++ { siblings = append(siblings, nil) } - depth += length - node = ext.child + depth += matchLen + node = extNode.child node, err = smt.resolveLazy(node) if err != nil { return nil, err } } else { - node = ext.expand() + node = extNode.expand() } } inner := node.(*innerNode) - if getPathBit(path, depth) == left { + if getPathBit(path, depth) == leftChildBit { node, sib = inner.leftChild, inner.rightChild } else { node, sib = inner.rightChild, inner.leftChild @@ -451,21 +451,21 @@ func (smt *SMT) ProveClosest(path []byte) ( proof.Depth = depth break } - if ext, ok := node.(*extensionNode); ok { - length, match := ext.match(workingPath, depth) + if extNode, ok := node.(*extensionNode); ok { + matchLen, fullMatch := extNode.boundsMatch(workingPath, depth) // workingPath from depth to end of extension node's path bounds // is a perfect match - if !match { - node = ext.expand() + if !fullMatch { + node = extNode.expand() } else { // extension nodes represent a singly linked list of inner nodes // add nil siblings to represent the empty neighbours - for i := 0; i < length; i++ { + for i := 0; i < matchLen; i++ { siblings = append(siblings, nil) } - depth += length - depthDelta += length - node = ext.child + depth += matchLen + depthDelta += matchLen + node = extNode.child node, err = smt.resolveLazy(node) if err != nil { return nil, err @@ -477,7 +477,7 @@ func (smt *SMT) ProveClosest(path []byte) ( proof.Depth = depth break } - if getPathBit(workingPath, depth) == left { + if getPathBit(workingPath, depth) == leftChildBit { node, sib = inner.leftChild, inner.rightChild } else { node, sib = inner.rightChild, inner.leftChild @@ -547,12 +547,12 @@ func (smt *SMT) resolve(hash []byte) (ret trieNode, err error) { return &leaf, nil } if isExtension(data) { - ext := extensionNode{persisted: true, digest: hash} + extNode := extensionNode{persisted: true, digest: hash} pathBounds, path, childHash := parseExtension(data, smt.ph) - ext.path = path - copy(ext.pathBounds[:], pathBounds) - ext.child = &lazyNode{childHash} - return &ext, nil + extNode.path = path + copy(extNode.pathBounds[:], pathBounds) + extNode.child = &lazyNode{childHash} + return &extNode, nil } leftHash, rightHash := smt.th.parseNode(data) inner := innerNode{persisted: true, digest: hash} @@ -576,12 +576,12 @@ func (smt *SMT) resolveSum(hash []byte) (ret trieNode, err error) { return &leaf, nil } if isExtension(data) { - ext := extensionNode{persisted: true, digest: hash} + extNode := extensionNode{persisted: true, digest: hash} pathBounds, path, childHash, _ := parseSumExtension(data, smt.ph) - ext.path = path - copy(ext.pathBounds[:], pathBounds) - ext.child = &lazyNode{childHash} - return &ext, nil + extNode.path = path + copy(extNode.pathBounds[:], pathBounds) + extNode.child = &lazyNode{childHash} + return &extNode, nil } leftHash, rightHash := smt.th.parseSumNode(data) inner := innerNode{persisted: true, digest: hash} diff --git a/types.go b/types.go index c88f311..c2b757f 100644 --- a/types.go +++ b/types.go @@ -6,15 +6,15 @@ import ( ) const ( - left = 0 - sumSize = 8 + leftChildBit = 0 + sumSizeBits = 8 ) var ( // defaultEmptyValue is the default value for a leaf node defaultEmptyValue []byte // defaultEmptySum is the default sum value for a leaf node - defaultEmptySum [sumSize]byte + defaultEmptySum [sumSizeBits]byte ) // MerkleRoot is a type alias for a byte slice returned from the Root method @@ -27,8 +27,8 @@ func (r MerkleRoot) Sum() uint64 { if len(r)%32 == 0 { panic("roo#sum: not a merkle sum trie") } - var sumbz [sumSize]byte - copy(sumbz[:], []byte(r)[len([]byte(r))-sumSize:]) + var sumbz [sumSizeBits]byte + copy(sumbz[:], []byte(r)[len([]byte(r))-sumSizeBits:]) return binary.BigEndian.Uint64(sumbz[:]) } @@ -193,7 +193,7 @@ func (spec *TrieSpec) hashSumNode(node trieNode) []byte { if *cache == nil { preimage := spec.sumSerialize(node) *cache = spec.th.digest(preimage) - *cache = append(*cache, preimage[len(preimage)-sumSize:]...) + *cache = append(*cache, preimage[len(preimage)-sumSizeBits:]...) } return *cache } diff --git a/utils.go b/utils.go index 26abfe4..bd8e33a 100644 --- a/utils.go +++ b/utils.go @@ -131,7 +131,7 @@ func placeholder(spec *TrieSpec) []byte { // hashSize returns the hash size depending on the trie type func hashSize(spec *TrieSpec) int { if spec.sumTrie { - return spec.th.hashSize() + sumSize + return spec.th.hashSize() + sumSizeBits } return spec.th.hashSize() } @@ -196,6 +196,6 @@ func hashSumSerialization(smt *TrieSpec, data []byte) []byte { return smt.hashSumNode(&ext) } digest := smt.th.digest(data) - digest = append(digest, data[len(data)-sumSize:]...) + digest = append(digest, data[len(data)-sumSizeBits:]...) return digest } From 6b368c0b7cda9e2239a885f9db8a4beb1d224cb8 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Fri, 9 Feb 2024 14:58:34 -0800 Subject: [PATCH 07/40] Remove the need for the sum workaround function --- root_test.go | 11 +---------- smst.go | 17 ++++++++++++++--- types.go | 15 +-------------- 3 files changed, 16 insertions(+), 27 deletions(-) diff --git a/root_test.go b/root_test.go index da6293c..d6e588b 100644 --- a/root_test.go +++ b/root_test.go @@ -3,7 +3,6 @@ package smt_test import ( "crypto/sha256" "crypto/sha512" - "encoding/binary" "fmt" "hash" "testing" @@ -59,8 +58,7 @@ func TestMerkleRoot_TrieTypes(t *testing.T) { for i := uint64(0); i < 10; i++ { require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i)) } - root := trie.Root() - require.Equal(t, root.Sum(), getSumBzHelper(t, root)) + require.NotNil(t, trie.Sum()) return } trie := smt.NewSparseMerkleTrie(nodeStore, tt.hasher) @@ -73,10 +71,3 @@ func TestMerkleRoot_TrieTypes(t *testing.T) { }) } } - -func getSumBzHelper(t *testing.T, r []byte) uint64 { - sumSize := len(r) % 32 - sumBz := make([]byte, sumSize) - copy(sumBz[:], []byte(r)[len([]byte(r))-sumSize:]) - return binary.BigEndian.Uint64(sumBz[:]) -} diff --git a/smst.go b/smst.go index 6684705..394cdb3 100644 --- a/smst.go +++ b/smst.go @@ -8,6 +8,11 @@ import ( "github.com/pokt-network/smt/kvstore" ) +const ( + // The number of bits used to represent the sum of a node + sumSizeBits = 8 +) + var _ SparseMerkleSumTrie = (*SMST)(nil) // SMST is an object wrapping a Sparse Merkle Trie for custom encoding @@ -116,8 +121,14 @@ func (smst *SMST) Root() MerkleRoot { return smst.SMT.Root() // [digest]+[binary sum] } -// Sum returns the uint64 sum of the entire trie +// Sum returns the sum of the entire trie stored in the root. +// If the tree is not a sum tree, it will panic. func (smst *SMST) Sum() uint64 { - digest := smst.Root() - return digest.Sum() + rootDigest := smst.Root() + if len(rootDigest) != smst.th.hashSize()+sumSizeBits { + panic("roo#sum: not a merkle sum trie") + } + var sumbz [sumSizeBits]byte + copy(sumbz[:], []byte(rootDigest)[len([]byte(rootDigest))-sumSizeBits:]) + return binary.BigEndian.Uint64(sumbz[:]) } diff --git a/types.go b/types.go index c2b757f..edd3dee 100644 --- a/types.go +++ b/types.go @@ -1,13 +1,12 @@ package smt import ( - "encoding/binary" "hash" ) const ( + // The bit value use to distinguish an inner nodes left child and right child leftChildBit = 0 - sumSizeBits = 8 ) var ( @@ -20,18 +19,6 @@ var ( // MerkleRoot is a type alias for a byte slice returned from the Root method type MerkleRoot []byte -// Sum returns the uint64 sum of the merkle root, it checks the length of the -// merkle root and if it is no the same as the size of the SMST's expected -// root hash it will panic. -func (r MerkleRoot) Sum() uint64 { - if len(r)%32 == 0 { - panic("roo#sum: not a merkle sum trie") - } - var sumbz [sumSizeBits]byte - copy(sumbz[:], []byte(r)[len([]byte(r))-sumSizeBits:]) - return binary.BigEndian.Uint64(sumbz[:]) -} - // SparseMerkleTrie represents a Sparse Merkle Trie. type SparseMerkleTrie interface { // Update inserts a value into the SMT. From 89f9ceb253eca5befaddcbe6f737351f90389a08 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Fri, 9 Feb 2024 15:27:59 -0800 Subject: [PATCH 08/40] WIP - refactoring node hasher's and encoders --- .gitignore | 3 ++ hasher.go | 116 ++++++----------------------------------------- node_encoders.go | 109 ++++++++++++++++++++++++++++++++++++++++++++ proofs.go | 4 +- smt.go | 26 ++++------- types.go | 53 ++++++++++++++++------ utils.go | 4 +- 7 files changed, 176 insertions(+), 139 deletions(-) create mode 100644 node_encoders.go diff --git a/.gitignore b/.gitignore index aa638db..36195c1 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,6 @@ # Ignore Goland and JetBrains IDE files .idea/ + +# Ignore vscode files +.vscode diff --git a/hasher.go b/hasher.go index 57b1614..e1b0ce3 100644 --- a/hasher.go +++ b/hasher.go @@ -1,17 +1,9 @@ package smt import ( - "bytes" - "encoding/binary" "hash" ) -var ( - leafPrefix = []byte{0} - innerPrefix = []byte{1} - extPrefix = []byte{2} -) - var ( _ PathHasher = (*pathHasher)(nil) _ ValueHasher = (*valueHasher)(nil) @@ -31,19 +23,23 @@ type ValueHasher interface { HashValue([]byte) []byte } +// trieHasher is a common hasher for all trie hashers (paths & values). type trieHasher struct { hasher hash.Hash zeroValue []byte } +// pathHasher is a hasher for trie paths. type pathHasher struct { trieHasher } +// valueHasher is a hasher for leaf values. type valueHasher struct { trieHasher } +// newTrieHasher returns a new trie hasher with the given hash function. func newTrieHasher(hasher hash.Hash) *trieHasher { th := trieHasher{hasher: hasher} th.zeroValue = make([]byte, th.hashSize()) @@ -61,10 +57,12 @@ func (ph *pathHasher) PathSize() int { return ph.hasher.Size() } +// HashValue hashes the data provided using the value hasher func (vh *valueHasher) HashValue(data []byte) []byte { return vh.digest(data) } +// digest returns the hash of the data provided using the trie hasher. func (th *trieHasher) digest(data []byte) []byte { th.hasher.Write(data) sum := th.hasher.Sum(nil) @@ -72,37 +70,38 @@ func (th *trieHasher) digest(data []byte) []byte { return sum } -func (th *trieHasher) digestLeaf(path []byte, leafData []byte) ([]byte, []byte) { - value := encodeLeaf(path, leafData) +// digestLeaf returns the hash of the leaf data & pathprovided using the trie hasher. +func (th *trieHasher) digestLeaf(path, data []byte) ([]byte, []byte) { + value := encodeLeafNode(path, data) return th.digest(value), value } func (th *trieHasher) digestSumLeaf(path []byte, leafData []byte) ([]byte, []byte) { - value := encodeLeaf(path, leafData) + value := encodeLeafNode(path, leafData) digest := th.digest(value) digest = append(digest, value[len(value)-sumSizeBits:]...) return digest, value } func (th *trieHasher) digestNode(leftData []byte, rightData []byte) ([]byte, []byte) { - value := encodeInner(leftData, rightData) + value := encodeInnerNode(leftData, rightData) return th.digest(value), value } func (th *trieHasher) digestSumNode(leftData []byte, rightData []byte) ([]byte, []byte) { - value := encodeSumInner(leftData, rightData) + value := encodeSumInnerNode(leftData, rightData) digest := th.digest(value) digest = append(digest, value[len(value)-sumSizeBits:]...) return digest, value } func (th *trieHasher) parseNode(data []byte) ([]byte, []byte) { - return data[len(innerPrefix) : th.hashSize()+len(innerPrefix)], data[len(innerPrefix)+th.hashSize():] + return data[len(innerNodePrefix) : th.hashSize()+len(innerNodePrefix)], data[len(innerNodePrefix)+th.hashSize():] } func (th *trieHasher) parseSumNode(data []byte) ([]byte, []byte) { sumless := data[:len(data)-sumSizeBits] - return sumless[len(innerPrefix) : th.hashSize()+sumSizeBits+len(innerPrefix)], sumless[len(innerPrefix)+th.hashSize()+sumSizeBits:] + return sumless[len(innerNodePrefix) : th.hashSize()+sumSizeBits+len(innerNodePrefix)], sumless[len(innerNodePrefix)+th.hashSize()+sumSizeBits:] } func (th *trieHasher) hashSize() int { @@ -112,90 +111,3 @@ func (th *trieHasher) hashSize() int { func (th *trieHasher) placeholder() []byte { return th.zeroValue } - -func isLeaf(data []byte) bool { - return bytes.Equal(data[:len(leafPrefix)], leafPrefix) -} - -func isExtension(data []byte) bool { - return bytes.Equal(data[:len(extPrefix)], extPrefix) -} - -func parseLeaf(data []byte, ph PathHasher) ([]byte, []byte) { - return data[len(leafPrefix) : ph.PathSize()+len(leafPrefix)], data[len(leafPrefix)+ph.PathSize():] -} - -func parseExtension(data []byte, ph PathHasher) (pathBounds, path, childData []byte) { - return data[len(extPrefix) : len(extPrefix)+2], // +2 represents the length of the pathBounds - data[len(extPrefix)+2 : len(extPrefix)+2+ph.PathSize()], - data[len(extPrefix)+2+ph.PathSize():] -} - -func parseSumExtension(data []byte, ph PathHasher) (pathBounds, path, childData []byte, sum [sumSizeBits]byte) { - var sumBz [sumSizeBits]byte - copy(sumBz[:], data[len(data)-sumSizeBits:]) - return data[len(extPrefix) : len(extPrefix)+2], // +2 represents the length of the pathBounds - data[len(extPrefix)+2 : len(extPrefix)+2+ph.PathSize()], - data[len(extPrefix)+2+ph.PathSize() : len(data)-sumSizeBits], - sumBz -} - -// encodeLeaf encodes both normal and sum leaves as in the sum leaf the -// sum is appended to the end of the valueHash -func encodeLeaf(path []byte, leafData []byte) []byte { - value := make([]byte, 0, len(leafPrefix)+len(path)+len(leafData)) - value = append(value, leafPrefix...) - value = append(value, path...) - value = append(value, leafData...) - return value -} - -func encodeInner(leftData []byte, rightData []byte) []byte { - value := make([]byte, 0, len(innerPrefix)+len(leftData)+len(rightData)) - value = append(value, innerPrefix...) - value = append(value, leftData...) - value = append(value, rightData...) - return value -} - -func encodeSumInner(leftData []byte, rightData []byte) []byte { - value := make([]byte, 0, len(innerPrefix)+len(leftData)+len(rightData)) - value = append(value, innerPrefix...) - value = append(value, leftData...) - value = append(value, rightData...) - var sum [sumSizeBits]byte - leftSum := uint64(0) - rightSum := uint64(0) - leftSumBz := leftData[len(leftData)-sumSizeBits:] - rightSumBz := rightData[len(rightData)-sumSizeBits:] - if !bytes.Equal(leftSumBz, defaultEmptySum[:]) { - leftSum = binary.BigEndian.Uint64(leftSumBz) - } - if !bytes.Equal(rightSumBz, defaultEmptySum[:]) { - rightSum = binary.BigEndian.Uint64(rightSumBz) - } - binary.BigEndian.PutUint64(sum[:], leftSum+rightSum) - value = append(value, sum[:]...) - return value -} - -func encodeExtension(pathBounds [2]byte, path []byte, childData []byte) []byte { - value := make([]byte, 0, len(extPrefix)+len(path)+2+len(childData)) - value = append(value, extPrefix...) - value = append(value, pathBounds[:]...) - value = append(value, path...) - value = append(value, childData...) - return value -} - -func encodeSumExtension(pathBounds [2]byte, path []byte, childData []byte) []byte { - value := make([]byte, 0, len(extPrefix)+len(path)+2+len(childData)) - value = append(value, extPrefix...) - value = append(value, pathBounds[:]...) - value = append(value, path...) - value = append(value, childData...) - var sum [sumSizeBits]byte - copy(sum[:], childData[len(childData)-sumSizeBits:]) - value = append(value, sum[:]...) - return value -} diff --git a/node_encoders.go b/node_encoders.go new file mode 100644 index 0000000..b9d1bfd --- /dev/null +++ b/node_encoders.go @@ -0,0 +1,109 @@ +package smt + +import ( + "bytes" + "encoding/binary" +) + +// TODO_IMPROVE: All of the parsing, encoding and checking functions in this file +// can be abstracted out into the `trieNode` interface. + +// NB: In this file, all references to the variable `data` should be treated as `encodedNodeData`. +// It was abbreviated to `data` for brevity. + +var ( + leafNodePrefix = []byte{0} + innerNodePrefix = []byte{1} + extNodePrefix = []byte{2} +) + +// isLeafNode returns true if the encoded node data is a leaf node +func isLeafNode(data []byte) bool { + return bytes.Equal(data[:len(leafNodePrefix)], leafNodePrefix) +} + +// isExtNode returns true if the encoded node data is an extension node +func isExtNode(data []byte) bool { + return bytes.Equal(data[:len(extNodePrefix)], extNodePrefix) +} + +func parseLeafNode(data []byte, ph PathHasher) ([]byte, []byte) { + return data[len(leafNodePrefix) : ph.PathSize()+len(leafNodePrefix)], data[len(leafNodePrefix)+ph.PathSize():] +} + +func parseExtension(data []byte, ph PathHasher) (pathBounds, path, childData []byte) { + return data[len(extNodePrefix) : len(extNodePrefix)+2], // +2 represents the length of the pathBounds + data[len(extNodePrefix)+2 : len(extNodePrefix)+2+ph.PathSize()], + data[len(extNodePrefix)+2+ph.PathSize():] +} + +func parseSumExtension(data []byte, ph PathHasher) (pathBounds, path, childData []byte, sum [sumSizeBits]byte) { + var sumBz [sumSizeBits]byte + copy(sumBz[:], data[len(data)-sumSizeBits:]) + return data[len(extNodePrefix) : len(extNodePrefix)+2], // +2 represents the length of the pathBounds + data[len(extNodePrefix)+2 : len(extNodePrefix)+2+ph.PathSize()], + data[len(extNodePrefix)+2+ph.PathSize() : len(data)-sumSizeBits], + sumBz +} + +// encodeLeafNode encodes leaf nodes. both normal and sum leaves as in the sum leaf the +// sum is appended to the end of the valueHash +func encodeLeafNode(path []byte, leafData []byte) []byte { + data := make([]byte, 0, len(leafNodePrefix)+len(path)+len(leafData)) + data = append(data, leafNodePrefix...) + data = append(data, path...) + data = append(data, leafData...) + return data +} + +func encodeInnerNode(leftData []byte, rightData []byte) []byte { + data := make([]byte, 0, len(innerNodePrefix)+len(leftData)+len(rightData)) + data = append(data, innerNodePrefix...) + data = append(data, leftData...) + data = append(data, rightData...) + return data +} + +func encodeSumInnerNode(leftData []byte, rightData []byte) []byte { + data := make([]byte, 0, len(innerNodePrefix)+len(leftData)+len(rightData)) + data = append(data, innerNodePrefix...) + data = append(data, leftData...) + data = append(data, rightData...) + + var sum [sumSizeBits]byte + leftSum := uint64(0) + rightSum := uint64(0) + leftSumBz := leftData[len(leftData)-sumSizeBits:] + rightSumBz := rightData[len(rightData)-sumSizeBits:] + if !bytes.Equal(leftSumBz, defaultEmptySum[:]) { + leftSum = binary.BigEndian.Uint64(leftSumBz) + } + if !bytes.Equal(rightSumBz, defaultEmptySum[:]) { + rightSum = binary.BigEndian.Uint64(rightSumBz) + } + binary.BigEndian.PutUint64(sum[:], leftSum+rightSum) + data = append(data, sum[:]...) + return data +} + +// encodeExtensionNode encodes the data of an extension nodes +func encodeExtensionNode(pathBounds [2]byte, path []byte, childData []byte) []byte { + data := []byte{} + data = append(data, extNodePrefix...) + data = append(data, pathBounds[:]...) + data = append(data, path...) + data = append(data, childData...) + return data +} + +// encodeSumExtensionNode encodes the data of a sum extension nodes +func encodeSumExtensionNode(pathBounds [2]byte, path []byte, childData []byte) []byte { + data := encodeExtensionNode(pathBounds, path, childData) + + // Append the sum to the end of the data + var sum [sumSizeBits]byte + copy(sum[:], childData[len(childData)-sumSizeBits:]) + data = append(data, sum[:]...) + + return data +} diff --git a/proofs.go b/proofs.go index 90176a6..211eed5 100644 --- a/proofs.go +++ b/proofs.go @@ -59,7 +59,7 @@ func (proof *SparseMerkleProof) validateBasic(spec *TrieSpec) error { return fmt.Errorf("too many side nodes: got %d but max is %d", len(proof.SideNodes), spec.ph.PathSize()*8) } // Check that leaf data for non-membership proofs is a valid size. - lps := len(leafPrefix) + spec.ph.PathSize() + lps := len(leafNodePrefix) + spec.ph.PathSize() if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < lps { return fmt.Errorf("invalid non-membership leaf data size: got %d but min is %d", len(proof.NonMembershipLeafData), lps) } @@ -336,7 +336,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v currentHash = placeholder(spec) } else { // Leaf is an unrelated leaf. var actualPath, valueHash []byte - actualPath, valueHash = parseLeaf(proof.NonMembershipLeafData, spec.ph) + actualPath, valueHash = parseLeafNode(proof.NonMembershipLeafData, spec.ph) if bytes.Equal(actualPath, path) { // This is not an unrelated leaf; non-membership proof failed. return false, nil, errors.Join(ErrBadProof, errors.New("non-membership proof on related leaf")) diff --git a/smt.go b/smt.go index 603620f..18e075f 100644 --- a/smt.go +++ b/smt.go @@ -7,19 +7,9 @@ import ( "github.com/pokt-network/smt/kvstore" ) +// Make sure the `SMT` struct implements the `SparseMerkleTrie` interface var _ SparseMerkleTrie = (*SMT)(nil) -// A high-level interface that captures the behaviour of all types of nodes -type trieNode interface { - // Persisted returns a boolean to determine whether or not the node - // has been persisted to disk or only held in memory. - // It can be used skip unnecessary iops if already persisted - Persisted() bool - - // The digest of the node, returning a cached value if available. - CachedDigest() []byte -} - // SMT is a Sparse Merkle Trie object that implements the SparseMerkleTrie interface type SMT struct { TrieSpec @@ -356,7 +346,7 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) { if !bytes.Equal(leaf.path, path) { // This is a non-membership proof that involves showing a different leaf. // Add the leaf data to the proof. - leafData = encodeLeaf(leaf.path, leaf.valueHash) + leafData = encodeLeafNode(leaf.path, leaf.valueHash) } } // Hash siblings from bottom up. @@ -541,12 +531,12 @@ func (smt *SMT) resolve(hash []byte) (ret trieNode, err error) { if err != nil { return } - if isLeaf(data) { + if isLeafNode(data) { leaf := leafNode{persisted: true, digest: hash} - leaf.path, leaf.valueHash = parseLeaf(data, smt.ph) + leaf.path, leaf.valueHash = parseLeafNode(data, smt.ph) return &leaf, nil } - if isExtension(data) { + if isExtNode(data) { extNode := extensionNode{persisted: true, digest: hash} pathBounds, path, childHash := parseExtension(data, smt.ph) extNode.path = path @@ -570,12 +560,12 @@ func (smt *SMT) resolveSum(hash []byte) (ret trieNode, err error) { if err != nil { return nil, err } - if isLeaf(data) { + if isLeafNode(data) { leaf := leafNode{persisted: true, digest: hash} - leaf.path, leaf.valueHash = parseLeaf(data, smt.ph) + leaf.path, leaf.valueHash = parseLeafNode(data, smt.ph) return &leaf, nil } - if isExtension(data) { + if isExtNode(data) { extNode := extensionNode{persisted: true, digest: hash} pathBounds, path, childHash, _ := parseSumExtension(data, smt.ph) extNode.path = path diff --git a/types.go b/types.go index edd3dee..ab54818 100644 --- a/types.go +++ b/types.go @@ -4,6 +4,11 @@ import ( "hash" ) +// TODO_DISCUSS_IN_THIS_PR_IMPROVEMENTS: +// 1. Should we rename all instances of digest to hash? +// 2. Should we introduce a shared interface between SparseMerkleTrie and SparseMerkleSumTrie? +// 3. Should we rename Commit to FlushToDisk? + const ( // The bit value use to distinguish an inner nodes left child and right child leftChildBit = 0 @@ -19,6 +24,17 @@ var ( // MerkleRoot is a type alias for a byte slice returned from the Root method type MerkleRoot []byte +// A high-level interface that captures the behaviour of all types of nodes +type trieNode interface { + // Persisted returns a boolean to determine whether or not the node + // has been persisted to disk or only held in memory. + // It can be used skip unnecessary iops if already persisted + Persisted() bool + + // The digest of the node, returning a cached value if available. + CachedDigest() []byte +} + // SparseMerkleTrie represents a Sparse Merkle Trie. type SparseMerkleTrie interface { // Update inserts a value into the SMT. @@ -76,6 +92,7 @@ type TrieSpec struct { sumTrie bool } +// newTrieSpec returns a new TrieSpec with the given hasher and sumTrie flag func newTrieSpec(hasher hash.Hash, sumTrie bool) TrieSpec { spec := TrieSpec{th: *newTrieHasher(hasher)} spec.ph = &pathHasher{spec.th} @@ -85,9 +102,15 @@ func newTrieSpec(hasher hash.Hash, sumTrie bool) TrieSpec { } // Spec returns the TrieSpec associated with the given trie -func (spec *TrieSpec) Spec() *TrieSpec { return spec } +func (spec *TrieSpec) Spec() *TrieSpec { + return spec +} -func (spec *TrieSpec) depth() int { return spec.ph.PathSize() * 8 } +// depth returns the maximum depth of the trie. +// Since this tree is a binary tree, the depth is the number of bits in the path +func (spec *TrieSpec) depth() int { + return spec.ph.PathSize() * 8 // path size is in bytes so multiply by 8 to get num bits +} func (spec *TrieSpec) digestValue(data []byte) []byte { if spec.vh == nil { @@ -101,14 +124,14 @@ func (spec *TrieSpec) serialize(node trieNode) (data []byte) { case *lazyNode: panic("serialize(lazyNode)") case *leafNode: - return encodeLeaf(n.path, n.valueHash) + return encodeLeafNode(n.path, n.valueHash) case *innerNode: lchild := spec.hashNode(n.leftChild) rchild := spec.hashNode(n.rightChild) - return encodeInner(lchild, rchild) + return encodeInnerNode(lchild, rchild) case *extensionNode: child := spec.hashNode(n.child) - return encodeExtension(n.pathBounds, n.path, child) + return encodeExtensionNode(n.pathBounds, n.path, child) } return nil } @@ -139,20 +162,20 @@ func (spec *TrieSpec) hashNode(node trieNode) []byte { // sumSerialize serializes a node returning the preimage hash, its sum and any // errors encountered -func (spec *TrieSpec) sumSerialize(node trieNode) (preimage []byte) { +func (spec *TrieSpec) sumSerialize(node trieNode) (preImage []byte) { switch n := node.(type) { case *lazyNode: panic("serialize(lazyNode)") case *leafNode: - return encodeLeaf(n.path, n.valueHash) + return encodeLeafNode(n.path, n.valueHash) case *innerNode: - lchild := spec.hashSumNode(n.leftChild) - rchild := spec.hashSumNode(n.rightChild) - preimage = encodeSumInner(lchild, rchild) - return preimage + leftChild := spec.hashSumNode(n.leftChild) + rightChild := spec.hashSumNode(n.rightChild) + preImage = encodeSumInnerNode(leftChild, rightChild) + return preImage case *extensionNode: child := spec.hashSumNode(n.child) - return encodeSumExtension(n.pathBounds, n.path, child) + return encodeSumExtensionNode(n.pathBounds, n.path, child) } return nil } @@ -178,9 +201,9 @@ func (spec *TrieSpec) hashSumNode(node trieNode) []byte { return n.digest } if *cache == nil { - preimage := spec.sumSerialize(node) - *cache = spec.th.digest(preimage) - *cache = append(*cache, preimage[len(preimage)-sumSizeBits:]...) + preImage := spec.sumSerialize(node) + *cache = spec.th.digest(preImage) + *cache = append(*cache, preImage[len(preImage)-sumSizeBits:]...) } return *cache } diff --git a/utils.go b/utils.go index bd8e33a..cda5ae2 100644 --- a/utils.go +++ b/utils.go @@ -178,7 +178,7 @@ func hashPreimage(spec *TrieSpec, data []byte) []byte { // Used for verification of serialized proof data func hashSerialization(smt *TrieSpec, data []byte) []byte { - if isExtension(data) { + if isExtNode(data) { pathBounds, path, childHash := parseExtension(data, smt.ph) ext := extensionNode{path: path, child: &lazyNode{childHash}} copy(ext.pathBounds[:], pathBounds) @@ -189,7 +189,7 @@ func hashSerialization(smt *TrieSpec, data []byte) []byte { // Used for verification of serialized proof data for sum trie nodes func hashSumSerialization(smt *TrieSpec, data []byte) []byte { - if isExtension(data) { + if isExtNode(data) { pathBounds, path, childHash, _ := parseSumExtension(data, smt.ph) ext := extensionNode{path: path, child: &lazyNode{childHash}} copy(ext.pathBounds[:], pathBounds) From 185749638d16cce724bfcb301ffaec7339bcc7ef Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Fri, 9 Feb 2024 15:54:07 -0800 Subject: [PATCH 09/40] Renamed a few things and converted the examples to proper tests --- extension_node.go | 8 ++++---- proofs.go | 4 ++-- smst.go | 2 +- smst_example_test.go | 33 ++++++++++++++++++++++++--------- smst_proofs_test.go | 2 +- smst_utils_test.go | 2 +- smt.go | 2 +- smt_example_test.go | 12 +++++++----- smt_proofs_test.go | 2 +- smt_utils_test.go | 2 +- types.go | 28 ++++++++++++++++------------ utils.go | 6 +++--- 12 files changed, 62 insertions(+), 41 deletions(-) diff --git a/extension_node.go b/extension_node.go index 88335f3..f2d2e32 100644 --- a/extension_node.go +++ b/extension_node.go @@ -132,11 +132,11 @@ func (extNode *extensionNode) split(path []byte) (trieNode, *trieNode, int) { // expand returns the inner node that represents the start of the singly // linked list that this extension node represents -func (ext *extensionNode) expand() trieNode { - last := ext.child - for i := ext.pathEnd() - 1; i >= ext.pathStart(); i-- { +func (extNode *extensionNode) expand() trieNode { + last := extNode.child + for i := extNode.pathEnd() - 1; i >= extNode.pathStart(); i-- { var next innerNode - if getPathBit(ext.path, i) == leftChildBit { + if getPathBit(extNode.path, i) == leftChildBit { next.leftChild = last } else { next.rightChild = last diff --git a/proofs.go b/proofs.go index 211eed5..af4126c 100644 --- a/proofs.go +++ b/proofs.go @@ -283,7 +283,7 @@ func VerifyProof(proof *SparseMerkleProof, root, key, value []byte, spec *TrieSp func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) { var sumBz [sumSizeBits]byte binary.BigEndian.PutUint64(sumBz[:], sum) - valueHash := spec.digestValue(value) + valueHash := spec.valueDigest(value) valueHash = append(valueHash, sumBz[:]...) if bytes.Equal(value, defaultEmptyValue) && sum == 0 { valueHash = defaultEmptyValue @@ -348,7 +348,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v updates = append(updates, update) } } else { // Membership proof. - valueHash := spec.digestValue(value) + valueHash := spec.valueDigest(value) currentHash, currentData = digestLeaf(spec, path, valueHash) update := make([][]byte, 2) update[0], update[1] = currentHash, currentData diff --git a/smst.go b/smst.go index 394cdb3..58dc58b 100644 --- a/smst.go +++ b/smst.go @@ -84,7 +84,7 @@ func (smst *SMST) Get(key []byte) ([]byte, uint64, error) { // appended with the binary representation of the weight provided. The weight // is used to compute the interim and total sum of the trie. func (smst *SMST) Update(key, value []byte, weight uint64) error { - valueHash := smst.digestValue(value) + valueHash := smst.valueDigest(value) var weightBz [sumSizeBits]byte binary.BigEndian.PutUint64(weightBz[:], weight) valueHash = append(valueHash, weightBz[:]...) diff --git a/smst_example_test.go b/smst_example_test.go index f1cdbca..79d1772 100644 --- a/smst_example_test.go +++ b/smst_example_test.go @@ -2,18 +2,21 @@ package smt_test import ( "crypto/sha256" - "fmt" + "testing" + + "github.com/stretchr/testify/require" "github.com/pokt-network/smt" "github.com/pokt-network/smt/kvstore/simplemap" ) -func ExampleSMST() { - // Initialise a new in-memory key-value store to store the nodes of the trie +// TestExampleSMT is a test that aims to act as an example of how to use the SMST. +func TestExampleSMST(t *testing.T) { + // Initialize a new in-memory key-value store to store the nodes of the trie // (Note: the trie only stores hashed values, not raw value data) nodeStore := simplemap.NewSimpleMap() - // Initialise the trie + // Initialize the trie trie := smt.NewSparseMerkleSumTrie(nodeStore, sha256.New()) // Update trie with keys, values and their sums @@ -36,13 +39,25 @@ func ExampleSMST() { root := trie.Root() // Verify the Merkle proof for "foo"="oof" where "foo" has a sum of 10 - valid_true1, _ := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 10, trie.Spec()) + valid_true1, err := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 10, trie.Spec()) + require.NoError(t, err) + require.True(t, valid_true1) + // Verify the Merkle proof for "baz"="zab" where "baz" has a sum of 7 - valid_true2, _ := smt.VerifySumProof(proof2, root, []byte("baz"), []byte("zab"), 7, trie.Spec()) + valid_true2, err := smt.VerifySumProof(proof2, root, []byte("baz"), []byte("zab"), 7, trie.Spec()) + require.NoError(t, err) + require.True(t, valid_true2) + // Verify the Merkle proof for "bin"="nib" where "bin" has a sum of 3 - valid_true3, _ := smt.VerifySumProof(proof3, root, []byte("bin"), []byte("nib"), 3, trie.Spec()) + valid_true3, err := smt.VerifySumProof(proof3, root, []byte("bin"), []byte("nib"), 3, trie.Spec()) + require.NoError(t, err) + require.True(t, valid_true3) + // Fail to verify the Merkle proof for "foo"="oof" where "foo" has a sum of 11 - valid_false1, _ := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 11, trie.Spec()) - fmt.Println(valid_true1, valid_true2, valid_true3, valid_false1) + valid_false1, err := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 11, trie.Spec()) + require.NoError(t, err) + require.False(t, valid_false1) + // Output: true true true false + t.Log(valid_true1, valid_true2, valid_true3, valid_false1) } diff --git a/smst_proofs_test.go b/smst_proofs_test.go index 394d595..41ad48d 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -105,7 +105,7 @@ func TestSMST_Proof_Operations(t *testing.T) { // Try proving a default value for a non-default leaf. var sum [sumSizeBits]byte binary.BigEndian.PutUint64(sum[:], 5) - tval := base.digestValue([]byte("testValue")) + tval := base.valueDigest([]byte("testValue")) tval = append(tval, sum[:]...) _, leafData := base.th.digestSumLeaf(base.ph.Path([]byte("testKey2")), tval) proof = &SparseMerkleProof{ diff --git a/smst_utils_test.go b/smst_utils_test.go index 90e27d9..4526fcc 100644 --- a/smst_utils_test.go +++ b/smst_utils_test.go @@ -26,7 +26,7 @@ func (smst *SMSTWithStorage) Update(key, value []byte, sum uint64) error { if err := smst.SMST.Update(key, value, sum); err != nil { return err } - valueHash := smst.digestValue(value) + valueHash := smst.valueDigest(value) var sumBz [sumSizeBits]byte binary.BigEndian.PutUint64(sumBz[:], sum) value = append(value, sumBz[:]...) diff --git a/smt.go b/smt.go index 18e075f..de20e07 100644 --- a/smt.go +++ b/smt.go @@ -108,7 +108,7 @@ func (smt *SMT) Get(key []byte) ([]byte, error) { func (smt *SMT) Update(key []byte, value []byte) error { // Expand path := smt.ph.Path(key) - valueHash := smt.digestValue(value) + valueHash := smt.valueDigest(value) var orphans orphanNodes trie, err := smt.update(smt.root, 0, path, valueHash, &orphans) if err != nil { diff --git a/smt_example_test.go b/smt_example_test.go index 6d74980..2f7af1b 100644 --- a/smt_example_test.go +++ b/smt_example_test.go @@ -2,18 +2,19 @@ package smt_test import ( "crypto/sha256" - "fmt" + "testing" "github.com/pokt-network/smt" "github.com/pokt-network/smt/kvstore/simplemap" ) -func ExampleSMT() { - // Initialise a new in-memory key-value store to store the nodes of the trie +// TestExampleSMT is a test that aims to act as an example of how to use the SMST. +func TestExampleSMT(t *testing.T) { + // Initialize a new in-memory key-value store to store the nodes of the trie // (Note: the trie only stores hashed values, not raw value data) nodeStore := simplemap.NewSimpleMap() - // Initialise the trie + // Initialize the trie trie := smt.NewSparseMerkleTrie(nodeStore, sha256.New()) // Update the key "foo" with the value "bar" @@ -30,6 +31,7 @@ func ExampleSMT() { valid, _ := smt.VerifyProof(proof, root, []byte("foo"), []byte("bar"), trie.Spec()) // Attempt to verify the Merkle proof for "foo"="baz" invalid, _ := smt.VerifyProof(proof, root, []byte("foo"), []byte("baz"), trie.Spec()) - fmt.Println(valid, invalid) + // Output: true false + t.Log(valid, invalid) } diff --git a/smt_proofs_test.go b/smt_proofs_test.go index b1d5005..2364295 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -84,7 +84,7 @@ func TestSMT_Proof_Operations(t *testing.T) { require.False(t, result) // Try proving a default value for a non-default leaf. - _, leafData := base.th.digestLeaf(base.ph.Path([]byte("testKey2")), base.digestValue([]byte("testValue"))) + _, leafData := base.th.digestLeaf(base.ph.Path([]byte("testKey2")), base.valueDigest([]byte("testValue"))) proof = &SparseMerkleProof{ SideNodes: proof.SideNodes, NonMembershipLeafData: leafData, diff --git a/smt_utils_test.go b/smt_utils_test.go index cff400c..7137fa8 100644 --- a/smt_utils_test.go +++ b/smt_utils_test.go @@ -24,7 +24,7 @@ func (smt *SMTWithStorage) Update(key, value []byte) error { if err := smt.SMT.Update(key, value); err != nil { return err } - valueHash := smt.digestValue(value) + valueHash := smt.valueDigest(value) return smt.preimages.Set(valueHash, value) } diff --git a/types.go b/types.go index ab54818..8d438b2 100644 --- a/types.go +++ b/types.go @@ -108,35 +108,39 @@ func (spec *TrieSpec) Spec() *TrieSpec { // depth returns the maximum depth of the trie. // Since this tree is a binary tree, the depth is the number of bits in the path +// TODO_IN_THIS_PR: Try to understand why we're not taking the log of the output func (spec *TrieSpec) depth() int { return spec.ph.PathSize() * 8 // path size is in bytes so multiply by 8 to get num bits } -func (spec *TrieSpec) digestValue(data []byte) []byte { +// valueDigest returns the hash of a value, or the value itself if no value hasher is specified. +func (spec *TrieSpec) valueDigest(value []byte) []byte { if spec.vh == nil { - return data + return value } - return spec.vh.HashValue(data) + return spec.vh.HashValue(value) } -func (spec *TrieSpec) serialize(node trieNode) (data []byte) { +// encodeNode serializes a node into a byte slice +func (spec *TrieSpec) encodeNode(node trieNode) (data []byte) { switch n := node.(type) { case *lazyNode: - panic("serialize(lazyNode)") + panic("Encoding a lazyNode is not supported") case *leafNode: return encodeLeafNode(n.path, n.valueHash) case *innerNode: - lchild := spec.hashNode(n.leftChild) - rchild := spec.hashNode(n.rightChild) - return encodeInnerNode(lchild, rchild) + leftChild := spec.digestNode(n.leftChild) + rightChild := spec.digestNode(n.rightChild) + return encodeInnerNode(leftChild, rightChild) case *extensionNode: - child := spec.hashNode(n.child) + child := spec.digestNode(n.child) return encodeExtensionNode(n.pathBounds, n.path, child) } return nil } -func (spec *TrieSpec) hashNode(node trieNode) []byte { +// digestNode hashes a node returning its digest +func (spec *TrieSpec) digestNode(node trieNode) []byte { if node == nil { return spec.th.placeholder() } @@ -150,12 +154,12 @@ func (spec *TrieSpec) hashNode(node trieNode) []byte { cache = &n.digest case *extensionNode: if n.digest == nil { - n.digest = spec.hashNode(n.expand()) + n.digest = spec.digestNode(n.expand()) } return n.digest } if *cache == nil { - *cache = spec.th.digest(spec.serialize(node)) + *cache = spec.th.digest(spec.encodeNode(node)) } return *cache } diff --git a/utils.go b/utils.go index cda5ae2..801284f 100644 --- a/utils.go +++ b/utils.go @@ -157,7 +157,7 @@ func hashNode(spec *TrieSpec, node trieNode) []byte { if spec.sumTrie { return spec.hashSumNode(node) } - return spec.hashNode(node) + return spec.digestNode(node) } // serialize serializes a node depending on the trie type @@ -165,7 +165,7 @@ func serialize(spec *TrieSpec, node trieNode) []byte { if spec.sumTrie { return spec.sumSerialize(node) } - return spec.serialize(node) + return spec.encodeNode(node) } // hashPreimage hashes the serialised data provided depending on the trie type @@ -182,7 +182,7 @@ func hashSerialization(smt *TrieSpec, data []byte) []byte { pathBounds, path, childHash := parseExtension(data, smt.ph) ext := extensionNode{path: path, child: &lazyNode{childHash}} copy(ext.pathBounds[:], pathBounds) - return smt.hashNode(&ext) + return smt.digestNode(&ext) } return smt.th.digest(data) } From 340d0604cb1c3afc08dae862e7ac521a3f1e4747 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Fri, 9 Feb 2024 16:08:44 -0800 Subject: [PATCH 10/40] Added e2e_review.md and started going through the _node.go files --- e2e_review.md | 67 ++++++++++++++++++++++++++++++++++++++++++++ extension_node.go | 1 + godoc.go | 14 +++++---- inner_node.go | 2 ++ lazy_node.go | 3 +- leaf_node.go | 3 +- options.go | 10 +++++-- smst_example_test.go | 28 +++++++++++------- smst_proofs_test.go | 10 +++---- smt.go | 2 +- smt_proofs_test.go | 10 +++---- 11 files changed, 118 insertions(+), 32 deletions(-) create mode 100644 e2e_review.md diff --git a/e2e_review.md b/e2e_review.md new file mode 100644 index 0000000..7c11b85 --- /dev/null +++ b/e2e_review.md @@ -0,0 +1,67 @@ +# E2E Code Review + + -[ ] . + -[ ] ├── LICENSE + -[ ] ├── Makefile + -[ ] ├── README.md + -[ ] ├── benchmarks + -[ ] │   ├── bench_leaf_test.go + -[ ] │   ├── bench_smst_test.go + -[ ] │   ├── bench_smt_test.go + -[ ] │   ├── bench_utils_test.go + -[ ] │   └── proof_sizes_test.go + -[ ] ├── bulk_test.go + -[ ] ├── docs + -[ ] │   ├── audit.md + -[ ] │   ├── badger-store.md + -[ ] │   ├── benchmarks.md + -[ ] │   ├── faq.md + -[ ] │   ├── mapstore.md + -[ ] │   ├── merkle-sum-trie.md + -[ ] │   └── smt.md + -[ ] ├── errors.go + -[ ] ├── extension_node.go + -[ ] ├── fuzz_test.go + -[x] ├── go.mod + -[x] ├── go.sum + -[x] ├── go.work + -[x] ├── go.work.sum + -[x] ├── godoc.go + -[ ] ├── hasher.go + -[x] ├── inner_node.go + -[ ] ├── kvstore + -[ ] │   ├── badger + -[ ] │   │   ├── errors.go + -[ ] │   │   ├── go.mod + -[ ] │   │   ├── go.sum + -[ ] │   │   ├── godoc.go + -[ ] │   │   ├── interface.go + -[ ] │   │   ├── kvstore.go + -[ ] │   │   └── kvstore_test.go + -[ ] │   ├── godoc.go + -[ ] │   ├── interfaces.go + -[ ] │   └── simplemap + -[ ] │   ├── errors.go + -[ ] │   ├── godoc.go + -[ ] │   ├── simplemap.go + -[ ] │   └── simplemap_test.go + -[x] ├── lazy_node.go + -[x] ├── leaf_node.go + -[ ] ├── node_encoders.go + -[ ] ├── options.go + -[ ] ├── proofs.go + -[ ] ├── proofs_test.go + -[ ] ├── reviewpad.yml + -[ ] ├── root_test.go + -[ ] ├── smst.go + -[ ] ├── smst_example_test.go + -[ ] ├── smst_proofs_test.go + -[ ] ├── smst_test.go + -[ ] ├── smst_utils_test.go + -[ ] ├── smt.go + -[ ] ├── smt_example_test.go + -[ ] ├── smt_proofs_test.go + -[ ] ├── smt_test.go + -[ ] ├── smt_utils_test.go + -[ ] ├── types.go + -[ ] └── utils.go diff --git a/extension_node.go b/extension_node.go index f2d2e32..38282d3 100644 --- a/extension_node.go +++ b/extension_node.go @@ -1,5 +1,6 @@ package smt +// Ensure extensionNode satisfies the trieNode interface var _ trieNode = (*extensionNode)(nil) // A compressed chain of singly-linked inner nodes. diff --git a/godoc.go b/godoc.go index 7354c89..67fc078 100644 --- a/godoc.go +++ b/godoc.go @@ -1,11 +1,13 @@ // Package smt provides an implementation of a Sparse Merkle Trie for a -// key-value map. +// key-value map or engine. // // The trie implements the same optimizations specified in the JMT -// whitepaper to account for empty and single-node subtrees. Unlike the -// JMT, it only supports binary trees and does not optimise for RockDB -// on-disk storage. +// whitepaper to account for empty and single-node subtrees. + +// Unlike the JMT, it only supports binary trees and does not implemented the +// same RocksDB optimizations as specified in the original JMT library when +// optimizing for disk iops // -// This package implements novel features that include native in-node -// weight sums, as well as support for ClosestProof mechanics. +// This package implements additional SMT specific functionality related to +// tree sums and closest proof mechanics. package smt diff --git a/inner_node.go b/inner_node.go index e6b3a63..96ed2f1 100644 --- a/inner_node.go +++ b/inner_node.go @@ -1,5 +1,6 @@ package smt +// Ensure innerNode satisfies the trieNode interface var _ trieNode = (*innerNode)(nil) // A branch within the binary trie pointing to a left & right child. @@ -17,6 +18,7 @@ func (node *innerNode) Persisted() bool { return node.persisted } // Persisted satisfied the trieNode#CachedDigest interface func (node *innerNode) CachedDigest() []byte { return node.digest } +// setDirty marks the node as dirty (i.e. not flushed to disk) and clears the cached digest func (node *innerNode) setDirty() { node.persisted = false node.digest = nil diff --git a/lazy_node.go b/lazy_node.go index a0fc822..fa60fef 100644 --- a/lazy_node.go +++ b/lazy_node.go @@ -1,8 +1,9 @@ package smt +// Ensure lazyNode satisfies the trieNode interface var _ trieNode = (*lazyNode)(nil) -// Represents an uncached, persisted node +// lazyNode represents an uncached persisted node type lazyNode struct { digest []byte } diff --git a/leaf_node.go b/leaf_node.go index 4c894e4..2c84c2f 100644 --- a/leaf_node.go +++ b/leaf_node.go @@ -1,8 +1,9 @@ package smt +// Ensure leafNode satisfies the trieNode interface var _ trieNode = (*leafNode)(nil) -// A leaf node storing a key-value pair for a full path. +// leafNode stores a full key-value pair in the trie type leafNode struct { path []byte valueHash []byte diff --git a/options.go b/options.go index c4eb422..4759021 100644 --- a/options.go +++ b/options.go @@ -17,16 +17,20 @@ func WithValueHasher(vh ValueHasher) Option { return func(ts *TrieSpec) { ts.vh = vh } } -// NoPrehashSpec returns a new TrieSpec that has a nil Value Hasher and a nil -// Path Hasher +// NoHasherSpec returns a new TrieSpec that has nil ValueHasher & PathHasher specs. // NOTE: This should only be used when values are already hashed and a path is // used instead of a key during proof verification, otherwise these will be // double hashed and produce an incorrect leaf digest invalidating the proof. -func NoPrehashSpec(hasher hash.Hash, sumTrie bool) *TrieSpec { +func NoHasherSpec(hasher hash.Hash, sumTrie bool) *TrieSpec { spec := newTrieSpec(hasher, sumTrie) + + // Set a nil path hasher opt := WithPathHasher(newNilPathHasher(hasher.Size())) opt(&spec) + // Set a nil value hasher opt = WithValueHasher(nil) opt(&spec) + + // Return the spec return &spec } diff --git a/smst_example_test.go b/smst_example_test.go index 79d1772..05d2425 100644 --- a/smst_example_test.go +++ b/smst_example_test.go @@ -12,28 +12,36 @@ import ( // TestExampleSMT is a test that aims to act as an example of how to use the SMST. func TestExampleSMST(t *testing.T) { - // Initialize a new in-memory key-value store to store the nodes of the trie - // (Note: the trie only stores hashed values, not raw value data) + // Initialize a new in-memory key-value store to store the nodes of the trie. + // NB: The trie only stores hashed values, not raw value data. nodeStore := simplemap.NewSimpleMap() // Initialize the trie trie := smt.NewSparseMerkleSumTrie(nodeStore, sha256.New()) // Update trie with keys, values and their sums - _ = trie.Update([]byte("foo"), []byte("oof"), 10) - _ = trie.Update([]byte("baz"), []byte("zab"), 7) - _ = trie.Update([]byte("bin"), []byte("nib"), 3) + err := trie.Update([]byte("foo"), []byte("oof"), 10) + require.NoError(t, err) + err = trie.Update([]byte("baz"), []byte("zab"), 7) + require.NoError(t, err) + err = trie.Update([]byte("bin"), []byte("nib"), 3) + require.NoError(t, err) // Commit the changes to the nodeStore - _ = trie.Commit() + err = trie.Commit() + require.NoError(t, err) // Calculate the total sum of the trie - _ = trie.Sum() // 20 + sum := trie.Sum() + require.Equal(t, uint64(20), sum) // Generate a Merkle proof for "foo" - proof1, _ := trie.Prove([]byte("foo")) - proof2, _ := trie.Prove([]byte("baz")) - proof3, _ := trie.Prove([]byte("bin")) + proof1, err := trie.Prove([]byte("foo")) + require.NoError(t, err) + proof2, err := trie.Prove([]byte("baz")) + require.NoError(t, err) + proof3, err := trie.Prove([]byte("bin")) + require.NoError(t, err) // We also need the current trie root for the proof root := trie.Root() diff --git a/smst_proofs_test.go b/smst_proofs_test.go index 41ad48d..c598d92 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -204,7 +204,7 @@ func TestSMST_Proof_ValidateBasic(t *testing.T) { func TestSMST_ClosestProof_ValidateBasic(t *testing.T) { smn := simplemap.NewSimpleMap() smst := NewSparseMerkleSumTrie(smn, sha256.New()) - np := NoPrehashSpec(sha256.New(), true) + np := NoHasherSpec(sha256.New(), true) base := smst.Spec() path := sha256.Sum256([]byte("testKey2")) flipPathBit(path[:], 3) @@ -326,7 +326,7 @@ func TestSMST_ProveClosest(t *testing.T) { ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields }) - result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true)) + result, err = VerifyClosestProof(proof, root, NoHasherSpec(sha256.New(), true)) require.NoError(t, err) require.True(t, result) @@ -352,7 +352,7 @@ func TestSMST_ProveClosest(t *testing.T) { ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields }) - result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true)) + result, err = VerifyClosestProof(proof, root, NoHasherSpec(sha256.New(), true)) require.NoError(t, err) require.True(t, result) } @@ -381,7 +381,7 @@ func TestSMST_ProveClosest_Empty(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true)) + result, err := VerifyClosestProof(proof, smst.Root(), NoHasherSpec(sha256.New(), true)) require.NoError(t, err) require.True(t, result) } @@ -419,7 +419,7 @@ func TestSMST_ProveClosest_OneNode(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true)) + result, err := VerifyClosestProof(proof, smst.Root(), NoHasherSpec(sha256.New(), true)) require.NoError(t, err) require.True(t, result) } diff --git a/smt.go b/smt.go index de20e07..cda02c0 100644 --- a/smt.go +++ b/smt.go @@ -7,7 +7,7 @@ import ( "github.com/pokt-network/smt/kvstore" ) -// Make sure the `SMT` struct implements the `SparseMerkleTrie` interface +// Ensure the `SMT` struct implements the `SparseMerkleTrie` interface var _ SparseMerkleTrie = (*SMT)(nil) // SMT is a Sparse Merkle Trie object that implements the SparseMerkleTrie interface diff --git a/smt_proofs_test.go b/smt_proofs_test.go index 2364295..06482c7 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -178,7 +178,7 @@ func TestSMT_Proof_ValidateBasic(t *testing.T) { func TestSMT_ClosestProof_ValidateBasic(t *testing.T) { smn := simplemap.NewSimpleMap() smt := NewSparseMerkleTrie(smn, sha256.New()) - np := NoPrehashSpec(sha256.New(), false) + np := NoHasherSpec(sha256.New(), false) base := smt.Spec() path := sha256.Sum256([]byte("testKey2")) flipPathBit(path[:], 3) @@ -287,7 +287,7 @@ func TestSMT_ProveClosest(t *testing.T) { checkClosestCompactEquivalence(t, proof, smt.Spec()) require.NotEqual(t, proof, &SparseMerkleClosestProof{}) - result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false)) + result, err = VerifyClosestProof(proof, root, NoHasherSpec(sha256.New(), false)) require.NoError(t, err) require.True(t, result) closestPath := sha256.Sum256([]byte("testKey2")) @@ -304,7 +304,7 @@ func TestSMT_ProveClosest(t *testing.T) { checkClosestCompactEquivalence(t, proof, smt.Spec()) require.NotEqual(t, proof, &SparseMerkleClosestProof{}) - result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false)) + result, err = VerifyClosestProof(proof, root, NoHasherSpec(sha256.New(), false)) require.NoError(t, err) require.True(t, result) closestPath = sha256.Sum256([]byte("testKey4")) @@ -336,7 +336,7 @@ func TestSMT_ProveClosest_Empty(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false)) + result, err := VerifyClosestProof(proof, smt.Root(), NoHasherSpec(sha256.New(), false)) require.NoError(t, err) require.True(t, result) } @@ -368,7 +368,7 @@ func TestSMT_ProveClosest_OneNode(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false)) + result, err := VerifyClosestProof(proof, smt.Root(), NoHasherSpec(sha256.New(), false)) require.NoError(t, err) require.True(t, result) } From ae20eab869373b1d643185f16717ab574db1efbf Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Sat, 10 Feb 2024 12:22:03 -0800 Subject: [PATCH 11/40] A lot of things in flux but finished commenting node_encoders --- e2e_review.md | 4 +- hasher.go | 24 +++++++++- kvstore/simplemap/simplemap.go | 8 ++++ node_encoders.go | 87 ++++++++++++++++++++-------------- options.go | 11 +++-- proofs.go | 4 +- smst.go | 47 ++++++++++++------ smst_example_test.go | 73 ++++++++++++++++++++-------- smst_proofs_test.go | 2 +- smst_utils_test.go | 2 +- smt.go | 44 +++++++++-------- smt_proofs_test.go | 2 +- smt_utils_test.go | 2 +- types.go | 6 +-- utils.go | 16 +------ 15 files changed, 213 insertions(+), 119 deletions(-) diff --git a/e2e_review.md b/e2e_review.md index 7c11b85..99da43b 100644 --- a/e2e_review.md +++ b/e2e_review.md @@ -47,8 +47,8 @@ -[ ] │   └── simplemap_test.go -[x] ├── lazy_node.go -[x] ├── leaf_node.go - -[ ] ├── node_encoders.go - -[ ] ├── options.go + -[x] ├── node_encoders.go + -[x] ├── options.go -[ ] ├── proofs.go -[ ] ├── proofs_test.go -[ ] ├── reviewpad.yml diff --git a/hasher.go b/hasher.go index e1b0ce3..3854873 100644 --- a/hasher.go +++ b/hasher.go @@ -4,8 +4,10 @@ import ( "hash" ) +// Ensure the hasher interfaces are satisfied var ( _ PathHasher = (*pathHasher)(nil) + _ PathHasher = (*nilPathHasher)(nil) _ ValueHasher = (*valueHasher)(nil) ) @@ -39,13 +41,21 @@ type valueHasher struct { trieHasher } -// newTrieHasher returns a new trie hasher with the given hash function. -func newTrieHasher(hasher hash.Hash) *trieHasher { +type nilPathHasher struct { + hashSize int +} + +// NewTrieHasher returns a new trie hasher with the given hash function. +func NewTrieHasher(hasher hash.Hash) *trieHasher { th := trieHasher{hasher: hasher} th.zeroValue = make([]byte, th.hashSize()) return &th } +func NewNilPathHasher(hasher hash.Hash) PathHasher { + return &nilPathHasher{hashSize: hasher.Size()} +} + // Path returns the digest of a key produced by the path hasher func (ph *pathHasher) Path(key []byte) []byte { return ph.digest(key)[:ph.PathSize()] @@ -62,6 +72,16 @@ func (vh *valueHasher) HashValue(data []byte) []byte { return vh.digest(data) } +// Path satisfies the PathHasher#Path interface +func (n *nilPathHasher) Path(key []byte) []byte { + return key[:n.hashSize] +} + +// PathSize satisfies the PathHasher#PathSize interface +func (n *nilPathHasher) PathSize() int { + return n.hashSize +} + // digest returns the hash of the data provided using the trie hasher. func (th *trieHasher) digest(data []byte) []byte { th.hasher.Write(data) diff --git a/kvstore/simplemap/simplemap.go b/kvstore/simplemap/simplemap.go index 9b548ab..c20e4bf 100644 --- a/kvstore/simplemap/simplemap.go +++ b/kvstore/simplemap/simplemap.go @@ -19,6 +19,14 @@ func NewSimpleMap() kvstore.MapStore { } } +// NewSimpleMap creates a new SimpleMap instance using the map provided. +// This is useful for testing & debugging purposes. +func NewSimpleMapWithMap(m map[string][]byte) kvstore.MapStore { + return &simpleMap{ + m: m, + } +} + // Get gets the value for a key. func (sm *simpleMap) Get(key []byte) ([]byte, error) { if len(key) == 0 { diff --git a/node_encoders.go b/node_encoders.go index b9d1bfd..66d6217 100644 --- a/node_encoders.go +++ b/node_encoders.go @@ -8,6 +8,10 @@ import ( // TODO_IMPROVE: All of the parsing, encoding and checking functions in this file // can be abstracted out into the `trieNode` interface. +// TODO_IMPROVE: We should create well-defined types & structs for every type of node +// (e.g. protobufs) to streamline the process of encoding & encoding and to improve +// readability. + // NB: In this file, all references to the variable `data` should be treated as `encodedNodeData`. // It was abbreviated to `data` for brevity. @@ -27,83 +31,96 @@ func isExtNode(data []byte) bool { return bytes.Equal(data[:len(extNodePrefix)], extNodePrefix) } -func parseLeafNode(data []byte, ph PathHasher) ([]byte, []byte) { - return data[len(leafNodePrefix) : ph.PathSize()+len(leafNodePrefix)], data[len(leafNodePrefix)+ph.PathSize():] +// parseLeafNode parses a leafNode into its components +func parseLeafNode(data []byte, ph PathHasher) (leftChild, rightChild []byte) { + leftChild = data[len(leafNodePrefix) : len(leafNodePrefix)+ph.PathSize()] + rightChild = data[len(leafNodePrefix)+ph.PathSize():] + return } -func parseExtension(data []byte, ph PathHasher) (pathBounds, path, childData []byte) { - return data[len(extNodePrefix) : len(extNodePrefix)+2], // +2 represents the length of the pathBounds - data[len(extNodePrefix)+2 : len(extNodePrefix)+2+ph.PathSize()], - data[len(extNodePrefix)+2+ph.PathSize():] +// parseExtNode parses an extNode into its components +func parseExtNode(data []byte, ph PathHasher) (pathBounds, path, childData []byte) { + // +2 represents the length of the pathBounds + pathBounds = data[len(extNodePrefix) : len(extNodePrefix)+2] + path = data[len(extNodePrefix)+2 : len(extNodePrefix)+2+ph.PathSize()] + childData = data[len(extNodePrefix)+2+ph.PathSize():] + return } -func parseSumExtension(data []byte, ph PathHasher) (pathBounds, path, childData []byte, sum [sumSizeBits]byte) { +// parseSumExtNode parses the pathBounds, path, child data and sum from the encoded extension node data +func parseSumExtNode(data []byte, ph PathHasher) (pathBounds, path, childData []byte, sum [sumSizeBits]byte) { + // Extract the sum from the encoded node data var sumBz [sumSizeBits]byte copy(sumBz[:], data[len(data)-sumSizeBits:]) - return data[len(extNodePrefix) : len(extNodePrefix)+2], // +2 represents the length of the pathBounds - data[len(extNodePrefix)+2 : len(extNodePrefix)+2+ph.PathSize()], - data[len(extNodePrefix)+2+ph.PathSize() : len(data)-sumSizeBits], - sumBz + + // +2 represents the length of the pathBounds + pathBounds = data[len(extNodePrefix) : len(extNodePrefix)+2] + path = data[len(extNodePrefix)+2 : len(extNodePrefix)+2+ph.PathSize()] + childData = data[len(extNodePrefix)+2+ph.PathSize() : len(data)-sumSizeBits] + return } // encodeLeafNode encodes leaf nodes. both normal and sum leaves as in the sum leaf the // sum is appended to the end of the valueHash -func encodeLeafNode(path []byte, leafData []byte) []byte { - data := make([]byte, 0, len(leafNodePrefix)+len(path)+len(leafData)) +func encodeLeafNode(path, leafData []byte) (data []byte) { data = append(data, leafNodePrefix...) data = append(data, path...) data = append(data, leafData...) - return data + return } -func encodeInnerNode(leftData []byte, rightData []byte) []byte { - data := make([]byte, 0, len(innerNodePrefix)+len(leftData)+len(rightData)) +// encodeInnerNode encodes inner node given the data for both children +func encodeInnerNode(leftData, rightData []byte) (data []byte) { data = append(data, innerNodePrefix...) data = append(data, leftData...) data = append(data, rightData...) - return data + return } -func encodeSumInnerNode(leftData []byte, rightData []byte) []byte { - data := make([]byte, 0, len(innerNodePrefix)+len(leftData)+len(rightData)) - data = append(data, innerNodePrefix...) - data = append(data, leftData...) - data = append(data, rightData...) - - var sum [sumSizeBits]byte +// encodeSumInnerNode encodes an inner node for an smst given the data for both children +func encodeSumInnerNode(leftData, rightData []byte) (data []byte) { + // Retrieve the sum of the left subtree leftSum := uint64(0) - rightSum := uint64(0) leftSumBz := leftData[len(leftData)-sumSizeBits:] - rightSumBz := rightData[len(rightData)-sumSizeBits:] if !bytes.Equal(leftSumBz, defaultEmptySum[:]) { leftSum = binary.BigEndian.Uint64(leftSumBz) } + + // Retrieve the sum of the right subtree + rightSum := uint64(0) + rightSumBz := rightData[len(rightData)-sumSizeBits:] if !bytes.Equal(rightSumBz, defaultEmptySum[:]) { rightSum = binary.BigEndian.Uint64(rightSumBz) } + + // Compute the sum of the current node + var sum [sumSizeBits]byte binary.BigEndian.PutUint64(sum[:], leftSum+rightSum) + + // Prepare and return the encoded inner node data + data = encodeInnerNode(leftData, rightData) data = append(data, sum[:]...) - return data + return } // encodeExtensionNode encodes the data of an extension nodes -func encodeExtensionNode(pathBounds [2]byte, path []byte, childData []byte) []byte { - data := []byte{} +func encodeExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byte) { data = append(data, extNodePrefix...) data = append(data, pathBounds[:]...) data = append(data, path...) data = append(data, childData...) - return data + return } // encodeSumExtensionNode encodes the data of a sum extension nodes -func encodeSumExtensionNode(pathBounds [2]byte, path []byte, childData []byte) []byte { - data := encodeExtensionNode(pathBounds, path, childData) +func encodeSumExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byte) { - // Append the sum to the end of the data + // Compute the sum of the current node var sum [sumSizeBits]byte copy(sum[:], childData[len(childData)-sumSizeBits:]) - data = append(data, sum[:]...) - return data + // Prepare and return the encoded inner node data + data = encodeExtensionNode(pathBounds, path, childData) + data = append(data, sum[:]...) + return } diff --git a/options.go b/options.go index 4759021..f2d6bc2 100644 --- a/options.go +++ b/options.go @@ -18,15 +18,18 @@ func WithValueHasher(vh ValueHasher) Option { } // NoHasherSpec returns a new TrieSpec that has nil ValueHasher & PathHasher specs. -// NOTE: This should only be used when values are already hashed and a path is -// used instead of a key during proof verification, otherwise these will be -// double hashed and produce an incorrect leaf digest invalidating the proof. +// NB: This should only be used when values are already hashed and a path is +// used instead of a key during proof verification. Otherwise, this will lead +// double hashing and product an incorrect leaf digest, thereby invalidating +// the proof. +// TODO_IN_THIS_PR: Need to understand this part more. func NoHasherSpec(hasher hash.Hash, sumTrie bool) *TrieSpec { spec := newTrieSpec(hasher, sumTrie) // Set a nil path hasher - opt := WithPathHasher(newNilPathHasher(hasher.Size())) + opt := WithPathHasher(NewNilPathHasher(hasher)) opt(&spec) + // Set a nil value hasher opt = WithValueHasher(nil) opt(&spec) diff --git a/proofs.go b/proofs.go index af4126c..3d242d4 100644 --- a/proofs.go +++ b/proofs.go @@ -283,7 +283,7 @@ func VerifyProof(proof *SparseMerkleProof, root, key, value []byte, spec *TrieSp func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) { var sumBz [sumSizeBits]byte binary.BigEndian.PutUint64(sumBz[:], sum) - valueHash := spec.valueDigest(value) + valueHash := spec.valueHash(value) valueHash = append(valueHash, sumBz[:]...) if bytes.Equal(value, defaultEmptyValue) && sum == 0 { valueHash = defaultEmptyValue @@ -348,7 +348,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v updates = append(updates, update) } } else { // Membership proof. - valueHash := spec.valueDigest(value) + valueHash := spec.valueHash(value) currentHash, currentData = digestLeaf(spec, path, valueHash) update := make([][]byte, 2) update[0], update[1] = currentHash, currentData diff --git a/smst.go b/smst.go index 58dc58b..dc0d7db 100644 --- a/smst.go +++ b/smst.go @@ -64,31 +64,50 @@ func (smst *SMST) Spec() *TrieSpec { return &smst.TrieSpec } -// Get returns the digest of the value stored at the given key and the weight -// of the leaf node -func (smst *SMST) Get(key []byte) ([]byte, uint64, error) { - valueHash, err := smst.SMT.Get(key) +// Get retrieves the value digest for the given key and the digest of the value +// along with its weight provided a leaf node exists. +func (smst *SMST) Get(key []byte) (valueDigest []byte, weight uint64, err error) { + // Retrieve the value digest from the trie for the given key + valueDigest, err = smst.SMT.Get(key) if err != nil { return nil, 0, err } - if bytes.Equal(valueHash, defaultEmptyValue) { + + // Check if it ias an empty branch + if bytes.Equal(valueDigest, defaultEmptyValue) { return defaultEmptyValue, 0, nil } + + // Retrieve the node weight var weightBz [sumSizeBits]byte - copy(weightBz[:], valueHash[len(valueHash)-sumSizeBits:]) - weight := binary.BigEndian.Uint64(weightBz[:]) - return valueHash[:len(valueHash)-sumSizeBits], weight, nil + copy(weightBz[:], valueDigest[len(valueDigest)-sumSizeBits:]) + weight = binary.BigEndian.Uint64(weightBz[:]) + + // Remove the weight from the value digest + valueDigest = valueDigest[:len(valueDigest)-sumSizeBits] + + // Return the value digest and weight + return valueDigest, weight, nil } -// Update sets the value for the given key, to the digest of the provided value -// appended with the binary representation of the weight provided. The weight -// is used to compute the interim and total sum of the trie. +// Update inserts the value and weight into the trie for the given key. +// +// The a digest (i.e. hash) of the value is computed and appended with the byte +// representation of the weight integer provided. + +// The weight is used to compute the interim sum of the node which then percolates +// up to the total sum of the trie. func (smst *SMST) Update(key, value []byte, weight uint64) error { - valueHash := smst.valueDigest(value) + // Convert the node weight to a byte slice var weightBz [sumSizeBits]byte binary.BigEndian.PutUint64(weightBz[:], weight) - valueHash = append(valueHash, weightBz[:]...) - return smst.SMT.Update(key, valueHash) + + // Compute the digest of the value and append the weight to it + valueDigest := smst.valueHash(value) + valueDigest = append(valueDigest, weightBz[:]...) + + // Return the result of the trie update + return smst.SMT.Update(key, valueDigest) } // Delete removes the node at the path corresponding to the given key diff --git a/smst_example_test.go b/smst_example_test.go index 05d2425..c2c5b53 100644 --- a/smst_example_test.go +++ b/smst_example_test.go @@ -1,71 +1,104 @@ -package smt_test +package smt import ( "crypto/sha256" + "fmt" "testing" "github.com/stretchr/testify/require" - "github.com/pokt-network/smt" + "github.com/pokt-network/smt/kvstore" "github.com/pokt-network/smt/kvstore/simplemap" ) // TestExampleSMT is a test that aims to act as an example of how to use the SMST. func TestExampleSMST(t *testing.T) { + dataMap := make(map[string]string) + // Initialize a new in-memory key-value store to store the nodes of the trie. - // NB: The trie only stores hashed values, not raw value data. + // NB: The trie only stores hashed values and not raw value data. nodeStore := simplemap.NewSimpleMap() - // Initialize the trie - trie := smt.NewSparseMerkleSumTrie(nodeStore, sha256.New()) + // Initialize the smst + smst := NewSparseMerkleSumTrie(nodeStore, sha256.New()) //, smt.WithValueHasher(nil), smt.WithPathHasher(smt.NewNilPathHasher(sha256.New()))) // Update trie with keys, values and their sums - err := trie.Update([]byte("foo"), []byte("oof"), 10) + err := smst.Update([]byte("foo"), []byte("oof"), 10) require.NoError(t, err) - err = trie.Update([]byte("baz"), []byte("zab"), 7) + dataMap["foo"] = "oof" + err = smst.Update([]byte("baz"), []byte("zab"), 7) require.NoError(t, err) - err = trie.Update([]byte("bin"), []byte("nib"), 3) + dataMap["baz"] = "zab" + err = smst.Update([]byte("bin"), []byte("nib"), 3) require.NoError(t, err) + dataMap["bin"] = "nib" // Commit the changes to the nodeStore - err = trie.Commit() + err = smst.Commit() require.NoError(t, err) // Calculate the total sum of the trie - sum := trie.Sum() + sum := smst.Sum() require.Equal(t, uint64(20), sum) // Generate a Merkle proof for "foo" - proof1, err := trie.Prove([]byte("foo")) + proof1, err := smst.Prove([]byte("foo")) require.NoError(t, err) - proof2, err := trie.Prove([]byte("baz")) + proof2, err := smst.Prove([]byte("baz")) require.NoError(t, err) - proof3, err := trie.Prove([]byte("bin")) + proof3, err := smst.Prove([]byte("bin")) require.NoError(t, err) // We also need the current trie root for the proof - root := trie.Root() + root := smst.Root() // Verify the Merkle proof for "foo"="oof" where "foo" has a sum of 10 - valid_true1, err := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 10, trie.Spec()) + valid_true1, err := VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 10, smst.Spec()) require.NoError(t, err) require.True(t, valid_true1) // Verify the Merkle proof for "baz"="zab" where "baz" has a sum of 7 - valid_true2, err := smt.VerifySumProof(proof2, root, []byte("baz"), []byte("zab"), 7, trie.Spec()) + valid_true2, err := VerifySumProof(proof2, root, []byte("baz"), []byte("zab"), 7, smst.Spec()) require.NoError(t, err) require.True(t, valid_true2) // Verify the Merkle proof for "bin"="nib" where "bin" has a sum of 3 - valid_true3, err := smt.VerifySumProof(proof3, root, []byte("bin"), []byte("nib"), 3, trie.Spec()) + valid_true3, err := VerifySumProof(proof3, root, []byte("bin"), []byte("nib"), 3, smst.Spec()) require.NoError(t, err) require.True(t, valid_true3) // Fail to verify the Merkle proof for "foo"="oof" where "foo" has a sum of 11 - valid_false1, err := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 11, trie.Spec()) + valid_false1, err := VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 11, smst.Spec()) require.NoError(t, err) require.False(t, valid_false1) - // Output: true true true false - t.Log(valid_true1, valid_true2, valid_true3, valid_false1) + exportToCSV(smst, dataMap, nodeStore) +} + +func exportToCSV( + smst SparseMerkleSumTrie, + innerMap map[string]string, + nodeStore kvstore.MapStore, +) { + // hasher := sha256.New() + fmt.Println("Exporting to CSV", smst.Root()) + // rootBits := smst.Root() + for key, value := range innerMap { + // fmt.Println(key, smst.Spec().) + v, s, err := smst.Get([]byte(key)) + if err != nil { + panic(err) + } + // parseLeafNode() + fmt.Println(v, s) + fmt.Println(value) + fmt.Println("") + fmt.Println("") + } + + // Export the trie to a CSV file + // err := smt.ExportToCSV("export.csv") + // if err != nil { + // panic(err) + // } } diff --git a/smst_proofs_test.go b/smst_proofs_test.go index c598d92..e951cbb 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -105,7 +105,7 @@ func TestSMST_Proof_Operations(t *testing.T) { // Try proving a default value for a non-default leaf. var sum [sumSizeBits]byte binary.BigEndian.PutUint64(sum[:], 5) - tval := base.valueDigest([]byte("testValue")) + tval := base.valueHash([]byte("testValue")) tval = append(tval, sum[:]...) _, leafData := base.th.digestSumLeaf(base.ph.Path([]byte("testKey2")), tval) proof = &SparseMerkleProof{ diff --git a/smst_utils_test.go b/smst_utils_test.go index 4526fcc..8a12b44 100644 --- a/smst_utils_test.go +++ b/smst_utils_test.go @@ -26,7 +26,7 @@ func (smst *SMSTWithStorage) Update(key, value []byte, sum uint64) error { if err := smst.SMST.Update(key, value, sum); err != nil { return err } - valueHash := smst.valueDigest(value) + valueHash := smst.valueHash(value) var sumBz [sumSizeBits]byte binary.BigEndian.PutUint64(sumBz[:], sum) value = append(value, sumBz[:]...) diff --git a/smt.go b/smt.go index cda02c0..9312c99 100644 --- a/smt.go +++ b/smt.go @@ -57,6 +57,11 @@ func ImportSparseMerkleTrie( return smt } +// Root returns the root hash of the trie +func (smt *SMT) Root() MerkleRoot { + return hashNode(smt.Spec(), smt.root) +} + // Get returns the hash (i.e. digest) of the leaf value stored at the given key func (smt *SMT) Get(key []byte) ([]byte, error) { path := smt.ph.Path(key) @@ -105,16 +110,21 @@ func (smt *SMT) Get(key []byte) ([]byte, error) { } // Update inserts the `value` for the given `key` into the SMT -func (smt *SMT) Update(key []byte, value []byte) error { - // Expand +func (smt *SMT) Update(key, value []byte) error { + // Expand the key into a path by computing its digest path := smt.ph.Path(key) - valueHash := smt.valueDigest(value) + + // Convert the value into a hash by computing its digest + valueHash := smt.valueHash(value) + + // Update the trie with the new key-value pair var orphans orphanNodes - trie, err := smt.update(smt.root, 0, path, valueHash, &orphans) + // Compute the new root by inserting (path, valueHash) starting + newRoot, err := smt.update(smt.root, 0, path, valueHash, &orphans) if err != nil { return err } - smt.root = trie + smt.root = newRoot if len(orphans) > 0 { smt.orphans = append(smt.orphans, orphans) } @@ -140,7 +150,8 @@ func (smt *SMT) update( } if leaf, ok := node.(*leafNode); ok { prefixLen := countCommonPrefixBits(path, leaf.path, depth) - if prefixLen == smt.depth() { // replace leaf if paths are equal + // replace leaf if paths are equal + if prefixLen == smt.depth() { smt.addOrphan(orphans, node) return newLeaf, nil } @@ -511,7 +522,7 @@ func (smt *SMT) ProveClosest(path []byte) ( return proof, nil } -// resolves a stub into a cached node +// resolveLazy resolves resolves a stub into a cached node func (smt *SMT) resolveLazy(node trieNode) (trieNode, error) { stub, ok := node.(*lazyNode) if !ok { @@ -523,13 +534,13 @@ func (smt *SMT) resolveLazy(node trieNode) (trieNode, error) { return smt.resolve(stub.digest) } -func (smt *SMT) resolve(hash []byte) (ret trieNode, err error) { +func (smt *SMT) resolve(hash []byte) (trieNode, error) { if bytes.Equal(smt.th.placeholder(), hash) { - return + return nil, nil } data, err := smt.nodes.Get(hash) if err != nil { - return + return nil, err } if isLeafNode(data) { leaf := leafNode{persisted: true, digest: hash} @@ -538,7 +549,7 @@ func (smt *SMT) resolve(hash []byte) (ret trieNode, err error) { } if isExtNode(data) { extNode := extensionNode{persisted: true, digest: hash} - pathBounds, path, childHash := parseExtension(data, smt.ph) + pathBounds, path, childHash := parseExtNode(data, smt.ph) extNode.path = path copy(extNode.pathBounds[:], pathBounds) extNode.child = &lazyNode{childHash} @@ -552,9 +563,9 @@ func (smt *SMT) resolve(hash []byte) (ret trieNode, err error) { } // resolveSum resolves -func (smt *SMT) resolveSum(hash []byte) (ret trieNode, err error) { +func (smt *SMT) resolveSum(hash []byte) (trieNode, error) { if bytes.Equal(placeholder(smt.Spec()), hash) { - return + return nil, nil } data, err := smt.nodes.Get(hash) if err != nil { @@ -567,7 +578,7 @@ func (smt *SMT) resolveSum(hash []byte) (ret trieNode, err error) { } if isExtNode(data) { extNode := extensionNode{persisted: true, digest: hash} - pathBounds, path, childHash, _ := parseSumExtension(data, smt.ph) + pathBounds, path, childHash, _ := parseSumExtNode(data, smt.ph) extNode.path = path copy(extNode.pathBounds[:], pathBounds) extNode.child = &lazyNode{childHash} @@ -626,11 +637,6 @@ func (smt *SMT) commit(node trieNode) error { return smt.nodes.Set(hashNode(smt.Spec(), node), preimage) } -// Root returns the root hash of the trie -func (smt *SMT) Root() MerkleRoot { - return hashNode(smt.Spec(), smt.root) -} - func (smt *SMT) addOrphan(orphans *[][]byte, node trieNode) { if node.Persisted() { *orphans = append(*orphans, node.CachedDigest()) diff --git a/smt_proofs_test.go b/smt_proofs_test.go index 06482c7..6b2c254 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -84,7 +84,7 @@ func TestSMT_Proof_Operations(t *testing.T) { require.False(t, result) // Try proving a default value for a non-default leaf. - _, leafData := base.th.digestLeaf(base.ph.Path([]byte("testKey2")), base.valueDigest([]byte("testValue"))) + _, leafData := base.th.digestLeaf(base.ph.Path([]byte("testKey2")), base.valueHash([]byte("testValue"))) proof = &SparseMerkleProof{ SideNodes: proof.SideNodes, NonMembershipLeafData: leafData, diff --git a/smt_utils_test.go b/smt_utils_test.go index 7137fa8..dc65a95 100644 --- a/smt_utils_test.go +++ b/smt_utils_test.go @@ -24,7 +24,7 @@ func (smt *SMTWithStorage) Update(key, value []byte) error { if err := smt.SMT.Update(key, value); err != nil { return err } - valueHash := smt.valueDigest(value) + valueHash := smt.valueHash(value) return smt.preimages.Set(valueHash, value) } diff --git a/types.go b/types.go index 8d438b2..653fc12 100644 --- a/types.go +++ b/types.go @@ -94,7 +94,7 @@ type TrieSpec struct { // newTrieSpec returns a new TrieSpec with the given hasher and sumTrie flag func newTrieSpec(hasher hash.Hash, sumTrie bool) TrieSpec { - spec := TrieSpec{th: *newTrieHasher(hasher)} + spec := TrieSpec{th: *NewTrieHasher(hasher)} spec.ph = &pathHasher{spec.th} spec.vh = &valueHasher{spec.th} spec.sumTrie = sumTrie @@ -113,8 +113,8 @@ func (spec *TrieSpec) depth() int { return spec.ph.PathSize() * 8 // path size is in bytes so multiply by 8 to get num bits } -// valueDigest returns the hash of a value, or the value itself if no value hasher is specified. -func (spec *TrieSpec) valueDigest(value []byte) []byte { +// valueHash returns the hash of a value, or the value itself if no value hasher is specified. +func (spec *TrieSpec) valueHash(value []byte) []byte { if spec.vh == nil { return value } diff --git a/utils.go b/utils.go index 801284f..e607edf 100644 --- a/utils.go +++ b/utils.go @@ -4,18 +4,6 @@ import ( "encoding/binary" ) -type nilPathHasher struct { - hashSize int -} - -func (n *nilPathHasher) Path(key []byte) []byte { return key[:n.hashSize] } - -func (n *nilPathHasher) PathSize() int { return n.hashSize } - -func newNilPathHasher(hashSize int) PathHasher { - return &nilPathHasher{hashSize: hashSize} -} - // getPathBit gets the bit at an offset (see position) in the data // provided relative to the most significant bit func getPathBit(data []byte, position int) int { @@ -179,7 +167,7 @@ func hashPreimage(spec *TrieSpec, data []byte) []byte { // Used for verification of serialized proof data func hashSerialization(smt *TrieSpec, data []byte) []byte { if isExtNode(data) { - pathBounds, path, childHash := parseExtension(data, smt.ph) + pathBounds, path, childHash := parseExtNode(data, smt.ph) ext := extensionNode{path: path, child: &lazyNode{childHash}} copy(ext.pathBounds[:], pathBounds) return smt.digestNode(&ext) @@ -190,7 +178,7 @@ func hashSerialization(smt *TrieSpec, data []byte) []byte { // Used for verification of serialized proof data for sum trie nodes func hashSumSerialization(smt *TrieSpec, data []byte) []byte { if isExtNode(data) { - pathBounds, path, childHash, _ := parseSumExtension(data, smt.ph) + pathBounds, path, childHash, _ := parseSumExtNode(data, smt.ph) ext := extensionNode{path: path, child: &lazyNode{childHash}} copy(ext.pathBounds[:], pathBounds) return smt.hashSumNode(&ext) From bff8a76f257a4ba0946b81d8a2ea3793ac3b792f Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Sat, 10 Feb 2024 12:29:52 -0800 Subject: [PATCH 12/40] Add prefixLen in node encoders --- node_encoders.go | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/node_encoders.go b/node_encoders.go index 66d6217..dab773d 100644 --- a/node_encoders.go +++ b/node_encoders.go @@ -19,31 +19,43 @@ var ( leafNodePrefix = []byte{0} innerNodePrefix = []byte{1} extNodePrefix = []byte{2} + prefixLen = 1 ) +// NB: We use `prefixLen` a lot through this file, so to make the code more readable, we +// define it as a constant but need to assert on its length just in case the code evolves +// in the future. +func init() { + if len(leafNodePrefix) != prefixLen || + len(innerNodePrefix) != prefixLen || + len(extNodePrefix) != prefixLen { + panic("invalid prefix length") + } +} + // isLeafNode returns true if the encoded node data is a leaf node func isLeafNode(data []byte) bool { - return bytes.Equal(data[:len(leafNodePrefix)], leafNodePrefix) + return bytes.Equal(data[:prefixLen], leafNodePrefix) } // isExtNode returns true if the encoded node data is an extension node func isExtNode(data []byte) bool { - return bytes.Equal(data[:len(extNodePrefix)], extNodePrefix) + return bytes.Equal(data[:prefixLen], extNodePrefix) } // parseLeafNode parses a leafNode into its components -func parseLeafNode(data []byte, ph PathHasher) (leftChild, rightChild []byte) { - leftChild = data[len(leafNodePrefix) : len(leafNodePrefix)+ph.PathSize()] - rightChild = data[len(leafNodePrefix)+ph.PathSize():] +func parseLeafNode(data []byte, ph PathHasher) (path, value []byte) { + path = data[prefixLen : prefixLen+ph.PathSize()] + value = data[prefixLen+ph.PathSize():] return } // parseExtNode parses an extNode into its components func parseExtNode(data []byte, ph PathHasher) (pathBounds, path, childData []byte) { // +2 represents the length of the pathBounds - pathBounds = data[len(extNodePrefix) : len(extNodePrefix)+2] - path = data[len(extNodePrefix)+2 : len(extNodePrefix)+2+ph.PathSize()] - childData = data[len(extNodePrefix)+2+ph.PathSize():] + pathBounds = data[prefixLen : prefixLen+2] + path = data[prefixLen+2 : prefixLen+2+ph.PathSize()] + childData = data[prefixLen+2+ph.PathSize():] return } @@ -54,9 +66,9 @@ func parseSumExtNode(data []byte, ph PathHasher) (pathBounds, path, childData [] copy(sumBz[:], data[len(data)-sumSizeBits:]) // +2 represents the length of the pathBounds - pathBounds = data[len(extNodePrefix) : len(extNodePrefix)+2] - path = data[len(extNodePrefix)+2 : len(extNodePrefix)+2+ph.PathSize()] - childData = data[len(extNodePrefix)+2+ph.PathSize() : len(data)-sumSizeBits] + pathBounds = data[prefixLen : prefixLen+2] + path = data[prefixLen+2 : prefixLen+2+ph.PathSize()] + childData = data[prefixLen+2+ph.PathSize() : len(data)-sumSizeBits] return } From cd24a787820a0f5010517c9f947ecfdcbe2b139f Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Sat, 10 Feb 2024 19:23:58 -0800 Subject: [PATCH 13/40] Interim checkpoint commit - modified lots of code and all tests pass but struggling with parsing in smst_example_test.go --- hasher.go | 63 +++++++++++++++++++++++++------------------- node_encoders.go | 45 ++++++++++++++++++++++--------- options.go | 8 +++--- smst.go | 30 +++++++++++---------- smst_example_test.go | 28 +++++++++++++------- smst_proofs_test.go | 2 +- smst_test.go | 2 +- smt.go | 8 +++--- smt_proofs_test.go | 2 +- smt_test.go | 2 +- types.go | 22 +++++++++------- utils.go | 4 +-- 12 files changed, 129 insertions(+), 87 deletions(-) diff --git a/hasher.go b/hasher.go index 3854873..56ec5d6 100644 --- a/hasher.go +++ b/hasher.go @@ -4,6 +4,9 @@ import ( "hash" ) +// TODO_IN_THIS_PR: Improve how the `hasher` file is consolidated (or not) +// with `node_encoders.go` since the two are very similar. + // Ensure the hasher interfaces are satisfied var ( _ PathHasher = (*pathHasher)(nil) @@ -58,7 +61,7 @@ func NewNilPathHasher(hasher hash.Hash) PathHasher { // Path returns the digest of a key produced by the path hasher func (ph *pathHasher) Path(key []byte) []byte { - return ph.digest(key)[:ph.PathSize()] + return ph.digestData(key)[:ph.PathSize()] } // PathSize returns the length (in bytes) of digests produced by the path hasher @@ -69,7 +72,7 @@ func (ph *pathHasher) PathSize() int { // HashValue hashes the data provided using the value hasher func (vh *valueHasher) HashValue(data []byte) []byte { - return vh.digest(data) + return vh.digestData(data) } // Path satisfies the PathHasher#Path interface @@ -82,46 +85,52 @@ func (n *nilPathHasher) PathSize() int { return n.hashSize } -// digest returns the hash of the data provided using the trie hasher. -func (th *trieHasher) digest(data []byte) []byte { +// digestData returns the hash of the data provided using the trie hasher. +func (th *trieHasher) digestData(data []byte) []byte { th.hasher.Write(data) - sum := th.hasher.Sum(nil) + digest := th.hasher.Sum(nil) th.hasher.Reset() - return sum + return digest } -// digestLeaf returns the hash of the leaf data & pathprovided using the trie hasher. -func (th *trieHasher) digestLeaf(path, data []byte) ([]byte, []byte) { - value := encodeLeafNode(path, data) - return th.digest(value), value +// digestLeaf returns the encoded leaf data as well as its hash (i.e. digest) +func (th *trieHasher) digestLeaf(path, data []byte) (digest, value []byte) { + value = encodeLeafNode(path, data) + digest = th.digestData(value) + return } -func (th *trieHasher) digestSumLeaf(path []byte, leafData []byte) ([]byte, []byte) { - value := encodeLeafNode(path, leafData) - digest := th.digest(value) - digest = append(digest, value[len(value)-sumSizeBits:]...) - return digest, value +func (th *trieHasher) digestNode(leftData, rightData []byte) (digest, value []byte) { + value = encodeInnerNode(leftData, rightData) + digest = th.digestData(value) + return } -func (th *trieHasher) digestNode(leftData []byte, rightData []byte) ([]byte, []byte) { - value := encodeInnerNode(leftData, rightData) - return th.digest(value), value +func (th *trieHasher) digestSumLeaf(path, leafData []byte) (digest, value []byte) { + value = encodeLeafNode(path, leafData) + digest = th.digestData(value) + digest = append(digest, value[len(value)-sumSizeBits:]...) + return } -func (th *trieHasher) digestSumNode(leftData []byte, rightData []byte) ([]byte, []byte) { - value := encodeSumInnerNode(leftData, rightData) - digest := th.digest(value) +func (th *trieHasher) digestSumNode(leftData, rightData []byte) (digest, value []byte) { + value = encodeSumInnerNode(leftData, rightData) + digest = th.digestData(value) digest = append(digest, value[len(value)-sumSizeBits:]...) - return digest, value + return } -func (th *trieHasher) parseNode(data []byte) ([]byte, []byte) { - return data[len(innerNodePrefix) : th.hashSize()+len(innerNodePrefix)], data[len(innerNodePrefix)+th.hashSize():] +func (th *trieHasher) parseInnerNode(data []byte) (leftData, rightData []byte) { + leftData = data[len(innerNodePrefix) : th.hashSize()+len(innerNodePrefix)] + rightData = data[len(innerNodePrefix)+th.hashSize():] + return } -func (th *trieHasher) parseSumNode(data []byte) ([]byte, []byte) { - sumless := data[:len(data)-sumSizeBits] - return sumless[len(innerNodePrefix) : th.hashSize()+sumSizeBits+len(innerNodePrefix)], sumless[len(innerNodePrefix)+th.hashSize()+sumSizeBits:] +func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte) { + dataWithoutSum := data[:len(data)-sumSizeBits] + leftData = dataWithoutSum[len(innerNodePrefix) : th.hashSize()+sumSizeBits+len(innerNodePrefix)] + rightData = dataWithoutSum[len(innerNodePrefix)+th.hashSize()+sumSizeBits:] + return } func (th *trieHasher) hashSize() int { diff --git a/node_encoders.go b/node_encoders.go index dab773d..fffe7a8 100644 --- a/node_encoders.go +++ b/node_encoders.go @@ -5,7 +5,7 @@ import ( "encoding/binary" ) -// TODO_IMPROVE: All of the parsing, encoding and checking functions in this file +// TODO_TECHDEBT: All of the parsing, encoding and checking functions in this file // can be abstracted out into the `trieNode` interface. // TODO_IMPROVE: We should create well-defined types & structs for every type of node @@ -43,8 +43,16 @@ func isExtNode(data []byte) bool { return bytes.Equal(data[:prefixLen], extNodePrefix) } +// isInnerNode returns true if the encoded node data is an inner node +func isInnerNode(data []byte) bool { + return bytes.Equal(data[:prefixLen], innerNodePrefix) +} + // parseLeafNode parses a leafNode into its components func parseLeafNode(data []byte, ph PathHasher) (path, value []byte) { + // panics if not a leaf node + checkPrefix(data, leafNodePrefix) + path = data[prefixLen : prefixLen+ph.PathSize()] value = data[prefixLen+ph.PathSize():] return @@ -52,6 +60,9 @@ func parseLeafNode(data []byte, ph PathHasher) (path, value []byte) { // parseExtNode parses an extNode into its components func parseExtNode(data []byte, ph PathHasher) (pathBounds, path, childData []byte) { + // panics if not an extension node + checkPrefix(data, extNodePrefix) + // +2 represents the length of the pathBounds pathBounds = data[prefixLen : prefixLen+2] path = data[prefixLen+2 : prefixLen+2+ph.PathSize()] @@ -61,6 +72,9 @@ func parseExtNode(data []byte, ph PathHasher) (pathBounds, path, childData []byt // parseSumExtNode parses the pathBounds, path, child data and sum from the encoded extension node data func parseSumExtNode(data []byte, ph PathHasher) (pathBounds, path, childData []byte, sum [sumSizeBits]byte) { + // panics if not an extension node + checkPrefix(data, extNodePrefix) + // Extract the sum from the encoded node data var sumBz [sumSizeBits]byte copy(sumBz[:], data[len(data)-sumSizeBits:]) @@ -72,8 +86,8 @@ func parseSumExtNode(data []byte, ph PathHasher) (pathBounds, path, childData [] return } -// encodeLeafNode encodes leaf nodes. both normal and sum leaves as in the sum leaf the -// sum is appended to the end of the valueHash +// encodeLeafNode encodes leaf nodes. This function applies to both the SMT and +// SMST since the weight of the node is appended to the end of the valueHash. func encodeLeafNode(path, leafData []byte) (data []byte) { data = append(data, leafNodePrefix...) data = append(data, path...) @@ -89,6 +103,15 @@ func encodeInnerNode(leftData, rightData []byte) (data []byte) { return } +// encodeExtensionNode encodes the data of an extension nodes +func encodeExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byte) { + data = append(data, extNodePrefix...) + data = append(data, pathBounds[:]...) + data = append(data, path...) + data = append(data, childData...) + return +} + // encodeSumInnerNode encodes an inner node for an smst given the data for both children func encodeSumInnerNode(leftData, rightData []byte) (data []byte) { // Retrieve the sum of the left subtree @@ -115,15 +138,6 @@ func encodeSumInnerNode(leftData, rightData []byte) (data []byte) { return } -// encodeExtensionNode encodes the data of an extension nodes -func encodeExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byte) { - data = append(data, extNodePrefix...) - data = append(data, pathBounds[:]...) - data = append(data, path...) - data = append(data, childData...) - return -} - // encodeSumExtensionNode encodes the data of a sum extension nodes func encodeSumExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byte) { @@ -136,3 +150,10 @@ func encodeSumExtensionNode(pathBounds [2]byte, path, childData []byte) (data [] data = append(data, sum[:]...) return } + +// checkPrefix panics if the prefix of the data does not match the expected prefix +func checkPrefix(data, prefix []byte) { + if !bytes.Equal(data[:prefixLen], prefix) { + panic("invalid prefix") + } +} diff --git a/options.go b/options.go index f2d6bc2..6164817 100644 --- a/options.go +++ b/options.go @@ -4,16 +4,16 @@ import ( "hash" ) -// Option is a function that configures SparseMerkleTrie. -type Option func(*TrieSpec) +// TrieSpecOption is a function that configures SparseMerkleTrie. +type TrieSpecOption func(*TrieSpec) // WithPathHasher returns an Option that sets the PathHasher to the one provided -func WithPathHasher(ph PathHasher) Option { +func WithPathHasher(ph PathHasher) TrieSpecOption { return func(ts *TrieSpec) { ts.ph = ph } } // WithValueHasher returns an Option that sets the ValueHasher to the one provided -func WithValueHasher(vh ValueHasher) Option { +func WithValueHasher(vh ValueHasher) TrieSpecOption { return func(ts *TrieSpec) { ts.vh = vh } } diff --git a/smst.go b/smst.go index dc0d7db..64f8e94 100644 --- a/smst.go +++ b/smst.go @@ -25,25 +25,27 @@ type SMST struct { func NewSparseMerkleSumTrie( nodes kvstore.MapStore, hasher hash.Hash, - options ...Option, + options ...TrieSpecOption, ) *SMST { + trieSpec := newTrieSpec(hasher, true) + for _, option := range options { + option(&trieSpec) + } + + // Initialize a non-sum SMT and modify it to have a nil value hasher + // TODO_IN_THIS_PR: Understand the purpose of the nilValueHasher and why + // we're not applying it to the smst but we need it for the smt. smt := &SMT{ - TrieSpec: newTrieSpec(hasher, true), + TrieSpec: trieSpec, nodes: nodes, } - for _, option := range options { - option(&smt.TrieSpec) - } - nvh := WithValueHasher(nil) - nvh(&smt.TrieSpec) - smst := &SMST{ - TrieSpec: newTrieSpec(hasher, true), + nilValueHasher := WithValueHasher(nil) + nilValueHasher(&smt.TrieSpec) + + return &SMST{ + TrieSpec: trieSpec, SMT: smt, } - for _, option := range options { - option(&smst.TrieSpec) - } - return smst } // ImportSparseMerkleSumTrie returns a pointer to an SMST struct with the root hash provided @@ -51,7 +53,7 @@ func ImportSparseMerkleSumTrie( nodes kvstore.MapStore, hasher hash.Hash, root []byte, - options ...Option, + options ...TrieSpecOption, ) *SMST { smst := NewSparseMerkleSumTrie(nodes, hasher, options...) smst.root = &lazyNode{root} diff --git a/smst_example_test.go b/smst_example_test.go index c2c5b53..a412569 100644 --- a/smst_example_test.go +++ b/smst_example_test.go @@ -20,7 +20,7 @@ func TestExampleSMST(t *testing.T) { nodeStore := simplemap.NewSimpleMap() // Initialize the smst - smst := NewSparseMerkleSumTrie(nodeStore, sha256.New()) //, smt.WithValueHasher(nil), smt.WithPathHasher(smt.NewNilPathHasher(sha256.New()))) + smst := NewSparseMerkleSumTrie(nodeStore, sha256.New()) // Update trie with keys, values and their sums err := smst.Update([]byte("foo"), []byte("oof"), 10) @@ -72,24 +72,32 @@ func TestExampleSMST(t *testing.T) { require.NoError(t, err) require.False(t, valid_false1) - exportToCSV(smst, dataMap, nodeStore) + exportToCSV(t, smst, dataMap, nodeStore) } func exportToCSV( + t *testing.T, smst SparseMerkleSumTrie, innerMap map[string]string, nodeStore kvstore.MapStore, ) { - // hasher := sha256.New() - fmt.Println("Exporting to CSV", smst.Root()) - // rootBits := smst.Root() + t.Helper() + // rootHash := smst.Root() + // rootNode, err := nodeStore.Get(rootHash) + // require.NoError(t, err) + + // Testing + // fmt.Println(isExtNode(rootNode), isLeafNode(rootNode), isInnerNode(rootNode)) + // leftChild, rightChild := smst.Spec().th.parseInnerNode(rootNode) + // // fmt.Println(isExtNode(leftChild), isExtNode(rightChild), rightChild, leftChild) + // fmt.Println(leftChild[:1], isExtNode(leftChild), isInnerNode(leftChild), isLeafNode(leftChild)) + // path, value := parseLeafNode(rightChild, smst.Spec().ph) + // path2, value2 := parseLeafNode(leftChild, smst.Spec().ph) + // fmt.Println(path, "~~~", value, "~~~", path2, "~~~", value2) + for key, value := range innerMap { - // fmt.Println(key, smst.Spec().) v, s, err := smst.Get([]byte(key)) - if err != nil { - panic(err) - } - // parseLeafNode() + require.NoError(t, err) fmt.Println(v, s) fmt.Println(value) fmt.Println("") diff --git a/smst_proofs_test.go b/smst_proofs_test.go index e951cbb..d10838d 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -187,7 +187,7 @@ func TestSMST_Proof_ValidateBasic(t *testing.T) { // Case: incorrect non-nil sibling data proof, _ = smst.Prove([]byte("testKey1")) - proof.SiblingData = base.th.digest(proof.SiblingData) + proof.SiblingData = base.th.digestData(proof.SiblingData) require.EqualError( t, proof.validateBasic(base), diff --git a/smst_test.go b/smst_test.go index fc0c75d..7458ff5 100644 --- a/smst_test.go +++ b/smst_test.go @@ -17,7 +17,7 @@ import ( func NewSMSTWithStorage( nodes, preimages kvstore.MapStore, hasher hash.Hash, - options ...Option, + options ...TrieSpecOption, ) *SMSTWithStorage { return &SMSTWithStorage{ SMST: NewSparseMerkleSumTrie(nodes, hasher, options...), diff --git a/smt.go b/smt.go index 9312c99..8892887 100644 --- a/smt.go +++ b/smt.go @@ -31,7 +31,7 @@ type orphanNodes = [][]byte func NewSparseMerkleTrie( nodes kvstore.MapStore, hasher hash.Hash, - options ...Option, + options ...TrieSpecOption, ) *SMT { smt := SMT{ TrieSpec: newTrieSpec(hasher, false), @@ -49,7 +49,7 @@ func ImportSparseMerkleTrie( nodes kvstore.MapStore, hasher hash.Hash, root []byte, - options ...Option, + options ...TrieSpecOption, ) *SMT { smt := NewSparseMerkleTrie(nodes, hasher, options...) smt.root = &lazyNode{root} @@ -555,7 +555,7 @@ func (smt *SMT) resolve(hash []byte) (trieNode, error) { extNode.child = &lazyNode{childHash} return &extNode, nil } - leftHash, rightHash := smt.th.parseNode(data) + leftHash, rightHash := smt.th.parseInnerNode(data) inner := innerNode{persisted: true, digest: hash} inner.leftChild = &lazyNode{leftHash} inner.rightChild = &lazyNode{rightHash} @@ -584,7 +584,7 @@ func (smt *SMT) resolveSum(hash []byte) (trieNode, error) { extNode.child = &lazyNode{childHash} return &extNode, nil } - leftHash, rightHash := smt.th.parseSumNode(data) + leftHash, rightHash := smt.th.parseSumInnerNode(data) inner := innerNode{persisted: true, digest: hash} inner.leftChild = &lazyNode{leftHash} inner.rightChild = &lazyNode{rightHash} diff --git a/smt_proofs_test.go b/smt_proofs_test.go index 6b2c254..74cac1f 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -161,7 +161,7 @@ func TestSMT_Proof_ValidateBasic(t *testing.T) { // Case: incorrect non-nil sibling data proof, _ = smt.Prove([]byte("testKey1")) - proof.SiblingData = base.th.digest(proof.SiblingData) + proof.SiblingData = base.th.digestData(proof.SiblingData) require.EqualError( t, proof.validateBasic(base), diff --git a/smt_test.go b/smt_test.go index 28f1c6b..cf0e48a 100644 --- a/smt_test.go +++ b/smt_test.go @@ -15,7 +15,7 @@ import ( func NewSMTWithStorage( nodes, preimages kvstore.MapStore, hasher hash.Hash, - options ...Option, + options ...TrieSpecOption, ) *SMTWithStorage { return &SMTWithStorage{ SMT: NewSparseMerkleTrie(nodes, hasher, options...), diff --git a/types.go b/types.go index 653fc12..2800165 100644 --- a/types.go +++ b/types.go @@ -122,7 +122,7 @@ func (spec *TrieSpec) valueHash(value []byte) []byte { } // encodeNode serializes a node into a byte slice -func (spec *TrieSpec) encodeNode(node trieNode) (data []byte) { +func (spec *TrieSpec) encodeNode(node trieNode) []byte { switch n := node.(type) { case *lazyNode: panic("Encoding a lazyNode is not supported") @@ -135,33 +135,35 @@ func (spec *TrieSpec) encodeNode(node trieNode) (data []byte) { case *extensionNode: child := spec.digestNode(n.child) return encodeExtensionNode(n.pathBounds, n.path, child) + default: + panic("Unknown node type") } - return nil } -// digestNode hashes a node returning its digest +// digestNode hashes a node and returns its digest func (spec *TrieSpec) digestNode(node trieNode) []byte { if node == nil { return spec.th.placeholder() } - var cache *[]byte + + var cachedDigest *[]byte switch n := node.(type) { case *lazyNode: return n.digest case *leafNode: - cache = &n.digest + cachedDigest = &n.digest case *innerNode: - cache = &n.digest + cachedDigest = &n.digest case *extensionNode: if n.digest == nil { n.digest = spec.digestNode(n.expand()) } return n.digest } - if *cache == nil { - *cache = spec.th.digest(spec.encodeNode(node)) + if *cachedDigest == nil { + *cachedDigest = spec.th.digestData(spec.encodeNode(node)) } - return *cache + return *cachedDigest } // sumSerialize serializes a node returning the preimage hash, its sum and any @@ -206,7 +208,7 @@ func (spec *TrieSpec) hashSumNode(node trieNode) []byte { } if *cache == nil { preImage := spec.sumSerialize(node) - *cache = spec.th.digest(preImage) + *cache = spec.th.digestData(preImage) *cache = append(*cache, preImage[len(preImage)-sumSizeBits:]...) } return *cache diff --git a/utils.go b/utils.go index e607edf..18bfb7f 100644 --- a/utils.go +++ b/utils.go @@ -172,7 +172,7 @@ func hashSerialization(smt *TrieSpec, data []byte) []byte { copy(ext.pathBounds[:], pathBounds) return smt.digestNode(&ext) } - return smt.th.digest(data) + return smt.th.digestData(data) } // Used for verification of serialized proof data for sum trie nodes @@ -183,7 +183,7 @@ func hashSumSerialization(smt *TrieSpec, data []byte) []byte { copy(ext.pathBounds[:], pathBounds) return smt.hashSumNode(&ext) } - digest := smt.th.digest(data) + digest := smt.th.digestData(data) digest = append(digest, data[len(data)-sumSizeBits:]...) return digest } From 03d85ea37765ba4d0ca60dd5adde8a98cb9113fe Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Sat, 10 Feb 2024 20:01:50 -0800 Subject: [PATCH 14/40] Improved how placeholder values are maintained and how the node is resolved --- hasher.go | 2 +- node_encoders.go | 28 +++++----- proofs.go | 6 +-- smst_example_test.go | 18 +++---- smst_proofs_test.go | 4 +- smt.go | 121 +++++++++++++++++++++++++++---------------- smt_proofs_test.go | 2 +- types.go | 12 ++++- utils.go | 10 ---- 9 files changed, 118 insertions(+), 85 deletions(-) diff --git a/hasher.go b/hasher.go index 56ec5d6..49621b1 100644 --- a/hasher.go +++ b/hasher.go @@ -128,7 +128,7 @@ func (th *trieHasher) parseInnerNode(data []byte) (leftData, rightData []byte) { func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte) { dataWithoutSum := data[:len(data)-sumSizeBits] - leftData = dataWithoutSum[len(innerNodePrefix) : th.hashSize()+sumSizeBits+len(innerNodePrefix)] + leftData = dataWithoutSum[len(innerNodePrefix) : len(innerNodePrefix)+th.hashSize()+sumSizeBits] rightData = dataWithoutSum[len(innerNodePrefix)+th.hashSize()+sumSizeBits:] return } diff --git a/node_encoders.go b/node_encoders.go index fffe7a8..ddb0a61 100644 --- a/node_encoders.go +++ b/node_encoders.go @@ -15,6 +15,8 @@ import ( // NB: In this file, all references to the variable `data` should be treated as `encodedNodeData`. // It was abbreviated to `data` for brevity. +// TODO_TECHDEBT: We can easily use `iota` and ENUMS to create a wait to have +// more expressive code, and leverage switches statements throughout. var ( leafNodePrefix = []byte{0} innerNodePrefix = []byte{1} @@ -114,22 +116,10 @@ func encodeExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byt // encodeSumInnerNode encodes an inner node for an smst given the data for both children func encodeSumInnerNode(leftData, rightData []byte) (data []byte) { - // Retrieve the sum of the left subtree - leftSum := uint64(0) - leftSumBz := leftData[len(leftData)-sumSizeBits:] - if !bytes.Equal(leftSumBz, defaultEmptySum[:]) { - leftSum = binary.BigEndian.Uint64(leftSumBz) - } - - // Retrieve the sum of the right subtree - rightSum := uint64(0) - rightSumBz := rightData[len(rightData)-sumSizeBits:] - if !bytes.Equal(rightSumBz, defaultEmptySum[:]) { - rightSum = binary.BigEndian.Uint64(rightSumBz) - } - // Compute the sum of the current node var sum [sumSizeBits]byte + leftSum := parseSum(leftData) + rightSum := parseSum(rightData) binary.BigEndian.PutUint64(sum[:], leftSum+rightSum) // Prepare and return the encoded inner node data @@ -157,3 +147,13 @@ func checkPrefix(data, prefix []byte) { panic("invalid prefix") } } + +// parseSum parses the sum from the encoded node data +func parseSum(data []byte) uint64 { + sum := uint64(0) + sumBz := data[len(data)-sumSizeBits:] + if !bytes.Equal(sumBz, defaultEmptySum[:]) { + sum = binary.BigEndian.Uint64(sumBz) + } + return sum +} diff --git a/proofs.go b/proofs.go index 3d242d4..08b0de9 100644 --- a/proofs.go +++ b/proofs.go @@ -333,7 +333,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v var currentHash, currentData []byte if bytes.Equal(value, defaultEmptyValue) { // Non-membership proof. if proof.NonMembershipLeafData == nil { // Leaf is a placeholder value. - currentHash = placeholder(spec) + currentHash = spec.placeholder() } else { // Leaf is an unrelated leaf. var actualPath, valueHash []byte actualPath, valueHash = parseLeafNode(proof.NonMembershipLeafData, spec.ph) @@ -412,7 +412,7 @@ func CompactProof(proof *SparseMerkleProof, spec *TrieSpec) (*SparseCompactMerkl for i := 0; i < len(proof.SideNodes); i++ { node := make([]byte, hashSize(spec)) copy(node, proof.SideNodes[i]) - if bytes.Equal(node, placeholder(spec)) { + if bytes.Equal(node, spec.placeholder()) { setPathBit(bitMask, i) } else { compactedSideNodes = append(compactedSideNodes, node) @@ -438,7 +438,7 @@ func DecompactProof(proof *SparseCompactMerkleProof, spec *TrieSpec) (*SparseMer position := 0 for i := 0; i < proof.NumSideNodes; i++ { if getPathBit(proof.BitMask, i) == 1 { - decompactedSideNodes[i] = placeholder(spec) + decompactedSideNodes[i] = spec.placeholder() } else { decompactedSideNodes[i] = proof.SideNodes[position] position++ diff --git a/smst_example_test.go b/smst_example_test.go index a412569..f7d2d88 100644 --- a/smst_example_test.go +++ b/smst_example_test.go @@ -82,15 +82,15 @@ func exportToCSV( nodeStore kvstore.MapStore, ) { t.Helper() - // rootHash := smst.Root() - // rootNode, err := nodeStore.Get(rootHash) - // require.NoError(t, err) - - // Testing - // fmt.Println(isExtNode(rootNode), isLeafNode(rootNode), isInnerNode(rootNode)) - // leftChild, rightChild := smst.Spec().th.parseInnerNode(rootNode) - // // fmt.Println(isExtNode(leftChild), isExtNode(rightChild), rightChild, leftChild) - // fmt.Println(leftChild[:1], isExtNode(leftChild), isInnerNode(leftChild), isLeafNode(leftChild)) + rootHash := smst.Root() + rootNode, err := nodeStore.Get(rootHash) + require.NoError(t, err) + + leftChild, rightChild := smst.Spec().th.parseSumInnerNode(rootNode) + fmt.Println("Prefix", "isExt", "isLeaf", "isInner") + fmt.Println(rootNode[:1], isExtNode(rootNode), isLeafNode(rootNode), isInnerNode(rootNode)) + fmt.Println(leftChild[:1], isExtNode(leftChild), isLeafNode(leftChild), isInnerNode(leftChild)) + fmt.Println(rightChild[:1], isExtNode(rightChild), isLeafNode(rightChild), isInnerNode(rightChild)) // path, value := parseLeafNode(rightChild, smst.Spec().ph) // path2, value2 := parseLeafNode(leftChild, smst.Spec().ph) // fmt.Println(path, "~~~", value, "~~~", path2, "~~~", value2) diff --git a/smst_proofs_test.go b/smst_proofs_test.go index d10838d..58efeac 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -33,7 +33,7 @@ func TestSMST_Proof_Operations(t *testing.T) { proof, err = smst.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifySumProof(proof, placeholder(base), []byte("testKey3"), defaultEmptyValue, 0, base) + result, err = VerifySumProof(proof, base.placeholder(), []byte("testKey3"), defaultEmptyValue, 0, base) require.NoError(t, err) require.True(t, result) result, err = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 5, base) @@ -377,7 +377,7 @@ func TestSMST_ProveClosest_Empty(t *testing.T) { Path: path[:], FlippedBits: []int{0}, Depth: 0, - ClosestPath: placeholder(smst.Spec()), + ClosestPath: smst.placeholder(), ClosestProof: &SparseMerkleProof{}, }) diff --git a/smt.go b/smt.go index 8892887..e20d27b 100644 --- a/smt.go +++ b/smt.go @@ -490,7 +490,7 @@ func (smt *SMT) ProveClosest(path []byte) ( // Retrieve the closest path and value hash if found if node == nil { // trie was empty - proof.ClosestPath, proof.ClosestValueHash = placeholder(smt.Spec()), nil + proof.ClosestPath, proof.ClosestValueHash = smt.placeholder(), nil proof.ClosestProof = &SparseMerkleProof{} return proof, nil } @@ -529,66 +529,99 @@ func (smt *SMT) resolveLazy(node trieNode) (trieNode, error) { return node, nil } if smt.sumTrie { - return smt.resolveSum(stub.digest) + return smt.resolveSumNode(stub.digest) } - return smt.resolve(stub.digest) + return smt.resolveNode(stub.digest) } -func (smt *SMT) resolve(hash []byte) (trieNode, error) { - if bytes.Equal(smt.th.placeholder(), hash) { +// resolveNode returns a trieNode (inner, leaf, or extension) based on what they +// keyHash points to. +func (smt *SMT) resolveNode(digest []byte) (trieNode, error) { + // Check if the keyHash is the empty zero value of an empty subtree + if bytes.Equal(smt.placeholder(), digest) { return nil, nil } - data, err := smt.nodes.Get(hash) + + // Retrieve the encoded noe data + data, err := smt.nodes.Get(digest) if err != nil { return nil, err } + + // Return the appropriate node type based on the first byte of the data if isLeafNode(data) { - leaf := leafNode{persisted: true, digest: hash} - leaf.path, leaf.valueHash = parseLeafNode(data, smt.ph) - return &leaf, nil - } - if isExtNode(data) { - extNode := extensionNode{persisted: true, digest: hash} - pathBounds, path, childHash := parseExtNode(data, smt.ph) - extNode.path = path - copy(extNode.pathBounds[:], pathBounds) - extNode.child = &lazyNode{childHash} - return &extNode, nil - } - leftHash, rightHash := smt.th.parseInnerNode(data) - inner := innerNode{persisted: true, digest: hash} - inner.leftChild = &lazyNode{leftHash} - inner.rightChild = &lazyNode{rightHash} - return &inner, nil + path, valueHash := parseLeafNode(data, smt.ph) + return &leafNode{ + path: path, + valueHash: valueHash, + persisted: true, + digest: digest, + }, nil + } else if isExtNode(data) { + pathBounds, path, childData := parseExtNode(data, smt.ph) + return &extensionNode{ + path: path, + pathBounds: [2]byte(pathBounds), + child: &lazyNode{childData}, + persisted: true, + digest: digest, + }, nil + } else if isInnerNode(data) { + leftData, rightData := smt.th.parseInnerNode(data) + return &innerNode{ + leftChild: &lazyNode{leftData}, + rightChild: &lazyNode{rightData}, + persisted: true, + digest: digest, + }, nil + } else { + panic("invalid node type") + } } -// resolveSum resolves -func (smt *SMT) resolveSum(hash []byte) (trieNode, error) { - if bytes.Equal(placeholder(smt.Spec()), hash) { +// resolveNode returns a trieNode (inner, leaf, or extension) based on what they +// keyHash points to. +func (smt *SMT) resolveSumNode(digest []byte) (trieNode, error) { + // Check if the keyHash is the empty zero value of an empty subtree + if bytes.Equal(smt.placeholder(), digest) { return nil, nil } - data, err := smt.nodes.Get(hash) + + // Retrieve the encoded noe data + data, err := smt.nodes.Get(digest) if err != nil { return nil, err } + + // Return the appropriate node type based on the first byte of the data if isLeafNode(data) { - leaf := leafNode{persisted: true, digest: hash} - leaf.path, leaf.valueHash = parseLeafNode(data, smt.ph) - return &leaf, nil - } - if isExtNode(data) { - extNode := extensionNode{persisted: true, digest: hash} - pathBounds, path, childHash, _ := parseSumExtNode(data, smt.ph) - extNode.path = path - copy(extNode.pathBounds[:], pathBounds) - extNode.child = &lazyNode{childHash} - return &extNode, nil - } - leftHash, rightHash := smt.th.parseSumInnerNode(data) - inner := innerNode{persisted: true, digest: hash} - inner.leftChild = &lazyNode{leftHash} - inner.rightChild = &lazyNode{rightHash} - return &inner, nil + path, valueHash := parseLeafNode(data, smt.ph) + return &leafNode{ + path: path, + valueHash: valueHash, + persisted: true, + digest: digest, + }, nil + } else if isExtNode(data) { + pathBounds, path, childData, _ := parseSumExtNode(data, smt.ph) + return &extensionNode{ + path: path, + pathBounds: [2]byte(pathBounds), + child: &lazyNode{childData}, + persisted: true, + digest: digest, + }, nil + } else if isInnerNode(data) { + leftData, rightData := smt.th.parseSumInnerNode(data) + return &innerNode{ + leftChild: &lazyNode{leftData}, + rightChild: &lazyNode{rightData}, + persisted: true, + digest: digest, + }, nil + } else { + panic("invalid node type") + } } // Commit persists all dirty nodes in the trie, deletes all orphaned diff --git a/smt_proofs_test.go b/smt_proofs_test.go index 74cac1f..e159b72 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -332,7 +332,7 @@ func TestSMT_ProveClosest_Empty(t *testing.T) { Path: path[:], FlippedBits: []int{0}, Depth: 0, - ClosestPath: placeholder(smt.Spec()), + ClosestPath: smt.placeholder(), ClosestProof: &SparseMerkleProof{}, }) diff --git a/types.go b/types.go index 2800165..b75cd70 100644 --- a/types.go +++ b/types.go @@ -106,6 +106,16 @@ func (spec *TrieSpec) Spec() *TrieSpec { return spec } +// placeholder returns the default placeholder value depending on the trie type +func (spec *TrieSpec) placeholder() []byte { + if spec.sumTrie { + placeholder := spec.th.placeholder() + placeholder = append(placeholder, defaultEmptySum[:]...) + return placeholder + } + return spec.th.placeholder() +} + // depth returns the maximum depth of the trie. // Since this tree is a binary tree, the depth is the number of bits in the path // TODO_IN_THIS_PR: Try to understand why we're not taking the log of the output @@ -190,7 +200,7 @@ func (spec *TrieSpec) sumSerialize(node trieNode) (preImage []byte) { // digest = [node hash]+[8 byte sum] func (spec *TrieSpec) hashSumNode(node trieNode) []byte { if node == nil { - return placeholder(spec) + return spec.placeholder() } var cache *[]byte switch n := node.(type) { diff --git a/utils.go b/utils.go index 18bfb7f..1ce8cdc 100644 --- a/utils.go +++ b/utils.go @@ -106,16 +106,6 @@ func bytesToInt(bz []byte) int { return int(u) } -// placeholder returns the default placeholder value depending on the trie type -func placeholder(spec *TrieSpec) []byte { - if spec.sumTrie { - placeholder := spec.th.placeholder() - placeholder = append(placeholder, defaultEmptySum[:]...) - return placeholder - } - return spec.th.placeholder() -} - // hashSize returns the hash size depending on the trie type func hashSize(spec *TrieSpec) int { if spec.sumTrie { From ee677404c3601fe872d589b6e3535b1644a4799a Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Sat, 10 Feb 2024 20:41:28 -0800 Subject: [PATCH 15/40] Moved TrieSpec functions from utils into receiver functions --- hasher.go | 4 +-- proofs.go | 16 +++++----- smt.go | 14 ++++----- types.go | 94 ++++++++++++++++++++++++++++++++++++++++++++++++------- utils.go | 72 ------------------------------------------ 5 files changed, 99 insertions(+), 101 deletions(-) diff --git a/hasher.go b/hasher.go index 49621b1..7ac250a 100644 --- a/hasher.go +++ b/hasher.go @@ -100,7 +100,7 @@ func (th *trieHasher) digestLeaf(path, data []byte) (digest, value []byte) { return } -func (th *trieHasher) digestNode(leftData, rightData []byte) (digest, value []byte) { +func (th *trieHasher) digestInnerNode(leftData, rightData []byte) (digest, value []byte) { value = encodeInnerNode(leftData, rightData) digest = th.digestData(value) return @@ -113,7 +113,7 @@ func (th *trieHasher) digestSumLeaf(path, leafData []byte) (digest, value []byte return } -func (th *trieHasher) digestSumNode(leftData, rightData []byte) (digest, value []byte) { +func (th *trieHasher) digestSumInnerNode(leftData, rightData []byte) (digest, value []byte) { value = encodeSumInnerNode(leftData, rightData) digest = th.digestData(value) digest = append(digest, value[len(value)-sumSizeBits:]...) diff --git a/proofs.go b/proofs.go index 08b0de9..6a31afd 100644 --- a/proofs.go +++ b/proofs.go @@ -66,8 +66,8 @@ func (proof *SparseMerkleProof) validateBasic(spec *TrieSpec) error { // Check that all supplied sidenodes are the correct size. for _, v := range proof.SideNodes { - if len(v) != hashSize(spec) { - return fmt.Errorf("invalid side node size: got %d but want %d", len(v), hashSize(spec)) + if len(v) != spec.hashSize() { + return fmt.Errorf("invalid side node size: got %d but want %d", len(v), spec.hashSize()) } } @@ -341,7 +341,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v // This is not an unrelated leaf; non-membership proof failed. return false, nil, errors.Join(ErrBadProof, errors.New("non-membership proof on related leaf")) } - currentHash, currentData = digestLeaf(spec, actualPath, valueHash) + currentHash, currentData = spec.digestLeaf(actualPath, valueHash) update := make([][]byte, 2) update[0], update[1] = currentHash, currentData @@ -349,7 +349,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v } } else { // Membership proof. valueHash := spec.valueHash(value) - currentHash, currentData = digestLeaf(spec, path, valueHash) + currentHash, currentData = spec.digestLeaf(path, valueHash) update := make([][]byte, 2) update[0], update[1] = currentHash, currentData updates = append(updates, update) @@ -357,13 +357,13 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v // Recompute root. for i := 0; i < len(proof.SideNodes); i++ { - node := make([]byte, hashSize(spec)) + node := make([]byte, spec.hashSize()) copy(node, proof.SideNodes[i]) if getPathBit(path, len(proof.SideNodes)-1-i) == leftChildBit { - currentHash, currentData = digestNode(spec, currentHash, node) + currentHash, currentData = spec.digestInnerNode(currentHash, node) } else { - currentHash, currentData = digestNode(spec, node, currentHash) + currentHash, currentData = spec.digestInnerNode(node, currentHash) } update := make([][]byte, 2) @@ -410,7 +410,7 @@ func CompactProof(proof *SparseMerkleProof, spec *TrieSpec) (*SparseCompactMerkl bitMask := make([]byte, int(math.Ceil(float64(len(proof.SideNodes))/float64(8)))) var compactedSideNodes [][]byte for i := 0; i < len(proof.SideNodes); i++ { - node := make([]byte, hashSize(spec)) + node := make([]byte, spec.hashSize()) copy(node, proof.SideNodes[i]) if bytes.Equal(node, spec.placeholder()) { setPathBit(bitMask, i) diff --git a/smt.go b/smt.go index e20d27b..de9cd0c 100644 --- a/smt.go +++ b/smt.go @@ -59,7 +59,7 @@ func ImportSparseMerkleTrie( // Root returns the root hash of the trie func (smt *SMT) Root() MerkleRoot { - return hashNode(smt.Spec(), smt.root) + return smt.digest(smt.root) } // Get returns the hash (i.e. digest) of the leaf value stored at the given key @@ -365,7 +365,7 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) { for i := range siblings { var sideNode []byte sibling := siblings[len(siblings)-i-1] - sideNode = hashNode(smt.Spec(), sibling) + sideNode = smt.digest(sibling) sideNodes = append(sideNodes, sideNode) } @@ -378,7 +378,7 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) { if err != nil { return nil, err } - proof.SiblingData = serialize(smt.Spec(), sib) + proof.SiblingData = smt.encode(sib) } return proof, nil } @@ -505,7 +505,7 @@ func (smt *SMT) ProveClosest(path []byte) ( for i := range siblings { var sideNode []byte sibling := siblings[len(siblings)-i-1] - sideNode = hashNode(smt.Spec(), sibling) + sideNode = smt.digest(sibling) sideNodes = append(sideNodes, sideNode) } proof.ClosestProof = &SparseMerkleProof{ @@ -516,7 +516,7 @@ func (smt *SMT) ProveClosest(path []byte) ( if err != nil { return nil, err } - proof.ClosestProof.SiblingData = serialize(smt.Spec(), sib) + proof.ClosestProof.SiblingData = smt.encode(sib) } return proof, nil @@ -666,8 +666,8 @@ func (smt *SMT) commit(node trieNode) error { default: return nil } - preimage := serialize(smt.Spec(), node) - return smt.nodes.Set(hashNode(smt.Spec(), node), preimage) + preimage := smt.encode(node) + return smt.nodes.Set(smt.digest(node), preimage) } func (smt *SMT) addOrphan(orphans *[][]byte, node trieNode) { diff --git a/types.go b/types.go index b75cd70..8dc3c30 100644 --- a/types.go +++ b/types.go @@ -116,6 +116,78 @@ func (spec *TrieSpec) placeholder() []byte { return spec.th.placeholder() } +// hashSize returns the hash size depending on the trie type +func (spec *TrieSpec) hashSize() int { + if spec.sumTrie { + return spec.th.hashSize() + sumSizeBits + } + return spec.th.hashSize() +} + +// digestLeaf returns the hash and preimage of a leaf node depending on the trie type +func (spec *TrieSpec) digestLeaf(path, value []byte) ([]byte, []byte) { + if spec.sumTrie { + return spec.th.digestSumLeaf(path, value) + } + return spec.th.digestLeaf(path, value) +} + +// digestNode returns the hash and preimage of a node depending on the trie type +func (spec *TrieSpec) digestInnerNode(left, right []byte) ([]byte, []byte) { + if spec.sumTrie { + return spec.th.digestSumInnerNode(left, right) + } + return spec.th.digestInnerNode(left, right) +} + +// digest hashes a node depending on the trie type +func (spec *TrieSpec) digest(node trieNode) []byte { + if spec.sumTrie { + return spec.digestSumNode(node) + } + return spec.digestNode(node) +} + +// encode serializes a node depending on the trie type +func (spec *TrieSpec) encode(node trieNode) []byte { + if spec.sumTrie { + return spec.encodeSumNode(node) + } + return spec.encodeNode(node) +} + +// hashPreimage hashes the serialised data provided depending on the trie type +func hashPreimage(spec *TrieSpec, data []byte) []byte { + if spec.sumTrie { + return hashSumSerialization(spec, data) + } + return hashSerialization(spec, data) +} + +// Used for verification of serialized proof data +func hashSerialization(smt *TrieSpec, data []byte) []byte { + if isExtNode(data) { + pathBounds, path, childHash := parseExtNode(data, smt.ph) + ext := extensionNode{path: path, child: &lazyNode{childHash}} + copy(ext.pathBounds[:], pathBounds) + return smt.digestNode(&ext) + } + return smt.th.digestData(data) +} + +// Used for verification of serialized proof data for sum trie nodes +func hashSumSerialization(smt *TrieSpec, data []byte) []byte { + if isExtNode(data) { + pathBounds, path, childHash, _ := parseSumExtNode(data, smt.ph) + ext := extensionNode{path: path, child: &lazyNode{childHash}} + copy(ext.pathBounds[:], pathBounds) + return smt.digestSumNode(&ext) + } + digest := smt.th.digestData(data) + digest = append(digest, data[len(data)-sumSizeBits:]...) + return digest +} + // depth returns the maximum depth of the trie. // Since this tree is a binary tree, the depth is the number of bits in the path // TODO_IN_THIS_PR: Try to understand why we're not taking the log of the output @@ -176,29 +248,27 @@ func (spec *TrieSpec) digestNode(node trieNode) []byte { return *cachedDigest } -// sumSerialize serializes a node returning the preimage hash, its sum and any -// errors encountered -func (spec *TrieSpec) sumSerialize(node trieNode) (preImage []byte) { +// encodeSumNode serializes a sum node and returns the preImage hash. +func (spec *TrieSpec) encodeSumNode(node trieNode) (preImage []byte) { switch n := node.(type) { case *lazyNode: - panic("serialize(lazyNode)") + panic("encodeSumNode(lazyNode)") case *leafNode: return encodeLeafNode(n.path, n.valueHash) case *innerNode: - leftChild := spec.hashSumNode(n.leftChild) - rightChild := spec.hashSumNode(n.rightChild) + leftChild := spec.digestSumNode(n.leftChild) + rightChild := spec.digestSumNode(n.rightChild) preImage = encodeSumInnerNode(leftChild, rightChild) return preImage case *extensionNode: - child := spec.hashSumNode(n.child) + child := spec.digestSumNode(n.child) return encodeSumExtensionNode(n.pathBounds, n.path, child) } return nil } -// hashSumNode hashes a node returning its digest in the following form -// digest = [node hash]+[8 byte sum] -func (spec *TrieSpec) hashSumNode(node trieNode) []byte { +// digestSumNode hashes a sum node returning its digest in the following form: [node hash]+[8 byte sum] +func (spec *TrieSpec) digestSumNode(node trieNode) []byte { if node == nil { return spec.placeholder() } @@ -212,12 +282,12 @@ func (spec *TrieSpec) hashSumNode(node trieNode) []byte { cache = &n.digest case *extensionNode: if n.digest == nil { - n.digest = spec.hashSumNode(n.expand()) + n.digest = spec.digestSumNode(n.expand()) } return n.digest } if *cache == nil { - preImage := spec.sumSerialize(node) + preImage := spec.encodeSumNode(node) *cache = spec.th.digestData(preImage) *cache = append(*cache, preImage[len(preImage)-sumSizeBits:]...) } diff --git a/utils.go b/utils.go index 1ce8cdc..710e117 100644 --- a/utils.go +++ b/utils.go @@ -105,75 +105,3 @@ func bytesToInt(bz []byte) int { u := binary.BigEndian.Uint64(b) return int(u) } - -// hashSize returns the hash size depending on the trie type -func hashSize(spec *TrieSpec) int { - if spec.sumTrie { - return spec.th.hashSize() + sumSizeBits - } - return spec.th.hashSize() -} - -// digestLeaf returns the hash and preimage of a leaf node depending on the trie type -func digestLeaf(spec *TrieSpec, path, value []byte) ([]byte, []byte) { - if spec.sumTrie { - return spec.th.digestSumLeaf(path, value) - } - return spec.th.digestLeaf(path, value) -} - -// digestNode returns the hash and preimage of a node depending on the trie type -func digestNode(spec *TrieSpec, left, right []byte) ([]byte, []byte) { - if spec.sumTrie { - return spec.th.digestSumNode(left, right) - } - return spec.th.digestNode(left, right) -} - -// hashNode hashes a node depending on the trie type -func hashNode(spec *TrieSpec, node trieNode) []byte { - if spec.sumTrie { - return spec.hashSumNode(node) - } - return spec.digestNode(node) -} - -// serialize serializes a node depending on the trie type -func serialize(spec *TrieSpec, node trieNode) []byte { - if spec.sumTrie { - return spec.sumSerialize(node) - } - return spec.encodeNode(node) -} - -// hashPreimage hashes the serialised data provided depending on the trie type -func hashPreimage(spec *TrieSpec, data []byte) []byte { - if spec.sumTrie { - return hashSumSerialization(spec, data) - } - return hashSerialization(spec, data) -} - -// Used for verification of serialized proof data -func hashSerialization(smt *TrieSpec, data []byte) []byte { - if isExtNode(data) { - pathBounds, path, childHash := parseExtNode(data, smt.ph) - ext := extensionNode{path: path, child: &lazyNode{childHash}} - copy(ext.pathBounds[:], pathBounds) - return smt.digestNode(&ext) - } - return smt.th.digestData(data) -} - -// Used for verification of serialized proof data for sum trie nodes -func hashSumSerialization(smt *TrieSpec, data []byte) []byte { - if isExtNode(data) { - pathBounds, path, childHash, _ := parseSumExtNode(data, smt.ph) - ext := extensionNode{path: path, child: &lazyNode{childHash}} - copy(ext.pathBounds[:], pathBounds) - return smt.hashSumNode(&ext) - } - digest := smt.th.digestData(data) - digest = append(digest, data[len(data)-sumSizeBits:]...) - return digest -} From c263cd75922f05ab5bf7f11dc228d63ed1c6b295 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Sun, 11 Feb 2024 16:48:07 -0800 Subject: [PATCH 16/40] Moved trie spec into its own file --- hasher.go | 17 ++- node_encoders.go | 39 ------- proofs.go | 46 ++++---- smst_example_test.go | 82 ++++++++++---- smst_proofs_test.go | 2 +- smt.go | 12 +- smt_proofs_test.go | 2 +- trie_spec.go | 255 +++++++++++++++++++++++++++++++++++++++++++ types.go | 215 ------------------------------------ 9 files changed, 361 insertions(+), 309 deletions(-) create mode 100644 trie_spec.go diff --git a/hasher.go b/hasher.go index 7ac250a..570a03e 100644 --- a/hasher.go +++ b/hasher.go @@ -1,6 +1,7 @@ package smt import ( + "encoding/binary" "hash" ) @@ -93,8 +94,8 @@ func (th *trieHasher) digestData(data []byte) []byte { return digest } -// digestLeaf returns the encoded leaf data as well as its hash (i.e. digest) -func (th *trieHasher) digestLeaf(path, data []byte) (digest, value []byte) { +// digestLeafNode returns the encoded leaf data as well as its hash (i.e. digest) +func (th *trieHasher) digestLeafNode(path, data []byte) (digest, value []byte) { value = encodeLeafNode(path, data) digest = th.digestData(value) return @@ -106,8 +107,8 @@ func (th *trieHasher) digestInnerNode(leftData, rightData []byte) (digest, value return } -func (th *trieHasher) digestSumLeaf(path, leafData []byte) (digest, value []byte) { - value = encodeLeafNode(path, leafData) +func (th *trieHasher) digestSumLeafNode(path, data []byte) (digest, value []byte) { + value = encodeLeafNode(path, data) digest = th.digestData(value) digest = append(digest, value[len(value)-sumSizeBits:]...) return @@ -126,7 +127,13 @@ func (th *trieHasher) parseInnerNode(data []byte) (leftData, rightData []byte) { return } -func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte) { +func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte, sum uint64) { + // Extract the sum from the encoded node data + var sumBz [sumSizeBits]byte + copy(sumBz[:], data[len(data)-sumSizeBits:]) + binary.BigEndian.PutUint64(sumBz[:], sum) + + // Extract the left and right children dataWithoutSum := data[:len(data)-sumSizeBits] leftData = dataWithoutSum[len(innerNodePrefix) : len(innerNodePrefix)+th.hashSize()+sumSizeBits] rightData = dataWithoutSum[len(innerNodePrefix)+th.hashSize()+sumSizeBits:] diff --git a/node_encoders.go b/node_encoders.go index ddb0a61..65ef604 100644 --- a/node_encoders.go +++ b/node_encoders.go @@ -50,44 +50,6 @@ func isInnerNode(data []byte) bool { return bytes.Equal(data[:prefixLen], innerNodePrefix) } -// parseLeafNode parses a leafNode into its components -func parseLeafNode(data []byte, ph PathHasher) (path, value []byte) { - // panics if not a leaf node - checkPrefix(data, leafNodePrefix) - - path = data[prefixLen : prefixLen+ph.PathSize()] - value = data[prefixLen+ph.PathSize():] - return -} - -// parseExtNode parses an extNode into its components -func parseExtNode(data []byte, ph PathHasher) (pathBounds, path, childData []byte) { - // panics if not an extension node - checkPrefix(data, extNodePrefix) - - // +2 represents the length of the pathBounds - pathBounds = data[prefixLen : prefixLen+2] - path = data[prefixLen+2 : prefixLen+2+ph.PathSize()] - childData = data[prefixLen+2+ph.PathSize():] - return -} - -// parseSumExtNode parses the pathBounds, path, child data and sum from the encoded extension node data -func parseSumExtNode(data []byte, ph PathHasher) (pathBounds, path, childData []byte, sum [sumSizeBits]byte) { - // panics if not an extension node - checkPrefix(data, extNodePrefix) - - // Extract the sum from the encoded node data - var sumBz [sumSizeBits]byte - copy(sumBz[:], data[len(data)-sumSizeBits:]) - - // +2 represents the length of the pathBounds - pathBounds = data[prefixLen : prefixLen+2] - path = data[prefixLen+2 : prefixLen+2+ph.PathSize()] - childData = data[prefixLen+2+ph.PathSize() : len(data)-sumSizeBits] - return -} - // encodeLeafNode encodes leaf nodes. This function applies to both the SMT and // SMST since the weight of the node is appended to the end of the valueHash. func encodeLeafNode(path, leafData []byte) (data []byte) { @@ -130,7 +92,6 @@ func encodeSumInnerNode(leftData, rightData []byte) (data []byte) { // encodeSumExtensionNode encodes the data of a sum extension nodes func encodeSumExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byte) { - // Compute the sum of the current node var sum [sumSizeBits]byte copy(sum[:], childData[len(childData)-sumSizeBits:]) diff --git a/proofs.go b/proofs.go index 6a31afd..9a49b54 100644 --- a/proofs.go +++ b/proofs.go @@ -49,33 +49,34 @@ func (proof *SparseMerkleProof) Unmarshal(bz []byte) error { return dec.Decode(proof) } +// validateBasic performs basic sanity check on the proof so that a malicious +// proof cannot cause the verifier to fatally exit (e.g. due to an index +// out-of-range error) or cause a CPU DoS attack. func (proof *SparseMerkleProof) validateBasic(spec *TrieSpec) error { - // Do a basic sanity check on the proof, so that a malicious proof cannot - // cause the verifier to fatally exit (e.g. due to an index out-of-range - // error) or cause a CPU DoS attack. - - // Check that the number of supplied sidenodes does not exceed the maximum possible. + // Verify the number of supplied sideNodes does not exceed the possible maximum. if len(proof.SideNodes) > spec.ph.PathSize()*8 { return fmt.Errorf("too many side nodes: got %d but max is %d", len(proof.SideNodes), spec.ph.PathSize()*8) } - // Check that leaf data for non-membership proofs is a valid size. - lps := len(leafNodePrefix) + spec.ph.PathSize() - if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < lps { - return fmt.Errorf("invalid non-membership leaf data size: got %d but min is %d", len(proof.NonMembershipLeafData), lps) - } - // Check that all supplied sidenodes are the correct size. - for _, v := range proof.SideNodes { - if len(v) != spec.hashSize() { - return fmt.Errorf("invalid side node size: got %d but want %d", len(v), spec.hashSize()) - } + // Verify that the non-membership leaf data is of the correct size. + leafPathSize := len(leafNodePrefix) + spec.ph.PathSize() + if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < leafPathSize { + return fmt.Errorf("invalid non-membership leaf data size: got %d but min is %d", len(proof.NonMembershipLeafData), leafPathSize) } // Check that the sibling data hashes to the first side node if not nil if proof.SiblingData == nil || len(proof.SideNodes) == 0 { return nil } - siblingHash := hashPreimage(spec, proof.SiblingData) + + // Check that all supplied sideNodes are the correct size. + for _, sideNodeValue := range proof.SideNodes { + if len(sideNodeValue) != spec.hashSize() { + return fmt.Errorf("invalid side node size: got %d but want %d", len(sideNodeValue), spec.hashSize()) + } + } + + siblingHash := spec.hashPreimage(proof.SiblingData) if eq := bytes.Equal(proof.SideNodes[0], siblingHash); !eq { return fmt.Errorf("invalid sibling data hash: got %x but want %x", siblingHash, proof.SideNodes[0]) } @@ -320,7 +321,13 @@ func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *Trie return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, spec) } -func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, value []byte, spec *TrieSpec) (bool, [][][]byte, error) { +// verifyProofWithUpdates +func verifyProofWithUpdates( + proof *SparseMerkleProof, + root, key, value []byte, + spec *TrieSpec, +) (bool, [][][]byte, error) { + // Retrieve the trie path for the give key path := spec.ph.Path(key) if err := proof.validateBasic(spec); err != nil { @@ -336,7 +343,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v currentHash = spec.placeholder() } else { // Leaf is an unrelated leaf. var actualPath, valueHash []byte - actualPath, valueHash = parseLeafNode(proof.NonMembershipLeafData, spec.ph) + actualPath, valueHash = spec.parseLeafNode(proof.NonMembershipLeafData) if bytes.Equal(actualPath, path) { // This is not an unrelated leaf; non-membership proof failed. return false, nil, errors.Join(ErrBadProof, errors.New("non-membership proof on related leaf")) @@ -347,7 +354,8 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v update[0], update[1] = currentHash, currentData updates = append(updates, update) } - } else { // Membership proof. + } else { + // Membership proof. valueHash := spec.valueHash(value) currentHash, currentData = spec.digestLeaf(path, valueHash) update := make([][]byte, 2) diff --git a/smst_example_test.go b/smst_example_test.go index f7d2d88..65dc8ff 100644 --- a/smst_example_test.go +++ b/smst_example_test.go @@ -82,31 +82,67 @@ func exportToCSV( nodeStore kvstore.MapStore, ) { t.Helper() - rootHash := smst.Root() - rootNode, err := nodeStore.Get(rootHash) + /* + rootHash := smst.Root() + rootNode, err := nodeStore.Get(rootHash) + require.NoError(t, err) + leftData, rightData := smst.Spec().th.parseSumInnerNode(rootNode) + leftChild, err := nodeStore.Get(leftData) + require.NoError(t, err) + rightChild, err := nodeStore.Get(rightData) + require.NoError(t, err) + fmt.Println("Prefix", "isExt", "isLeaf", "isInner") + // false false true + fmt.Println("root", isExtNode(rootNode), isLeafNode(rootNode), isInnerNode(rootNode), rootNode) + fmt.Println() + // false false false + fmt.Println("left", isExtNode(leftChild), isLeafNode(leftChild), isInnerNode(leftChild), leftChild) + fmt.Println() + // false false false + fmt.Println("right", isExtNode(rightChild), isLeafNode(rightChild), isInnerNode(rightChild), rightChild) + fmt.Println() + */ + + /* + for key, value := range innerMap { + v, s, err := smst.Get([]byte(key)) + require.NoError(t, err) + fmt.Println(v, s, []byte(key)) + fmt.Println(value) + fmt.Println("") + fmt.Println("") + } + */ + + helper(t, smst, nodeStore, smst.Root()) +} + +func helper(t *testing.T, smst SparseMerkleSumTrie, nodeStore kvstore.MapStore, nodeDigest []byte) { + t.Helper() + + node, err := nodeStore.Get(nodeDigest) require.NoError(t, err) - leftChild, rightChild := smst.Spec().th.parseSumInnerNode(rootNode) - fmt.Println("Prefix", "isExt", "isLeaf", "isInner") - fmt.Println(rootNode[:1], isExtNode(rootNode), isLeafNode(rootNode), isInnerNode(rootNode)) - fmt.Println(leftChild[:1], isExtNode(leftChild), isLeafNode(leftChild), isInnerNode(leftChild)) - fmt.Println(rightChild[:1], isExtNode(rightChild), isLeafNode(rightChild), isInnerNode(rightChild)) - // path, value := parseLeafNode(rightChild, smst.Spec().ph) - // path2, value2 := parseLeafNode(leftChild, smst.Spec().ph) - // fmt.Println(path, "~~~", value, "~~~", path2, "~~~", value2) - - for key, value := range innerMap { - v, s, err := smst.Get([]byte(key)) - require.NoError(t, err) - fmt.Println(v, s) - fmt.Println(value) - fmt.Println("") - fmt.Println("") + fmt.Println() + if isExtNode(node) { + pathBounds, path, childData, sum := smst.Spec().parseSumExtNode(node) + fmt.Println("ext node sum", sum) + fmt.Println(pathBounds, path) + helper(t, smst, nodeStore, childData) + return + } else if isLeafNode(node) { + path, value := smst.Spec().parseLeafNode(node) + fmt.Println("leaf node sum", 0) + fmt.Println(path, value) + } else if isInnerNode(node) { + leftData, rightData, sum := smst.Spec().th.parseSumInnerNode(node) + fmt.Println("inner node sum", sum) + helper(t, smst, nodeStore, leftData) + helper(t, smst, nodeStore, rightData) } - // Export the trie to a CSV file - // err := smt.ExportToCSV("export.csv") - // if err != nil { - // panic(err) - // } + // v, s, err := smst.Get([]byte(key)) + // require.NoError(t, err) + // require.Equal(t, []byte(value), v) + // require.Equal(t, sum, s) } diff --git a/smst_proofs_test.go b/smst_proofs_test.go index 58efeac..d2b06ac 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -107,7 +107,7 @@ func TestSMST_Proof_Operations(t *testing.T) { binary.BigEndian.PutUint64(sum[:], 5) tval := base.valueHash([]byte("testValue")) tval = append(tval, sum[:]...) - _, leafData := base.th.digestSumLeaf(base.ph.Path([]byte("testKey2")), tval) + _, leafData := base.th.digestSumLeafNode(base.ph.Path([]byte("testKey2")), tval) proof = &SparseMerkleProof{ SideNodes: proof.SideNodes, NonMembershipLeafData: leafData, diff --git a/smt.go b/smt.go index de9cd0c..fd1fac4 100644 --- a/smt.go +++ b/smt.go @@ -522,7 +522,7 @@ func (smt *SMT) ProveClosest(path []byte) ( return proof, nil } -// resolveLazy resolves resolves a stub into a cached node +// resolveLazy resolves a lazy note into a cached node depending on the tree type func (smt *SMT) resolveLazy(node trieNode) (trieNode, error) { stub, ok := node.(*lazyNode) if !ok { @@ -550,7 +550,7 @@ func (smt *SMT) resolveNode(digest []byte) (trieNode, error) { // Return the appropriate node type based on the first byte of the data if isLeafNode(data) { - path, valueHash := parseLeafNode(data, smt.ph) + path, valueHash := smt.parseLeafNode(data) return &leafNode{ path: path, valueHash: valueHash, @@ -558,7 +558,7 @@ func (smt *SMT) resolveNode(digest []byte) (trieNode, error) { digest: digest, }, nil } else if isExtNode(data) { - pathBounds, path, childData := parseExtNode(data, smt.ph) + pathBounds, path, childData := smt.parseExtNode(data) return &extensionNode{ path: path, pathBounds: [2]byte(pathBounds), @@ -595,7 +595,7 @@ func (smt *SMT) resolveSumNode(digest []byte) (trieNode, error) { // Return the appropriate node type based on the first byte of the data if isLeafNode(data) { - path, valueHash := parseLeafNode(data, smt.ph) + path, valueHash := smt.parseLeafNode(data) return &leafNode{ path: path, valueHash: valueHash, @@ -603,7 +603,7 @@ func (smt *SMT) resolveSumNode(digest []byte) (trieNode, error) { digest: digest, }, nil } else if isExtNode(data) { - pathBounds, path, childData, _ := parseSumExtNode(data, smt.ph) + pathBounds, path, childData, _ := smt.parseSumExtNode(data) return &extensionNode{ path: path, pathBounds: [2]byte(pathBounds), @@ -612,7 +612,7 @@ func (smt *SMT) resolveSumNode(digest []byte) (trieNode, error) { digest: digest, }, nil } else if isInnerNode(data) { - leftData, rightData := smt.th.parseSumInnerNode(data) + leftData, rightData, _ := smt.th.parseSumInnerNode(data) return &innerNode{ leftChild: &lazyNode{leftData}, rightChild: &lazyNode{rightData}, diff --git a/smt_proofs_test.go b/smt_proofs_test.go index e159b72..61e3f80 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -84,7 +84,7 @@ func TestSMT_Proof_Operations(t *testing.T) { require.False(t, result) // Try proving a default value for a non-default leaf. - _, leafData := base.th.digestLeaf(base.ph.Path([]byte("testKey2")), base.valueHash([]byte("testValue"))) + _, leafData := base.th.digestLeafNode(base.ph.Path([]byte("testKey2")), base.valueHash([]byte("testValue"))) proof = &SparseMerkleProof{ SideNodes: proof.SideNodes, NonMembershipLeafData: leafData, diff --git a/trie_spec.go b/trie_spec.go new file mode 100644 index 0000000..f5008fc --- /dev/null +++ b/trie_spec.go @@ -0,0 +1,255 @@ +package smt + +import ( + "encoding/binary" + "hash" +) + +// TrieSpec specifies the hashing functions used by a trie instance to encode +// leaf paths and stored values, and the corresponding maximum trie depth. +type TrieSpec struct { + th trieHasher + ph PathHasher + vh ValueHasher + sumTrie bool +} + +// newTrieSpec returns a new TrieSpec with the given hasher and sumTrie flag +func newTrieSpec(hasher hash.Hash, sumTrie bool) TrieSpec { + spec := TrieSpec{th: *NewTrieHasher(hasher)} + spec.ph = &pathHasher{spec.th} + spec.vh = &valueHasher{spec.th} + spec.sumTrie = sumTrie + return spec +} + +// Spec returns the TrieSpec associated with the given trie +func (spec *TrieSpec) Spec() *TrieSpec { + return spec +} + +// placeholder returns the default placeholder value depending on the trie type +func (spec *TrieSpec) placeholder() []byte { + if spec.sumTrie { + placeholder := spec.th.placeholder() + placeholder = append(placeholder, defaultEmptySum[:]...) + return placeholder + } + return spec.th.placeholder() +} + +// hashSize returns the hash size depending on the trie type +func (spec *TrieSpec) hashSize() int { + if spec.sumTrie { + return spec.th.hashSize() + sumSizeBits + } + return spec.th.hashSize() +} + +// digestLeaf returns the hash and preimage of a leaf node depending on the trie type +func (spec *TrieSpec) digestLeaf(path, value []byte) ([]byte, []byte) { + if spec.sumTrie { + return spec.th.digestSumLeafNode(path, value) + } + return spec.th.digestLeafNode(path, value) +} + +// digestNode returns the hash and preimage of a node depending on the trie type +func (spec *TrieSpec) digestInnerNode(left, right []byte) ([]byte, []byte) { + if spec.sumTrie { + return spec.th.digestSumInnerNode(left, right) + } + return spec.th.digestInnerNode(left, right) +} + +// digest hashes a node depending on the trie type +func (spec *TrieSpec) digest(node trieNode) []byte { + if spec.sumTrie { + return spec.digestSumNode(node) + } + return spec.digestNode(node) +} + +// encode serializes a node depending on the trie type +func (spec *TrieSpec) encode(node trieNode) []byte { + if spec.sumTrie { + return spec.encodeSumNode(node) + } + return spec.encodeNode(node) +} + +// hashPreimage hashes the serialised data provided depending on the trie type +func (spec *TrieSpec) hashPreimage(data []byte) []byte { + if spec.sumTrie { + return spec.hashSumSerialization(data) + } + return spec.hashSerialization(data) +} + +// Used for verification of serialized proof data +func (spec *TrieSpec) hashSerialization(data []byte) []byte { + if isExtNode(data) { + pathBounds, path, childHash := spec.parseExtNode(data) + ext := extensionNode{path: path, child: &lazyNode{childHash}} + copy(ext.pathBounds[:], pathBounds) + return spec.digestNode(&ext) + } + return spec.th.digestData(data) +} + +// Used for verification of serialized proof data for sum trie nodes +func (spec *TrieSpec) hashSumSerialization(data []byte) []byte { + if isExtNode(data) { + pathBounds, path, childHash, _ := spec.parseSumExtNode(data) + ext := extensionNode{path: path, child: &lazyNode{childHash}} + copy(ext.pathBounds[:], pathBounds) + return spec.digestSumNode(&ext) + } + digest := spec.th.digestData(data) + digest = append(digest, data[len(data)-sumSizeBits:]...) + return digest +} + +// depth returns the maximum depth of the trie. +// Since this tree is a binary tree, the depth is the number of bits in the path +// TODO_IN_THIS_PR: Try to understand why we're not taking the log of the output +func (spec *TrieSpec) depth() int { + return spec.ph.PathSize() * 8 // path size is in bytes so multiply by 8 to get num bits +} + +// valueHash returns the hash of a value, or the value itself if no value hasher is specified. +func (spec *TrieSpec) valueHash(value []byte) []byte { + if spec.vh == nil { + return value + } + return spec.vh.HashValue(value) +} + +// encodeNode serializes a node into a byte slice +func (spec *TrieSpec) encodeNode(node trieNode) []byte { + switch n := node.(type) { + case *lazyNode: + panic("Encoding a lazyNode is not supported") + case *leafNode: + return encodeLeafNode(n.path, n.valueHash) + case *innerNode: + leftChild := spec.digestNode(n.leftChild) + rightChild := spec.digestNode(n.rightChild) + return encodeInnerNode(leftChild, rightChild) + case *extensionNode: + child := spec.digestNode(n.child) + return encodeExtensionNode(n.pathBounds, n.path, child) + default: + panic("Unknown node type") + } +} + +// digestNode hashes a node and returns its digest +func (spec *TrieSpec) digestNode(node trieNode) []byte { + if node == nil { + return spec.th.placeholder() + } + + var cachedDigest *[]byte + switch n := node.(type) { + case *lazyNode: + return n.digest + case *leafNode: + cachedDigest = &n.digest + case *innerNode: + cachedDigest = &n.digest + case *extensionNode: + if n.digest == nil { + n.digest = spec.digestNode(n.expand()) + } + return n.digest + } + if *cachedDigest == nil { + *cachedDigest = spec.th.digestData(spec.encodeNode(node)) + } + return *cachedDigest +} + +// encodeSumNode serializes a sum node and returns the preImage hash. +func (spec *TrieSpec) encodeSumNode(node trieNode) (preImage []byte) { + switch n := node.(type) { + case *lazyNode: + panic("encodeSumNode(lazyNode)") + case *leafNode: + return encodeLeafNode(n.path, n.valueHash) + case *innerNode: + leftChild := spec.digestSumNode(n.leftChild) + rightChild := spec.digestSumNode(n.rightChild) + return encodeSumInnerNode(leftChild, rightChild) + case *extensionNode: + child := spec.digestSumNode(n.child) + return encodeSumExtensionNode(n.pathBounds, n.path, child) + } + return nil +} + +// digestSumNode hashes a sum node returning its digest in the following form: [node hash]+[8 byte sum] +func (spec *TrieSpec) digestSumNode(node trieNode) []byte { + if node == nil { + return spec.placeholder() + } + var cache *[]byte + switch n := node.(type) { + case *lazyNode: + return n.digest + case *leafNode: + cache = &n.digest + case *innerNode: + cache = &n.digest + case *extensionNode: + if n.digest == nil { + n.digest = spec.digestSumNode(n.expand()) + } + return n.digest + } + if *cache == nil { + preImage := spec.encodeSumNode(node) + *cache = spec.th.digestData(preImage) + *cache = append(*cache, preImage[len(preImage)-sumSizeBits:]...) + } + return *cache +} + +// parseLeafNode parses a leafNode into its components +func (spec *TrieSpec) parseLeafNode(data []byte) (path, value []byte) { + // panics if not a leaf node + checkPrefix(data, leafNodePrefix) + + path = data[prefixLen : prefixLen+spec.ph.PathSize()] + value = data[prefixLen+spec.ph.PathSize():] + return +} + +// parseExtNode parses an extNode into its components +func (spec *TrieSpec) parseExtNode(data []byte) (pathBounds, path, childData []byte) { + // panics if not an extension node + checkPrefix(data, extNodePrefix) + + // +2 represents the length of the pathBounds + pathBounds = data[prefixLen : prefixLen+2] + path = data[prefixLen+2 : prefixLen+2+spec.ph.PathSize()] + childData = data[prefixLen+2+spec.ph.PathSize():] + return +} + +// parseSumExtNode parses the pathBounds, path, child data and sum from the encoded extension node data +func (spec *TrieSpec) parseSumExtNode(data []byte) (pathBounds, path, childData []byte, sum uint64) { + // panics if not an extension node + checkPrefix(data, extNodePrefix) + + // Extract the sum from the encoded node data + var sumBz [sumSizeBits]byte + copy(sumBz[:], data[len(data)-sumSizeBits:]) + binary.BigEndian.PutUint64(sumBz[:], sum) + + // +2 represents the length of the pathBounds + pathBounds = data[prefixLen : prefixLen+2] + path = data[prefixLen+2 : prefixLen+2+spec.ph.PathSize()] + childData = data[prefixLen+2+spec.ph.PathSize() : len(data)-sumSizeBits] + return +} diff --git a/types.go b/types.go index 8dc3c30..0be9de6 100644 --- a/types.go +++ b/types.go @@ -1,9 +1,5 @@ package smt -import ( - "hash" -) - // TODO_DISCUSS_IN_THIS_PR_IMPROVEMENTS: // 1. Should we rename all instances of digest to hash? // 2. Should we introduce a shared interface between SparseMerkleTrie and SparseMerkleSumTrie? @@ -82,214 +78,3 @@ type SparseMerkleSumTrie interface { // Spec returns the TrieSpec for the trie Spec() *TrieSpec } - -// TrieSpec specifies the hashing functions used by a trie instance to encode -// leaf paths and stored values, and the corresponding maximum trie depth. -type TrieSpec struct { - th trieHasher - ph PathHasher - vh ValueHasher - sumTrie bool -} - -// newTrieSpec returns a new TrieSpec with the given hasher and sumTrie flag -func newTrieSpec(hasher hash.Hash, sumTrie bool) TrieSpec { - spec := TrieSpec{th: *NewTrieHasher(hasher)} - spec.ph = &pathHasher{spec.th} - spec.vh = &valueHasher{spec.th} - spec.sumTrie = sumTrie - return spec -} - -// Spec returns the TrieSpec associated with the given trie -func (spec *TrieSpec) Spec() *TrieSpec { - return spec -} - -// placeholder returns the default placeholder value depending on the trie type -func (spec *TrieSpec) placeholder() []byte { - if spec.sumTrie { - placeholder := spec.th.placeholder() - placeholder = append(placeholder, defaultEmptySum[:]...) - return placeholder - } - return spec.th.placeholder() -} - -// hashSize returns the hash size depending on the trie type -func (spec *TrieSpec) hashSize() int { - if spec.sumTrie { - return spec.th.hashSize() + sumSizeBits - } - return spec.th.hashSize() -} - -// digestLeaf returns the hash and preimage of a leaf node depending on the trie type -func (spec *TrieSpec) digestLeaf(path, value []byte) ([]byte, []byte) { - if spec.sumTrie { - return spec.th.digestSumLeaf(path, value) - } - return spec.th.digestLeaf(path, value) -} - -// digestNode returns the hash and preimage of a node depending on the trie type -func (spec *TrieSpec) digestInnerNode(left, right []byte) ([]byte, []byte) { - if spec.sumTrie { - return spec.th.digestSumInnerNode(left, right) - } - return spec.th.digestInnerNode(left, right) -} - -// digest hashes a node depending on the trie type -func (spec *TrieSpec) digest(node trieNode) []byte { - if spec.sumTrie { - return spec.digestSumNode(node) - } - return spec.digestNode(node) -} - -// encode serializes a node depending on the trie type -func (spec *TrieSpec) encode(node trieNode) []byte { - if spec.sumTrie { - return spec.encodeSumNode(node) - } - return spec.encodeNode(node) -} - -// hashPreimage hashes the serialised data provided depending on the trie type -func hashPreimage(spec *TrieSpec, data []byte) []byte { - if spec.sumTrie { - return hashSumSerialization(spec, data) - } - return hashSerialization(spec, data) -} - -// Used for verification of serialized proof data -func hashSerialization(smt *TrieSpec, data []byte) []byte { - if isExtNode(data) { - pathBounds, path, childHash := parseExtNode(data, smt.ph) - ext := extensionNode{path: path, child: &lazyNode{childHash}} - copy(ext.pathBounds[:], pathBounds) - return smt.digestNode(&ext) - } - return smt.th.digestData(data) -} - -// Used for verification of serialized proof data for sum trie nodes -func hashSumSerialization(smt *TrieSpec, data []byte) []byte { - if isExtNode(data) { - pathBounds, path, childHash, _ := parseSumExtNode(data, smt.ph) - ext := extensionNode{path: path, child: &lazyNode{childHash}} - copy(ext.pathBounds[:], pathBounds) - return smt.digestSumNode(&ext) - } - digest := smt.th.digestData(data) - digest = append(digest, data[len(data)-sumSizeBits:]...) - return digest -} - -// depth returns the maximum depth of the trie. -// Since this tree is a binary tree, the depth is the number of bits in the path -// TODO_IN_THIS_PR: Try to understand why we're not taking the log of the output -func (spec *TrieSpec) depth() int { - return spec.ph.PathSize() * 8 // path size is in bytes so multiply by 8 to get num bits -} - -// valueHash returns the hash of a value, or the value itself if no value hasher is specified. -func (spec *TrieSpec) valueHash(value []byte) []byte { - if spec.vh == nil { - return value - } - return spec.vh.HashValue(value) -} - -// encodeNode serializes a node into a byte slice -func (spec *TrieSpec) encodeNode(node trieNode) []byte { - switch n := node.(type) { - case *lazyNode: - panic("Encoding a lazyNode is not supported") - case *leafNode: - return encodeLeafNode(n.path, n.valueHash) - case *innerNode: - leftChild := spec.digestNode(n.leftChild) - rightChild := spec.digestNode(n.rightChild) - return encodeInnerNode(leftChild, rightChild) - case *extensionNode: - child := spec.digestNode(n.child) - return encodeExtensionNode(n.pathBounds, n.path, child) - default: - panic("Unknown node type") - } -} - -// digestNode hashes a node and returns its digest -func (spec *TrieSpec) digestNode(node trieNode) []byte { - if node == nil { - return spec.th.placeholder() - } - - var cachedDigest *[]byte - switch n := node.(type) { - case *lazyNode: - return n.digest - case *leafNode: - cachedDigest = &n.digest - case *innerNode: - cachedDigest = &n.digest - case *extensionNode: - if n.digest == nil { - n.digest = spec.digestNode(n.expand()) - } - return n.digest - } - if *cachedDigest == nil { - *cachedDigest = spec.th.digestData(spec.encodeNode(node)) - } - return *cachedDigest -} - -// encodeSumNode serializes a sum node and returns the preImage hash. -func (spec *TrieSpec) encodeSumNode(node trieNode) (preImage []byte) { - switch n := node.(type) { - case *lazyNode: - panic("encodeSumNode(lazyNode)") - case *leafNode: - return encodeLeafNode(n.path, n.valueHash) - case *innerNode: - leftChild := spec.digestSumNode(n.leftChild) - rightChild := spec.digestSumNode(n.rightChild) - preImage = encodeSumInnerNode(leftChild, rightChild) - return preImage - case *extensionNode: - child := spec.digestSumNode(n.child) - return encodeSumExtensionNode(n.pathBounds, n.path, child) - } - return nil -} - -// digestSumNode hashes a sum node returning its digest in the following form: [node hash]+[8 byte sum] -func (spec *TrieSpec) digestSumNode(node trieNode) []byte { - if node == nil { - return spec.placeholder() - } - var cache *[]byte - switch n := node.(type) { - case *lazyNode: - return n.digest - case *leafNode: - cache = &n.digest - case *innerNode: - cache = &n.digest - case *extensionNode: - if n.digest == nil { - n.digest = spec.digestSumNode(n.expand()) - } - return n.digest - } - if *cache == nil { - preImage := spec.encodeSumNode(node) - *cache = spec.th.digestData(preImage) - *cache = append(*cache, preImage[len(preImage)-sumSizeBits:]...) - } - return *cache -} From 5e7ed766138b9e7ec2e60613e1c7e2579a0e82e6 Mon Sep 17 00:00:00 2001 From: h5law Date: Tue, 19 Mar 2024 11:34:46 +0000 Subject: [PATCH 17/40] chore: add context on hashing algorithms used by the trie --- docs/smt.md | 116 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 68 insertions(+), 48 deletions(-) diff --git a/docs/smt.md b/docs/smt.md index bae85d2..19cdbf7 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -4,31 +4,31 @@ - [Overview](#overview) - [Implementation](#implementation) - * [Inner Nodes](#inner-nodes) - * [Extension Nodes](#extension-nodes) - * [Leaf Nodes](#leaf-nodes) - * [Lazy Nodes](#lazy-nodes) - * [Lazy Loading](#lazy-loading) - * [Visualisations](#visualisations) - + [General Trie Structure](#general-trie-structure) - + [Lazy Nodes](#lazy-nodes-1) + - [Inner Nodes](#inner-nodes) + - [Extension Nodes](#extension-nodes) + - [Leaf Nodes](#leaf-nodes) + - [Lazy Nodes](#lazy-nodes) + - [Lazy Loading](#lazy-loading) + - [Visualisations](#visualisations) + - [General Trie Structure](#general-trie-structure) + - [Lazy Nodes](#lazy-nodes-1) - [Paths](#paths) - * [Visualisation](#visualisation) + - [Visualisation](#visualisation) - [Values](#values) - * [Nil values](#nil-values) + - [Nil values](#nil-values) - [Hashers & Digests](#hashers--digests) - [Roots](#roots) - [Proofs](#proofs) - * [Verification](#verification) - * [Closest Proof](#closest-proof) - + [Closest Proof Use Cases](#closest-proof-use-cases) - * [Compression](#compression) - * [Serialisation](#serialisation) + - [Verification](#verification) + - [Closest Proof](#closest-proof) + - [Closest Proof Use Cases](#closest-proof-use-cases) + - [Compression](#compression) + - [Serialisation](#serialisation) - [Database](#database) - * [Database Submodules](#database-submodules) - + [SimpleMap](#simplemap) - + [Badger](#badger) - * [Data Loss](#data-loss) + - [Database Submodules](#database-submodules) + - [SimpleMap](#simplemap) + - [Badger](#badger) + - [Data Loss](#data-loss) - [Sparse Merkle Sum Trie](#sparse-merkle-sum-trie) @@ -44,6 +44,25 @@ make SMTs valuable in applications like blockchains, decentralized databases, and authenticated data structures, providing optimized and trustworthy data storage and verification. +Although any hash function that satisfies the `hash.Hash` interface can be used +to construct the trie it is **strongly recommended** to use a hashing function +that provides the following properties: + +- **Collision resistance**: The hash function must be collision resistant, in + order for the inputs of the SMT to be unique. +- **Preimage resistance**: The hash function must be preimage resistant, to + protect against the attack of the Merkle tree construction attacks where the + attacker can modify unknown data. +- **Efficiency**: The hash function must be efficient, as it is used to compute + the hash of many nodes in the trie. + +Therefore it is recommended to use a hashing function such as: + +- `sha256` +- `sha3_256`/`keccak256` + +Or another sufficiently secure hashing algorithm. + See [smt.go](../smt.go) for more details on the implementation. ## Implementation @@ -66,9 +85,9 @@ The SMT has 4 node types that are used to construct the trie: ### Inner Nodes -Inner nodes represent a branch in the trie with two **non-nil** child nodes. -The inner node has an internal `digest` which represents the hash of the child -nodes concatenated hashes. +Inner nodes represent a branch in the trie with two **non-nil** child nodes. The +inner node has an internal `digest` which represents the hash of the child nodes +concatenated hashes. ### Extension Nodes @@ -307,8 +326,8 @@ described in the [implementation](#implementation) section. The following diagram represents the creation of a leaf node in an abstracted and simplified manner. -_Note: This diagram is not entirely accurate regarding the process of creating -a leaf node, but is a good representation of the process._ +_Note: This diagram is not entirely accurate regarding the process of creating a +leaf node, but is a good representation of the process._ ```mermaid graph TD @@ -346,17 +365,17 @@ graph TD ## Roots The root of the tree is a slice of bytes. `MerkleRoot` is an alias for `[]byte`. -This design enables easily passing around the data (e.g. on-chain) -while maintaining primitive usage in different use cases (e.g. proofs). +This design enables easily passing around the data (e.g. on-chain) while +maintaining primitive usage in different use cases (e.g. proofs). `MerkleRoot` provides helpers, such as retrieving the `Sum() uint64` to -interface with data it captures. However, for the SMT it **always** panics, -as there is no sum. +interface with data it captures. However, for the SMT it **always** panics, as +there is no sum. ## Proofs -The `SparseMerkleProof` type contains the information required for inclusion -and exclusion proofs, depending on the key provided to the trie method +The `SparseMerkleProof` type contains the information required for inclusion and +exclusion proofs, depending on the key provided to the trie method `Prove(key []byte)` either an inclusion or exclusion proof will be generated. _NOTE: The inclusion and exclusion proof are the same type, just constructed @@ -397,29 +416,29 @@ using the `VerifyClosestProof` function which requires the proof and root hash of the trie. Since the `ClosestProof` method takes a hash as input, it is possible to place a -leaf in the trie according to the hash's path, if it is known. Depending on -the use case of this function this may expose a vulnerability. **It is not -intendend to be used as a general purpose proof mechanism**, but instead as a -**Commit and Reveal** mechanism, as detailed below. +leaf in the trie according to the hash's path, if it is known. Depending on the +use case of this function this may expose a vulnerability. **It is not intendend +to be used as a general purpose proof mechanism**, but instead as a **Commit and +Reveal** mechanism, as detailed below. #### Closest Proof Use Cases The `CloestProof` function is intended for use as a `commit & reveal` mechanism. Where there are two actors involved, the **prover** and **verifier**. -_NOTE: Throughout this document, `commitment` of the the trie's root hash is also -referred to as closing the trie, such that no more updates are made to it once -committed._ +_NOTE: Throughout this document, `commitment` of the the trie's root hash is +also referred to as closing the trie, such that no more updates are made to it +once committed._ Consider the following attack vector (**without** a commit prior to a reveal) into consideration: 1. The **verifier** picks the hash (i.e. a single branch) they intend to check -1. The **prover** inserts a leaf (i.e. a value) whose key (determined via the +2. The **prover** inserts a leaf (i.e. a value) whose key (determined via the hasher) has a longer common prefix than any other leaf in the trie. -1. Due to the deterministic nature of the `ClosestProof`, method this leaf will +3. Due to the deterministic nature of the `ClosestProof`, method this leaf will **always** be returned given the identified hash. -1. The **verifier** then verifies the revealed `ClosestProof`, which returns a +4. The **verifier** then verifies the revealed `ClosestProof`, which returns a branch the **prover** inserted after knowing which leaf was going to be checked. @@ -428,16 +447,16 @@ Consider the following normal flow (**with** a commit prior to reveal) as 1. The **prover** commits to the state of their trie by publishes their root hash, thereby _closing_ their trie and not being able to make further changes. -1. The **verifier** selects a hash to be used in the `commit & reveal` process +2. The **verifier** selects a hash to be used in the `commit & reveal` process that the **prover** must provide a closest proof for. -1. The **prover** utilises this hash and computes the `ClosestProof` on their +3. The **prover** utilises this hash and computes the `ClosestProof` on their _closed_ trie, producing a `ClosestProof`, thus revealing a deterministic, pseudo-random leaf that existed in the tree prior to commitment, yet -1. The **verifier** verifies the proof, in turn, verifying the commitment - made by the **prover** to the state of the trie in the first step. -1. The **prover** had no opportunity to insert a new leaf into the trie - after learning which hash the **verifier** was going to require a - `ClosestProof` for. +4. The **verifier** verifies the proof, in turn, verifying the commitment made + by the **prover** to the state of the trie in the first step. +5. The **prover** had no opportunity to insert a new leaf into the trie after + learning which hash the **verifier** was going to require a `ClosestProof` + for. ### Compression @@ -497,7 +516,8 @@ database. It's interface exposes numerous extra methods not used by the trie, However it can still be used as a node-store with both in-memory and persistent options. -See [badger-store.md](./badger-store.md.md) for the details of the implementation. +See [badger-store.md](./badger-store.md.md) for the details of the +implementation. ### Data Loss From 9e9d9b3f1437002c27196eacedf827d8ca717618 Mon Sep 17 00:00:00 2001 From: h5law Date: Tue, 19 Mar 2024 11:35:10 +0000 Subject: [PATCH 18/40] chore: add context on external kvstore writeability --- docs/mapstore.md | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/docs/mapstore.md b/docs/mapstore.md index a47c670..d36fc95 100644 --- a/docs/mapstore.md +++ b/docs/mapstore.md @@ -3,8 +3,8 @@ - [Implementations](#implementations) - * [SimpleMap](#simplemap) - * [BadgerV4](#badgerv4) + - [SimpleMap](#simplemap) + - [BadgerV4](#badgerv4) @@ -12,6 +12,10 @@ The `MapStore` is a simple interface used by the SM(S)T to store, delete and retrieve key-value pairs. It is intentionally simple and minimalistic so as to enable different key-value engines to implement and back the trie database. +Any key-value store used by the tries should **not** be able to be externally +writeable in production. This opens the possibility to attacks where the writer +can modify the trie database and prove values that were not inserted. + See: [the interface](../kvstore/interfaces.go) for a more detailed description of the simple interface required by the SM(S)T. @@ -31,11 +35,11 @@ See [simplemap.go](../kvstore/simplemap/simplemap.go) for more details. ### BadgerV4 -This library provides a wrapper around [dgraph-io/badger][badgerv4] to adhere -to the `MapStore` interface. See the [full documentation](./badger-store.md) -for additional functionality and implementation details. +This library provides a wrapper around [dgraph-io/badger][badgerv4] to adhere to +the `MapStore` interface. See the [full documentation](./badger-store.md) for +additional functionality and implementation details. -See: [badger](../kvstore/badger/) for more details on the implementation of -this submodule. +See: [badger](../kvstore/badger/) for more details on the implementation of this +submodule. [badgerv4]: https://github.com/dgraph-io/badger From 4831b52ba5f12955b8ac94e0ae96ea189116d249 Mon Sep 17 00:00:00 2001 From: h5law Date: Tue, 19 Mar 2024 11:38:17 +0000 Subject: [PATCH 19/40] feat: consolidate ClosestProof verification and remove the NilPathHasher method --- options.go | 18 ------------- proofs.go | 66 ++++++++++++++++++++++++++++++++++++--------- smst_proofs_test.go | 46 ++++++++++++++++++++++--------- smt_proofs_test.go | 19 +++++++------ 4 files changed, 96 insertions(+), 53 deletions(-) diff --git a/options.go b/options.go index c4eb422..884d559 100644 --- a/options.go +++ b/options.go @@ -1,9 +1,5 @@ package smt -import ( - "hash" -) - // Option is a function that configures SparseMerkleTrie. type Option func(*TrieSpec) @@ -16,17 +12,3 @@ func WithPathHasher(ph PathHasher) Option { func WithValueHasher(vh ValueHasher) Option { return func(ts *TrieSpec) { ts.vh = vh } } - -// NoPrehashSpec returns a new TrieSpec that has a nil Value Hasher and a nil -// Path Hasher -// NOTE: This should only be used when values are already hashed and a path is -// used instead of a key during proof verification, otherwise these will be -// double hashed and produce an incorrect leaf digest invalidating the proof. -func NoPrehashSpec(hasher hash.Hash, sumTrie bool) *TrieSpec { - spec := newTrieSpec(hasher, sumTrie) - opt := WithPathHasher(newNilPathHasher(hasher.Size())) - opt(&spec) - opt = WithValueHasher(nil) - opt(&spec) - return &spec -} diff --git a/proofs.go b/proofs.go index 64ce171..3a2b690 100644 --- a/proofs.go +++ b/proofs.go @@ -61,7 +61,11 @@ func (proof *SparseMerkleProof) validateBasic(spec *TrieSpec) error { // Check that leaf data for non-membership proofs is a valid size. lps := len(leafPrefix) + spec.ph.PathSize() if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < lps { - return fmt.Errorf("invalid non-membership leaf data size: got %d but min is %d", len(proof.NonMembershipLeafData), lps) + return fmt.Errorf( + "invalid non-membership leaf data size: got %d but min is %d", + len(proof.NonMembershipLeafData), + lps, + ) } // Check that all supplied sidenodes are the correct size. @@ -133,7 +137,11 @@ func (proof *SparseCompactMerkleProof) validateBasic(spec *TrieSpec) error { // Compact proofs: check that NumSideNodes is within the right range. if proof.NumSideNodes < 0 || proof.NumSideNodes > spec.ph.PathSize()*8 { - return fmt.Errorf("invalid number of side nodes: got %d, min is 0 and max is %d", len(proof.SideNodes), spec.ph.PathSize()*8) + return fmt.Errorf( + "invalid number of side nodes: got %d, min is 0 and max is %d", + len(proof.SideNodes), + spec.ph.PathSize()*8, + ) } // Compact proofs: check that the length of the bit mask is as expected @@ -185,6 +193,17 @@ func (proof *SparseMerkleClosestProof) Unmarshal(bz []byte) error { return dec.Decode(proof) } +// GetValueHash returns the value hash of the closest proof. +func (proof *SparseMerkleClosestProof) GetValueHash(spec *TrieSpec) []byte { + if proof.ClosestValueHash == nil { + return nil + } + if spec.sumTrie { + return proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize] + } + return proof.ClosestValueHash +} + func (proof *SparseMerkleClosestProof) validateBasic(spec *TrieSpec) error { // ensure the depth of the leaf node being proven is within the path size if proof.Depth < 0 || proof.Depth > spec.ph.PathSize()*8 { @@ -246,7 +265,12 @@ func (proof *SparseCompactMerkleClosestProof) validateBasic(spec *TrieSpec) erro } for i, b := range proof.FlippedBits { if len(b) > maxSliceLen { - return fmt.Errorf("invalid compressed flipped bit index %d: got length %d, max is %d]", i, bytesToInt(b), maxSliceLen) + return fmt.Errorf( + "invalid compressed flipped bit index %d: got length %d, max is %d]", + i, + bytesToInt(b), + maxSliceLen, + ) } } // perform a sanity check on the closest proof @@ -301,26 +325,38 @@ func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint6 // VerifyClosestProof verifies a Merkle proof for a proof of inclusion for a leaf // found to have the closest path to the one provided to the proof structure -// -// TO_AUDITOR: This is akin to an inclusion proof with N (num flipped bits) exclusion -// proof wrapped into one and needs to be reviewed from an algorithm POV. func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *TrieSpec) (bool, error) { if err := proof.validateBasic(spec); err != nil { return false, errors.Join(ErrBadProof, err) } - if !spec.sumTrie { - return VerifyProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash, spec) + // Create a new TrieSpec with a nil path hasher - as the ClosestProof + // already contains a hashed path, double hashing it will invalidate the + // proof. + nilSpec := &TrieSpec{ + th: spec.th, + ph: newNilPathHasher(spec.ph.PathSize()), + vh: spec.vh, + sumTrie: spec.sumTrie, + } + if !nilSpec.sumTrie { + return VerifyProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash, nilSpec) } if proof.ClosestValueHash == nil { - return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, spec) + return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, nilSpec) } sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSize:] sum := binary.BigEndian.Uint64(sumBz) valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize] - return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, spec) + return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, nilSpec) } -func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, value []byte, spec *TrieSpec) (bool, [][][]byte, error) { +func verifyProofWithUpdates( + proof *SparseMerkleProof, + root []byte, + key []byte, + value []byte, + spec *TrieSpec, +) (bool, [][][]byte, error) { path := spec.ph.Path(key) if err := proof.validateBasic(spec); err != nil { @@ -384,7 +420,13 @@ func VerifyCompactProof(proof *SparseCompactMerkleProof, root []byte, key, value } // VerifyCompactSumProof is similar to VerifySumProof but for a compacted Merkle proof. -func VerifyCompactSumProof(proof *SparseCompactMerkleProof, root []byte, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) { +func VerifyCompactSumProof( + proof *SparseCompactMerkleProof, + root []byte, + key, value []byte, + sum uint64, + spec *TrieSpec, +) (bool, error) { decompactedProof, err := DecompactProof(proof, spec) if err != nil { return false, errors.Join(ErrBadProof, err) diff --git a/smst_proofs_test.go b/smst_proofs_test.go index 21fd454..c909d23 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -79,7 +79,14 @@ func TestSMST_Proof_Operations(t *testing.T) { result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 10, base) // wrong value and sum require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey"), []byte("testValue"), 5, base) // invalid proof + result, err = VerifySumProof( + randomiseSumProof(proof), + root, + []byte("testKey"), + []byte("testValue"), + 5, + base, + ) // invalid proof require.NoError(t, err) require.False(t, result) @@ -98,7 +105,14 @@ func TestSMST_Proof_Operations(t *testing.T) { result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("badValue"), 10, base) // wrong value and sum require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey2"), []byte("testValue"), 5, base) // invalid proof + result, err = VerifySumProof( + randomiseSumProof(proof), + root, + []byte("testKey2"), + []byte("testValue"), + 5, + base, + ) // invalid proof require.NoError(t, err) require.False(t, result) @@ -129,7 +143,14 @@ func TestSMST_Proof_Operations(t *testing.T) { result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultValue, 5, base) // wrong sum require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey3"), defaultValue, 0, base) // invalid proof + result, err = VerifySumProof( + randomiseSumProof(proof), + root, + []byte("testKey3"), + defaultValue, + 0, + base, + ) // invalid proof require.NoError(t, err) require.False(t, result) } @@ -204,7 +225,6 @@ func TestSMST_Proof_ValidateBasic(t *testing.T) { func TestSMST_ClosestProof_ValidateBasic(t *testing.T) { smn := simplemap.NewSimpleMap() smst := NewSparseMerkleSumTrie(smn, sha256.New()) - np := NoPrehashSpec(sha256.New(), true) base := smst.Spec() path := sha256.Sum256([]byte("testKey2")) flipPathBit(path[:], 3) @@ -227,14 +247,14 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) { require.NoError(t, err) proof.Depth = -1 require.EqualError(t, proof.validateBasic(base), "invalid depth: got -1, outside of [0, 256]") - result, err := VerifyClosestProof(proof, root, np) + result, err := VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) require.Error(t, err) proof.Depth = 257 require.EqualError(t, proof.validateBasic(base), "invalid depth: got 257, outside of [0, 256]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -244,14 +264,14 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) { require.NoError(t, err) proof.FlippedBits[0] = -1 require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got -1, outside of [0, 8]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) require.Error(t, err) proof.FlippedBits[0] = 9 require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got 9, outside of [0, 8]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -265,7 +285,7 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) { proof.validateBasic(base), "invalid closest path: 8d13809f932d0296b88c1913231ab4b403f05c88363575476204fef6930f22ae (not equal at bit: 3)", ) - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -326,7 +346,7 @@ func TestSMST_ProveClosest(t *testing.T) { ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields }) - result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true)) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.NoError(t, err) require.True(t, result) @@ -352,7 +372,7 @@ func TestSMST_ProveClosest(t *testing.T) { ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields }) - result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true)) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.NoError(t, err) require.True(t, result) } @@ -381,7 +401,7 @@ func TestSMST_ProveClosest_Empty(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true)) + result, err := VerifyClosestProof(proof, smst.Root(), smst.Spec()) require.NoError(t, err) require.True(t, result) } @@ -419,7 +439,7 @@ func TestSMST_ProveClosest_OneNode(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true)) + result, err := VerifyClosestProof(proof, smst.Root(), smst.Spec()) require.NoError(t, err) require.True(t, result) } diff --git a/smt_proofs_test.go b/smt_proofs_test.go index 2cf70c8..1c353df 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -178,7 +178,6 @@ func TestSMT_Proof_ValidateBasic(t *testing.T) { func TestSMT_ClosestProof_ValidateBasic(t *testing.T) { smn := simplemap.NewSimpleMap() smt := NewSparseMerkleTrie(smn, sha256.New()) - np := NoPrehashSpec(sha256.New(), false) base := smt.Spec() path := sha256.Sum256([]byte("testKey2")) flipPathBit(path[:], 3) @@ -201,14 +200,14 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) { require.NoError(t, err) proof.Depth = -1 require.EqualError(t, proof.validateBasic(base), "invalid depth: got -1, outside of [0, 256]") - result, err := VerifyClosestProof(proof, root, np) + result, err := VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) require.Error(t, err) proof.Depth = 257 require.EqualError(t, proof.validateBasic(base), "invalid depth: got 257, outside of [0, 256]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -218,14 +217,14 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) { require.NoError(t, err) proof.FlippedBits[0] = -1 require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got -1, outside of [0, 8]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) require.Error(t, err) proof.FlippedBits[0] = 9 require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got 9, outside of [0, 8]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -239,7 +238,7 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) { proof.validateBasic(base), "invalid closest path: 8d13809f932d0296b88c1913231ab4b403f05c88363575476204fef6930f22ae (not equal at bit: 3)", ) - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -287,7 +286,7 @@ func TestSMT_ProveClosest(t *testing.T) { checkClosestCompactEquivalence(t, proof, smt.Spec()) require.NotEqual(t, proof, &SparseMerkleClosestProof{}) - result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false)) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.NoError(t, err) require.True(t, result) closestPath := sha256.Sum256([]byte("testKey2")) @@ -304,7 +303,7 @@ func TestSMT_ProveClosest(t *testing.T) { checkClosestCompactEquivalence(t, proof, smt.Spec()) require.NotEqual(t, proof, &SparseMerkleClosestProof{}) - result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false)) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.NoError(t, err) require.True(t, result) closestPath = sha256.Sum256([]byte("testKey4")) @@ -336,7 +335,7 @@ func TestSMT_ProveClosest_Empty(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false)) + result, err := VerifyClosestProof(proof, smt.Root(), smt.Spec()) require.NoError(t, err) require.True(t, result) } @@ -368,7 +367,7 @@ func TestSMT_ProveClosest_OneNode(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false)) + result, err := VerifyClosestProof(proof, smt.Root(), smt.Spec()) require.NoError(t, err) require.True(t, result) } From 46dc5c53df57a0bb51ad3229ad3b8489d331b575 Mon Sep 17 00:00:00 2001 From: h5law Date: Tue, 19 Mar 2024 11:38:55 +0000 Subject: [PATCH 20/40] feat: reorganise extension node insertion to use separate pointers for the child node --- proofs.go | 11 ----------- smt.go | 29 +++++++++++++++++++++++------ 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/proofs.go b/proofs.go index 3a2b690..0b8f11d 100644 --- a/proofs.go +++ b/proofs.go @@ -193,17 +193,6 @@ func (proof *SparseMerkleClosestProof) Unmarshal(bz []byte) error { return dec.Decode(proof) } -// GetValueHash returns the value hash of the closest proof. -func (proof *SparseMerkleClosestProof) GetValueHash(spec *TrieSpec) []byte { - if proof.ClosestValueHash == nil { - return nil - } - if spec.sumTrie { - return proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize] - } - return proof.ClosestValueHash -} - func (proof *SparseMerkleClosestProof) validateBasic(spec *TrieSpec) error { // ensure the depth of the leaf node being proven is within the path size if proof.Depth < 0 || proof.Depth > spec.ph.PathSize()*8 { diff --git a/smt.go b/smt.go index 558db63..f55a3f8 100644 --- a/smt.go +++ b/smt.go @@ -178,19 +178,36 @@ func (smt *SMT) update( } // We insert an "extension" representing multiple single-branch inner nodes last := &node + var newInner *innerNode + if getPathBit(path, prefixlen) == left { + newInner = &innerNode{ + leftChild: newLeaf, + rightChild: leaf, + } + } else { + newInner = &innerNode{ + leftChild: leaf, + rightChild: newLeaf, + } + } + // Determine if we need to insert an extension or a branch if depth < prefixlen { // note: this keeps path slice alive - GC inefficiency? if depth > 0xff { panic("invalid depth") } - ext := extensionNode{path: path, pathBounds: [2]byte{byte(depth), byte(prefixlen)}} + ext := extensionNode{ + child: newInner, + path: path, + pathBounds: [2]byte{ + byte(depth), byte(prefixlen), + }, + } + // Dereference the last node to replace it with the extension node *last = &ext - last = &ext.child - } - if getPathBit(path, prefixlen) == left { - *last = &innerNode{leftChild: newLeaf, rightChild: leaf} } else { - *last = &innerNode{leftChild: leaf, rightChild: newLeaf} + // Dereference the last node to replace it with the new inner node + *last = newInner } return node, nil } From 38623451c940d4eaf832bc1e035ba3822be496a8 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Thu, 21 Mar 2024 13:23:20 -0700 Subject: [PATCH 21/40] Old changes --- proofs.go | 24 ++++++++++++------------ smst_example_test.go | 19 +++++++++++++++---- trie_spec.go | 16 ++++++++++++++++ 3 files changed, 43 insertions(+), 16 deletions(-) diff --git a/proofs.go b/proofs.go index 9a49b54..a7eed6e 100644 --- a/proofs.go +++ b/proofs.go @@ -327,7 +327,7 @@ func verifyProofWithUpdates( root, key, value []byte, spec *TrieSpec, ) (bool, [][][]byte, error) { - // Retrieve the trie path for the give key + // Retrieve the trie path for the key being proven path := spec.ph.Path(key) if err := proof.validateBasic(spec); err != nil { @@ -338,10 +338,13 @@ func verifyProofWithUpdates( // Determine what the leaf hash should be. var currentHash, currentData []byte - if bytes.Equal(value, defaultEmptyValue) { // Non-membership proof. - if proof.NonMembershipLeafData == nil { // Leaf is a placeholder value. + if bytes.Equal(value, defaultEmptyValue) { + // Non-membership proof if `value` is empty. + if proof.NonMembershipLeafData == nil { + // Leaf is a placeholder value. currentHash = spec.placeholder() - } else { // Leaf is an unrelated leaf. + } else { + // Leaf is an unrelated leaf. var actualPath, valueHash []byte actualPath, valueHash = spec.parseLeafNode(proof.NonMembershipLeafData) if bytes.Equal(actualPath, path) { @@ -349,20 +352,17 @@ func verifyProofWithUpdates( return false, nil, errors.Join(ErrBadProof, errors.New("non-membership proof on related leaf")) } currentHash, currentData = spec.digestLeaf(actualPath, valueHash) - - update := make([][]byte, 2) - update[0], update[1] = currentHash, currentData - updates = append(updates, update) } } else { - // Membership proof. + // Membership proof if `value` is non-empty. valueHash := spec.valueHash(value) currentHash, currentData = spec.digestLeaf(path, valueHash) - update := make([][]byte, 2) - update[0], update[1] = currentHash, currentData - updates = append(updates, update) } + update := make([][]byte, 2) + update[0], update[1] = currentHash, currentData + updates = append(updates, update) + // Recompute root. for i := 0; i < len(proof.SideNodes); i++ { node := make([]byte, spec.hashSize()) diff --git a/smst_example_test.go b/smst_example_test.go index 65dc8ff..a640c2c 100644 --- a/smst_example_test.go +++ b/smst_example_test.go @@ -32,6 +32,9 @@ func TestExampleSMST(t *testing.T) { err = smst.Update([]byte("bin"), []byte("nib"), 3) require.NoError(t, err) dataMap["bin"] = "nib" + err = smst.Update([]byte("binn"), []byte("nnib"), 3) + require.NoError(t, err) + dataMap["binn"] = "nnib" // Commit the changes to the nodeStore err = smst.Commit() @@ -39,7 +42,7 @@ func TestExampleSMST(t *testing.T) { // Calculate the total sum of the trie sum := smst.Sum() - require.Equal(t, uint64(20), sum) + require.Equal(t, uint64(23), sum) // Generate a Merkle proof for "foo" proof1, err := smst.Prove([]byte("foo")) @@ -114,6 +117,7 @@ func exportToCSV( } */ + fmt.Println("Root sum", smst.Sum()) helper(t, smst, nodeStore, smst.Root()) } @@ -131,12 +135,19 @@ func helper(t *testing.T, smst SparseMerkleSumTrie, nodeStore kvstore.MapStore, helper(t, smst, nodeStore, childData) return } else if isLeafNode(node) { - path, value := smst.Spec().parseLeafNode(node) - fmt.Println("leaf node sum", 0) - fmt.Println(path, value) + path, valueHash, weight := smst.Spec().parseSumLeafNode(node) + fmt.Println("leaf node weight", weight) + fmt.Println(path, valueHash) } else if isInnerNode(node) { leftData, rightData, sum := smst.Spec().th.parseSumInnerNode(node) fmt.Println("inner node sum", sum) + + // leftChild, err := nodeStore.Get(leftData) + // require.NoError(t, err) + + // rightChild, err := nodeStore.Get(rightData) + // require.NoError(t, err) + helper(t, smst, nodeStore, leftData) helper(t, smst, nodeStore, rightData) } diff --git a/trie_spec.go b/trie_spec.go index f5008fc..55a0af0 100644 --- a/trie_spec.go +++ b/trie_spec.go @@ -237,6 +237,22 @@ func (spec *TrieSpec) parseExtNode(data []byte) (pathBounds, path, childData []b return } +// parseSumLeafNode parses a leafNode and returns its weight as well +func (spec *TrieSpec) parseSumLeafNode(data []byte) (path, value []byte, weight uint64) { + // panics if not a leaf node + checkPrefix(data, leafNodePrefix) + + path = data[prefixLen : prefixLen+spec.ph.PathSize()] + value = data[prefixLen+spec.ph.PathSize():] + + // Extract the sum from the encoded node data + var weightBz [sumSizeBits]byte + copy(weightBz[:], value[len(value)-sumSizeBits:]) + binary.BigEndian.PutUint64(weightBz[:], weight) + + return +} + // parseSumExtNode parses the pathBounds, path, child data and sum from the encoded extension node data func (spec *TrieSpec) parseSumExtNode(data []byte) (pathBounds, path, childData []byte, sum uint64) { // panics if not an extension node From 51999693457b1f0014afa42d1c69765377a21c95 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Mon, 8 Apr 2024 16:00:05 -0700 Subject: [PATCH 22/40] Review submittd PR --- .gitignore | 3 +++ docs/mapstore.md | 20 ++++++++++---------- docs/smt.md | 42 +++++++++++++++++------------------------- proofs.go | 6 +++--- smt.go | 4 ++-- 5 files changed, 35 insertions(+), 40 deletions(-) diff --git a/.gitignore b/.gitignore index aa638db..55b4fed 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,6 @@ # Ignore Goland and JetBrains IDE files .idea/ + +# Visual Studio Code +.vscode diff --git a/docs/mapstore.md b/docs/mapstore.md index d36fc95..ebaf965 100644 --- a/docs/mapstore.md +++ b/docs/mapstore.md @@ -1,21 +1,17 @@ -# MapStore - - +# MapStore +- [Introduction](#introduction) - [Implementations](#implementations) - [SimpleMap](#simplemap) - [BadgerV4](#badgerv4) +- [Note On External Writability](#note-on-external-writability) - +## Introduction The `MapStore` is a simple interface used by the SM(S)T to store, delete and retrieve key-value pairs. It is intentionally simple and minimalistic so as to enable different key-value engines to implement and back the trie database. -Any key-value store used by the tries should **not** be able to be externally -writeable in production. This opens the possibility to attacks where the writer -can modify the trie database and prove values that were not inserted. - See: [the interface](../kvstore/interfaces.go) for a more detailed description of the simple interface required by the SM(S)T. @@ -35,11 +31,15 @@ See [simplemap.go](../kvstore/simplemap/simplemap.go) for more details. ### BadgerV4 -This library provides a wrapper around [dgraph-io/badger][badgerv4] to adhere to +This library provides a wrapper around [dgraph-io/badger][https://github.com/dgraph-io/badger] to adhere to the `MapStore` interface. See the [full documentation](./badger-store.md) for additional functionality and implementation details. See: [badger](../kvstore/badger/) for more details on the implementation of this submodule. -[badgerv4]: https://github.com/dgraph-io/badger +## Note On External Writability + +Any key-value store used by the tries should **not** be able to be externally +writeable in production. This opens the possibility to attacks where the writer +can modify the trie database and prove values that were not inserted. diff --git a/docs/smt.md b/docs/smt.md index 19cdbf7..b7bdfad 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -1,6 +1,4 @@ -# smt - - +# smt - [Overview](#overview) - [Implementation](#implementation) @@ -16,7 +14,8 @@ - [Visualisation](#visualisation) - [Values](#values) - [Nil values](#nil-values) -- [Hashers & Digests](#hashers--digests) +- [Hashers \& Digests](#hashers--digests) + - [Hash Function Recommendations](#hash-function-recommendations) - [Roots](#roots) - [Proofs](#proofs) - [Verification](#verification) @@ -31,8 +30,6 @@ - [Data Loss](#data-loss) - [Sparse Merkle Sum Trie](#sparse-merkle-sum-trie) - - ## Overview Sparse Merkle Tries (SMTs) are efficient and secure data structures for storing @@ -44,25 +41,6 @@ make SMTs valuable in applications like blockchains, decentralized databases, and authenticated data structures, providing optimized and trustworthy data storage and verification. -Although any hash function that satisfies the `hash.Hash` interface can be used -to construct the trie it is **strongly recommended** to use a hashing function -that provides the following properties: - -- **Collision resistance**: The hash function must be collision resistant, in - order for the inputs of the SMT to be unique. -- **Preimage resistance**: The hash function must be preimage resistant, to - protect against the attack of the Merkle tree construction attacks where the - attacker can modify unknown data. -- **Efficiency**: The hash function must be efficient, as it is used to compute - the hash of many nodes in the trie. - -Therefore it is recommended to use a hashing function such as: - -- `sha256` -- `sha3_256`/`keccak256` - -Or another sufficiently secure hashing algorithm. - See [smt.go](../smt.go) for more details on the implementation. ## Implementation @@ -362,6 +340,20 @@ graph TD VH --ValueHash-->L ``` +### Hash Function Recommendations + +Although any hash function that satisfies the `hash.Hash` interface can be used +to construct the trie, it is **strongly recommended** to use a hashing function +that provides the following properties: + +- **Collision resistance**: The hash function must be collision resistant. This + is needed in order for the inputs of the SMT to be unique. +- **Preimage resistance**: The hash function must be preimage resistant. This + is needed to protect against the Merkle tree construction attacks where + the attacker can modify unknown data. +- **Efficiency**: The hash function must be efficient, as it is used to compute + the hash of many nodes in the trie. + ## Roots The root of the tree is a slice of bytes. `MerkleRoot` is an alias for `[]byte`. diff --git a/proofs.go b/proofs.go index 12d3de8..6a09fe9 100644 --- a/proofs.go +++ b/proofs.go @@ -341,9 +341,9 @@ func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *Trie if err := proof.validateBasic(spec); err != nil { return false, errors.Join(ErrBadProof, err) } - // Create a new TrieSpec with a nil path hasher - as the ClosestProof - // already contains a hashed path, double hashing it will invalidate the - // proof. + // Create a new TrieSpec with a nil path hasher. + // Since the ClosestProof already contains a hashed path, double hashing it + // will invalidate the proof. nilSpec := &TrieSpec{ th: spec.th, ph: newNilPathHasher(spec.ph.PathSize()), diff --git a/smt.go b/smt.go index 7928c40..0bcc6df 100644 --- a/smt.go +++ b/smt.go @@ -177,7 +177,6 @@ func (smt *SMT) update( return newLeaf, nil } // We insert an "extension" representing multiple single-branch inner nodes - last := &node var newInner *innerNode if getPathBit(path, prefixlen) == left { newInner = &innerNode{ @@ -191,6 +190,7 @@ func (smt *SMT) update( } } // Determine if we need to insert an extension or a branch + last := &node if depth < prefixlen { // note: this keeps path slice alive - GC inefficiency? if depth > 0xff { @@ -419,7 +419,7 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) { // node is encountered, the traversal backsteps and flips the path bit for that // depth (ie tries left if it tried right and vice versa). This guarantees that // a proof of inclusion is found that has the most common bits with the path -// provided, biased to the longest common prefix +// provided, biased to the longest common prefix. func (smt *SMT) ProveClosest(path []byte) ( proof *SparseMerkleClosestProof, // proof of the key-value pair found err error, // the error value encountered From c96cb4a9d769565e1c06c13c1d5863faa301b6e5 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Mon, 3 Jun 2024 11:56:39 -0700 Subject: [PATCH 23/40] Fix tests --- smt.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smt.go b/smt.go index a541d51..897707c 100644 --- a/smt.go +++ b/smt.go @@ -34,7 +34,7 @@ func NewSparseMerkleTrie( options ...TrieSpecOption, ) *SMT { smt := SMT{ - TrieSpec: NewTrieSpec(hasher, false), + TrieSpec: newTrieSpec(hasher, false), nodes: nodes, } for _, option := range options { From 47b74de518347f5470ce28bf2e2e2a3a71b02cd4 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Mon, 3 Jun 2024 12:48:06 -0700 Subject: [PATCH 24/40] Self review --- docs/audit.md | 91 --------------------------------------------------- docs/faq.md | 2 -- docs/smt.md | 4 +-- e2e_review.md | 67 ------------------------------------- hasher.go | 4 +-- options.go | 2 +- smst.go | 4 +-- trie_spec.go | 2 +- types.go | 2 +- 9 files changed, 9 insertions(+), 169 deletions(-) delete mode 100644 docs/audit.md delete mode 100644 e2e_review.md diff --git a/docs/audit.md b/docs/audit.md deleted file mode 100644 index f1852f1..0000000 --- a/docs/audit.md +++ /dev/null @@ -1,91 +0,0 @@ -# Audit - -- [Audit](#audit) - -## Pre-Audit Checklist - -### Checklist Requirements - -Pre-Audit Checklist: - -A reminder to provide us with the must haves 3 business days prior to the audit start days to avoid delays: - -**Must haves:** - -- A URL of the repository containing the source code -- The release branch / commit hash to be reviewed -- An explicit list of files in scope and out of scope for the security audit -- Robust and comprehensive documentation describing the intended functionality of the system - -**Nice-to-haves:** - -- Clear instructions for setting up the system and run the tests (usually in the README file) -- Any past audits -- Any tooling output logs -- Output generated by running the test suite -- Test coverage report - -Please disregard what may be irrelevant in the above list for this particular audit. -Reminder that we cannot accept any scope changes during the course of the audit. - -And more on audit preparation in this blogpost: https://medium.com/thesis-defense/why-crypto-needs-security-audits-d12f3909ac21 - thanks! - -### Checklist Response - -**Repository**: https://github.com/pokt-network/smt - -- **Branch**: `main` -- **Hash**: `868237978c0b3c0e2added161b36eeb7a3dc93b0` - -**Documentation** - -- **Background**: [Relay Mining](https://arxiv.org/abs/2305.10672) -- **Technical Documentation**: https://github.com/pokt-network/smt/tree/main/docs -- _NOTE: we may integrate this into [https://dev.poktroll.com](https://dev.poktroll.com/) (which is out of scope) in the future_ - -**Files** - -- Nothing is explicitly out of scope but the focus should be on the files below -- The following is a manually filtered list of files after running `tree -P '*.go' -I '*_test.go'` - -```bash - . - ├── errors.go - ├── hasher.go - ├── kvstore - │   ├── badger - │   │   ├── errors.go - │   │   ├── godoc.go - │   │   ├── interface.go - │   │   └── kvstore.go - │   ├── interfaces.go - ├── options.go - ├── proofs.go - ├── smst.go - ├── smt.go - ├── types.go - └── utils.go -``` - -**Makefile** - -- Running `make` in the root of the repo shows a list of options -- This gives access to tests, benchmarks, etc... - -```bash -make -help Prints all the targets in all the Makefiles -list List all make targets -test_all runs the test suite -test_badger runs the badger KVStore submodule's test suite -mod_tidy runs go mod tidy for all (sub)modules -go_docs Generate documentation for the project -benchmark_all runs all benchmarks -benchmark_smt runs all benchmarks for the SMT -benchmark_smt_fill runs a benchmark on filling the SMT with different amounts of values -benchmark_smt_ops runs the benchmarks testing different operations on the SMT against different sized tries -benchmark_smst runs all benchmarks for the SMST -benchmark_smst_fill runs a benchmark on filling the SMST with different amounts of values -benchmark_smst_ops runs the benchmarks test different operations on the SMST against different sized tries -benchmark_proof_sizes runs the benchmarks test the proof sizes for different sized tries -``` diff --git a/docs/faq.md b/docs/faq.md index d66b496..7dcd4c4 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -23,5 +23,3 @@ The [SMT extension node](./smt.md#extension-nodes) is very similar to that of Ethereum's [Modified Merkle Patricia Trie](https://ethereum.org/developers/docs/data-structures-and-encoding/patricia-merkle-trie). A quick primer on it can be found in this [5P;1R post](https://olshansky.substack.com/p/5p1r-ethereums-modified-merkle-patricia). - -WIP diff --git a/docs/smt.md b/docs/smt.md index eae4ddc..01b2ecc 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -86,8 +86,8 @@ Extension nodes represent a singly linked chain of inner nodes, with a single child. In other words, they are an optimization to avoid having a long chain of inner nodes where each inner node only has one child. -In other words, they are used to -represent a common path in the trie and as such contain the path and bounds of the path they represent. The `digest` of an extension +In other words, they are used to represent a common path in the trie and as such +contain the path and bounds of the path they represent. The `digest` of an extension node is the hash of its path bounds, the path itself and the child nodes digest concatenated. diff --git a/e2e_review.md b/e2e_review.md deleted file mode 100644 index 99da43b..0000000 --- a/e2e_review.md +++ /dev/null @@ -1,67 +0,0 @@ -# E2E Code Review - - -[ ] . - -[ ] ├── LICENSE - -[ ] ├── Makefile - -[ ] ├── README.md - -[ ] ├── benchmarks - -[ ] │   ├── bench_leaf_test.go - -[ ] │   ├── bench_smst_test.go - -[ ] │   ├── bench_smt_test.go - -[ ] │   ├── bench_utils_test.go - -[ ] │   └── proof_sizes_test.go - -[ ] ├── bulk_test.go - -[ ] ├── docs - -[ ] │   ├── audit.md - -[ ] │   ├── badger-store.md - -[ ] │   ├── benchmarks.md - -[ ] │   ├── faq.md - -[ ] │   ├── mapstore.md - -[ ] │   ├── merkle-sum-trie.md - -[ ] │   └── smt.md - -[ ] ├── errors.go - -[ ] ├── extension_node.go - -[ ] ├── fuzz_test.go - -[x] ├── go.mod - -[x] ├── go.sum - -[x] ├── go.work - -[x] ├── go.work.sum - -[x] ├── godoc.go - -[ ] ├── hasher.go - -[x] ├── inner_node.go - -[ ] ├── kvstore - -[ ] │   ├── badger - -[ ] │   │   ├── errors.go - -[ ] │   │   ├── go.mod - -[ ] │   │   ├── go.sum - -[ ] │   │   ├── godoc.go - -[ ] │   │   ├── interface.go - -[ ] │   │   ├── kvstore.go - -[ ] │   │   └── kvstore_test.go - -[ ] │   ├── godoc.go - -[ ] │   ├── interfaces.go - -[ ] │   └── simplemap - -[ ] │   ├── errors.go - -[ ] │   ├── godoc.go - -[ ] │   ├── simplemap.go - -[ ] │   └── simplemap_test.go - -[x] ├── lazy_node.go - -[x] ├── leaf_node.go - -[x] ├── node_encoders.go - -[x] ├── options.go - -[ ] ├── proofs.go - -[ ] ├── proofs_test.go - -[ ] ├── reviewpad.yml - -[ ] ├── root_test.go - -[ ] ├── smst.go - -[ ] ├── smst_example_test.go - -[ ] ├── smst_proofs_test.go - -[ ] ├── smst_test.go - -[ ] ├── smst_utils_test.go - -[ ] ├── smt.go - -[ ] ├── smt_example_test.go - -[ ] ├── smt_proofs_test.go - -[ ] ├── smt_test.go - -[ ] ├── smt_utils_test.go - -[ ] ├── types.go - -[ ] └── utils.go diff --git a/hasher.go b/hasher.go index 220eddd..dace46b 100644 --- a/hasher.go +++ b/hasher.go @@ -5,8 +5,8 @@ import ( "hash" ) -// TODO_IN_THIS_PR: Improve how the `hasher` file is consolidated (or not) -// with `node_encoders.go` since the two are very similar. +// TODO_IMPROVE:: Improve how the `hasher` file is consolidated with +// `node_encoders.go` since the two are very similar. // Ensure the hasher interfaces are satisfied var ( diff --git a/options.go b/options.go index 4556a4c..c3fea18 100644 --- a/options.go +++ b/options.go @@ -22,7 +22,7 @@ func WithValueHasher(vh ValueHasher) TrieSpecOption { // used instead of a key during proof verification. Otherwise, this will lead // double hashing and product an incorrect leaf digest, thereby invalidating // the proof. -// TODO_IN_THIS_PR: Need to understand this part more. +// TODO_TECHDEBT: Document better when/why this is needed. func NoHasherSpec(hasher hash.Hash, sumTrie bool) *TrieSpec { spec := newTrieSpec(hasher, sumTrie) diff --git a/smst.go b/smst.go index 64f8e94..27d6c28 100644 --- a/smst.go +++ b/smst.go @@ -33,8 +33,8 @@ func NewSparseMerkleSumTrie( } // Initialize a non-sum SMT and modify it to have a nil value hasher - // TODO_IN_THIS_PR: Understand the purpose of the nilValueHasher and why - // we're not applying it to the smst but we need it for the smt. + // TODO_UPNEXT(@Olshansk): Understand the purpose of the nilValueHasher and + // why we're not applying it to the smst but we need it for the smt. smt := &SMT{ TrieSpec: trieSpec, nodes: nodes, diff --git a/trie_spec.go b/trie_spec.go index 55a0af0..2b3ab73 100644 --- a/trie_spec.go +++ b/trie_spec.go @@ -112,7 +112,7 @@ func (spec *TrieSpec) hashSumSerialization(data []byte) []byte { // depth returns the maximum depth of the trie. // Since this tree is a binary tree, the depth is the number of bits in the path -// TODO_IN_THIS_PR: Try to understand why we're not taking the log of the output +// TODO_UPNEXT(@Olshansk):: Try to understand why we're not taking the log of the output func (spec *TrieSpec) depth() int { return spec.ph.PathSize() * 8 // path size is in bytes so multiply by 8 to get num bits } diff --git a/types.go b/types.go index 0be9de6..93ed09f 100644 --- a/types.go +++ b/types.go @@ -1,6 +1,6 @@ package smt -// TODO_DISCUSS_IN_THIS_PR_IMPROVEMENTS: +// TODO_DISCUSS_IN_THE_FUTURE: // 1. Should we rename all instances of digest to hash? // 2. Should we introduce a shared interface between SparseMerkleTrie and SparseMerkleSumTrie? // 3. Should we rename Commit to FlushToDisk? From 92d467ea4f2190b3b4a7d1cb336f9fb5ac82f62e Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Mon, 3 Jun 2024 12:48:27 -0700 Subject: [PATCH 25/40] Revert example --- smst_example_test.go | 155 ++++++------------------------------------- 1 file changed, 22 insertions(+), 133 deletions(-) diff --git a/smst_example_test.go b/smst_example_test.go index a640c2c..f1cdbca 100644 --- a/smst_example_test.go +++ b/smst_example_test.go @@ -1,159 +1,48 @@ -package smt +package smt_test import ( "crypto/sha256" "fmt" - "testing" - "github.com/stretchr/testify/require" - - "github.com/pokt-network/smt/kvstore" + "github.com/pokt-network/smt" "github.com/pokt-network/smt/kvstore/simplemap" ) -// TestExampleSMT is a test that aims to act as an example of how to use the SMST. -func TestExampleSMST(t *testing.T) { - dataMap := make(map[string]string) - - // Initialize a new in-memory key-value store to store the nodes of the trie. - // NB: The trie only stores hashed values and not raw value data. +func ExampleSMST() { + // Initialise a new in-memory key-value store to store the nodes of the trie + // (Note: the trie only stores hashed values, not raw value data) nodeStore := simplemap.NewSimpleMap() - // Initialize the smst - smst := NewSparseMerkleSumTrie(nodeStore, sha256.New()) + // Initialise the trie + trie := smt.NewSparseMerkleSumTrie(nodeStore, sha256.New()) // Update trie with keys, values and their sums - err := smst.Update([]byte("foo"), []byte("oof"), 10) - require.NoError(t, err) - dataMap["foo"] = "oof" - err = smst.Update([]byte("baz"), []byte("zab"), 7) - require.NoError(t, err) - dataMap["baz"] = "zab" - err = smst.Update([]byte("bin"), []byte("nib"), 3) - require.NoError(t, err) - dataMap["bin"] = "nib" - err = smst.Update([]byte("binn"), []byte("nnib"), 3) - require.NoError(t, err) - dataMap["binn"] = "nnib" + _ = trie.Update([]byte("foo"), []byte("oof"), 10) + _ = trie.Update([]byte("baz"), []byte("zab"), 7) + _ = trie.Update([]byte("bin"), []byte("nib"), 3) // Commit the changes to the nodeStore - err = smst.Commit() - require.NoError(t, err) + _ = trie.Commit() // Calculate the total sum of the trie - sum := smst.Sum() - require.Equal(t, uint64(23), sum) + _ = trie.Sum() // 20 // Generate a Merkle proof for "foo" - proof1, err := smst.Prove([]byte("foo")) - require.NoError(t, err) - proof2, err := smst.Prove([]byte("baz")) - require.NoError(t, err) - proof3, err := smst.Prove([]byte("bin")) - require.NoError(t, err) + proof1, _ := trie.Prove([]byte("foo")) + proof2, _ := trie.Prove([]byte("baz")) + proof3, _ := trie.Prove([]byte("bin")) // We also need the current trie root for the proof - root := smst.Root() + root := trie.Root() // Verify the Merkle proof for "foo"="oof" where "foo" has a sum of 10 - valid_true1, err := VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 10, smst.Spec()) - require.NoError(t, err) - require.True(t, valid_true1) - + valid_true1, _ := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 10, trie.Spec()) // Verify the Merkle proof for "baz"="zab" where "baz" has a sum of 7 - valid_true2, err := VerifySumProof(proof2, root, []byte("baz"), []byte("zab"), 7, smst.Spec()) - require.NoError(t, err) - require.True(t, valid_true2) - + valid_true2, _ := smt.VerifySumProof(proof2, root, []byte("baz"), []byte("zab"), 7, trie.Spec()) // Verify the Merkle proof for "bin"="nib" where "bin" has a sum of 3 - valid_true3, err := VerifySumProof(proof3, root, []byte("bin"), []byte("nib"), 3, smst.Spec()) - require.NoError(t, err) - require.True(t, valid_true3) - + valid_true3, _ := smt.VerifySumProof(proof3, root, []byte("bin"), []byte("nib"), 3, trie.Spec()) // Fail to verify the Merkle proof for "foo"="oof" where "foo" has a sum of 11 - valid_false1, err := VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 11, smst.Spec()) - require.NoError(t, err) - require.False(t, valid_false1) - - exportToCSV(t, smst, dataMap, nodeStore) -} - -func exportToCSV( - t *testing.T, - smst SparseMerkleSumTrie, - innerMap map[string]string, - nodeStore kvstore.MapStore, -) { - t.Helper() - /* - rootHash := smst.Root() - rootNode, err := nodeStore.Get(rootHash) - require.NoError(t, err) - leftData, rightData := smst.Spec().th.parseSumInnerNode(rootNode) - leftChild, err := nodeStore.Get(leftData) - require.NoError(t, err) - rightChild, err := nodeStore.Get(rightData) - require.NoError(t, err) - fmt.Println("Prefix", "isExt", "isLeaf", "isInner") - // false false true - fmt.Println("root", isExtNode(rootNode), isLeafNode(rootNode), isInnerNode(rootNode), rootNode) - fmt.Println() - // false false false - fmt.Println("left", isExtNode(leftChild), isLeafNode(leftChild), isInnerNode(leftChild), leftChild) - fmt.Println() - // false false false - fmt.Println("right", isExtNode(rightChild), isLeafNode(rightChild), isInnerNode(rightChild), rightChild) - fmt.Println() - */ - - /* - for key, value := range innerMap { - v, s, err := smst.Get([]byte(key)) - require.NoError(t, err) - fmt.Println(v, s, []byte(key)) - fmt.Println(value) - fmt.Println("") - fmt.Println("") - } - */ - - fmt.Println("Root sum", smst.Sum()) - helper(t, smst, nodeStore, smst.Root()) -} - -func helper(t *testing.T, smst SparseMerkleSumTrie, nodeStore kvstore.MapStore, nodeDigest []byte) { - t.Helper() - - node, err := nodeStore.Get(nodeDigest) - require.NoError(t, err) - - fmt.Println() - if isExtNode(node) { - pathBounds, path, childData, sum := smst.Spec().parseSumExtNode(node) - fmt.Println("ext node sum", sum) - fmt.Println(pathBounds, path) - helper(t, smst, nodeStore, childData) - return - } else if isLeafNode(node) { - path, valueHash, weight := smst.Spec().parseSumLeafNode(node) - fmt.Println("leaf node weight", weight) - fmt.Println(path, valueHash) - } else if isInnerNode(node) { - leftData, rightData, sum := smst.Spec().th.parseSumInnerNode(node) - fmt.Println("inner node sum", sum) - - // leftChild, err := nodeStore.Get(leftData) - // require.NoError(t, err) - - // rightChild, err := nodeStore.Get(rightData) - // require.NoError(t, err) - - helper(t, smst, nodeStore, leftData) - helper(t, smst, nodeStore, rightData) - } - - // v, s, err := smst.Get([]byte(key)) - // require.NoError(t, err) - // require.Equal(t, []byte(value), v) - // require.Equal(t, sum, s) + valid_false1, _ := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 11, trie.Spec()) + fmt.Println(valid_true1, valid_true2, valid_true3, valid_false1) + // Output: true true true false } From b322e33246e4fbaa06806263307aad52343f2650 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Mon, 3 Jun 2024 18:03:53 -0700 Subject: [PATCH 26/40] Update extension_node.go Co-authored-by: h5law Signed-off-by: Daniel Olshansky --- extension_node.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension_node.go b/extension_node.go index 38282d3..379793d 100644 --- a/extension_node.go +++ b/extension_node.go @@ -131,7 +131,7 @@ func (extNode *extensionNode) split(path []byte) (trieNode, *trieNode, int) { return head, &b, pathIdx } -// expand returns the inner node that represents the start of the singly +// expand returns the inner node that represents the end of the singly // linked list that this extension node represents func (extNode *extensionNode) expand() trieNode { last := extNode.child From 186ca40f000bf16664ef30b9037f97d7177cbc6f Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Mon, 3 Jun 2024 18:04:01 -0700 Subject: [PATCH 27/40] Update docs/smt.md Co-authored-by: h5law Signed-off-by: Daniel Olshansky --- docs/smt.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/smt.md b/docs/smt.md index 01b2ecc..4c15c06 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -61,7 +61,7 @@ Leaf nodes store the full path associated with the `key`. A leaf node also store the hash of the `value` stored. The `digest` of a leaf node is the hash of concatenation of the leaf node's -path and value. +prefix, path and value. By default, the SMT only stores the hashes of the values in the trie, and not the raw values themselves. In order to store the raw values in the underlying database, From 2594f1b0c497044a38bf735fed15f9913c837834 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Mon, 3 Jun 2024 18:04:33 -0700 Subject: [PATCH 28/40] Update docs/smt.md Co-authored-by: h5law Signed-off-by: Daniel Olshansky --- docs/smt.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/smt.md b/docs/smt.md index 4c15c06..edb02b9 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -86,7 +86,7 @@ Extension nodes represent a singly linked chain of inner nodes, with a single child. In other words, they are an optimization to avoid having a long chain of inner nodes where each inner node only has one child. -In other words, they are used to represent a common path in the trie and as such +They are used to represent a common path in the trie and as such contain the path and bounds of the path they represent. The `digest` of an extension node is the hash of its path bounds, the path itself and the child nodes digest concatenated. From 77ba35df22c36ba7abfed7b20c20dd47471a6777 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Mon, 3 Jun 2024 18:04:44 -0700 Subject: [PATCH 29/40] Update hasher.go Co-authored-by: h5law Signed-off-by: Daniel Olshansky --- hasher.go | 1 + 1 file changed, 1 insertion(+) diff --git a/hasher.go b/hasher.go index dace46b..8930bf0 100644 --- a/hasher.go +++ b/hasher.go @@ -47,6 +47,7 @@ type valueHasher struct { trieHasher } +// nilPathHasher is a dummy hasher that returns its input - it should not be used outside of the closest proof verification logic type nilPathHasher struct { hashSize int } From a54669b509f7132b94c62dc03539d48601cb7452 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Mon, 3 Jun 2024 18:13:52 -0700 Subject: [PATCH 30/40] Before adding helper --- Makefile | 4 ++-- options.go | 25 ------------------------- root_test.go | 2 ++ types.go | 5 ++++- 4 files changed, 8 insertions(+), 28 deletions(-) diff --git a/Makefile b/Makefile index 6d5ab99..c9d4774 100644 --- a/Makefile +++ b/Makefile @@ -34,11 +34,11 @@ check_godoc: .PHONY: test_all test_all: ## runs the test suite - go test -v -p 1 ./... -mod=readonly -race + go test -v -p 1 -count=1 ./... -mod=readonly -race .PHONY: test_badger test_badger: ## runs the badger KVStore submodule's test suite - go test -v -p 1 ./kvstore/badger/... -mod=readonly -race + go test -v -p 1 -count=1 ./kvstore/badger/... -mod=readonly -race ##################### diff --git a/options.go b/options.go index c3fea18..ed6a75a 100644 --- a/options.go +++ b/options.go @@ -1,9 +1,5 @@ package smt -import ( - "hash" -) - // TrieSpecOption is a function that configures SparseMerkleTrie. type TrieSpecOption func(*TrieSpec) @@ -16,24 +12,3 @@ func WithPathHasher(ph PathHasher) TrieSpecOption { func WithValueHasher(vh ValueHasher) TrieSpecOption { return func(ts *TrieSpec) { ts.vh = vh } } - -// NoHasherSpec returns a new TrieSpec that has nil ValueHasher & PathHasher specs. -// NB: This should only be used when values are already hashed and a path is -// used instead of a key during proof verification. Otherwise, this will lead -// double hashing and product an incorrect leaf digest, thereby invalidating -// the proof. -// TODO_TECHDEBT: Document better when/why this is needed. -func NoHasherSpec(hasher hash.Hash, sumTrie bool) *TrieSpec { - spec := newTrieSpec(hasher, sumTrie) - - // Set a nil path hasher - opt := WithPathHasher(NewNilPathHasher(hasher.Size())) - opt(&spec) - - // Set a nil value hasher - opt = WithValueHasher(nil) - opt(&spec) - - // Return the spec - return &spec -} diff --git a/root_test.go b/root_test.go index d6e588b..08f8b6d 100644 --- a/root_test.go +++ b/root_test.go @@ -59,6 +59,8 @@ func TestMerkleRoot_TrieTypes(t *testing.T) { require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i)) } require.NotNil(t, trie.Sum()) + require.EqualValues(t, 45, trie.Sum()) + return } trie := smt.NewSparseMerkleTrie(nodeStore, tt.hasher) diff --git a/types.go b/types.go index 93ed09f..8ef9a8b 100644 --- a/types.go +++ b/types.go @@ -1,9 +1,12 @@ package smt -// TODO_DISCUSS_IN_THE_FUTURE: +// TODO_DISCUSS_CONSIDERIN_THE_FUTURE: // 1. Should we rename all instances of digest to hash? +// > digest is the correct term for the output of a hashing function IIRC // 2. Should we introduce a shared interface between SparseMerkleTrie and SparseMerkleSumTrie? +// > Sum() would have to be no-op but could be done // 3. Should we rename Commit to FlushToDisk? +// > No because what if this is an in memory trie? const ( // The bit value use to distinguish an inner nodes left child and right child From 9b48300c30abd193698321f5c9f3f5fb1afcadf3 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Mon, 3 Jun 2024 18:20:10 -0700 Subject: [PATCH 31/40] Added helpers --- smt.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/smt.go b/smt.go index 897707c..8bac991 100644 --- a/smt.go +++ b/smt.go @@ -558,7 +558,12 @@ func (smt *SMT) resolveNode(digest []byte) (trieNode, error) { return nil, err } - // Return the appropriate node type based on the first byte of the data + return smt.parseTrieNode(data, digest) +} + +// parseTrieNode returns a trieNode (inner, leaf, or extension) based on the +// first byte of the data. +func (smt *SMT) parseTrieNode(data, digest []byte) (trieNode, error) { if isLeafNode(data) { path, valueHash := smt.parseLeafNode(data) return &leafNode{ @@ -603,7 +608,12 @@ func (smt *SMT) resolveSumNode(digest []byte) (trieNode, error) { return nil, err } - // Return the appropriate node type based on the first byte of the data + return smt.parseSumTrieNode(data, digest) +} + +// parseTrieNode returns a trieNode (inner, leaf, or extension) based on the +// first byte of the data. +func (smt *SMT) parseSumTrieNode(data, digest []byte) (trieNode, error) { if isLeafNode(data) { path, valueHash := smt.parseLeafNode(data) return &leafNode{ From 1a144d1fe1fb98599b853be7f552b8d458140c4e Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Mon, 3 Jun 2024 18:30:23 -0700 Subject: [PATCH 32/40] Reply to harry's comments --- hasher.go | 9 ++++++++- node_encoders.go | 8 +++++--- proofs.go | 2 +- smst.go | 15 ++++++++++----- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/hasher.go b/hasher.go index 8930bf0..6cc9574 100644 --- a/hasher.go +++ b/hasher.go @@ -59,7 +59,9 @@ func NewTrieHasher(hasher hash.Hash) *trieHasher { return &th } -func NewNilPathHasher(hasherSize int) PathHasher { +// newNilPathHasher returns a new nil path hasher with the given hash size. +// It is not exported the validation logic for the ClosestProof automatically handles this. +func newNilPathHasher(hasherSize int) PathHasher { return &nilPathHasher{hashSize: hasherSize} } @@ -112,12 +114,14 @@ func (th *trieHasher) digestLeafNode(path, data []byte) (digest, value []byte) { return } +// digestInnerNode returns the encoded inner node data as well as its hash (i.e. digest) func (th *trieHasher) digestInnerNode(leftData, rightData []byte) (digest, value []byte) { value = encodeInnerNode(leftData, rightData) digest = th.digestData(value) return } +// digestSumNode returns the encoded leaf node data as well as its hash (i.e. digest) func (th *trieHasher) digestSumLeafNode(path, data []byte) (digest, value []byte) { value = encodeLeafNode(path, data) digest = th.digestData(value) @@ -125,6 +129,7 @@ func (th *trieHasher) digestSumLeafNode(path, data []byte) (digest, value []byte return } +// digestSumInnerNode returns the encoded inner node data as well as its hash (i.e. digest) func (th *trieHasher) digestSumInnerNode(leftData, rightData []byte) (digest, value []byte) { value = encodeSumInnerNode(leftData, rightData) digest = th.digestData(value) @@ -132,12 +137,14 @@ func (th *trieHasher) digestSumInnerNode(leftData, rightData []byte) (digest, va return } +// parseInnerNode returns the encoded left and right nodes func (th *trieHasher) parseInnerNode(data []byte) (leftData, rightData []byte) { leftData = data[len(innerNodePrefix) : th.hashSize()+len(innerNodePrefix)] rightData = data[len(innerNodePrefix)+th.hashSize():] return } +// parseSumInnerNode returns the encoded left and right nodes as well as the sum of the current node func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte, sum uint64) { // Extract the sum from the encoded node data var sumBz [sumSizeBits]byte diff --git a/node_encoders.go b/node_encoders.go index 65ef604..27bf24d 100644 --- a/node_encoders.go +++ b/node_encoders.go @@ -8,9 +8,11 @@ import ( // TODO_TECHDEBT: All of the parsing, encoding and checking functions in this file // can be abstracted out into the `trieNode` interface. -// TODO_IMPROVE: We should create well-defined types & structs for every type of node -// (e.g. protobufs) to streamline the process of encoding & encoding and to improve -// readability. +// TODO_IMPROVE: We should create well-defined structs for every type of node +// to streamline the process of encoding & encoding and to improve readability. +// If decoding needs to be language agnostic (to implement POKT clients), in other +// languages, protobufs should be considered. If decoding does not need to be +// language agnostic, we can use Go's gob package for more efficient serialization. // NB: In this file, all references to the variable `data` should be treated as `encodedNodeData`. // It was abbreviated to `data` for brevity. diff --git a/proofs.go b/proofs.go index 3b43dbc..c8f0497 100644 --- a/proofs.go +++ b/proofs.go @@ -353,7 +353,7 @@ func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *Trie // will invalidate the proof. nilSpec := &TrieSpec{ th: spec.th, - ph: NewNilPathHasher(spec.ph.PathSize()), + ph: newNilPathHasher(spec.ph.PathSize()), vh: spec.vh, sumTrie: spec.sumTrie, } diff --git a/smst.go b/smst.go index 27d6c28..b87647e 100644 --- a/smst.go +++ b/smst.go @@ -32,9 +32,14 @@ func NewSparseMerkleSumTrie( option(&trieSpec) } - // Initialize a non-sum SMT and modify it to have a nil value hasher - // TODO_UPNEXT(@Olshansk): Understand the purpose of the nilValueHasher and - // why we're not applying it to the smst but we need it for the smt. + // Initialize a non-sum SMT and modify it to have a nil value hasher. + // NB: We are using a nil value hasher because the SMST pre-hashes its paths. + // This results result in double path hashing because the SMST is a wrapper + // around the SMT. The reason the SMST uses its own path hashing logic is + // to account for the additional sum in the encoding/decoding process. + // Therefore, the underlying SMT underneath needs a nil path hasher, while + // the outer SMST does all the (non nil) path hashing itself. + // TODO_TECHDEBT(@Olshansk): Look for ways to simplify / cleanup the above. smt := &SMT{ TrieSpec: trieSpec, nodes: nodes, @@ -146,8 +151,8 @@ func (smst *SMST) Root() MerkleRoot { // If the tree is not a sum tree, it will panic. func (smst *SMST) Sum() uint64 { rootDigest := smst.Root() - if len(rootDigest) != smst.th.hashSize()+sumSizeBits { - panic("roo#sum: not a merkle sum trie") + if !smst.Spec().sumTrie { + panic("SMST: not a merkle sum trie") } var sumbz [sumSizeBits]byte copy(sumbz[:], []byte(rootDigest)[len([]byte(rootDigest))-sumSizeBits:]) From 3f436405d0f11e6fa1dafb321a84af04c9b8a647 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Mon, 3 Jun 2024 18:34:21 -0700 Subject: [PATCH 33/40] Added nolint --- trie_spec.go | 1 + 1 file changed, 1 insertion(+) diff --git a/trie_spec.go b/trie_spec.go index 2b3ab73..3cb821b 100644 --- a/trie_spec.go +++ b/trie_spec.go @@ -238,6 +238,7 @@ func (spec *TrieSpec) parseExtNode(data []byte) (pathBounds, path, childData []b } // parseSumLeafNode parses a leafNode and returns its weight as well +// // nolint: unused func (spec *TrieSpec) parseSumLeafNode(data []byte) (path, value []byte, weight uint64) { // panics if not a leaf node checkPrefix(data, leafNodePrefix) From 32c5f24967fa605dedb3721e3c7c089d59e47e60 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Tue, 4 Jun 2024 19:45:34 -0700 Subject: [PATCH 34/40] Update docs/smt.md Co-authored-by: h5law Signed-off-by: Daniel Olshansky --- docs/smt.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/smt.md b/docs/smt.md index edb02b9..eb9b55d 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -209,7 +209,7 @@ Paths are **only** stored in two types of nodes: `Leaf` nodes and `Extension` no - `Leaf` nodes contain: - The full path which it represent - - The value stored at that path + - The (hashed) value stored at that path - `Extension` nodes contain: - not only the path they represent but also the path bounds (ie. the start and end of the path they cover). From 5ebf731005b4addaf50bc30dfb51cc5f0bdf68ed Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Tue, 4 Jun 2024 19:45:50 -0700 Subject: [PATCH 35/40] Update docs/smt.md Co-authored-by: h5law Signed-off-by: Daniel Olshansky --- docs/smt.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/smt.md b/docs/smt.md index eb9b55d..606928a 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -51,7 +51,7 @@ The SMT has 4 node types that are used to construct the trie: - [Extension Nodes](#extension-nodes) - [Leaf Nodes](#leaf-nodes) - [Lazy Nodes](#lazy-nodes) - - Prefix of the actual node type is stored in the persisted digest as + - Prefix of the actual node type is stored in the persisted preimage as determined above - `digest = persistedDigest` From cf7b7d5225954250c676ed75e2ad29d0071a8ef6 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Tue, 4 Jun 2024 19:45:57 -0700 Subject: [PATCH 36/40] Update docs/smt.md Co-authored-by: h5law Signed-off-by: Daniel Olshansky --- docs/smt.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/smt.md b/docs/smt.md index 606928a..0dc75bb 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -212,7 +212,7 @@ Paths are **only** stored in two types of nodes: `Leaf` nodes and `Extension` no - The (hashed) value stored at that path - `Extension` nodes contain: - not only the path they represent but also the path - bounds (ie. the start and end of the path they cover). + bounds (ie. the start and end of the path that they cover). Inner nodes do **not** contain a path, as they represent a branch in the trie and not a path. As such their children, _if they are extension nodes or leaf From 583f57d5cc5e24104f178204487755a8deffd38d Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Tue, 4 Jun 2024 19:47:20 -0700 Subject: [PATCH 37/40] Update extension_node.go Co-authored-by: h5law Signed-off-by: Daniel Olshansky --- extension_node.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension_node.go b/extension_node.go index 379793d..38282d3 100644 --- a/extension_node.go +++ b/extension_node.go @@ -131,7 +131,7 @@ func (extNode *extensionNode) split(path []byte) (trieNode, *trieNode, int) { return head, &b, pathIdx } -// expand returns the inner node that represents the end of the singly +// expand returns the inner node that represents the start of the singly // linked list that this extension node represents func (extNode *extensionNode) expand() trieNode { last := extNode.child From eb543a31a1badac405fedbe712f7efc31f2e0464 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Tue, 4 Jun 2024 20:31:22 -0700 Subject: [PATCH 38/40] Update hasher.go Co-authored-by: h5law Signed-off-by: Daniel Olshansky --- hasher.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hasher.go b/hasher.go index 6cc9574..5b28250 100644 --- a/hasher.go +++ b/hasher.go @@ -60,7 +60,7 @@ func NewTrieHasher(hasher hash.Hash) *trieHasher { } // newNilPathHasher returns a new nil path hasher with the given hash size. -// It is not exported the validation logic for the ClosestProof automatically handles this. +// It is not exported as the validation logic for the ClosestProof automatically handles this case. func newNilPathHasher(hasherSize int) PathHasher { return &nilPathHasher{hashSize: hasherSize} } From 3d3e3c8d5f837c7caada667cfc741427f3a07c21 Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Tue, 4 Jun 2024 20:45:02 -0700 Subject: [PATCH 39/40] Replied to remaining review comments --- docs/smt.md | 10 ++++++---- extension_node.go | 6 +++--- hasher.go | 14 +++++++------- node_encoders.go | 10 ++++++---- options.go | 1 + proofs.go | 8 ++++---- proofs_test.go | 4 ++-- smst.go | 18 +++++++++--------- smst_proofs_test.go | 6 +++--- smst_test.go | 2 +- smst_utils_test.go | 8 ++++---- smt.go | 17 +++++++++++++---- trie_spec.go | 16 ++++++++-------- types.go | 2 +- 14 files changed, 68 insertions(+), 54 deletions(-) diff --git a/docs/smt.md b/docs/smt.md index 0dc75bb..af421fd 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -86,10 +86,12 @@ Extension nodes represent a singly linked chain of inner nodes, with a single child. In other words, they are an optimization to avoid having a long chain of inner nodes where each inner node only has one child. -They are used to represent a common path in the trie and as such -contain the path and bounds of the path they represent. The `digest` of an extension -node is the hash of its path bounds, the path itself and the child nodes digest -concatenated. +They are used to represent a common path in the trie and as such contain the path +and bounds of the path they represent. + +The `digest` of an extension node is the hash of its path bounds, the path itself +and the child node digest. Note that an extension node can only have exactly one +child node. - _Prefix_: `[]byte{2}` - _Digest_: `hash([]byte{2} + pathBounds + path + child.digest)` diff --git a/extension_node.go b/extension_node.go index 38282d3..016ed84 100644 --- a/extension_node.go +++ b/extension_node.go @@ -16,7 +16,7 @@ type extensionNode struct { // inner nodes that this single extension node replaces. pathBounds [2]byte // A child node from this extension node. - // It MUST be either an innerNode or a lazyNode. + // It will always be an innerNode, leafNode or lazyNode. child trieNode // Bool whether or not the node has been flushed to disk persisted bool @@ -72,8 +72,8 @@ func (extNode *extensionNode) boundsMatch(path []byte, depth int) (int, bool) { return extNode.length(), true } -// split splits the node in-place by returning a new node at the extension node, -// a child node at the split and split depth. +// split splits the node in-place by returning a new extensionNode and a child +// node at the split and split depth. func (extNode *extensionNode) split(path []byte) (trieNode, *trieNode, int) { // Start path to extNode.pathBounds until there is no match var extNodeBit, pathBit int diff --git a/hasher.go b/hasher.go index 5b28250..3676a49 100644 --- a/hasher.go +++ b/hasher.go @@ -125,7 +125,7 @@ func (th *trieHasher) digestInnerNode(leftData, rightData []byte) (digest, value func (th *trieHasher) digestSumLeafNode(path, data []byte) (digest, value []byte) { value = encodeLeafNode(path, data) digest = th.digestData(value) - digest = append(digest, value[len(value)-sumSizeBits:]...) + digest = append(digest, value[len(value)-sumSizeBytes:]...) return } @@ -133,7 +133,7 @@ func (th *trieHasher) digestSumLeafNode(path, data []byte) (digest, value []byte func (th *trieHasher) digestSumInnerNode(leftData, rightData []byte) (digest, value []byte) { value = encodeSumInnerNode(leftData, rightData) digest = th.digestData(value) - digest = append(digest, value[len(value)-sumSizeBits:]...) + digest = append(digest, value[len(value)-sumSizeBytes:]...) return } @@ -147,14 +147,14 @@ func (th *trieHasher) parseInnerNode(data []byte) (leftData, rightData []byte) { // parseSumInnerNode returns the encoded left and right nodes as well as the sum of the current node func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte, sum uint64) { // Extract the sum from the encoded node data - var sumBz [sumSizeBits]byte - copy(sumBz[:], data[len(data)-sumSizeBits:]) + var sumBz [sumSizeBytes]byte + copy(sumBz[:], data[len(data)-sumSizeBytes:]) binary.BigEndian.PutUint64(sumBz[:], sum) // Extract the left and right children - dataWithoutSum := data[:len(data)-sumSizeBits] - leftData = dataWithoutSum[len(innerNodePrefix) : len(innerNodePrefix)+th.hashSize()+sumSizeBits] - rightData = dataWithoutSum[len(innerNodePrefix)+th.hashSize()+sumSizeBits:] + dataWithoutSum := data[:len(data)-sumSizeBytes] + leftData = dataWithoutSum[len(innerNodePrefix) : len(innerNodePrefix)+th.hashSize()+sumSizeBytes] + rightData = dataWithoutSum[len(innerNodePrefix)+th.hashSize()+sumSizeBytes:] return } diff --git a/node_encoders.go b/node_encoders.go index 27bf24d..fdc6a06 100644 --- a/node_encoders.go +++ b/node_encoders.go @@ -81,9 +81,11 @@ func encodeExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byt // encodeSumInnerNode encodes an inner node for an smst given the data for both children func encodeSumInnerNode(leftData, rightData []byte) (data []byte) { // Compute the sum of the current node - var sum [sumSizeBits]byte + var sum [sumSizeBytes]byte leftSum := parseSum(leftData) rightSum := parseSum(rightData) + // TODO_CONSIDERATION: ` I chose BigEndian for readability but most computers + // now are optimized for LittleEndian encoding could be a micro optimization one day.` binary.BigEndian.PutUint64(sum[:], leftSum+rightSum) // Prepare and return the encoded inner node data @@ -95,8 +97,8 @@ func encodeSumInnerNode(leftData, rightData []byte) (data []byte) { // encodeSumExtensionNode encodes the data of a sum extension nodes func encodeSumExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byte) { // Compute the sum of the current node - var sum [sumSizeBits]byte - copy(sum[:], childData[len(childData)-sumSizeBits:]) + var sum [sumSizeBytes]byte + copy(sum[:], childData[len(childData)-sumSizeBytes:]) // Prepare and return the encoded inner node data data = encodeExtensionNode(pathBounds, path, childData) @@ -114,7 +116,7 @@ func checkPrefix(data, prefix []byte) { // parseSum parses the sum from the encoded node data func parseSum(data []byte) uint64 { sum := uint64(0) - sumBz := data[len(data)-sumSizeBits:] + sumBz := data[len(data)-sumSizeBytes:] if !bytes.Equal(sumBz, defaultEmptySum[:]) { sum = binary.BigEndian.Uint64(sumBz) } diff --git a/options.go b/options.go index ed6a75a..92ec8fb 100644 --- a/options.go +++ b/options.go @@ -4,6 +4,7 @@ package smt type TrieSpecOption func(*TrieSpec) // WithPathHasher returns an Option that sets the PathHasher to the one provided +// this MUST not be nil or unknown behaviour will occur. func WithPathHasher(ph PathHasher) TrieSpecOption { return func(ts *TrieSpec) { ts.ph = ph } } diff --git a/proofs.go b/proofs.go index c8f0497..b834468 100644 --- a/proofs.go +++ b/proofs.go @@ -206,7 +206,7 @@ func (proof *SparseMerkleClosestProof) GetValueHash(spec *TrieSpec) []byte { return nil } if spec.sumTrie { - return proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSizeBits] + return proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSizeBytes] } return proof.ClosestValueHash } @@ -324,7 +324,7 @@ func VerifyProof(proof *SparseMerkleProof, root, key, value []byte, spec *TrieSp // VerifySumProof verifies a Merkle proof for a sum trie. func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) { - var sumBz [sumSizeBits]byte + var sumBz [sumSizeBytes]byte binary.BigEndian.PutUint64(sumBz[:], sum) valueHash := spec.valueHash(value) valueHash = append(valueHash, sumBz[:]...) @@ -363,10 +363,10 @@ func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *Trie if proof.ClosestValueHash == nil { return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, nilSpec) } - sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSizeBits:] + sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSizeBytes:] sum := binary.BigEndian.Uint64(sumBz) - valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSizeBits] + valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSizeBytes] return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, nilSpec) } diff --git a/proofs_test.go b/proofs_test.go index dddd41f..b673b2f 100644 --- a/proofs_test.go +++ b/proofs_test.go @@ -177,9 +177,9 @@ func randomizeProof(proof *SparseMerkleProof) *SparseMerkleProof { func randomizeSumProof(proof *SparseMerkleProof) *SparseMerkleProof { sideNodes := make([][]byte, len(proof.SideNodes)) for i := range sideNodes { - sideNodes[i] = make([]byte, len(proof.SideNodes[i])-sumSizeBits) + sideNodes[i] = make([]byte, len(proof.SideNodes[i])-sumSizeBytes) rand.Read(sideNodes[i]) // nolint: errcheck - sideNodes[i] = append(sideNodes[i], proof.SideNodes[i][len(proof.SideNodes[i])-sumSizeBits:]...) + sideNodes[i] = append(sideNodes[i], proof.SideNodes[i][len(proof.SideNodes[i])-sumSizeBytes:]...) } return &SparseMerkleProof{ SideNodes: sideNodes, diff --git a/smst.go b/smst.go index b87647e..a2bc49c 100644 --- a/smst.go +++ b/smst.go @@ -10,7 +10,7 @@ import ( const ( // The number of bits used to represent the sum of a node - sumSizeBits = 8 + sumSizeBytes = 8 ) var _ SparseMerkleSumTrie = (*SMST)(nil) @@ -34,7 +34,7 @@ func NewSparseMerkleSumTrie( // Initialize a non-sum SMT and modify it to have a nil value hasher. // NB: We are using a nil value hasher because the SMST pre-hashes its paths. - // This results result in double path hashing because the SMST is a wrapper + // This results in double path hashing because the SMST is a wrapper // around the SMT. The reason the SMST uses its own path hashing logic is // to account for the additional sum in the encoding/decoding process. // Therefore, the underlying SMT underneath needs a nil path hasher, while @@ -86,12 +86,12 @@ func (smst *SMST) Get(key []byte) (valueDigest []byte, weight uint64, err error) } // Retrieve the node weight - var weightBz [sumSizeBits]byte - copy(weightBz[:], valueDigest[len(valueDigest)-sumSizeBits:]) + var weightBz [sumSizeBytes]byte + copy(weightBz[:], valueDigest[len(valueDigest)-sumSizeBytes:]) weight = binary.BigEndian.Uint64(weightBz[:]) // Remove the weight from the value digest - valueDigest = valueDigest[:len(valueDigest)-sumSizeBits] + valueDigest = valueDigest[:len(valueDigest)-sumSizeBytes] // Return the value digest and weight return valueDigest, weight, nil @@ -106,7 +106,7 @@ func (smst *SMST) Get(key []byte) (valueDigest []byte, weight uint64, err error) // up to the total sum of the trie. func (smst *SMST) Update(key, value []byte, weight uint64) error { // Convert the node weight to a byte slice - var weightBz [sumSizeBits]byte + var weightBz [sumSizeBytes]byte binary.BigEndian.PutUint64(weightBz[:], weight) // Compute the digest of the value and append the weight to it @@ -154,7 +154,7 @@ func (smst *SMST) Sum() uint64 { if !smst.Spec().sumTrie { panic("SMST: not a merkle sum trie") } - var sumbz [sumSizeBits]byte - copy(sumbz[:], []byte(rootDigest)[len([]byte(rootDigest))-sumSizeBits:]) - return binary.BigEndian.Uint64(sumbz[:]) + var sumBz [sumSizeBytes]byte + copy(sumBz[:], []byte(rootDigest)[len([]byte(rootDigest))-sumSizeBytes:]) + return binary.BigEndian.Uint64(sumBz[:]) } diff --git a/smst_proofs_test.go b/smst_proofs_test.go index cf5e41f..d0d8c9d 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -117,7 +117,7 @@ func TestSMST_Proof_Operations(t *testing.T) { require.False(t, result) // Try proving a default value for a non-default leaf. - var sum [sumSizeBits]byte + var sum [sumSizeBytes]byte binary.BigEndian.PutUint64(sum[:], 5) tval := base.valueHash([]byte("testValue")) tval = append(tval, sum[:]...) @@ -301,7 +301,7 @@ func TestSMST_ProveClosest(t *testing.T) { var result bool var root []byte var err error - var sumBz [sumSizeBits]byte + var sumBz [sumSizeBytes]byte smn = simplemap.NewSimpleMap() require.NoError(t, err) @@ -427,7 +427,7 @@ func TestSMST_ProveClosest_OneNode(t *testing.T) { closestPath := sha256.Sum256([]byte("foo")) closestValueHash := []byte("bar") - var sumBz [sumSizeBits]byte + var sumBz [sumSizeBytes]byte binary.BigEndian.PutUint64(sumBz[:], 5) closestValueHash = append(closestValueHash, sumBz[:]...) require.Equal(t, proof, &SparseMerkleClosestProof{ diff --git a/smst_test.go b/smst_test.go index 7458ff5..e331994 100644 --- a/smst_test.go +++ b/smst_test.go @@ -441,7 +441,7 @@ func TestSMST_TotalSum(t *testing.T) { // Check root hash contains the correct hex sum root1 := smst.Root() - sumBz := root1[len(root1)-sumSizeBits:] + sumBz := root1[len(root1)-sumSizeBytes:] rootSum := binary.BigEndian.Uint64(sumBz) require.NoError(t, err) diff --git a/smst_utils_test.go b/smst_utils_test.go index 8a12b44..db2acd0 100644 --- a/smst_utils_test.go +++ b/smst_utils_test.go @@ -27,7 +27,7 @@ func (smst *SMSTWithStorage) Update(key, value []byte, sum uint64) error { return err } valueHash := smst.valueHash(value) - var sumBz [sumSizeBits]byte + var sumBz [sumSizeBytes]byte binary.BigEndian.PutUint64(sumBz[:], sum) value = append(value, sumBz[:]...) return smst.preimages.Set(valueHash, value) @@ -57,13 +57,13 @@ func (smst *SMSTWithStorage) GetValueSum(key []byte) ([]byte, uint64, error) { // Otherwise percolate up any other error return nil, 0, err } - var sumBz [sumSizeBits]byte - copy(sumBz[:], value[len(value)-sumSizeBits:]) + var sumBz [sumSizeBytes]byte + copy(sumBz[:], value[len(value)-sumSizeBytes:]) storedSum := binary.BigEndian.Uint64(sumBz[:]) if storedSum != sum { return nil, 0, fmt.Errorf("sum mismatch for %s: got %d, expected %d", string(key), storedSum, sum) } - return value[:len(value)-sumSizeBits], storedSum, nil + return value[:len(value)-sumSizeBytes], storedSum, nil } // Has returns true if the value at the given key is non-default, false otherwise. diff --git a/smt.go b/smt.go index 8bac991..120186f 100644 --- a/smt.go +++ b/smt.go @@ -111,7 +111,7 @@ func (smt *SMT) Get(key []byte) ([]byte, error) { // Update inserts the `value` for the given `key` into the SMT func (smt *SMT) Update(key, value []byte) error { - // Expand the key into a path by computing its digest + // Convert the key into a path by computing its digest path := smt.ph.Path(key) // Convert the value into a hash by computing its digest @@ -119,7 +119,9 @@ func (smt *SMT) Update(key, value []byte) error { // Update the trie with the new key-value pair var orphans orphanNodes - // Compute the new root by inserting (path, valueHash) starting + + // Compute the new root by inserting (path, valueHash) starting from the + // root of the tree in order to find the correct position of the new leaf. newRoot, err := smt.update(smt.root, 0, path, valueHash, &orphans) if err != nil { return err @@ -155,7 +157,8 @@ func (smt *SMT) update( smt.addOrphan(orphans, node) return newLeaf, nil } - // We insert an "extension" representing multiple single-branch inner nodes + // Create a new innerNode where a previous leafNode was, branching + // based on the path bit at the current depth in the path. var newInner *innerNode if getPathBit(path, prefixLen) == leftChildBit { newInner = &innerNode{ @@ -168,7 +171,9 @@ func (smt *SMT) update( rightChild: newLeaf, } } - // Determine if we need to insert an extension or a branch + // Determine if we need to insert the new innerNode as the child + // of an extensionNode or a insert a the new innerNode in place of + // a pre-existing leafNode with a common prefix. last := &node if depth < prefixLen { // note: this keeps path slice alive - GC inefficiency? @@ -193,6 +198,8 @@ func (smt *SMT) update( smt.addOrphan(orphans, node) + // If the node is an extensionNode split it by the path provided, we + // call update() on the results to place the newLeaf correctly. if extNode, ok := node.(*extensionNode); ok { var branch *trieNode node, branch, depth = extNode.split(path) @@ -204,6 +211,8 @@ func (smt *SMT) update( return node, nil } + // The node must be an innerNode. Depending on which side of the branch inner + // node the newLeaf should be added to, call update() accordingly. inner := node.(*innerNode) var child *trieNode if getPathBit(path, depth) == leftChildBit { diff --git a/trie_spec.go b/trie_spec.go index 3cb821b..a9f047e 100644 --- a/trie_spec.go +++ b/trie_spec.go @@ -41,7 +41,7 @@ func (spec *TrieSpec) placeholder() []byte { // hashSize returns the hash size depending on the trie type func (spec *TrieSpec) hashSize() int { if spec.sumTrie { - return spec.th.hashSize() + sumSizeBits + return spec.th.hashSize() + sumSizeBytes } return spec.th.hashSize() } @@ -106,7 +106,7 @@ func (spec *TrieSpec) hashSumSerialization(data []byte) []byte { return spec.digestSumNode(&ext) } digest := spec.th.digestData(data) - digest = append(digest, data[len(data)-sumSizeBits:]...) + digest = append(digest, data[len(data)-sumSizeBytes:]...) return digest } @@ -210,7 +210,7 @@ func (spec *TrieSpec) digestSumNode(node trieNode) []byte { if *cache == nil { preImage := spec.encodeSumNode(node) *cache = spec.th.digestData(preImage) - *cache = append(*cache, preImage[len(preImage)-sumSizeBits:]...) + *cache = append(*cache, preImage[len(preImage)-sumSizeBytes:]...) } return *cache } @@ -247,8 +247,8 @@ func (spec *TrieSpec) parseSumLeafNode(data []byte) (path, value []byte, weight value = data[prefixLen+spec.ph.PathSize():] // Extract the sum from the encoded node data - var weightBz [sumSizeBits]byte - copy(weightBz[:], value[len(value)-sumSizeBits:]) + var weightBz [sumSizeBytes]byte + copy(weightBz[:], value[len(value)-sumSizeBytes:]) binary.BigEndian.PutUint64(weightBz[:], weight) return @@ -260,13 +260,13 @@ func (spec *TrieSpec) parseSumExtNode(data []byte) (pathBounds, path, childData checkPrefix(data, extNodePrefix) // Extract the sum from the encoded node data - var sumBz [sumSizeBits]byte - copy(sumBz[:], data[len(data)-sumSizeBits:]) + var sumBz [sumSizeBytes]byte + copy(sumBz[:], data[len(data)-sumSizeBytes:]) binary.BigEndian.PutUint64(sumBz[:], sum) // +2 represents the length of the pathBounds pathBounds = data[prefixLen : prefixLen+2] path = data[prefixLen+2 : prefixLen+2+spec.ph.PathSize()] - childData = data[prefixLen+2+spec.ph.PathSize() : len(data)-sumSizeBits] + childData = data[prefixLen+2+spec.ph.PathSize() : len(data)-sumSizeBytes] return } diff --git a/types.go b/types.go index 8ef9a8b..fd6dacf 100644 --- a/types.go +++ b/types.go @@ -17,7 +17,7 @@ var ( // defaultEmptyValue is the default value for a leaf node defaultEmptyValue []byte // defaultEmptySum is the default sum value for a leaf node - defaultEmptySum [sumSizeBits]byte + defaultEmptySum [sumSizeBytes]byte ) // MerkleRoot is a type alias for a byte slice returned from the Root method From db19822cdbc4699645b3daa333288cd8c4be3a3b Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Tue, 4 Jun 2024 20:45:58 -0700 Subject: [PATCH 40/40] Added TODO --- extension_node.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/extension_node.go b/extension_node.go index 016ed84..fd2dc03 100644 --- a/extension_node.go +++ b/extension_node.go @@ -8,7 +8,9 @@ var _ trieNode = (*extensionNode)(nil) // Extension nodes are used to captures a series of inner nodes that only // have one child in a succinct `pathBounds` for optimization purposes. // -// Assumption: the path is <=256 bits +// TODO_TECHDEBT(@Olshansk): Does this assumption still hold? +// +// Assumption: the path is <=256 bits type extensionNode struct { // The path (starting at the root) to this extension node. path []byte